diff options
Diffstat (limited to '')
116 files changed, 17559 insertions, 0 deletions
diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..b2713c7 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,3 @@ +[run] +parallel=True +source=pgcli diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..bacb65c --- /dev/null +++ b/.editorconfig @@ -0,0 +1,15 @@ +# editorconfig.org +# Get your text editor plugin at: +# http://editorconfig.org/#download +root = true + +[*] +charset = utf-8 +end_of_line = lf +indent_size = 4 +indent_style = space +insert_final_newline = true +trim_trailing_whitespace = true + +[travis.yml] +indent_size = 2 diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/.git-blame-ignore-revs diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md new file mode 100644 index 0000000..b5cdbec --- /dev/null +++ b/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,9 @@ +## Description +<!--- Describe your problem as fully as you can. --> + +## Your environment +<!-- This gives us some more context to work with. --> + +- [ ] Please provide your OS and version information. +- [ ] Please provide your CLI version. +- [ ] What is the output of ``pip freeze`` command. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..35e8486 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,12 @@ +## Description +<!--- Describe your changes in detail. --> + + + +## Checklist +<!--- We appreciate your help and want to give you credit. Please take a moment to put an `x` in the boxes below as you complete them. --> +- [ ] I've added this contribution to the `changelog.rst`. +- [ ] I've added my name to the `AUTHORS` file (or it's already there). +<!-- We would appreciate if you comply with our code style guidelines. --> +- [ ] I installed pre-commit hooks (`pip install pre-commit && pre-commit install`), and ran `black` on my code. +- [x] Please squash merge this pull request (uncheck if you'd like us to merge as multiple commits) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..68a69ac --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,102 @@ +name: pgcli + +on: + push: + branches: + - main + pull_request: + paths-ignore: + - '**.rst' + +jobs: + build: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + + services: + postgres: + image: postgres:9.6 + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install pgbouncer + run: | + sudo apt install pgbouncer -y + + sudo chmod 666 /etc/pgbouncer/*.* + + cat <<EOF > /etc/pgbouncer/userlist.txt + "postgres" "postgres" + EOF + + cat <<EOF > /etc/pgbouncer/pgbouncer.ini + [databases] + * = host=localhost port=5432 + [pgbouncer] + listen_port = 6432 + listen_addr = localhost + auth_type = trust + auth_file = /etc/pgbouncer/userlist.txt + logfile = pgbouncer.log + pidfile = pgbouncer.pid + admin_users = postgres + EOF + + sudo systemctl stop pgbouncer + + pgbouncer -d /etc/pgbouncer/pgbouncer.ini + + psql -h localhost -U postgres -p 6432 pgbouncer -c 'show help' + + - name: Install beta version of pendulum + run: pip install pendulum==3.0.0b1 + if: matrix.python-version == '3.12' + + - name: Install requirements + run: | + pip install -U pip setuptools + pip install --no-cache-dir ".[sshtunnel]" + pip install -r requirements-dev.txt + pip install keyrings.alt>=3.1 + + - name: Run unit tests + run: coverage run --source pgcli -m pytest + + - name: Run integration tests + env: + PGUSER: postgres + PGPASSWORD: postgres + + run: behave tests/features --no-capture + + - name: Check changelog for ReST compliance + run: rst2html.py --halt=warning changelog.rst >/dev/null + + - name: Run Black + run: black --check . + if: matrix.python-version == '3.8' + + - name: Coverage + run: | + coverage combine + coverage report + codecov diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 0000000..c9232c7 --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,41 @@ +name: "CodeQL" + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + schedule: + - cron: "29 13 * * 1" + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ python ] + + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Initialize CodeQL + uses: github/codeql-action/init@v2 + with: + languages: ${{ matrix.language }} + queries: +security-and-quality + + - name: Autobuild + uses: github/codeql-action/autobuild@v2 + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v2 + with: + category: "/language:${{ matrix.language }}" diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b993cb9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,74 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +pyvenv/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +.pytest_cache + +# Translations +*.mo +*.pot + +# Django stuff: +*.log + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# PyCharm +.idea/ +*.iml + +# Vagrant +.vagrant/ + +# Generated Packages +*.deb +*.rpm + +.vscode/ +venv/ + +.ropeproject/ + diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..8462cc2 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,5 @@ +repos: +- repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black @@ -0,0 +1,136 @@ +Many thanks to the following contributors. + +Project Lead: +------------- + * Irina Truong + +Core Devs: +---------- + * Amjith Ramanujam + * Darik Gamble + * Stuart Quin + * Joakim Koljonen + * Daniel Rocco + * Karl-Aksel Puulmann + * Dick Marinus + +Contributors: +------------- + * Brett + * Étienne BERSAC (bersace) + * Daniel Schwarz + * inkn + * Jonathan Slenders + * xalley + * TamasNo1 + * François Pietka + * Michael Kaminsky + * Alexander Kukushkin + * Ludovic Gasc (GMLudo) + * Marc Abramowitz + * Nick Hahner + * Jay Zeng + * Dimitar Roustchev + * Dhaivat Pandit + * Matheus Rosa + * Ali Kargın + * Nathan Jhaveri + * David Celis + * Sven-Hendrik Haase + * Çağatay Yüksel + * Tiago Ribeiro + * Vignesh Anand + * Charlie Arnold + * dwalmsley + * Artur Dryomov + * rrampage + * while0pass + * Eric Workman + * xa + * Hans Roman + * Guewen Baconnier + * Dionysis Grigoropoulos + * Jacob Magnusson + * Johannes Hoff + * vinotheassassin + * Jacek Wielemborek + * Fabien Meghazi + * Manuel Barkhau + * Sergii V + * Emanuele Gaifas + * Owen Stephens + * Russell Davies + * AlexTes + * Hraban Luyat + * Jackson Popkin + * Gustavo Castro + * Alexander Schmolck + * Donnell Muse + * Andrew Speed + * Dmitry B + * Isank + * Marcin Sztolcman + * Bojan Delić + * Chris Vaughn + * Frederic Aoustin + * Pierre Giraud + * Andrew Kuchling + * Dan Clark + * Catherine Devlin + * Jason Ribeiro + * Rishi Ramraj + * Matthieu Guilbert + * Alexandr Korsak + * Saif Hakim + * Artur Balabanov + * Kenny Do + * Max Rothman + * Daniel Egger + * Ignacio Campabadal + * Mikhail Elovskikh (wronglink) + * Marcin Cieślak (saper) + * easteregg (verfriemelt-dot-org) + * Scott Brenstuhl (808sAndBR) + * Nathan Verzemnieks + * raylu + * Zhaolong Zhu + * Zane C. Bowers-Hadley + * Telmo "Trooper" (telmotrooper) + * Alexander Zawadzki + * Pablo A. Bianchi (pabloab) + * Sebastian Janko (sebojanko) + * Pedro Ferrari (petobens) + * Martin Matejek (mmtj) + * Jonas Jelten + * BrownShibaDog + * George Thomas(thegeorgeous) + * Yoni Nakache(lazydba247) + * Gantsev Denis + * Stephano Paraskeva + * Panos Mavrogiorgos (pmav99) + * Igor Kim (igorkim) + * Anthony DeBarros (anthonydb) + * Seungyong Kwak (GUIEEN) + * Tom Caruso (tomplex) + * Jan Brun Rasmussen (janbrunrasmussen) + * Kevin Marsh (kevinmarsh) + * Eero Ruohola (ruohola) + * Miroslav Šedivý (eumiro) + * Eric R Young (ERYoung11) + * Paweł Sacawa (psacawa) + * Bruno Inec (sweenu) + * Daniele Varrazzo + * Daniel Kukula (dkuku) + * Kian-Meng Ang (kianmeng) + * Liu Zhao (astroshot) + * Rigo Neri (rigoneri) + * Anna Glasgall (annathyst) + * Andy Schoenberger (andyscho) + * Damien Baty (dbaty) + * blag + * Rob Berry (rob-b) + * Sharon Yogev (sharonyogev) + +Creator: +-------- +Amjith Ramanujam diff --git a/DEVELOP.rst b/DEVELOP.rst new file mode 100644 index 0000000..aed2cf8 --- /dev/null +++ b/DEVELOP.rst @@ -0,0 +1,220 @@ +Development Guide +----------------- +This is a guide for developers who would like to contribute to this project. + +GitHub Workflow +--------------- + +If you're interested in contributing to pgcli, first of all my heart felt +thanks. `Fork the project <https://github.com/dbcli/pgcli>`_ on github. Then +clone your fork into your computer (``git clone <url-for-your-fork>``). Make +the changes and create the commits in your local machine. Then push those +changes to your fork. Then click on the pull request icon on github and create +a new pull request. Add a description about the change and send it along. I +promise to review the pull request in a reasonable window of time and get back +to you. + +In order to keep your fork up to date with any changes from mainline, add a new +git remote to your local copy called 'upstream' and point it to the main pgcli +repo. + +:: + + $ git remote add upstream git@github.com:dbcli/pgcli.git + +Once the 'upstream' end point is added you can then periodically do a ``git +pull upstream master`` to update your local copy and then do a ``git push +origin master`` to keep your own fork up to date. + +Check Github's `Understanding the GitHub flow guide +<https://guides.github.com/introduction/flow/>`_ for a more detailed +explanation of this process. + +Local Setup +----------- + +The installation instructions in the README file are intended for users of +pgcli. If you're developing pgcli, you'll need to install it in a slightly +different way so you can see the effects of your changes right away without +having to go through the install cycle every time you change the code. + +It is highly recommended to use virtualenv for development. If you don't know +what a virtualenv is, `this guide <http://docs.python-guide.org/en/latest/dev/virtualenvs/#virtual-environments>`_ +will help you get started. + +Create a virtualenv (let's call it pgcli-dev). Activate it: + +:: + + source ./pgcli-dev/bin/activate + + or + + .\pgcli-dev\scripts\activate (for Windows) + +Once the virtualenv is activated, `cd` into the local clone of pgcli folder +and install pgcli using pip as follows: + +:: + + $ pip install --editable . + + or + + $ pip install -e . + +This will install the necessary dependencies as well as install pgcli from the +working folder into the virtualenv. By installing it using `pip install -e` +we've linked the pgcli installation with the working copy. Any changes made +to the code are immediately available in the installed version of pgcli. This +makes it easy to change something in the code, launch pgcli and check the +effects of your changes. + +Adding PostgreSQL Special (Meta) Commands +----------------------------------------- + +If you want to work on adding new meta-commands (such as `\dp`, `\ds`, `dy`), +you need to contribute to `pgspecial <https://github.com/dbcli/pgspecial/>`_ +project. + +Visual Studio Code Debugging +----------------------------- +To set up Visual Studio Code to debug pgcli requires a launch.json file. + +Within the project, create a file: .vscode\\launch.json like below. + +:: + + { + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Module", + "type": "python", + "request": "launch", + "module": "pgcli.main", + "justMyCode": false, + "console": "externalTerminal", + "env": { + "PGUSER": "postgres", + "PGPASS": "password", + "PGHOST": "localhost", + "PGPORT": "5432" + } + } + ] + } + +Building RPM and DEB packages +----------------------------- + +You will need Vagrant 1.7.2 or higher. In the project root there is a +Vagrantfile that is setup to do multi-vm provisioning. If you're setting things +up for the first time, then do: + +:: + + $ version=x.y.z vagrant up debian + $ version=x.y.z vagrant up centos + +If you already have those VMs setup and you're merely creating a new version of +DEB or RPM package, then you can do: + +:: + + $ version=x.y.z vagrant provision + +That will create a .deb file and a .rpm file. + +The deb package can be installed as follows: + +:: + + $ sudo dpkg -i pgcli*.deb # if dependencies are available. + + or + + $ sudo apt-get install -f pgcli*.deb # if dependencies are not available. + + +The rpm package can be installed as follows: + +:: + + $ sudo yum install pgcli*.rpm + +Running the integration tests +----------------------------- + +Integration tests use `behave package <https://behave.readthedocs.io/>`_ and +pytest. +Configuration settings for this package are provided via a ``behave.ini`` file +in the ``tests`` directory. An example:: + + [behave] + stderr_capture = false + + [behave.userdata] + pg_test_user = dbuser + pg_test_host = db.example.com + pg_test_port = 30000 + +First, install the requirements for testing: + +:: + $ pip install -U pip setuptools + $ pip install --no-cache-dir ".[sshtunnel]" + $ pip install -r requirements-dev.txt + +Ensure that the database user has permissions to create and drop test databases +by checking your ``pg_hba.conf`` file. The default user should be ``postgres`` +at ``localhost``. Make sure the authentication method is set to ``trust``. If +you made any changes to your ``pg_hba.conf`` make sure to restart the postgres +service for the changes to take effect. + +:: + + # ONLY IF YOU MADE CHANGES TO YOUR pg_hba.conf FILE + $ sudo service postgresql restart + +After that, tests in the ``/pgcli/tests`` directory can be run with: +(Note that these ``behave`` tests do not currently work when developing on Windows due to pexpect incompatibility.) + +:: + + # on directory /pgcli/tests + $ behave + +And on the ``/pgcli`` directory: + +:: + + # on directory /pgcli + $ py.test + +To see stdout/stderr, use the following command: + +:: + + $ behave --no-capture + +Troubleshooting the integration tests +------------------------------------- + +- Make sure postgres instance on localhost is running +- Check your ``pg_hba.conf`` file to verify local connections are enabled +- Check `this issue <https://github.com/dbcli/pgcli/issues/945>`_ for relevant information. +- `File an issue <https://github.com/dbcli/pgcli/issues/new>`_. + +Coding Style +------------ + +``pgcli`` uses `black <https://github.com/ambv/black>`_ to format the source code. Make sure to install black. + +Releases +-------- + +If you're the person responsible for releasing `pgcli`, `this guide <https://github.com/dbcli/pgcli/blob/main/RELEASES.md>`_ is for you. diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..32d341a --- /dev/null +++ b/Dockerfile @@ -0,0 +1,6 @@ +FROM python:3.8 + +COPY . /app +RUN cd /app && pip install -e . + +CMD pgcli diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..83226b7 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,26 @@ +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the documentation and/or + other materials provided with the distribution. + +* Neither the name of the {organization} nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..1c5f697 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +include LICENSE.txt AUTHORS changelog.rst +recursive-include tests *.py *.txt *.feature *.ini diff --git a/README.rst b/README.rst new file mode 100644 index 0000000..56695cc --- /dev/null +++ b/README.rst @@ -0,0 +1,392 @@ +We stand with Ukraine +--------------------- + +Ukrainian people are fighting for their country. A lot of civilians, women and children, are suffering. Hundreds were killed and injured, and thousands were displaced. + +This is an image from my home town, Kharkiv. This place is right in the old city center. + +.. image:: screenshots/kharkiv-destroyed.jpg + +Picture by @fomenko_ph (Telegram). + +Please consider donating or volunteering. + +* https://bank.gov.ua/en/ +* https://savelife.in.ua/en/donate/ +* https://www.comebackalive.in.ua/donate +* https://www.globalgiving.org/projects/ukraine-crisis-relief-fund/ +* https://www.savethechildren.org/us/where-we-work/ukraine +* https://www.facebook.com/donate/1137971146948461/ +* https://donate.wck.org/give/393234#!/donation/checkout +* https://atlantaforukraine.com/ + + +A REPL for Postgres +------------------- + +|Build Status| |CodeCov| |PyPI| |netlify| + +This is a postgres client that does auto-completion and syntax highlighting. + +Home Page: http://pgcli.com + +MySQL Equivalent: http://mycli.net + +.. image:: screenshots/pgcli.gif +.. image:: screenshots/image01.png + +Quick Start +----------- + +If you already know how to install python packages, then you can simply do: + +:: + + $ pip install -U pgcli + + or + + $ sudo apt-get install pgcli # Only on Debian based Linux (e.g. Ubuntu, Mint, etc) + $ brew install pgcli # Only on macOS + +If you don't know how to install python packages, please check the +`detailed instructions`_. + +.. _`detailed instructions`: https://github.com/dbcli/pgcli#detailed-installation-instructions + +Usage +----- + +:: + + $ pgcli [database_name] + + or + + $ pgcli postgresql://[user[:password]@][netloc][:port][/dbname][?extra=value[&other=other-value]] + +Examples: + +:: + + $ pgcli local_database + + $ pgcli postgres://amjith:pa$$w0rd@example.com:5432/app_db?sslmode=verify-ca&sslrootcert=/myrootcert + +For more details: + +:: + + $ pgcli --help + + Usage: pgcli [OPTIONS] [DBNAME] [USERNAME] + + Options: + -h, --host TEXT Host address of the postgres database. + -p, --port INTEGER Port number at which the postgres instance is + listening. + -U, --username TEXT Username to connect to the postgres database. + -u, --user TEXT Username to connect to the postgres database. + -W, --password Force password prompt. + -w, --no-password Never prompt for password. + --single-connection Do not use a separate connection for completions. + -v, --version Version of pgcli. + -d, --dbname TEXT database name to connect to. + --pgclirc FILE Location of pgclirc file. + -D, --dsn TEXT Use DSN configured into the [alias_dsn] section + of pgclirc file. + --list-dsn list of DSN configured into the [alias_dsn] + section of pgclirc file. + --row-limit INTEGER Set threshold for row limit prompt. Use 0 to + disable prompt. + --less-chatty Skip intro on startup and goodbye on exit. + --prompt TEXT Prompt format (Default: "\u@\h:\d> "). + --prompt-dsn TEXT Prompt format for connections using DSN aliases + (Default: "\u@\h:\d> "). + -l, --list list available databases, then exit. + --auto-vertical-output Automatically switch to vertical output mode if + the result is wider than the terminal width. + --warn [all|moderate|off] Warn before running a destructive query. + --help Show this message and exit. + +``pgcli`` also supports many of the same `environment variables`_ as ``psql`` for login options (e.g. ``PGHOST``, ``PGPORT``, ``PGUSER``, ``PGPASSWORD``, ``PGDATABASE``). + +The SSL-related environment variables are also supported, so if you need to connect a postgres database via ssl connection, you can set set environment like this: + +:: + + export PGSSLMODE="verify-full" + export PGSSLCERT="/your-path-to-certs/client.crt" + export PGSSLKEY="/your-path-to-keys/client.key" + export PGSSLROOTCERT="/your-path-to-ca/ca.crt" + pgcli -h localhost -p 5432 -U username postgres + +.. _environment variables: https://www.postgresql.org/docs/current/libpq-envars.html + +Features +-------- + +The `pgcli` is written using prompt_toolkit_. + +* Auto-completes as you type for SQL keywords as well as tables and + columns in the database. +* Syntax highlighting using Pygments. +* Smart-completion (enabled by default) will suggest context-sensitive + completion. + + - ``SELECT * FROM <tab>`` will only show table names. + - ``SELECT * FROM users WHERE <tab>`` will only show column names. + +* Primitive support for ``psql`` back-slash commands. +* Pretty prints tabular data. + +.. _prompt_toolkit: https://github.com/jonathanslenders/python-prompt-toolkit +.. _tabulate: https://pypi.python.org/pypi/tabulate + +Config +------ +A config file is automatically created at ``~/.config/pgcli/config`` at first launch. +See the file itself for a description of all available options. + +Contributions: +-------------- + +If you're interested in contributing to this project, first of all I would like +to extend my heartfelt gratitude. I've written a small doc to describe how to +get this running in a development setup. + +https://github.com/dbcli/pgcli/blob/master/DEVELOP.rst + +Please feel free to reach out to us if you need help. +* Amjith, pgcli author: amjith.r@gmail.com, Twitter: `@amjithr <http://twitter.com/amjithr>`_ +* Irina, pgcli maintainer: i.chernyavska@gmail.com, Twitter: `@irinatruong <http://twitter.com/irinatruong>`_ + +Detailed Installation Instructions: +----------------------------------- + +macOS: +====== + +The easiest way to install pgcli is using Homebrew. + +:: + + $ brew install pgcli + +Done! + +Alternatively, you can install ``pgcli`` as a python package using a package +manager called called ``pip``. You will need postgres installed on your system +for this to work. + +In depth getting started guide for ``pip`` - https://pip.pypa.io/en/latest/installing.html. + +:: + + $ which pip + +If it is installed then you can do: + +:: + + $ pip install pgcli + +If that fails due to permission issues, you might need to run the command with +sudo permissions. + +:: + + $ sudo pip install pgcli + +If pip is not installed check if easy_install is available on the system. + +:: + + $ which easy_install + + $ sudo easy_install pgcli + +Linux: +====== + +In depth getting started guide for ``pip`` - https://pip.pypa.io/en/latest/installing.html. + +Check if pip is already available in your system. + +:: + + $ which pip + +If it doesn't exist, use your linux package manager to install `pip`. This +might look something like: + +:: + + $ sudo apt-get install python-pip # Debian, Ubuntu, Mint etc + + or + + $ sudo yum install python-pip # RHEL, Centos, Fedora etc + +``pgcli`` requires python-dev, libpq-dev and libevent-dev packages. You can +install these via your operating system package manager. + + +:: + + $ sudo apt-get install python-dev libpq-dev libevent-dev + + or + + $ sudo yum install python-devel postgresql-devel + +Then you can install pgcli: + +:: + + $ sudo pip install pgcli + + +Docker +====== + +Pgcli can be run from within Docker. This can be useful to try pgcli without +installing it, or any dependencies, system-wide. + +To build the image: + +:: + + $ docker build -t pgcli . + +To create a container from the image: + +:: + + $ docker run --rm -ti pgcli pgcli <ARGS> + +To access postgresql databases listening on localhost, make sure to run the +docker in "host net mode". E.g. to access a database called "foo" on the +postgresql server running on localhost:5432 (the standard port): + +:: + + $ docker run --rm -ti --net host pgcli pgcli -h localhost foo + +To connect to a locally running instance over a unix socket, bind the socket to +the docker container: + +:: + + $ docker run --rm -ti -v /var/run/postgres:/var/run/postgres pgcli pgcli foo + + +IPython +======= + +Pgcli can be run from within `IPython <https://ipython.org>`_ console. When working on a query, +it may be useful to drop into a pgcli session without leaving the IPython console, iterate on a +query, then quit pgcli to find the query results in your IPython workspace. + +Assuming you have IPython installed: + +:: + + $ pip install ipython-sql + +After that, run ipython and load the ``pgcli.magic`` extension: + +:: + + $ ipython + + In [1]: %load_ext pgcli.magic + + +Connect to a database and construct a query: + +:: + + In [2]: %pgcli postgres://someone@localhost:5432/world + Connected: someone@world + someone@localhost:world> select * from city c where countrycode = 'USA' and population > 1000000; + +------+--------------+---------------+--------------+--------------+ + | id | name | countrycode | district | population | + |------+--------------+---------------+--------------+--------------| + | 3793 | New York | USA | New York | 8008278 | + | 3794 | Los Angeles | USA | California | 3694820 | + | 3795 | Chicago | USA | Illinois | 2896016 | + | 3796 | Houston | USA | Texas | 1953631 | + | 3797 | Philadelphia | USA | Pennsylvania | 1517550 | + | 3798 | Phoenix | USA | Arizona | 1321045 | + | 3799 | San Diego | USA | California | 1223400 | + | 3800 | Dallas | USA | Texas | 1188580 | + | 3801 | San Antonio | USA | Texas | 1144646 | + +------+--------------+---------------+--------------+--------------+ + SELECT 9 + Time: 0.003s + + +Exit out of pgcli session with ``Ctrl + D`` and find the query results: + +:: + + someone@localhost:world> + Goodbye! + 9 rows affected. + Out[2]: + [(3793, u'New York', u'USA', u'New York', 8008278), + (3794, u'Los Angeles', u'USA', u'California', 3694820), + (3795, u'Chicago', u'USA', u'Illinois', 2896016), + (3796, u'Houston', u'USA', u'Texas', 1953631), + (3797, u'Philadelphia', u'USA', u'Pennsylvania', 1517550), + (3798, u'Phoenix', u'USA', u'Arizona', 1321045), + (3799, u'San Diego', u'USA', u'California', 1223400), + (3800, u'Dallas', u'USA', u'Texas', 1188580), + (3801, u'San Antonio', u'USA', u'Texas', 1144646)] + +The results are available in special local variable ``_``, and can be assigned to a variable of your +choice: + +:: + + In [3]: my_result = _ + +Pgcli dropped support for Python<3.8 as of 4.0.0. If you need it, install ``pgcli <= 4.0.0``. + +Thanks: +------- + +A special thanks to `Jonathan Slenders <https://twitter.com/jonathan_s>`_ for +creating `Python Prompt Toolkit <http://github.com/jonathanslenders/python-prompt-toolkit>`_, +which is quite literally the backbone library, that made this app possible. +Jonathan has also provided valuable feedback and support during the development +of this app. + +`Click <http://click.pocoo.org/>`_ is used for command line option parsing +and printing error messages. + +Thanks to `psycopg <https://www.psycopg.org/>`_ for providing a rock solid +interface to Postgres database. + +Thanks to all the beta testers and contributors for your time and patience. :) + + +.. |Build Status| image:: https://github.com/dbcli/pgcli/actions/workflows/ci.yml/badge.svg?branch=main + :target: https://github.com/dbcli/pgcli/actions/workflows/ci.yml + +.. |CodeCov| image:: https://codecov.io/gh/dbcli/pgcli/branch/master/graph/badge.svg + :target: https://codecov.io/gh/dbcli/pgcli + :alt: Code coverage report + +.. |Landscape| image:: https://landscape.io/github/dbcli/pgcli/master/landscape.svg?style=flat + :target: https://landscape.io/github/dbcli/pgcli/master + :alt: Code Health + +.. |PyPI| image:: https://img.shields.io/pypi/v/pgcli.svg + :target: https://pypi.python.org/pypi/pgcli/ + :alt: Latest Version + +.. |netlify| image:: https://api.netlify.com/api/v1/badges/3a0a14dd-776d-445d-804c-3dd74fe31c4e/deploy-status + :target: https://app.netlify.com/sites/pgcli/deploys + :alt: Netlify diff --git a/RELEASES.md b/RELEASES.md new file mode 100644 index 0000000..526c260 --- /dev/null +++ b/RELEASES.md @@ -0,0 +1,24 @@ +Releasing pgcli +--------------- + +You have been made the maintainer of `pgcli`? Congratulations! We have a release script to help you: + +```sh +> python release.py --help +Usage: release.py [options] + +Options: + -h, --help show this help message and exit + -c, --confirm-steps Confirm every step. If the step is not confirmed, it + will be skipped. + -d, --dry-run Print out, but not actually run any steps. +``` + +The script can be run with `-c` to confirm or skip steps. There's also a `--dry-run` option that only prints out the steps. + +To release a new version of the package: + +* Create and merge a PR to bump the version in the changelog ([example PR](https://github.com/dbcli/pgcli/pull/1325)). +* Pull `main` and bump the version number inside `pgcli/__init__.py`. Do not check in - the release script will do that. +* Make sure you have the dev requirements installed: `pip install -r requirements-dev.txt -U --upgrade-strategy only-if-needed`. +* Finally, run the release script: `python release.py`. @@ -0,0 +1,12 @@ +# vi: ft=vimwiki +* [ ] Add coverage. +* [ ] Refactor to sqlcompletion to consume the text from left to right and use a state machine to suggest cols or tables instead of relying on hacks. +* [ ] Add a few more special commands. (\l pattern, \dp, \ds, \dy, \z etc) +* [ ] Refactor pgspecial.py to a class. +* [ ] Show/hide docs for a statement using a keybinding. +* [ ] Check how to add the name of the table before printing the table. +* [ ] Add a new trigger for M-/ that does naive completion. +* [ ] New Feature List - Write the current version to config file. At launch if the version has changed, display the changelog between the two versions. +* [ ] Add a test for 'select * from custom.abc where custom.abc.' should suggest columns from abc. +* [ ] pgexecute columns(), tables() etc can be just cursors instead of fetchall() +* [ ] Add colorschemes in config file. diff --git a/Vagrantfile b/Vagrantfile new file mode 100644 index 0000000..297e70a --- /dev/null +++ b/Vagrantfile @@ -0,0 +1,138 @@ +# -*- mode: ruby -*- +# vi: set ft=ruby : +# +# + +Vagrant.configure(2) do |config| + + config.vm.synced_folder ".", "/pgcli" + + pgcli_version = ENV['version'] + pgcli_description = "Postgres CLI with autocompletion and syntax highlighting" + + config.vm.define "debian" do |debian| + debian.vm.box = "bento/debian-10.8" + debian.vm.provision "shell", inline: <<-SHELL + echo "-> Building DEB on `lsb_release -d`" + sudo apt-get update + sudo apt-get install -y libpq-dev python-dev python-setuptools rubygems + sudo apt install -y python3-pip + sudo pip3 install --no-cache-dir virtualenv virtualenv-tools3 + sudo apt-get install -y ruby-dev + sudo apt-get install -y git + sudo apt-get install -y rpm librpmbuild8 + + sudo gem install fpm + + echo "-> Cleaning up old workspace" + sudo rm -rf build + mkdir -p build/usr/share + virtualenv build/usr/share/pgcli + build/usr/share/pgcli/bin/pip install /pgcli + + echo "-> Cleaning Virtualenv" + cd build/usr/share/pgcli + virtualenv-tools --update-path /usr/share/pgcli > /dev/null + cd /home/vagrant/ + + echo "-> Removing compiled files" + find build -iname '*.pyc' -delete + find build -iname '*.pyo' -delete + + echo "-> Creating PgCLI deb" + sudo fpm -t deb -s dir -C build -n pgcli -v #{pgcli_version} \ + -a all \ + -d libpq-dev \ + -d python-dev \ + -p /pgcli/ \ + --after-install /pgcli/post-install \ + --after-remove /pgcli/post-remove \ + --url https://github.com/dbcli/pgcli \ + --description "#{pgcli_description}" \ + --license 'BSD' + + SHELL + end + + +# This is considerably more messy than the debian section. I had to go off-standard to update +# some packages to get this to work. + + config.vm.define "centos" do |centos| + + centos.vm.box = "bento/centos-7.9" + centos.vm.box_version = "202012.21.0" + centos.vm.provision "shell", inline: <<-SHELL + #!/bin/bash + echo "-> Building RPM on `hostnamectl | grep "Operating System"`" + export PATH=/usr/local/rvm/gems/ruby-2.6.3/bin:/usr/local/rvm/gems/ruby-2.6.3@global/bin:/usr/local/rvm/rubies/ruby-2.6.3/bin:/usr/local/sbin:/usr/local/bin:/sbin:/bin:/usr/sbin:/usr/bin:/usr/local/rvm/bin:/root/bin + echo "PATH -> " $PATH + +##### +### get base updates + + sudo yum install -y rpm-build gcc postgresql-devel python-devel python3-pip git python3-devel + +###### +### install FPM, which we need to install to get an up-to-date version of ruby, which we need for git + + echo "-> Get FPM installed" + # import the necessary GPG keys + gpg --keyserver hkp://pool.sks-keyservers.net --recv-keys 409B6B1796C275462A1703113804BB82D39DC0E3 7D2BAF1CF37B13E2069D6956105BD0E739499BDB + sudo gpg --keyserver hkp://pool.sks-keyservers.net --recv-keys 409B6B1796C275462A1703113804BB82D39DC0E3 7D2BAF1CF37B13E2069D6956105BD0E739499BDB + # install RVM + sudo curl -sSL https://get.rvm.io | sudo bash -s stable + sudo usermod -aG rvm vagrant + sudo usermod -aG rvm root + sudo /usr/local/rvm/bin/rvm alias create default 2.6.3 + source /etc/profile.d/rvm.sh + + # install a newer version of ruby. centos7 only comes with ruby2.0.0, which isn't good enough for git. + sudo yum install -y ruby-devel + sudo /usr/local/rvm/bin/rvm install 2.6.3 + + # + # yes,this gives an error about generating doc but we don't need the doc. + + /usr/local/rvm/gems/ruby-2.6.3/wrappers/gem install fpm + +###### + + sudo pip3 install virtualenv virtualenv-tools3 + echo "-> Cleaning up old workspace" + rm -rf build + mkdir -p build/usr/share + virtualenv build/usr/share/pgcli + build/usr/share/pgcli/bin/pip install /pgcli + + echo "-> Cleaning Virtualenv" + cd build/usr/share/pgcli + virtualenv-tools --update-path /usr/share/pgcli > /dev/null + cd /home/vagrant/ + + echo "-> Removing compiled files" + find build -iname '*.pyc' -delete + find build -iname '*.pyo' -delete + + cd /home/vagrant + echo "-> Creating PgCLI RPM" + /usr/local/rvm/gems/ruby-2.6.3/gems/fpm-1.12.0/bin/fpm -t rpm -s dir -C build -n pgcli -v #{pgcli_version} \ + -a all \ + -d postgresql-devel \ + -d python-devel \ + -p /pgcli/ \ + --after-install /pgcli/post-install \ + --after-remove /pgcli/post-remove \ + --url https://github.com/dbcli/pgcli \ + --description "#{pgcli_description}" \ + --license 'BSD' + + + SHELL + + + end + + +end + diff --git a/changelog.rst b/changelog.rst new file mode 100644 index 0000000..7d08839 --- /dev/null +++ b/changelog.rst @@ -0,0 +1,1203 @@ +================== +4.0.1 (2023-11-30) +================== + +Internal: +--------- +* Allow stable version of pendulum. + +================== +4.0.0 (2023-11-27) +================== + +Features: +--------- + +* Ask for confirmation when quitting cli while a transaction is ongoing. +* New `destructive_statements_require_transaction` config option to refuse to execute a + destructive SQL statement if outside a transaction. This option is off by default. +* Changed the `destructive_warning` config to be a list of commands that are considered + destructive. This would allow you to be warned on `create`, `grant`, or `insert` queries. +* Destructive warnings will now include the alias dsn connection string name if provided (-D option). +* pgcli.magic will now work with connection URLs that use TLS client certificates for authentication +* Have config option to retry queries on operational errors like connections being lost. + Also prevents getting stuck in a retry loop. +* Config option to not restart connection when cancelling a `destructive_warning` query. By default, + it will now not restart. +* Config option to always run with a single connection. +* Add comment explaining default LESS environment variable behavior and change example pager setting. +* Added `\echo` & `\qecho` special commands. ([issue 1335](https://github.com/dbcli/pgcli/issues/1335)). + +Bug fixes: +---------- + +* Fix `\ev` not producing a correctly quoted "schema"."view" +* Fix 'invalid connection option "dsn"' ([issue 1373](https://github.com/dbcli/pgcli/issues/1373)). +* Fix explain mode when used with `expand`, `auto_expand`, or `--explain-vertical-output` ([issue 1393](https://github.com/dbcli/pgcli/issues/1393)). +* Fix sql-insert format emits NULL as 'None' ([issue 1408](https://github.com/dbcli/pgcli/issues/1408)). +* Improve check for prompt-toolkit 3.0.6 ([issue 1416](https://github.com/dbcli/pgcli/issues/1416)). +* Allow specifying an `alias_map_file` in the config that will use + predetermined table aliases instead of generating aliases programmatically on + the fly +* Fixed SQL error when there is a comment on the first line: ([issue 1403](https://github.com/dbcli/pgcli/issues/1403)) +* Fix wrong usage of prompt instead of confirm when confirm execution of destructive query + +Internal: +--------- + +* Drop support for Python 3.7 and add 3.12. + +3.5.0 (2022/09/15): +=================== + +Features: +--------- + +* New formatter is added to export query result to sql format (such as sql-insert, sql-update) like mycli. + +Bug fixes: +---------- + +* Fix exception when retrieving password from keyring ([issue 1338](https://github.com/dbcli/pgcli/issues/1338)). +* Fix using comments with special commands ([issue 1362](https://github.com/dbcli/pgcli/issues/1362)). +* Small improvements to the Windows developer experience +* Fix submitting queries in safe multiline mode ([1360](https://github.com/dbcli/pgcli/issues/1360)). + +Internal: +--------- + +* Port to psycopg3 (https://github.com/psycopg/psycopg). +* Fix typos + +3.4.1 (2022/03/19) +================== + +Bug fixes: +---------- + +* Fix the bug with Redshift not displaying word count in status ([related issue](https://github.com/dbcli/pgcli/issues/1320)). +* Show the error status for CSV output format. + + +3.4.0 (2022/02/21) +================== + +Features: +--------- + +* Add optional support for automatically creating an SSH tunnel to a machine with access to the remote database ([related issue](https://github.com/dbcli/pgcli/issues/459)). + +3.3.1 (2022/01/18) +================== + +Bug fixes: +---------- + +* Prompt for password when -W is provided even if there is a password in keychain. Fixes #1307. +* Upgrade cli_helpers to 2.2.1 + +3.3.0 (2022/01/11) +================== + +Features: +--------- + +* Add `max_field_width` setting to config, to enable more control over field truncation ([related issue](https://github.com/dbcli/pgcli/issues/1250)). +* Re-run last query via bare `\watch`. (Thanks: `Saif Hakim`_) + +Bug fixes: +---------- + +* Pin the version of pygments to prevent breaking change + +3.2.0 +===== + +Release date: 2021/08/23 + +Features: +--------- + +* Consider `update` queries destructive and issue a warning. Change + `destructive_warning` setting to `all|moderate|off`, vs `true|false`. (#1239) +* Skip initial comment in .pg_session even if it doesn't start with '#' +* Include functions from schemas in search_path. (`Amjith Ramanujam`_) +* Easy way to show explain output under F5 + +Bug fixes: +---------- + +* Fix issue where `syntax_style` config value would not have any effect. (#1212) +* Fix crash because of not found `InputMode.REPLACE_SINGLE` with prompt-toolkit < 3.0.6 +* Fix comments being lost in config when saving a named query. (#1240) +* Fix IPython magic for ipython-sql >= 0.4.0 +* Fix pager not being used when output format is set to csv. (#1238) +* Add function literals random, generate_series, generate_subscripts +* Fix ANSI escape codes in first line make the cli choose expanded output incorrectly +* Fix pgcli crashing with virtual `pgbouncer` database. (#1093) + +3.1.0 +===== + +Features: +--------- + +* Make the output more compact by removing the empty newline. (Thanks: `laixintao`_) +* Add support for using [pspg](https://github.com/okbob/pspg) as a pager (#1102) +* Update python version in Dockerfile +* Support setting color for null, string, number, keyword value +* Support Prompt Toolkit 2 +* Support sqlparse 0.4.x +* Update functions, datatypes literals for auto-suggestion field +* Add suggestion for schema in function auto-complete + +Bug fixes: +---------- + +* Minor typo fixes in `pgclirc`. (Thanks: `anthonydb`_) +* Fix for list index out of range when executing commands from a file (#1193). (Thanks: `Irina Truong`_) +* Move from `humanize` to `pendulum` for displaying query durations (#1015) +* More explicit error message when connecting using DSN alias and it is not found. + +3.0.0 +===== + +Features: +--------- + +* Add `__main__.py` file to execute pgcli as a package directly (#1123). +* Add support for ANSI escape sequences for coloring the prompt (#1122). +* Add support for partitioned tables (relkind "p"). +* Add support for `pg_service.conf` files +* Add config option show_bottom_toolbar. + +Bug fixes: +---------- + +* Fix warning raised for using `is not` to compare string literal +* Close open connection in completion_refresher thread + +Internal: +--------- + +* Drop Python2.7, 3.4, 3.5 support. (Thanks: `laixintao`_) +* Support Python3.8. (Thanks: `laixintao`_) +* Fix dead link in development guide. (Thanks: `BrownShibaDog`_) +* Upgrade python-prompt-toolkit to v3.0. (Thanks: `laixintao`_) + + +2.2.0: +====== + +Features: +--------- + +* Add `\\G` as a terminator to sql statements that will show the results in expanded mode. This feature is copied from mycli. (Thanks: `Amjith Ramanujam`_) +* Removed limit prompt and added automatic row limit on queries with no LIMIT clause (#1079) (Thanks: `Sebastian Janko`_) +* Function argument completions now take account of table aliases (#1048). (Thanks: `Owen Stephens`_) + +Bug fixes: +---------- + +* Error connecting to PostgreSQL 12beta1 (#1058). (Thanks: `Irina Truong`_ and `Amjith Ramanujam`_) +* Empty query caused error message (#1019) (Thanks: `Sebastian Janko`_) +* History navigation bindings in multiline queries (#1004) (Thanks: `Pedro Ferrari`_) +* Can't connect to pgbouncer database (#1093). (Thanks: `Irina Truong`_) +* Fix broken multi-line history search (#1031). (Thanks: `Owen Stephens`_) +* Fix slow typing/movement when multi-line query ends in a semicolon (#994). (Thanks: `Owen Stephens`_) +* Fix for PQconninfo not available in libpq < 9.3 (#1110). (Thanks: `Irina Truong`_) + +Internal: +--------- + +* Add optional but default squash merge request to PULL_REQUEST_TEMPLATE + +2.1.1 +===== + +Bug fixes: +---------- +* Escape switches to VI navigation mode when not canceling completion popup. (Thanks: `Nathan Verzemnieks`_) +* Allow application_name to be overridden. (Thanks: `raylu`_) +* Fix for "no attribute KeyringLocked" (#1040). (Thanks: `Irina Truong`_) +* Pgcli no longer works with password containing spaces (#1043). (Thanks: `Irina Truong`_) +* Load keyring only when keyring is enabled in the config file (#1041). (Thanks: `Zhaolong Zhu`_) +* No longer depend on sqlparse as being less than 0.3.0 with the release of sqlparse 0.3.0. (Thanks: `VVelox`_) +* Fix the broken support for pgservice . (Thanks: `Xavier Francisco`_) +* Connecting using socket is broken in current master. (#1053). (Thanks: `Irina Truong`_) +* Allow usage of newer versions of psycopg2 (Thanks: `Telmo "Trooper"`_) +* Update README in alignment with the usage of newer versions of psycopg2 (Thanks: `Alexander Zawadzki`_) + +Internal: +--------- + +* Add python 3.7 to travis build matrix. (Thanks: `Irina Truong`_) +* Apply `black` to code. (Thanks: `Irina Truong`_) + +2.1.0 +===== + +Features: +--------- + +* Keybindings for closing the autocomplete list. (Thanks: `easteregg`_) +* Reconnect automatically when server closes connection. (Thanks: `Scott Brenstuhl`_) + +Bug fixes: +---------- +* Avoid error message on the server side if hstore extension is not installed in the current database (#991). (Thanks: `Marcin Cieślak`_) +* All pexpect submodules have been moved into the pexpect package as of version 3.0. Use pexpect.TIMEOUT (Thanks: `Marcin Cieślak`_) +* Resizing pgcli terminal kills the connection to postgres in python 2.7 (Thanks: `Amjith Ramanujam`_) +* Fix crash retrieving server version with ``--single-connection``. (Thanks: `Irina Truong`_) +* Cannot quit application without reconnecting to database (#1014). (Thanks: `Irina Truong`_) +* Password authentication failed for user "postgres" when using non-default password (#1020). (Thanks: `Irina Truong`_) + +Internal: +--------- + +* (Fixup) Clean up and add behave logging. (Thanks: `Marcin Cieślak`_, `Dick Marinus`_) +* Override VISUAL environment variable for behave tests. (Thanks: `Marcin Cieślak`_) +* Remove build dir before running sdist, remove stray files from wheel distribution. (Thanks: `Dick Marinus`_) +* Fix unit tests, unhashable formatted text since new python prompttoolkit version. (Thanks: `Dick Marinus`_) + +2.0.2: +====== + +Features: +--------- + +* Allows passing the ``-u`` flag to specify a username. (Thanks: `Ignacio Campabadal`_) +* Fix for lag in v2 (#979). (Thanks: `Irina Truong`_) +* Support for multihost connection string that is convenient if you have postgres cluster. (Thanks: `Mikhail Elovskikh`_) + +Internal: +--------- + +* Added tests for special command completion. (Thanks: `Amjith Ramanujam`_) + +2.0.1: +====== + +Bug fixes: +---------- + +* Tab press on an empty line increases the indentation instead of triggering + the auto-complete pop-up. (Thanks: `Artur Balabanov`_) +* Fix for loading/saving named queries from provided config file (#938). (Thanks: `Daniel Egger`_) +* Set default port in `connect_uri` when none is given. (Thanks: `Daniel Egger`_) +* Fix for error listing databases (#951). (Thanks: `Irina Truong`_) +* Enable Ctrl-Z to suspend the app (Thanks: `Amjith Ramanujam`_). +* Fix StopIteration exception raised at runtime for Python 3.7 (Thanks: `Amjith Ramanujam`_). + +Internal: +--------- + +* Clean up and add behave logging. (Thanks: `Dick Marinus`_) +* Require prompt_toolkit>=2.0.6. (Thanks: `Dick Marinus`_) +* Improve development guide. (Thanks: `Ignacio Campabadal`_) + +2.0.0: +====== + +* Update to ``prompt-toolkit`` 2.0. (Thanks: `Jonathan Slenders`_, `Dick Marinus`_, `Irina Truong`_) + +1.11.0 +====== + +Features: +--------- + +* Respect `\pset pager on` and use pager when output is longer than terminal height (Thanks: `Max Rothman`_) + +1.10.3 +====== + +Bug fixes: +---------- + +* Adapt the query used to get functions metadata to PG11 (#919). (Thanks: `Lele Gaifax`_). +* Fix for error retrieving version in Redshift (#922). (Thanks: `Irina Truong`_) +* Fix for keyring not disabled properly (#920). (Thanks: `Irina Truong`_) + +1.10.2 +====== + +Features: +--------- + +* Make `keyring` optional (Thanks: `Dick Marinus`_) + +1.10.1 +====== + +Bug fixes: +---------- + +* Fix for missing keyring. (Thanks: `Kenny Do`_) +* Fix for "-l" Flag Throws Error (#909). (Thanks: `Irina Truong`_) + +1.10.0 +====== + +Features: +--------- +* Add quit commands to the completion menu. (Thanks: `Jason Ribeiro`_) +* Add table formats to ``\T`` completion. (Thanks: `Jason Ribeiro`_) +* Support `\\ev``, ``\ef`` (#754). (Thanks: `Catherine Devlin`_) +* Add ``application_name`` to help identify pgcli connection to database (issue #868) (Thanks: `François Pietka`_) +* Add `--user` option, duplicate of `--username`, the same cli option like `psql` (Thanks: `Alexandr Korsak`_) + +Internal changes: +----------------- + +* Mark tests requiring a running database server as dbtest (Thanks: `Dick Marinus`_) +* Add an is_special command flag to MetaQuery (Thanks: `Rishi Ramraj`_) +* Ported Destructive Warning from mycli. +* Refactor Destructive Warning behave tests (Thanks: `Dick Marinus`_) + +Bug Fixes: +---------- +* Disable pager when using \watch (#837). (Thanks: `Jason Ribeiro`_) +* Don't offer to reconnect when we can't change a param in realtime (#807). (Thanks: `Amjith Ramanujam`_ and `Saif Hakim`_) +* Make keyring optional. (Thanks: `Dick Marinus`_) +* Fix ipython magic connection (#891). (Thanks: `Irina Truong`_) +* Fix not enough values to unpack. (Thanks: `Matthieu Guilbert`_) +* Fix unbound local error when destructive_warning is false. (Thanks: `Matthieu Guilbert`_) +* Render tab characters as 4 spaces instead of `^I`. (Thanks: `Artur Balabanov`_) + +1.9.1: +====== + +Features: +--------- + +* Change ``\h`` format string in prompt to only return the first part of the hostname, + up to the first '.' character. Add ``\H`` that returns the entire hostname (#858). + (Thanks: `Andrew Kuchling`_) +* Add Color of table by parameter. The color of table is function of syntax style + +Internal changes: +----------------- + +* Add tests, AUTHORS and changelog.rst to release. (Thanks: `Dick Marinus`_) + +Bug Fixes: +---------- +* Fix broken pgcli --list command line option (#850). (Thanks: `Dmitry B`_) + +1.9.0 +===== + +Features: +--------- + +* manage pager by \pset pager and add enable_pager to the config file (Thanks: `Frederic Aoustin`_). +* Add support for `\T` command to change format output. (Thanks: `Frederic Aoustin`_). +* Add option list-dsn (Thanks: `Frederic Aoustin`_). + + +Internal changes: +----------------- + +* Removed support for Python 3.3. (Thanks: `Irina Truong`_) + +1.8.2 +===== + +Features: +--------- + +* Use other prompt (prompt_dsn) when connecting using --dsn parameter. (Thanks: `Marcin Sztolcman`_) +* Include username into password prompt. (Thanks: `Bojan Delić`_) + +Internal changes: +----------------- +* Use temporary dir as config location in tests. (Thanks: `Dmitry B`_) +* Fix errors in the ``tee`` test (#795 and #797). (Thanks: `Irina Truong`_) +* Increase timeout for quitting pgcli. (Thanks: `Dick Marinus`_) + +Bug Fixes: +---------- +* Do NOT quote the database names in the completion menu (Thanks: `Amjith Ramanujam`_) +* Fix error in ``unix_socket_directories`` (#805). (Thanks: `Irina Truong`_) +* Fix the --list command line option tries to connect to 'personal' DB (#816). (Thanks: `Isank`_) + +1.8.1 +===== + +Internal changes: +----------------- +* Remove shebang and git execute permission from pgcli/main.py. (Thanks: `Dick Marinus`_) +* Require cli_helpers 0.2.3 (fix #791). (Thanks: `Dick Marinus`_) + +1.8.0 +===== + +Features: +--------- + +* Add fish-style auto-suggestion from history. (Thanks: `Amjith Ramanujam`_) +* Improved formatting of arrays in output (Thanks: `Joakim Koljonen`_) +* Don't quote identifiers that are non-reserved keywords. (Thanks: `Joakim Koljonen`_) +* Remove the ``...`` in the continuation prompt and use empty space instead. (Thanks: `Amjith Ramanujam`_) +* Add \conninfo and handle more parameters with \c (issue #716) (Thanks: `François Pietka`_) + +Internal changes: +----------------- +* Preliminary work for a future change in outputting results that uses less memory. (Thanks: `Dick Marinus`_) +* Remove import workaround for OrderedDict, required for python < 2.7. (Thanks: `Andrew Speed`_) +* Use less memory when formatting results for display (Thanks: `Dick Marinus`_). +* Port auto_vertical feature test from mycli to pgcli. (Thanks: `Dick Marinus`_) +* Drop wcwidth dependency (Thanks: `Dick Marinus`_) + +Bug Fixes: +---------- + +* Fix the way we get host when using DSN (issue #765) (Thanks: `François Pietka`_) +* Add missing keyword COLUMN after DROP (issue #769) (Thanks: `François Pietka`_) +* Don't include arguments in function suggestions for backslash commands (Thanks: `Joakim Koljonen`_) +* Optionally use POSTGRES_USER, POSTGRES_HOST POSTGRES_PASSWORD from environment (Thanks: `Dick Marinus`_) + +1.7.0 +===== + +* Refresh completions after `COMMIT` or `ROLLBACK`. (Thanks: `Irina Truong`_) +* Fixed DSN aliases not being read from custom pgclirc (issue #717). (Thanks: `Irina Truong`_). +* Use dbcli's Homebrew tap for installing pgcli on macOS (issue #718) (Thanks: `Thomas Roten`_). +* Only set `LESS` environment variable if it's unset. (Thanks: `Irina Truong`_) +* Quote schema in `SET SCHEMA` statement (issue #469) (Thanks: `Irina Truong`_) +* Include arguments in function suggestions (Thanks: `Joakim Koljonen`_) +* Use CLI Helpers for pretty printing query results (Thanks: `Thomas Roten`_). +* Skip serial columns when expanding * for `INSERT INTO foo(*` (Thanks: `Joakim Koljonen`_). +* Command line option to list databases (issue #206) (Thanks: `François Pietka`_) + +1.6.0 +===== + +Features: +--------- +* Add time option for prompt (Thanks: `Gustavo Castro`_) +* Suggest objects from all schemas (not just those in search_path) (Thanks: `Joakim Koljonen`_) +* Casing for column headers (Thanks: `Joakim Koljonen`_) +* Allow configurable character to be used for multi-line query continuations. (Thanks: `Owen Stephens`_) +* Completions after ORDER BY and DISTINCT now take account of table aliases. (Thanks: `Owen Stephens`_) +* Narrow keyword candidates based on previous keyword. (Thanks: `Étienne Bersac`_) +* Opening an external editor will edit the last-run query. (Thanks: `Thomas Roten`_) +* Support query options in postgres URIs such as ?sslcert=foo.pem (Thanks: `Alexander Schmolck`_) + +Bug fixes: +---------- +* Fixed external editor bug (issue #668). (Thanks: `Irina Truong`_). +* Standardize command line option names. (Thanks: `Russell Davies`_) +* Improve handling of ``lock_not_available`` error (issue #700). (Thanks: `Jackson Popkin <https://github.com/jdpopkin>`_) +* Fixed user option precedence (issue #697). (Thanks: `Irina Truong`_). + +Internal changes: +----------------- +* Run pep8 checks in travis (Thanks: `Irina Truong`_). +* Add pager wrapper for behave tests (Thanks: `Dick Marinus`_). +* Behave quit pgcli nicely (Thanks: `Dick Marinus`_). +* Behave test source command (Thanks: `Dick Marinus`_). +* Behave fix clean up. (Thanks: `Dick Marinus`_). +* Test using behave the tee command (Thanks: `Dick Marinus`_). +* Behave remove boiler plate code (Thanks: `Dick Marinus`_). +* Behave fix pgspecial update (Thanks: `Dick Marinus`_). +* Add behave to tox (Thanks: `Dick Marinus`_). + +1.5.1 +===== + +Features: +--------- +* Better suggestions when editing functions (Thanks: `Joakim Koljonen`_) +* Command line option for ``--less-chatty``. (Thanks: `tk`_) +* Added ``MATERIALIZED VIEW`` keywords. (Thanks: `Joakim Koljonen`_). + +Bug fixes: +---------- + +* Support unicode chars in expanded mode. (Thanks: `Amjith Ramanujam`_) +* Fixed "set_session cannot be used inside a transaction" when using dsn. (Thanks: `Irina Truong`_). + +1.5.0 +===== + +Features: +--------- +* Upgraded pgspecial to 1.7.0. (See `pgspecial changelog <https://github.com/dbcli/pgspecial/blob/master/changelog.rst>`_ for list of fixes) +* Add a new config setting to allow expandable mode (Thanks: `Jonathan Boudreau <https://github.com/AGhost-7>`_) +* Make pgcli prompt width short when the prompt is too long (Thanks: `Jonathan Virga <https://github.com/jnth>`_) +* Add additional completion for ``ALTER`` keyword (Thanks: `Darik Gamble`_) +* Make the menu size configurable. (Thanks `Darik Gamble`_) + +Bug Fixes: +---------- +* Handle more connection failure cases. (Thanks: `Amjith Ramanujam`_) +* Fix the connection failure issues with latest psycopg2. (Thanks: `Amjith Ramanujam`_) + +Internal Changes: +----------------- + +* Add testing for Python 3.5 and 3.6. (Thanks: `Amjith Ramanujam`_) + +1.4.0 +===== + +Features: +--------- + +* Search table suggestions using initialisms. (Thanks: `Joakim Koljonen`_). +* Support for table-qualifying column suggestions. (Thanks: `Joakim Koljonen`_). +* Display transaction status in the toolbar. (Thanks: `Joakim Koljonen`_). +* Display vi mode in the toolbar. (Thanks: `Joakim Koljonen`_). +* Added --prompt option. (Thanks: `Irina Truong`_). + +Bug Fixes: +---------- + +* Fix scoping for columns from CTEs. (Thanks: `Joakim Koljonen`_) +* Fix crash after `with`. (Thanks: `Joakim Koljonen`_). +* Fix issue #603 (`\i` raises a TypeError). (Thanks: `Lele Gaifax`_). + + +Internal Changes: +----------------- + +* Set default data_formatting to nothing. (Thanks: `Amjith Ramanujam`_). +* Increased minimum prompt_toolkit requirement to 1.0.9. (Thanks: `Irina Truong`_). + + +1.3.1 +===== + +Bug Fixes: +---------- +* Fix a crashing bug due to sqlparse upgrade. (Thanks: `Darik Gamble`_) + + +1.3.0 +===== + +IMPORTANT: Python 2.6 is not officially supported anymore. + +Features: +--------- +* Add delimiters to displayed numbers. This can be configured via the config file. (Thanks: `Sergii`_). +* Fix broken 'SHOW ALL' in redshift. (Thanks: `Manuel Barkhau`_). +* Support configuring keyword casing preferences. (Thanks: `Darik Gamble`_). +* Add a new multi_line_mode option in config file. The values can be `psql` or `safe`. (Thanks: `Joakim Koljonen`_) + Setting ``multi_line_mode = safe`` will make sure that a query will only be executed when Alt+Enter is pressed. + +Bug Fixes: +---------- +* Fix crash bug with leading parenthesis. (Thanks: `Joakim Koljonen`_). +* Remove cumulative addition of timing data. (Thanks: `Amjith Ramanujam`_). +* Handle unrecognized keywords gracefully. (Thanks: `Darik Gamble`_) +* Use raw strings in regex specifiers. This preemptively fixes a crash in Python 3.6. (Thanks `Lele Gaifax`_) + +Internal Changes: +----------------- +* Set sqlparse version dependency to >0.2.0, <0.3.0. (Thanks: `Amjith Ramanujam`_). +* XDG_CONFIG_HOME support for config file location. (Thanks: `Fabien Meghazi`_). +* Remove Python 2.6 from travis test suite. (Thanks: `Amjith Ramanujam`_) + +1.2.0 +===== + +Features: +--------- + +* Add more specifiers to pgcli prompt. (Thanks: `Julien Rouhaud`_). + ``\p`` for port info ``\#`` for super user and ``\i`` for pid. +* Add `\watch` command to periodically execute a command. (Thanks: `Stuart Quin`_). + ``> SELECT * FROM django_migrations; \watch 1 /* Runs the command every second */`` +* Add command-line option --single-connection to prevent pgcli from using multiple connections. (Thanks: `Joakim Koljonen`_). +* Add priority to the suggestions to sort based on relevance. (Thanks: `Joakim Koljonen`_). +* Configurable null format via the config file. (Thanks: `Adrian Dries`_). +* Add support for CTE aware auto-completion. (Thanks: `Darik Gamble`_). +* Add host and user information to default pgcli prompt. (Thanks: `Lim H`_). +* Better scoping for tables in insert statements to improve suggestions. (Thanks: `Joakim Koljonen`_). + +Bug Fixes: +---------- + +* Do not install setproctitle on cygwin. (Thanks: `Janus Troelsen`_). +* Work around sqlparse crashing after AS keyword. (Thanks: `Joakim Koljonen`_). +* Fix a crashing bug with named queries. (Thanks: `Joakim Koljonen`_). +* Replace timestampz alias since AWS Redshift does not support it. (Thanks: `Tahir Butt`_). +* Prevent pgcli from hanging indefinitely when Postgres instance is not running. (Thanks: `Darik Gamble`_) + +Internal Changes: +----------------- + +* Upgrade to sqlparse-0.2.0. (Thanks: `Tiziano Müller`_). +* Upgrade to pgspecial 1.6.0. (Thanks: `Stuart Quin`_). + + +1.1.0 +===== + +Features: +--------- + +* Add support for ``\db`` command. (Thanks: `Irina Truong`_) + +Bugs: +----- + +* Fix the crash at startup while parsing the postgres url with port number. (Thanks: `Eric Wald`_) +* Fix the crash with Redshift databases. (Thanks: `Darik Gamble`_) + +Internal Changes: +----------------- + +* Upgrade pgspecial to 1.5.0 and above. + +1.0.0 +===== + +Features: +--------- + +* Upgrade to prompt-toolkit 1.0.0. (Thanks: `Jonathan Slenders`_). +* Add support for `\o` command to redirect query output to a file. (Thanks: `Tim Sanders`_). +* Add `\i` path completion. (Thanks: `Anthony Lai`_). +* Connect to a dsn saved in config file. (Thanks: `Rodrigo Ramírez Norambuena`_). +* Upgrade sqlparse requirement to version 0.1.19. (Thanks: `Fernando L. Canizo`_). +* Add timestamptz to DATE custom extension. (Thanks: `Fernando Mora`_). +* Ensure target dir exists when copying config. (Thanks: `David Szotten`_). +* Handle dates that fall in the B.C. range. (Thanks: `Stuart Quin`_). +* Pager is selected from config file or else from environment variable. (Thanks: `Fernando Mora`_). +* Add support for Amazon Redshift. (Thanks: `Timothy Cleaver`_). +* Add support for Postgres 8.x. (Thanks: `Timothy Cleaver`_ and `Darik Gamble`_) +* Don't error when completing parameter-less functions. (Thanks: `David Szotten`_). +* Concat and return all available notices. (Thanks: `Stuart Quin`_). +* Handle unicode in record type. (Thanks: `Amjith Ramanujam`_). +* Added humanized time display. Connect #396. (Thanks: `Irina Truong`_). +* Add EXPLAIN keyword to the completion list. (Thanks: `Amjith Ramanujam`_). +* Added sdist upload to release script. (Thanks: `Irina Truong`_). +* Sort completions based on most recently used. (Thanks: `Darik Gamble`) +* Expand '*' into column list during completion. This can be triggered by hitting `<tab>` after the '*' character in the sql while typing. (Thanks: `Joakim Koljonen`_) +* Add a limit to the warning about too many rows. This is controlled by a new config value in ~/.config/pgcli/config. (Thanks: `Anže Pečar`_) +* Improved argument list in function parameter completions. (Thanks: `Joakim Koljonen`_) +* Column suggestions after the COLUMN keyword. (Thanks: `Darik Gamble`_) +* Filter out trigger implemented functions from the suggestion list. (Thanks: `Daniel Rocco`_) +* State of the art JOIN clause completions that suggest entire conditions. (Thanks: `Joakim Koljonen`_) +* Suggest fully formed JOIN clauses based on Foreign Key relations. (Thanks: `Joakim Koljonen`_) +* Add support for `\dx` meta command to list the installed extensions. (Thanks: `Darik Gamble`_) +* Add support for `\copy` command. (Thanks: `Catherine Devlin`_) + +Bugs: +----- + +* Fix bug where config writing would leave a '~' dir. (Thanks: `James Munson`_). +* Fix auto-completion breaking for table names with caps. (Thanks: `Anthony Lai`_). +* Fix lexical ordering bug. (Thanks: `Anthony Lai`_). +* Use lexical order to break ties when fuzzy matching. (Thanks: `Daniel Rocco`_). +* Fix the bug in auto-expand mode when there are no rows to display. (Thanks: `Amjith Ramanujam`_). +* Fix broken `\i` after #395. (Thanks: `David Szotten`_). +* Fix multi-way joins in auto-completion. (Thanks: `Darik Gamble`_) +* Display null values as <null> in expanded output. (Thanks: `Amjith Ramanujam`_). +* Robust support for Postgres version less than 9.x. (Thanks: `Darik Gamble`_) + +Internal Changes: +----------------- + +* Update config file location in README. (Thanks: `Ari Summer`_). +* Explicitly add wcwidth as a dependency. (Thanks: `Amjith Ramanujam`_). +* Add tests for the format_output. (Thanks: `Amjith Ramanujam`_). +* Lots of tests for pgcompleter. (Thanks: `Darik Gamble`_). +* Update pgspecial dependency to 1.4.0. + + +0.20.1 +====== + +Bug Fixes: +---------- +* Fixed logging in Windows by switching the location of log and history file based on OS. (Thanks: Amjith, `Darik Gamble`_, `Irina Truong`_). + +0.20.0 +====== + +Features: +--------- +* Perform auto-completion refresh in background. (Thanks: Amjith, `Darik Gamble`_, `Irina Truong`_). + When the auto-completion entries are refreshed, the update now happens in a + background thread. This means large databases with thousands of tables are + handled without blocking. +* Add ``CONCURRENTLY`` to keyword completion. (Thanks: `Johannes Hoff`_). +* Add support for ``\h`` command. (Thanks: `Stuart Quin`_). + This is a huge deal. Users can now get help on an SQL command by typing: + ``\h COMMAND_NAME`` in the pgcli prompt. +* Add support for ``\x auto``. (Thanks: `Stuart Quin`_). + ``\\x auto`` will automatically switch to expanded mode if the output is wider + than the display window. +* Don't hide functions from pg_catalog. (Thanks: `Darik Gamble`_). +* Suggest set-returning functions as tables. (Thanks: `Darik Gamble`_). + Functions that return table like results will now be suggested in places of tables. +* Suggest fields from functions used as tables. (Thanks: `Darik Gamble`_). +* Using ``pgspecial`` as a separate module. (Thanks: `Irina Truong`_). +* Make "enter" key behave as "tab" key when the completion menu is displayed. (Thanks: `Matheus Rosa`_). +* Support different error-handling options when running multiple queries. (Thanks: `Darik Gamble`_). + When ``on_error = STOP`` in the config file, pgcli will abort execution if one of the queries results in an error. +* Hide the password displayed in the process name in ``ps``. (Thanks: `Stuart Quin`_) + +Bug Fixes: +---------- +* Fix the ordering bug in `\\d+` display, this bug was displaying the wrong table name in the reference. (Thanks: `Tamas Boros`_). +* Only show expanded layout if valid list of headers provided. (Thanks: `Stuart Quin`_). +* Fix suggestions in compound join clauses. (Thanks: `Darik Gamble`_). +* Fix completion refresh in multiple query scenario. (Thanks: `Darik Gamble`_). +* Fix the broken timing information. +* Fix the removal of whitespaces in the output. (Thanks: `Jacek Wielemborek`_) +* Fix PyPI badge. (Thanks: `Artur Dryomov`_). + +Improvements: +------------- +* Move config file to `~/.config/pgcli/config` instead of `~/.pgclirc` (Thanks: `inkn`_). +* Move literal definitions to standalone JSON files. (Thanks: `Darik Gamble`_). + +Internal Changes: +----------------- +* Improvements to integration tests to make it more robust. (Thanks: `Irina Truong`_). + +0.19.2 +====== + +Features: +--------- + +* Autocompletion for database name in \c and \connect. (Thanks: `Darik Gamble`_). +* Improved multiline query support by correctly handling open quotes. (Thanks: `Darik Gamble`_). +* Added \pager command. +* Enhanced \i to run multiple queries and display the results for each of them +* Added keywords to suggestions after WHERE clause. +* Enabled autocompletion in named queries. (Thanks: `Irina Truong`_). +* Path to .pgclirc can be specified in command line. (Thanks: `Irina Truong`_). +* Added support for pg_service_conf file. (Thanks: `Irina Truong`_). +* Added custom styles. (Contributor: `Darik Gamble`_). + +Internal Changes: +----------------- + +* More completer test cases. (Thanks: `Darik Gamble`_). +* Updated sqlparse version from 0.1.14 to 0.1.16. (Thanks: `Darik Gamble`_). +* Upgraded to prompt_toolkit 0.46. (Thanks: `Jonathan Slenders`_). + +BugFixes: +--------- +* Fixed the completer crashing on invalid SQL. (Thanks: `Darik Gamble`_). +* Fixed unicode issues, updated tests and fixed broken tests. + +0.19.1 +====== + +BugFixes: +--------- + +* Fix an autocompletion bug that was crashing the completion engine when unknown keyword is entered. (Thanks: `Darik Gamble`_) + +0.19.0 +====== + +Features: +--------- + +* Wider completion menus can be enabled via the config file. (Thanks: `Jonathan Slenders`_) + + Open the config file (~/.pgclirc) and check if you have + ``wider_completion_menu`` option available. If not add it in and set it to + ``True``. + +* Completion menu now has metadata information such as schema, table, column, view, etc., next to the suggestions. (Thanks: `Darik Gamble`_) +* Customizable history file location via config file. (Thanks: `Çağatay Yüksel`_) + + Add this line to your config file (~/.pgclirc) to customize where to store the history file. + +:: + + history_file = /path/to/history/file + +* Add support for running queries from a file using ``\i`` special command. (Thanks: `Michael Kaminsky`_) + +BugFixes: +--------- + +* Always use utf-8 for database encoding regardless of the default encoding used by the database. +* Fix for None dereference on ``\d schemaname.`` with sequence. (Thanks: `Nathan Jhaveri`_) +* Fix a crashing bug in the autocompletion engine for some ``JOIN`` queries. +* Handle KeyboardInterrupt in pager and not quit pgcli as a consequence. + +Internal Changes: +----------------- + +* Added more behaviorial tests (Thanks: `Irina Truong`_) +* Added code coverage to the tests. (Thanks: `Irina Truong`_) +* Run behaviorial tests as part of TravisCI (Thanks: `Irina Truong`_) +* Upgraded prompt_toolkit version to 0.45 (Thanks: `Jonathan Slenders`_) +* Update the minimum required version of click to 4.1. + +0.18.0 +====== + +Features: +--------- + +* Add fuzzy matching for the table names and column names. + + Matching very long table/column names are now easier with fuzzy matching. The + fuzzy match works like the fuzzy open in SublimeText or Vim's Ctrl-P plugin. + + eg: Typing ``djmv`` will match `django_migration_views` since it is able to + match parts of the input to the full table name. + +* Change the timing information to seconds. + + The ``Command Time`` and ``Format Time`` are now displayed in seconds instead + of a unitless number displayed in scientific notation. + +* Support for named queries (favorite queries). (Thanks: `Brett Atoms`_) + + Frequently typed queries can now be saved and recalled using a name using + newly added special commands (``\n[+]``, ``\ns``, ``\nd``). + + eg: + +:: + + # Save a query + pgcli> \ns simple select * from foo + saved + + # List all saved queries + pgcli> \n+ + + # Execute a saved query + pgcli> \n simple + + # Delete a saved query + pgcli> \nd simple + +* Pasting queries into the pgcli repl is orders of magnitude faster. (Thanks: `Jonathan Slenders`_) + +* Add support for PGPASSWORD environment variable to pass the password for the + postgres database. (Thanks: `Irina Truong`_) + +* Add the ability to manually refresh autocompletions by typing ``\#`` or + ``\refresh``. This is useful if the database was updated by an external means + and you'd like to refresh the auto-completions to pick up the new change. + +Bug Fixes: +---------- + +* Fix an error when running ``\d table_name`` when running on a table with rules. (Thanks: `Ali Kargın`_) +* Fix a pgcli crash when entering non-ascii characters in Windows. (Thanks: `Darik Gamble`_, `Jonathan Slenders`_) +* Faster rendering of expanded mode output by making the horizontal separator a fixed length string. +* Completion suggestions for the ``\c`` command are not auto-escaped by default. + +Internal Changes: +----------------- + +* Complete refactor of handling the back-slash commands. +* Upgrade prompt_toolkit to 0.42. (Thanks: `Jonathan Slenders`_) +* Change the config file management to use ConfigObj.(Thanks: `Brett Atoms`_) +* Add integration tests using ``behave``. (Thanks: `Irina Truong`_) + +0.17.0 +====== + +Features: +--------- + +* Add support for auto-completing view names. (Thanks: `Darik Gamble`_) +* Add support for building RPM and DEB packages. (Thanks: dp_) +* Add subsequence matching for completion. (Thanks: `Daniel Rocco`_) + Previously completions only matched a table name if it started with the + partially typed word. Now completions will match even if the partially typed + word is in the middle of a suggestion. + eg: When you type 'mig', 'django_migrations' will be suggested. +* Completion for built-in tables and temporary tables are suggested after entering a prefix of ``pg_``. (Thanks: `Darik Gamble`_) +* Add place holder doc strings for special commands that are planned for implementation. (Thanks: `Irina Truong`_) +* Updated version of prompt_toolkit, now matching braces are highlighted. (Thanks: `Jonathan Slenders`_) +* Added support of ``\\e`` command. Queries can be edited in an external editor. (Thanks: `Irina Truong`_) + eg: When you type ``SELECT * FROM \e`` it will be opened in an external editor. +* Add special command ``\dT`` to show datatypes. (Thanks: `Darik Gamble`_) +* Add auto-completion support for datatypes in CREATE, SELECT etc. (Thanks: `Darik Gamble`_) +* Improve the auto-completion in WHERE clause with logical operators. (Thanks: `Darik Gamble`_) +* + +Bug Fixes: +---------- + +* Fix the table formatting while printing multi-byte characters (Chinese, Japanese etc). (Thanks: `蔡佳男`_) +* Fix a crash when pg_catalog was present in search path. (Thanks: `Darik Gamble`_) +* Fixed a bug that broke `\\e` when prompt_tookit was updated. (Thanks: `François Pietka`_) +* Fix the display of triggers as shown in the ``\d`` output. (Thanks: `Dimitar Roustchev`_) +* Fix broken auto-completion for INNER JOIN, LEFT JOIN etc. (Thanks: `Darik Gamble`_) +* Fix incorrect super() calls in pgbuffer, pgtoolbar and pgcompleter. No change in functionality but protects against future problems. (Thanks: `Daniel Rocco`_) +* Add missing schema completion for CREATE and DROP statements. (Thanks: `Darik Gamble`_) +* Minor fixes around cursor cleanup. + +0.16.3 +====== + +Bug Fixes: +---------- +* Add more SQL keywords for auto-complete suggestion. +* Messages raised as part of stored procedures are no longer ignored. +* Use postgres flavored syntax highlighting instead of generic ANSI SQL. + +0.16.2 +====== + +Bug Fixes: +---------- +* Fix a bug where the schema qualifier was ignored by the auto-completion. + As a result the suggestions for tables vs functions are cleaner. (Thanks: `Darik Gamble`_) +* Remove scientific notation when formatting large numbers. (Thanks: `Daniel Rocco`_) +* Add the FUNCTION keyword to auto-completion. +* Display NULL values as <null> instead of empty strings. +* Fix the completion refresh when ``\connect`` is executed. + +0.16.1 +====== + +Bug Fixes: +---------- +* Fix unicode issues with hstore. +* Fix a silent error when database is changed using \\c. + +0.16.0 +====== + +Features: +--------- +* Add \ds special command to show sequences. +* Add Vi mode for keybindings. This can be enabled by adding 'vi = True' in ~/.pgclirc. (Thanks: `Jay Zeng`_) +* Add a -v/--version flag to pgcli. +* Add completion for TEMPLATE keyword and smart-completion for + 'CREATE DATABASE blah WITH TEMPLATE <tab>'. (Thanks: `Daniel Rocco`_) +* Add custom decoders to json/jsonb to emulate the behavior of psql. This + removes the unicode prefix (eg: u'Éowyn') in the output. (Thanks: `Daniel Rocco`_) +* Add \df special command to show functions. (Thanks: `Darik Gamble`_) +* Make suggestions for special commands smarter. eg: \dn - only suggests schemas. (Thanks: `Darik Gamble`_) +* Print out the version and other meta info about pgcli at startup. + +Bug Fixes: +---------- +* Fix a rare crash caused by adding new schemas to a database. (Thanks: `Darik Gamble`_) +* Make \dt command honor the explicit schema specified in the arg. (Thanks: `Darik Gamble`_) +* Print BIGSERIAL type as Integer instead of Float. +* Show completions for special commands at the beginning of a statement. (Thanks: `Daniel Rocco`_) +* Allow special commands to work in a multi-statement case where multiple sql + statements are separated by semi-colon in the same line. + +0.15.4 +====== +* Dummy version to replace accidental PyPI entry deletion. + +0.15.3 +====== +* Override the LESS options completely instead of appending to it. + +0.15.2 +====== +* Revert back to using psycopg2 as the postgres adapter. psycopg2cffi fails for some tests in Python 3. + +0.15.0 +====== + +Features: +--------- +* Add syntax color styles to config. +* Add auto-completion for COPY statements. +* Change Postgres adapter to psycopg2cffi, to make it PyPy compatible. + Now pgcli can be run by PyPy. + +Bug Fixes: +---------- +* Treat boolean values as strings instead of ints. +* Make \di, \dv and \dt to be schema aware. (Thanks: `Darik Gamble`_) +* Make column name display unicode compatible. + +0.14.0 +====== + +Features: +--------- +* Add alias completion support to ON keyword. (Thanks: `Irina Truong`_) +* Add LIMIT keyword to completion. +* Auto-completion for Postgres schemas. (Thanks: `Darik Gamble`_) +* Better unicode handling for datatypes, dbname and roles. +* Add \timing command to time the sql commands. + This can be set via config file (~/.pgclirc) using `timing = True`. +* Add different table styles for displaying output. + This can be changed via config file (~/.pgclirc) using `table_format = fancy_grid`. +* Add confirmation before printing results that have more than 1000 rows. + +Bug Fixes: +---------- + +* Performance improvements to expanded view display (\x). +* Cast bytea files to text while displaying. (Thanks: `Daniel Rocco`_) +* Added a list of reserved words that should be auto-escaped. +* Auto-completion is now case-insensitive. +* Fix the broken completion for multiple sql statements. (Thanks: `Darik Gamble`_) + +0.13.0 +====== + +Features: +--------- + +* Add -d/--dbname option to the commandline. + eg: pgcli -d database +* Add the username as an argument after the database. + eg: pgcli dbname user + +Bug Fixes: +---------- +* Fix the crash when \c fails. +* Fix the error thrown by \d when triggers are present. +* Fix broken behavior on \?. (Thanks: `Darik Gamble`_) + +0.12.0 +====== + +Features: +--------- + +* Upgrade to prompt_toolkit version 0.26 (Thanks: https://github.com/macobo) + * Adds Ctrl-left/right to move the cursor one word left/right respectively. + * Internal API changes. +* IPython integration through `ipython-sql`_ (Thanks: `Darik Gamble`_) + * Add an ipython magic extension to embed pgcli inside ipython. + * Results from a pgcli query are sent back to ipython. +* Multiple sql statements in the same line separated by semi-colon. (Thanks: https://github.com/macobo) + +.. _`ipython-sql`: https://github.com/catherinedevlin/ipython-sql + +Bug Fixes: +---------- + +* Fix 'message' attribute not found exception in Python 3. (Thanks: https://github.com/GMLudo) +* Use the database username as the database name instead of defaulting to OS username. (Thanks: https://github.com/fpietka) +* Auto-completion for auto-escaped column/table names. +* Fix i-reverse-search to work in prompt_toolkit version 0.26. + +0.11.0 +====== + +Features: +--------- + +* Add \dn command. (Thanks: https://github.com/CyberDem0n) +* Add \x command. (Thanks: https://github.com/stuartquin) +* Auto-escape special column/table names. (Thanks: https://github.com/qwesda) +* Cancel a command using Ctrl+C. (Thanks: https://github.com/macobo) +* Faster startup by reading all columns and tables in a single query. (Thanks: https://github.com/macobo) +* Improved psql compliance with env vars and password prompting. (Thanks: `Darik Gamble`_) +* Pressing Alt-Enter will introduce a line break. This is a way to break up the query into multiple lines without switching to multi-line mode. (Thanks: https://github.com/pabloab). + +Bug Fixes: +---------- +* Fix the broken behavior of \d+. (Thanks: https://github.com/macobo) +* Fix a crash during auto-completion. (Thanks: https://github.com/Erethon) +* Avoid losing pre_run_callables on error in editing. (Thanks: https://github.com/catherinedevlin) + +Improvements: +------------- +* Faster test runs on TravisCI. (Thanks: https://github.com/macobo) +* Integration tests with Postgres!! (Thanks: https://github.com/macobo) + +.. _`Amjith Ramanujam`: https://blog.amjith.com +.. _`Andrew Kuchling`: https://github.com/akuchling +.. _`Darik Gamble`: https://github.com/darikg +.. _`Daniel Rocco`: https://github.com/drocco007 +.. _`Jay Zeng`: https://github.com/jayzeng +.. _`蔡佳男`: https://github.com/xalley +.. _dp: https://github.com/ceocoder +.. _`Jonathan Slenders`: https://github.com/jonathanslenders +.. _`Dimitar Roustchev`: https://github.com/droustchev +.. _`François Pietka`: https://github.com/fpietka +.. _`Ali Kargın`: https://github.com/sancopanco +.. _`Brett Atoms`: https://github.com/brettatoms +.. _`Nathan Jhaveri`: https://github.com/nathanjhaveri +.. _`Çağatay Yüksel`: https://github.com/cagatay +.. _`Michael Kaminsky`: https://github.com/mikekaminsky +.. _`inkn`: inkn +.. _`Johannes Hoff`: Johannes Hoff +.. _`Matheus Rosa`: Matheus Rosa +.. _`Artur Dryomov`: https://github.com/ming13 +.. _`Stuart Quin`: https://github.com/stuartquin +.. _`Tamas Boros`: https://github.com/TamasNo1 +.. _`Jacek Wielemborek`: https://github.com/d33tah +.. _`Rodrigo Ramírez Norambuena`: https://github.com/roramirez +.. _`Anthony Lai`: https://github.com/ajlai +.. _`Ari Summer`: Ari Summer +.. _`David Szotten`: David Szotten +.. _`Fernando L. Canizo`: Fernando L. Canizo +.. _`Tim Sanders`: https://github.com/Gollum999 +.. _`Irina Truong`: https://github.com/j-bennet +.. _`James Munson`: https://github.com/jmunson +.. _`Fernando Mora`: https://github.com/fernandomora +.. _`Timothy Cleaver`: Timothy Cleaver +.. _`gtxx`: gtxx +.. _`Joakim Koljonen`: https://github.com/koljonen +.. _`Anže Pečar`: https://github.com/Smotko +.. _`Catherine Devlin`: https://github.com/catherinedevlin +.. _`Eric Wald`: https://github.com/eswald +.. _`avdd`: https://github.com/avdd +.. _`Adrian Dries`: Adrian Dries +.. _`Julien Rouhaud`: https://github.com/rjuju +.. _`Lim H`: Lim H +.. _`Tahir Butt`: Tahir Butt +.. _`Tiziano Müller`: https://github.com/dev-zero +.. _`Janus Troelsen`: https://github.com/ysangkok +.. _`Fabien Meghazi`: https://github.com/amigrave +.. _`Manuel Barkhau`: https://github.com/mbarkhau +.. _`Sergii`: https://github.com/foxyterkel +.. _`Lele Gaifax`: https://github.com/lelit +.. _`tk`: https://github.com/kanet77 +.. _`Owen Stephens`: https://github.com/owst +.. _`Russell Davies`: https://github.com/russelldavies +.. _`Dick Marinus`: https://github.com/meeuw +.. _`Étienne Bersac`: https://github.com/bersace +.. _`Thomas Roten`: https://github.com/tsroten +.. _`Gustavo Castro`: https://github.com/gustavo-castro +.. _`Alexander Schmolck`: https://github.com/aschmolck +.. _`Andrew Speed`: https://github.com/AndrewSpeed +.. _`Dmitry B`: https://github.com/oxitnik +.. _`Marcin Sztolcman`: https://github.com/msztolcman +.. _`Isank`: https://github.com/isank +.. _`Bojan Delić`: https://github.com/delicb +.. _`Frederic Aoustin`: https://github.com/fraoustin +.. _`Jason Ribeiro`: https://github.com/jrib +.. _`Rishi Ramraj`: https://github.com/RishiRamraj +.. _`Matthieu Guilbert`: https://github.com/gma2th +.. _`Alexandr Korsak`: https://github.com/oivoodoo +.. _`Saif Hakim`: https://github.com/saifelse +.. _`Artur Balabanov`: https://github.com/arturbalabanov +.. _`Kenny Do`: https://github.com/kennydo +.. _`Max Rothman`: https://github.com/maxrothman +.. _`Daniel Egger`: https://github.com/DanEEStar +.. _`Ignacio Campabadal`: https://github.com/igncampa +.. _`Mikhail Elovskikh`: https://github.com/wronglink +.. _`Marcin Cieślak`: https://github.com/saper +.. _`Scott Brenstuhl`: https://github.com/808sAndBR +.. _`easteregg`: https://github.com/verfriemelt-dot-org +.. _`Nathan Verzemnieks`: https://github.com/njvrzm +.. _`raylu`: https://github.com/raylu +.. _`Zhaolong Zhu`: https://github.com/zzl0 +.. _`Xavier Francisco`: https://github.com/Qu4tro +.. _`VVelox`: https://github.com/VVelox +.. _`Telmo "Trooper"`: https://github.com/telmotrooper +.. _`Alexander Zawadzki`: https://github.com/zadacka +.. _`Sebastian Janko`: https://github.com/sebojanko +.. _`Pedro Ferrari`: https://github.com/petobens +.. _`BrownShibaDog`: https://github.com/BrownShibaDog +.. _`thegeorgeous`: https://github.com/thegeorgeous +.. _`laixintao`: https://github.com/laixintao +.. _`anthonydb`: https://github.com/anthonydb +.. _`Daniel Kukula`: https://github.com/dkuku diff --git a/pgcli-completion.bash b/pgcli-completion.bash new file mode 100644 index 0000000..3549b56 --- /dev/null +++ b/pgcli-completion.bash @@ -0,0 +1,61 @@ +_pg_databases() +{ + # -w was introduced in 8.4, https://launchpad.net/bugs/164772 + # "Access privileges" in output may contain linefeeds, hence the NF > 1 + COMPREPLY=( $( compgen -W "$( psql -AtqwlF $'\t' 2>/dev/null | \ + awk 'NF > 1 { print $1 }' )" -- "$cur" ) ) +} + +_pg_users() +{ + # -w was introduced in 8.4, https://launchpad.net/bugs/164772 + COMPREPLY=( $( compgen -W "$( psql -Atqwc 'select usename from pg_user' \ + template1 2>/dev/null )" -- "$cur" ) ) + [[ ${#COMPREPLY[@]} -eq 0 ]] && COMPREPLY=( $( compgen -u -- "$cur" ) ) +} + +_pgcli() +{ + local cur prev words cword + _init_completion -s || return + + case $prev in + -h|--host) + _known_hosts_real "$cur" + return 0 + ;; + -U|--user) + _pg_users + return 0 + ;; + -d|--dbname) + _pg_databases + return 0 + ;; + --help|-v|--version|-p|--port|-R|--row-limit) + # all other arguments are noop with these + return 0 + ;; + esac + + case "$cur" in + --*) + # return list of available options + COMPREPLY=( $( compgen -W '--host --port --user --password --no-password + --single-connection --version --dbname --pgclirc --dsn + --row-limit --help' -- "$cur" ) ) + [[ $COMPREPLY == *= ]] && compopt -o nospace + return 0 + ;; + -) + # only complete long options + compopt -o nospace + COMPREPLY=( -- ) + return 0 + ;; + *) + # return list of available databases + _pg_databases + esac +} && +complete -F _pgcli pgcli diff --git a/pgcli/__init__.py b/pgcli/__init__.py new file mode 100644 index 0000000..76ad18b --- /dev/null +++ b/pgcli/__init__.py @@ -0,0 +1 @@ +__version__ = "4.0.1" diff --git a/pgcli/__main__.py b/pgcli/__main__.py new file mode 100644 index 0000000..ddf1662 --- /dev/null +++ b/pgcli/__main__.py @@ -0,0 +1,9 @@ +""" +pgcli package main entry point +""" + +from .main import cli + + +if __name__ == "__main__": + cli() diff --git a/pgcli/auth.py b/pgcli/auth.py new file mode 100644 index 0000000..2f1e552 --- /dev/null +++ b/pgcli/auth.py @@ -0,0 +1,60 @@ +import click +from textwrap import dedent + + +keyring = None # keyring will be loaded later + + +keyring_error_message = dedent( + """\ + {} + {} + To remove this message do one of the following: + - prepare keyring as described at: https://keyring.readthedocs.io/en/stable/ + - uninstall keyring: pip uninstall keyring + - disable keyring in our configuration: add keyring = False to [main]""" +) + + +def keyring_initialize(keyring_enabled, *, logger): + """Initialize keyring only if explicitly enabled""" + global keyring + + if keyring_enabled: + # Try best to load keyring (issue #1041). + import importlib + + try: + keyring = importlib.import_module("keyring") + except ( + ModuleNotFoundError + ) as e: # ImportError for Python 2, ModuleNotFoundError for Python 3 + logger.warning("import keyring failed: %r.", e) + + +def keyring_get_password(key): + """Attempt to get password from keyring""" + # Find password from store + passwd = "" + try: + passwd = keyring.get_password("pgcli", key) or "" + except Exception as e: + click.secho( + keyring_error_message.format( + "Load your password from keyring returned:", str(e) + ), + err=True, + fg="red", + ) + return passwd + + +def keyring_set_password(key, passwd): + try: + keyring.set_password("pgcli", key, passwd) + except Exception as e: + click.secho( + keyring_error_message.format("Set password in keyring returned:", str(e)), + err=True, + fg="red", + ) diff --git a/pgcli/completion_refresher.py b/pgcli/completion_refresher.py new file mode 100644 index 0000000..c887cb6 --- /dev/null +++ b/pgcli/completion_refresher.py @@ -0,0 +1,152 @@ +import threading +import os +from collections import OrderedDict + +from .pgcompleter import PGCompleter + + +class CompletionRefresher: + refreshers = OrderedDict() + + def __init__(self): + self._completer_thread = None + self._restart_refresh = threading.Event() + + def refresh(self, executor, special, callbacks, history=None, settings=None): + """ + Creates a PGCompleter object and populates it with the relevant + completion suggestions in a background thread. + + executor - PGExecute object, used to extract the credentials to connect + to the database. + special - PGSpecial object used for creating a new completion object. + settings - dict of settings for completer object + callbacks - A function or a list of functions to call after the thread + has completed the refresh. The newly created completion + object will be passed in as an argument to each callback. + """ + if executor.is_virtual_database(): + # do nothing + return [(None, None, None, "Auto-completion refresh can't be started.")] + + if self.is_refreshing(): + self._restart_refresh.set() + return [(None, None, None, "Auto-completion refresh restarted.")] + else: + self._completer_thread = threading.Thread( + target=self._bg_refresh, + args=(executor, special, callbacks, history, settings), + name="completion_refresh", + ) + self._completer_thread.daemon = True + self._completer_thread.start() + return [ + (None, None, None, "Auto-completion refresh started in the background.") + ] + + def is_refreshing(self): + return self._completer_thread and self._completer_thread.is_alive() + + def _bg_refresh(self, pgexecute, special, callbacks, history=None, settings=None): + settings = settings or {} + completer = PGCompleter( + smart_completion=True, pgspecial=special, settings=settings + ) + + if settings.get("single_connection"): + executor = pgexecute + else: + # Create a new pgexecute method to populate the completions. + executor = pgexecute.copy() + # If callbacks is a single function then push it into a list. + if callable(callbacks): + callbacks = [callbacks] + + while 1: + for refresher in self.refreshers.values(): + refresher(completer, executor) + if self._restart_refresh.is_set(): + self._restart_refresh.clear() + break + else: + # Break out of while loop if the for loop finishes natually + # without hitting the break statement. + break + + # Start over the refresh from the beginning if the for loop hit the + # break statement. + continue + + # Load history into pgcompleter so it can learn user preferences + n_recent = 100 + if history: + for recent in history.get_strings()[-n_recent:]: + completer.extend_query_history(recent, is_init=True) + + for callback in callbacks: + callback(completer) + + if not settings.get("single_connection") and executor.conn: + # close connection established with pgexecute.copy() + executor.conn.close() + + +def refresher(name, refreshers=CompletionRefresher.refreshers): + """Decorator to populate the dictionary of refreshers with the current + function. + """ + + def wrapper(wrapped): + refreshers[name] = wrapped + return wrapped + + return wrapper + + +@refresher("schemata") +def refresh_schemata(completer, executor): + completer.set_search_path(executor.search_path()) + completer.extend_schemata(executor.schemata()) + + +@refresher("tables") +def refresh_tables(completer, executor): + completer.extend_relations(executor.tables(), kind="tables") + completer.extend_columns(executor.table_columns(), kind="tables") + completer.extend_foreignkeys(executor.foreignkeys()) + + +@refresher("views") +def refresh_views(completer, executor): + completer.extend_relations(executor.views(), kind="views") + completer.extend_columns(executor.view_columns(), kind="views") + + +@refresher("types") +def refresh_types(completer, executor): + completer.extend_datatypes(executor.datatypes()) + + +@refresher("databases") +def refresh_databases(completer, executor): + completer.extend_database_names(executor.databases()) + + +@refresher("casing") +def refresh_casing(completer, executor): + casing_file = completer.casing_file + if not casing_file: + return + generate_casing_file = completer.generate_casing_file + if generate_casing_file and not os.path.isfile(casing_file): + casing_prefs = "\n".join(executor.casing()) + with open(casing_file, "w") as f: + f.write(casing_prefs) + if os.path.isfile(casing_file): + with open(casing_file) as f: + completer.extend_casing([line.strip() for line in f]) + + +@refresher("functions") +def refresh_functions(completer, executor): + completer.extend_functions(executor.functions()) diff --git a/pgcli/config.py b/pgcli/config.py new file mode 100644 index 0000000..22f08dc --- /dev/null +++ b/pgcli/config.py @@ -0,0 +1,99 @@ +import errno +import shutil +import os +import platform +from os.path import expanduser, exists, dirname +import re +from typing import TextIO +from configobj import ConfigObj + + +def config_location(): + if "XDG_CONFIG_HOME" in os.environ: + return "%s/pgcli/" % expanduser(os.environ["XDG_CONFIG_HOME"]) + elif platform.system() == "Windows": + return os.getenv("USERPROFILE") + "\\AppData\\Local\\dbcli\\pgcli\\" + else: + return expanduser("~/.config/pgcli/") + + +def load_config(usr_cfg, def_cfg=None): + # avoid config merges when possible. For writing, we need an umerged config instance. + # see https://github.com/dbcli/pgcli/issues/1240 and https://github.com/DiffSK/configobj/issues/171 + if def_cfg: + cfg = ConfigObj() + cfg.merge(ConfigObj(def_cfg, interpolation=False)) + cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8")) + else: + cfg = ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8") + cfg.filename = expanduser(usr_cfg) + return cfg + + +def ensure_dir_exists(path): + parent_dir = expanduser(dirname(path)) + os.makedirs(parent_dir, exist_ok=True) + + +def write_default_config(source, destination, overwrite=False): + destination = expanduser(destination) + if not overwrite and exists(destination): + return + + ensure_dir_exists(destination) + + shutil.copyfile(source, destination) + + +def upgrade_config(config, def_config): + cfg = load_config(config, def_config) + cfg.write() + + +def get_config_filename(pgclirc_file=None): + return pgclirc_file or "%sconfig" % config_location() + + +def get_config(pgclirc_file=None): + from pgcli import __file__ as package_root + + package_root = os.path.dirname(package_root) + + pgclirc_file = get_config_filename(pgclirc_file) + + default_config = os.path.join(package_root, "pgclirc") + write_default_config(default_config, pgclirc_file) + + return load_config(pgclirc_file, default_config) + + +def get_casing_file(config): + casing_file = config["main"]["casing_file"] + if casing_file == "default": + casing_file = config_location() + "casing" + return casing_file + + +def skip_initial_comment(f_stream: TextIO) -> int: + """ + Initial comment in ~/.pg_service.conf is not always marked with '#' + which crashes the parser. This function takes a file object and + "rewinds" it to the beginning of the first section, + from where on it can be parsed safely + + :return: number of skipped lines + """ + section_regex = r"\s*\[" + pos = f_stream.tell() + lines_skipped = 0 + while True: + line = f_stream.readline() + if line == "": + break + if re.match(section_regex, line) is not None: + f_stream.seek(pos) + break + else: + pos += len(line) + lines_skipped += 1 + return lines_skipped diff --git a/pgcli/explain_output_formatter.py b/pgcli/explain_output_formatter.py new file mode 100644 index 0000000..ce45b4f --- /dev/null +++ b/pgcli/explain_output_formatter.py @@ -0,0 +1,19 @@ +from pgcli.pyev import Visualizer +import json + + +"""Explain response output adapter""" + + +class ExplainOutputFormatter: + def __init__(self, max_width): + self.max_width = max_width + + def format_output(self, cur, headers, **output_kwargs): + # explain query results should always contain 1 row each + [(data,)] = list(cur) + explain_list = json.loads(data) + visualizer = Visualizer(self.max_width) + for explain in explain_list: + visualizer.load(explain) + yield visualizer.get_list() diff --git a/pgcli/key_bindings.py b/pgcli/key_bindings.py new file mode 100644 index 0000000..9c016f7 --- /dev/null +++ b/pgcli/key_bindings.py @@ -0,0 +1,133 @@ +import logging +from prompt_toolkit.enums import EditingMode +from prompt_toolkit.key_binding import KeyBindings +from prompt_toolkit.filters import ( + completion_is_selected, + is_searching, + has_completions, + has_selection, + vi_mode, +) + +from .pgbuffer import buffer_should_be_handled, safe_multi_line_mode + +_logger = logging.getLogger(__name__) + + +def pgcli_bindings(pgcli): + """Custom key bindings for pgcli.""" + kb = KeyBindings() + + tab_insert_text = " " * 4 + + @kb.add("f2") + def _(event): + """Enable/Disable SmartCompletion Mode.""" + _logger.debug("Detected F2 key.") + pgcli.completer.smart_completion = not pgcli.completer.smart_completion + + @kb.add("f3") + def _(event): + """Enable/Disable Multiline Mode.""" + _logger.debug("Detected F3 key.") + pgcli.multi_line = not pgcli.multi_line + + @kb.add("f4") + def _(event): + """Toggle between Vi and Emacs mode.""" + _logger.debug("Detected F4 key.") + pgcli.vi_mode = not pgcli.vi_mode + event.app.editing_mode = EditingMode.VI if pgcli.vi_mode else EditingMode.EMACS + + @kb.add("f5") + def _(event): + """Toggle between Vi and Emacs mode.""" + _logger.debug("Detected F5 key.") + pgcli.explain_mode = not pgcli.explain_mode + + @kb.add("tab") + def _(event): + """Force autocompletion at cursor on non-empty lines.""" + + _logger.debug("Detected <Tab> key.") + + buff = event.app.current_buffer + doc = buff.document + + if doc.on_first_line or doc.current_line.strip(): + if buff.complete_state: + buff.complete_next() + else: + buff.start_completion(select_first=True) + else: + buff.insert_text(tab_insert_text, fire_event=False) + + @kb.add("escape", filter=has_completions) + def _(event): + """Force closing of autocompletion.""" + _logger.debug("Detected <Esc> key.") + + event.current_buffer.complete_state = None + event.app.current_buffer.complete_state = None + + @kb.add("c-space") + def _(event): + """ + Initialize autocompletion at cursor. + + If the autocompletion menu is not showing, display it with the + appropriate completions for the context. + + If the menu is showing, select the next completion. + """ + _logger.debug("Detected <C-Space> key.") + + b = event.app.current_buffer + if b.complete_state: + b.complete_next() + else: + b.start_completion(select_first=False) + + @kb.add("enter", filter=completion_is_selected) + def _(event): + """Makes the enter key work as the tab key only when showing the menu. + + In other words, don't execute query when enter is pressed in + the completion dropdown menu, instead close the dropdown menu + (accept current selection). + + """ + _logger.debug("Detected enter key during completion selection.") + + event.current_buffer.complete_state = None + event.app.current_buffer.complete_state = None + + # When using multi_line input mode the buffer is not handled on Enter (a new line is + # inserted instead), so we force the handling if we're not in a completion or + # history search, and one of several conditions are True + @kb.add( + "enter", + filter=~(completion_is_selected | is_searching) + & buffer_should_be_handled(pgcli), + ) + def _(event): + _logger.debug("Detected enter key.") + event.current_buffer.validate_and_handle() + + @kb.add("escape", "enter", filter=~vi_mode & ~safe_multi_line_mode(pgcli)) + def _(event): + """Introduces a line break regardless of multi-line mode or not.""" + _logger.debug("Detected alt-enter key.") + event.app.current_buffer.insert_text("\n") + + @kb.add("c-p", filter=~has_selection) + def _(event): + """Move up in history.""" + event.current_buffer.history_backward(count=event.arg) + + @kb.add("c-n", filter=~has_selection) + def _(event): + """Move down in history.""" + event.current_buffer.history_forward(count=event.arg) + + return kb diff --git a/pgcli/magic.py b/pgcli/magic.py new file mode 100644 index 0000000..09902a2 --- /dev/null +++ b/pgcli/magic.py @@ -0,0 +1,71 @@ +from .main import PGCli +import sql.parse +import sql.connection +import logging + +_logger = logging.getLogger(__name__) + + +def load_ipython_extension(ipython): + """This is called via the ipython command '%load_ext pgcli.magic'""" + + # first, load the sql magic if it isn't already loaded + if not ipython.find_line_magic("sql"): + ipython.run_line_magic("load_ext", "sql") + + # register our own magic + ipython.register_magic_function(pgcli_line_magic, "line", "pgcli") + + +def pgcli_line_magic(line): + _logger.debug("pgcli magic called: %r", line) + parsed = sql.parse.parse(line, {}) + # "get" was renamed to "set" in ipython-sql: + # https://github.com/catherinedevlin/ipython-sql/commit/f4283c65aaf68f961e84019e8b939e4a3c501d43 + if hasattr(sql.connection.Connection, "get"): + conn = sql.connection.Connection.get(parsed["connection"]) + else: + try: + conn = sql.connection.Connection.set(parsed["connection"]) + # a new positional argument was added to Connection.set in version 0.4.0 of ipython-sql + except TypeError: + conn = sql.connection.Connection.set(parsed["connection"], False) + + try: + # A corresponding pgcli object already exists + pgcli = conn._pgcli + _logger.debug("Reusing existing pgcli") + except AttributeError: + # I can't figure out how to get the underylying psycopg2 connection + # from the sqlalchemy connection, so just grab the url and make a + # new connection + pgcli = PGCli() + u = conn.session.engine.url + _logger.debug("New pgcli: %r", str(u)) + + pgcli.connect_uri(str(u._replace(drivername="postgres"))) + conn._pgcli = pgcli + + # For convenience, print the connection alias + print(f"Connected: {conn.name}") + + try: + pgcli.run_cli() + except SystemExit: + pass + + if not pgcli.query_history: + return + + q = pgcli.query_history[-1] + + if not q.successful: + _logger.debug("Unsuccessful query - ignoring") + return + + if q.meta_changed or q.db_changed or q.path_changed: + _logger.debug("Dangerous query detected -- ignoring") + return + + ipython = get_ipython() + return ipython.run_cell_magic("sql", line, q.query) diff --git a/pgcli/main.py b/pgcli/main.py new file mode 100644 index 0000000..f95c800 --- /dev/null +++ b/pgcli/main.py @@ -0,0 +1,1728 @@ +from configobj import ConfigObj, ParseError +from pgspecial.namedqueries import NamedQueries +from .config import skip_initial_comment + +import atexit +import os +import re +import sys +import traceback +import logging +import threading +import shutil +import functools +import pendulum +import datetime as dt +import itertools +import platform +from time import time, sleep +from typing import Optional + +from cli_helpers.tabular_output import TabularOutputFormatter +from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers +from cli_helpers.utils import strip_ansi +from .explain_output_formatter import ExplainOutputFormatter +import click + +try: + import setproctitle +except ImportError: + setproctitle = None +from prompt_toolkit.completion import DynamicCompleter, ThreadedCompleter +from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode +from prompt_toolkit.shortcuts import PromptSession, CompleteStyle +from prompt_toolkit.document import Document +from prompt_toolkit.filters import HasFocus, IsDone +from prompt_toolkit.formatted_text import ANSI +from prompt_toolkit.lexers import PygmentsLexer +from prompt_toolkit.layout.processors import ( + ConditionalProcessor, + HighlightMatchingBracketProcessor, + TabsProcessor, +) +from prompt_toolkit.history import FileHistory +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from pygments.lexers.sql import PostgresLexer + +from pgspecial.main import PGSpecial, NO_QUERY, PAGER_OFF, PAGER_LONG_OUTPUT +import pgspecial as special + +from . import auth +from .pgcompleter import PGCompleter +from .pgtoolbar import create_toolbar_tokens_func +from .pgstyle import style_factory, style_factory_output +from .pgexecute import PGExecute +from .completion_refresher import CompletionRefresher +from .config import ( + get_casing_file, + load_config, + config_location, + ensure_dir_exists, + get_config, + get_config_filename, +) +from .key_bindings import pgcli_bindings +from .packages.formatter.sqlformatter import register_new_formatter +from .packages.prompt_utils import confirm, confirm_destructive_query +from .packages.parseutils import is_destructive +from .packages.parseutils import parse_destructive_warning +from .__init__ import __version__ + +click.disable_unicode_literals_warning = True + +from urllib.parse import urlparse + +from getpass import getuser + +from psycopg import OperationalError, InterfaceError +from psycopg.conninfo import make_conninfo, conninfo_to_dict + +from collections import namedtuple + +try: + import sshtunnel + + SSH_TUNNEL_SUPPORT = True +except ImportError: + SSH_TUNNEL_SUPPORT = False + + +# Ref: https://stackoverflow.com/questions/30425105/filter-special-chars-such-as-color-codes-from-shell-output +COLOR_CODE_REGEX = re.compile(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))") +DEFAULT_MAX_FIELD_WIDTH = 500 + +# Query tuples are used for maintaining history +MetaQuery = namedtuple( + "Query", + [ + "query", # The entire text of the command + "successful", # True If all subqueries were successful + "total_time", # Time elapsed executing the query and formatting results + "execution_time", # Time elapsed executing the query + "meta_changed", # True if any subquery executed create/alter/drop + "db_changed", # True if any subquery changed the database + "path_changed", # True if any subquery changed the search path + "mutated", # True if any subquery executed insert/update/delete + "is_special", # True if the query is a special command + ], +) +MetaQuery.__new__.__defaults__ = ("", False, 0, 0, False, False, False, False) + +OutputSettings = namedtuple( + "OutputSettings", + "table_format dcmlfmt floatfmt missingval expanded max_width case_function style_output max_field_width", +) +OutputSettings.__new__.__defaults__ = ( + None, + None, + None, + "<null>", + False, + None, + lambda x: x, + None, + DEFAULT_MAX_FIELD_WIDTH, +) + + +class PgCliQuitError(Exception): + pass + + +class PGCli: + default_prompt = "\\u@\\h:\\d> " + max_len_prompt = 30 + + def set_default_pager(self, config): + configured_pager = config["main"].get("pager") + os_environ_pager = os.environ.get("PAGER") + + if configured_pager: + self.logger.info( + 'Default pager found in config file: "%s"', configured_pager + ) + os.environ["PAGER"] = configured_pager + elif os_environ_pager: + self.logger.info( + 'Default pager found in PAGER environment variable: "%s"', + os_environ_pager, + ) + os.environ["PAGER"] = os_environ_pager + else: + self.logger.info( + "No default pager found in environment. Using os default pager" + ) + + # Set default set of less recommended options, if they are not already set. + # They are ignored if pager is different than less. + if not os.environ.get("LESS"): + os.environ["LESS"] = "-SRXF" + + def __init__( + self, + force_passwd_prompt=False, + never_passwd_prompt=False, + pgexecute=None, + pgclirc_file=None, + row_limit=None, + single_connection=False, + less_chatty=None, + prompt=None, + prompt_dsn=None, + auto_vertical_output=False, + warn=None, + ssh_tunnel_url: Optional[str] = None, + ): + self.force_passwd_prompt = force_passwd_prompt + self.never_passwd_prompt = never_passwd_prompt + self.pgexecute = pgexecute + self.dsn_alias = None + self.watch_command = None + + # Load config. + c = self.config = get_config(pgclirc_file) + + # at this point, config should be written to pgclirc_file if it did not exist. Read it. + self.config_writer = load_config(get_config_filename(pgclirc_file)) + + # make sure to use self.config_writer, not self.config + NamedQueries.instance = NamedQueries.from_config(self.config_writer) + + self.logger = logging.getLogger(__name__) + self.initialize_logging() + + self.set_default_pager(c) + self.output_file = None + self.pgspecial = PGSpecial() + + self.explain_mode = False + self.multi_line = c["main"].as_bool("multi_line") + self.multiline_mode = c["main"].get("multi_line_mode", "psql") + self.vi_mode = c["main"].as_bool("vi") + self.auto_expand = auto_vertical_output or c["main"].as_bool("auto_expand") + self.auto_retry_closed_connection = c["main"].as_bool( + "auto_retry_closed_connection" + ) + self.expanded_output = c["main"].as_bool("expand") + self.pgspecial.timing_enabled = c["main"].as_bool("timing") + if row_limit is not None: + self.row_limit = row_limit + else: + self.row_limit = c["main"].as_int("row_limit") + + # if not specified, set to DEFAULT_MAX_FIELD_WIDTH + # if specified but empty, set to None to disable truncation + # ellipsis will take at least 3 symbols, so this can't be less than 3 if specified and > 0 + max_field_width = c["main"].get("max_field_width", DEFAULT_MAX_FIELD_WIDTH) + if max_field_width and max_field_width.lower() != "none": + max_field_width = max(3, abs(int(max_field_width))) + else: + max_field_width = None + self.max_field_width = max_field_width + + self.min_num_menu_lines = c["main"].as_int("min_num_menu_lines") + self.multiline_continuation_char = c["main"]["multiline_continuation_char"] + self.table_format = c["main"]["table_format"] + self.syntax_style = c["main"]["syntax_style"] + self.cli_style = c["colors"] + self.wider_completion_menu = c["main"].as_bool("wider_completion_menu") + self.destructive_warning = parse_destructive_warning( + warn or c["main"].as_list("destructive_warning") + ) + self.destructive_warning_restarts_connection = c["main"].as_bool( + "destructive_warning_restarts_connection" + ) + self.destructive_statements_require_transaction = c["main"].as_bool( + "destructive_statements_require_transaction" + ) + + self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty") + self.null_string = c["main"].get("null_string", "<null>") + self.prompt_format = ( + prompt + if prompt is not None + else c["main"].get("prompt", self.default_prompt) + ) + self.prompt_dsn_format = prompt_dsn + self.on_error = c["main"]["on_error"].upper() + self.decimal_format = c["data_formats"]["decimal"] + self.float_format = c["data_formats"]["float"] + auth.keyring_initialize(c["main"].as_bool("keyring"), logger=self.logger) + self.show_bottom_toolbar = c["main"].as_bool("show_bottom_toolbar") + + self.pgspecial.pset_pager( + self.config["main"].as_bool("enable_pager") and "on" or "off" + ) + + self.style_output = style_factory_output(self.syntax_style, c["colors"]) + + self.now = dt.datetime.today() + + self.completion_refresher = CompletionRefresher() + + self.query_history = [] + + # Initialize completer + smart_completion = c["main"].as_bool("smart_completion") + keyword_casing = c["main"]["keyword_casing"] + single_connection = single_connection or c["main"].as_bool( + "always_use_single_connection" + ) + self.settings = { + "casing_file": get_casing_file(c), + "generate_casing_file": c["main"].as_bool("generate_casing_file"), + "generate_aliases": c["main"].as_bool("generate_aliases"), + "asterisk_column_order": c["main"]["asterisk_column_order"], + "qualify_columns": c["main"]["qualify_columns"], + "case_column_headers": c["main"].as_bool("case_column_headers"), + "search_path_filter": c["main"].as_bool("search_path_filter"), + "single_connection": single_connection, + "less_chatty": less_chatty, + "keyword_casing": keyword_casing, + "alias_map_file": c["main"]["alias_map_file"] or None, + } + + completer = PGCompleter( + smart_completion, pgspecial=self.pgspecial, settings=self.settings + ) + self.completer = completer + self._completer_lock = threading.Lock() + self.register_special_commands() + + self.prompt_app = None + + self.ssh_tunnel_config = c.get("ssh tunnels") + self.ssh_tunnel_url = ssh_tunnel_url + self.ssh_tunnel = None + + # formatter setup + self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"]) + register_new_formatter(self.formatter) + + def quit(self): + raise PgCliQuitError + + def register_special_commands(self): + self.pgspecial.register( + self.change_db, + "\\c", + "\\c[onnect] database_name", + "Change to a new database.", + aliases=("use", "\\connect", "USE"), + ) + + refresh_callback = lambda: self.refresh_completions(persist_priorities="all") + + self.pgspecial.register( + self.quit, + "\\q", + "\\q", + "Quit pgcli.", + arg_type=NO_QUERY, + case_sensitive=True, + aliases=(":q",), + ) + self.pgspecial.register( + self.quit, + "quit", + "quit", + "Quit pgcli.", + arg_type=NO_QUERY, + case_sensitive=False, + aliases=("exit",), + ) + self.pgspecial.register( + refresh_callback, + "\\#", + "\\#", + "Refresh auto-completions.", + arg_type=NO_QUERY, + ) + self.pgspecial.register( + refresh_callback, + "\\refresh", + "\\refresh", + "Refresh auto-completions.", + arg_type=NO_QUERY, + ) + self.pgspecial.register( + self.execute_from_file, "\\i", "\\i filename", "Execute commands from file." + ) + self.pgspecial.register( + self.write_to_file, + "\\o", + "\\o [filename]", + "Send all query results to file.", + ) + self.pgspecial.register( + self.info_connection, "\\conninfo", "\\conninfo", "Get connection details" + ) + self.pgspecial.register( + self.change_table_format, + "\\T", + "\\T [format]", + "Change the table format used to output results", + ) + + self.pgspecial.register( + self.echo, + "\\echo", + "\\echo [string]", + "Echo a string to stdout", + ) + + self.pgspecial.register( + self.echo, + "\\qecho", + "\\qecho [string]", + "Echo a string to the query output channel.", + ) + + def echo(self, pattern, **_): + return [(None, None, None, pattern)] + + def change_table_format(self, pattern, **_): + try: + if pattern not in TabularOutputFormatter().supported_formats: + raise ValueError() + self.table_format = pattern + yield (None, None, None, f"Changed table format to {pattern}") + except ValueError: + msg = f"Table format {pattern} not recognized. Allowed formats:" + for table_type in TabularOutputFormatter().supported_formats: + msg += f"\n\t{table_type}" + msg += "\nCurrently set to: %s" % self.table_format + yield (None, None, None, msg) + + def info_connection(self, **_): + if self.pgexecute.host.startswith("/"): + host = 'socket "%s"' % self.pgexecute.host + else: + host = 'host "%s"' % self.pgexecute.host + + yield ( + None, + None, + None, + 'You are connected to database "%s" as user ' + '"%s" on %s at port "%s".' + % (self.pgexecute.dbname, self.pgexecute.user, host, self.pgexecute.port), + ) + + def change_db(self, pattern, **_): + if pattern: + # Get all the parameters in pattern, handling double quotes if any. + infos = re.findall(r'"[^"]*"|[^"\'\s]+', pattern) + # Now removing quotes. + list(map(lambda s: s.strip('"'), infos)) + + infos.extend([None] * (4 - len(infos))) + db, user, host, port = infos + try: + self.pgexecute.connect( + database=db, + user=user, + host=host, + port=port, + **self.pgexecute.extra_args, + ) + except OperationalError as e: + click.secho(str(e), err=True, fg="red") + click.echo("Previous connection kept") + else: + self.pgexecute.connect() + + yield ( + None, + None, + None, + 'You are now connected to database "%s" as ' + 'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user), + ) + + def execute_from_file(self, pattern, **_): + if not pattern: + message = "\\i: missing required argument" + return [(None, None, None, message, "", False, True)] + try: + with open(os.path.expanduser(pattern), encoding="utf-8") as f: + query = f.read() + except OSError as e: + return [(None, None, None, str(e), "", False, True)] + + if self.destructive_warning: + if ( + self.destructive_statements_require_transaction + and not self.pgexecute.valid_transaction() + and is_destructive(query, self.destructive_warning) + ): + message = "Destructive statements must be run within a transaction. Command execution stopped." + return [(None, None, None, message)] + destroy = confirm_destructive_query( + query, self.destructive_warning, self.dsn_alias + ) + if destroy is False: + message = "Wise choice. Command execution stopped." + return [(None, None, None, message)] + + on_error_resume = self.on_error == "RESUME" + return self.pgexecute.run( + query, + self.pgspecial, + on_error_resume=on_error_resume, + explain_mode=self.explain_mode, + ) + + def write_to_file(self, pattern, **_): + if not pattern: + self.output_file = None + message = "File output disabled" + return [(None, None, None, message, "", True, True)] + filename = os.path.abspath(os.path.expanduser(pattern)) + if not os.path.isfile(filename): + try: + open(filename, "w").close() + except OSError as e: + self.output_file = None + message = str(e) + "\nFile output disabled" + return [(None, None, None, message, "", False, True)] + self.output_file = filename + message = 'Writing to file "%s"' % self.output_file + return [(None, None, None, message, "", True, True)] + + def initialize_logging(self): + log_file = self.config["main"]["log_file"] + if log_file == "default": + log_file = config_location() + "log" + ensure_dir_exists(log_file) + log_level = self.config["main"]["log_level"] + + # Disable logging if value is NONE by switching to a no-op handler. + # Set log level to a high value so it doesn't even waste cycles getting called. + if log_level.upper() == "NONE": + handler = logging.NullHandler() + else: + handler = logging.FileHandler(os.path.expanduser(log_file)) + + level_map = { + "CRITICAL": logging.CRITICAL, + "ERROR": logging.ERROR, + "WARNING": logging.WARNING, + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, + "NONE": logging.CRITICAL, + } + + log_level = level_map[log_level.upper()] + + formatter = logging.Formatter( + "%(asctime)s (%(process)d/%(threadName)s) " + "%(name)s %(levelname)s - %(message)s" + ) + + handler.setFormatter(formatter) + + root_logger = logging.getLogger("pgcli") + root_logger.addHandler(handler) + root_logger.setLevel(log_level) + + root_logger.debug("Initializing pgcli logging.") + root_logger.debug("Log file %r.", log_file) + + pgspecial_logger = logging.getLogger("pgspecial") + pgspecial_logger.addHandler(handler) + pgspecial_logger.setLevel(log_level) + + def connect_dsn(self, dsn, **kwargs): + self.connect(dsn=dsn, **kwargs) + + def connect_service(self, service, user): + service_config, file = parse_service_info(service) + if service_config is None: + click.secho( + f"service '{service}' was not found in {file}", err=True, fg="red" + ) + exit(1) + self.connect( + database=service_config.get("dbname"), + host=service_config.get("host"), + user=user or service_config.get("user"), + port=service_config.get("port"), + passwd=service_config.get("password"), + ) + + def connect_uri(self, uri): + kwargs = conninfo_to_dict(uri) + remap = {"dbname": "database", "password": "passwd"} + kwargs = {remap.get(k, k): v for k, v in kwargs.items()} + self.connect(**kwargs) + + def connect( + self, database="", host="", user="", port="", passwd="", dsn="", **kwargs + ): + # Connect to the database. + + if not user: + user = getuser() + + if not database: + database = user + + kwargs.setdefault("application_name", "pgcli") + + # If password prompt is not forced but no password is provided, try + # getting it from environment variable. + if not self.force_passwd_prompt and not passwd: + passwd = os.environ.get("PGPASSWORD", "") + + # Prompt for a password immediately if requested via the -W flag. This + # avoids wasting time trying to connect to the database and catching a + # no-password exception. + # If we successfully parsed a password from a URI, there's no need to + # prompt for it, even with the -W flag + if self.force_passwd_prompt and not passwd: + passwd = click.prompt( + "Password for %s" % user, hide_input=True, show_default=False, type=str + ) + + key = f"{user}@{host}" + + if not passwd and auth.keyring: + passwd = auth.keyring_get_password(key) + + def should_ask_for_password(exc): + # Prompt for a password after 1st attempt to connect + # fails. Don't prompt if the -w flag is supplied + if self.never_passwd_prompt: + return False + error_msg = exc.args[0] + if "no password supplied" in error_msg: + return True + if "password authentication failed" in error_msg: + return True + return False + + if dsn: + parsed_dsn = conninfo_to_dict(dsn) + if "host" in parsed_dsn: + host = parsed_dsn["host"] + if "port" in parsed_dsn: + port = parsed_dsn["port"] + + if self.ssh_tunnel_config and not self.ssh_tunnel_url: + for db_host_regex, tunnel_url in self.ssh_tunnel_config.items(): + if re.search(db_host_regex, host): + self.ssh_tunnel_url = tunnel_url + break + + if self.ssh_tunnel_url: + # We add the protocol as urlparse doesn't find it by itself + if "://" not in self.ssh_tunnel_url: + self.ssh_tunnel_url = f"ssh://{self.ssh_tunnel_url}" + + tunnel_info = urlparse(self.ssh_tunnel_url) + params = { + "local_bind_address": ("127.0.0.1",), + "remote_bind_address": (host, int(port or 5432)), + "ssh_address_or_host": (tunnel_info.hostname, tunnel_info.port or 22), + "logger": self.logger, + } + if tunnel_info.username: + params["ssh_username"] = tunnel_info.username + if tunnel_info.password: + params["ssh_password"] = tunnel_info.password + + # Hack: sshtunnel adds a console handler to the logger, so we revert handlers. + # We can remove this when https://github.com/pahaz/sshtunnel/pull/250 is merged. + logger_handlers = self.logger.handlers.copy() + try: + self.ssh_tunnel = sshtunnel.SSHTunnelForwarder(**params) + self.ssh_tunnel.start() + except Exception as e: + self.logger.handlers = logger_handlers + self.logger.error("traceback: %r", traceback.format_exc()) + click.secho(str(e), err=True, fg="red") + exit(1) + self.logger.handlers = logger_handlers + + atexit.register(self.ssh_tunnel.stop) + host = "127.0.0.1" + port = self.ssh_tunnel.local_bind_ports[0] + + if dsn: + dsn = make_conninfo(dsn, host=host, port=port) + + # Attempt to connect to the database. + # Note that passwd may be empty on the first attempt. If connection + # fails because of a missing or incorrect password, but we're allowed to + # prompt for a password (no -w flag), prompt for a passwd and try again. + try: + try: + pgexecute = PGExecute(database, user, passwd, host, port, dsn, **kwargs) + except (OperationalError, InterfaceError) as e: + if should_ask_for_password(e): + passwd = click.prompt( + "Password for %s" % user, + hide_input=True, + show_default=False, + type=str, + ) + pgexecute = PGExecute( + database, user, passwd, host, port, dsn, **kwargs + ) + else: + raise e + if passwd and auth.keyring: + auth.keyring_set_password(key, passwd) + + except Exception as e: # Connecting to a database could fail. + self.logger.debug("Database connection failed: %r.", e) + self.logger.error("traceback: %r", traceback.format_exc()) + click.secho(str(e), err=True, fg="red") + exit(1) + + self.pgexecute = pgexecute + + def handle_editor_command(self, text): + r""" + Editor command is any query that is prefixed or suffixed + by a '\e'. The reason for a while loop is because a user + might edit a query multiple times. + For eg: + "select * from \e"<enter> to edit it in vim, then come + back to the prompt with the edited query "select * from + blah where q = 'abc'\e" to edit it again. + :param text: Document + :return: Document + """ + editor_command = special.editor_command(text) + while editor_command: + if editor_command == "\\e": + filename = special.get_filename(text) + query = special.get_editor_query(text) or self.get_last_query() + else: # \ev or \ef + filename = None + spec = text.split()[1] + if editor_command == "\\ev": + query = self.pgexecute.view_definition(spec) + elif editor_command == "\\ef": + query = self.pgexecute.function_definition(spec) + sql, message = special.open_external_editor(filename, sql=query) + if message: + # Something went wrong. Raise an exception and bail. + raise RuntimeError(message) + while True: + try: + text = self.prompt_app.prompt(default=sql) + break + except KeyboardInterrupt: + sql = "" + + editor_command = special.editor_command(text) + return text + + def execute_command(self, text, handle_closed_connection=True): + logger = self.logger + + query = MetaQuery(query=text, successful=False) + + try: + if self.destructive_warning: + if ( + self.destructive_statements_require_transaction + and not self.pgexecute.valid_transaction() + and is_destructive(text, self.destructive_warning) + ): + click.secho( + "Destructive statements must be run within a transaction." + ) + raise KeyboardInterrupt + destroy = confirm_destructive_query( + text, self.destructive_warning, self.dsn_alias + ) + if destroy is False: + click.secho("Wise choice!") + raise KeyboardInterrupt + elif destroy: + click.secho("Your call!") + + output, query = self._evaluate_command(text) + except KeyboardInterrupt: + if self.destructive_warning_restarts_connection: + # Restart connection to the database + self.pgexecute.connect() + logger.debug("cancelled query and restarted connection, sql: %r", text) + click.secho( + "cancelled query and restarted connection", err=True, fg="red" + ) + else: + logger.debug("cancelled query, sql: %r", text) + click.secho("cancelled query", err=True, fg="red") + except NotImplementedError: + click.secho("Not Yet Implemented.", fg="yellow") + except OperationalError as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + click.secho(str(e), err=True, fg="red") + if handle_closed_connection: + self._handle_server_closed_connection(text) + except (PgCliQuitError, EOFError): + raise + except Exception as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + click.secho(str(e), err=True, fg="red") + else: + try: + if self.output_file and not text.startswith( + ("\\o ", "\\? ", "\\echo ") + ): + try: + with open(self.output_file, "a", encoding="utf-8") as f: + click.echo(text, file=f) + click.echo("\n".join(output), file=f) + click.echo("", file=f) # extra newline + except OSError as e: + click.secho(str(e), err=True, fg="red") + else: + if output: + self.echo_via_pager("\n".join(output)) + except KeyboardInterrupt: + pass + + if self.pgspecial.timing_enabled: + # Only add humanized time display if > 1 second + if query.total_time > 1: + print( + "Time: %0.03fs (%s), executed in: %0.03fs (%s)" + % ( + query.total_time, + pendulum.Duration(seconds=query.total_time).in_words(), + query.execution_time, + pendulum.Duration(seconds=query.execution_time).in_words(), + ) + ) + else: + print("Time: %0.03fs" % query.total_time) + + # Check if we need to update completions, in order of most + # to least drastic changes + if query.db_changed: + with self._completer_lock: + self.completer.reset_completions() + self.refresh_completions(persist_priorities="keywords") + elif query.meta_changed: + self.refresh_completions(persist_priorities="all") + elif query.path_changed: + logger.debug("Refreshing search path") + with self._completer_lock: + self.completer.set_search_path(self.pgexecute.search_path()) + logger.debug("Search path: %r", self.completer.search_path) + return query + + def _check_ongoing_transaction_and_allow_quitting(self): + """Return whether we can really quit, possibly by asking the + user to confirm so if there is an ongoing transaction. + """ + if not self.pgexecute.valid_transaction(): + return True + while 1: + try: + choice = click.prompt( + "A transaction is ongoing. Choose `c` to COMMIT, `r` to ROLLBACK, `a` to abort exit.", + default="a", + ) + except click.Abort: + # Print newline if user aborts with `^C`, otherwise + # pgcli's prompt will be printed on the same line + # (just after the confirmation prompt). + click.echo(None, err=False) + choice = "a" + choice = choice.lower() + if choice == "a": + return False # do not quit + if choice == "c": + query = self.execute_command("commit") + return query.successful # quit only if query is successful + if choice == "r": + query = self.execute_command("rollback") + return query.successful # quit only if query is successful + + def run_cli(self): + logger = self.logger + + history_file = self.config["main"]["history_file"] + if history_file == "default": + history_file = config_location() + "history" + history = FileHistory(os.path.expanduser(history_file)) + self.refresh_completions(history=history, persist_priorities="none") + + self.prompt_app = self._build_cli(history) + + if not self.less_chatty: + print("Server: PostgreSQL", self.pgexecute.server_version) + print("Version:", __version__) + print("Home: http://pgcli.com") + + try: + while True: + try: + text = self.prompt_app.prompt() + except KeyboardInterrupt: + continue + except EOFError: + if not self._check_ongoing_transaction_and_allow_quitting(): + continue + raise + + try: + text = self.handle_editor_command(text) + except RuntimeError as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + click.secho(str(e), err=True, fg="red") + continue + + try: + self.handle_watch_command(text) + except PgCliQuitError: + if not self._check_ongoing_transaction_and_allow_quitting(): + continue + raise + + self.now = dt.datetime.today() + + # Allow PGCompleter to learn user's preferred keywords, etc. + with self._completer_lock: + self.completer.extend_query_history(text) + + except (PgCliQuitError, EOFError): + if not self.less_chatty: + print("Goodbye!") + + def handle_watch_command(self, text): + # Initialize default metaquery in case execution fails + self.watch_command, timing = special.get_watch_command(text) + + # If we run \watch without a command, apply it to the last query run. + if self.watch_command is not None and not self.watch_command.strip(): + try: + self.watch_command = self.query_history[-1].query + except IndexError: + click.secho( + "\\watch cannot be used with an empty query", err=True, fg="red" + ) + self.watch_command = None + + # If there's a command to \watch, run it in a loop. + if self.watch_command: + while self.watch_command: + try: + query = self.execute_command(self.watch_command) + click.echo(f"Waiting for {timing} seconds before repeating") + sleep(timing) + except KeyboardInterrupt: + self.watch_command = None + + # Otherwise, execute it as a regular command. + else: + query = self.execute_command(text) + + self.query_history.append(query) + + def _build_cli(self, history): + key_bindings = pgcli_bindings(self) + + def get_message(): + if self.dsn_alias and self.prompt_dsn_format is not None: + prompt_format = self.prompt_dsn_format + else: + prompt_format = self.prompt_format + + prompt = self.get_prompt(prompt_format) + + if ( + prompt_format == self.default_prompt + and len(prompt) > self.max_len_prompt + ): + prompt = self.get_prompt("\\d> ") + + prompt = prompt.replace("\\x1b", "\x1b") + return ANSI(prompt) + + def get_continuation(width, line_number, is_soft_wrap): + continuation = self.multiline_continuation_char * (width - 1) + " " + return [("class:continuation", continuation)] + + get_toolbar_tokens = create_toolbar_tokens_func(self) + + if self.wider_completion_menu: + complete_style = CompleteStyle.MULTI_COLUMN + else: + complete_style = CompleteStyle.COLUMN + + with self._completer_lock: + prompt_app = PromptSession( + lexer=PygmentsLexer(PostgresLexer), + reserve_space_for_menu=self.min_num_menu_lines, + message=get_message, + prompt_continuation=get_continuation, + bottom_toolbar=get_toolbar_tokens if self.show_bottom_toolbar else None, + complete_style=complete_style, + input_processors=[ + # Highlight matching brackets while editing. + ConditionalProcessor( + processor=HighlightMatchingBracketProcessor(chars="[](){}"), + filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(), + ), + # Render \t as 4 spaces instead of "^I" + TabsProcessor(char1=" ", char2=" "), + ], + auto_suggest=AutoSuggestFromHistory(), + tempfile_suffix=".sql", + # N.b. pgcli's multi-line mode controls submit-on-Enter (which + # overrides the default behaviour of prompt_toolkit) and is + # distinct from prompt_toolkit's multiline mode here, which + # controls layout/display of the prompt/buffer + multiline=True, + history=history, + completer=ThreadedCompleter(DynamicCompleter(lambda: self.completer)), + complete_while_typing=True, + style=style_factory(self.syntax_style, self.cli_style), + include_default_pygments_style=False, + key_bindings=key_bindings, + enable_open_in_editor=True, + enable_system_prompt=True, + enable_suspend=True, + editing_mode=EditingMode.VI if self.vi_mode else EditingMode.EMACS, + search_ignore_case=True, + ) + + return prompt_app + + def _should_limit_output(self, sql, cur): + """returns True if the output should be truncated, False otherwise.""" + if self.explain_mode: + return False + if not is_select(sql): + return False + + return ( + not self._has_limit(sql) + and self.row_limit != 0 + and cur + and cur.rowcount > self.row_limit + ) + + def _has_limit(self, sql): + if not sql: + return False + return "limit " in sql.lower() + + def _limit_output(self, cur): + limit = min(self.row_limit, cur.rowcount) + new_cur = itertools.islice(cur, limit) + new_status = "SELECT " + str(limit) + click.secho("The result was limited to %s rows" % limit, fg="red") + + return new_cur, new_status + + def _evaluate_command(self, text): + """Used to run a command entered by the user during CLI operation + (Puts the E in REPL) + + returns (results, MetaQuery) + """ + logger = self.logger + logger.debug("sql: %r", text) + + # set query to formatter in order to parse table name + self.formatter.query = text + all_success = True + meta_changed = False # CREATE, ALTER, DROP, etc + mutated = False # INSERT, DELETE, etc + db_changed = False + path_changed = False + output = [] + total = 0 + execution = 0 + + # Run the query. + start = time() + on_error_resume = self.on_error == "RESUME" + res = self.pgexecute.run( + text, + self.pgspecial, + exception_formatter, + on_error_resume, + explain_mode=self.explain_mode, + ) + + is_special = None + + for title, cur, headers, status, sql, success, is_special in res: + logger.debug("headers: %r", headers) + logger.debug("rows: %r", cur) + logger.debug("status: %r", status) + + if self._should_limit_output(sql, cur): + cur, status = self._limit_output(cur) + + if self.pgspecial.auto_expand or self.auto_expand: + max_width = self.prompt_app.output.get_size().columns + else: + max_width = None + + expanded = self.pgspecial.expanded_output or self.expanded_output + settings = OutputSettings( + table_format=self.table_format, + dcmlfmt=self.decimal_format, + floatfmt=self.float_format, + missingval=self.null_string, + expanded=expanded, + max_width=max_width, + case_function=( + self.completer.case + if self.settings["case_column_headers"] + else lambda x: x + ), + style_output=self.style_output, + max_field_width=self.max_field_width, + ) + execution = time() - start + formatted = format_output( + title, cur, headers, status, settings, self.explain_mode + ) + + output.extend(formatted) + total = time() - start + + # Keep track of whether any of the queries are mutating or changing + # the database + if success: + mutated = mutated or is_mutating(status) + db_changed = db_changed or has_change_db_cmd(sql) + meta_changed = meta_changed or has_meta_cmd(sql) + path_changed = path_changed or has_change_path_cmd(sql) + else: + all_success = False + + meta_query = MetaQuery( + text, + all_success, + total, + execution, + meta_changed, + db_changed, + path_changed, + mutated, + is_special, + ) + + return output, meta_query + + def _handle_server_closed_connection(self, text): + """Used during CLI execution.""" + try: + click.secho("Reconnecting...", fg="green") + self.pgexecute.connect() + click.secho("Reconnected!", fg="green") + except OperationalError as e: + click.secho("Reconnect Failed", fg="red") + click.secho(str(e), err=True, fg="red") + else: + retry = self.auto_retry_closed_connection or confirm( + "Run the query from before reconnecting?" + ) + if retry: + click.secho("Running query...", fg="green") + # Don't get stuck in a retry loop + self.execute_command(text, handle_closed_connection=False) + + def refresh_completions(self, history=None, persist_priorities="all"): + """Refresh outdated completions + + :param history: A prompt_toolkit.history.FileHistory object. Used to + load keyword and identifier preferences + + :param persist_priorities: 'all' or 'keywords' + """ + + callback = functools.partial( + self._on_completions_refreshed, persist_priorities=persist_priorities + ) + return self.completion_refresher.refresh( + self.pgexecute, + self.pgspecial, + callback, + history=history, + settings=self.settings, + ) + + def _on_completions_refreshed(self, new_completer, persist_priorities): + self._swap_completer_objects(new_completer, persist_priorities) + + if self.prompt_app: + # After refreshing, redraw the CLI to clear the statusbar + # "Refreshing completions..." indicator + self.prompt_app.app.invalidate() + + def _swap_completer_objects(self, new_completer, persist_priorities): + """Swap the completer object with the newly created completer. + + persist_priorities is a string specifying how the old completer's + learned prioritizer should be transferred to the new completer. + + 'none' - The new prioritizer is left in a new/clean state + + 'all' - The new prioritizer is updated to exactly reflect + the old one + + 'keywords' - The new prioritizer is updated with old keyword + priorities, but not any other. + + """ + with self._completer_lock: + old_completer = self.completer + self.completer = new_completer + + if persist_priorities == "all": + # Just swap over the entire prioritizer + new_completer.prioritizer = old_completer.prioritizer + elif persist_priorities == "keywords": + # Swap over the entire prioritizer, but clear name priorities, + # leaving learned keyword priorities alone + new_completer.prioritizer = old_completer.prioritizer + new_completer.prioritizer.clear_names() + elif persist_priorities == "none": + # Leave the new prioritizer as is + pass + self.completer = new_completer + + def get_completions(self, text, cursor_positition): + with self._completer_lock: + return self.completer.get_completions( + Document(text=text, cursor_position=cursor_positition), None + ) + + def get_prompt(self, string): + # should be before replacing \\d + string = string.replace("\\dsn_alias", self.dsn_alias or "") + string = string.replace("\\t", self.now.strftime("%x %X")) + string = string.replace("\\u", self.pgexecute.user or "(none)") + string = string.replace("\\H", self.pgexecute.host or "(none)") + string = string.replace("\\h", self.pgexecute.short_host or "(none)") + string = string.replace("\\d", self.pgexecute.dbname or "(none)") + string = string.replace( + "\\p", + str(self.pgexecute.port) if self.pgexecute.port is not None else "5432", + ) + string = string.replace("\\i", str(self.pgexecute.pid) or "(none)") + string = string.replace("\\#", "#" if self.pgexecute.superuser else ">") + string = string.replace("\\n", "\n") + return string + + def get_last_query(self): + """Get the last query executed or None.""" + return self.query_history[-1][0] if self.query_history else None + + def is_too_wide(self, line): + """Will this line be too wide to fit into terminal?""" + if not self.prompt_app: + return False + return ( + len(COLOR_CODE_REGEX.sub("", line)) + > self.prompt_app.output.get_size().columns + ) + + def is_too_tall(self, lines): + """Are there too many lines to fit into terminal?""" + if not self.prompt_app: + return False + return len(lines) >= (self.prompt_app.output.get_size().rows - 4) + + def echo_via_pager(self, text, color=None): + if self.pgspecial.pager_config == PAGER_OFF or self.watch_command: + click.echo(text, color=color) + elif ( + self.pgspecial.pager_config == PAGER_LONG_OUTPUT + and self.table_format != "csv" + ): + lines = text.split("\n") + + # The last 4 lines are reserved for the pgcli menu and padding + if self.is_too_tall(lines) or any(self.is_too_wide(l) for l in lines): + click.echo_via_pager(text, color=color) + else: + click.echo(text, color=color) + else: + click.echo_via_pager(text, color) + + +@click.command() +# Default host is '' so psycopg can default to either localhost or unix socket +@click.option( + "-h", + "--host", + default="", + envvar="PGHOST", + help="Host address of the postgres database.", +) +@click.option( + "-p", + "--port", + default=5432, + help="Port number at which the " "postgres instance is listening.", + envvar="PGPORT", + type=click.INT, +) +@click.option( + "-U", + "--username", + "username_opt", + help="Username to connect to the postgres database.", +) +@click.option( + "-u", "--user", "username_opt", help="Username to connect to the postgres database." +) +@click.option( + "-W", + "--password", + "prompt_passwd", + is_flag=True, + default=False, + help="Force password prompt.", +) +@click.option( + "-w", + "--no-password", + "never_prompt", + is_flag=True, + default=False, + help="Never prompt for password.", +) +@click.option( + "--single-connection", + "single_connection", + is_flag=True, + default=False, + help="Do not use a separate connection for completions.", +) +@click.option("-v", "--version", is_flag=True, help="Version of pgcli.") +@click.option("-d", "--dbname", "dbname_opt", help="database name to connect to.") +@click.option( + "--pgclirc", + default=config_location() + "config", + envvar="PGCLIRC", + help="Location of pgclirc file.", + type=click.Path(dir_okay=False), +) +@click.option( + "-D", + "--dsn", + default="", + envvar="DSN", + help="Use DSN configured into the [alias_dsn] section of pgclirc file.", +) +@click.option( + "--list-dsn", + "list_dsn", + is_flag=True, + help="list of DSN configured into the [alias_dsn] section of pgclirc file.", +) +@click.option( + "--row-limit", + default=None, + envvar="PGROWLIMIT", + type=click.INT, + help="Set threshold for row limit prompt. Use 0 to disable prompt.", +) +@click.option( + "--less-chatty", + "less_chatty", + is_flag=True, + default=False, + help="Skip intro on startup and goodbye on exit.", +) +@click.option("--prompt", help='Prompt format (Default: "\\u@\\h:\\d> ").') +@click.option( + "--prompt-dsn", + help='Prompt format for connections using DSN aliases (Default: "\\u@\\h:\\d> ").', +) +@click.option( + "-l", + "--list", + "list_databases", + is_flag=True, + help="list available databases, then exit.", +) +@click.option( + "--auto-vertical-output", + is_flag=True, + help="Automatically switch to vertical output mode if the result is wider than the terminal width.", +) +@click.option( + "--warn", + default=None, + help="Warn before running a destructive query.", +) +@click.option( + "--ssh-tunnel", + default=None, + help="Open an SSH tunnel to the given address and connect to the database from it.", +) +@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1) +@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1) +def cli( + dbname, + username_opt, + host, + port, + prompt_passwd, + never_prompt, + single_connection, + dbname_opt, + username, + version, + pgclirc, + dsn, + row_limit, + less_chatty, + prompt, + prompt_dsn, + list_databases, + auto_vertical_output, + list_dsn, + warn, + ssh_tunnel: str, +): + if version: + print("Version:", __version__) + sys.exit(0) + + config_dir = os.path.dirname(config_location()) + if not os.path.exists(config_dir): + os.makedirs(config_dir) + + # Migrate the config file from old location. + config_full_path = config_location() + "config" + if os.path.exists(os.path.expanduser("~/.pgclirc")): + if not os.path.exists(config_full_path): + shutil.move(os.path.expanduser("~/.pgclirc"), config_full_path) + print("Config file (~/.pgclirc) moved to new location", config_full_path) + else: + print("Config file is now located at", config_full_path) + print( + "Please move the existing config file ~/.pgclirc to", + config_full_path, + ) + if list_dsn: + try: + cfg = load_config(pgclirc, config_full_path) + for alias in cfg["alias_dsn"]: + click.secho(alias + " : " + cfg["alias_dsn"][alias]) + sys.exit(0) + except Exception as err: + click.secho( + "Invalid DSNs found in the config file. " + 'Please check the "[alias_dsn]" section in pgclirc.', + err=True, + fg="red", + ) + exit(1) + + if ssh_tunnel and not SSH_TUNNEL_SUPPORT: + click.secho( + 'Cannot open SSH tunnel, "sshtunnel" package was not found. ' + "Please install pgcli with `pip install pgcli[sshtunnel]` if you want SSH tunnel support.", + err=True, + fg="red", + ) + exit(1) + + pgcli = PGCli( + prompt_passwd, + never_prompt, + pgclirc_file=pgclirc, + row_limit=row_limit, + single_connection=single_connection, + less_chatty=less_chatty, + prompt=prompt, + prompt_dsn=prompt_dsn, + auto_vertical_output=auto_vertical_output, + warn=warn, + ssh_tunnel_url=ssh_tunnel, + ) + + # Choose which ever one has a valid value. + if dbname_opt and dbname: + # work as psql: when database is given as option and argument use the argument as user + username = dbname + database = dbname_opt or dbname or "" + user = username_opt or username + service = None + if database.startswith("service="): + service = database[8:] + elif os.getenv("PGSERVICE") is not None: + service = os.getenv("PGSERVICE") + # because option --list or -l are not supposed to have a db name + if list_databases: + database = "postgres" + + if dsn != "": + try: + cfg = load_config(pgclirc, config_full_path) + dsn_config = cfg["alias_dsn"][dsn] + except KeyError: + click.secho( + f"Could not find a DSN with alias {dsn}. " + 'Please check the "[alias_dsn]" section in pgclirc.', + err=True, + fg="red", + ) + exit(1) + except Exception: + click.secho( + "Invalid DSNs found in the config file. " + 'Please check the "[alias_dsn]" section in pgclirc.', + err=True, + fg="red", + ) + exit(1) + pgcli.connect_uri(dsn_config) + pgcli.dsn_alias = dsn + elif "://" in database: + pgcli.connect_uri(database) + elif "=" in database and service is None: + pgcli.connect_dsn(database, user=user) + elif service is not None: + pgcli.connect_service(service, user) + else: + pgcli.connect(database, host, user, port) + + if list_databases: + cur, headers, status = pgcli.pgexecute.full_databases() + + title = "List of databases" + settings = OutputSettings(table_format="ascii", missingval="<null>") + formatted = format_output(title, cur, headers, status, settings) + pgcli.echo_via_pager("\n".join(formatted)) + + sys.exit(0) + + pgcli.logger.debug( + "Launch Params: \n" "\tdatabase: %r" "\tuser: %r" "\thost: %r" "\tport: %r", + database, + user, + host, + port, + ) + + if setproctitle: + obfuscate_process_password() + + pgcli.run_cli() + + +def obfuscate_process_password(): + process_title = setproctitle.getproctitle() + if "://" in process_title: + process_title = re.sub(r":(.*):(.*)@", r":\1:xxxx@", process_title) + elif "=" in process_title: + process_title = re.sub( + r"password=(.+?)((\s[a-zA-Z]+=)|$)", r"password=xxxx\2", process_title + ) + + setproctitle.setproctitle(process_title) + + +def has_meta_cmd(query): + """Determines if the completion needs a refresh by checking if the sql + statement is an alter, create, drop, commit or rollback.""" + try: + first_token = query.split()[0] + if first_token.lower() in ("alter", "create", "drop", "commit", "rollback"): + return True + except Exception: + return False + + return False + + +def has_change_db_cmd(query): + """Determines if the statement is a database switch such as 'use' or '\\c'""" + try: + first_token = query.split()[0] + if first_token.lower() in ("use", "\\c", "\\connect"): + return True + except Exception: + return False + + return False + + +def has_change_path_cmd(sql): + """Determines if the search_path should be refreshed by checking if the + sql has 'set search_path'.""" + return "set search_path" in sql.lower() + + +def is_mutating(status): + """Determines if the statement is mutating based on the status.""" + if not status: + return False + + mutating = {"insert", "update", "delete"} + return status.split(None, 1)[0].lower() in mutating + + +def is_select(status): + """Returns true if the first word in status is 'select'.""" + if not status: + return False + return status.split(None, 1)[0].lower() == "select" + + +def exception_formatter(e): + return click.style(str(e), fg="red") + + +def format_output(title, cur, headers, status, settings, explain_mode=False): + output = [] + expanded = settings.expanded or settings.table_format == "vertical" + table_format = "vertical" if settings.expanded else settings.table_format + max_width = settings.max_width + case_function = settings.case_function + if explain_mode: + formatter = ExplainOutputFormatter(max_width or 100) + else: + formatter = TabularOutputFormatter(format_name=table_format) + + def format_array(val): + if val is None: + return settings.missingval + if not isinstance(val, list): + return val + return "{" + ",".join(str(format_array(e)) for e in val) + "}" + + def format_arrays(data, headers, **_): + data = list(data) + for row in data: + row[:] = [ + format_array(val) if isinstance(val, list) else val for val in row + ] + + return data, headers + + def format_status(cur, status): + # redshift does not return rowcount as part of status. + # See https://github.com/dbcli/pgcli/issues/1320 + if cur and hasattr(cur, "rowcount") and cur.rowcount is not None: + if status and not status.endswith(str(cur.rowcount)): + status += " %s" % cur.rowcount + return status + + output_kwargs = { + "sep_title": "RECORD {n}", + "sep_character": "-", + "sep_length": (1, 25), + "missing_value": settings.missingval, + "integer_format": settings.dcmlfmt, + "float_format": settings.floatfmt, + "preprocessors": (format_numbers, format_arrays), + "disable_numparse": True, + "preserve_whitespace": True, + "style": settings.style_output, + "max_field_width": settings.max_field_width, + } + if not settings.floatfmt: + output_kwargs["preprocessors"] = (align_decimals,) + + if table_format == "csv": + # The default CSV dialect is "excel" which is not handling newline values correctly + # Nevertheless, we want to keep on using "excel" on Windows since it uses '\r\n' + # as the line terminator + # https://github.com/dbcli/pgcli/issues/1102 + dialect = "excel" if platform.system() == "Windows" else "unix" + output_kwargs["dialect"] = dialect + + if title: # Only print the title if it's not None. + output.append(title) + + if cur: + headers = [case_function(x) for x in headers] + if max_width is not None: + cur = list(cur) + column_types = None + if hasattr(cur, "description"): + column_types = [] + for d in cur.description: + col_type = cur.adapters.types.get(d.type_code) + type_name = col_type.name if col_type else None + if type_name in ("numeric", "float4", "float8"): + column_types.append(float) + if type_name in ("int2", "int4", "int8"): + column_types.append(int) + else: + column_types.append(str) + + formatted = formatter.format_output(cur, headers, **output_kwargs) + if isinstance(formatted, str): + formatted = iter(formatted.splitlines()) + first_line = next(formatted) + formatted = itertools.chain([first_line], formatted) + if ( + not explain_mode + and not expanded + and max_width + and len(strip_ansi(first_line)) > max_width + and headers + ): + formatted = formatter.format_output( + cur, + headers, + format_name="vertical", + column_types=column_types, + **output_kwargs, + ) + if isinstance(formatted, str): + formatted = iter(formatted.splitlines()) + + output = itertools.chain(output, formatted) + + # Only print the status if it's not None + if status: + output = itertools.chain(output, [format_status(cur, status)]) + + return output + + +def parse_service_info(service): + service = service or os.getenv("PGSERVICE") + service_file = os.getenv("PGSERVICEFILE") + if not service_file: + # try ~/.pg_service.conf (if that exists) + if platform.system() == "Windows": + service_file = os.getenv("PGSYSCONFDIR") + "\\pg_service.conf" + elif os.getenv("PGSYSCONFDIR"): + service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf") + else: + service_file = os.path.expanduser("~/.pg_service.conf") + if not service or not os.path.exists(service_file): + # nothing to do + return None, service_file + with open(service_file, newline="") as f: + skipped_lines = skip_initial_comment(f) + try: + service_file_config = ConfigObj(f) + except ParseError as err: + err.line_number += skipped_lines + raise err + if service not in service_file_config: + return None, service_file + service_conf = service_file_config.get(service) + return service_conf, service_file + + +if __name__ == "__main__": + cli() diff --git a/pgcli/packages/__init__.py b/pgcli/packages/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/pgcli/packages/__init__.py diff --git a/pgcli/packages/formatter/__init__.py b/pgcli/packages/formatter/__init__.py new file mode 100644 index 0000000..9bad579 --- /dev/null +++ b/pgcli/packages/formatter/__init__.py @@ -0,0 +1 @@ +# coding=utf-8 diff --git a/pgcli/packages/formatter/sqlformatter.py b/pgcli/packages/formatter/sqlformatter.py new file mode 100644 index 0000000..5224eff --- /dev/null +++ b/pgcli/packages/formatter/sqlformatter.py @@ -0,0 +1,74 @@ +# coding=utf-8 + +from pgcli.packages.parseutils.tables import extract_tables + + +supported_formats = ( + "sql-insert", + "sql-update", + "sql-update-1", + "sql-update-2", +) + +preprocessors = () + + +def escape_for_sql_statement(value): + if value is None: + return "NULL" + + if isinstance(value, bytes): + return f"X'{value.hex()}'" + + return "'{}'".format(value) + + +def adapter(data, headers, table_format=None, **kwargs): + tables = extract_tables(formatter.query) + if len(tables) > 0: + table = tables[0] + if table[0]: + table_name = "{}.{}".format(*table[:2]) + else: + table_name = table[1] + else: + table_name = "DUAL" + if table_format == "sql-insert": + h = '", "'.join(headers) + yield 'INSERT INTO "{}" ("{}") VALUES'.format(table_name, h) + prefix = " " + for d in data: + values = ", ".join(escape_for_sql_statement(v) for i, v in enumerate(d)) + yield "{}({})".format(prefix, values) + if prefix == " ": + prefix = ", " + yield ";" + if table_format.startswith("sql-update"): + s = table_format.split("-") + keys = 1 + if len(s) > 2: + keys = int(s[-1]) + for d in data: + yield 'UPDATE "{}" SET'.format(table_name) + prefix = " " + for i, v in enumerate(d[keys:], keys): + yield '{}"{}" = {}'.format( + prefix, headers[i], escape_for_sql_statement(v) + ) + if prefix == " ": + prefix = ", " + f = '"{}" = {}' + where = ( + f.format(headers[i], escape_for_sql_statement(d[i])) + for i in range(keys) + ) + yield "WHERE {};".format(" AND ".join(where)) + + +def register_new_formatter(TabularOutputFormatter): + global formatter + formatter = TabularOutputFormatter + for sql_format in supported_formats: + TabularOutputFormatter.register_new_formatter( + sql_format, adapter, preprocessors, {"table_format": sql_format} + ) diff --git a/pgcli/packages/parseutils/__init__.py b/pgcli/packages/parseutils/__init__.py new file mode 100644 index 0000000..023e13b --- /dev/null +++ b/pgcli/packages/parseutils/__init__.py @@ -0,0 +1,58 @@ +import sqlparse + + +BASE_KEYWORDS = [ + "drop", + "shutdown", + "delete", + "truncate", + "alter", + "unconditional_update", +] +ALL_KEYWORDS = BASE_KEYWORDS + ["update"] + + +def query_starts_with(formatted_sql, prefixes): + """Check if the query starts with any item from *prefixes*.""" + prefixes = [prefix.lower() for prefix in prefixes] + return bool(formatted_sql) and formatted_sql.split()[0] in prefixes + + +def query_is_unconditional_update(formatted_sql): + """Check if the query starts with UPDATE and contains no WHERE.""" + tokens = formatted_sql.split() + return bool(tokens) and tokens[0] == "update" and "where" not in tokens + + +def is_destructive(queries, keywords): + """Returns if any of the queries in *queries* is destructive.""" + for query in sqlparse.split(queries): + if query: + formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip() + if "unconditional_update" in keywords and query_is_unconditional_update( + formatted_sql + ): + return True + if query_starts_with(formatted_sql, keywords): + return True + return False + + +def parse_destructive_warning(warning_level): + """Converts a deprecated destructive warning option to a list of command keywords.""" + if not warning_level: + return [] + + if not isinstance(warning_level, list): + if "," in warning_level: + return warning_level.split(",") + warning_level = [warning_level] + + return { + "true": ALL_KEYWORDS, + "false": [], + "all": ALL_KEYWORDS, + "moderate": BASE_KEYWORDS, + "off": [], + "": [], + }.get(warning_level[0], warning_level) diff --git a/pgcli/packages/parseutils/ctes.py b/pgcli/packages/parseutils/ctes.py new file mode 100644 index 0000000..e1f9088 --- /dev/null +++ b/pgcli/packages/parseutils/ctes.py @@ -0,0 +1,141 @@ +from sqlparse import parse +from sqlparse.tokens import Keyword, CTE, DML +from sqlparse.sql import Identifier, IdentifierList, Parenthesis +from collections import namedtuple +from .meta import TableMetadata, ColumnMetadata + + +# TableExpression is a namedtuple representing a CTE, used internally +# name: cte alias assigned in the query +# columns: list of column names +# start: index into the original string of the left parens starting the CTE +# stop: index into the original string of the right parens ending the CTE +TableExpression = namedtuple("TableExpression", "name columns start stop") + + +def isolate_query_ctes(full_text, text_before_cursor): + """Simplify a query by converting CTEs into table metadata objects""" + + if not full_text or not full_text.strip(): + return full_text, text_before_cursor, tuple() + + ctes, remainder = extract_ctes(full_text) + if not ctes: + return full_text, text_before_cursor, () + + current_position = len(text_before_cursor) + meta = [] + + for cte in ctes: + if cte.start < current_position < cte.stop: + # Currently editing a cte - treat its body as the current full_text + text_before_cursor = full_text[cte.start : current_position] + full_text = full_text[cte.start : cte.stop] + return full_text, text_before_cursor, meta + + # Append this cte to the list of available table metadata + cols = (ColumnMetadata(name, None, ()) for name in cte.columns) + meta.append(TableMetadata(cte.name, cols)) + + # Editing past the last cte (ie the main body of the query) + full_text = full_text[ctes[-1].stop :] + text_before_cursor = text_before_cursor[ctes[-1].stop : current_position] + + return full_text, text_before_cursor, tuple(meta) + + +def extract_ctes(sql): + """Extract constant table expresseions from a query + + Returns tuple (ctes, remainder_sql) + + ctes is a list of TableExpression namedtuples + remainder_sql is the text from the original query after the CTEs have + been stripped. + """ + + p = parse(sql)[0] + + # Make sure the first meaningful token is "WITH" which is necessary to + # define CTEs + idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True) + if not (tok and tok.ttype == CTE): + return [], sql + + # Get the next (meaningful) token, which should be the first CTE + idx, tok = p.token_next(idx) + if not tok: + return ([], "") + start_pos = token_start_pos(p.tokens, idx) + ctes = [] + + if isinstance(tok, IdentifierList): + # Multiple ctes + for t in tok.get_identifiers(): + cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t)) + cte = get_cte_from_token(t, start_pos + cte_start_offset) + if not cte: + continue + ctes.append(cte) + elif isinstance(tok, Identifier): + # A single CTE + cte = get_cte_from_token(tok, start_pos) + if cte: + ctes.append(cte) + + idx = p.token_index(tok) + 1 + + # Collapse everything after the ctes into a remainder query + remainder = "".join(str(tok) for tok in p.tokens[idx:]) + + return ctes, remainder + + +def get_cte_from_token(tok, pos0): + cte_name = tok.get_real_name() + if not cte_name: + return None + + # Find the start position of the opening parens enclosing the cte body + idx, parens = tok.token_next_by(Parenthesis) + if not parens: + return None + + start_pos = pos0 + token_start_pos(tok.tokens, idx) + cte_len = len(str(parens)) # includes parens + stop_pos = start_pos + cte_len + + column_names = extract_column_names(parens) + + return TableExpression(cte_name, column_names, start_pos, stop_pos) + + +def extract_column_names(parsed): + # Find the first DML token to check if it's a SELECT or INSERT/UPDATE/DELETE + idx, tok = parsed.token_next_by(t=DML) + tok_val = tok and tok.value.lower() + + if tok_val in ("insert", "update", "delete"): + # Jump ahead to the RETURNING clause where the list of column names is + idx, tok = parsed.token_next_by(idx, (Keyword, "returning")) + elif not tok_val == "select": + # Must be invalid CTE + return () + + # The next token should be either a column name, or a list of column names + idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True) + return tuple(t.get_name() for t in _identifiers(tok)) + + +def token_start_pos(tokens, idx): + return sum(len(str(t)) for t in tokens[:idx]) + + +def _identifiers(tok): + if isinstance(tok, IdentifierList): + for t in tok.get_identifiers(): + # NB: IdentifierList.get_identifiers() can return non-identifiers! + if isinstance(t, Identifier): + yield t + elif isinstance(tok, Identifier): + yield tok diff --git a/pgcli/packages/parseutils/meta.py b/pgcli/packages/parseutils/meta.py new file mode 100644 index 0000000..333cab5 --- /dev/null +++ b/pgcli/packages/parseutils/meta.py @@ -0,0 +1,170 @@ +from collections import namedtuple + +_ColumnMetadata = namedtuple( + "ColumnMetadata", ["name", "datatype", "foreignkeys", "default", "has_default"] +) + + +def ColumnMetadata(name, datatype, foreignkeys=None, default=None, has_default=False): + return _ColumnMetadata(name, datatype, foreignkeys or [], default, has_default) + + +ForeignKey = namedtuple( + "ForeignKey", + [ + "parentschema", + "parenttable", + "parentcolumn", + "childschema", + "childtable", + "childcolumn", + ], +) +TableMetadata = namedtuple("TableMetadata", "name columns") + + +def parse_defaults(defaults_string): + """Yields default values for a function, given the string provided by + pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)""" + if not defaults_string: + return + current = "" + in_quote = None + for char in defaults_string: + if current == "" and char == " ": + # Skip space after comma separating default expressions + continue + if char == '"' or char == "'": + if in_quote and char == in_quote: + # End quote + in_quote = None + elif not in_quote: + # Begin quote + in_quote = char + elif char == "," and not in_quote: + # End of expression + yield current + current = "" + continue + current += char + yield current + + +class FunctionMetadata: + def __init__( + self, + schema_name, + func_name, + arg_names, + arg_types, + arg_modes, + return_type, + is_aggregate, + is_window, + is_set_returning, + is_extension, + arg_defaults, + ): + """Class for describing a postgresql function""" + + self.schema_name = schema_name + self.func_name = func_name + + self.arg_modes = tuple(arg_modes) if arg_modes else None + self.arg_names = tuple(arg_names) if arg_names else None + + # Be flexible in not requiring arg_types -- use None as a placeholder + # for each arg. (Used for compatibility with old versions of postgresql + # where such info is hard to get. + if arg_types: + self.arg_types = tuple(arg_types) + elif arg_modes: + self.arg_types = tuple([None] * len(arg_modes)) + elif arg_names: + self.arg_types = tuple([None] * len(arg_names)) + else: + self.arg_types = None + + self.arg_defaults = tuple(parse_defaults(arg_defaults)) + + self.return_type = return_type.strip() + self.is_aggregate = is_aggregate + self.is_window = is_window + self.is_set_returning = is_set_returning + self.is_extension = bool(is_extension) + self.is_public = self.schema_name and self.schema_name == "public" + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not self.__eq__(other) + + def _signature(self): + return ( + self.schema_name, + self.func_name, + self.arg_names, + self.arg_types, + self.arg_modes, + self.return_type, + self.is_aggregate, + self.is_window, + self.is_set_returning, + self.is_extension, + self.arg_defaults, + ) + + def __hash__(self): + return hash(self._signature()) + + def __repr__(self): + return ( + "%s(schema_name=%r, func_name=%r, arg_names=%r, " + "arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, " + "is_window=%r, is_set_returning=%r, is_extension=%r, arg_defaults=%r)" + ) % ((self.__class__.__name__,) + self._signature()) + + def has_variadic(self): + return self.arg_modes and any(arg_mode == "v" for arg_mode in self.arg_modes) + + def args(self): + """Returns a list of input-parameter ColumnMetadata namedtuples.""" + if not self.arg_names: + return [] + modes = self.arg_modes or ["i"] * len(self.arg_names) + args = [ + (name, typ) + for name, typ, mode in zip(self.arg_names, self.arg_types, modes) + if mode in ("i", "b", "v") # IN, INOUT, VARIADIC + ] + + def arg(name, typ, num): + num_args = len(args) + num_defaults = len(self.arg_defaults) + has_default = num + num_defaults >= num_args + default = ( + self.arg_defaults[num - num_args + num_defaults] + if has_default + else None + ) + return ColumnMetadata(name, typ, [], default, has_default) + + return [arg(name, typ, num) for num, (name, typ) in enumerate(args)] + + def fields(self): + """Returns a list of output-field ColumnMetadata namedtuples""" + + if self.return_type.lower() == "void": + return [] + elif not self.arg_modes: + # For functions without output parameters, the function name + # is used as the name of the output column. + # E.g. 'SELECT unnest FROM unnest(...);' + return [ColumnMetadata(self.func_name, self.return_type, [])] + + return [ + ColumnMetadata(name, typ, []) + for name, typ, mode in zip(self.arg_names, self.arg_types, self.arg_modes) + if mode in ("o", "b", "t") + ] # OUT, INOUT, TABLE diff --git a/pgcli/packages/parseutils/tables.py b/pgcli/packages/parseutils/tables.py new file mode 100644 index 0000000..9098115 --- /dev/null +++ b/pgcli/packages/parseutils/tables.py @@ -0,0 +1,165 @@ +import sqlparse +from collections import namedtuple +from sqlparse.sql import IdentifierList, Identifier, Function +from sqlparse.tokens import Keyword, DML, Punctuation + +TableReference = namedtuple( + "TableReference", ["schema", "name", "alias", "is_function"] +) +TableReference.ref = property( + lambda self: self.alias + or ( + self.name + if self.name.islower() or self.name[0] == '"' + else '"' + self.name + '"' + ) +) + + +# This code is borrowed from sqlparse example script. +# <url> +def is_subselect(parsed): + if not parsed.is_group: + return False + for item in parsed.tokens: + if item.ttype is DML and item.value.upper() in ( + "SELECT", + "INSERT", + "UPDATE", + "CREATE", + "DELETE", + ): + return True + return False + + +def _identifier_is_function(identifier): + return any(isinstance(t, Function) for t in identifier.tokens) + + +def extract_from_part(parsed, stop_at_punctuation=True): + tbl_prefix_seen = False + for item in parsed.tokens: + if tbl_prefix_seen: + if is_subselect(item): + yield from extract_from_part(item, stop_at_punctuation) + elif stop_at_punctuation and item.ttype is Punctuation: + return + # An incomplete nested select won't be recognized correctly as a + # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes + # the second FROM to trigger this elif condition resulting in a + # `return`. So we need to ignore the keyword if the keyword + # FROM. + # Also 'SELECT * FROM abc JOIN def' will trigger this elif + # condition. So we need to ignore the keyword JOIN and its variants + # INNER JOIN, FULL OUTER JOIN, etc. + elif ( + item.ttype is Keyword + and (not item.value.upper() == "FROM") + and (not item.value.upper().endswith("JOIN")) + ): + tbl_prefix_seen = False + else: + yield item + elif item.ttype is Keyword or item.ttype is Keyword.DML: + item_val = item.value.upper() + if item_val in ( + "COPY", + "FROM", + "INTO", + "UPDATE", + "TABLE", + ) or item_val.endswith("JOIN"): + tbl_prefix_seen = True + # 'SELECT a, FROM abc' will detect FROM as part of the column list. + # So this check here is necessary. + elif isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + if identifier.ttype is Keyword and identifier.value.upper() == "FROM": + tbl_prefix_seen = True + break + + +def extract_table_identifiers(token_stream, allow_functions=True): + """yields tuples of TableReference namedtuples""" + + # We need to do some massaging of the names because postgres is case- + # insensitive and '"Foo"' is not the same table as 'Foo' (while 'foo' is) + def parse_identifier(item): + name = item.get_real_name() + schema_name = item.get_parent_name() + alias = item.get_alias() + if not name: + schema_name = None + name = item.get_name() + alias = alias or name + schema_quoted = schema_name and item.value[0] == '"' + if schema_name and not schema_quoted: + schema_name = schema_name.lower() + quote_count = item.value.count('"') + name_quoted = quote_count > 2 or (quote_count and not schema_quoted) + alias_quoted = alias and item.value[-1] == '"' + if alias_quoted or name_quoted and not alias and name.islower(): + alias = '"' + (alias or name) + '"' + if name and not name_quoted and not name.islower(): + if not alias: + alias = name + name = name.lower() + return schema_name, name, alias + + try: + for item in token_stream: + if isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + # Sometimes Keywords (such as FROM ) are classified as + # identifiers which don't have the get_real_name() method. + try: + schema_name = identifier.get_parent_name() + real_name = identifier.get_real_name() + is_function = allow_functions and _identifier_is_function( + identifier + ) + except AttributeError: + continue + if real_name: + yield TableReference( + schema_name, real_name, identifier.get_alias(), is_function + ) + elif isinstance(item, Identifier): + schema_name, real_name, alias = parse_identifier(item) + is_function = allow_functions and _identifier_is_function(item) + + yield TableReference(schema_name, real_name, alias, is_function) + elif isinstance(item, Function): + schema_name, real_name, alias = parse_identifier(item) + yield TableReference(None, real_name, alias, allow_functions) + except StopIteration: + return + + +# extract_tables is inspired from examples in the sqlparse lib. +def extract_tables(sql): + """Extract the table names from an SQL statement. + + Returns a list of TableReference namedtuples + + """ + parsed = sqlparse.parse(sql) + if not parsed: + return () + + # INSERT statements must stop looking for tables at the sign of first + # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2) + # abc is the table name, but if we don't stop at the first lparen, then + # we'll identify abc, col1 and col2 as table names. + insert_stmt = parsed[0].token_first().value.lower() == "insert" + stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt) + + # Kludge: sqlparse mistakenly identifies insert statements as + # function calls due to the parenthesized column list, e.g. interprets + # "insert into foo (bar, baz)" as a function call to foo with arguments + # (bar, baz). So don't allow any identifiers in insert statements + # to have is_function=True + identifiers = extract_table_identifiers(stream, allow_functions=not insert_stmt) + # In the case 'sche.<cursor>', we get an empty TableReference; remove that + return tuple(i for i in identifiers if i.name) diff --git a/pgcli/packages/parseutils/utils.py b/pgcli/packages/parseutils/utils.py new file mode 100644 index 0000000..034c96e --- /dev/null +++ b/pgcli/packages/parseutils/utils.py @@ -0,0 +1,140 @@ +import re +import sqlparse +from sqlparse.sql import Identifier +from sqlparse.tokens import Token, Error + +cleanup_regex = { + # This matches only alphanumerics and underscores. + "alphanum_underscore": re.compile(r"(\w+)$"), + # This matches everything except spaces, parens, colon, and comma + "many_punctuations": re.compile(r"([^():,\s]+)$"), + # This matches everything except spaces, parens, colon, comma, and period + "most_punctuations": re.compile(r"([^\.():,\s]+)$"), + # This matches everything except a space. + "all_punctuations": re.compile(r"([^\s]+)$"), +} + + +def last_word(text, include="alphanum_underscore"): + r""" + Find the last word in a sentence. + + >>> last_word('abc') + 'abc' + >>> last_word(' abc') + 'abc' + >>> last_word('') + '' + >>> last_word(' ') + '' + >>> last_word('abc ') + '' + >>> last_word('abc def') + 'def' + >>> last_word('abc def ') + '' + >>> last_word('abc def;') + '' + >>> last_word('bac $def') + 'def' + >>> last_word('bac $def', include='most_punctuations') + '$def' + >>> last_word('bac \def', include='most_punctuations') + '\\\\def' + >>> last_word('bac \def;', include='most_punctuations') + '\\\\def;' + >>> last_word('bac::def', include='most_punctuations') + 'def' + >>> last_word('"foo*bar', include='most_punctuations') + '"foo*bar' + """ + + if not text: # Empty string + return "" + + if text[-1].isspace(): + return "" + else: + regex = cleanup_regex[include] + matches = regex.search(text) + if matches: + return matches.group(0) + else: + return "" + + +def find_prev_keyword(sql, n_skip=0): + """Find the last sql keyword in an SQL statement + + Returns the value of the last keyword, and the text of the query with + everything after the last keyword stripped + """ + if not sql.strip(): + return None, "" + + parsed = sqlparse.parse(sql)[0] + flattened = list(parsed.flatten()) + flattened = flattened[: len(flattened) - n_skip] + + logical_operators = ("AND", "OR", "NOT", "BETWEEN") + + for t in reversed(flattened): + if t.value == "(" or ( + t.is_keyword and (t.value.upper() not in logical_operators) + ): + # Find the location of token t in the original parsed statement + # We can't use parsed.token_index(t) because t may be a child token + # inside a TokenList, in which case token_index throws an error + # Minimal example: + # p = sqlparse.parse('select * from foo where bar') + # t = list(p.flatten())[-3] # The "Where" token + # p.token_index(t) # Throws ValueError: not in list + idx = flattened.index(t) + + # Combine the string values of all tokens in the original list + # up to and including the target keyword token t, to produce a + # query string with everything after the keyword token removed + text = "".join(tok.value for tok in flattened[: idx + 1]) + return t, text + + return None, "" + + +# Postgresql dollar quote signs look like `$$` or `$tag$` +dollar_quote_regex = re.compile(r"^\$[^$]*\$$") + + +def is_open_quote(sql): + """Returns true if the query contains an unclosed quote""" + + # parsed can contain one or more semi-colon separated commands + parsed = sqlparse.parse(sql) + return any(_parsed_is_open_quote(p) for p in parsed) + + +def _parsed_is_open_quote(parsed): + # Look for unmatched single quotes, or unmatched dollar sign quotes + return any(tok.match(Token.Error, ("'", "$")) for tok in parsed.flatten()) + + +def parse_partial_identifier(word): + """Attempt to parse a (partially typed) word as an identifier + + word may include a schema qualification, like `schema_name.partial_name` + or `schema_name.` There may also be unclosed quotation marks, like + `"schema`, or `schema."partial_name` + + :param word: string representing a (partially complete) identifier + :return: sqlparse.sql.Identifier, or None + """ + + p = sqlparse.parse(word)[0] + n_tok = len(p.tokens) + if n_tok == 1 and isinstance(p.tokens[0], Identifier): + return p.tokens[0] + elif p.token_next_by(m=(Error, '"'))[1]: + # An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar' + # Close the double quote, then reparse + return parse_partial_identifier(word + '"') + else: + return None diff --git a/pgcli/packages/pgliterals/__init__.py b/pgcli/packages/pgliterals/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/pgcli/packages/pgliterals/__init__.py diff --git a/pgcli/packages/pgliterals/main.py b/pgcli/packages/pgliterals/main.py new file mode 100644 index 0000000..5c39296 --- /dev/null +++ b/pgcli/packages/pgliterals/main.py @@ -0,0 +1,15 @@ +import os +import json + +root = os.path.dirname(__file__) +literal_file = os.path.join(root, "pgliterals.json") + +with open(literal_file) as f: + literals = json.load(f) + + +def get_literals(literal_type, type_=tuple): + # Where `literal_type` is one of 'keywords', 'functions', 'datatypes', + # returns a tuple of literal values of that type. + + return type_(literals[literal_type]) diff --git a/pgcli/packages/pgliterals/pgliterals.json b/pgcli/packages/pgliterals/pgliterals.json new file mode 100644 index 0000000..df00817 --- /dev/null +++ b/pgcli/packages/pgliterals/pgliterals.json @@ -0,0 +1,630 @@ +{ + "keywords": { + "ACCESS": [], + "ADD": [], + "ALL": [], + "ALTER": [ + "AGGREGATE", + "COLLATION", + "COLUMN", + "CONVERSION", + "DATABASE", + "DEFAULT", + "DOMAIN", + "EVENT TRIGGER", + "EXTENSION", + "FOREIGN", + "FUNCTION", + "GROUP", + "INDEX", + "LANGUAGE", + "LARGE OBJECT", + "MATERIALIZED VIEW", + "OPERATOR", + "POLICY", + "ROLE", + "RULE", + "SCHEMA", + "SEQUENCE", + "SERVER", + "SYSTEM", + "TABLE", + "TABLESPACE", + "TEXT SEARCH", + "TRIGGER", + "TYPE", + "USER", + "VIEW" + ], + "AND": [], + "ANY": [], + "AS": [], + "ASC": [], + "AUDIT": [], + "BEGIN": [], + "BETWEEN": [], + "BY": [], + "CASE": [], + "CHAR": [], + "CHECK": [], + "CLUSTER": [], + "COLUMN": [], + "COMMENT": [], + "COMMIT": [], + "COMPRESS": [], + "CONCURRENTLY": [], + "CONNECT": [], + "COPY": [], + "CREATE": [ + "ACCESS METHOD", + "AGGREGATE", + "CAST", + "COLLATION", + "CONVERSION", + "DATABASE", + "DOMAIN", + "EVENT TRIGGER", + "EXTENSION", + "FOREIGN DATA WRAPPER", + "FOREIGN EXTENSION", + "FUNCTION", + "GLOBAL", + "GROUP", + "IF NOT EXISTS", + "INDEX", + "LANGUAGE", + "LOCAL", + "MATERIALIZED VIEW", + "OPERATOR", + "OR REPLACE", + "POLICY", + "ROLE", + "RULE", + "SCHEMA", + "SEQUENCE", + "SERVER", + "TABLE", + "TABLESPACE", + "TEMPORARY", + "TEXT SEARCH", + "TRIGGER", + "TYPE", + "UNIQUE", + "UNLOGGED", + "USER", + "USER MAPPING", + "VIEW" + ], + "CURRENT": [], + "DATABASE": [], + "DATE": [], + "DECIMAL": [], + "DEFAULT": [], + "DELETE FROM": [], + "DELIMITER": [], + "DESC": [], + "DESCRIBE": [], + "DISTINCT": [], + "DROP": [ + "ACCESS METHOD", + "AGGREGATE", + "CAST", + "COLLATION", + "COLUMN", + "CONVERSION", + "DATABASE", + "DOMAIN", + "EVENT TRIGGER", + "EXTENSION", + "FOREIGN DATA WRAPPER", + "FOREIGN TABLE", + "FUNCTION", + "GROUP", + "INDEX", + "LANGUAGE", + "MATERIALIZED VIEW", + "OPERATOR", + "OWNED", + "POLICY", + "ROLE", + "RULE", + "SCHEMA", + "SEQUENCE", + "SERVER", + "TABLE", + "TABLESPACE", + "TEXT SEARCH", + "TRANSFORM", + "TRIGGER", + "TYPE", + "USER", + "USER MAPPING", + "VIEW" + ], + "EXPLAIN": [], + "ELSE": [], + "ENCODING": [], + "ESCAPE": [], + "EXCLUSIVE": [], + "EXISTS": [], + "EXTENSION": [], + "FILE": [], + "FLOAT": [], + "FOR": [], + "FORMAT": [], + "FORCE_QUOTE": [], + "FORCE_NOT_NULL": [], + "FREEZE": [], + "FROM": [], + "FULL": [], + "FUNCTION": [], + "GRANT": [], + "GROUP BY": [], + "HAVING": [], + "HEADER": [], + "IDENTIFIED": [], + "IMMEDIATE": [], + "IN": [], + "INCREMENT": [], + "INDEX": [], + "INITIAL": [], + "INSERT INTO": [], + "INTEGER": [], + "INTERSECT": [], + "INTERVAL": [], + "INTO": [], + "IS": [], + "JOIN": [], + "LANGUAGE": [], + "LEFT": [], + "LEVEL": [], + "LIKE": [], + "LIMIT": [], + "LOCK": [], + "LONG": [], + "MATERIALIZED VIEW": [], + "MAXEXTENTS": [], + "MINUS": [], + "MLSLABEL": [], + "MODE": [], + "MODIFY": [], + "NOT": [], + "NOAUDIT": [], + "NOTICE": [], + "NOCOMPRESS": [], + "NOWAIT": [], + "NULL": [], + "NUMBER": [], + "OIDS": [], + "OF": [], + "OFFLINE": [], + "ON": [], + "ONLINE": [], + "OPTION": [], + "OR": [], + "ORDER BY": [], + "OUTER": [], + "OWNER": [], + "PCTFREE": [], + "PRIMARY": [], + "PRIOR": [], + "PRIVILEGES": [], + "QUOTE": [], + "RAISE": [], + "RENAME": [], + "REPLACE": [], + "RESET": ["ALL"], + "RAW": [], + "REFRESH MATERIALIZED VIEW": [], + "RESOURCE": [], + "RETURNS": [], + "REVOKE": [], + "RIGHT": [], + "ROLLBACK": [], + "ROW": [], + "ROWID": [], + "ROWNUM": [], + "ROWS": [], + "SELECT": [], + "SESSION": [], + "SET": [], + "SHARE": [], + "SHOW": [], + "SIZE": [], + "SMALLINT": [], + "START": [], + "SUCCESSFUL": [], + "SYNONYM": [], + "SYSDATE": [], + "TABLE": [], + "TEMPLATE": [], + "THEN": [], + "TO": [], + "TRIGGER": [], + "TRUNCATE": [], + "UID": [], + "UNION": [], + "UNIQUE": [], + "UPDATE": [], + "USE": [], + "USER": [], + "USING": [], + "VALIDATE": [], + "VALUES": [], + "VARCHAR": [], + "VARCHAR2": [], + "VIEW": [], + "WHEN": [], + "WHENEVER": [], + "WHERE": [], + "WITH": [] + }, + "functions": [ + "ABBREV", + "ABS", + "AGE", + "AREA", + "ARRAY_AGG", + "ARRAY_APPEND", + "ARRAY_CAT", + "ARRAY_DIMS", + "ARRAY_FILL", + "ARRAY_LENGTH", + "ARRAY_LOWER", + "ARRAY_NDIMS", + "ARRAY_POSITION", + "ARRAY_POSITIONS", + "ARRAY_PREPEND", + "ARRAY_REMOVE", + "ARRAY_REPLACE", + "ARRAY_TO_STRING", + "ARRAY_UPPER", + "ASCII", + "AVG", + "BIT_AND", + "BIT_LENGTH", + "BIT_OR", + "BOOL_AND", + "BOOL_OR", + "BOUND_BOX", + "BOX", + "BROADCAST", + "BTRIM", + "CARDINALITY", + "CBRT", + "CEIL", + "CEILING", + "CENTER", + "CHAR_LENGTH", + "CHR", + "CIRCLE", + "CLOCK_TIMESTAMP", + "CONCAT", + "CONCAT_WS", + "CONVERT", + "CONVERT_FROM", + "CONVERT_TO", + "COUNT", + "CUME_DIST", + "CURRENT_DATE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "DATE_PART", + "DATE_TRUNC", + "DECODE", + "DEGREES", + "DENSE_RANK", + "DIAMETER", + "DIV", + "ENCODE", + "ENUM_FIRST", + "ENUM_LAST", + "ENUM_RANGE", + "EVERY", + "EXP", + "EXTRACT", + "FAMILY", + "FIRST_VALUE", + "FLOOR", + "FORMAT", + "GET_BIT", + "GET_BYTE", + "HEIGHT", + "HOST", + "HOSTMASK", + "INET_MERGE", + "INET_SAME_FAMILY", + "INITCAP", + "ISCLOSED", + "ISFINITE", + "ISOPEN", + "JUSTIFY_DAYS", + "JUSTIFY_HOURS", + "JUSTIFY_INTERVAL", + "LAG", + "LAST_VALUE", + "LEAD", + "LEFT", + "LENGTH", + "LINE", + "LN", + "LOCALTIME", + "LOCALTIMESTAMP", + "LOG", + "LOG10", + "LOWER", + "LPAD", + "LSEG", + "LTRIM", + "MAKE_DATE", + "MAKE_INTERVAL", + "MAKE_TIME", + "MAKE_TIMESTAMP", + "MAKE_TIMESTAMPTZ", + "MASKLEN", + "MAX", + "MD5", + "MIN", + "MOD", + "NETMASK", + "NETWORK", + "NOW", + "NPOINTS", + "NTH_VALUE", + "NTILE", + "NUM_NONNULLS", + "NUM_NULLS", + "OCTET_LENGTH", + "OVERLAY", + "PARSE_IDENT", + "PATH", + "PCLOSE", + "PERCENT_RANK", + "PG_CLIENT_ENCODING", + "PI", + "POINT", + "POLYGON", + "POPEN", + "POSITION", + "POWER", + "QUOTE_IDENT", + "QUOTE_LITERAL", + "QUOTE_NULLABLE", + "RADIANS", + "RADIUS", + "RANDOM", + "RANK", + "REGEXP_MATCH", + "REGEXP_MATCHES", + "REGEXP_REPLACE", + "REGEXP_SPLIT_TO_ARRAY", + "REGEXP_SPLIT_TO_TABLE", + "REPEAT", + "REPLACE", + "REVERSE", + "RIGHT", + "ROUND", + "ROW_NUMBER", + "RPAD", + "RTRIM", + "SCALE", + "SET_BIT", + "SET_BYTE", + "SET_MASKLEN", + "SHA224", + "SHA256", + "SHA384", + "SHA512", + "SIGN", + "SPLIT_PART", + "SQRT", + "STARTS_WITH", + "STATEMENT_TIMESTAMP", + "STRING_TO_ARRAY", + "STRPOS", + "SUBSTR", + "SUBSTRING", + "SUM", + "TEXT", + "TIMEOFDAY", + "TO_ASCII", + "TO_CHAR", + "TO_DATE", + "TO_HEX", + "TO_NUMBER", + "TO_TIMESTAMP", + "TRANSACTION_TIMESTAMP", + "TRANSLATE", + "TRIM", + "TRUNC", + "UNNEST", + "UPPER", + "WIDTH", + "WIDTH_BUCKET", + "XMLAGG" + ], + "datatypes": [ + "ANY", + "ANYARRAY", + "ANYELEMENT", + "ANYENUM", + "ANYNONARRAY", + "ANYRANGE", + "BIGINT", + "BIGSERIAL", + "BIT", + "BIT VARYING", + "BOOL", + "BOOLEAN", + "BOX", + "BYTEA", + "CHAR", + "CHARACTER", + "CHARACTER VARYING", + "CIDR", + "CIRCLE", + "CSTRING", + "DATE", + "DECIMAL", + "DOUBLE PRECISION", + "EVENT_TRIGGER", + "FDW_HANDLER", + "FLOAT4", + "FLOAT8", + "INET", + "INT", + "INT2", + "INT4", + "INT8", + "INTEGER", + "INTERNAL", + "INTERVAL", + "JSON", + "JSONB", + "LANGUAGE_HANDLER", + "LINE", + "LSEG", + "MACADDR", + "MACADDR8", + "MONEY", + "NUMERIC", + "OID", + "OPAQUE", + "PATH", + "PG_LSN", + "POINT", + "POLYGON", + "REAL", + "RECORD", + "REGCLASS", + "REGCONFIG", + "REGDICTIONARY", + "REGNAMESPACE", + "REGOPER", + "REGOPERATOR", + "REGPROC", + "REGPROCEDURE", + "REGROLE", + "REGTYPE", + "SERIAL", + "SERIAL2", + "SERIAL4", + "SERIAL8", + "SMALLINT", + "SMALLSERIAL", + "TEXT", + "TIME", + "TIMESTAMP", + "TRIGGER", + "TSQUERY", + "TSVECTOR", + "TXID_SNAPSHOT", + "UUID", + "VARBIT", + "VARCHAR", + "VOID", + "XML" + ], + "reserved": [ + "ALL", + "ANALYSE", + "ANALYZE", + "AND", + "ANY", + "ARRAY", + "AS", + "ASC", + "ASYMMETRIC", + "BOTH", + "CASE", + "CAST", + "CHECK", + "COLLATE", + "COLUMN", + "CONSTRAINT", + "CREATE", + "CURRENT_CATALOG", + "CURRENT_DATE", + "CURRENT_ROLE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "CURRENT_USER", + "DEFAULT", + "DEFERRABLE", + "DESC", + "DISTINCT", + "DO", + "ELSE", + "END", + "EXCEPT", + "FALSE", + "FETCH", + "FOR", + "FOREIGN", + "FROM", + "GRANT", + "GROUP", + "HAVING", + "IN", + "INITIALLY", + "INTERSECT", + "INTO", + "LATERAL", + "LEADING", + "LIMIT", + "LOCALTIME", + "LOCALTIMESTAMP", + "NOT", + "NULL", + "OFFSET", + "ON", + "ONLY", + "OR", + "ORDER", + "PLACING", + "PRIMARY", + "REFERENCES", + "RETURNING", + "SELECT", + "SESSION_USER", + "SOME", + "SYMMETRIC", + "TABLE", + "THEN", + "TO", + "TRAILING", + "TRUE", + "UNION", + "UNIQUE", + "USER", + "USING", + "VARIADIC", + "WHEN", + "WHERE", + "WINDOW", + "WITH", + "AUTHORIZATION", + "BINARY", + "COLLATION", + "CONCURRENTLY", + "CROSS", + "CURRENT_SCHEMA", + "FREEZE", + "FULL", + "ILIKE", + "INNER", + "IS", + "ISNULL", + "JOIN", + "LEFT", + "LIKE", + "NATURAL", + "NOTNULL", + "OUTER", + "OVERLAPS", + "RIGHT", + "SIMILAR", + "TABLESAMPLE", + "VERBOSE" + ] +} diff --git a/pgcli/packages/prioritization.py b/pgcli/packages/prioritization.py new file mode 100644 index 0000000..f5a9cb5 --- /dev/null +++ b/pgcli/packages/prioritization.py @@ -0,0 +1,51 @@ +import re +import sqlparse +from sqlparse.tokens import Name +from collections import defaultdict +from .pgliterals.main import get_literals + + +white_space_regex = re.compile("\\s+", re.MULTILINE) + + +def _compile_regex(keyword): + # Surround the keyword with word boundaries and replace interior whitespace + # with whitespace wildcards + pattern = "\\b" + white_space_regex.sub(r"\\s+", keyword) + "\\b" + return re.compile(pattern, re.MULTILINE | re.IGNORECASE) + + +keywords = get_literals("keywords") +keyword_regexs = {kw: _compile_regex(kw) for kw in keywords} + + +class PrevalenceCounter: + def __init__(self): + self.keyword_counts = defaultdict(int) + self.name_counts = defaultdict(int) + + def update(self, text): + self.update_keywords(text) + self.update_names(text) + + def update_names(self, text): + for parsed in sqlparse.parse(text): + for token in parsed.flatten(): + if token.ttype in Name: + self.name_counts[token.value] += 1 + + def clear_names(self): + self.name_counts = defaultdict(int) + + def update_keywords(self, text): + # Count keywords. Can't rely for sqlparse for this, because it's + # database agnostic + for keyword, regex in keyword_regexs.items(): + for _ in regex.finditer(text): + self.keyword_counts[keyword] += 1 + + def keyword_count(self, keyword): + return self.keyword_counts[keyword] + + def name_count(self, name): + return self.name_counts[name] diff --git a/pgcli/packages/prompt_utils.py b/pgcli/packages/prompt_utils.py new file mode 100644 index 0000000..997b86e --- /dev/null +++ b/pgcli/packages/prompt_utils.py @@ -0,0 +1,37 @@ +import sys +import click +from .parseutils import is_destructive + + +def confirm_destructive_query(queries, keywords, alias): + """Check if the query is destructive and prompts the user to confirm. + + Returns: + * None if the query is non-destructive or we can't prompt the user. + * True if the query is destructive and the user wants to proceed. + * False if the query is destructive and the user doesn't want to proceed. + + """ + info = "You're about to run a destructive command" + if alias: + info += f" in {click.style(alias, fg='red')}" + + prompt_text = f"{info}.\nDo you want to proceed?" + if is_destructive(queries, keywords) and sys.stdin.isatty(): + return confirm(prompt_text) + + +def confirm(*args, **kwargs): + """Prompt for confirmation (yes/no) and handle any abort exceptions.""" + try: + return click.confirm(*args, **kwargs) + except click.Abort: + return False + + +def prompt(*args, **kwargs): + """Prompt the user for input and handle any abort exceptions.""" + try: + return click.prompt(*args, **kwargs) + except click.Abort: + return False diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py new file mode 100644 index 0000000..b78edd6 --- /dev/null +++ b/pgcli/packages/sqlcompletion.py @@ -0,0 +1,605 @@ +import sys +import re +import sqlparse +from collections import namedtuple +from sqlparse.sql import Comparison, Identifier, Where +from .parseutils.utils import last_word, find_prev_keyword, parse_partial_identifier +from .parseutils.tables import extract_tables +from .parseutils.ctes import isolate_query_ctes +from pgspecial.main import parse_special_command + + +Special = namedtuple("Special", []) +Database = namedtuple("Database", []) +Schema = namedtuple("Schema", ["quoted"]) +Schema.__new__.__defaults__ = (False,) +# FromClauseItem is a table/view/function used in the FROM clause +# `table_refs` contains the list of tables/... already in the statement, +# used to ensure that the alias we suggest is unique +FromClauseItem = namedtuple("FromClauseItem", "schema table_refs local_tables") +Table = namedtuple("Table", ["schema", "table_refs", "local_tables"]) +TableFormat = namedtuple("TableFormat", []) +View = namedtuple("View", ["schema", "table_refs"]) +# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid' +JoinCondition = namedtuple("JoinCondition", ["table_refs", "parent"]) +# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid' +Join = namedtuple("Join", ["table_refs", "schema"]) + +Function = namedtuple("Function", ["schema", "table_refs", "usage"]) +# For convenience, don't require the `usage` argument in Function constructor +Function.__new__.__defaults__ = (None, tuple(), None) +Table.__new__.__defaults__ = (None, tuple(), tuple()) +View.__new__.__defaults__ = (None, tuple()) +FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple()) + +Column = namedtuple( + "Column", + ["table_refs", "require_last_table", "local_tables", "qualifiable", "context"], +) +Column.__new__.__defaults__ = (None, None, tuple(), False, None) + +Keyword = namedtuple("Keyword", ["last_token"]) +Keyword.__new__.__defaults__ = (None,) +NamedQuery = namedtuple("NamedQuery", []) +Datatype = namedtuple("Datatype", ["schema"]) +Alias = namedtuple("Alias", ["aliases"]) + +Path = namedtuple("Path", []) + + +class SqlStatement: + def __init__(self, full_text, text_before_cursor): + self.identifier = None + self.word_before_cursor = word_before_cursor = last_word( + text_before_cursor, include="many_punctuations" + ) + full_text = _strip_named_query(full_text) + text_before_cursor = _strip_named_query(text_before_cursor) + + full_text, text_before_cursor, self.local_tables = isolate_query_ctes( + full_text, text_before_cursor + ) + + self.text_before_cursor_including_last_word = text_before_cursor + + # If we've partially typed a word then word_before_cursor won't be an + # empty string. In that case we want to remove the partially typed + # string before sending it to the sqlparser. Otherwise the last token + # will always be the partially typed string which renders the smart + # completion useless because it will always return the list of + # keywords as completion. + if self.word_before_cursor: + if word_before_cursor[-1] == "(" or word_before_cursor[0] == "\\": + parsed = sqlparse.parse(text_before_cursor) + else: + text_before_cursor = text_before_cursor[: -len(word_before_cursor)] + parsed = sqlparse.parse(text_before_cursor) + self.identifier = parse_partial_identifier(word_before_cursor) + else: + parsed = sqlparse.parse(text_before_cursor) + + full_text, text_before_cursor, parsed = _split_multiple_statements( + full_text, text_before_cursor, parsed + ) + + self.full_text = full_text + self.text_before_cursor = text_before_cursor + self.parsed = parsed + + self.last_token = parsed and parsed.token_prev(len(parsed.tokens))[1] or "" + + def is_insert(self): + return self.parsed.token_first().value.lower() == "insert" + + def get_tables(self, scope="full"): + """Gets the tables available in the statement. + param `scope:` possible values: 'full', 'insert', 'before' + If 'insert', only the first table is returned. + If 'before', only tables before the cursor are returned. + If not 'insert' and the stmt is an insert, the first table is skipped. + """ + tables = extract_tables( + self.full_text if scope == "full" else self.text_before_cursor + ) + if scope == "insert": + tables = tables[:1] + elif self.is_insert(): + tables = tables[1:] + return tables + + def get_previous_token(self, token): + return self.parsed.token_prev(self.parsed.token_index(token))[1] + + def get_identifier_schema(self): + schema = (self.identifier and self.identifier.get_parent_name()) or None + # If schema name is unquoted, lower-case it + if schema and self.identifier.value[0] != '"': + schema = schema.lower() + + return schema + + def reduce_to_prev_keyword(self, n_skip=0): + prev_keyword, self.text_before_cursor = find_prev_keyword( + self.text_before_cursor, n_skip=n_skip + ) + return prev_keyword + + +def suggest_type(full_text, text_before_cursor): + """Takes the full_text that is typed so far and also the text before the + cursor to suggest completion type and scope. + + Returns a tuple with a type of entity ('table', 'column' etc) and a scope. + A scope for a column category will be a list of tables. + """ + + if full_text.startswith("\\i "): + return (Path(),) + + # This is a temporary hack; the exception handling + # here should be removed once sqlparse has been fixed + try: + stmt = SqlStatement(full_text, text_before_cursor) + except (TypeError, AttributeError): + return [] + + # Check for special commands and handle those separately + if stmt.parsed: + # Be careful here because trivial whitespace is parsed as a + # statement, but the statement won't have a first token + tok1 = stmt.parsed.token_first() + if tok1 and tok1.value.startswith("\\"): + text = stmt.text_before_cursor + stmt.word_before_cursor + return suggest_special(text) + + return suggest_based_on_last_token(stmt.last_token, stmt) + + +named_query_regex = re.compile(r"^\s*\\ns\s+[A-z0-9\-_]+\s+") + + +def _strip_named_query(txt): + """ + This will strip "save named query" command in the beginning of the line: + '\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' + ' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' + """ + + if named_query_regex.match(txt): + txt = named_query_regex.sub("", txt) + return txt + + +function_body_pattern = re.compile(r"(\$.*?\$)([\s\S]*?)\1", re.M) + + +def _find_function_body(text): + split = function_body_pattern.search(text) + return (split.start(2), split.end(2)) if split else (None, None) + + +def _statement_from_function(full_text, text_before_cursor, statement): + current_pos = len(text_before_cursor) + body_start, body_end = _find_function_body(full_text) + if body_start is None: + return full_text, text_before_cursor, statement + if not body_start <= current_pos < body_end: + return full_text, text_before_cursor, statement + full_text = full_text[body_start:body_end] + text_before_cursor = text_before_cursor[body_start:] + parsed = sqlparse.parse(text_before_cursor) + return _split_multiple_statements(full_text, text_before_cursor, parsed) + + +def _split_multiple_statements(full_text, text_before_cursor, parsed): + if len(parsed) > 1: + # Multiple statements being edited -- isolate the current one by + # cumulatively summing statement lengths to find the one that bounds + # the current position + current_pos = len(text_before_cursor) + stmt_start, stmt_end = 0, 0 + + for statement in parsed: + stmt_len = len(str(statement)) + stmt_start, stmt_end = stmt_end, stmt_end + stmt_len + + if stmt_end >= current_pos: + text_before_cursor = full_text[stmt_start:current_pos] + full_text = full_text[stmt_start:] + break + + elif parsed: + # A single statement + statement = parsed[0] + else: + # The empty string + return full_text, text_before_cursor, None + + token2 = None + if statement.get_type() in ("CREATE", "CREATE OR REPLACE"): + token1 = statement.token_first() + if token1: + token1_idx = statement.token_index(token1) + token2 = statement.token_next(token1_idx)[1] + if token2 and token2.value.upper() == "FUNCTION": + full_text, text_before_cursor, statement = _statement_from_function( + full_text, text_before_cursor, statement + ) + return full_text, text_before_cursor, statement + + +SPECIALS_SUGGESTION = { + "dT": Datatype, + "df": Function, + "dt": Table, + "dv": View, + "sf": Function, +} + + +def suggest_special(text): + text = text.lstrip() + cmd, _, arg = parse_special_command(text) + + if cmd == text: + # Trying to complete the special command itself + return (Special(),) + + if cmd in ("\\c", "\\connect"): + return (Database(),) + + if cmd == "\\T": + return (TableFormat(),) + + if cmd == "\\dn": + return (Schema(),) + + if arg: + # Try to distinguish "\d name" from "\d schema.name" + # Note that this will fail to obtain a schema name if wildcards are + # used, e.g. "\d schema???.name" + parsed = sqlparse.parse(arg)[0].tokens[0] + try: + schema = parsed.get_parent_name() + except AttributeError: + schema = None + else: + schema = None + + if cmd[1:] == "d": + # \d can describe tables or views + if schema: + return (Table(schema=schema), View(schema=schema)) + else: + return (Schema(), Table(schema=None), View(schema=None)) + elif cmd[1:] in SPECIALS_SUGGESTION: + rel_type = SPECIALS_SUGGESTION[cmd[1:]] + if schema: + if rel_type == Function: + return (Function(schema=schema, usage="special"),) + return (rel_type(schema=schema),) + else: + if rel_type == Function: + return (Schema(), Function(schema=None, usage="special")) + return (Schema(), rel_type(schema=None)) + + if cmd in ["\\n", "\\ns", "\\nd"]: + return (NamedQuery(),) + + return (Keyword(), Special()) + + +def suggest_based_on_last_token(token, stmt): + if isinstance(token, str): + token_v = token.lower() + elif isinstance(token, Comparison): + # If 'token' is a Comparison type such as + # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling + # token.value on the comparison type will only return the lhs of the + # comparison. In this case a.id. So we need to do token.tokens to get + # both sides of the comparison and pick the last token out of that + # list. + token_v = token.tokens[-1].value.lower() + elif isinstance(token, Where): + # sqlparse groups all tokens from the where clause into a single token + # list. This means that token.value may be something like + # 'where foo > 5 and '. We need to look "inside" token.tokens to handle + # suggestions in complicated where clauses correctly + prev_keyword = stmt.reduce_to_prev_keyword() + return suggest_based_on_last_token(prev_keyword, stmt) + elif isinstance(token, Identifier): + # If the previous token is an identifier, we can suggest datatypes if + # we're in a parenthesized column/field list, e.g.: + # CREATE TABLE foo (Identifier <CURSOR> + # CREATE FUNCTION foo (Identifier <CURSOR> + # If we're not in a parenthesized list, the most likely scenario is the + # user is about to specify an alias, e.g.: + # SELECT Identifier <CURSOR> + # SELECT foo FROM Identifier <CURSOR> + prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor) + if prev_keyword and prev_keyword.value == "(": + # Suggest datatypes + return suggest_based_on_last_token("type", stmt) + else: + return (Keyword(),) + else: + token_v = token.value.lower() + + if not token: + return (Keyword(), Special()) + elif token_v.endswith("("): + p = sqlparse.parse(stmt.text_before_cursor)[0] + + if p.tokens and isinstance(p.tokens[-1], Where): + # Four possibilities: + # 1 - Parenthesized clause like "WHERE foo AND (" + # Suggest columns/functions + # 2 - Function call like "WHERE foo(" + # Suggest columns/functions + # 3 - Subquery expression like "WHERE EXISTS (" + # Suggest keywords, in order to do a subquery + # 4 - Subquery OR array comparison like "WHERE foo = ANY(" + # Suggest columns/functions AND keywords. (If we wanted to be + # really fancy, we could suggest only array-typed columns) + + column_suggestions = suggest_based_on_last_token("where", stmt) + + # Check for a subquery expression (cases 3 & 4) + where = p.tokens[-1] + prev_tok = where.token_prev(len(where.tokens) - 1)[1] + + if isinstance(prev_tok, Comparison): + # e.g. "SELECT foo FROM bar WHERE foo = ANY(" + prev_tok = prev_tok.tokens[-1] + + prev_tok = prev_tok.value.lower() + if prev_tok == "exists": + return (Keyword(),) + else: + return column_suggestions + + # Get the token before the parens + prev_tok = p.token_prev(len(p.tokens) - 1)[1] + + if ( + prev_tok + and prev_tok.value + and prev_tok.value.lower().split(" ")[-1] == "using" + ): + # tbl1 INNER JOIN tbl2 USING (col1, col2) + tables = stmt.get_tables("before") + + # suggest columns that are present in more than one table + return ( + Column( + table_refs=tables, + require_last_table=True, + local_tables=stmt.local_tables, + ), + ) + + elif p.token_first().value.lower() == "select": + # If the lparen is preceded by a space chances are we're about to + # do a sub-select. + if last_word(stmt.text_before_cursor, "all_punctuations").startswith("("): + return (Keyword(),) + prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1] + if prev_prev_tok and prev_prev_tok.normalized == "INTO": + return (Column(table_refs=stmt.get_tables("insert"), context="insert"),) + # We're probably in a function argument list + return _suggest_expression(token_v, stmt) + elif token_v == "set": + return (Column(table_refs=stmt.get_tables(), local_tables=stmt.local_tables),) + elif token_v in ("select", "where", "having", "order by", "distinct"): + return _suggest_expression(token_v, stmt) + elif token_v == "as": + # Don't suggest anything for aliases + return () + elif (token_v.endswith("join") and token.is_keyword) or ( + token_v in ("copy", "from", "update", "into", "describe", "truncate") + ): + schema = stmt.get_identifier_schema() + tables = extract_tables(stmt.text_before_cursor) + is_join = token_v.endswith("join") and token.is_keyword + + # Suggest tables from either the currently-selected schema or the + # public schema if no schema has been specified + suggest = [] + + if not schema: + # Suggest schemas + suggest.insert(0, Schema()) + + if token_v == "from" or is_join: + suggest.append( + FromClauseItem( + schema=schema, table_refs=tables, local_tables=stmt.local_tables + ) + ) + elif token_v == "truncate": + suggest.append(Table(schema)) + else: + suggest.extend((Table(schema), View(schema))) + + if is_join and _allow_join(stmt.parsed): + tables = stmt.get_tables("before") + suggest.append(Join(table_refs=tables, schema=schema)) + + return tuple(suggest) + + elif token_v == "function": + schema = stmt.get_identifier_schema() + + # stmt.get_previous_token will fail for e.g. `SELECT 1 FROM functions WHERE function:` + try: + prev = stmt.get_previous_token(token).value.lower() + if prev in ("drop", "alter", "create", "create or replace"): + # Suggest functions from either the currently-selected schema or the + # public schema if no schema has been specified + suggest = [] + + if not schema: + # Suggest schemas + suggest.insert(0, Schema()) + + suggest.append(Function(schema=schema, usage="signature")) + return tuple(suggest) + + except ValueError: + pass + return tuple() + + elif token_v in ("table", "view"): + # E.g. 'ALTER TABLE <tablname>' + rel_type = {"table": Table, "view": View, "function": Function}[token_v] + schema = stmt.get_identifier_schema() + if schema: + return (rel_type(schema=schema),) + else: + return (Schema(), rel_type(schema=schema)) + + elif token_v == "column": + # E.g. 'ALTER TABLE foo ALTER COLUMN bar + return (Column(table_refs=stmt.get_tables()),) + + elif token_v == "on": + tables = stmt.get_tables("before") + parent = (stmt.identifier and stmt.identifier.get_parent_name()) or None + if parent: + # "ON parent.<suggestion>" + # parent can be either a schema name or table alias + filteredtables = tuple(t for t in tables if identifies(parent, t)) + sugs = [ + Column(table_refs=filteredtables, local_tables=stmt.local_tables), + Table(schema=parent), + View(schema=parent), + Function(schema=parent), + ] + if filteredtables and _allow_join_condition(stmt.parsed): + sugs.append(JoinCondition(table_refs=tables, parent=filteredtables[-1])) + return tuple(sugs) + else: + # ON <suggestion> + # Use table alias if there is one, otherwise the table name + aliases = tuple(t.ref for t in tables) + if _allow_join_condition(stmt.parsed): + return ( + Alias(aliases=aliases), + JoinCondition(table_refs=tables, parent=None), + ) + else: + return (Alias(aliases=aliases),) + + elif token_v in ("c", "use", "database", "template"): + # "\c <db", "use <db>", "DROP DATABASE <db>", + # "CREATE DATABASE <newdb> WITH TEMPLATE <db>" + return (Database(),) + elif token_v == "schema": + # DROP SCHEMA schema_name, SET SCHEMA schema name + prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2) + quoted = prev_keyword and prev_keyword.value.lower() == "set" + return (Schema(quoted),) + elif token_v.endswith(",") or token_v in ("=", "and", "or"): + prev_keyword = stmt.reduce_to_prev_keyword() + if prev_keyword: + return suggest_based_on_last_token(prev_keyword, stmt) + else: + return () + elif token_v in ("type", "::"): + # ALTER TABLE foo SET DATA TYPE bar + # SELECT foo::bar + # Note that tables are a form of composite type in postgresql, so + # they're suggested here as well + schema = stmt.get_identifier_schema() + suggestions = [Datatype(schema=schema), Table(schema=schema)] + if not schema: + suggestions.append(Schema()) + return tuple(suggestions) + elif token_v in {"alter", "create", "drop"}: + return (Keyword(token_v.upper()),) + elif token.is_keyword: + # token is a keyword we haven't implemented any special handling for + # go backwards in the query until we find one we do recognize + prev_keyword = stmt.reduce_to_prev_keyword(n_skip=1) + if prev_keyword: + return suggest_based_on_last_token(prev_keyword, stmt) + else: + return (Keyword(token_v.upper()),) + else: + return (Keyword(),) + + +def _suggest_expression(token_v, stmt): + """ + Return suggestions for an expression, taking account of any partially-typed + identifier's parent, which may be a table alias or schema name. + """ + parent = stmt.identifier.get_parent_name() if stmt.identifier else [] + tables = stmt.get_tables() + + if parent: + tables = tuple(t for t in tables if identifies(parent, t)) + return ( + Column(table_refs=tables, local_tables=stmt.local_tables), + Table(schema=parent), + View(schema=parent), + Function(schema=parent), + ) + + return ( + Column(table_refs=tables, local_tables=stmt.local_tables, qualifiable=True), + Function(schema=None), + Keyword(token_v.upper()), + ) + + +def identifies(id, ref): + """Returns true if string `id` matches TableReference `ref`""" + + return ( + id == ref.alias + or id == ref.name + or (ref.schema and (id == ref.schema + "." + ref.name)) + ) + + +def _allow_join_condition(statement): + """ + Tests if a join condition should be suggested + + We need this to avoid bad suggestions when entering e.g. + select * from tbl1 a join tbl2 b on a.id = <cursor> + So check that the preceding token is a ON, AND, or OR keyword, instead of + e.g. an equals sign. + + :param statement: an sqlparse.sql.Statement + :return: boolean + """ + + if not statement or not statement.tokens: + return False + + last_tok = statement.token_prev(len(statement.tokens))[1] + return last_tok.value.lower() in ("on", "and", "or") + + +def _allow_join(statement): + """ + Tests if a join should be suggested + + We need this to avoid bad suggestions when entering e.g. + select * from tbl1 a join tbl2 b <cursor> + So check that the preceding token is a JOIN keyword + + :param statement: an sqlparse.sql.Statement + :return: boolean + """ + + if not statement or not statement.tokens: + return False + + last_tok = statement.token_prev(len(statement.tokens))[1] + return last_tok.value.lower().endswith("join") and last_tok.value.lower() not in ( + "cross join", + "natural join", + ) diff --git a/pgcli/pgbuffer.py b/pgcli/pgbuffer.py new file mode 100644 index 0000000..c236c13 --- /dev/null +++ b/pgcli/pgbuffer.py @@ -0,0 +1,61 @@ +import logging + +from prompt_toolkit.enums import DEFAULT_BUFFER +from prompt_toolkit.filters import Condition +from prompt_toolkit.application import get_app +from .packages.parseutils.utils import is_open_quote + +_logger = logging.getLogger(__name__) + + +def _is_complete(sql): + # A complete command is an sql statement that ends with a semicolon, unless + # there's an open quote surrounding it, as is common when writing a + # CREATE FUNCTION command + return sql.endswith(";") and not is_open_quote(sql) + + +""" +Returns True if the buffer contents should be handled (i.e. the query/command +executed) immediately. This is necessary as we use prompt_toolkit in multiline +mode, which by default will insert new lines on Enter. +""" + + +def safe_multi_line_mode(pgcli): + @Condition + def cond(): + _logger.debug( + 'Multi-line mode state: "%s" / "%s"', pgcli.multi_line, pgcli.multiline_mode + ) + return pgcli.multi_line and (pgcli.multiline_mode == "safe") + + return cond + + +def buffer_should_be_handled(pgcli): + @Condition + def cond(): + if not pgcli.multi_line: + _logger.debug("Not in multi-line mode. Handle the buffer.") + return True + + if pgcli.multiline_mode == "safe": + _logger.debug("Multi-line mode is set to 'safe'. Do NOT handle the buffer.") + return False + + doc = get_app().layout.get_buffer_by_name(DEFAULT_BUFFER).document + text = doc.text.strip() + + return ( + text.startswith("\\") # Special Command + or text.endswith(r"\e") # Special Command + or text.endswith(r"\G") # Ended with \e which should launch the editor + or _is_complete(text) # A complete SQL command + or (text == "exit") # Exit doesn't need semi-colon + or (text == "quit") # Quit doesn't need semi-colon + or (text == ":q") # To all the vim fans out there + or (text == "") # Just a plain enter without any text + ) + + return cond diff --git a/pgcli/pgclirc b/pgcli/pgclirc new file mode 100644 index 0000000..51f7eae --- /dev/null +++ b/pgcli/pgclirc @@ -0,0 +1,236 @@ +# vi: ft=dosini +[main] + +# Enables context sensitive auto-completion. If this is disabled, all +# possible completions will be listed. +smart_completion = True + +# Display the completions in several columns. (More completions will be +# visible.) +wider_completion_menu = False + +# Do not create new connections for refreshing completions; Equivalent to +# always running with the --single-connection flag. +always_use_single_connection = False + +# Multi-line mode allows breaking up the sql statements into multiple lines. If +# this is set to True, then the end of the statements must have a semi-colon. +# If this is set to False then sql statements can't be split into multiple +# lines. End of line (return) is considered as the end of the statement. +multi_line = False + +# If multi_line_mode is set to "psql", in multi-line mode, [Enter] will execute +# the current input if the input ends in a semicolon. +# If multi_line_mode is set to "safe", in multi-line mode, [Enter] will always +# insert a newline, and [Esc] [Enter] or [Alt]-[Enter] must be used to execute +# a command. +multi_line_mode = psql + +# Destructive warning will alert you before executing a sql statement +# that may cause harm to the database such as "drop table", "drop database", +# "shutdown", "delete", or "update". +# You can pass a list of destructive commands or leave it empty if you want to skip all warnings. +# "unconditional_update" will warn you of update statements that don't have a where clause +destructive_warning = drop, shutdown, delete, truncate, alter, update, unconditional_update + +# Destructive warning can restart the connection if this is enabled and the +# user declines. This means that any current uncommitted transaction can be +# aborted if the user doesn't want to proceed with a destructive_warning +# statement. +destructive_warning_restarts_connection = False + +# When this option is on (and if `destructive_warning` is not empty), +# destructive statements are not executed when outside of a transaction. +destructive_statements_require_transaction = False + +# Enables expand mode, which is similar to `\x` in psql. +expand = False + +# Enables auto expand mode, which is similar to `\x auto` in psql. +auto_expand = False + +# Auto-retry queries on connection failures and other operational errors. If +# False, will prompt to rerun the failed query instead of auto-retrying. +auto_retry_closed_connection = True + +# If set to True, table suggestions will include a table alias +generate_aliases = False + +# Path to a json file that specifies specific table aliases to use when generate_aliases is set to True +# the format for this file should be: +# { +# "some_table_name": "desired_alias", +# "some_other_table_name": "another_alias" +# } +alias_map_file = + +# log_file location. +# In Unix/Linux: ~/.config/pgcli/log +# In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\log +# %USERPROFILE% is typically C:\Users\{username} +log_file = default + +# keyword casing preference. Possible values: "lower", "upper", "auto" +keyword_casing = auto + +# casing_file location. +# In Unix/Linux: ~/.config/pgcli/casing +# In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\casing +# %USERPROFILE% is typically C:\Users\{username} +casing_file = default + +# If generate_casing_file is set to True and there is no file in the above +# location, one will be generated based on usage in SQL/PLPGSQL functions. +generate_casing_file = False + +# Casing of column headers based on the casing_file described above +case_column_headers = True + +# history_file location. +# In Unix/Linux: ~/.config/pgcli/history +# In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\history +# %USERPROFILE% is typically C:\Users\{username} +history_file = default + +# Default log level. Possible values: "CRITICAL", "ERROR", "WARNING", "INFO" +# and "DEBUG". "NONE" disables logging. +log_level = INFO + +# Order of columns when expanding * to column list +# Possible values: "table_order" and "alphabetic" +asterisk_column_order = table_order + +# Whether to qualify with table alias/name when suggesting columns +# Possible values: "always", "never" and "if_more_than_one_table" +qualify_columns = if_more_than_one_table + +# When no schema is entered, only suggest objects in search_path +search_path_filter = False + +# Default pager. See https://www.pgcli.com/pager for more information on settings. +# By default 'PAGER' environment variable is used. If the pager is less, and the 'LESS' +# environment variable is not set, then LESS='-SRXF' will be automatically set. +# pager = less + +# Timing of sql statements and table rendering. +timing = True + +# Show/hide the informational toolbar with function keymap at the footer. +show_bottom_toolbar = True + +# Table format. Possible values: psql, plain, simple, grid, fancy_grid, pipe, +# ascii, double, github, orgtbl, rst, mediawiki, html, latex, latex_booktabs, +# textile, moinmoin, jira, vertical, tsv, csv, sql-insert, sql-update, +# sql-update-1, sql-update-2 (formatter with sql-* prefix can format query +# output to executable insertion or updating sql). +# Recommended: psql, fancy_grid and grid. +table_format = psql + +# Syntax Style. Possible values: manni, igor, xcode, vim, autumn, vs, rrt, +# native, perldoc, borland, tango, emacs, friendly, monokai, paraiso-dark, +# colorful, murphy, bw, pastie, paraiso-light, trac, default, fruity +syntax_style = default + +# Keybindings: +# When Vi mode is enabled you can use modal editing features offered by Vi in the REPL. +# When Vi mode is disabled emacs keybindings such as Ctrl-A for home and Ctrl-E +# for end are available in the REPL. +vi = False + +# Error handling +# When one of multiple SQL statements causes an error, choose to either +# continue executing the remaining statements, or stopping +# Possible values "STOP" or "RESUME" +on_error = STOP + +# Set threshold for row limit. Use 0 to disable limiting. +row_limit = 1000 + +# Truncate long text fields to this value for tabular display (does not apply to csv). +# Leave unset to disable truncation. Example: "max_field_width = " +# Be aware that formatting might get slow with values larger than 500 and tables with +# lots of records. +max_field_width = 500 + +# Skip intro on startup and goodbye on exit +less_chatty = False + +# Postgres prompt +# \t - Current date and time +# \u - Username +# \h - Short hostname of the server (up to first '.') +# \H - Hostname of the server +# \d - Database name +# \p - Database port +# \i - Postgres PID +# \# - "@" sign if logged in as superuser, '>' in other case +# \n - Newline +# \dsn_alias - name of dsn connection string alias if -D option is used (empty otherwise) +# \x1b[...m - insert ANSI escape sequence +# eg: prompt = '\x1b[35m\u@\x1b[32m\h:\x1b[36m\d>' +prompt = '\u@\h:\d> ' + +# Number of lines to reserve for the suggestion menu +min_num_menu_lines = 4 + +# Character used to left pad multi-line queries to match the prompt size. +multiline_continuation_char = '' + +# The string used in place of a null value. +null_string = '<null>' + +# manage pager on startup +enable_pager = True + +# Use keyring to automatically save and load password in a secure manner +keyring = True + +# Custom colors for the completion menu, toolbar, etc. +[colors] +completion-menu.completion.current = 'bg:#ffffff #000000' +completion-menu.completion = 'bg:#008888 #ffffff' +completion-menu.meta.completion.current = 'bg:#44aaaa #000000' +completion-menu.meta.completion = 'bg:#448888 #ffffff' +completion-menu.multi-column-meta = 'bg:#aaffff #000000' +scrollbar.arrow = 'bg:#003333' +scrollbar = 'bg:#00aaaa' +selected = '#ffffff bg:#6666aa' +search = '#ffffff bg:#4444aa' +search.current = '#ffffff bg:#44aa44' +bottom-toolbar = 'bg:#222222 #aaaaaa' +bottom-toolbar.off = 'bg:#222222 #888888' +bottom-toolbar.on = 'bg:#222222 #ffffff' +search-toolbar = 'noinherit bold' +search-toolbar.text = 'nobold' +system-toolbar = 'noinherit bold' +arg-toolbar = 'noinherit bold' +arg-toolbar.text = 'nobold' +bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold' +bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold' +# These three values can be used to further refine the syntax highlighting. +# They are commented out by default, since they have priority over the theme set +# with the `syntax_style` setting and overriding its behavior can be confusing. +# literal.string = '#ba2121' +# literal.number = '#666666' +# keyword = 'bold #008000' + +# style classes for colored table output +output.header = "#00ff5f bold" +output.odd-row = "" +output.even-row = "" +output.null = "#808080" + +# Named queries are queries you can execute by name. +[named queries] + +# Here's where you can provide a list of connection string aliases. +# You can use it by passing the -D option. `pgcli -D example_dsn` +[alias_dsn] +# example_dsn = postgresql://[user[:password]@][netloc][:port][/dbname] + +# Format for number representation +# for decimal "d" - 12345678, ",d" - 12,345,678 +# for float "g" - 123456.78, ",g" - 123,456.78 +[data_formats] +decimal = "" +float = "" diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py new file mode 100644 index 0000000..17fc540 --- /dev/null +++ b/pgcli/pgcompleter.py @@ -0,0 +1,1072 @@ +import json +import logging +import re +from itertools import count, repeat, chain +import operator +from collections import namedtuple, defaultdict, OrderedDict +from cli_helpers.tabular_output import TabularOutputFormatter +from pgspecial.namedqueries import NamedQueries +from prompt_toolkit.completion import Completer, Completion, PathCompleter +from prompt_toolkit.document import Document +from .packages.sqlcompletion import ( + FromClauseItem, + suggest_type, + Special, + Database, + Schema, + Table, + TableFormat, + Function, + Column, + View, + Keyword, + NamedQuery, + Datatype, + Alias, + Path, + JoinCondition, + Join, +) +from .packages.parseutils.meta import ColumnMetadata, ForeignKey +from .packages.parseutils.utils import last_word +from .packages.parseutils.tables import TableReference +from .packages.pgliterals.main import get_literals +from .packages.prioritization import PrevalenceCounter +from .config import load_config, config_location + +_logger = logging.getLogger(__name__) + +Match = namedtuple("Match", ["completion", "priority"]) + +_SchemaObject = namedtuple("SchemaObject", "name schema meta") + + +def SchemaObject(name, schema=None, meta=None): + return _SchemaObject(name, schema, meta) + + +_Candidate = namedtuple("Candidate", "completion prio meta synonyms prio2 display") + + +def Candidate( + completion, prio=None, meta=None, synonyms=None, prio2=None, display=None +): + return _Candidate( + completion, prio, meta, synonyms or [completion], prio2, display or completion + ) + + +# Used to strip trailing '::some_type' from default-value expressions +arg_default_type_strip_regex = re.compile(r"::[\w\.]+(\[\])?$") + +normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"' + + +def generate_alias(tbl, alias_map=None): + """Generate a table alias, consisting of all upper-case letters in + the table name, or, if there are no upper-case letters, the first letter + + all letters preceded by _ + param tbl - unescaped name of the table to alias + """ + if alias_map and tbl in alias_map: + return alias_map[tbl] + return "".join( + [l for l in tbl if l.isupper()] + or [l for l, prev in zip(tbl, "_" + tbl) if prev == "_" and l != "_"] + ) + + +class InvalidMapFile(ValueError): + pass + + +def load_alias_map_file(path): + try: + with open(path) as fo: + alias_map = json.load(fo) + except FileNotFoundError as err: + raise InvalidMapFile( + f"Cannot read alias_map_file - {err.filename} does not exist" + ) + except json.JSONDecodeError: + raise InvalidMapFile(f"Cannot read alias_map_file - {path} is not valid json") + else: + return alias_map + + +class PGCompleter(Completer): + # keywords_tree: A dict mapping keywords to well known following keywords. + # e.g. 'CREATE': ['TABLE', 'USER', ...], + keywords_tree = get_literals("keywords", type_=dict) + keywords = tuple(set(chain(keywords_tree.keys(), *keywords_tree.values()))) + functions = get_literals("functions") + datatypes = get_literals("datatypes") + reserved_words = set(get_literals("reserved")) + + def __init__(self, smart_completion=True, pgspecial=None, settings=None): + super().__init__() + self.smart_completion = smart_completion + self.pgspecial = pgspecial + self.prioritizer = PrevalenceCounter() + settings = settings or {} + self.signature_arg_style = settings.get( + "signature_arg_style", "{arg_name} {arg_type}" + ) + self.call_arg_style = settings.get( + "call_arg_style", "{arg_name: <{max_arg_len}} := {arg_default}" + ) + self.call_arg_display_style = settings.get( + "call_arg_display_style", "{arg_name}" + ) + self.call_arg_oneliner_max = settings.get("call_arg_oneliner_max", 2) + self.search_path_filter = settings.get("search_path_filter") + self.generate_aliases = settings.get("generate_aliases") + alias_map_file = settings.get("alias_map_file") + if alias_map_file is not None: + self.alias_map = load_alias_map_file(alias_map_file) + else: + self.alias_map = None + self.casing_file = settings.get("casing_file") + self.insert_col_skip_patterns = [ + re.compile(pattern) + for pattern in settings.get( + "insert_col_skip_patterns", [r"^now\(\)$", r"^nextval\("] + ) + ] + self.generate_casing_file = settings.get("generate_casing_file") + self.qualify_columns = settings.get("qualify_columns", "if_more_than_one_table") + self.asterisk_column_order = settings.get( + "asterisk_column_order", "table_order" + ) + + keyword_casing = settings.get("keyword_casing", "upper").lower() + if keyword_casing not in ("upper", "lower", "auto"): + keyword_casing = "upper" + self.keyword_casing = keyword_casing + self.name_pattern = re.compile(r"^[_a-z][_a-z0-9\$]*$") + + self.databases = [] + self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}} + self.search_path = [] + self.casing = {} + + self.all_completions = set(self.keywords + self.functions) + + def escape_name(self, name): + if name and ( + (not self.name_pattern.match(name)) + or (name.upper() in self.reserved_words) + or (name.upper() in self.functions) + ): + name = '"%s"' % name + + return name + + def escape_schema(self, name): + return "'{}'".format(self.unescape_name(name)) + + def unescape_name(self, name): + """Unquote a string.""" + if name and name[0] == '"' and name[-1] == '"': + name = name[1:-1] + + return name + + def escaped_names(self, names): + return [self.escape_name(name) for name in names] + + def extend_database_names(self, databases): + self.databases.extend(databases) + + def extend_keywords(self, additional_keywords): + self.keywords.extend(additional_keywords) + self.all_completions.update(additional_keywords) + + def extend_schemata(self, schemata): + # schemata is a list of schema names + schemata = self.escaped_names(schemata) + metadata = self.dbmetadata["tables"] + for schema in schemata: + metadata[schema] = {} + + # dbmetadata.values() are the 'tables' and 'functions' dicts + for metadata in self.dbmetadata.values(): + for schema in schemata: + metadata[schema] = {} + + self.all_completions.update(schemata) + + def extend_casing(self, words): + """extend casing data + + :return: + """ + # casing should be a dict {lowercasename:PreferredCasingName} + self.casing = {word.lower(): word for word in words} + + def extend_relations(self, data, kind): + """extend metadata for tables or views. + + :param data: list of (schema_name, rel_name) tuples + :param kind: either 'tables' or 'views' + + :return: + + """ + + data = [self.escaped_names(d) for d in data] + + # dbmetadata['tables']['schema_name']['table_name'] should be an + # OrderedDict {column_name:ColumnMetaData}. + metadata = self.dbmetadata[kind] + for schema, relname in data: + try: + metadata[schema][relname] = OrderedDict() + except KeyError: + _logger.error( + "%r %r listed in unrecognized schema %r", kind, relname, schema + ) + self.all_completions.add(relname) + + def extend_columns(self, column_data, kind): + """extend column metadata. + + :param column_data: list of (schema_name, rel_name, column_name, + column_type, has_default, default) tuples + :param kind: either 'tables' or 'views' + + :return: + + """ + metadata = self.dbmetadata[kind] + for schema, relname, colname, datatype, has_default, default in column_data: + (schema, relname, colname) = self.escaped_names([schema, relname, colname]) + column = ColumnMetadata( + name=colname, + datatype=datatype, + has_default=has_default, + default=default, + ) + metadata[schema][relname][colname] = column + self.all_completions.add(colname) + + def extend_functions(self, func_data): + # func_data is a list of function metadata namedtuples + + # dbmetadata['schema_name']['functions']['function_name'] should return + # the function metadata namedtuple for the corresponding function + metadata = self.dbmetadata["functions"] + + for f in func_data: + schema, func = self.escaped_names([f.schema_name, f.func_name]) + + if func in metadata[schema]: + metadata[schema][func].append(f) + else: + metadata[schema][func] = [f] + + self.all_completions.add(func) + + self._refresh_arg_list_cache() + + def _refresh_arg_list_cache(self): + # We keep a cache of {function_usage:{function_metadata: function_arg_list_string}} + # This is used when suggesting functions, to avoid the latency that would result + # if we'd recalculate the arg lists each time we suggest functions (in large DBs) + self._arg_list_cache = { + usage: { + meta: self._arg_list(meta, usage) + for sch, funcs in self.dbmetadata["functions"].items() + for func, metas in funcs.items() + for meta in metas + } + for usage in ("call", "call_display", "signature") + } + + def extend_foreignkeys(self, fk_data): + # fk_data is a list of ForeignKey namedtuples, with fields + # parentschema, childschema, parenttable, childtable, + # parentcolumns, childcolumns + + # These are added as a list of ForeignKey namedtuples to the + # ColumnMetadata namedtuple for both the child and parent + meta = self.dbmetadata["tables"] + + for fk in fk_data: + e = self.escaped_names + parentschema, childschema = e([fk.parentschema, fk.childschema]) + parenttable, childtable = e([fk.parenttable, fk.childtable]) + childcol, parcol = e([fk.childcolumn, fk.parentcolumn]) + childcolmeta = meta[childschema][childtable][childcol] + parcolmeta = meta[parentschema][parenttable][parcol] + fk = ForeignKey( + parentschema, parenttable, parcol, childschema, childtable, childcol + ) + childcolmeta.foreignkeys.append(fk) + parcolmeta.foreignkeys.append(fk) + + def extend_datatypes(self, type_data): + # dbmetadata['datatypes'][schema_name][type_name] should store type + # metadata, such as composite type field names. Currently, we're not + # storing any metadata beyond typename, so just store None + meta = self.dbmetadata["datatypes"] + + for t in type_data: + schema, type_name = self.escaped_names(t) + meta[schema][type_name] = None + self.all_completions.add(type_name) + + def extend_query_history(self, text, is_init=False): + if is_init: + # During completer initialization, only load keyword preferences, + # not names + self.prioritizer.update_keywords(text) + else: + self.prioritizer.update(text) + + def set_search_path(self, search_path): + self.search_path = self.escaped_names(search_path) + + def reset_completions(self): + self.databases = [] + self.special_commands = [] + self.search_path = [] + self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}} + self.all_completions = set(self.keywords + self.functions) + + def find_matches(self, text, collection, mode="fuzzy", meta=None): + """Find completion matches for the given text. + + Given the user's input text and a collection of available + completions, find completions matching the last word of the + text. + + `collection` can be either a list of strings or a list of Candidate + namedtuples. + `mode` can be either 'fuzzy', or 'strict' + 'fuzzy': fuzzy matching, ties broken by name prevalance + `keyword`: start only matching, ties broken by keyword prevalance + + yields prompt_toolkit Completion instances for any matches found + in the collection of available completions. + + """ + if not collection: + return [] + prio_order = [ + "keyword", + "function", + "view", + "table", + "datatype", + "database", + "schema", + "column", + "table alias", + "join", + "name join", + "fk join", + "table format", + ] + type_priority = prio_order.index(meta) if meta in prio_order else -1 + text = last_word(text, include="most_punctuations").lower() + text_len = len(text) + + if text and text[0] == '"': + # text starts with double quote; user is manually escaping a name + # Match on everything that follows the double-quote. Note that + # text_len is calculated before removing the quote, so the + # Completion.position value is correct + text = text[1:] + + if mode == "fuzzy": + fuzzy = True + priority_func = self.prioritizer.name_count + else: + fuzzy = False + priority_func = self.prioritizer.keyword_count + + # Construct a `_match` function for either fuzzy or non-fuzzy matching + # The match function returns a 2-tuple used for sorting the matches, + # or None if the item doesn't match + # Note: higher priority values mean more important, so use negative + # signs to flip the direction of the tuple + if fuzzy: + regex = ".*?".join(map(re.escape, text)) + pat = re.compile("(%s)" % regex) + + def _match(item): + if item.lower()[: len(text) + 1] in (text, text + " "): + # Exact match of first word in suggestion + # This is to get exact alias matches to the top + # E.g. for input `e`, 'Entries E' should be on top + # (before e.g. `EndUsers EU`) + return float("Infinity"), -1 + r = pat.search(self.unescape_name(item.lower())) + if r: + return -len(r.group()), -r.start() + + else: + match_end_limit = len(text) + + def _match(item): + match_point = item.lower().find(text, 0, match_end_limit) + if match_point >= 0: + # Use negative infinity to force keywords to sort after all + # fuzzy matches + return -float("Infinity"), -match_point + + matches = [] + for cand in collection: + if isinstance(cand, _Candidate): + item, prio, display_meta, synonyms, prio2, display = cand + if display_meta is None: + display_meta = meta + syn_matches = (_match(x) for x in synonyms) + # Nones need to be removed to avoid max() crashing in Python 3 + syn_matches = [m for m in syn_matches if m] + sort_key = max(syn_matches) if syn_matches else None + else: + item, display_meta, prio, prio2, display = cand, meta, 0, 0, cand + sort_key = _match(cand) + + if sort_key: + if display_meta and len(display_meta) > 50: + # Truncate meta-text to 50 characters, if necessary + display_meta = display_meta[:47] + "..." + + # Lexical order of items in the collection, used for + # tiebreaking items with the same match group length and start + # position. Since we use *higher* priority to mean "more + # important," we use -ord(c) to prioritize "aa" > "ab" and end + # with 1 to prioritize shorter strings (ie "user" > "users"). + # We first do a case-insensitive sort and then a + # case-sensitive one as a tie breaker. + # We also use the unescape_name to make sure quoted names have + # the same priority as unquoted names. + lexical_priority = ( + tuple( + 0 if c in " _" else -ord(c) + for c in self.unescape_name(item.lower()) + ) + + (1,) + + tuple(c for c in item) + ) + + item = self.case(item) + display = self.case(display) + priority = ( + sort_key, + type_priority, + prio, + priority_func(item), + prio2, + lexical_priority, + ) + matches.append( + Match( + completion=Completion( + text=item, + start_position=-text_len, + display_meta=display_meta, + display=display, + ), + priority=priority, + ) + ) + return matches + + def case(self, word): + return self.casing.get(word, word) + + def get_completions(self, document, complete_event, smart_completion=None): + word_before_cursor = document.get_word_before_cursor(WORD=True) + + if smart_completion is None: + smart_completion = self.smart_completion + + # If smart_completion is off then match any word that starts with + # 'word_before_cursor'. + if not smart_completion: + matches = self.find_matches( + word_before_cursor, self.all_completions, mode="strict" + ) + completions = [m.completion for m in matches] + return sorted(completions, key=operator.attrgetter("text")) + + matches = [] + suggestions = suggest_type(document.text, document.text_before_cursor) + + for suggestion in suggestions: + suggestion_type = type(suggestion) + _logger.debug("Suggestion type: %r", suggestion_type) + + # Map suggestion type to method + # e.g. 'table' -> self.get_table_matches + matcher = self.suggestion_matchers[suggestion_type] + matches.extend(matcher(self, suggestion, word_before_cursor)) + + # Sort matches so highest priorities are first + matches = sorted(matches, key=operator.attrgetter("priority"), reverse=True) + + return [m.completion for m in matches] + + def get_column_matches(self, suggestion, word_before_cursor): + tables = suggestion.table_refs + do_qualify = ( + suggestion.qualifiable + and { + "always": True, + "never": False, + "if_more_than_one_table": len(tables) > 1, + }[self.qualify_columns] + ) + qualify = lambda col, tbl: ( + (tbl + "." + self.case(col)) if do_qualify else self.case(col) + ) + _logger.debug("Completion column scope: %r", tables) + scoped_cols = self.populate_scoped_cols(tables, suggestion.local_tables) + + def make_cand(name, ref): + synonyms = (name, generate_alias(self.case(name))) + return Candidate(qualify(name, ref), 0, "column", synonyms) + + def flat_cols(): + return [ + make_cand(c.name, t.ref) + for t, cols in scoped_cols.items() + for c in cols + ] + + if suggestion.require_last_table: + # require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should + # suggest only columns that appear in the last table and one more + ltbl = tables[-1].ref + other_tbl_cols = { + c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs + } + scoped_cols = { + t: [col for col in cols if col.name in other_tbl_cols] + for t, cols in scoped_cols.items() + if t.ref == ltbl + } + lastword = last_word(word_before_cursor, include="most_punctuations") + if lastword == "*": + if suggestion.context == "insert": + + def filter(col): + if not col.has_default: + return True + return not any( + p.match(col.default) for p in self.insert_col_skip_patterns + ) + + scoped_cols = { + t: [col for col in cols if filter(col)] + for t, cols in scoped_cols.items() + } + if self.asterisk_column_order == "alphabetic": + for cols in scoped_cols.values(): + cols.sort(key=operator.attrgetter("name")) + if ( + lastword != word_before_cursor + and len(tables) == 1 + and word_before_cursor[-len(lastword) - 1] == "." + ): + # User typed x.*; replicate "x." for all columns except the + # first, which gets the original (as we only replace the "*"") + sep = ", " + word_before_cursor[:-1] + collist = sep.join(self.case(c.completion) for c in flat_cols()) + else: + collist = ", ".join( + qualify(c.name, t.ref) for t, cs in scoped_cols.items() for c in cs + ) + + return [ + Match( + completion=Completion( + collist, -1, display_meta="columns", display="*" + ), + priority=(1, 1, 1), + ) + ] + + return self.find_matches(word_before_cursor, flat_cols(), meta="column") + + def alias(self, tbl, tbls): + """Generate a unique table alias + tbl - name of the table to alias, quoted if it needs to be + tbls - TableReference iterable of tables already in query + """ + tbl = self.case(tbl) + tbls = {normalize_ref(t.ref) for t in tbls} + if self.generate_aliases: + tbl = generate_alias(self.unescape_name(tbl)) + if normalize_ref(tbl) not in tbls: + return tbl + elif tbl[0] == '"': + aliases = ('"' + tbl[1:-1] + str(i) + '"' for i in count(2)) + else: + aliases = (tbl + str(i) for i in count(2)) + return next(a for a in aliases if normalize_ref(a) not in tbls) + + def get_join_matches(self, suggestion, word_before_cursor): + tbls = suggestion.table_refs + cols = self.populate_scoped_cols(tbls) + # Set up some data structures for efficient access + qualified = {normalize_ref(t.ref): t.schema for t in tbls} + ref_prio = {normalize_ref(t.ref): n for n, t in enumerate(tbls)} + refs = {normalize_ref(t.ref) for t in tbls} + other_tbls = {(t.schema, t.name) for t in list(cols)[:-1]} + joins = [] + # Iterate over FKs in existing tables to find potential joins + fks = ( + (fk, rtbl, rcol) + for rtbl, rcols in cols.items() + for rcol in rcols + for fk in rcol.foreignkeys + ) + col = namedtuple("col", "schema tbl col") + for fk, rtbl, rcol in fks: + right = col(rtbl.schema, rtbl.name, rcol.name) + child = col(fk.childschema, fk.childtable, fk.childcolumn) + parent = col(fk.parentschema, fk.parenttable, fk.parentcolumn) + left = child if parent == right else parent + if suggestion.schema and left.schema != suggestion.schema: + continue + c = self.case + if self.generate_aliases or normalize_ref(left.tbl) in refs: + lref = self.alias(left.tbl, suggestion.table_refs) + join = "{0} {4} ON {4}.{1} = {2}.{3}".format( + c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref + ) + else: + join = "{0} ON {0}.{1} = {2}.{3}".format( + c(left.tbl), c(left.col), rtbl.ref, c(right.col) + ) + alias = generate_alias(self.case(left.tbl)) + synonyms = [ + join, + "{0} ON {0}.{1} = {2}.{3}".format( + alias, c(left.col), rtbl.ref, c(right.col) + ), + ] + # Schema-qualify if (1) new table in same schema as old, and old + # is schema-qualified, or (2) new in other schema, except public + if not suggestion.schema and ( + qualified[normalize_ref(rtbl.ref)] + and left.schema == right.schema + or left.schema not in (right.schema, "public") + ): + join = left.schema + "." + join + prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + ( + 0 if (left.schema, left.tbl) in other_tbls else 1 + ) + joins.append(Candidate(join, prio, "join", synonyms=synonyms)) + + return self.find_matches(word_before_cursor, joins, meta="join") + + def get_join_condition_matches(self, suggestion, word_before_cursor): + col = namedtuple("col", "schema tbl col") + tbls = self.populate_scoped_cols(suggestion.table_refs).items + cols = [(t, c) for t, cs in tbls() for c in cs] + try: + lref = (suggestion.parent or suggestion.table_refs[-1]).ref + ltbl, lcols = [(t, cs) for (t, cs) in tbls() if t.ref == lref][-1] + except IndexError: # The user typed an incorrect table qualifier + return [] + conds, found_conds = [], set() + + def add_cond(lcol, rcol, rref, prio, meta): + prefix = "" if suggestion.parent else ltbl.ref + "." + case = self.case + cond = prefix + case(lcol) + " = " + rref + "." + case(rcol) + if cond not in found_conds: + found_conds.add(cond) + conds.append(Candidate(cond, prio + ref_prio[rref], meta)) + + def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]} + d = defaultdict(list) + for pair in pairs: + d[pair[0]].append(pair[1]) + return d + + # Tables that are closer to the cursor get higher prio + ref_prio = {tbl.ref: num for num, tbl in enumerate(suggestion.table_refs)} + # Map (schema, table, col) to tables + coldict = list_dict( + ((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref + ) + # For each fk from the left table, generate a join condition if + # the other table is also in the scope + fks = ((fk, lcol.name) for lcol in lcols for fk in lcol.foreignkeys) + for fk, lcol in fks: + left = col(ltbl.schema, ltbl.name, lcol) + child = col(fk.childschema, fk.childtable, fk.childcolumn) + par = col(fk.parentschema, fk.parenttable, fk.parentcolumn) + left, right = (child, par) if left == child else (par, child) + for rtbl in coldict[right]: + add_cond(left.col, right.col, rtbl.ref, 2000, "fk join") + # For name matching, use a {(colname, coltype): TableReference} dict + coltyp = namedtuple("coltyp", "name datatype") + col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols) + # Find all name-match join conditions + for c in (coltyp(c.name, c.datatype) for c in lcols): + for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref): + prio = 1000 if c.datatype in ("integer", "bigint", "smallint") else 0 + add_cond(c.name, c.name, rtbl.ref, prio, "name join") + + return self.find_matches(word_before_cursor, conds, meta="join") + + def get_function_matches(self, suggestion, word_before_cursor, alias=False): + if suggestion.usage == "from": + # Only suggest functions allowed in FROM clause + + def filt(f): + return ( + not f.is_aggregate + and not f.is_window + and not f.is_extension + and ( + f.is_public + or f.schema_name in self.search_path + or f.schema_name == suggestion.schema + ) + ) + + else: + alias = False + + def filt(f): + return not f.is_extension and ( + f.is_public or f.schema_name == suggestion.schema + ) + + arg_mode = {"signature": "signature", "special": None}.get( + suggestion.usage, "call" + ) + + # Function overloading means we way have multiple functions of the same + # name at this point, so keep unique names only + all_functions = self.populate_functions(suggestion.schema, filt) + funcs = {self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions} + + matches = self.find_matches(word_before_cursor, funcs, meta="function") + + if not suggestion.schema and not suggestion.usage: + # also suggest hardcoded functions using startswith matching + predefined_funcs = self.find_matches( + word_before_cursor, self.functions, mode="strict", meta="function" + ) + matches.extend(predefined_funcs) + + return matches + + def get_schema_matches(self, suggestion, word_before_cursor): + schema_names = self.dbmetadata["tables"].keys() + + # Unless we're sure the user really wants them, hide schema names + # starting with pg_, which are mostly temporary schemas + if not word_before_cursor.startswith("pg_"): + schema_names = [s for s in schema_names if not s.startswith("pg_")] + + if suggestion.quoted: + schema_names = [self.escape_schema(s) for s in schema_names] + + return self.find_matches(word_before_cursor, schema_names, meta="schema") + + def get_from_clause_item_matches(self, suggestion, word_before_cursor): + alias = self.generate_aliases + s = suggestion + t_sug = Table(s.schema, s.table_refs, s.local_tables) + v_sug = View(s.schema, s.table_refs) + f_sug = Function(s.schema, s.table_refs, usage="from") + return ( + self.get_table_matches(t_sug, word_before_cursor, alias) + + self.get_view_matches(v_sug, word_before_cursor, alias) + + self.get_function_matches(f_sug, word_before_cursor, alias) + ) + + def _arg_list(self, func, usage): + """Returns a an arg list string, e.g. `(_foo:=23)` for a func. + + :param func is a FunctionMetadata object + :param usage is 'call', 'call_display' or 'signature' + + """ + template = { + "call": self.call_arg_style, + "call_display": self.call_arg_display_style, + "signature": self.signature_arg_style, + }[usage] + args = func.args() + if not template: + return "()" + elif usage == "call" and len(args) < 2: + return "()" + elif usage == "call" and func.has_variadic(): + return "()" + multiline = usage == "call" and len(args) > self.call_arg_oneliner_max + max_arg_len = max(len(a.name) for a in args) if multiline else 0 + args = ( + self._format_arg(template, arg, arg_num + 1, max_arg_len) + for arg_num, arg in enumerate(args) + ) + if multiline: + return "(" + ",".join("\n " + a for a in args if a) + "\n)" + else: + return "(" + ", ".join(a for a in args if a) + ")" + + def _format_arg(self, template, arg, arg_num, max_arg_len): + if not template: + return None + if arg.has_default: + arg_default = "NULL" if arg.default is None else arg.default + # Remove trailing ::(schema.)type + arg_default = arg_default_type_strip_regex.sub("", arg_default) + else: + arg_default = "" + return template.format( + max_arg_len=max_arg_len, + arg_name=self.case(arg.name), + arg_num=arg_num, + arg_type=arg.datatype, + arg_default=arg_default, + ) + + def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None): + """Returns a Candidate namedtuple. + + :param tbl is a SchemaObject + :param arg_mode determines what type of arg list to suffix for functions. + Possible values: call, signature + + """ + cased_tbl = self.case(tbl.name) + if do_alias: + alias = self.alias(cased_tbl, suggestion.table_refs) + synonyms = (cased_tbl, generate_alias(cased_tbl)) + maybe_alias = (" " + alias) if do_alias else "" + maybe_schema = (self.case(tbl.schema) + ".") if tbl.schema else "" + suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else "" + if arg_mode == "call": + display_suffix = self._arg_list_cache["call_display"][tbl.meta] + elif arg_mode == "signature": + display_suffix = self._arg_list_cache["signature"][tbl.meta] + else: + display_suffix = "" + item = maybe_schema + cased_tbl + suffix + maybe_alias + display = maybe_schema + cased_tbl + display_suffix + maybe_alias + prio2 = 0 if tbl.schema else 1 + return Candidate(item, synonyms=synonyms, prio2=prio2, display=display) + + def get_table_matches(self, suggestion, word_before_cursor, alias=False): + tables = self.populate_schema_objects(suggestion.schema, "tables") + tables.extend(SchemaObject(tbl.name) for tbl in suggestion.local_tables) + + # Unless we're sure the user really wants them, don't suggest the + # pg_catalog tables that are implicitly on the search path + if not suggestion.schema and (not word_before_cursor.startswith("pg_")): + tables = [t for t in tables if not t.name.startswith("pg_")] + tables = [self._make_cand(t, alias, suggestion) for t in tables] + return self.find_matches(word_before_cursor, tables, meta="table") + + def get_table_formats(self, _, word_before_cursor): + formats = TabularOutputFormatter().supported_formats + return self.find_matches(word_before_cursor, formats, meta="table format") + + def get_view_matches(self, suggestion, word_before_cursor, alias=False): + views = self.populate_schema_objects(suggestion.schema, "views") + + if not suggestion.schema and (not word_before_cursor.startswith("pg_")): + views = [v for v in views if not v.name.startswith("pg_")] + views = [self._make_cand(v, alias, suggestion) for v in views] + return self.find_matches(word_before_cursor, views, meta="view") + + def get_alias_matches(self, suggestion, word_before_cursor): + aliases = suggestion.aliases + return self.find_matches(word_before_cursor, aliases, meta="table alias") + + def get_database_matches(self, _, word_before_cursor): + return self.find_matches(word_before_cursor, self.databases, meta="database") + + def get_keyword_matches(self, suggestion, word_before_cursor): + keywords = self.keywords_tree.keys() + # Get well known following keywords for the last token. If any, narrow + # candidates to this list. + next_keywords = self.keywords_tree.get(suggestion.last_token, []) + if next_keywords: + keywords = next_keywords + + casing = self.keyword_casing + if casing == "auto": + if word_before_cursor and word_before_cursor[-1].islower(): + casing = "lower" + else: + casing = "upper" + + if casing == "upper": + keywords = [k.upper() for k in keywords] + else: + keywords = [k.lower() for k in keywords] + + return self.find_matches( + word_before_cursor, keywords, mode="strict", meta="keyword" + ) + + def get_path_matches(self, _, word_before_cursor): + completer = PathCompleter(expanduser=True) + document = Document( + text=word_before_cursor, cursor_position=len(word_before_cursor) + ) + for c in completer.get_completions(document, None): + yield Match(completion=c, priority=(0,)) + + def get_special_matches(self, _, word_before_cursor): + if not self.pgspecial: + return [] + + commands = self.pgspecial.commands + cmds = commands.keys() + cmds = [Candidate(cmd, 0, commands[cmd].description) for cmd in cmds] + return self.find_matches(word_before_cursor, cmds, mode="strict") + + def get_datatype_matches(self, suggestion, word_before_cursor): + # suggest custom datatypes + types = self.populate_schema_objects(suggestion.schema, "datatypes") + types = [self._make_cand(t, False, suggestion) for t in types] + matches = self.find_matches(word_before_cursor, types, meta="datatype") + + if not suggestion.schema: + # Also suggest hardcoded types + matches.extend( + self.find_matches( + word_before_cursor, self.datatypes, mode="strict", meta="datatype" + ) + ) + + return matches + + def get_namedquery_matches(self, _, word_before_cursor): + return self.find_matches( + word_before_cursor, NamedQueries.instance.list(), meta="named query" + ) + + suggestion_matchers = { + FromClauseItem: get_from_clause_item_matches, + JoinCondition: get_join_condition_matches, + Join: get_join_matches, + Column: get_column_matches, + Function: get_function_matches, + Schema: get_schema_matches, + Table: get_table_matches, + TableFormat: get_table_formats, + View: get_view_matches, + Alias: get_alias_matches, + Database: get_database_matches, + Keyword: get_keyword_matches, + Special: get_special_matches, + Datatype: get_datatype_matches, + NamedQuery: get_namedquery_matches, + Path: get_path_matches, + } + + def populate_scoped_cols(self, scoped_tbls, local_tbls=()): + """Find all columns in a set of scoped_tables. + + :param scoped_tbls: list of TableReference namedtuples + :param local_tbls: tuple(TableMetadata) + :return: {TableReference:{colname:ColumnMetaData}} + + """ + ctes = {normalize_ref(t.name): t.columns for t in local_tbls} + columns = OrderedDict() + meta = self.dbmetadata + + def addcols(schema, rel, alias, reltype, cols): + tbl = TableReference(schema, rel, alias, reltype == "functions") + if tbl not in columns: + columns[tbl] = [] + columns[tbl].extend(cols) + + for tbl in scoped_tbls: + # Local tables should shadow database tables + if tbl.schema is None and normalize_ref(tbl.name) in ctes: + cols = ctes[normalize_ref(tbl.name)] + addcols(None, tbl.name, "CTE", tbl.alias, cols) + continue + schemas = [tbl.schema] if tbl.schema else self.search_path + for schema in schemas: + relname = self.escape_name(tbl.name) + schema = self.escape_name(schema) + if tbl.is_function: + # Return column names from a set-returning function + # Get an array of FunctionMetadata objects + functions = meta["functions"].get(schema, {}).get(relname) + for func in functions or []: + # func is a FunctionMetadata object + cols = func.fields() + addcols(schema, relname, tbl.alias, "functions", cols) + else: + for reltype in ("tables", "views"): + cols = meta[reltype].get(schema, {}).get(relname) + if cols: + cols = cols.values() + addcols(schema, relname, tbl.alias, reltype, cols) + break + + return columns + + def _get_schemas(self, obj_typ, schema): + """Returns a list of schemas from which to suggest objects. + + :param schema is the schema qualification input by the user (if any) + + """ + metadata = self.dbmetadata[obj_typ] + if schema: + schema = self.escape_name(schema) + return [schema] if schema in metadata else [] + return self.search_path if self.search_path_filter else metadata.keys() + + def _maybe_schema(self, schema, parent): + return None if parent or schema in self.search_path else schema + + def populate_schema_objects(self, schema, obj_type): + """Returns a list of SchemaObjects representing tables or views. + + :param schema is the schema qualification input by the user (if any) + + """ + + return [ + SchemaObject( + name=obj, schema=(self._maybe_schema(schema=sch, parent=schema)) + ) + for sch in self._get_schemas(obj_type, schema) + for obj in self.dbmetadata[obj_type][sch].keys() + ] + + def populate_functions(self, schema, filter_func): + """Returns a list of function SchemaObjects. + + :param filter_func is a function that accepts a FunctionMetadata + namedtuple and returns a boolean indicating whether that + function should be kept or discarded + + """ + + # Because of multiple dispatch, we can have multiple functions + # with the same name, which is why `for meta in metas` is necessary + # in the comprehensions below + return [ + SchemaObject( + name=func, + schema=(self._maybe_schema(schema=sch, parent=schema)), + meta=meta, + ) + for sch in self._get_schemas("functions", schema) + for (func, metas) in self.dbmetadata["functions"][sch].items() + for meta in metas + if filter_func(meta) + ] diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py new file mode 100644 index 0000000..497d681 --- /dev/null +++ b/pgcli/pgexecute.py @@ -0,0 +1,868 @@ +import logging +import traceback +from collections import namedtuple +import re +import pgspecial as special +import psycopg +import psycopg.sql +from psycopg.conninfo import make_conninfo +import sqlparse + +from .packages.parseutils.meta import FunctionMetadata, ForeignKey + +_logger = logging.getLogger(__name__) + +ViewDef = namedtuple( + "ViewDef", "nspname relname relkind viewdef reloptions checkoption" +) + + +# we added this funcion to strip beginning comments +# because sqlparse didn't handle tem well. It won't be needed if sqlparse +# does parsing of this situation better + + +def remove_beginning_comments(command): + # Regular expression pattern to match comments + pattern = r"^(/\*.*?\*/|--.*?)(?:\n|$)" + + # Find and remove all comments from the beginning + cleaned_command = command + comments = [] + match = re.match(pattern, cleaned_command, re.DOTALL) + while match: + comments.append(match.group()) + cleaned_command = cleaned_command[len(match.group()) :].lstrip() + match = re.match(pattern, cleaned_command, re.DOTALL) + + return [cleaned_command, comments] + + +def register_typecasters(connection): + """Casts date and timestamp values to string, resolves issues with out-of-range + dates (e.g. BC) which psycopg can't handle""" + for forced_text_type in [ + "date", + "time", + "timestamp", + "timestamptz", + "bytea", + "json", + "jsonb", + ]: + connection.adapters.register_loader( + forced_text_type, psycopg.types.string.TextLoader + ) + + +# pg3: I don't know what is this +class ProtocolSafeCursor(psycopg.Cursor): + """This class wraps and suppresses Protocol Errors with pgbouncer database. + See https://github.com/dbcli/pgcli/pull/1097. + Pgbouncer database is a virtual database with its own set of commands.""" + + def __init__(self, *args, **kwargs): + self.protocol_error = False + self.protocol_message = "" + super().__init__(*args, **kwargs) + + def __iter__(self): + if self.protocol_error: + raise StopIteration + return super().__iter__() + + def fetchall(self): + if self.protocol_error: + return [(self.protocol_message,)] + return super().fetchall() + + def fetchone(self): + if self.protocol_error: + return (self.protocol_message,) + return super().fetchone() + + # def mogrify(self, query, params): + # args = [Literal(v).as_string(self.connection) for v in params] + # return query % tuple(args) + # + def execute(self, *args, **kwargs): + try: + super().execute(*args, **kwargs) + self.protocol_error = False + self.protocol_message = "" + except psycopg.errors.ProtocolViolation as ex: + self.protocol_error = True + self.protocol_message = str(ex) + _logger.debug("%s: %s" % (ex.__class__.__name__, ex)) + + +class PGExecute: + # The boolean argument to the current_schemas function indicates whether + # implicit schemas, e.g. pg_catalog + search_path_query = """ + SELECT * FROM unnest(current_schemas(true))""" + + schemata_query = """ + SELECT nspname + FROM pg_catalog.pg_namespace + ORDER BY 1 """ + + tables_query = """ + SELECT n.nspname schema_name, + c.relname table_name + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n + ON n.oid = c.relnamespace + WHERE c.relkind = ANY(%s) + ORDER BY 1,2;""" + + databases_query = """ + SELECT d.datname + FROM pg_catalog.pg_database d + ORDER BY 1""" + + full_databases_query = """ + SELECT d.datname as "Name", + pg_catalog.pg_get_userbyid(d.datdba) as "Owner", + pg_catalog.pg_encoding_to_char(d.encoding) as "Encoding", + d.datcollate as "Collate", + d.datctype as "Ctype", + pg_catalog.array_to_string(d.datacl, E'\n') AS "Access privileges" + FROM pg_catalog.pg_database d + ORDER BY 1""" + + socket_directory_query = """ + SELECT setting + FROM pg_settings + WHERE name = 'unix_socket_directories' + """ + + view_definition_query = """ + WITH v AS (SELECT %s::pg_catalog.regclass::pg_catalog.oid AS v_oid) + SELECT nspname, relname, relkind, + pg_catalog.pg_get_viewdef(c.oid, true), + array_remove(array_remove(c.reloptions,'check_option=local'), + 'check_option=cascaded') AS reloptions, + CASE + WHEN 'check_option=local' = ANY (c.reloptions) THEN 'LOCAL'::text + WHEN 'check_option=cascaded' = ANY (c.reloptions) THEN 'CASCADED'::text + ELSE NULL + END AS checkoption + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n ON (c.relnamespace = n.oid) + JOIN v ON (c.oid = v.v_oid)""" + + function_definition_query = """ + WITH f AS + (SELECT %s::pg_catalog.regproc::pg_catalog.oid AS f_oid) + SELECT pg_catalog.pg_get_functiondef(f.f_oid) + FROM f""" + + def __init__( + self, + database=None, + user=None, + password=None, + host=None, + port=None, + dsn=None, + **kwargs, + ): + self._conn_params = {} + self._is_virtual_database = None + self.conn = None + self.dbname = None + self.user = None + self.password = None + self.host = None + self.port = None + self.server_version = None + self.extra_args = None + self.connect(database, user, password, host, port, dsn, **kwargs) + self.reset_expanded = None + + def is_virtual_database(self): + if self._is_virtual_database is None: + self._is_virtual_database = self.is_protocol_error() + return self._is_virtual_database + + def copy(self): + """Returns a clone of the current executor.""" + return self.__class__(**self._conn_params) + + def connect( + self, + database=None, + user=None, + password=None, + host=None, + port=None, + dsn=None, + **kwargs, + ): + conn_params = self._conn_params.copy() + + new_params = { + "dbname": database, + "user": user, + "password": password, + "host": host, + "port": port, + "dsn": dsn, + } + new_params.update(kwargs) + + if new_params["dsn"]: + new_params = {"dsn": new_params["dsn"], "password": new_params["password"]} + + if new_params["password"]: + new_params["dsn"] = make_conninfo( + new_params["dsn"], password=new_params.pop("password") + ) + + conn_params.update({k: v for k, v in new_params.items() if v}) + + if "dsn" in conn_params: + other_params = {k: v for k, v in conn_params.items() if k != "dsn"} + conn_info = make_conninfo(conn_params["dsn"], **other_params) + else: + conn_info = make_conninfo(**conn_params) + conn = psycopg.connect(conn_info) + conn.cursor_factory = ProtocolSafeCursor + + self._conn_params = conn_params + if self.conn: + self.conn.close() + self.conn = conn + self.conn.autocommit = True + + # When we connect using a DSN, we don't really know what db, + # user, etc. we connected to. Let's read it. + # Note: moved this after setting autocommit because of #664. + dsn_parameters = conn.info.get_parameters() + + if dsn_parameters: + self.dbname = dsn_parameters.get("dbname") + self.user = dsn_parameters.get("user") + self.host = dsn_parameters.get("host") + self.port = dsn_parameters.get("port") + else: + self.dbname = conn_params.get("database") + self.user = conn_params.get("user") + self.host = conn_params.get("host") + self.port = conn_params.get("port") + + self.password = password + self.extra_args = kwargs + + if not self.host: + self.host = ( + "pgbouncer" + if self.is_virtual_database() + else self.get_socket_directory() + ) + + self.pid = conn.info.backend_pid + self.superuser = conn.info.parameter_status("is_superuser") in ("on", "1") + self.server_version = conn.info.parameter_status("server_version") or "" + + # _set_wait_callback(self.is_virtual_database()) + + if not self.is_virtual_database(): + register_typecasters(conn) + + @property + def short_host(self): + if "," in self.host: + host, _, _ = self.host.partition(",") + else: + host = self.host + short_host, _, _ = host.partition(".") + return short_host + + def _select_one(self, cur, sql): + """ + Helper method to run a select and retrieve a single field value + :param cur: cursor + :param sql: string + :return: string + """ + cur.execute(sql) + return cur.fetchone() + + def failed_transaction(self): + return self.conn.info.transaction_status == psycopg.pq.TransactionStatus.INERROR + + def valid_transaction(self): + status = self.conn.info.transaction_status + return ( + status == psycopg.pq.TransactionStatus.ACTIVE + or status == psycopg.pq.TransactionStatus.INTRANS + ) + + def run( + self, + statement, + pgspecial=None, + exception_formatter=None, + on_error_resume=False, + explain_mode=False, + ): + """Execute the sql in the database and return the results. + + :param statement: A string containing one or more sql statements + :param pgspecial: PGSpecial object + :param exception_formatter: A callable that accepts an Exception and + returns a formatted (title, rows, headers, status) tuple that can + act as a query result. If an exception_formatter is not supplied, + psycopg2 exceptions are always raised. + :param on_error_resume: Bool. If true, queries following an exception + (assuming exception_formatter has been supplied) continue to + execute. + + :return: Generator yielding tuples containing + (title, rows, headers, status, query, success, is_special) + """ + + # Remove spaces and EOL + statement = statement.strip() + if not statement: # Empty string + yield None, None, None, None, statement, False, False + + # sql parse doesn't split on a comment first + special + # so we're going to do it + + removed_comments = [] + sqlarr = [] + cleaned_command = "" + + # could skip if statement doesn't match ^-- or ^/* + cleaned_command, removed_comments = remove_beginning_comments(statement) + + sqlarr = sqlparse.split(cleaned_command) + + # now re-add the beginning comments if there are any, so that they show up in + # log files etc when running these commands + + if len(removed_comments) > 0: + sqlarr = removed_comments + sqlarr + + # run each sql query + for sql in sqlarr: + # Remove spaces, eol and semi-colons. + sql = sql.rstrip(";") + sql = sqlparse.format(sql, strip_comments=False).strip() + if not sql: + continue + try: + if explain_mode: + sql = self.explain_prefix() + sql + elif pgspecial: + # \G is treated specially since we have to set the expanded output. + if sql.endswith("\\G"): + if not pgspecial.expanded_output: + pgspecial.expanded_output = True + self.reset_expanded = True + sql = sql[:-2].strip() + + # First try to run each query as special + _logger.debug("Trying a pgspecial command. sql: %r", sql) + try: + cur = self.conn.cursor() + except psycopg.InterfaceError: + # edge case when connection is already closed, but we + # don't need cursor for special_cmd.arg_type == NO_QUERY. + # See https://github.com/dbcli/pgcli/issues/1014. + cur = None + try: + response = pgspecial.execute(cur, sql) + if cur and cur.protocol_error: + yield None, None, None, cur.protocol_message, statement, False, False + # this would close connection. We should reconnect. + self.connect() + continue + for result in response: + # e.g. execute_from_file already appends these + if len(result) < 7: + yield result + (sql, True, True) + else: + yield result + continue + except special.CommandNotFound: + pass + + # Not a special command, so execute as normal sql + yield self.execute_normal_sql(sql) + (sql, True, False) + except psycopg.DatabaseError as e: + _logger.error("sql: %r, error: %r", sql, e) + _logger.error("traceback: %r", traceback.format_exc()) + + if self._must_raise(e) or not exception_formatter: + raise + + yield None, None, None, exception_formatter(e), sql, False, False + + if not on_error_resume: + break + finally: + if self.reset_expanded: + pgspecial.expanded_output = False + self.reset_expanded = None + + def _must_raise(self, e): + """Return true if e is an error that should not be caught in ``run``. + + An uncaught error will prompt the user to reconnect; as long as we + detect that the connection is still open, we catch the error, as + reconnecting won't solve that problem. + + :param e: DatabaseError. An exception raised while executing a query. + + :return: Bool. True if ``run`` must raise this exception. + + """ + return self.conn.closed != 0 + + def execute_normal_sql(self, split_sql): + """Returns tuple (title, rows, headers, status)""" + _logger.debug("Regular sql statement. sql: %r", split_sql) + + title = "" + + def handle_notices(n): + nonlocal title + title = f"{n.message_primary}\n{n.message_detail}\n{title}" + + self.conn.add_notice_handler(handle_notices) + + if self.is_virtual_database() and "show help" in split_sql.lower(): + # see https://github.com/psycopg/psycopg/issues/303 + # special case "show help" in pgbouncer + res = self.conn.pgconn.exec_(split_sql.encode()) + return title, None, None, res.command_status.decode() + + cur = self.conn.cursor() + cur.execute(split_sql) + + # cur.description will be None for operations that do not return + # rows. + if cur.description: + headers = [x[0] for x in cur.description] + return title, cur, headers, cur.statusmessage + elif cur.protocol_error: + _logger.debug("Protocol error, unsupported command.") + return title, None, None, cur.protocol_message + else: + _logger.debug("No rows in result.") + return title, None, None, cur.statusmessage + + def search_path(self): + """Returns the current search path as a list of schema names""" + + try: + with self.conn.cursor() as cur: + _logger.debug("Search path query. sql: %r", self.search_path_query) + cur.execute(self.search_path_query) + return [x[0] for x in cur.fetchall()] + except psycopg.ProgrammingError: + fallback = "SELECT * FROM current_schemas(true)" + with self.conn.cursor() as cur: + _logger.debug("Search path query. sql: %r", fallback) + cur.execute(fallback) + return cur.fetchone()[0] + + def view_definition(self, spec): + """Returns the SQL defining views described by `spec`""" + + # 2: relkind, v or m (materialized) + # 4: reloptions, null + # 5: checkoption: local or cascaded + with self.conn.cursor() as cur: + sql = self.view_definition_query + _logger.debug("View Definition Query. sql: %r\nspec: %r", sql, spec) + try: + cur.execute(sql, (spec,)) + except psycopg.ProgrammingError: + raise RuntimeError(f"View {spec} does not exist.") + result = ViewDef(*cur.fetchone()) + if result.relkind == "m": + template = "CREATE OR REPLACE MATERIALIZED VIEW {name} AS \n{stmt}" + else: + template = "CREATE OR REPLACE VIEW {name} AS \n{stmt}" + return ( + psycopg.sql.SQL(template) + .format( + name=psycopg.sql.Identifier(result.nspname, result.relname), + stmt=psycopg.sql.SQL(result.viewdef), + ) + .as_string(self.conn) + ) + + def function_definition(self, spec): + """Returns the SQL defining functions described by `spec`""" + + with self.conn.cursor() as cur: + sql = self.function_definition_query + _logger.debug("Function Definition Query. sql: %r\nspec: %r", sql, spec) + try: + cur.execute(sql, (spec,)) + result = cur.fetchone() + return result[0] + except psycopg.ProgrammingError: + raise RuntimeError(f"Function {spec} does not exist.") + + def schemata(self): + """Returns a list of schema names in the database""" + + with self.conn.cursor() as cur: + _logger.debug("Schemata Query. sql: %r", self.schemata_query) + cur.execute(self.schemata_query) + return [x[0] for x in cur.fetchall()] + + def _relations(self, kinds=("r", "p", "f", "v", "m")): + """Get table or view name metadata + + :param kinds: list of postgres relkind filters: + 'r' - table + 'p' - partitioned table + 'f' - foreign table + 'v' - view + 'm' - materialized view + :return: (schema_name, rel_name) tuples + """ + + with self.conn.cursor() as cur: + # sql = cur.mogrify(self.tables_query, kinds) + # _logger.debug("Tables Query. sql: %r", sql) + cur.execute(self.tables_query, [kinds]) + yield from cur + + def tables(self): + """Yields (schema_name, table_name) tuples""" + yield from self._relations(kinds=["r", "p", "f"]) + + def views(self): + """Yields (schema_name, view_name) tuples. + + Includes both views and and materialized views + """ + yield from self._relations(kinds=["v", "m"]) + + def _columns(self, kinds=("r", "p", "f", "v", "m")): + """Get column metadata for tables and views + + :param kinds: kinds: list of postgres relkind filters: + 'r' - table + 'p' - partitioned table + 'f' - foreign table + 'v' - view + 'm' - materialized view + :return: list of (schema_name, relation_name, column_name, column_type) tuples + """ + + if self.conn.info.server_version >= 80400: + columns_query = """ + SELECT nsp.nspname schema_name, + cls.relname table_name, + att.attname column_name, + att.atttypid::regtype::text type_name, + att.atthasdef AS has_default, + pg_catalog.pg_get_expr(def.adbin, def.adrelid, true) as default + FROM pg_catalog.pg_attribute att + INNER JOIN pg_catalog.pg_class cls + ON att.attrelid = cls.oid + INNER JOIN pg_catalog.pg_namespace nsp + ON cls.relnamespace = nsp.oid + LEFT OUTER JOIN pg_attrdef def + ON def.adrelid = att.attrelid + AND def.adnum = att.attnum + WHERE cls.relkind = ANY(%s) + AND NOT att.attisdropped + AND att.attnum > 0 + ORDER BY 1, 2, att.attnum""" + else: + columns_query = """ + SELECT nsp.nspname schema_name, + cls.relname table_name, + att.attname column_name, + typ.typname type_name, + NULL AS has_default, + NULL AS default + FROM pg_catalog.pg_attribute att + INNER JOIN pg_catalog.pg_class cls + ON att.attrelid = cls.oid + INNER JOIN pg_catalog.pg_namespace nsp + ON cls.relnamespace = nsp.oid + INNER JOIN pg_catalog.pg_type typ + ON typ.oid = att.atttypid + WHERE cls.relkind = ANY(%s) + AND NOT att.attisdropped + AND att.attnum > 0 + ORDER BY 1, 2, att.attnum""" + + with self.conn.cursor() as cur: + # sql = cur.mogrify(columns_query, kinds) + # _logger.debug("Columns Query. sql: %r", sql) + cur.execute(columns_query, [kinds]) + yield from cur + + def table_columns(self): + yield from self._columns(kinds=["r", "p", "f"]) + + def view_columns(self): + yield from self._columns(kinds=["v", "m"]) + + def databases(self): + with self.conn.cursor() as cur: + _logger.debug("Databases Query. sql: %r", self.databases_query) + cur.execute(self.databases_query) + return [x[0] for x in cur.fetchall()] + + def full_databases(self): + with self.conn.cursor() as cur: + _logger.debug("Databases Query. sql: %r", self.full_databases_query) + cur.execute(self.full_databases_query) + headers = [x[0] for x in cur.description] + return cur.fetchall(), headers, cur.statusmessage + + def is_protocol_error(self): + query = "SELECT 1" + with self.conn.cursor() as cur: + _logger.debug("Simple Query. sql: %r", query) + cur.execute(query) + return bool(cur.protocol_error) + + def get_socket_directory(self): + with self.conn.cursor() as cur: + _logger.debug( + "Socket directory Query. sql: %r", self.socket_directory_query + ) + cur.execute(self.socket_directory_query) + result = cur.fetchone() + return result[0] if result else "" + + def foreignkeys(self): + """Yields ForeignKey named tuples""" + + if self.conn.info.server_version < 90000: + return + + with self.conn.cursor() as cur: + query = """ + SELECT s_p.nspname AS parentschema, + t_p.relname AS parenttable, + unnest(( + select + array_agg(attname ORDER BY i) + from + (select unnest(confkey) as attnum, generate_subscripts(confkey, 1) as i) x + JOIN pg_catalog.pg_attribute c USING(attnum) + WHERE c.attrelid = fk.confrelid + )) AS parentcolumn, + s_c.nspname AS childschema, + t_c.relname AS childtable, + unnest(( + select + array_agg(attname ORDER BY i) + from + (select unnest(conkey) as attnum, generate_subscripts(conkey, 1) as i) x + JOIN pg_catalog.pg_attribute c USING(attnum) + WHERE c.attrelid = fk.conrelid + )) AS childcolumn + FROM pg_catalog.pg_constraint fk + JOIN pg_catalog.pg_class t_p ON t_p.oid = fk.confrelid + JOIN pg_catalog.pg_namespace s_p ON s_p.oid = t_p.relnamespace + JOIN pg_catalog.pg_class t_c ON t_c.oid = fk.conrelid + JOIN pg_catalog.pg_namespace s_c ON s_c.oid = t_c.relnamespace + WHERE fk.contype = 'f'; + """ + _logger.debug("Functions Query. sql: %r", query) + cur.execute(query) + for row in cur: + yield ForeignKey(*row) + + def functions(self): + """Yields FunctionMetadata named tuples""" + + if self.conn.info.server_version >= 110000: + query = """ + SELECT n.nspname schema_name, + p.proname func_name, + p.proargnames, + COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[], + p.proargmodes, + prorettype::regtype::text return_type, + p.prokind = 'a' is_aggregate, + p.prokind = 'w' is_window, + p.proretset is_set_returning, + d.deptype = 'e' is_extension, + pg_get_expr(proargdefaults, 0) AS arg_defaults + FROM pg_catalog.pg_proc p + INNER JOIN pg_catalog.pg_namespace n + ON n.oid = p.pronamespace + LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' + WHERE p.prorettype::regtype != 'trigger'::regtype + ORDER BY 1, 2 + """ + elif self.conn.info.server_version > 90000: + query = """ + SELECT n.nspname schema_name, + p.proname func_name, + p.proargnames, + COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[], + p.proargmodes, + prorettype::regtype::text return_type, + p.proisagg is_aggregate, + p.proiswindow is_window, + p.proretset is_set_returning, + d.deptype = 'e' is_extension, + pg_get_expr(proargdefaults, 0) AS arg_defaults + FROM pg_catalog.pg_proc p + INNER JOIN pg_catalog.pg_namespace n + ON n.oid = p.pronamespace + LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' + WHERE p.prorettype::regtype != 'trigger'::regtype + ORDER BY 1, 2 + """ + elif self.conn.info.server_version >= 80400: + query = """ + SELECT n.nspname schema_name, + p.proname func_name, + p.proargnames, + COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[], + p.proargmodes, + prorettype::regtype::text, + p.proisagg is_aggregate, + false is_window, + p.proretset is_set_returning, + d.deptype = 'e' is_extension, + NULL AS arg_defaults + FROM pg_catalog.pg_proc p + INNER JOIN pg_catalog.pg_namespace n + ON n.oid = p.pronamespace + LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' + WHERE p.prorettype::regtype != 'trigger'::regtype + ORDER BY 1, 2 + """ + else: + query = """ + SELECT n.nspname schema_name, + p.proname func_name, + p.proargnames, + NULL arg_types, + NULL arg_modes, + '' ret_type, + p.proisagg is_aggregate, + false is_window, + p.proretset is_set_returning, + d.deptype = 'e' is_extension, + NULL AS arg_defaults + FROM pg_catalog.pg_proc p + INNER JOIN pg_catalog.pg_namespace n + ON n.oid = p.pronamespace + LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' + WHERE p.prorettype::regtype != 'trigger'::regtype + ORDER BY 1, 2 + """ + + with self.conn.cursor() as cur: + _logger.debug("Functions Query. sql: %r", query) + cur.execute(query) + for row in cur: + yield FunctionMetadata(*row) + + def datatypes(self): + """Yields tuples of (schema_name, type_name)""" + + with self.conn.cursor() as cur: + if self.conn.info.server_version > 90000: + query = """ + SELECT n.nspname schema_name, + t.typname type_name + FROM pg_catalog.pg_type t + INNER JOIN pg_catalog.pg_namespace n + ON n.oid = t.typnamespace + WHERE ( t.typrelid = 0 -- non-composite types + OR ( -- composite type, but not a table + SELECT c.relkind = 'c' + FROM pg_catalog.pg_class c + WHERE c.oid = t.typrelid + ) + ) + AND NOT EXISTS( -- ignore array types + SELECT 1 + FROM pg_catalog.pg_type el + WHERE el.oid = t.typelem AND el.typarray = t.oid + ) + AND n.nspname <> 'pg_catalog' + AND n.nspname <> 'information_schema' + ORDER BY 1, 2; + """ + else: + query = """ + SELECT n.nspname schema_name, + pg_catalog.format_type(t.oid, NULL) type_name + FROM pg_catalog.pg_type t + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace + WHERE (t.typrelid = 0 OR (SELECT c.relkind = 'c' FROM pg_catalog.pg_class c WHERE c.oid = t.typrelid)) + AND t.typname !~ '^_' + AND n.nspname <> 'pg_catalog' + AND n.nspname <> 'information_schema' + AND pg_catalog.pg_type_is_visible(t.oid) + ORDER BY 1, 2; + """ + _logger.debug("Datatypes Query. sql: %r", query) + cur.execute(query) + yield from cur + + def casing(self): + """Yields the most common casing for names used in db functions""" + with self.conn.cursor() as cur: + query = r""" + WITH Words AS ( + SELECT regexp_split_to_table(prosrc, '\W+') AS Word, COUNT(1) + FROM pg_catalog.pg_proc P + JOIN pg_catalog.pg_namespace N ON N.oid = P.pronamespace + JOIN pg_catalog.pg_language L ON L.oid = P.prolang + WHERE L.lanname IN ('sql', 'plpgsql') + AND N.nspname NOT IN ('pg_catalog', 'information_schema') + GROUP BY Word + ), + OrderWords AS ( + SELECT Word, + ROW_NUMBER() OVER(PARTITION BY LOWER(Word) ORDER BY Count DESC) + FROM Words + WHERE Word ~* '.*[a-z].*' + ), + Names AS ( + --Column names + SELECT attname AS Name + FROM pg_catalog.pg_attribute + UNION -- Table/view names + SELECT relname + FROM pg_catalog.pg_class + UNION -- Function names + SELECT proname + FROM pg_catalog.pg_proc + UNION -- Type names + SELECT typname + FROM pg_catalog.pg_type + UNION -- Schema names + SELECT nspname + FROM pg_catalog.pg_namespace + UNION -- Parameter names + SELECT unnest(proargnames) + FROM pg_proc + ) + SELECT Word + FROM OrderWords + WHERE LOWER(Word) IN (SELECT Name FROM Names) + AND Row_Number = 1; + """ + _logger.debug("Casing Query. sql: %r", query) + cur.execute(query) + for row in cur: + yield row[0] + + def explain_prefix(self): + return "EXPLAIN (ANALYZE, COSTS, VERBOSE, BUFFERS, FORMAT JSON) " diff --git a/pgcli/pgstyle.py b/pgcli/pgstyle.py new file mode 100644 index 0000000..77874f4 --- /dev/null +++ b/pgcli/pgstyle.py @@ -0,0 +1,116 @@ +import logging + +import pygments.styles +from pygments.token import string_to_tokentype, Token +from pygments.style import Style as PygmentsStyle +from pygments.util import ClassNotFound +from prompt_toolkit.styles.pygments import style_from_pygments_cls +from prompt_toolkit.styles import merge_styles, Style + +logger = logging.getLogger(__name__) + +# map Pygments tokens (ptk 1.0) to class names (ptk 2.0). +TOKEN_TO_PROMPT_STYLE = { + Token.Menu.Completions.Completion.Current: "completion-menu.completion.current", + Token.Menu.Completions.Completion: "completion-menu.completion", + Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current", + Token.Menu.Completions.Meta: "completion-menu.meta.completion", + Token.Menu.Completions.MultiColumnMeta: "completion-menu.multi-column-meta", + Token.Menu.Completions.ProgressButton: "scrollbar.arrow", # best guess + Token.Menu.Completions.ProgressBar: "scrollbar", # best guess + Token.SelectedText: "selected", + Token.SearchMatch: "search", + Token.SearchMatch.Current: "search.current", + Token.Toolbar: "bottom-toolbar", + Token.Toolbar.Off: "bottom-toolbar.off", + Token.Toolbar.On: "bottom-toolbar.on", + Token.Toolbar.Search: "search-toolbar", + Token.Toolbar.Search.Text: "search-toolbar.text", + Token.Toolbar.System: "system-toolbar", + Token.Toolbar.Arg: "arg-toolbar", + Token.Toolbar.Arg.Text: "arg-toolbar.text", + Token.Toolbar.Transaction.Valid: "bottom-toolbar.transaction.valid", + Token.Toolbar.Transaction.Failed: "bottom-toolbar.transaction.failed", + Token.Output.Header: "output.header", + Token.Output.OddRow: "output.odd-row", + Token.Output.EvenRow: "output.even-row", + Token.Output.Null: "output.null", + Token.Literal.String: "literal.string", + Token.Literal.Number: "literal.number", + Token.Keyword: "keyword", + Token.Prompt: "prompt", + Token.Continuation: "continuation", +} + +# reverse dict for cli_helpers, because they still expect Pygments tokens. +PROMPT_STYLE_TO_TOKEN = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()} + + +def parse_pygments_style(token_name, style_object, style_dict): + """Parse token type and style string. + + :param token_name: str name of Pygments token. Example: "Token.String" + :param style_object: pygments.style.Style instance to use as base + :param style_dict: dict of token names and their styles, customized to this cli + + """ + token_type = string_to_tokentype(token_name) + try: + other_token_type = string_to_tokentype(style_dict[token_name]) + return token_type, style_object.styles[other_token_type] + except AttributeError: + return token_type, style_dict[token_name] + + +def style_factory(name, cli_style): + try: + style = pygments.styles.get_style_by_name(name) + except ClassNotFound: + style = pygments.styles.get_style_by_name("native") + + prompt_styles = [] + # prompt-toolkit used pygments tokens for styling before, switched to style + # names in 2.0. Convert old token types to new style names, for backwards compatibility. + for token in cli_style: + if token.startswith("Token."): + # treat as pygments token (1.0) + token_type, style_value = parse_pygments_style(token, style, cli_style) + if token_type in TOKEN_TO_PROMPT_STYLE: + prompt_style = TOKEN_TO_PROMPT_STYLE[token_type] + prompt_styles.append((prompt_style, style_value)) + else: + # we don't want to support tokens anymore + logger.error("Unhandled style / class name: %s", token) + else: + # treat as prompt style name (2.0). See default style names here: + # https://github.com/prompt-toolkit/python-prompt-toolkit/blob/master/src/prompt_toolkit/styles/defaults.py + prompt_styles.append((token, cli_style[token])) + + override_style = Style([("bottom-toolbar", "noreverse")]) + return merge_styles( + [style_from_pygments_cls(style), override_style, Style(prompt_styles)] + ) + + +def style_factory_output(name, cli_style): + try: + style = pygments.styles.get_style_by_name(name).styles + except ClassNotFound: + style = pygments.styles.get_style_by_name("native").styles + + for token in cli_style: + if token.startswith("Token."): + token_type, style_value = parse_pygments_style(token, style, cli_style) + style.update({token_type: style_value}) + elif token in PROMPT_STYLE_TO_TOKEN: + token_type = PROMPT_STYLE_TO_TOKEN[token] + style.update({token_type: cli_style[token]}) + else: + # TODO: cli helpers will have to switch to ptk.Style + logger.error("Unhandled style / class name: %s", token) + + class OutputStyle(PygmentsStyle): + default_style = "" + styles = style + + return OutputStyle diff --git a/pgcli/pgtoolbar.py b/pgcli/pgtoolbar.py new file mode 100644 index 0000000..4a12ff4 --- /dev/null +++ b/pgcli/pgtoolbar.py @@ -0,0 +1,71 @@ +from prompt_toolkit.key_binding.vi_state import InputMode +from prompt_toolkit.application import get_app + +vi_modes = { + InputMode.INSERT: "I", + InputMode.NAVIGATION: "N", + InputMode.REPLACE: "R", + InputMode.INSERT_MULTIPLE: "M", +} +# REPLACE_SINGLE is available in prompt_toolkit >= 3.0.6 +if "REPLACE_SINGLE" in {e.name for e in InputMode}: + vi_modes[InputMode.REPLACE_SINGLE] = "R" + + +def _get_vi_mode(): + return vi_modes[get_app().vi_state.input_mode] + + +def create_toolbar_tokens_func(pgcli): + """Return a function that generates the toolbar tokens.""" + + def get_toolbar_tokens(): + result = [] + result.append(("class:bottom-toolbar", " ")) + + if pgcli.completer.smart_completion: + result.append(("class:bottom-toolbar.on", "[F2] Smart Completion: ON ")) + else: + result.append(("class:bottom-toolbar.off", "[F2] Smart Completion: OFF ")) + + if pgcli.multi_line: + result.append(("class:bottom-toolbar.on", "[F3] Multiline: ON ")) + else: + result.append(("class:bottom-toolbar.off", "[F3] Multiline: OFF ")) + + if pgcli.multi_line: + if pgcli.multiline_mode == "safe": + result.append(("class:bottom-toolbar", " ([Esc] [Enter] to execute]) ")) + else: + result.append( + ("class:bottom-toolbar", " (Semi-colon [;] will end the line) ") + ) + + if pgcli.vi_mode: + result.append( + ("class:bottom-toolbar", "[F4] Vi-mode (" + _get_vi_mode() + ") ") + ) + else: + result.append(("class:bottom-toolbar", "[F4] Emacs-mode ")) + + if pgcli.explain_mode: + result.append(("class:bottom-toolbar", "[F5] Explain: ON ")) + else: + result.append(("class:bottom-toolbar", "[F5] Explain: OFF ")) + + if pgcli.pgexecute.failed_transaction(): + result.append( + ("class:bottom-toolbar.transaction.failed", " Failed transaction") + ) + + if pgcli.pgexecute.valid_transaction(): + result.append( + ("class:bottom-toolbar.transaction.valid", " Transaction") + ) + + if pgcli.completion_refresher.is_refreshing(): + result.append(("class:bottom-toolbar", " Refreshing completions...")) + + return result + + return get_toolbar_tokens diff --git a/pgcli/pyev.py b/pgcli/pyev.py new file mode 100644 index 0000000..2886c9c --- /dev/null +++ b/pgcli/pyev.py @@ -0,0 +1,439 @@ +import textwrap +import re +from click import style as color + +DESCRIPTIONS = { + "Append": "Used in a UNION to merge multiple record sets by appending them together.", + "Limit": "Returns a specified number of rows from a record set.", + "Sort": "Sorts a record set based on the specified sort key.", + "Nested Loop": "Merges two record sets by looping through every record in the first set and trying to find a match in the second set. All matching records are returned.", + "Merge Join": "Merges two record sets by first sorting them on a join key.", + "Hash": "Generates a hash table from the records in the input recordset. Hash is used by Hash Join.", + "Hash Join": "Joins to record sets by hashing one of them (using a Hash Scan).", + "Aggregate": "Groups records together based on a GROUP BY or aggregate function (e.g. sum()).", + "Hashaggregate": "Groups records together based on a GROUP BY or aggregate function (e.g. sum()). Hash Aggregate uses a hash to first organize the records by a key.", + "Sequence Scan": "Finds relevant records by sequentially scanning the input record set. When reading from a table, Seq Scans (unlike Index Scans) perform a single read operation (only the table is read).", + "Seq Scan": "Finds relevant records by sequentially scanning the input record set. When reading from a table, Seq Scans (unlike Index Scans) perform a single read operation (only the table is read).", + "Index Scan": "Finds relevant records based on an Index. Index Scans perform 2 read operations: one to read the index and another to read the actual value from the table.", + "Index Only Scan": "Finds relevant records based on an Index. Index Only Scans perform a single read operation from the index and do not read from the corresponding table.", + "Bitmap Heap Scan": "Searches through the pages returned by the Bitmap Index Scan for relevant rows.", + "Bitmap Index Scan": "Uses a Bitmap Index (index which uses 1 bit per page) to find all relevant pages. Results of this node are fed to the Bitmap Heap Scan.", + "CTEScan": "Performs a sequential scan of Common Table Expression (CTE) query results. Note that results of a CTE are materialized (calculated and temporarily stored).", + "ProjectSet": "ProjectSet appears when the SELECT or ORDER BY clause of the query. They basically just execute the set-returning function(s) for each tuple until none of the functions return any more records.", + "Result": "Returns result", +} + + +class Visualizer: + def __init__(self, terminal_width=100, color=True): + self.color = color + self.terminal_width = terminal_width + self.string_lines = [] + + def load(self, explain_dict): + self.plan = explain_dict.pop("Plan") + self.explain = explain_dict + self.process_all() + self.generate_lines() + + def process_all(self): + self.plan = self.process_plan(self.plan) + self.plan = self.calculate_outlier_nodes(self.plan) + + # + def process_plan(self, plan): + plan = self.calculate_planner_estimate(plan) + plan = self.calculate_actuals(plan) + self.calculate_maximums(plan) + # + for index in range(len(plan.get("Plans", []))): + _plan = plan["Plans"][index] + plan["Plans"][index] = self.process_plan(_plan) + return plan + + def prefix_format(self, v): + if self.color: + return color(v, fg="bright_black") + return v + + def tag_format(self, v): + if self.color: + return color(v, fg="white", bg="red") + return v + + def muted_format(self, v): + if self.color: + return color(v, fg="bright_black") + return v + + def bold_format(self, v): + if self.color: + return color(v, fg="white") + return v + + def good_format(self, v): + if self.color: + return color(v, fg="green") + return v + + def warning_format(self, v): + if self.color: + return color(v, fg="yellow") + return v + + def critical_format(self, v): + if self.color: + return color(v, fg="red") + return v + + def output_format(self, v): + if self.color: + return color(v, fg="cyan") + return v + + def calculate_planner_estimate(self, plan): + plan["Planner Row Estimate Factor"] = 0 + plan["Planner Row Estimate Direction"] = "Under" + + if plan["Plan Rows"] == plan["Actual Rows"]: + return plan + + if plan["Plan Rows"] != 0: + plan["Planner Row Estimate Factor"] = ( + plan["Actual Rows"] / plan["Plan Rows"] + ) + + if plan["Planner Row Estimate Factor"] < 10: + plan["Planner Row Estimate Factor"] = 0 + plan["Planner Row Estimate Direction"] = "Over" + if plan["Actual Rows"] != 0: + plan["Planner Row Estimate Factor"] = ( + plan["Plan Rows"] / plan["Actual Rows"] + ) + return plan + + # + def calculate_actuals(self, plan): + plan["Actual Duration"] = plan["Actual Total Time"] + plan["Actual Cost"] = plan["Total Cost"] + + for child in plan.get("Plans", []): + if child["Node Type"] != "CTEScan": + plan["Actual Duration"] = ( + plan["Actual Duration"] - child["Actual Total Time"] + ) + plan["Actual Cost"] = plan["Actual Cost"] - child["Total Cost"] + + if plan["Actual Cost"] < 0: + plan["Actual Cost"] = 0 + + plan["Actual Duration"] = plan["Actual Duration"] * plan["Actual Loops"] + return plan + + def calculate_outlier_nodes(self, plan): + plan["Costliest"] = plan["Actual Cost"] == self.explain["Max Cost"] + plan["Largest"] = plan["Actual Rows"] == self.explain["Max Rows"] + plan["Slowest"] = plan["Actual Duration"] == self.explain["Max Duration"] + + for index in range(len(plan.get("Plans", []))): + _plan = plan["Plans"][index] + plan["Plans"][index] = self.calculate_outlier_nodes(_plan) + return plan + + def calculate_maximums(self, plan): + if not self.explain.get("Max Rows"): + self.explain["Max Rows"] = plan["Actual Rows"] + elif self.explain.get("Max Rows") < plan["Actual Rows"]: + self.explain["Max Rows"] = plan["Actual Rows"] + + if not self.explain.get("Max Cost"): + self.explain["Max Cost"] = plan["Actual Cost"] + elif self.explain.get("Max Cost") < plan["Actual Cost"]: + self.explain["Max Cost"] = plan["Actual Cost"] + + if not self.explain.get("Max Duration"): + self.explain["Max Duration"] = plan["Actual Duration"] + elif self.explain.get("Max Duration") < plan["Actual Duration"]: + self.explain["Max Duration"] = plan["Actual Duration"] + + if not self.explain.get("Total Cost"): + self.explain["Total Cost"] = plan["Actual Cost"] + elif self.explain.get("Total Cost") < plan["Actual Cost"]: + self.explain["Total Cost"] = plan["Actual Cost"] + + # + def duration_to_string(self, value): + if value < 1: + return self.good_format("<1 ms") + elif value < 100: + return self.good_format("%.2f ms" % value) + elif value < 1000: + return self.warning_format("%.2f ms" % value) + elif value < 60000: + return self.critical_format( + "%.2f s" % (value / 1000.0), + ) + else: + return self.critical_format( + "%.2f m" % (value / 60000.0), + ) + + # } + # + def format_details(self, plan): + details = [] + + if plan.get("Scan Direction"): + details.append(plan["Scan Direction"]) + + if plan.get("Strategy"): + details.append(plan["Strategy"]) + + if len(details) > 0: + return self.muted_format(" [%s]" % ", ".join(details)) + + return "" + + def format_tags(self, plan): + tags = [] + + if plan["Slowest"]: + tags.append(self.tag_format("slowest")) + if plan["Costliest"]: + tags.append(self.tag_format("costliest")) + if plan["Largest"]: + tags.append(self.tag_format("largest")) + if plan.get("Planner Row Estimate Factor", 0) >= 100: + tags.append(self.tag_format("bad estimate")) + + return " ".join(tags) + + def get_terminator(self, index, plan): + if index == 0: + if len(plan.get("Plans", [])) == 0: + return "⌡► " + else: + return "├► " + else: + if len(plan.get("Plans", [])) == 0: + return " " + else: + return "│ " + + def wrap_string(self, line, width): + if width == 0: + return [line] + return textwrap.wrap(line, width) + + def intcomma(self, value): + sep = "," + if not isinstance(value, str): + value = int(value) + + orig = str(value) + + new = re.sub(r"^(-?\d+)(\d{3})", rf"\g<1>{sep}\g<2>", orig) + if orig == new: + return new + else: + return self.intcomma(new) + + def output_fn(self, current_prefix, string): + return "%s%s" % (self.prefix_format(current_prefix), string) + + def create_lines(self, plan, prefix, depth, width, last_child): + current_prefix = prefix + self.string_lines.append( + self.output_fn(current_prefix, self.prefix_format("│")) + ) + + joint = "├" + if last_child: + joint = "└" + # + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %s%s %s" + % ( + self.prefix_format(joint + "─⌠"), + self.bold_format(plan["Node Type"]), + self.format_details(plan), + self.format_tags(plan), + ), + ) + ) + # + if last_child: + prefix += " " + else: + prefix += "│ " + + current_prefix = prefix + "│ " + + cols = width - len(current_prefix) + + for line in self.wrap_string( + DESCRIPTIONS.get(plan["Node Type"], "Not found : %s" % plan["Node Type"]), + cols, + ): + self.string_lines.append( + self.output_fn(current_prefix, "%s" % self.muted_format(line)) + ) + # + if plan.get("Actual Duration"): + self.string_lines.append( + self.output_fn( + current_prefix, + "○ %s %s (%.0f%%)" + % ( + "Duration:", + self.duration_to_string(plan["Actual Duration"]), + (plan["Actual Duration"] / self.explain["Execution Time"]) + * 100, + ), + ) + ) + + self.string_lines.append( + self.output_fn( + current_prefix, + "○ %s %s (%.0f%%)" + % ( + "Cost:", + self.intcomma(plan["Actual Cost"]), + (plan["Actual Cost"] / self.explain["Total Cost"]) * 100, + ), + ) + ) + + self.string_lines.append( + self.output_fn( + current_prefix, + "○ %s %s" % ("Rows:", self.intcomma(plan["Actual Rows"])), + ) + ) + + current_prefix = current_prefix + " " + + if plan.get("Join Type"): + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %s" % (plan["Join Type"], self.muted_format("join")), + ) + ) + + if plan.get("Relation Name"): + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %s.%s" + % ( + self.muted_format("on"), + plan.get("Schema", "unknown"), + plan["Relation Name"], + ), + ) + ) + + if plan.get("Index Name"): + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %s" % (self.muted_format("using"), plan["Index Name"]), + ) + ) + + if plan.get("Index Condition"): + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %s" % (self.muted_format("condition"), plan["Index Condition"]), + ) + ) + + if plan.get("Filter"): + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %s %s" + % ( + self.muted_format("filter"), + plan["Filter"], + self.muted_format( + "[-%s rows]" % self.intcomma(plan["Rows Removed by Filter"]) + ), + ), + ) + ) + + if plan.get("Hash Condition"): + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %s" % (self.muted_format("on"), plan["Hash Condition"]), + ) + ) + + if plan.get("CTE Name"): + self.string_lines.append( + self.output_fn(current_prefix, "CTE %s" % plan["CTE Name"]) + ) + + if plan.get("Planner Row Estimate Factor") != 0: + self.string_lines.append( + self.output_fn( + current_prefix, + "%s %sestimated %s %.2fx" + % ( + self.muted_format("rows"), + plan["Planner Row Estimate Direction"], + self.muted_format("by"), + plan["Planner Row Estimate Factor"], + ), + ) + ) + + current_prefix = prefix + + if len(plan.get("Output", [])) > 0: + for index, line in enumerate( + self.wrap_string(" + ".join(plan["Output"]), cols) + ): + self.string_lines.append( + self.output_fn( + current_prefix, + self.prefix_format(self.get_terminator(index, plan)) + + self.output_format(line), + ) + ) + + for index, nested_plan in enumerate(plan.get("Plans", [])): + self.create_lines( + nested_plan, prefix, depth + 1, width, index == len(plan["Plans"]) - 1 + ) + + def generate_lines(self): + self.string_lines = [ + "○ Total Cost: %s" % self.intcomma(self.explain["Total Cost"]), + "○ Planning Time: %s" + % self.duration_to_string(self.explain["Planning Time"]), + "○ Execution Time: %s" + % self.duration_to_string(self.explain["Execution Time"]), + self.prefix_format("┬"), + ] + self.create_lines( + self.plan, + "", + 0, + self.terminal_width, + len(self.plan.get("Plans", [])) == 1, + ) + + def get_list(self): + return "\n".join(self.string_lines) + + def print(self): + for lin in self.string_lines: + print(lin) diff --git a/post-install b/post-install new file mode 100644 index 0000000..d516a3f --- /dev/null +++ b/post-install @@ -0,0 +1,4 @@ +#!/bin/bash + +echo "Setting up symlink to pgcli" +ln -sf /usr/share/pgcli/bin/pgcli /usr/local/bin/pgcli diff --git a/post-remove b/post-remove new file mode 100644 index 0000000..1013eb4 --- /dev/null +++ b/post-remove @@ -0,0 +1,4 @@ +#!/bin/bash + +echo "Removing symlink to pgcli" +rm /usr/local/bin/pgcli diff --git a/pylintrc b/pylintrc new file mode 100644 index 0000000..4aa448d --- /dev/null +++ b/pylintrc @@ -0,0 +1,2 @@ +[MESSAGES CONTROL] +disable=missing-docstring,invalid-name
\ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8477d72 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,21 @@ +[tool.black] +line-length = 88 +target-version = ['py38'] +include = '\.pyi?$' +exclude = ''' +/( + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | \.cache + | \.pytest_cache + | _build + | buck-out + | build + | dist + | tests/data +)/ +''' diff --git a/release.py b/release.py new file mode 100644 index 0000000..42a72a9 --- /dev/null +++ b/release.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python +"""A script to publish a release of pgcli to PyPI.""" + +import io +from optparse import OptionParser +import re +import subprocess +import sys + +import click + +DEBUG = False +CONFIRM_STEPS = False +DRY_RUN = False + + +def skip_step(): + """ + Asks for user's response whether to run a step. Default is yes. + :return: boolean + """ + global CONFIRM_STEPS + + if CONFIRM_STEPS: + return not click.confirm("--- Run this step?", default=True) + return False + + +def run_step(*args): + """ + Prints out the command and asks if it should be run. + If yes (default), runs it. + :param args: list of strings (command and args) + """ + global DRY_RUN + + cmd = args + print(" ".join(cmd)) + if skip_step(): + print("--- Skipping...") + elif DRY_RUN: + print("--- Pretending to run...") + else: + subprocess.check_output(cmd) + + +def version(version_file): + _version_re = re.compile( + r'__version__\s+=\s+(?P<quote>[\'"])(?P<version>.*)(?P=quote)' + ) + + with io.open(version_file, encoding="utf-8") as f: + ver = _version_re.search(f.read()).group("version") + + return ver + + +def commit_for_release(version_file, ver): + run_step("git", "reset") + run_step("git", "add", "-u") + run_step("git", "commit", "--message", "Releasing version {}".format(ver)) + + +def create_git_tag(tag_name): + run_step("git", "tag", tag_name) + + +def create_distribution_files(): + run_step("python", "setup.py", "clean", "--all", "sdist", "bdist_wheel") + + +def upload_distribution_files(): + run_step("twine", "upload", "dist/*") + + +def push_to_github(): + run_step("git", "push", "origin", "main") + + +def push_tags_to_github(): + run_step("git", "push", "--tags", "origin") + + +def checklist(questions): + for question in questions: + if not click.confirm("--- {}".format(question), default=False): + sys.exit(1) + + +if __name__ == "__main__": + if DEBUG: + subprocess.check_output = lambda x: x + + checks = [ + "Have you updated the AUTHORS file?", + "Have you updated the `Usage` section of the README?", + ] + checklist(checks) + + ver = version("pgcli/__init__.py") + print("Releasing Version:", ver) + + parser = OptionParser() + parser.add_option( + "-c", + "--confirm-steps", + action="store_true", + dest="confirm_steps", + default=False, + help=( + "Confirm every step. If the step is not " "confirmed, it will be skipped." + ), + ) + parser.add_option( + "-d", + "--dry-run", + action="store_true", + dest="dry_run", + default=False, + help="Print out, but not actually run any steps.", + ) + + popts, pargs = parser.parse_args() + CONFIRM_STEPS = popts.confirm_steps + DRY_RUN = popts.dry_run + + if not click.confirm("Are you sure?", default=False): + sys.exit(1) + + commit_for_release("pgcli/__init__.py", ver) + create_git_tag("v{}".format(ver)) + create_distribution_files() + push_to_github() + push_tags_to_github() + upload_distribution_files() diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..15505a7 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,13 @@ +pytest>=2.7.0 +tox>=1.9.2 +behave>=1.2.4 +black>=23.3.0 +pexpect==3.3; platform_system != "Windows" +pre-commit>=1.16.0 +coverage>=5.0.4 +codecov>=1.5.1 +docutils>=0.13.1 +autopep8>=1.3.3 +twine>=1.11.0 +wheel>=0.33.6 +sshtunnel>=0.4.0 diff --git a/sanity_checks.txt b/sanity_checks.txt new file mode 100644 index 0000000..d8a4898 --- /dev/null +++ b/sanity_checks.txt @@ -0,0 +1,37 @@ +# vi: ft=vimwiki + +* Launch pgcli with different inputs. + * pgcli test_db + * pgcli postgres://localhost/test_db + * pgcli postgres://localhost:5432/test_db + * pgcli postgres://amjith@localhost:5432/test_db + * pgcli postgres://amjith:password@localhost:5432/test_db + * pgcli non-existent-db + +* Test special command + * \d + * \d table_name + * \dt + * \l + * \c amjith + * \q + +* Simple execution: + 1 Execute a simple 'select * from users;' test that will pass. + 2 Execute a syntax error: 'insert into users ( ;' + 3 Execute a simple test from step 1 again to see if it still passes. + * Change the database and try steps 1 - 3. + +* Test smart-completion + * Sele - Must auto-complete to SELECT + * SELECT * FROM - Must list the table names. + * INSERT INTO - Must list table names. + * \d <tab> - Must list table names. + * \c <tab> - Database names. + * SELECT * FROM table_name WHERE <tab> - column names (all of it). + +* Test naive-completion - turn off smart completion (using F2 key after launch) + * Sele - autocomplete to select. + * SELECT * FROM - autocomplete list should have everything. + + diff --git a/screenshots/image01.png b/screenshots/image01.png Binary files differnew file mode 100644 index 0000000..58520c5 --- /dev/null +++ b/screenshots/image01.png diff --git a/screenshots/image02.png b/screenshots/image02.png Binary files differnew file mode 100644 index 0000000..c321c86 --- /dev/null +++ b/screenshots/image02.png diff --git a/screenshots/kharkiv-destroyed.jpg b/screenshots/kharkiv-destroyed.jpg Binary files differnew file mode 100644 index 0000000..4f95783 --- /dev/null +++ b/screenshots/kharkiv-destroyed.jpg diff --git a/screenshots/pgcli.gif b/screenshots/pgcli.gif Binary files differnew file mode 100644 index 0000000..9c8e66d --- /dev/null +++ b/screenshots/pgcli.gif diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..9a398a4 --- /dev/null +++ b/setup.py @@ -0,0 +1,76 @@ +import platform +from setuptools import setup, find_packages + +from pgcli import __version__ + +description = "CLI for Postgres Database. With auto-completion and syntax highlighting." + +install_requirements = [ + "pgspecial>=2.0.0", + "click >= 4.1", + "Pygments>=2.0", # Pygments has to be Capitalcased. WTF? + # We still need to use pt-2 unless pt-3 released on Fedora32 + # see: https://github.com/dbcli/pgcli/pull/1197 + "prompt_toolkit>=2.0.6,<4.0.0", + "psycopg >= 3.0.14", + "sqlparse >=0.3.0,<0.5", + "configobj >= 5.0.6", + "pendulum>=2.1.0", + "cli_helpers[styles] >= 2.2.1", +] + + +# setproctitle is used to mask the password when running `ps` in command line. +# But this is not necessary in Windows since the password is never shown in the +# task manager. Also setproctitle is a hard dependency to install in Windows, +# so we'll only install it if we're not in Windows. +if platform.system() != "Windows" and not platform.system().startswith("CYGWIN"): + install_requirements.append("setproctitle >= 1.1.9") + +# Windows will require the binary psycopg to run pgcli +if platform.system() == "Windows": + install_requirements.append("psycopg-binary >= 3.0.14") + + +setup( + name="pgcli", + author="Pgcli Core Team", + author_email="pgcli-dev@googlegroups.com", + version=__version__, + license="BSD", + url="http://pgcli.com", + packages=find_packages(), + package_data={"pgcli": ["pgclirc", "packages/pgliterals/pgliterals.json"]}, + description=description, + long_description=open("README.rst").read(), + install_requires=install_requirements, + dependency_links=[ + "http://github.com/psycopg/repo/tarball/master#egg=psycopg-3.0.10" + ], + extras_require={ + "keyring": ["keyring >= 12.2.0"], + "sshtunnel": ["sshtunnel >= 0.4.0"], + }, + python_requires=">=3.8", + entry_points=""" + [console_scripts] + pgcli=pgcli.main:cli + """, + classifiers=[ + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: Unix", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: SQL", + "Topic :: Database", + "Topic :: Database :: Front-Ends", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..33cddf2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,52 @@ +import os +import pytest +from utils import ( + POSTGRES_HOST, + POSTGRES_PORT, + POSTGRES_USER, + POSTGRES_PASSWORD, + create_db, + db_connection, + drop_tables, +) +import pgcli.pgexecute + + +@pytest.fixture(scope="function") +def connection(): + create_db("_test_db") + connection = db_connection("_test_db") + yield connection + + drop_tables(connection) + connection.close() + + +@pytest.fixture +def cursor(connection): + with connection.cursor() as cur: + return cur + + +@pytest.fixture +def executor(connection): + return pgcli.pgexecute.PGExecute( + database="_test_db", + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + port=POSTGRES_PORT, + dsn=None, + ) + + +@pytest.fixture +def exception_formatter(): + return lambda e: str(e) + + +@pytest.fixture(scope="session", autouse=True) +def temp_config(tmpdir_factory): + # this function runs on start of test session. + # use temporary directory for config home so user config will not be used + os.environ["XDG_CONFIG_HOME"] = str(tmpdir_factory.mktemp("data")) diff --git a/tests/features/__init__.py b/tests/features/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/features/__init__.py diff --git a/tests/features/auto_vertical.feature b/tests/features/auto_vertical.feature new file mode 100644 index 0000000..aa95718 --- /dev/null +++ b/tests/features/auto_vertical.feature @@ -0,0 +1,12 @@ +Feature: auto_vertical mode: + on, off + + Scenario: auto_vertical on with small query + When we run dbcli with --auto-vertical-output + and we execute a small query + then we see small results in horizontal format + + Scenario: auto_vertical on with large query + When we run dbcli with --auto-vertical-output + and we execute a large query + then we see large results in vertical format diff --git a/tests/features/basic_commands.feature b/tests/features/basic_commands.feature new file mode 100644 index 0000000..ee497b9 --- /dev/null +++ b/tests/features/basic_commands.feature @@ -0,0 +1,81 @@ +Feature: run the cli, + call the help command, + exit the cli + + Scenario: run "\?" command + When we send "\?" command + then we see help output + + Scenario: run source command + When we send source command + then we see help output + + Scenario: run partial select command + When we send partial select command + then we see error message + then we see dbcli prompt + + Scenario: check our application_name + When we run query to check application_name + then we see found + + Scenario: run the cli and exit + When we send "ctrl + d" + then dbcli exits + + Scenario: confirm exit when a transaction is ongoing + When we begin transaction + and we try to send "ctrl + d" + then we see ongoing transaction message + when we send "c" + then dbcli exits + + Scenario: cancel exit when a transaction is ongoing + When we begin transaction + and we try to send "ctrl + d" + then we see ongoing transaction message + when we send "a" + then we see dbcli prompt + when we rollback transaction + when we send "ctrl + d" + then dbcli exits + + Scenario: interrupt current query via "ctrl + c" + When we send sleep query + and we send "ctrl + c" + then we see cancelled query warning + when we check for any non-idle sleep queries + then we don't see any non-idle sleep queries + + Scenario: list databases + When we list databases + then we see list of databases + + Scenario: run the cli with --username + When we launch dbcli using --username + and we send "\?" command + then we see help output + + Scenario: run the cli with --user + When we launch dbcli using --user + and we send "\?" command + then we see help output + + Scenario: run the cli with --port + When we launch dbcli using --port + and we send "\?" command + then we see help output + + Scenario: run the cli with --password + When we launch dbcli using --password + then we send password + and we see dbcli prompt + when we send "\?" command + then we see help output + + Scenario: run the cli with dsn and password + When we launch dbcli using dsn_password + then we send password + and we see dbcli prompt + when we send "\?" command + then we see help output diff --git a/tests/features/crud_database.feature b/tests/features/crud_database.feature new file mode 100644 index 0000000..87da4e3 --- /dev/null +++ b/tests/features/crud_database.feature @@ -0,0 +1,17 @@ +Feature: manipulate databases: + create, drop, connect, disconnect + + Scenario: create and drop temporary database + When we create database + then we see database created + when we drop database + then we respond to the destructive warning: y + then we see database dropped + when we connect to dbserver + then we see database connected + + Scenario: connect and disconnect from test database + When we connect to test database + then we see database connected + when we connect to dbserver + then we see database connected diff --git a/tests/features/crud_table.feature b/tests/features/crud_table.feature new file mode 100644 index 0000000..8a43c5c --- /dev/null +++ b/tests/features/crud_table.feature @@ -0,0 +1,45 @@ +Feature: manipulate tables: + create, insert, update, select, delete from, drop + + Scenario: create, insert, select from, update, drop table + When we connect to test database + then we see database connected + when we create table + then we see table created + when we insert into table + then we see record inserted + when we select from table + then we see data selected: initial + when we update table + then we see record updated + when we select from table + then we see data selected: updated + when we delete from table + then we respond to the destructive warning: y + then we see record deleted + when we drop table + then we respond to the destructive warning: y + then we see table dropped + when we connect to dbserver + then we see database connected + + Scenario: transaction handling, with cancelling on a destructive warning. + When we connect to test database + then we see database connected + when we create table + then we see table created + when we begin transaction + then we see transaction began + when we insert into table + then we see record inserted + when we delete from table + then we respond to the destructive warning: n + when we select from table + then we see data selected: initial + when we rollback transaction + then we see transaction rolled back + when we select from table + then we see select output without data + when we drop table + then we respond to the destructive warning: y + then we see table dropped diff --git a/tests/features/db_utils.py b/tests/features/db_utils.py new file mode 100644 index 0000000..595c6c2 --- /dev/null +++ b/tests/features/db_utils.py @@ -0,0 +1,87 @@ +from psycopg import connect + + +def create_db( + hostname="localhost", username=None, password=None, dbname=None, port=None +): + """Create test database. + + :param hostname: string + :param username: string + :param password: string + :param dbname: string + :param port: int + :return: + + """ + cn = create_cn(hostname, password, username, "postgres", port) + + cn.autocommit = True + with cn.cursor() as cr: + cr.execute(f"drop database if exists {dbname}") + cr.execute(f"create database {dbname}") + + cn.close() + + cn = create_cn(hostname, password, username, dbname, port) + return cn + + +def create_cn(hostname, password, username, dbname, port): + """ + Open connection to database. + :param hostname: + :param password: + :param username: + :param dbname: string + :return: psycopg2.connection + """ + cn = connect( + host=hostname, user=username, dbname=dbname, password=password, port=port + ) + + print(f"Created connection: {cn.info.get_parameters()}.") + return cn + + +def pgbouncer_available(hostname="localhost", password=None, username="postgres"): + cn = None + try: + cn = create_cn(hostname, password, username, "pgbouncer", 6432) + return True + except: + print("Pgbouncer is not available.") + finally: + if cn: + cn.close() + return False + + +def drop_db(hostname="localhost", username=None, password=None, dbname=None, port=None): + """ + Drop database. + :param hostname: string + :param username: string + :param password: string + :param dbname: string + """ + cn = create_cn(hostname, password, username, "postgres", port) + + # Needed for DB drop. + cn.autocommit = True + + with cn.cursor() as cr: + cr.execute(f"drop database if exists {dbname}") + + close_cn(cn) + + +def close_cn(cn=None): + """ + Close connection. + :param connection: psycopg2.connection + """ + if cn: + cn_params = cn.info.get_parameters() + cn.close() + print(f"Closed connection: {cn_params}.") diff --git a/tests/features/environment.py b/tests/features/environment.py new file mode 100644 index 0000000..50ac5fa --- /dev/null +++ b/tests/features/environment.py @@ -0,0 +1,227 @@ +import copy +import os +import sys +import db_utils as dbutils +import fixture_utils as fixutils +import pexpect +import tempfile +import shutil +import signal + + +from steps import wrappers + + +def before_all(context): + """Set env parameters.""" + env_old = copy.deepcopy(dict(os.environ)) + os.environ["LINES"] = "100" + os.environ["COLUMNS"] = "100" + os.environ["PAGER"] = "cat" + os.environ["EDITOR"] = "ex" + os.environ["VISUAL"] = "ex" + os.environ["PROMPT_TOOLKIT_NO_CPR"] = "1" + + context.package_root = os.path.abspath( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + ) + fixture_dir = os.path.join(context.package_root, "tests/features/fixture_data") + + print("package root:", context.package_root) + print("fixture dir:", fixture_dir) + + os.environ["COVERAGE_PROCESS_START"] = os.path.join( + context.package_root, ".coveragerc" + ) + + context.exit_sent = False + + vi = "_".join([str(x) for x in sys.version_info[:3]]) + db_name = context.config.userdata.get("pg_test_db", "pgcli_behave_tests") + db_name_full = f"{db_name}_{vi}" + + # Store get params from config. + context.conf = { + "host": context.config.userdata.get( + "pg_test_host", os.getenv("PGHOST", "localhost") + ), + "user": context.config.userdata.get( + "pg_test_user", os.getenv("PGUSER", "postgres") + ), + "pass": context.config.userdata.get( + "pg_test_pass", os.getenv("PGPASSWORD", None) + ), + "port": context.config.userdata.get( + "pg_test_port", os.getenv("PGPORT", "5432") + ), + "cli_command": ( + context.config.userdata.get("pg_cli_command", None) + or '{python} -c "{startup}"'.format( + python=sys.executable, + startup="; ".join( + [ + "import coverage", + "coverage.process_startup()", + "import pgcli.main", + "pgcli.main.cli(auto_envvar_prefix='BEHAVE')", + ] + ), + ) + ), + "dbname": db_name_full, + "dbname_tmp": db_name_full + "_tmp", + "vi": vi, + "pager_boundary": "---boundary---", + } + os.environ["PAGER"] = "{0} {1} {2}".format( + sys.executable, + os.path.join(context.package_root, "tests/features/wrappager.py"), + context.conf["pager_boundary"], + ) + + # Store old env vars. + context.pgenv = { + "PGDATABASE": os.environ.get("PGDATABASE", None), + "PGUSER": os.environ.get("PGUSER", None), + "PGHOST": os.environ.get("PGHOST", None), + "PGPASSWORD": os.environ.get("PGPASSWORD", None), + "PGPORT": os.environ.get("PGPORT", None), + "XDG_CONFIG_HOME": os.environ.get("XDG_CONFIG_HOME", None), + "PGSERVICEFILE": os.environ.get("PGSERVICEFILE", None), + } + + # Set new env vars. + os.environ["PGDATABASE"] = context.conf["dbname"] + os.environ["PGUSER"] = context.conf["user"] + os.environ["PGHOST"] = context.conf["host"] + os.environ["PGPORT"] = context.conf["port"] + os.environ["PGSERVICEFILE"] = os.path.join(fixture_dir, "mock_pg_service.conf") + + if context.conf["pass"]: + os.environ["PGPASSWORD"] = context.conf["pass"] + else: + if "PGPASSWORD" in os.environ: + del os.environ["PGPASSWORD"] + os.environ["BEHAVE_WARN"] = "moderate" + + context.cn = dbutils.create_db( + context.conf["host"], + context.conf["user"], + context.conf["pass"], + context.conf["dbname"], + context.conf["port"], + ) + context.pgbouncer_available = dbutils.pgbouncer_available( + hostname=context.conf["host"], + password=context.conf["pass"], + username=context.conf["user"], + ) + context.fixture_data = fixutils.read_fixture_files() + + # use temporary directory as config home + context.env_config_home = tempfile.mkdtemp(prefix="pgcli_home_") + os.environ["XDG_CONFIG_HOME"] = context.env_config_home + show_env_changes(env_old, dict(os.environ)) + + +def show_env_changes(env_old, env_new): + """Print out all test-specific env values.""" + print("--- os.environ changed values: ---") + all_keys = env_old.keys() | env_new.keys() + for k in sorted(all_keys): + old_value = env_old.get(k, "") + new_value = env_new.get(k, "") + if new_value and old_value != new_value: + print(f'{k}="{new_value}"') + print("-" * 20) + + +def after_all(context): + """ + Unset env parameters. + """ + dbutils.close_cn(context.cn) + dbutils.drop_db( + context.conf["host"], + context.conf["user"], + context.conf["pass"], + context.conf["dbname"], + context.conf["port"], + ) + + # Remove temp config directory + shutil.rmtree(context.env_config_home) + + # Restore env vars. + for k, v in context.pgenv.items(): + if k in os.environ and v is None: + del os.environ[k] + elif v: + os.environ[k] = v + + +def before_step(context, _): + context.atprompt = False + + +def is_known_problem(scenario): + """TODO: why is this not working in 3.12?""" + if sys.version_info >= (3, 12): + return scenario.name in ( + 'interrupt current query via "ctrl + c"', + "run the cli with --username", + "run the cli with --user", + "run the cli with --port", + ) + return False + + +def before_scenario(context, scenario): + if scenario.name == "list databases": + # not using the cli for that + return + if is_known_problem(scenario): + scenario.skip() + currentdb = None + if "pgbouncer" in scenario.feature.tags: + if context.pgbouncer_available: + os.environ["PGDATABASE"] = "pgbouncer" + os.environ["PGPORT"] = "6432" + currentdb = "pgbouncer" + else: + scenario.skip() + else: + # set env vars back to normal test database + os.environ["PGDATABASE"] = context.conf["dbname"] + os.environ["PGPORT"] = context.conf["port"] + wrappers.run_cli(context, currentdb=currentdb) + wrappers.wait_prompt(context) + + +def after_scenario(context, scenario): + """Cleans up after each scenario completes.""" + if hasattr(context, "cli") and context.cli and not context.exit_sent: + # Quit nicely. + if not getattr(context, "atprompt", False): + dbname = context.currentdb + context.cli.expect_exact(f"{dbname}>", timeout=5) + try: + context.cli.sendcontrol("c") + context.cli.sendcontrol("d") + except Exception as x: + print("Failed cleanup after scenario:") + print(x) + try: + context.cli.expect_exact(pexpect.EOF, timeout=5) + except pexpect.TIMEOUT: + print(f"--- after_scenario {scenario.name}: kill cli") + context.cli.kill(signal.SIGKILL) + if hasattr(context, "tmpfile_sql_help") and context.tmpfile_sql_help: + context.tmpfile_sql_help.close() + context.tmpfile_sql_help = None + + +# # TODO: uncomment to debug a failure +# def after_step(context, step): +# if step.status == "failed": +# import pdb; pdb.set_trace() diff --git a/tests/features/expanded.feature b/tests/features/expanded.feature new file mode 100644 index 0000000..e486048 --- /dev/null +++ b/tests/features/expanded.feature @@ -0,0 +1,29 @@ +Feature: expanded mode: + on, off, auto + + Scenario: expanded on + When we prepare the test data + and we set expanded on + and we select from table + then we see expanded data selected + when we drop table + then we respond to the destructive warning: y + then we see table dropped + + Scenario: expanded off + When we prepare the test data + and we set expanded off + and we select from table + then we see nonexpanded data selected + when we drop table + then we respond to the destructive warning: y + then we see table dropped + + Scenario: expanded auto + When we prepare the test data + and we set expanded auto + and we select from table + then we see auto data selected + when we drop table + then we respond to the destructive warning: y + then we see table dropped diff --git a/tests/features/fixture_data/help.txt b/tests/features/fixture_data/help.txt new file mode 100644 index 0000000..bebb976 --- /dev/null +++ b/tests/features/fixture_data/help.txt @@ -0,0 +1,25 @@ ++--------------------------+------------------------------------------------+ +| Command | Description | +|--------------------------+------------------------------------------------| +| \# | Refresh auto-completions. | +| \? | Show Help. | +| \T [format] | Change the table format used to output results | +| \c[onnect] database_name | Change to a new database. | +| \d [pattern] | List or describe tables, views and sequences. | +| \dT[S+] [pattern] | List data types | +| \df[+] [pattern] | List functions. | +| \di[+] [pattern] | List indexes. | +| \dn[+] [pattern] | List schemas. | +| \ds[+] [pattern] | List sequences. | +| \dt[+] [pattern] | List tables. | +| \du[+] [pattern] | List roles. | +| \dv[+] [pattern] | List views. | +| \e [file] | Edit the query with external editor. | +| \l | List databases. | +| \n[+] [name] | List or execute named queries. | +| \nd [name [query]] | Delete a named query. | +| \ns name query | Save a named query. | +| \refresh | Refresh auto-completions. | +| \timing | Toggle timing of commands. | +| \x | Toggle expanded output. | ++--------------------------+------------------------------------------------+ diff --git a/tests/features/fixture_data/help_commands.txt b/tests/features/fixture_data/help_commands.txt new file mode 100644 index 0000000..e076661 --- /dev/null +++ b/tests/features/fixture_data/help_commands.txt @@ -0,0 +1,64 @@ +Command +Description +\# +Refresh auto-completions. +\? +Show Commands. +\T [format] +Change the table format used to output results +\c[onnect] database_name +Change to a new database. +\copy [tablename] to/from [filename] +Copy data between a file and a table. +\d[+] [pattern] +List or describe tables, views and sequences. +\dT[S+] [pattern] +List data types +\db[+] [pattern] +List tablespaces. +\df[+] [pattern] +List functions. +\di[+] [pattern] +List indexes. +\dm[+] [pattern] +List materialized views. +\dn[+] [pattern] +List schemas. +\ds[+] [pattern] +List sequences. +\dt[+] [pattern] +List tables. +\du[+] [pattern] +List roles. +\dv[+] [pattern] +List views. +\dx[+] [pattern] +List extensions. +\e [file] +Edit the query with external editor. +\h +Show SQL syntax and help. +\i filename +Execute commands from file. +\l +List databases. +\n[+] [name] [param1 param2 ...] +List or execute named queries. +\nd [name] +Delete a named query. +\ns name query +Save a named query. +\o [filename] +Send all query results to file. +\pager [command] +Set PAGER. Print the query results via PAGER. +\pset [key] [value] +A limited version of traditional \pset +\refresh +Refresh auto-completions. +\sf[+] FUNCNAME +Show a function's definition. +\timing +Toggle timing of commands. +\x +Toggle expanded output. diff --git a/tests/features/fixture_data/mock_pg_service.conf b/tests/features/fixture_data/mock_pg_service.conf new file mode 100644 index 0000000..15f9811 --- /dev/null +++ b/tests/features/fixture_data/mock_pg_service.conf @@ -0,0 +1,4 @@ +[mock_postgres] +dbname=postgres +host=localhost +user=postgres diff --git a/tests/features/fixture_utils.py b/tests/features/fixture_utils.py new file mode 100644 index 0000000..70b603d --- /dev/null +++ b/tests/features/fixture_utils.py @@ -0,0 +1,28 @@ +import os +import codecs + + +def read_fixture_lines(filename): + """ + Read lines of text from file. + :param filename: string name + :return: list of strings + """ + lines = [] + for line in codecs.open(filename, "rb", encoding="utf-8"): + lines.append(line.strip()) + return lines + + +def read_fixture_files(): + """Read all files inside fixture_data directory.""" + current_dir = os.path.dirname(__file__) + fixture_dir = os.path.join(current_dir, "fixture_data/") + print(f"reading fixture data: {fixture_dir}") + fixture_dict = {} + for filename in os.listdir(fixture_dir): + if filename not in [".", ".."]: + fullname = os.path.join(fixture_dir, filename) + fixture_dict[filename] = read_fixture_lines(fullname) + + return fixture_dict diff --git a/tests/features/iocommands.feature b/tests/features/iocommands.feature new file mode 100644 index 0000000..dad7d10 --- /dev/null +++ b/tests/features/iocommands.feature @@ -0,0 +1,17 @@ +Feature: I/O commands + + Scenario: edit sql in file with external editor + When we start external editor providing a file name + and we type sql in the editor + and we exit the editor + then we see dbcli prompt + and we see the sql in prompt + + Scenario: tee output from query + When we tee output + and we wait for prompt + and we query "select 123456" + and we wait for prompt + and we stop teeing output + and we wait for prompt + then we see 123456 in tee output diff --git a/tests/features/named_queries.feature b/tests/features/named_queries.feature new file mode 100644 index 0000000..74201b9 --- /dev/null +++ b/tests/features/named_queries.feature @@ -0,0 +1,10 @@ +Feature: named queries: + save, use and delete named queries + + Scenario: save, use and delete named queries + When we connect to test database + then we see database connected + when we save a named query + then we see the named query saved + when we delete a named query + then we see the named query deleted diff --git a/tests/features/pgbouncer.feature b/tests/features/pgbouncer.feature new file mode 100644 index 0000000..14cc5ad --- /dev/null +++ b/tests/features/pgbouncer.feature @@ -0,0 +1,12 @@ +@pgbouncer +Feature: run pgbouncer, + call the help command, + exit the cli + + Scenario: run "show help" command + When we send "show help" command + then we see the pgbouncer help output + + Scenario: run the cli and exit + When we send "ctrl + d" + then dbcli exits diff --git a/tests/features/specials.feature b/tests/features/specials.feature new file mode 100644 index 0000000..63c5cdc --- /dev/null +++ b/tests/features/specials.feature @@ -0,0 +1,6 @@ +Feature: Special commands + + Scenario: run refresh command + When we refresh completions + and we wait for prompt + then we see completions refresh started diff --git a/tests/features/steps/__init__.py b/tests/features/steps/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/features/steps/__init__.py diff --git a/tests/features/steps/auto_vertical.py b/tests/features/steps/auto_vertical.py new file mode 100644 index 0000000..d7cdccd --- /dev/null +++ b/tests/features/steps/auto_vertical.py @@ -0,0 +1,99 @@ +from textwrap import dedent +from behave import then, when +import wrappers + + +@when("we run dbcli with {arg}") +def step_run_cli_with_arg(context, arg): + wrappers.run_cli(context, run_args=arg.split("=")) + + +@when("we execute a small query") +def step_execute_small_query(context): + context.cli.sendline("select 1") + + +@when("we execute a large query") +def step_execute_large_query(context): + context.cli.sendline("select {}".format(",".join([str(n) for n in range(1, 50)]))) + + +@then("we see small results in horizontal format") +def step_see_small_results(context): + wrappers.expect_pager( + context, + dedent( + """\ + +----------+\r + | ?column? |\r + |----------|\r + | 1 |\r + +----------+\r + SELECT 1\r + """ + ), + timeout=5, + ) + + +@then("we see large results in vertical format") +def step_see_large_results(context): + wrappers.expect_pager( + context, + dedent( + """\ + -[ RECORD 1 ]-------------------------\r + ?column? | 1\r + ?column? | 2\r + ?column? | 3\r + ?column? | 4\r + ?column? | 5\r + ?column? | 6\r + ?column? | 7\r + ?column? | 8\r + ?column? | 9\r + ?column? | 10\r + ?column? | 11\r + ?column? | 12\r + ?column? | 13\r + ?column? | 14\r + ?column? | 15\r + ?column? | 16\r + ?column? | 17\r + ?column? | 18\r + ?column? | 19\r + ?column? | 20\r + ?column? | 21\r + ?column? | 22\r + ?column? | 23\r + ?column? | 24\r + ?column? | 25\r + ?column? | 26\r + ?column? | 27\r + ?column? | 28\r + ?column? | 29\r + ?column? | 30\r + ?column? | 31\r + ?column? | 32\r + ?column? | 33\r + ?column? | 34\r + ?column? | 35\r + ?column? | 36\r + ?column? | 37\r + ?column? | 38\r + ?column? | 39\r + ?column? | 40\r + ?column? | 41\r + ?column? | 42\r + ?column? | 43\r + ?column? | 44\r + ?column? | 45\r + ?column? | 46\r + ?column? | 47\r + ?column? | 48\r + ?column? | 49\r + SELECT 1\r + """ + ), + timeout=5, + ) diff --git a/tests/features/steps/basic_commands.py b/tests/features/steps/basic_commands.py new file mode 100644 index 0000000..687bdc0 --- /dev/null +++ b/tests/features/steps/basic_commands.py @@ -0,0 +1,231 @@ +""" +Steps for behavioral style tests are defined in this module. +Each step is defined by the string decorating it. +This string is used to call the step in "*.feature" file. +""" + +import pexpect +import subprocess +import tempfile + +from behave import when, then +from textwrap import dedent +import wrappers + + +@when("we list databases") +def step_list_databases(context): + cmd = ["pgcli", "--list"] + context.cmd_output = subprocess.check_output(cmd, cwd=context.package_root) + + +@then("we see list of databases") +def step_see_list_databases(context): + assert b"List of databases" in context.cmd_output + assert b"postgres" in context.cmd_output + context.cmd_output = None + + +@when("we run dbcli") +def step_run_cli(context): + wrappers.run_cli(context) + + +@when("we launch dbcli using {arg}") +def step_run_cli_using_arg(context, arg): + prompt_check = False + currentdb = None + if arg == "--username": + arg = "--username={}".format(context.conf["user"]) + if arg == "--user": + arg = "--user={}".format(context.conf["user"]) + if arg == "--port": + arg = "--port={}".format(context.conf["port"]) + if arg == "--password": + arg = "--password" + prompt_check = False + # This uses the mock_pg_service.conf file in fixtures folder. + if arg == "dsn_password": + arg = "service=mock_postgres --password" + prompt_check = False + currentdb = "postgres" + wrappers.run_cli( + context, run_args=[arg], prompt_check=prompt_check, currentdb=currentdb + ) + + +@when("we wait for prompt") +def step_wait_prompt(context): + wrappers.wait_prompt(context) + + +@when('we send "ctrl + d"') +def step_ctrl_d(context): + """ + Send Ctrl + D to hopefully exit. + """ + step_try_to_ctrl_d(context) + context.cli.expect(pexpect.EOF, timeout=5) + context.exit_sent = True + + +@when('we try to send "ctrl + d"') +def step_try_to_ctrl_d(context): + """ + Send Ctrl + D, perhaps exiting, perhaps not (if a transaction is + ongoing). + """ + # turn off pager before exiting + context.cli.sendcontrol("c") + context.cli.sendline(r"\pset pager off") + wrappers.wait_prompt(context) + context.cli.sendcontrol("d") + + +@when('we send "ctrl + c"') +def step_ctrl_c(context): + """Send Ctrl + c to hopefully interrupt.""" + context.cli.sendcontrol("c") + + +@then("we see cancelled query warning") +def step_see_cancelled_query_warning(context): + """ + Make sure we receive the warning that the current query was cancelled. + """ + wrappers.expect_exact(context, "cancelled query", timeout=2) + + +@then("we see ongoing transaction message") +def step_see_ongoing_transaction_error(context): + """ + Make sure we receive the warning that a transaction is ongoing. + """ + context.cli.expect("A transaction is ongoing.", timeout=2) + + +@when("we send sleep query") +def step_send_sleep_15_seconds(context): + """ + Send query to sleep for 15 seconds. + """ + context.cli.sendline("select pg_sleep(15)") + + +@when("we check for any non-idle sleep queries") +def step_check_for_active_sleep_queries(context): + """ + Send query to check for any non-idle pg_sleep queries. + """ + context.cli.sendline( + "select state from pg_stat_activity where query not like '%pg_stat_activity%' and query like '%pg_sleep%' and state != 'idle';" + ) + + +@then("we don't see any non-idle sleep queries") +def step_no_active_sleep_queries(context): + """Confirm that any pg_sleep queries are either idle or not active.""" + wrappers.expect_exact( + context, + context.conf["pager_boundary"] + + "\r" + + dedent( + """ + +-------+\r + | state |\r + |-------|\r + +-------+\r + SELECT 0\r + """ + ) + + context.conf["pager_boundary"], + timeout=5, + ) + + +@when(r'we send "\?" command') +def step_send_help(context): + r""" + Send \? to see help. + """ + context.cli.sendline(r"\?") + + +@when("we send partial select command") +def step_send_partial_select_command(context): + """ + Send `SELECT a` to see completion. + """ + context.cli.sendline("SELECT a") + + +@then("we see error message") +def step_see_error_message(context): + wrappers.expect_exact(context, 'column "a" does not exist', timeout=2) + + +@when("we send source command") +def step_send_source_command(context): + context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_") + context.tmpfile_sql_help.write(rb"\?") + context.tmpfile_sql_help.flush() + context.cli.sendline(rf"\i {context.tmpfile_sql_help.name}") + wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) + + +@when("we run query to check application_name") +def step_check_application_name(context): + context.cli.sendline( + "SELECT 'found' FROM pg_stat_activity WHERE application_name = 'pgcli' HAVING COUNT(*) > 0;" + ) + + +@then("we see found") +def step_see_found(context): + wrappers.expect_exact( + context, + context.conf["pager_boundary"] + + "\r" + + dedent( + """ + +----------+\r + | ?column? |\r + |----------|\r + | found |\r + +----------+\r + SELECT 1\r + """ + ) + + context.conf["pager_boundary"], + timeout=5, + ) + + +@then("we respond to the destructive warning: {response}") +def step_resppond_to_destructive_command(context, response): + """Respond to destructive command.""" + wrappers.expect_exact( + context, + "You're about to run a destructive command.\r\nDo you want to proceed? [y/N]:", + timeout=2, + ) + context.cli.sendline(response.strip()) + + +@then("we send password") +def step_send_password(context): + wrappers.expect_exact(context, "Password for", timeout=5) + context.cli.sendline(context.conf["pass"] or "DOES NOT MATTER") + + +@when('we send "{text}"') +def step_send_text(context, text): + context.cli.sendline(text) + # Try to detect whether we are exiting. If so, set `exit_sent` + # so that `after_scenario` correctly cleans up. + try: + context.cli.expect(pexpect.EOF, timeout=0.2) + except pexpect.TIMEOUT: + pass + else: + context.exit_sent = True diff --git a/tests/features/steps/crud_database.py b/tests/features/steps/crud_database.py new file mode 100644 index 0000000..87cdc85 --- /dev/null +++ b/tests/features/steps/crud_database.py @@ -0,0 +1,93 @@ +""" +Steps for behavioral style tests are defined in this module. +Each step is defined by the string decorating it. +This string is used to call the step in "*.feature" file. +""" +import pexpect + +from behave import when, then +import wrappers + + +@when("we create database") +def step_db_create(context): + """ + Send create database. + """ + context.cli.sendline("create database {};".format(context.conf["dbname_tmp"])) + + context.response = {"database_name": context.conf["dbname_tmp"]} + + +@when("we drop database") +def step_db_drop(context): + """ + Send drop database. + """ + context.cli.sendline("drop database {};".format(context.conf["dbname_tmp"])) + + +@when("we connect to test database") +def step_db_connect_test(context): + """ + Send connect to database. + """ + db_name = context.conf["dbname"] + context.cli.sendline(f"\\connect {db_name}") + + +@when("we connect to dbserver") +def step_db_connect_dbserver(context): + """ + Send connect to database. + """ + context.cli.sendline("\\connect postgres") + context.currentdb = "postgres" + + +@then("dbcli exits") +def step_wait_exit(context): + """ + Make sure the cli exits. + """ + wrappers.expect_exact(context, pexpect.EOF, timeout=5) + + +@then("we see dbcli prompt") +def step_see_prompt(context): + """ + Wait to see the prompt. + """ + db_name = getattr(context, "currentdb", context.conf["dbname"]) + wrappers.expect_exact(context, f"{db_name}>", timeout=5) + context.atprompt = True + + +@then("we see help output") +def step_see_help(context): + for expected_line in context.fixture_data["help_commands.txt"]: + wrappers.expect_exact(context, expected_line, timeout=2) + + +@then("we see database created") +def step_see_db_created(context): + """ + Wait to see create database output. + """ + wrappers.expect_pager(context, "CREATE DATABASE\r\n", timeout=5) + + +@then("we see database dropped") +def step_see_db_dropped(context): + """ + Wait to see drop database output. + """ + wrappers.expect_pager(context, "DROP DATABASE\r\n", timeout=2) + + +@then("we see database connected") +def step_see_db_connected(context): + """ + Wait to see drop database output. + """ + wrappers.expect_exact(context, "You are now connected to database", timeout=2) diff --git a/tests/features/steps/crud_table.py b/tests/features/steps/crud_table.py new file mode 100644 index 0000000..27d543e --- /dev/null +++ b/tests/features/steps/crud_table.py @@ -0,0 +1,185 @@ +""" +Steps for behavioral style tests are defined in this module. +Each step is defined by the string decorating it. +This string is used to call the step in "*.feature" file. +""" + +from behave import when, then +from textwrap import dedent +import wrappers + + +INITIAL_DATA = "xxx" +UPDATED_DATA = "yyy" + + +@when("we create table") +def step_create_table(context): + """ + Send create table. + """ + context.cli.sendline("create table a(x text);") + + +@when("we insert into table") +def step_insert_into_table(context): + """ + Send insert into table. + """ + context.cli.sendline(f"""insert into a(x) values('{INITIAL_DATA}');""") + + +@when("we update table") +def step_update_table(context): + """ + Send insert into table. + """ + context.cli.sendline( + f"""update a set x = '{UPDATED_DATA}' where x = '{INITIAL_DATA}';""" + ) + + +@when("we select from table") +def step_select_from_table(context): + """ + Send select from table. + """ + context.cli.sendline("select * from a;") + + +@when("we delete from table") +def step_delete_from_table(context): + """ + Send deete from table. + """ + context.cli.sendline(f"""delete from a where x = '{UPDATED_DATA}';""") + + +@when("we drop table") +def step_drop_table(context): + """ + Send drop table. + """ + context.cli.sendline("drop table a;") + + +@when("we alter the table") +def step_alter_table(context): + """ + Alter the table by adding a column. + """ + context.cli.sendline("""alter table a add column y varchar;""") + + +@when("we begin transaction") +def step_begin_transaction(context): + """ + Begin transaction + """ + context.cli.sendline("begin;") + + +@when("we rollback transaction") +def step_rollback_transaction(context): + """ + Rollback transaction + """ + context.cli.sendline("rollback;") + + +@then("we see table created") +def step_see_table_created(context): + """ + Wait to see create table output. + """ + wrappers.expect_pager(context, "CREATE TABLE\r\n", timeout=2) + + +@then("we see record inserted") +def step_see_record_inserted(context): + """ + Wait to see insert output. + """ + wrappers.expect_pager(context, "INSERT 0 1\r\n", timeout=2) + + +@then("we see record updated") +def step_see_record_updated(context): + """ + Wait to see update output. + """ + wrappers.expect_pager(context, "UPDATE 1\r\n", timeout=2) + + +@then("we see data selected: {data}") +def step_see_data_selected(context, data): + """ + Wait to see select output with initial or updated data. + """ + x = UPDATED_DATA if data == "updated" else INITIAL_DATA + wrappers.expect_pager( + context, + dedent( + f"""\ + +-----+\r + | x |\r + |-----|\r + | {x} |\r + +-----+\r + SELECT 1\r + """ + ), + timeout=1, + ) + + +@then("we see select output without data") +def step_see_no_data_selected(context): + """ + Wait to see select output without data. + """ + wrappers.expect_pager( + context, + dedent( + """\ + +---+\r + | x |\r + |---|\r + +---+\r + SELECT 0\r + """ + ), + timeout=1, + ) + + +@then("we see record deleted") +def step_see_data_deleted(context): + """ + Wait to see delete output. + """ + wrappers.expect_pager(context, "DELETE 1\r\n", timeout=2) + + +@then("we see table dropped") +def step_see_table_dropped(context): + """ + Wait to see drop output. + """ + wrappers.expect_pager(context, "DROP TABLE\r\n", timeout=2) + + +@then("we see transaction began") +def step_see_transaction_began(context): + """ + Wait to see transaction began. + """ + wrappers.expect_pager(context, "BEGIN\r\n", timeout=2) + + +@then("we see transaction rolled back") +def step_see_transaction_rolled_back(context): + """ + Wait to see transaction rollback. + """ + wrappers.expect_pager(context, "ROLLBACK\r\n", timeout=2) diff --git a/tests/features/steps/expanded.py b/tests/features/steps/expanded.py new file mode 100644 index 0000000..302cab9 --- /dev/null +++ b/tests/features/steps/expanded.py @@ -0,0 +1,70 @@ +"""Steps for behavioral style tests are defined in this module. + +Each step is defined by the string decorating it. This string is used +to call the step in "*.feature" file. + +""" + +from behave import when, then +from textwrap import dedent +import wrappers + + +@when("we prepare the test data") +def step_prepare_data(context): + """Create table, insert a record.""" + context.cli.sendline("drop table if exists a;") + wrappers.expect_exact( + context, + "You're about to run a destructive command.\r\nDo you want to proceed? [y/N]:", + timeout=2, + ) + context.cli.sendline("y") + + wrappers.wait_prompt(context) + context.cli.sendline("create table a(x integer, y real, z numeric(10, 4));") + wrappers.expect_pager(context, "CREATE TABLE\r\n", timeout=2) + context.cli.sendline("""insert into a(x, y, z) values(1, 1.0, 1.0);""") + wrappers.expect_pager(context, "INSERT 0 1\r\n", timeout=2) + + +@when("we set expanded {mode}") +def step_set_expanded(context, mode): + """Set expanded to mode.""" + context.cli.sendline("\\" + f"x {mode}") + wrappers.expect_exact(context, "Expanded display is", timeout=2) + wrappers.wait_prompt(context) + + +@then("we see {which} data selected") +def step_see_data(context, which): + """Select data from expanded test table.""" + if which == "expanded": + wrappers.expect_pager( + context, + dedent( + """\ + -[ RECORD 1 ]-------------------------\r + x | 1\r + y | 1.0\r + z | 1.0000\r + SELECT 1\r + """ + ), + timeout=1, + ) + else: + wrappers.expect_pager( + context, + dedent( + """\ + +---+-----+--------+\r + | x | y | z |\r + |---+-----+--------|\r + | 1 | 1.0 | 1.0000 |\r + +---+-----+--------+\r + SELECT 1\r + """ + ), + timeout=1, + ) diff --git a/tests/features/steps/iocommands.py b/tests/features/steps/iocommands.py new file mode 100644 index 0000000..a614490 --- /dev/null +++ b/tests/features/steps/iocommands.py @@ -0,0 +1,80 @@ +import os +import os.path + +from behave import when, then +import wrappers + + +@when("we start external editor providing a file name") +def step_edit_file(context): + """Edit file with external editor.""" + context.editor_file_name = os.path.join( + context.package_root, "test_file_{0}.sql".format(context.conf["vi"]) + ) + if os.path.exists(context.editor_file_name): + os.remove(context.editor_file_name) + context.cli.sendline(r"\e {}".format(os.path.basename(context.editor_file_name))) + wrappers.expect_exact( + context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2 + ) + wrappers.expect_exact(context, ":", timeout=2) + + +@when("we type sql in the editor") +def step_edit_type_sql(context): + context.cli.sendline("i") + context.cli.sendline("select * from abc") + context.cli.sendline(".") + wrappers.expect_exact(context, ":", timeout=2) + + +@when("we exit the editor") +def step_edit_quit(context): + context.cli.sendline("x") + wrappers.expect_exact(context, "written", timeout=2) + + +@then("we see the sql in prompt") +def step_edit_done_sql(context): + for match in "select * from abc".split(" "): + wrappers.expect_exact(context, match, timeout=1) + # Cleanup the command line. + context.cli.sendcontrol("c") + # Cleanup the edited file. + if context.editor_file_name and os.path.exists(context.editor_file_name): + os.remove(context.editor_file_name) + context.atprompt = True + + +@when("we tee output") +def step_tee_ouptut(context): + context.tee_file_name = os.path.join( + context.package_root, "tee_file_{0}.sql".format(context.conf["vi"]) + ) + if os.path.exists(context.tee_file_name): + os.remove(context.tee_file_name) + context.cli.sendline(r"\o {}".format(os.path.basename(context.tee_file_name))) + wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) + wrappers.expect_exact(context, "Writing to file", timeout=5) + wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) + wrappers.expect_exact(context, "Time", timeout=5) + + +@when('we query "select 123456"') +def step_query_select_123456(context): + context.cli.sendline("select 123456") + + +@when("we stop teeing output") +def step_notee_output(context): + context.cli.sendline(r"\o") + wrappers.expect_exact(context, "Time", timeout=5) + + +@then("we see 123456 in tee output") +def step_see_123456_in_ouput(context): + with open(context.tee_file_name) as f: + assert "123456" in f.read() + if os.path.exists(context.tee_file_name): + os.remove(context.tee_file_name) + context.atprompt = True diff --git a/tests/features/steps/named_queries.py b/tests/features/steps/named_queries.py new file mode 100644 index 0000000..3f52859 --- /dev/null +++ b/tests/features/steps/named_queries.py @@ -0,0 +1,57 @@ +""" +Steps for behavioral style tests are defined in this module. +Each step is defined by the string decorating it. +This string is used to call the step in "*.feature" file. +""" + +from behave import when, then +import wrappers + + +@when("we save a named query") +def step_save_named_query(context): + """ + Send \ns command + """ + context.cli.sendline("\\ns foo SELECT 12345") + + +@when("we use a named query") +def step_use_named_query(context): + """ + Send \n command + """ + context.cli.sendline("\\n foo") + + +@when("we delete a named query") +def step_delete_named_query(context): + """ + Send \nd command + """ + context.cli.sendline("\\nd foo") + + +@then("we see the named query saved") +def step_see_named_query_saved(context): + """ + Wait to see query saved. + """ + wrappers.expect_exact(context, "Saved.", timeout=2) + + +@then("we see the named query executed") +def step_see_named_query_executed(context): + """ + Wait to see select output. + """ + wrappers.expect_exact(context, "12345", timeout=1) + wrappers.expect_exact(context, "SELECT 1", timeout=1) + + +@then("we see the named query deleted") +def step_see_named_query_deleted(context): + """ + Wait to see query deleted. + """ + wrappers.expect_pager(context, "foo: Deleted\r\n", timeout=1) diff --git a/tests/features/steps/pgbouncer.py b/tests/features/steps/pgbouncer.py new file mode 100644 index 0000000..f156982 --- /dev/null +++ b/tests/features/steps/pgbouncer.py @@ -0,0 +1,22 @@ +""" +Steps for behavioral style tests are defined in this module. +Each step is defined by the string decorating it. +This string is used to call the step in "*.feature" file. +""" + +from behave import when, then +import wrappers + + +@when('we send "show help" command') +def step_send_help_command(context): + context.cli.sendline("show help") + + +@then("we see the pgbouncer help output") +def see_pgbouncer_help(context): + wrappers.expect_exact( + context, + "SHOW HELP|CONFIG|DATABASES|POOLS|CLIENTS|SERVERS|USERS|VERSION", + timeout=3, + ) diff --git a/tests/features/steps/specials.py b/tests/features/steps/specials.py new file mode 100644 index 0000000..a85f371 --- /dev/null +++ b/tests/features/steps/specials.py @@ -0,0 +1,31 @@ +""" +Steps for behavioral style tests are defined in this module. +Each step is defined by the string decorating it. +This string is used to call the step in "*.feature" file. +""" + +from behave import when, then +import wrappers + + +@when("we refresh completions") +def step_refresh_completions(context): + """ + Send refresh command. + """ + context.cli.sendline("\\refresh") + + +@then("we see completions refresh started") +def step_see_refresh_started(context): + """ + Wait to see refresh output. + """ + wrappers.expect_pager( + context, + [ + "Auto-completion refresh started in the background.\r\n", + "Auto-completion refresh restarted.\r\n", + ], + timeout=2, + ) diff --git a/tests/features/steps/wrappers.py b/tests/features/steps/wrappers.py new file mode 100644 index 0000000..3ebcc92 --- /dev/null +++ b/tests/features/steps/wrappers.py @@ -0,0 +1,71 @@ +import re +import pexpect +from pgcli.main import COLOR_CODE_REGEX +import textwrap + +from io import StringIO + + +def expect_exact(context, expected, timeout): + timedout = False + try: + context.cli.expect_exact(expected, timeout=timeout) + except pexpect.TIMEOUT: + timedout = True + if timedout: + # Strip color codes out of the output. + actual = re.sub(r"\x1b\[([0-9A-Za-z;?])+[m|K]?", "", context.cli.before) + raise Exception( + textwrap.dedent( + """\ + Expected: + --- + {0!r} + --- + Actual: + --- + {1!r} + --- + Full log: + --- + {2!r} + --- + """ + ).format(expected, actual, context.logfile.getvalue()) + ) + + +def expect_pager(context, expected, timeout): + formatted = expected if isinstance(expected, list) else [expected] + formatted = [ + f"{context.conf['pager_boundary']}\r\n{t}{context.conf['pager_boundary']}\r\n" + for t in formatted + ] + + expect_exact( + context, + formatted, + timeout=timeout, + ) + + +def run_cli(context, run_args=None, prompt_check=True, currentdb=None): + """Run the process using pexpect.""" + run_args = run_args or [] + cli_cmd = context.conf.get("cli_command") + cmd_parts = [cli_cmd] + run_args + cmd = " ".join(cmd_parts) + context.cli = pexpect.spawnu(cmd, cwd=context.package_root) + context.logfile = StringIO() + context.cli.logfile = context.logfile + context.exit_sent = False + context.currentdb = currentdb or context.conf["dbname"] + context.cli.sendline(r"\pset pager always") + if prompt_check: + wait_prompt(context) + + +def wait_prompt(context): + """Make sure prompt is displayed.""" + prompt_str = "{0}>".format(context.currentdb) + expect_exact(context, [prompt_str + " ", prompt_str, pexpect.EOF], timeout=3) diff --git a/tests/features/wrappager.py b/tests/features/wrappager.py new file mode 100644 index 0000000..51d4909 --- /dev/null +++ b/tests/features/wrappager.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +import sys + + +def wrappager(boundary): + print(boundary) + while 1: + buf = sys.stdin.read(2048) + if not buf: + break + sys.stdout.write(buf) + print(boundary) + + +if __name__ == "__main__": + wrappager(sys.argv[1]) diff --git a/tests/formatter/__init__.py b/tests/formatter/__init__.py new file mode 100644 index 0000000..9bad579 --- /dev/null +++ b/tests/formatter/__init__.py @@ -0,0 +1 @@ +# coding=utf-8 diff --git a/tests/formatter/test_sqlformatter.py b/tests/formatter/test_sqlformatter.py new file mode 100644 index 0000000..016ed95 --- /dev/null +++ b/tests/formatter/test_sqlformatter.py @@ -0,0 +1,111 @@ +# coding=utf-8 + +from pgcli.packages.formatter.sqlformatter import escape_for_sql_statement + +from cli_helpers.tabular_output import TabularOutputFormatter +from pgcli.packages.formatter.sqlformatter import adapter, register_new_formatter + + +def test_escape_for_sql_statement_bytes(): + bts = b"837124ab3e8dc0f" + escaped_bytes = escape_for_sql_statement(bts) + assert escaped_bytes == "X'383337313234616233653864633066'" + + +def test_escape_for_sql_statement_number(): + num = 2981 + escaped_bytes = escape_for_sql_statement(num) + assert escaped_bytes == "'2981'" + + +def test_escape_for_sql_statement_str(): + example_str = "example str" + escaped_bytes = escape_for_sql_statement(example_str) + assert escaped_bytes == "'example str'" + + +def test_output_sql_insert(): + global formatter + formatter = TabularOutputFormatter + register_new_formatter(formatter) + data = [ + [ + 1, + "Jackson", + "jackson_test@gmail.com", + "132454789", + None, + "2022-09-09 19:44:32.712343+08", + "2022-09-09 19:44:32.712343+08", + ] + ] + header = ["id", "name", "email", "phone", "description", "created_at", "updated_at"] + table_format = "sql-insert" + kwargs = { + "column_types": [int, str, str, str, str, str, str], + "sep_title": "RECORD {n}", + "sep_character": "-", + "sep_length": (1, 25), + "missing_value": "<null>", + "integer_format": "", + "float_format": "", + "disable_numparse": True, + "preserve_whitespace": True, + "max_field_width": 500, + } + formatter.query = 'SELECT * FROM "user";' + output = adapter(data, header, table_format=table_format, **kwargs) + output_list = [l for l in output] + expected = [ + 'INSERT INTO "user" ("id", "name", "email", "phone", "description", "created_at", "updated_at") VALUES', + " ('1', 'Jackson', 'jackson_test@gmail.com', '132454789', NULL, " + + "'2022-09-09 19:44:32.712343+08', '2022-09-09 19:44:32.712343+08')", + ";", + ] + assert expected == output_list + + +def test_output_sql_update(): + global formatter + formatter = TabularOutputFormatter + register_new_formatter(formatter) + data = [ + [ + 1, + "Jackson", + "jackson_test@gmail.com", + "132454789", + "", + "2022-09-09 19:44:32.712343+08", + "2022-09-09 19:44:32.712343+08", + ] + ] + header = ["id", "name", "email", "phone", "description", "created_at", "updated_at"] + table_format = "sql-update" + kwargs = { + "column_types": [int, str, str, str, str, str, str], + "sep_title": "RECORD {n}", + "sep_character": "-", + "sep_length": (1, 25), + "missing_value": "<null>", + "integer_format": "", + "float_format": "", + "disable_numparse": True, + "preserve_whitespace": True, + "max_field_width": 500, + } + formatter.query = 'SELECT * FROM "user";' + output = adapter(data, header, table_format=table_format, **kwargs) + output_list = [l for l in output] + print(output_list) + expected = [ + 'UPDATE "user" SET', + " \"name\" = 'Jackson'", + ", \"email\" = 'jackson_test@gmail.com'", + ", \"phone\" = '132454789'", + ", \"description\" = ''", + ", \"created_at\" = '2022-09-09 19:44:32.712343+08'", + ", \"updated_at\" = '2022-09-09 19:44:32.712343+08'", + "WHERE \"id\" = '1';", + ] + assert expected == output_list diff --git a/tests/metadata.py b/tests/metadata.py new file mode 100644 index 0000000..4ebcccd --- /dev/null +++ b/tests/metadata.py @@ -0,0 +1,255 @@ +from functools import partial +from itertools import product +from pgcli.packages.parseutils.meta import FunctionMetadata, ForeignKey +from prompt_toolkit.completion import Completion +from prompt_toolkit.document import Document +from unittest.mock import Mock +import pytest + +parametrize = pytest.mark.parametrize + +qual = ["if_more_than_one_table", "always"] +no_qual = ["if_more_than_one_table", "never"] + + +def escape(name): + if not name.islower() or name in ("select", "localtimestamp"): + return '"' + name + '"' + return name + + +def completion(display_meta, text, pos=0): + return Completion(text, start_position=pos, display_meta=display_meta) + + +def function(text, pos=0, display=None): + return Completion( + text, display=display or text, start_position=pos, display_meta="function" + ) + + +def get_result(completer, text, position=None): + position = len(text) if position is None else position + return completer.get_completions( + Document(text=text, cursor_position=position), Mock() + ) + + +def result_set(completer, text, position=None): + return set(get_result(completer, text, position)) + + +# The code below is quivalent to +# def schema(text, pos=0): +# return completion('schema', text, pos) +# and so on +schema = partial(completion, "schema") +table = partial(completion, "table") +view = partial(completion, "view") +column = partial(completion, "column") +keyword = partial(completion, "keyword") +datatype = partial(completion, "datatype") +alias = partial(completion, "table alias") +name_join = partial(completion, "name join") +fk_join = partial(completion, "fk join") +join = partial(completion, "join") + + +def wildcard_expansion(cols, pos=-1): + return Completion(cols, start_position=pos, display_meta="columns", display="*") + + +class MetaData: + def __init__(self, metadata): + self.metadata = metadata + + def builtin_functions(self, pos=0): + return [function(f, pos) for f in self.completer.functions] + + def builtin_datatypes(self, pos=0): + return [datatype(dt, pos) for dt in self.completer.datatypes] + + def keywords(self, pos=0): + return [keyword(kw, pos) for kw in self.completer.keywords_tree.keys()] + + def specials(self, pos=0): + return [ + Completion(text=k, start_position=pos, display_meta=v.description) + for k, v in self.completer.pgspecial.commands.items() + ] + + def columns(self, tbl, parent="public", typ="tables", pos=0): + if typ == "functions": + fun = [x for x in self.metadata[typ][parent] if x[0] == tbl][0] + cols = fun[1] + else: + cols = self.metadata[typ][parent][tbl] + return [column(escape(col), pos) for col in cols] + + def datatypes(self, parent="public", pos=0): + return [ + datatype(escape(x), pos) + for x in self.metadata.get("datatypes", {}).get(parent, []) + ] + + def tables(self, parent="public", pos=0): + return [ + table(escape(x), pos) + for x in self.metadata.get("tables", {}).get(parent, []) + ] + + def views(self, parent="public", pos=0): + return [ + view(escape(x), pos) for x in self.metadata.get("views", {}).get(parent, []) + ] + + def functions(self, parent="public", pos=0): + return [ + function( + escape(x[0]) + + "(" + + ", ".join( + arg_name + " := " + for (arg_name, arg_mode) in zip(x[1], x[3]) + if arg_mode in ("b", "i") + ) + + ")", + pos, + escape(x[0]) + + "(" + + ", ".join( + arg_name + for (arg_name, arg_mode) in zip(x[1], x[3]) + if arg_mode in ("b", "i") + ) + + ")", + ) + for x in self.metadata.get("functions", {}).get(parent, []) + ] + + def schemas(self, pos=0): + schemas = {sch for schs in self.metadata.values() for sch in schs} + return [schema(escape(s), pos=pos) for s in schemas] + + def functions_and_keywords(self, parent="public", pos=0): + return ( + self.functions(parent, pos) + + self.builtin_functions(pos) + + self.keywords(pos) + ) + + # Note that the filtering parameters here only apply to the columns + def columns_functions_and_keywords(self, tbl, parent="public", typ="tables", pos=0): + return self.functions_and_keywords(pos=pos) + self.columns( + tbl, parent, typ, pos + ) + + def from_clause_items(self, parent="public", pos=0): + return ( + self.functions(parent, pos) + + self.views(parent, pos) + + self.tables(parent, pos) + ) + + def schemas_and_from_clause_items(self, parent="public", pos=0): + return self.from_clause_items(parent, pos) + self.schemas(pos) + + def types(self, parent="public", pos=0): + return self.datatypes(parent, pos) + self.tables(parent, pos) + + @property + def completer(self): + return self.get_completer() + + def get_completers(self, casing): + """ + Returns a function taking three bools `casing`, `filtr`, `aliasing` and + the list `qualify`, all defaulting to None. + Returns a list of completers. + These parameters specify the allowed values for the corresponding + completer parameters, `None` meaning any, i.e. (None, None, None, None) + results in all 24 possible completers, whereas e.g. + (True, False, True, ['never']) results in the one completer with + casing, without `search_path` filtering of objects, with table + aliasing, and without column qualification. + """ + + def _cfg(_casing, filtr, aliasing, qualify): + cfg = {"settings": {}} + if _casing: + cfg["casing"] = casing + cfg["settings"]["search_path_filter"] = filtr + cfg["settings"]["generate_aliases"] = aliasing + cfg["settings"]["qualify_columns"] = qualify + return cfg + + def _cfgs(casing, filtr, aliasing, qualify): + casings = [True, False] if casing is None else [casing] + filtrs = [True, False] if filtr is None else [filtr] + aliases = [True, False] if aliasing is None else [aliasing] + qualifys = qualify or ["always", "if_more_than_one_table", "never"] + return [_cfg(*p) for p in product(casings, filtrs, aliases, qualifys)] + + def completers(casing=None, filtr=None, aliasing=None, qualify=None): + get_comp = self.get_completer + return [get_comp(**c) for c in _cfgs(casing, filtr, aliasing, qualify)] + + return completers + + def _make_col(self, sch, tbl, col): + defaults = self.metadata.get("defaults", {}).get(sch, {}) + return (sch, tbl, col, "text", (tbl, col) in defaults, defaults.get((tbl, col))) + + def get_completer(self, settings=None, casing=None): + metadata = self.metadata + from pgcli.pgcompleter import PGCompleter + from pgspecial import PGSpecial + + comp = PGCompleter( + smart_completion=True, settings=settings, pgspecial=PGSpecial() + ) + + schemata, tables, tbl_cols, views, view_cols = [], [], [], [], [] + + for sch, tbls in metadata["tables"].items(): + schemata.append(sch) + + for tbl, cols in tbls.items(): + tables.append((sch, tbl)) + # Let all columns be text columns + tbl_cols.extend([self._make_col(sch, tbl, col) for col in cols]) + + for sch, tbls in metadata.get("views", {}).items(): + for tbl, cols in tbls.items(): + views.append((sch, tbl)) + # Let all columns be text columns + view_cols.extend([self._make_col(sch, tbl, col) for col in cols]) + + functions = [ + FunctionMetadata(sch, *func_meta, arg_defaults=None) + for sch, funcs in metadata["functions"].items() + for func_meta in funcs + ] + + datatypes = [ + (sch, typ) + for sch, datatypes in metadata["datatypes"].items() + for typ in datatypes + ] + + foreignkeys = [ + ForeignKey(*fk) for fks in metadata["foreignkeys"].values() for fk in fks + ] + + comp.extend_schemata(schemata) + comp.extend_relations(tables, kind="tables") + comp.extend_relations(views, kind="views") + comp.extend_columns(tbl_cols, kind="tables") + comp.extend_columns(view_cols, kind="views") + comp.extend_functions(functions) + comp.extend_datatypes(datatypes) + comp.extend_foreignkeys(foreignkeys) + comp.set_search_path(["public"]) + comp.extend_casing(casing or []) + + return comp diff --git a/tests/parseutils/test_ctes.py b/tests/parseutils/test_ctes.py new file mode 100644 index 0000000..3e89cca --- /dev/null +++ b/tests/parseutils/test_ctes.py @@ -0,0 +1,137 @@ +import pytest +from sqlparse import parse +from pgcli.packages.parseutils.ctes import ( + token_start_pos, + extract_ctes, + extract_column_names as _extract_column_names, +) + + +def extract_column_names(sql): + p = parse(sql)[0] + return _extract_column_names(p) + + +def test_token_str_pos(): + sql = "SELECT * FROM xxx" + p = parse(sql)[0] + idx = p.token_index(p.tokens[-1]) + assert token_start_pos(p.tokens, idx) == len("SELECT * FROM ") + + sql = "SELECT * FROM \nxxx" + p = parse(sql)[0] + idx = p.token_index(p.tokens[-1]) + assert token_start_pos(p.tokens, idx) == len("SELECT * FROM \n") + + +def test_single_column_name_extraction(): + sql = "SELECT abc FROM xxx" + assert extract_column_names(sql) == ("abc",) + + +def test_aliased_single_column_name_extraction(): + sql = "SELECT abc def FROM xxx" + assert extract_column_names(sql) == ("def",) + + +def test_aliased_expression_name_extraction(): + sql = "SELECT 99 abc FROM xxx" + assert extract_column_names(sql) == ("abc",) + + +def test_multiple_column_name_extraction(): + sql = "SELECT abc, def FROM xxx" + assert extract_column_names(sql) == ("abc", "def") + + +def test_missing_column_name_handled_gracefully(): + sql = "SELECT abc, 99 FROM xxx" + assert extract_column_names(sql) == ("abc",) + + sql = "SELECT abc, 99, def FROM xxx" + assert extract_column_names(sql) == ("abc", "def") + + +def test_aliased_multiple_column_name_extraction(): + sql = "SELECT abc def, ghi jkl FROM xxx" + assert extract_column_names(sql) == ("def", "jkl") + + +def test_table_qualified_column_name_extraction(): + sql = "SELECT abc.def, ghi.jkl FROM xxx" + assert extract_column_names(sql) == ("def", "jkl") + + +@pytest.mark.parametrize( + "sql", + [ + "INSERT INTO foo (x, y, z) VALUES (5, 6, 7) RETURNING x, y", + "DELETE FROM foo WHERE x > y RETURNING x, y", + "UPDATE foo SET x = 9 RETURNING x, y", + ], +) +def test_extract_column_names_from_returning_clause(sql): + assert extract_column_names(sql) == ("x", "y") + + +def test_simple_cte_extraction(): + sql = "WITH a AS (SELECT abc FROM xxx) SELECT * FROM a" + start_pos = len("WITH a AS ") + stop_pos = len("WITH a AS (SELECT abc FROM xxx)") + ctes, remainder = extract_ctes(sql) + + assert tuple(ctes) == (("a", ("abc",), start_pos, stop_pos),) + assert remainder.strip() == "SELECT * FROM a" + + +def test_cte_extraction_around_comments(): + sql = """--blah blah blah + WITH a AS (SELECT abc def FROM x) + SELECT * FROM a""" + start_pos = len( + """--blah blah blah + WITH a AS """ + ) + stop_pos = len( + """--blah blah blah + WITH a AS (SELECT abc def FROM x)""" + ) + + ctes, remainder = extract_ctes(sql) + assert tuple(ctes) == (("a", ("def",), start_pos, stop_pos),) + assert remainder.strip() == "SELECT * FROM a" + + +def test_multiple_cte_extraction(): + sql = """WITH + x AS (SELECT abc, def FROM x), + y AS (SELECT ghi, jkl FROM y) + SELECT * FROM a, b""" + + start1 = len( + """WITH + x AS """ + ) + + stop1 = len( + """WITH + x AS (SELECT abc, def FROM x)""" + ) + + start2 = len( + """WITH + x AS (SELECT abc, def FROM x), + y AS """ + ) + + stop2 = len( + """WITH + x AS (SELECT abc, def FROM x), + y AS (SELECT ghi, jkl FROM y)""" + ) + + ctes, remainder = extract_ctes(sql) + assert tuple(ctes) == ( + ("x", ("abc", "def"), start1, stop1), + ("y", ("ghi", "jkl"), start2, stop2), + ) diff --git a/tests/parseutils/test_function_metadata.py b/tests/parseutils/test_function_metadata.py new file mode 100644 index 0000000..0350e2a --- /dev/null +++ b/tests/parseutils/test_function_metadata.py @@ -0,0 +1,19 @@ +from pgcli.packages.parseutils.meta import FunctionMetadata + + +def test_function_metadata_eq(): + f1 = FunctionMetadata( + "s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None + ) + f2 = FunctionMetadata( + "s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None + ) + f3 = FunctionMetadata( + "s", "g", ["x"], ["integer"], [], "int", False, False, False, False, None + ) + assert f1 == f2 + assert f1 != f3 + assert not (f1 != f2) + assert not (f1 == f3) + assert hash(f1) == hash(f2) + assert hash(f1) != hash(f3) diff --git a/tests/parseutils/test_parseutils.py b/tests/parseutils/test_parseutils.py new file mode 100644 index 0000000..349cbd0 --- /dev/null +++ b/tests/parseutils/test_parseutils.py @@ -0,0 +1,310 @@ +import pytest +from pgcli.packages.parseutils import ( + is_destructive, + parse_destructive_warning, + BASE_KEYWORDS, + ALL_KEYWORDS, +) +from pgcli.packages.parseutils.tables import extract_tables +from pgcli.packages.parseutils.utils import find_prev_keyword, is_open_quote + + +def test_empty_string(): + tables = extract_tables("") + assert tables == () + + +def test_simple_select_single_table(): + tables = extract_tables("select * from abc") + assert tables == ((None, "abc", None, False),) + + +@pytest.mark.parametrize( + "sql", ['select * from "abc"."def"', 'select * from abc."def"'] +) +def test_simple_select_single_table_schema_qualified_quoted_table(sql): + tables = extract_tables(sql) + assert tables == (("abc", "def", '"def"', False),) + + +@pytest.mark.parametrize("sql", ["select * from abc.def", 'select * from "abc".def']) +def test_simple_select_single_table_schema_qualified(sql): + tables = extract_tables(sql) + assert tables == (("abc", "def", None, False),) + + +def test_simple_select_single_table_double_quoted(): + tables = extract_tables('select * from "Abc"') + assert tables == ((None, "Abc", None, False),) + + +def test_simple_select_multiple_tables(): + tables = extract_tables("select * from abc, def") + assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)} + + +def test_simple_select_multiple_tables_double_quoted(): + tables = extract_tables('select * from "Abc", "Def"') + assert set(tables) == {(None, "Abc", None, False), (None, "Def", None, False)} + + +def test_simple_select_single_table_deouble_quoted_aliased(): + tables = extract_tables('select * from "Abc" a') + assert tables == ((None, "Abc", "a", False),) + + +def test_simple_select_multiple_tables_deouble_quoted_aliased(): + tables = extract_tables('select * from "Abc" a, "Def" d') + assert set(tables) == {(None, "Abc", "a", False), (None, "Def", "d", False)} + + +def test_simple_select_multiple_tables_schema_qualified(): + tables = extract_tables("select * from abc.def, ghi.jkl") + assert set(tables) == {("abc", "def", None, False), ("ghi", "jkl", None, False)} + + +def test_simple_select_with_cols_single_table(): + tables = extract_tables("select a,b from abc") + assert tables == ((None, "abc", None, False),) + + +def test_simple_select_with_cols_single_table_schema_qualified(): + tables = extract_tables("select a,b from abc.def") + assert tables == (("abc", "def", None, False),) + + +def test_simple_select_with_cols_multiple_tables(): + tables = extract_tables("select a,b from abc, def") + assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)} + + +def test_simple_select_with_cols_multiple_qualified_tables(): + tables = extract_tables("select a,b from abc.def, def.ghi") + assert set(tables) == {("abc", "def", None, False), ("def", "ghi", None, False)} + + +def test_select_with_hanging_comma_single_table(): + tables = extract_tables("select a, from abc") + assert tables == ((None, "abc", None, False),) + + +def test_select_with_hanging_comma_multiple_tables(): + tables = extract_tables("select a, from abc, def") + assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)} + + +def test_select_with_hanging_period_multiple_tables(): + tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2") + assert set(tables) == {(None, "tabl1", "t1", False), (None, "tabl2", "t2", False)} + + +def test_simple_insert_single_table(): + tables = extract_tables('insert into abc (id, name) values (1, "def")') + + # sqlparse mistakenly assigns an alias to the table + # AND mistakenly identifies the field list as + # assert tables == ((None, 'abc', 'abc', False),) + + assert tables == ((None, "abc", "abc", False),) + + +@pytest.mark.xfail +def test_simple_insert_single_table_schema_qualified(): + tables = extract_tables('insert into abc.def (id, name) values (1, "def")') + assert tables == (("abc", "def", None, False),) + + +def test_simple_update_table_no_schema(): + tables = extract_tables("update abc set id = 1") + assert tables == ((None, "abc", None, False),) + + +def test_simple_update_table_with_schema(): + tables = extract_tables("update abc.def set id = 1") + assert tables == (("abc", "def", None, False),) + + +@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"]) +def test_join_table(join_type): + sql = f"SELECT * FROM abc a {join_type} JOIN def d ON a.id = d.num" + tables = extract_tables(sql) + assert set(tables) == {(None, "abc", "a", False), (None, "def", "d", False)} + + +def test_join_table_schema_qualified(): + tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num") + assert set(tables) == {("abc", "def", "x", False), ("ghi", "jkl", "y", False)} + + +def test_incomplete_join_clause(): + sql = """select a.x, b.y + from abc a join bcd b + on a.id = """ + tables = extract_tables(sql) + assert tables == ((None, "abc", "a", False), (None, "bcd", "b", False)) + + +def test_join_as_table(): + tables = extract_tables("SELECT * FROM my_table AS m WHERE m.a > 5") + assert tables == ((None, "my_table", "m", False),) + + +def test_multiple_joins(): + sql = """select * from t1 + inner join t2 ON + t1.id = t2.t1_id + inner join t3 ON + t2.id = t3.""" + tables = extract_tables(sql) + assert tables == ( + (None, "t1", None, False), + (None, "t2", None, False), + (None, "t3", None, False), + ) + + +def test_subselect_tables(): + sql = "SELECT * FROM (SELECT FROM abc" + tables = extract_tables(sql) + assert tables == ((None, "abc", None, False),) + + +@pytest.mark.parametrize("text", ["SELECT * FROM foo.", "SELECT 123 AS foo"]) +def test_extract_no_tables(text): + tables = extract_tables(text) + assert tables == tuple() + + +@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"]) +def test_simple_function_as_table(arg_list): + tables = extract_tables(f"SELECT * FROM foo({arg_list})") + assert tables == ((None, "foo", None, True),) + + +@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"]) +def test_simple_schema_qualified_function_as_table(arg_list): + tables = extract_tables(f"SELECT * FROM foo.bar({arg_list})") + assert tables == (("foo", "bar", None, True),) + + +@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"]) +def test_simple_aliased_function_as_table(arg_list): + tables = extract_tables(f"SELECT * FROM foo({arg_list}) bar") + assert tables == ((None, "foo", "bar", True),) + + +def test_simple_table_and_function(): + tables = extract_tables("SELECT * FROM foo JOIN bar()") + assert set(tables) == {(None, "foo", None, False), (None, "bar", None, True)} + + +def test_complex_table_and_function(): + tables = extract_tables( + """SELECT * FROM foo.bar baz + JOIN bar.qux(x, y, z) quux""" + ) + assert set(tables) == {("foo", "bar", "baz", False), ("bar", "qux", "quux", True)} + + +def test_find_prev_keyword_using(): + q = "select * from tbl1 inner join tbl2 using (col1, " + kw, q2 = find_prev_keyword(q) + assert kw.value == "(" and q2 == "select * from tbl1 inner join tbl2 using (" + + +@pytest.mark.parametrize( + "sql", + [ + "select * from foo where bar", + "select * from foo where bar = 1 and baz or ", + "select * from foo where bar = 1 and baz between qux and ", + ], +) +def test_find_prev_keyword_where(sql): + kw, stripped = find_prev_keyword(sql) + assert kw.value == "where" and stripped == "select * from foo where" + + +@pytest.mark.parametrize( + "sql", ["create table foo (bar int, baz ", "select * from foo() as bar (baz "] +) +def test_find_prev_keyword_open_parens(sql): + kw, _ = find_prev_keyword(sql) + assert kw.value == "(" + + +@pytest.mark.parametrize( + "sql", + [ + "", + "$$ foo $$", + "$$ 'foo' $$", + '$$ "foo" $$', + "$$ $a$ $$", + "$a$ $$ $a$", + "foo bar $$ baz $$", + ], +) +def test_is_open_quote__closed(sql): + assert not is_open_quote(sql) + + +@pytest.mark.parametrize( + "sql", + [ + "$$", + ";;;$$", + "foo $$ bar $$; foo $$", + "$$ foo $a$", + "foo 'bar baz", + "$a$ foo ", + '$$ "foo" ', + "$$ $a$ ", + "foo bar $$ baz", + ], +) +def test_is_open_quote__open(sql): + assert is_open_quote(sql) + + +@pytest.mark.parametrize( + ("sql", "keywords", "expected"), + [ + ("update abc set x = 1", ALL_KEYWORDS, True), + ("update abc set x = 1 where y = 2", ALL_KEYWORDS, True), + ("update abc set x = 1", BASE_KEYWORDS, True), + ("update abc set x = 1 where y = 2", BASE_KEYWORDS, False), + ("select x, y, z from abc", ALL_KEYWORDS, False), + ("drop abc", ALL_KEYWORDS, True), + ("alter abc", ALL_KEYWORDS, True), + ("delete abc", ALL_KEYWORDS, True), + ("truncate abc", ALL_KEYWORDS, True), + ("insert into abc values (1, 2, 3)", ALL_KEYWORDS, False), + ("insert into abc values (1, 2, 3)", BASE_KEYWORDS, False), + ("insert into abc values (1, 2, 3)", ["insert"], True), + ("insert into abc values (1, 2, 3)", ["insert"], True), + ], +) +def test_is_destructive(sql, keywords, expected): + assert is_destructive(sql, keywords) == expected + + +@pytest.mark.parametrize( + ("warning_level", "expected"), + [ + ("true", ALL_KEYWORDS), + ("false", []), + ("all", ALL_KEYWORDS), + ("moderate", BASE_KEYWORDS), + ("off", []), + ("", []), + (None, []), + (ALL_KEYWORDS, ALL_KEYWORDS), + (BASE_KEYWORDS, BASE_KEYWORDS), + ("insert", ["insert"]), + ("drop,alter,delete", ["drop", "alter", "delete"]), + (["drop", "alter", "delete"], ["drop", "alter", "delete"]), + ], +) +def test_parse_destructive_warning(warning_level, expected): + assert parse_destructive_warning(warning_level) == expected diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 0000000..f787740 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts=--capture=sys --showlocals
\ No newline at end of file diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..a517a89 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,40 @@ +import pytest +from unittest import mock +from pgcli import auth + + +@pytest.mark.parametrize("enabled,call_count", [(True, 1), (False, 0)]) +def test_keyring_initialize(enabled, call_count): + logger = mock.MagicMock() + + with mock.patch("importlib.import_module", return_value=True) as import_method: + auth.keyring_initialize(enabled, logger=logger) + assert import_method.call_count == call_count + + +def test_keyring_get_password_ok(): + with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()): + with mock.patch("pgcli.auth.keyring.get_password", return_value="abc123"): + assert auth.keyring_get_password("test") == "abc123" + + +def test_keyring_get_password_exception(): + with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()): + with mock.patch( + "pgcli.auth.keyring.get_password", side_effect=Exception("Boom!") + ): + assert auth.keyring_get_password("test") == "" + + +def test_keyring_set_password_ok(): + with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()): + with mock.patch("pgcli.auth.keyring.set_password"): + auth.keyring_set_password("test", "abc123") + + +def test_keyring_set_password_exception(): + with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()): + with mock.patch( + "pgcli.auth.keyring.set_password", side_effect=Exception("Boom!") + ): + auth.keyring_set_password("test", "abc123") diff --git a/tests/test_completion_refresher.py b/tests/test_completion_refresher.py new file mode 100644 index 0000000..a5529d6 --- /dev/null +++ b/tests/test_completion_refresher.py @@ -0,0 +1,95 @@ +import time +import pytest +from unittest.mock import Mock, patch + + +@pytest.fixture +def refresher(): + from pgcli.completion_refresher import CompletionRefresher + + return CompletionRefresher() + + +def test_ctor(refresher): + """ + Refresher object should contain a few handlers + :param refresher: + :return: + """ + assert len(refresher.refreshers) > 0 + actual_handlers = list(refresher.refreshers.keys()) + expected_handlers = [ + "schemata", + "tables", + "views", + "types", + "databases", + "casing", + "functions", + ] + assert expected_handlers == actual_handlers + + +def test_refresh_called_once(refresher): + """ + + :param refresher: + :return: + """ + callbacks = Mock() + pgexecute = Mock(**{"is_virtual_database.return_value": False}) + special = Mock() + + with patch.object(refresher, "_bg_refresh") as bg_refresh: + actual = refresher.refresh(pgexecute, special, callbacks) + time.sleep(1) # Wait for the thread to work. + assert len(actual) == 1 + assert len(actual[0]) == 4 + assert actual[0][3] == "Auto-completion refresh started in the background." + bg_refresh.assert_called_with(pgexecute, special, callbacks, None, None) + + +def test_refresh_called_twice(refresher): + """ + If refresh is called a second time, it should be restarted + :param refresher: + :return: + """ + callbacks = Mock() + + pgexecute = Mock(**{"is_virtual_database.return_value": False}) + special = Mock() + + def dummy_bg_refresh(*args): + time.sleep(3) # seconds + + refresher._bg_refresh = dummy_bg_refresh + + actual1 = refresher.refresh(pgexecute, special, callbacks) + time.sleep(1) # Wait for the thread to work. + assert len(actual1) == 1 + assert len(actual1[0]) == 4 + assert actual1[0][3] == "Auto-completion refresh started in the background." + + actual2 = refresher.refresh(pgexecute, special, callbacks) + time.sleep(1) # Wait for the thread to work. + assert len(actual2) == 1 + assert len(actual2[0]) == 4 + assert actual2[0][3] == "Auto-completion refresh restarted." + + +def test_refresh_with_callbacks(refresher): + """ + Callbacks must be called + :param refresher: + """ + callbacks = [Mock()] + pgexecute = Mock(**{"is_virtual_database.return_value": False}) + pgexecute.extra_args = {} + special = Mock() + + # Set refreshers to 0: we're not testing refresh logic here + refresher.refreshers = {} + refresher.refresh(pgexecute, special, callbacks) + time.sleep(1) # Wait for the thread to work. + assert callbacks[0].call_count == 1 diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..08fe74e --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,43 @@ +import io +import os +import stat + +import pytest + +from pgcli.config import ensure_dir_exists, skip_initial_comment + + +def test_ensure_file_parent(tmpdir): + subdir = tmpdir.join("subdir") + rcfile = subdir.join("rcfile") + ensure_dir_exists(str(rcfile)) + + +def test_ensure_existing_dir(tmpdir): + rcfile = str(tmpdir.mkdir("subdir").join("rcfile")) + + # should just not raise + ensure_dir_exists(rcfile) + + +def test_ensure_other_create_error(tmpdir): + subdir = tmpdir.join('subdir"') + rcfile = subdir.join("rcfile") + + # trigger an oserror that isn't "directory already exists" + os.chmod(str(tmpdir), stat.S_IREAD) + + with pytest.raises(OSError): + ensure_dir_exists(str(rcfile)) + + +@pytest.mark.parametrize( + "text, skipped_lines", + ( + ("abc\n", 1), + ("#[section]\ndef\n[section]", 2), + ("[section]", 0), + ), +) +def test_skip_initial_comment(text, skipped_lines): + assert skip_initial_comment(io.StringIO(text)) == skipped_lines diff --git a/tests/test_fuzzy_completion.py b/tests/test_fuzzy_completion.py new file mode 100644 index 0000000..8f8f2cd --- /dev/null +++ b/tests/test_fuzzy_completion.py @@ -0,0 +1,87 @@ +import pytest + + +@pytest.fixture +def completer(): + import pgcli.pgcompleter as pgcompleter + + return pgcompleter.PGCompleter() + + +def test_ranking_ignores_identifier_quotes(completer): + """When calculating result rank, identifier quotes should be ignored. + + The result ranking algorithm ignores identifier quotes. Without this + correction, the match "user", which Postgres requires to be quoted + since it is also a reserved word, would incorrectly fall below the + match user_action because the literal quotation marks in "user" + alter the position of the match. + + This test checks that the fuzzy ranking algorithm correctly ignores + quotation marks when computing match ranks. + + """ + + text = "user" + collection = ["user_action", '"user"'] + matches = completer.find_matches(text, collection) + assert len(matches) == 2 + + +def test_ranking_based_on_shortest_match(completer): + """Fuzzy result rank should be based on shortest match. + + Result ranking in fuzzy searching is partially based on the length + of matches: shorter matches are considered more relevant than + longer ones. When searching for the text 'user', the length + component of the match 'user_group' could be either 4 ('user') or + 7 ('user_gr'). + + This test checks that the fuzzy ranking algorithm uses the shorter + match when calculating result rank. + + """ + + text = "user" + collection = ["api_user", "user_group"] + matches = completer.find_matches(text, collection) + + assert matches[1].priority > matches[0].priority + + +@pytest.mark.parametrize( + "collection", + [["user_action", "user"], ["user_group", "user"], ["user_group", "user_action"]], +) +def test_should_break_ties_using_lexical_order(completer, collection): + """Fuzzy result rank should use lexical order to break ties. + + When fuzzy matching, if multiple matches have the same match length and + start position, present them in lexical (rather than arbitrary) order. For + example, if we have tables 'user', 'user_action', and 'user_group', a + search for the text 'user' should present these tables in this order. + + The input collections to this test are out of order; each run checks that + the search text 'user' results in the input tables being reordered + lexically. + + """ + + text = "user" + matches = completer.find_matches(text, collection) + + assert matches[1].priority > matches[0].priority + + +def test_matching_should_be_case_insensitive(completer): + """Fuzzy matching should keep matches even if letter casing doesn't match. + + This test checks that variations of the text which have different casing + are still matched. + """ + + text = "foo" + collection = ["Foo", "FOO", "fOO"] + matches = completer.find_matches(text, collection) + + assert len(matches) == 3 diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..cbf20a6 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,490 @@ +import os +import platform +from unittest import mock + +import pytest + +try: + import setproctitle +except ImportError: + setproctitle = None + +from pgcli.main import ( + obfuscate_process_password, + format_output, + PGCli, + OutputSettings, + COLOR_CODE_REGEX, +) +from pgcli.pgexecute import PGExecute +from pgspecial.main import PAGER_OFF, PAGER_LONG_OUTPUT, PAGER_ALWAYS +from utils import dbtest, run +from collections import namedtuple + + +@pytest.mark.skipif(platform.system() == "Windows", reason="Not applicable in windows") +@pytest.mark.skipif(not setproctitle, reason="setproctitle not available") +def test_obfuscate_process_password(): + original_title = setproctitle.getproctitle() + + setproctitle.setproctitle("pgcli user=root password=secret host=localhost") + obfuscate_process_password() + title = setproctitle.getproctitle() + expected = "pgcli user=root password=xxxx host=localhost" + assert title == expected + + setproctitle.setproctitle("pgcli user=root password=top secret host=localhost") + obfuscate_process_password() + title = setproctitle.getproctitle() + expected = "pgcli user=root password=xxxx host=localhost" + assert title == expected + + setproctitle.setproctitle("pgcli user=root password=top secret") + obfuscate_process_password() + title = setproctitle.getproctitle() + expected = "pgcli user=root password=xxxx" + assert title == expected + + setproctitle.setproctitle("pgcli postgres://root:secret@localhost/db") + obfuscate_process_password() + title = setproctitle.getproctitle() + expected = "pgcli postgres://root:xxxx@localhost/db" + assert title == expected + + setproctitle.setproctitle(original_title) + + +def test_format_output(): + settings = OutputSettings(table_format="psql", dcmlfmt="d", floatfmt="g") + results = format_output( + "Title", [("abc", "def")], ["head1", "head2"], "test status", settings + ) + expected = [ + "Title", + "+-------+-------+", + "| head1 | head2 |", + "|-------+-------|", + "| abc | def |", + "+-------+-------+", + "test status", + ] + assert list(results) == expected + + +def test_format_output_truncate_on(): + settings = OutputSettings( + table_format="psql", dcmlfmt="d", floatfmt="g", max_field_width=10 + ) + results = format_output( + None, + [("first field value", "second field value")], + ["head1", "head2"], + None, + settings, + ) + expected = [ + "+------------+------------+", + "| head1 | head2 |", + "|------------+------------|", + "| first f... | second ... |", + "+------------+------------+", + ] + assert list(results) == expected + + +def test_format_output_truncate_off(): + settings = OutputSettings( + table_format="psql", dcmlfmt="d", floatfmt="g", max_field_width=None + ) + long_field_value = ("first field " * 100).strip() + results = format_output(None, [(long_field_value,)], ["head1"], None, settings) + lines = list(results) + assert lines[3] == f"| {long_field_value} |" + + +@dbtest +def test_format_array_output(executor): + statement = """ + SELECT + array[1, 2, 3]::bigint[] as bigint_array, + '{{1,2},{3,4}}'::numeric[] as nested_numeric_array, + '{å,魚,текст}'::text[] as 配列 + UNION ALL + SELECT '{}', NULL, array[NULL] + """ + results = run(executor, statement) + expected = [ + "+--------------+----------------------+--------------+", + "| bigint_array | nested_numeric_array | 配列 |", + "|--------------+----------------------+--------------|", + "| {1,2,3} | {{1,2},{3,4}} | {å,魚,текст} |", + "| {} | <null> | {<null>} |", + "+--------------+----------------------+--------------+", + "SELECT 2", + ] + assert list(results) == expected + + +@dbtest +def test_format_array_output_expanded(executor): + statement = """ + SELECT + array[1, 2, 3]::bigint[] as bigint_array, + '{{1,2},{3,4}}'::numeric[] as nested_numeric_array, + '{å,魚,текст}'::text[] as 配列 + UNION ALL + SELECT '{}', NULL, array[NULL] + """ + results = run(executor, statement, expanded=True) + expected = [ + "-[ RECORD 1 ]-------------------------", + "bigint_array | {1,2,3}", + "nested_numeric_array | {{1,2},{3,4}}", + "配列 | {å,魚,текст}", + "-[ RECORD 2 ]-------------------------", + "bigint_array | {}", + "nested_numeric_array | <null>", + "配列 | {<null>}", + "SELECT 2", + ] + assert "\n".join(results) == "\n".join(expected) + + +def test_format_output_auto_expand(): + settings = OutputSettings( + table_format="psql", dcmlfmt="d", floatfmt="g", max_width=100 + ) + table_results = format_output( + "Title", [("abc", "def")], ["head1", "head2"], "test status", settings + ) + table = [ + "Title", + "+-------+-------+", + "| head1 | head2 |", + "|-------+-------|", + "| abc | def |", + "+-------+-------+", + "test status", + ] + assert list(table_results) == table + expanded_results = format_output( + "Title", + [("abc", "def")], + ["head1", "head2"], + "test status", + settings._replace(max_width=1), + ) + expanded = [ + "Title", + "-[ RECORD 1 ]-------------------------", + "head1 | abc", + "head2 | def", + "test status", + ] + assert "\n".join(expanded_results) == "\n".join(expanded) + + +termsize = namedtuple("termsize", ["rows", "columns"]) +test_line = "-" * 10 +test_data = [ + (10, 10, "\n".join([test_line] * 7)), + (10, 10, "\n".join([test_line] * 6)), + (10, 10, "\n".join([test_line] * 5)), + (10, 10, "-" * 11), + (10, 10, "-" * 10), + (10, 10, "-" * 9), +] + +# 4 lines are reserved at the bottom of the terminal for pgcli's prompt +use_pager_when_on = [True, True, False, True, False, False] + +# Can be replaced with pytest.param once we can upgrade pytest after Python 3.4 goes EOL +test_ids = [ + "Output longer than terminal height", + "Output equal to terminal height", + "Output shorter than terminal height", + "Output longer than terminal width", + "Output equal to terminal width", + "Output shorter than terminal width", +] + + +@pytest.fixture +def pset_pager_mocks(): + cli = PGCli() + cli.watch_command = None + with mock.patch("pgcli.main.click.echo") as mock_echo, mock.patch( + "pgcli.main.click.echo_via_pager" + ) as mock_echo_via_pager, mock.patch.object(cli, "prompt_app") as mock_app: + yield cli, mock_echo, mock_echo_via_pager, mock_app + + +@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids) +def test_pset_pager_off(term_height, term_width, text, pset_pager_mocks): + cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks + mock_cli.output.get_size.return_value = termsize( + rows=term_height, columns=term_width + ) + + with mock.patch.object(cli.pgspecial, "pager_config", PAGER_OFF): + cli.echo_via_pager(text) + + mock_echo.assert_called() + mock_echo_via_pager.assert_not_called() + + +@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids) +def test_pset_pager_always(term_height, term_width, text, pset_pager_mocks): + cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks + mock_cli.output.get_size.return_value = termsize( + rows=term_height, columns=term_width + ) + + with mock.patch.object(cli.pgspecial, "pager_config", PAGER_ALWAYS): + cli.echo_via_pager(text) + + mock_echo.assert_not_called() + mock_echo_via_pager.assert_called() + + +pager_on_test_data = [l + (r,) for l, r in zip(test_data, use_pager_when_on)] + + +@pytest.mark.parametrize( + "term_height,term_width,text,use_pager", pager_on_test_data, ids=test_ids +) +def test_pset_pager_on(term_height, term_width, text, use_pager, pset_pager_mocks): + cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks + mock_cli.output.get_size.return_value = termsize( + rows=term_height, columns=term_width + ) + + with mock.patch.object(cli.pgspecial, "pager_config", PAGER_LONG_OUTPUT): + cli.echo_via_pager(text) + + if use_pager: + mock_echo.assert_not_called() + mock_echo_via_pager.assert_called() + else: + mock_echo_via_pager.assert_not_called() + mock_echo.assert_called() + + +@pytest.mark.parametrize( + "text,expected_length", + [ + ( + "22200K .......\u001b[0m\u001b[91m... .......... ...\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... .........\u001b[0m\u001b[91m.\u001b[0m\u001b[91m \u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... 50% 28.6K 12m55s", + 78, + ), + ("=\u001b[m=", 2), + ("-\u001b]23\u0007-", 2), + ], +) +def test_color_pattern(text, expected_length, pset_pager_mocks): + cli = pset_pager_mocks[0] + assert len(COLOR_CODE_REGEX.sub("", text)) == expected_length + + +@dbtest +def test_i_works(tmpdir, executor): + sqlfile = tmpdir.join("test.sql") + sqlfile.write("SELECT NOW()") + rcfile = str(tmpdir.join("rcfile")) + cli = PGCli(pgexecute=executor, pgclirc_file=rcfile) + statement = r"\i {0}".format(sqlfile) + run(executor, statement, pgspecial=cli.pgspecial) + + +@dbtest +def test_echo_works(executor): + cli = PGCli(pgexecute=executor) + statement = r"\echo asdf" + result = run(executor, statement, pgspecial=cli.pgspecial) + assert result == ["asdf"] + + +@dbtest +def test_qecho_works(executor): + cli = PGCli(pgexecute=executor) + statement = r"\qecho asdf" + result = run(executor, statement, pgspecial=cli.pgspecial) + assert result == ["asdf"] + + +@dbtest +def test_watch_works(executor): + cli = PGCli(pgexecute=executor) + + def run_with_watch( + query, target_call_count=1, expected_output="", expected_timing=None + ): + """ + :param query: Input to the CLI + :param target_call_count: Number of times the user lets the command run before Ctrl-C + :param expected_output: Substring expected to be found for each executed query + :param expected_timing: value `time.sleep` expected to be called with on every invocation + """ + with mock.patch.object(cli, "echo_via_pager") as mock_echo, mock.patch( + "pgcli.main.sleep" + ) as mock_sleep: + mock_sleep.side_effect = [None] * (target_call_count - 1) + [ + KeyboardInterrupt + ] + cli.handle_watch_command(query) + # Validate that sleep was called with the right timing + for i in range(target_call_count - 1): + assert mock_sleep.call_args_list[i][0][0] == expected_timing + # Validate that the output of the query was expected + assert mock_echo.call_count == target_call_count + for i in range(target_call_count): + assert expected_output in mock_echo.call_args_list[i][0][0] + + # With no history, it errors. + with mock.patch("pgcli.main.click.secho") as mock_secho: + cli.handle_watch_command(r"\watch 2") + mock_secho.assert_called() + assert ( + r"\watch cannot be used with an empty query" + in mock_secho.call_args_list[0][0][0] + ) + + # Usage 1: Run a query and then re-run it with \watch across two prompts. + run_with_watch("SELECT 111", expected_output="111") + run_with_watch( + "\\watch 10", target_call_count=2, expected_output="111", expected_timing=10 + ) + + # Usage 2: Run a query and \watch via the same prompt. + run_with_watch( + "SELECT 222; \\watch 4", + target_call_count=3, + expected_output="222", + expected_timing=4, + ) + + # Usage 3: Re-run the last watched command with a new timing + run_with_watch( + "\\watch 5", target_call_count=4, expected_output="222", expected_timing=5 + ) + + +def test_missing_rc_dir(tmpdir): + rcfile = str(tmpdir.join("subdir").join("rcfile")) + + PGCli(pgclirc_file=rcfile) + assert os.path.exists(rcfile) + + +def test_quoted_db_uri(tmpdir): + with mock.patch.object(PGCli, "connect") as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + cli.connect_uri("postgres://bar%5E:%5Dfoo@baz.com/testdb%5B") + mock_connect.assert_called_with( + database="testdb[", host="baz.com", user="bar^", passwd="]foo" + ) + + +def test_pg_service_file(tmpdir): + with mock.patch.object(PGCli, "connect") as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + with open(tmpdir.join(".pg_service.conf").strpath, "w") as service_conf: + service_conf.write( + """File begins with a comment + that is not a comment + # or maybe a comment after all + because psql is crazy + + [myservice] + host=a_host + user=a_user + port=5433 + password=much_secure + dbname=a_dbname + + [my_other_service] + host=b_host + user=b_user + port=5435 + dbname=b_dbname + """ + ) + os.environ["PGSERVICEFILE"] = tmpdir.join(".pg_service.conf").strpath + cli.connect_service("myservice", "another_user") + mock_connect.assert_called_with( + database="a_dbname", + host="a_host", + user="another_user", + port="5433", + passwd="much_secure", + ) + + with mock.patch.object(PGExecute, "__init__") as mock_pgexecute: + mock_pgexecute.return_value = None + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + os.environ["PGPASSWORD"] = "very_secure" + cli.connect_service("my_other_service", None) + mock_pgexecute.assert_called_with( + "b_dbname", + "b_user", + "very_secure", + "b_host", + "5435", + "", + application_name="pgcli", + ) + del os.environ["PGPASSWORD"] + del os.environ["PGSERVICEFILE"] + + +def test_ssl_db_uri(tmpdir): + with mock.patch.object(PGCli, "connect") as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + cli.connect_uri( + "postgres://bar%5E:%5Dfoo@baz.com/testdb%5B?" + "sslmode=verify-full&sslcert=m%79.pem&sslkey=my-key.pem&sslrootcert=c%61.pem" + ) + mock_connect.assert_called_with( + database="testdb[", + host="baz.com", + user="bar^", + passwd="]foo", + sslmode="verify-full", + sslcert="my.pem", + sslkey="my-key.pem", + sslrootcert="ca.pem", + ) + + +def test_port_db_uri(tmpdir): + with mock.patch.object(PGCli, "connect") as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + cli.connect_uri("postgres://bar:foo@baz.com:2543/testdb") + mock_connect.assert_called_with( + database="testdb", host="baz.com", user="bar", passwd="foo", port="2543" + ) + + +def test_multihost_db_uri(tmpdir): + with mock.patch.object(PGCli, "connect") as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + cli.connect_uri( + "postgres://bar:foo@baz1.com:2543,baz2.com:2543,baz3.com:2543/testdb" + ) + mock_connect.assert_called_with( + database="testdb", + host="baz1.com,baz2.com,baz3.com", + user="bar", + passwd="foo", + port="2543,2543,2543", + ) + + +def test_application_name_db_uri(tmpdir): + with mock.patch.object(PGExecute, "__init__") as mock_pgexecute: + mock_pgexecute.return_value = None + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + cli.connect_uri("postgres://bar@baz.com/?application_name=cow") + mock_pgexecute.assert_called_with( + "bar", "bar", "", "baz.com", "", "", application_name="cow" + ) diff --git a/tests/test_naive_completion.py b/tests/test_naive_completion.py new file mode 100644 index 0000000..5b93661 --- /dev/null +++ b/tests/test_naive_completion.py @@ -0,0 +1,133 @@ +import pytest +from prompt_toolkit.completion import Completion +from prompt_toolkit.document import Document +from utils import completions_to_set + + +@pytest.fixture +def completer(): + import pgcli.pgcompleter as pgcompleter + + return pgcompleter.PGCompleter(smart_completion=False) + + +@pytest.fixture +def complete_event(): + from unittest.mock import Mock + + return Mock() + + +def test_empty_string_completion(completer, complete_event): + text = "" + position = 0 + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == completions_to_set(map(Completion, completer.all_completions)) + + +def test_select_keyword_completion(completer, complete_event): + text = "SEL" + position = len("SEL") + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == completions_to_set([Completion(text="SELECT", start_position=-3)]) + + +def test_function_name_completion(completer, complete_event): + text = "SELECT MA" + position = len("SELECT MA") + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == completions_to_set( + [ + Completion(text="MATERIALIZED VIEW", start_position=-2), + Completion(text="MAX", start_position=-2), + Completion(text="MAXEXTENTS", start_position=-2), + Completion(text="MAKE_DATE", start_position=-2), + Completion(text="MAKE_TIME", start_position=-2), + Completion(text="MAKE_TIMESTAMPTZ", start_position=-2), + Completion(text="MAKE_INTERVAL", start_position=-2), + Completion(text="MASKLEN", start_position=-2), + Completion(text="MAKE_TIMESTAMP", start_position=-2), + ] + ) + + +def test_column_name_completion(completer, complete_event): + text = "SELECT FROM users" + position = len("SELECT ") + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == completions_to_set(map(Completion, completer.all_completions)) + + +def test_alter_well_known_keywords_completion(completer, complete_event): + text = "ALTER " + position = len(text) + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), + complete_event, + smart_completion=True, + ) + ) + assert result > completions_to_set( + [ + Completion(text="DATABASE", display_meta="keyword"), + Completion(text="TABLE", display_meta="keyword"), + Completion(text="SYSTEM", display_meta="keyword"), + ] + ) + assert ( + completions_to_set([Completion(text="CREATE", display_meta="keyword")]) + not in result + ) + + +def test_special_name_completion(completer, complete_event): + text = "\\" + position = len("\\") + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + # Special commands will NOT be suggested during naive completion mode. + assert result == completions_to_set([]) + + +def test_datatype_name_completion(completer, complete_event): + text = "SELECT price::IN" + position = len("SELECT price::IN") + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), + complete_event, + smart_completion=True, + ) + ) + assert result == completions_to_set( + [ + Completion(text="INET", display_meta="datatype"), + Completion(text="INT", display_meta="datatype"), + Completion(text="INT2", display_meta="datatype"), + Completion(text="INT4", display_meta="datatype"), + Completion(text="INT8", display_meta="datatype"), + Completion(text="INTEGER", display_meta="datatype"), + Completion(text="INTERNAL", display_meta="datatype"), + Completion(text="INTERVAL", display_meta="datatype"), + ] + ) diff --git a/tests/test_pgcompleter.py b/tests/test_pgcompleter.py new file mode 100644 index 0000000..909fa0b --- /dev/null +++ b/tests/test_pgcompleter.py @@ -0,0 +1,76 @@ +import pytest +from pgcli import pgcompleter + + +def test_load_alias_map_file_missing_file(): + with pytest.raises( + pgcompleter.InvalidMapFile, + match=r"Cannot read alias_map_file - /path/to/non-existent/file.json does not exist$", + ): + pgcompleter.load_alias_map_file("/path/to/non-existent/file.json") + + +def test_load_alias_map_file_invalid_json(tmp_path): + fpath = tmp_path / "foo.json" + fpath.write_text("this is not valid json") + with pytest.raises(pgcompleter.InvalidMapFile, match=r".*is not valid json$"): + pgcompleter.load_alias_map_file(str(fpath)) + + +@pytest.mark.parametrize( + "table_name, alias", + [ + ("SomE_Table", "SET"), + ("SOmeTabLe", "SOTL"), + ("someTable", "T"), + ], +) +def test_generate_alias_uses_upper_case_letters_from_name(table_name, alias): + assert pgcompleter.generate_alias(table_name) == alias + + +@pytest.mark.parametrize( + "table_name, alias", + [ + ("some_tab_le", "stl"), + ("s_ome_table", "sot"), + ("sometable", "s"), + ], +) +def test_generate_alias_uses_first_char_and_every_preceded_by_underscore( + table_name, alias +): + assert pgcompleter.generate_alias(table_name) == alias + + +@pytest.mark.parametrize( + "table_name, alias_map, alias", + [ + ("some_table", {"some_table": "my_alias"}, "my_alias"), + ], +) +def test_generate_alias_can_use_alias_map(table_name, alias_map, alias): + assert pgcompleter.generate_alias(table_name, alias_map) == alias + + +@pytest.mark.parametrize( + "table_name, alias_map, alias", + [ + ("SomeTable", {"SomeTable": "my_alias"}, "my_alias"), + ], +) +def test_generate_alias_prefers_alias_over_upper_case_name( + table_name, alias_map, alias +): + assert pgcompleter.generate_alias(table_name, alias_map) == alias + + +@pytest.mark.parametrize( + "table_name, alias", + [ + ("Some_tablE", "SE"), + ("SomeTab_le", "ST"), + ], +) +def test_generate_alias_prefers_upper_case_name_over_underscore_name(table_name, alias): + assert pgcompleter.generate_alias(table_name) == alias diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py new file mode 100644 index 0000000..636795b --- /dev/null +++ b/tests/test_pgexecute.py @@ -0,0 +1,773 @@ +from textwrap import dedent + +import psycopg +import pytest +from unittest.mock import patch, MagicMock +from pgspecial.main import PGSpecial, NO_QUERY +from utils import run, dbtest, requires_json, requires_jsonb + +from pgcli.main import PGCli +from pgcli.packages.parseutils.meta import FunctionMetadata + + +def function_meta_data( + func_name, + schema_name="public", + arg_names=None, + arg_types=None, + arg_modes=None, + return_type=None, + is_aggregate=False, + is_window=False, + is_set_returning=False, + is_extension=False, + arg_defaults=None, +): + return FunctionMetadata( + schema_name, + func_name, + arg_names, + arg_types, + arg_modes, + return_type, + is_aggregate, + is_window, + is_set_returning, + is_extension, + arg_defaults, + ) + + +@dbtest +def test_conn(executor): + run(executor, """create table test(a text)""") + run(executor, """insert into test values('abc')""") + assert run(executor, """select * from test""", join=True) == dedent( + """\ + +-----+ + | a | + |-----| + | abc | + +-----+ + SELECT 1""" + ) + + +@dbtest +def test_copy(executor): + executor_copy = executor.copy() + run(executor_copy, """create table test(a text)""") + run(executor_copy, """insert into test values('abc')""") + assert run(executor_copy, """select * from test""", join=True) == dedent( + """\ + +-----+ + | a | + |-----| + | abc | + +-----+ + SELECT 1""" + ) + + +@dbtest +def test_bools_are_treated_as_strings(executor): + run(executor, """create table test(a boolean)""") + run(executor, """insert into test values(True)""") + assert run(executor, """select * from test""", join=True) == dedent( + """\ + +------+ + | a | + |------| + | True | + +------+ + SELECT 1""" + ) + + +@dbtest +def test_expanded_slash_G(executor, pgspecial): + # Tests whether we reset the expanded output after a \G. + run(executor, """create table test(a boolean)""") + run(executor, """insert into test values(True)""") + results = run(executor, r"""select * from test \G""", pgspecial=pgspecial) + assert pgspecial.expanded_output == False + + +@dbtest +def test_schemata_table_views_and_columns_query(executor): + run(executor, "create table a(x text, y text)") + run(executor, "create table b(z text)") + run(executor, "create view d as select 1 as e") + run(executor, "create schema schema1") + run(executor, "create table schema1.c (w text DEFAULT 'meow')") + run(executor, "create schema schema2") + + # schemata + # don't enforce all members of the schemas since they may include postgres + # temporary schemas + assert set(executor.schemata()) >= { + "public", + "pg_catalog", + "information_schema", + "schema1", + "schema2", + } + assert executor.search_path() == ["pg_catalog", "public"] + + # tables + assert set(executor.tables()) >= { + ("public", "a"), + ("public", "b"), + ("schema1", "c"), + } + + assert set(executor.table_columns()) >= { + ("public", "a", "x", "text", False, None), + ("public", "a", "y", "text", False, None), + ("public", "b", "z", "text", False, None), + ("schema1", "c", "w", "text", True, "'meow'::text"), + } + + # views + assert set(executor.views()) >= {("public", "d")} + + assert set(executor.view_columns()) >= { + ("public", "d", "e", "integer", False, None) + } + + +@dbtest +def test_foreign_key_query(executor): + run(executor, "create schema schema1") + run(executor, "create schema schema2") + run(executor, "create table schema1.parent(parentid int PRIMARY KEY)") + run( + executor, + "create table schema2.child(childid int PRIMARY KEY, motherid int REFERENCES schema1.parent)", + ) + + assert set(executor.foreignkeys()) >= { + ("schema1", "parent", "parentid", "schema2", "child", "motherid") + } + + +@dbtest +def test_functions_query(executor): + run( + executor, + """create function func1() returns int + language sql as $$select 1$$""", + ) + run(executor, "create schema schema1") + run( + executor, + """create function schema1.func2() returns int + language sql as $$select 2$$""", + ) + + run( + executor, + """create function func3() + returns table(x int, y int) language sql + as $$select 1, 2 from generate_series(1,5)$$;""", + ) + + run( + executor, + """create function func4(x int) returns setof int language sql + as $$select generate_series(1,5)$$;""", + ) + + funcs = set(executor.functions()) + assert funcs >= { + function_meta_data(func_name="func1", return_type="integer"), + function_meta_data( + func_name="func3", + arg_names=["x", "y"], + arg_types=["integer", "integer"], + arg_modes=["t", "t"], + return_type="record", + is_set_returning=True, + ), + function_meta_data( + schema_name="public", + func_name="func4", + arg_names=("x",), + arg_types=("integer",), + return_type="integer", + is_set_returning=True, + ), + function_meta_data( + schema_name="schema1", func_name="func2", return_type="integer" + ), + } + + +@dbtest +def test_datatypes_query(executor): + run(executor, "create type foo AS (a int, b text)") + + types = list(executor.datatypes()) + assert types == [("public", "foo")] + + +@dbtest +def test_database_list(executor): + databases = executor.databases() + assert "_test_db" in databases + + +@dbtest +def test_invalid_syntax(executor, exception_formatter): + result = run(executor, "invalid syntax!", exception_formatter=exception_formatter) + assert 'syntax error at or near "invalid"' in result[0] + + +@dbtest +def test_invalid_column_name(executor, exception_formatter): + result = run( + executor, "select invalid command", exception_formatter=exception_formatter + ) + assert 'column "invalid" does not exist' in result[0] + + +@pytest.fixture(params=[True, False]) +def expanded(request): + return request.param + + +@dbtest +def test_unicode_support_in_output(executor, expanded): + run(executor, "create table unicodechars(t text)") + run(executor, "insert into unicodechars (t) values ('é')") + + # See issue #24, this raises an exception without proper handling + assert "é" in run( + executor, "select * from unicodechars", join=True, expanded=expanded + ) + + +@dbtest +def test_not_is_special(executor, pgspecial): + """is_special is set to false for database queries.""" + query = "select 1" + result = list(executor.run(query, pgspecial=pgspecial)) + success, is_special = result[0][5:] + assert success == True + assert is_special == False + + +@dbtest +def test_execute_from_file_no_arg(executor, pgspecial): + r"""\i without a filename returns an error.""" + result = list(executor.run(r"\i", pgspecial=pgspecial)) + status, sql, success, is_special = result[0][3:] + assert "missing required argument" in status + assert success == False + assert is_special == True + + +@dbtest +@patch("pgcli.main.os") +def test_execute_from_file_io_error(os, executor, pgspecial): + r"""\i with an os_error returns an error.""" + # Inject an OSError. + os.path.expanduser.side_effect = OSError("test") + + # Check the result. + result = list(executor.run(r"\i test", pgspecial=pgspecial)) + status, sql, success, is_special = result[0][3:] + assert status == "test" + assert success == False + assert is_special == True + + +@dbtest +def test_execute_from_commented_file_that_executes_another_file( + executor, pgspecial, tmpdir +): + # https://github.com/dbcli/pgcli/issues/1336 + sqlfile1 = tmpdir.join("test01.sql") + sqlfile1.write("-- asdf \n\\h") + sqlfile2 = tmpdir.join("test00.sql") + sqlfile2.write("--An useless comment;\nselect now();\n-- another useless comment") + + rcfile = str(tmpdir.join("rcfile")) + print(rcfile) + cli = PGCli(pgexecute=executor, pgclirc_file=rcfile) + assert cli != None + statement = "--comment\n\\h" + result = run(executor, statement, pgspecial=cli.pgspecial) + assert result != None + assert result[0].find("ALTER TABLE") + + +@dbtest +def test_execute_commented_first_line_and_special(executor, pgspecial, tmpdir): + # just some base cases that should work also + statement = "--comment\nselect now();" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("now") >= 0 + + statement = "/*comment*/\nselect now();" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("now") >= 0 + + # https://github.com/dbcli/pgcli/issues/1362 + statement = "--comment\n\\h" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = "--comment1\n--comment2\n\\h" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = "/*comment*/\n\h;" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = """/*comment1 + comment2*/ + \h""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = """/*comment1 + comment2*/ + /*comment 3 + comment4*/ + \\h""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = " /*comment*/\n\h;" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = "/*comment\ncomment line2*/\n\h;" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = " /*comment\ncomment line2*/\n\h;" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = """\\h /*comment4 */""" + result = run(executor, statement, pgspecial=pgspecial) + print(result) + assert result != None + assert result[0].find("No help") >= 0 + + # TODO: we probably don't want to do this but sqlparse is not parsing things well + # we relly want it to find help but right now, sqlparse isn't dropping the /*comment*/ + # style comments after command + + statement = """/*comment1*/ + \h + /*comment4 */""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[0].find("No help") >= 0 + + # TODO: same for this one + statement = """/*comment1 + comment3 + comment2*/ + \\h + /*comment4 + comment5 + comment6*/""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[0].find("No help") >= 0 + + +@dbtest +def test_execute_commented_first_line_and_normal(executor, pgspecial, tmpdir): + # https://github.com/dbcli/pgcli/issues/1403 + + # just some base cases that should work also + statement = "--comment\nselect now();" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("now") >= 0 + + statement = "/*comment*/\nselect now();" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("now") >= 0 + + # this simulates the original error (1403) without having to add/drop tables + # since it was just an error on reading input files and not the actual + # command itself + + # test that the statement works + statement = """VALUES (1, 'one'), (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # test the statement with a \n in the middle + statement = """VALUES (1, 'one'),\n (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # test the statement with a newline in the middle + statement = """VALUES (1, 'one'), + (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # now add a single comment line + statement = """--comment\nVALUES (1, 'one'), (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # doing without special char \n + statement = """--comment + VALUES (1,'one'), + (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # two comment lines + statement = """--comment\n--comment2\nVALUES (1,'one'), (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # doing without special char \n + statement = """--comment + --comment2 + VALUES (1,'one'), (2, 'two'), (3, 'three'); + """ + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # multiline comment + newline in middle of the statement + statement = """/*comment +comment2 +comment3*/ +VALUES (1,'one'), +(2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # multiline comment + newline in middle of the statement + # + comments after the statement + statement = """/*comment +comment2 +comment3*/ +VALUES (1,'one'), +(2, 'two'), (3, 'three'); +--comment4 +--comment5""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + +@dbtest +def test_multiple_queries_same_line(executor): + result = run(executor, "select 'foo'; select 'bar'") + assert len(result) == 12 # 2 * (output+status) * 3 lines + assert "foo" in result[3] + assert "bar" in result[9] + + +@dbtest +def test_multiple_queries_with_special_command_same_line(executor, pgspecial): + result = run(executor, r"select 'foo'; \d", pgspecial=pgspecial) + assert len(result) == 11 # 2 * (output+status) * 3 lines + assert "foo" in result[3] + # This is a lame check. :( + assert "Schema" in result[7] + + +@dbtest +def test_multiple_queries_same_line_syntaxerror(executor, exception_formatter): + result = run( + executor, + "select 'fooé'; invalid syntax é", + exception_formatter=exception_formatter, + ) + assert "fooé" in result[3] + assert 'syntax error at or near "invalid"' in result[-1] + + +@pytest.fixture +def pgspecial(): + return PGCli().pgspecial + + +@dbtest +def test_special_command_help(executor, pgspecial): + result = run(executor, "\\?", pgspecial=pgspecial)[1].split("|") + assert "Command" in result[1] + assert "Description" in result[2] + + +@dbtest +def test_bytea_field_support_in_output(executor): + run(executor, "create table binarydata(c bytea)") + run(executor, "insert into binarydata (c) values (decode('DEADBEEF', 'hex'))") + + assert "\\xdeadbeef" in run(executor, "select * from binarydata", join=True) + + +@dbtest +def test_unicode_support_in_unknown_type(executor): + assert "日本語" in run(executor, "SELECT '日本語' AS japanese;", join=True) + + +@dbtest +def test_unicode_support_in_enum_type(executor): + run(executor, "CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy', '日本語')") + run(executor, "CREATE TABLE person (name TEXT, current_mood mood)") + run(executor, "INSERT INTO person VALUES ('Moe', '日本語')") + assert "日本語" in run(executor, "SELECT * FROM person", join=True) + + +@requires_json +def test_json_renders_without_u_prefix(executor, expanded): + run(executor, "create table jsontest(d json)") + run(executor, """insert into jsontest (d) values ('{"name": "Éowyn"}')""") + result = run( + executor, "SELECT d FROM jsontest LIMIT 1", join=True, expanded=expanded + ) + + assert '{"name": "Éowyn"}' in result + + +@requires_jsonb +def test_jsonb_renders_without_u_prefix(executor, expanded): + run(executor, "create table jsonbtest(d jsonb)") + run(executor, """insert into jsonbtest (d) values ('{"name": "Éowyn"}')""") + result = run( + executor, "SELECT d FROM jsonbtest LIMIT 1", join=True, expanded=expanded + ) + + assert '{"name": "Éowyn"}' in result + + +@dbtest +def test_date_time_types(executor): + run(executor, "SET TIME ZONE UTC") + assert ( + run(executor, "SELECT (CAST('00:00:00' AS time))", join=True).split("\n")[3] + == "| 00:00:00 |" + ) + assert ( + run(executor, "SELECT (CAST('00:00:00+14:59' AS timetz))", join=True).split( + "\n" + )[3] + == "| 00:00:00+14:59 |" + ) + assert ( + run(executor, "SELECT (CAST('4713-01-01 BC' AS date))", join=True).split("\n")[ + 3 + ] + == "| 4713-01-01 BC |" + ) + assert ( + run( + executor, "SELECT (CAST('4713-01-01 00:00:00 BC' AS timestamp))", join=True + ).split("\n")[3] + == "| 4713-01-01 00:00:00 BC |" + ) + assert ( + run( + executor, + "SELECT (CAST('4713-01-01 00:00:00+00 BC' AS timestamptz))", + join=True, + ).split("\n")[3] + == "| 4713-01-01 00:00:00+00 BC |" + ) + assert ( + run( + executor, "SELECT (CAST('-123456789 days 12:23:56' AS interval))", join=True + ).split("\n")[3] + == "| -123456789 days, 12:23:56 |" + ) + + +@dbtest +@pytest.mark.parametrize("value", ["10000000", "10000000.0", "10000000000000"]) +def test_large_numbers_render_directly(executor, value): + run(executor, "create table numbertest(a numeric)") + run(executor, f"insert into numbertest (a) values ({value})") + assert value in run(executor, "select * from numbertest", join=True) + + +@dbtest +@pytest.mark.parametrize("command", ["di", "dv", "ds", "df", "dT"]) +@pytest.mark.parametrize("verbose", ["", "+"]) +@pytest.mark.parametrize("pattern", ["", "x", "*.*", "x.y", "x.*", "*.y"]) +def test_describe_special(executor, command, verbose, pattern, pgspecial): + # We don't have any tests for the output of any of the special commands, + # but we can at least make sure they run without error + sql = r"\{command}{verbose} {pattern}".format(**locals()) + list(executor.run(sql, pgspecial=pgspecial)) + + +@dbtest +@pytest.mark.parametrize("sql", ["invalid sql", "SELECT 1; select error;"]) +def test_raises_with_no_formatter(executor, sql): + with pytest.raises(psycopg.ProgrammingError): + list(executor.run(sql)) + + +@dbtest +def test_on_error_resume(executor, exception_formatter): + sql = "select 1; error; select 1;" + result = list( + executor.run(sql, on_error_resume=True, exception_formatter=exception_formatter) + ) + assert len(result) == 3 + + +@dbtest +def test_on_error_stop(executor, exception_formatter): + sql = "select 1; error; select 1;" + result = list( + executor.run( + sql, on_error_resume=False, exception_formatter=exception_formatter + ) + ) + assert len(result) == 2 + + +# @dbtest +# def test_unicode_notices(executor): +# sql = "DO language plpgsql $$ BEGIN RAISE NOTICE '有人更改'; END $$;" +# result = list(executor.run(sql)) +# assert result[0][0] == u'NOTICE: 有人更改\n' + + +@dbtest +def test_nonexistent_function_definition(executor): + with pytest.raises(RuntimeError): + result = executor.view_definition("there_is_no_such_function") + + +@dbtest +def test_function_definition(executor): + run( + executor, + """ + CREATE OR REPLACE FUNCTION public.the_number_three() + RETURNS int + LANGUAGE sql + AS $function$ + select 3; + $function$ + """, + ) + result = executor.function_definition("the_number_three") + + +@dbtest +def test_view_definition(executor): + run(executor, "create table tbl1 (a text, b numeric)") + run(executor, "create view vw1 AS SELECT * FROM tbl1") + run(executor, "create materialized view mvw1 AS SELECT * FROM tbl1") + result = executor.view_definition("vw1") + assert 'VIEW "public"."vw1" AS' in result + assert "FROM tbl1" in result + # import pytest; pytest.set_trace() + result = executor.view_definition("mvw1") + assert "MATERIALIZED VIEW" in result + + +@dbtest +def test_nonexistent_view_definition(executor): + with pytest.raises(RuntimeError): + result = executor.view_definition("there_is_no_such_view") + with pytest.raises(RuntimeError): + result = executor.view_definition("mvw1") + + +@dbtest +def test_short_host(executor): + with patch.object(executor, "host", "localhost"): + assert executor.short_host == "localhost" + with patch.object(executor, "host", "localhost.example.org"): + assert executor.short_host == "localhost" + with patch.object( + executor, "host", "localhost1.example.org,localhost2.example.org" + ): + assert executor.short_host == "localhost1" + + +class VirtualCursor: + """Mock a cursor to virtual database like pgbouncer.""" + + def __init__(self): + self.protocol_error = False + self.protocol_message = "" + self.description = None + self.status = None + self.statusmessage = "Error" + + def execute(self, *args, **kwargs): + self.protocol_error = True + self.protocol_message = "Command not supported" + + +@dbtest +def test_exit_without_active_connection(executor): + quit_handler = MagicMock() + pgspecial = PGSpecial() + pgspecial.register( + quit_handler, + "\\q", + "\\q", + "Quit pgcli.", + arg_type=NO_QUERY, + case_sensitive=True, + aliases=(":q",), + ) + + with patch.object( + executor.conn, "cursor", side_effect=psycopg.InterfaceError("I'm broken!") + ): + # we should be able to quit the app, even without active connection + run(executor, "\\q", pgspecial=pgspecial) + quit_handler.assert_called_once() + + # an exception should be raised when running a query without active connection + with pytest.raises(psycopg.InterfaceError): + run(executor, "select 1", pgspecial=pgspecial) + + +@dbtest +def test_virtual_database(executor): + virtual_connection = MagicMock() + virtual_connection.cursor.return_value = VirtualCursor() + with patch.object(executor, "conn", virtual_connection): + result = run(executor, "select 1") + assert "Command not supported" in result diff --git a/tests/test_pgspecial.py b/tests/test_pgspecial.py new file mode 100644 index 0000000..cd99e32 --- /dev/null +++ b/tests/test_pgspecial.py @@ -0,0 +1,78 @@ +import pytest +from pgcli.packages.sqlcompletion import ( + suggest_type, + Special, + Database, + Schema, + Table, + View, + Function, + Datatype, +) + + +def test_slash_suggests_special(): + suggestions = suggest_type("\\", "\\") + assert set(suggestions) == {Special()} + + +def test_slash_d_suggests_special(): + suggestions = suggest_type("\\d", "\\d") + assert set(suggestions) == {Special()} + + +def test_dn_suggests_schemata(): + suggestions = suggest_type("\\dn ", "\\dn ") + assert suggestions == (Schema(),) + + suggestions = suggest_type("\\dn xxx", "\\dn xxx") + assert suggestions == (Schema(),) + + +def test_d_suggests_tables_views_and_schemas(): + suggestions = suggest_type(r"\d ", r"\d ") + assert set(suggestions) == {Schema(), Table(schema=None), View(schema=None)} + + suggestions = suggest_type(r"\d xxx", r"\d xxx") + assert set(suggestions) == {Schema(), Table(schema=None), View(schema=None)} + + +def test_d_dot_suggests_schema_qualified_tables_or_views(): + suggestions = suggest_type(r"\d myschema.", r"\d myschema.") + assert set(suggestions) == {Table(schema="myschema"), View(schema="myschema")} + + suggestions = suggest_type(r"\d myschema.xxx", r"\d myschema.xxx") + assert set(suggestions) == {Table(schema="myschema"), View(schema="myschema")} + + +def test_df_suggests_schema_or_function(): + suggestions = suggest_type("\\df xxx", "\\df xxx") + assert set(suggestions) == {Function(schema=None, usage="special"), Schema()} + + suggestions = suggest_type("\\df myschema.xxx", "\\df myschema.xxx") + assert suggestions == (Function(schema="myschema", usage="special"),) + + +def test_leading_whitespace_ok(): + cmd = "\\dn " + whitespace = " " + suggestions = suggest_type(whitespace + cmd, whitespace + cmd) + assert suggestions == suggest_type(cmd, cmd) + + +def test_dT_suggests_schema_or_datatypes(): + text = "\\dT " + suggestions = suggest_type(text, text) + assert set(suggestions) == {Schema(), Datatype(schema=None)} + + +def test_schema_qualified_dT_suggests_datatypes(): + text = "\\dT foo." + suggestions = suggest_type(text, text) + assert suggestions == (Datatype(schema="foo"),) + + +@pytest.mark.parametrize("command", ["\\c ", "\\connect "]) +def test_c_suggests_databases(command): + suggestions = suggest_type(command, command) + assert suggestions == (Database(),) diff --git a/tests/test_prioritization.py b/tests/test_prioritization.py new file mode 100644 index 0000000..f5b6700 --- /dev/null +++ b/tests/test_prioritization.py @@ -0,0 +1,20 @@ +from pgcli.packages.prioritization import PrevalenceCounter + + +def test_prevalence_counter(): + counter = PrevalenceCounter() + sql = """SELECT * FROM foo WHERE bar GROUP BY baz; + select * from foo; + SELECT * FROM foo WHERE bar GROUP + BY baz""" + counter.update(sql) + + keywords = ["SELECT", "FROM", "GROUP BY"] + expected = [3, 3, 2] + kw_counts = [counter.keyword_count(x) for x in keywords] + assert kw_counts == expected + assert counter.keyword_count("NOSUCHKEYWORD") == 0 + + names = ["foo", "bar", "baz"] + name_counts = [counter.name_count(x) for x in names] + assert name_counts == [3, 2, 2] diff --git a/tests/test_prompt_utils.py b/tests/test_prompt_utils.py new file mode 100644 index 0000000..91abe37 --- /dev/null +++ b/tests/test_prompt_utils.py @@ -0,0 +1,17 @@ +import click + +from pgcli.packages.prompt_utils import confirm_destructive_query + + +def test_confirm_destructive_query_notty(): + stdin = click.get_text_stream("stdin") + if not stdin.isatty(): + sql = "drop database foo;" + assert confirm_destructive_query(sql, [], None) is None + + +def test_confirm_destructive_query_with_alias(): + stdin = click.get_text_stream("stdin") + if not stdin.isatty(): + sql = "drop database foo;" + assert confirm_destructive_query(sql, ["drop"], "test") is None diff --git a/tests/test_rowlimit.py b/tests/test_rowlimit.py new file mode 100644 index 0000000..da916b4 --- /dev/null +++ b/tests/test_rowlimit.py @@ -0,0 +1,79 @@ +import pytest +from unittest.mock import Mock + +from pgcli.main import PGCli + + +# We need this fixtures because we need PGCli object to be created +# after test collection so it has config loaded from temp directory + + +@pytest.fixture(scope="module") +def default_pgcli_obj(): + return PGCli() + + +@pytest.fixture(scope="module") +def DEFAULT(default_pgcli_obj): + return default_pgcli_obj.row_limit + + +@pytest.fixture(scope="module") +def LIMIT(DEFAULT): + return DEFAULT + 1000 + + +@pytest.fixture(scope="module") +def over_default(DEFAULT): + over_default_cursor = Mock() + over_default_cursor.configure_mock(rowcount=DEFAULT + 10) + return over_default_cursor + + +@pytest.fixture(scope="module") +def over_limit(LIMIT): + over_limit_cursor = Mock() + over_limit_cursor.configure_mock(rowcount=LIMIT + 10) + return over_limit_cursor + + +@pytest.fixture(scope="module") +def low_count(): + low_count_cursor = Mock() + low_count_cursor.configure_mock(rowcount=1) + return low_count_cursor + + +def test_row_limit_with_LIMIT_clause(LIMIT, over_limit): + cli = PGCli(row_limit=LIMIT) + stmt = "SELECT * FROM students LIMIT 1000" + + result = cli._should_limit_output(stmt, over_limit) + assert result is False + + cli = PGCli(row_limit=0) + result = cli._should_limit_output(stmt, over_limit) + assert result is False + + +def test_row_limit_without_LIMIT_clause(LIMIT, over_limit): + cli = PGCli(row_limit=LIMIT) + stmt = "SELECT * FROM students" + + result = cli._should_limit_output(stmt, over_limit) + assert result is True + + cli = PGCli(row_limit=0) + result = cli._should_limit_output(stmt, over_limit) + assert result is False + + +def test_row_limit_on_non_select(over_limit): + cli = PGCli() + stmt = "UPDATE students SET name='Boby'" + result = cli._should_limit_output(stmt, over_limit) + assert result is False + + cli = PGCli(row_limit=0) + result = cli._should_limit_output(stmt, over_limit) + assert result is False diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py new file mode 100644 index 0000000..5c9c9af --- /dev/null +++ b/tests/test_smart_completion_multiple_schemata.py @@ -0,0 +1,757 @@ +import itertools +from metadata import ( + MetaData, + alias, + name_join, + fk_join, + join, + schema, + table, + function, + wildcard_expansion, + column, + get_result, + result_set, + qual, + no_qual, + parametrize, +) +from utils import completions_to_set + +metadata = { + "tables": { + "public": { + "users": ["id", "email", "first_name", "last_name"], + "orders": ["id", "ordered_date", "status", "datestamp"], + "select": ["id", "localtime", "ABC"], + }, + "custom": { + "users": ["id", "phone_number"], + "Users": ["userid", "username"], + "products": ["id", "product_name", "price"], + "shipments": ["id", "address", "user_id"], + }, + "Custom": {"projects": ["projectid", "name"]}, + "blog": { + "entries": ["entryid", "entrytitle", "entrytext"], + "tags": ["tagid", "name"], + "entrytags": ["entryid", "tagid"], + "entacclog": ["entryid", "username", "datestamp"], + }, + }, + "functions": { + "public": [ + ["func1", [], [], [], "", False, False, False, False], + ["func2", [], [], [], "", False, False, False, False], + ], + "custom": [ + ["func3", [], [], [], "", False, False, False, False], + [ + "set_returning_func", + ["x"], + ["integer"], + ["o"], + "integer", + False, + False, + True, + False, + ], + ], + "Custom": [["func4", [], [], [], "", False, False, False, False]], + "blog": [ + [ + "extract_entry_symbols", + ["_entryid", "symbol"], + ["integer", "text"], + ["i", "o"], + "", + False, + False, + True, + False, + ], + [ + "enter_entry", + ["_title", "_text", "entryid"], + ["text", "text", "integer"], + ["i", "i", "o"], + "", + False, + False, + False, + False, + ], + ], + }, + "datatypes": {"public": ["typ1", "typ2"], "custom": ["typ3", "typ4"]}, + "foreignkeys": { + "custom": [("public", "users", "id", "custom", "shipments", "user_id")], + "blog": [ + ("blog", "entries", "entryid", "blog", "entacclog", "entryid"), + ("blog", "entries", "entryid", "blog", "entrytags", "entryid"), + ("blog", "tags", "tagid", "blog", "entrytags", "tagid"), + ], + }, + "defaults": { + "public": { + ("orders", "id"): "nextval('orders_id_seq'::regclass)", + ("orders", "datestamp"): "now()", + ("orders", "status"): "'PENDING'::text", + } + }, +} + +testdata = MetaData(metadata) +cased_schemas = [schema(x) for x in ("public", "blog", "CUSTOM", '"Custom"')] +casing = ( + "SELECT", + "Orders", + "User_Emails", + "CUSTOM", + "Func1", + "Entries", + "Tags", + "EntryTags", + "EntAccLog", + "EntryID", + "EntryTitle", + "EntryText", +) +completers = testdata.get_completers(casing) + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +@parametrize("table", ["users", '"users"']) +def test_suggested_column_names_from_shadowed_visible_table(completer, table): + result = get_result(completer, "SELECT FROM " + table, len("SELECT ")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("users") + ) + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +@parametrize( + "text", + [ + "SELECT from custom.users", + "WITH users as (SELECT 1 AS foo) SELECT from custom.users", + ], +) +def test_suggested_column_names_from_qualified_shadowed_table(completer, text): + result = get_result(completer, text, position=text.find(" ") + 1) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("users", "custom") + ) + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +@parametrize("text", ["WITH users as (SELECT 1 AS foo) SELECT from users"]) +def test_suggested_column_names_from_cte(completer, text): + result = completions_to_set(get_result(completer, text, text.find(" ") + 1)) + assert result == completions_to_set( + [column("foo")] + testdata.functions_and_keywords() + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT * FROM users JOIN custom.shipments ON ", + """SELECT * + FROM public.users + JOIN custom.shipments ON """, + ], +) +def test_suggested_join_conditions(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [ + alias("users"), + alias("shipments"), + name_join("shipments.id = users.id"), + fk_join("shipments.user_id = users.id"), + ] + ) + + +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +@parametrize( + ("query", "tbl"), + itertools.product( + ( + "SELECT * FROM public.{0} RIGHT OUTER JOIN ", + """SELECT * + FROM {0} + JOIN """, + ), + ("users", '"users"', "Users"), + ), +) +def test_suggested_joins(completer, query, tbl): + result = get_result(completer, query.format(tbl)) + assert completions_to_set(result) == completions_to_set( + testdata.schemas_and_from_clause_items() + + [join(f"custom.shipments ON shipments.user_id = {tbl}.id")] + ) + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +def test_suggested_column_names_from_schema_qualifed_table(completer): + result = get_result(completer, "SELECT from custom.products", len("SELECT ")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("products", "custom") + ) + + +@parametrize( + "text", + [ + "INSERT INTO orders(", + "INSERT INTO orders (", + "INSERT INTO public.orders(", + "INSERT INTO public.orders (", + ], +) +@parametrize("completer", completers(filtr=True, casing=False)) +def test_suggested_columns_with_insert(completer, text): + assert completions_to_set(get_result(completer, text)) == completions_to_set( + testdata.columns("orders") + ) + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +def test_suggested_column_names_in_function(completer): + result = get_result( + completer, "SELECT MAX( from custom.products", len("SELECT MAX(") + ) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("products", "custom") + ) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize( + "text", + ["SELECT * FROM Custom.", "SELECT * FROM custom.", 'SELECT * FROM "custom".'], +) +@parametrize("use_leading_double_quote", [False, True]) +def test_suggested_table_names_with_schema_dot( + completer, text, use_leading_double_quote +): + if use_leading_double_quote: + text += '"' + start_position = -1 + else: + start_position = 0 + + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.from_clause_items("custom", start_position) + ) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize("text", ['SELECT * FROM "Custom".']) +@parametrize("use_leading_double_quote", [False, True]) +def test_suggested_table_names_with_schema_dot2( + completer, text, use_leading_double_quote +): + if use_leading_double_quote: + text += '"' + start_position = -1 + else: + start_position = 0 + + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.from_clause_items("Custom", start_position) + ) + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_suggested_column_names_with_qualified_alias(completer): + result = get_result(completer, "SELECT p. from custom.products p", len("SELECT p.")) + assert completions_to_set(result) == completions_to_set( + testdata.columns("products", "custom") + ) + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +def test_suggested_multiple_column_names(completer): + result = get_result( + completer, "SELECT id, from custom.products", len("SELECT id, ") + ) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("products", "custom") + ) + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_suggested_multiple_column_names_with_alias(completer): + result = get_result( + completer, "SELECT p.id, p. from custom.products p", len("SELECT u.id, u.") + ) + assert completions_to_set(result) == completions_to_set( + testdata.columns("products", "custom") + ) + + +@parametrize("completer", completers(filtr=True, casing=False)) +@parametrize( + "text", + [ + "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON ", + "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON JOIN public.orders z ON z.id > y.id", + ], +) +def test_suggestions_after_on(completer, text): + position = len( + "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON " + ) + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + [ + alias("x"), + alias("y"), + name_join("y.price = x.price"), + name_join("y.product_name = x.product_name"), + name_join("y.id = x.id"), + ] + ) + + +@parametrize("completer", completers()) +def test_suggested_aliases_after_on_right_side(completer): + text = "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON x.id = " + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set([alias("x"), alias("y")]) + + +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +def test_table_names_after_from(completer): + text = "SELECT * FROM " + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas_and_from_clause_items() + ) + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_schema_qualified_function_name(completer): + text = "SELECT custom.func" + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [ + function("func3()", -len("func")), + function("set_returning_func()", -len("func")), + ] + ) + + +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +def test_schema_qualified_function_name_after_from(completer): + text = "SELECT * FROM custom.set_r" + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [ + function("set_returning_func()", -len("func")), + ] + ) + + +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +def test_unqualified_function_name_not_returned(completer): + text = "SELECT * FROM set_r" + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set([]) + + +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +def test_unqualified_function_name_in_search_path(completer): + completer.search_path = ["public", "custom"] + text = "SELECT * FROM set_r" + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [ + function("set_returning_func()", -len("func")), + ] + ) + + +@parametrize("completer", completers(filtr=True, casing=False)) +@parametrize( + "text", + [ + "SELECT 1::custom.", + "CREATE TABLE foo (bar custom.", + "CREATE FUNCTION foo (bar INT, baz custom.", + "ALTER TABLE foo ALTER COLUMN bar TYPE custom.", + ], +) +def test_schema_qualified_type_name(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set(testdata.types("custom")) + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_suggest_columns_from_aliased_set_returning_function(completer): + result = get_result( + completer, "select f. from custom.set_returning_func() f", len("select f.") + ) + assert completions_to_set(result) == completions_to_set( + testdata.columns("set_returning_func", "custom", "functions") + ) + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +@parametrize( + "text", + [ + "SELECT * FROM custom.set_returning_func()", + "SELECT * FROM Custom.set_returning_func()", + "SELECT * FROM Custom.Set_Returning_Func()", + ], +) +def test_wildcard_column_expansion_with_function(completer, text): + position = len("SELECT *") + + completions = get_result(completer, text, position) + + col_list = "x" + expected = [wildcard_expansion(col_list)] + + assert expected == completions + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_wildcard_column_expansion_with_alias_qualifier(completer): + text = "SELECT p.* FROM custom.products p" + position = len("SELECT p.*") + + completions = get_result(completer, text, position) + + col_list = "id, p.product_name, p.price" + expected = [wildcard_expansion(col_list)] + + assert expected == completions + + +@parametrize("completer", completers(filtr=True, casing=False)) +@parametrize( + "text", + [ + """ + SELECT count(1) FROM users; + CREATE FUNCTION foo(custom.products _products) returns custom.shipments + LANGUAGE SQL + AS $foo$ + SELECT 1 FROM custom.shipments; + INSERT INTO public.orders(*) values(-1, now(), 'preliminary'); + SELECT 2 FROM custom.users; + $foo$; + SELECT count(1) FROM custom.shipments; + """, + "INSERT INTO public.orders(*", + "INSERT INTO public.Orders(*", + "INSERT INTO public.orders (*", + "INSERT INTO public.Orders (*", + "INSERT INTO orders(*", + "INSERT INTO Orders(*", + "INSERT INTO orders (*", + "INSERT INTO Orders (*", + "INSERT INTO public.orders(*)", + "INSERT INTO public.Orders(*)", + "INSERT INTO public.orders (*)", + "INSERT INTO public.Orders (*)", + "INSERT INTO orders(*)", + "INSERT INTO Orders(*)", + "INSERT INTO orders (*)", + "INSERT INTO Orders (*)", + ], +) +def test_wildcard_column_expansion_with_insert(completer, text): + position = text.index("*") + 1 + completions = get_result(completer, text, position) + + expected = [wildcard_expansion("ordered_date, status")] + assert expected == completions + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_wildcard_column_expansion_with_table_qualifier(completer): + text = 'SELECT "select".* FROM public."select"' + position = len('SELECT "select".*') + + completions = get_result(completer, text, position) + + col_list = 'id, "select"."localtime", "select"."ABC"' + expected = [wildcard_expansion(col_list)] + + assert expected == completions + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=qual)) +def test_wildcard_column_expansion_with_two_tables(completer): + text = 'SELECT * FROM public."select" JOIN custom.users ON true' + position = len("SELECT *") + + completions = get_result(completer, text, position) + + cols = ( + '"select".id, "select"."localtime", "select"."ABC", ' + "users.id, users.phone_number" + ) + expected = [wildcard_expansion(cols)] + assert completions == expected + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_wildcard_column_expansion_with_two_tables_and_parent(completer): + text = 'SELECT "select".* FROM public."select" JOIN custom.users u ON true' + position = len('SELECT "select".*') + + completions = get_result(completer, text, position) + + col_list = 'id, "select"."localtime", "select"."ABC"' + expected = [wildcard_expansion(col_list)] + + assert expected == completions + + +@parametrize("completer", completers(filtr=True, casing=False)) +@parametrize( + "text", + [ + "SELECT U. FROM custom.Users U", + "SELECT U. FROM custom.USERS U", + "SELECT U. FROM custom.users U", + 'SELECT U. FROM "custom".Users U', + 'SELECT U. FROM "custom".USERS U', + 'SELECT U. FROM "custom".users U', + ], +) +def test_suggest_columns_from_unquoted_table(completer, text): + position = len("SELECT U.") + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + testdata.columns("users", "custom") + ) + + +@parametrize("completer", completers(filtr=True, casing=False)) +@parametrize( + "text", ['SELECT U. FROM custom."Users" U', 'SELECT U. FROM "custom"."Users" U'] +) +def test_suggest_columns_from_quoted_table(completer, text): + position = len("SELECT U.") + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + testdata.columns("Users", "custom") + ) + + +texts = ["SELECT * FROM ", "SELECT * FROM public.Orders O CROSS JOIN "] + + +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +@parametrize("text", texts) +def test_schema_or_visible_table_completion(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas_and_from_clause_items() + ) + + +@parametrize("completer", completers(aliasing=True, casing=False, filtr=True)) +@parametrize("text", texts) +def test_table_aliases(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas() + + [ + table("users u"), + table("orders o" if text == "SELECT * FROM " else "orders o2"), + table('"select" s'), + function("func1() f"), + function("func2() f"), + ] + ) + + +@parametrize("completer", completers(aliasing=True, casing=True, filtr=True)) +@parametrize("text", texts) +def test_aliases_with_casing(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + cased_schemas + + [ + table("users u"), + table("Orders O" if text == "SELECT * FROM " else "Orders O2"), + table('"select" s'), + function("Func1() F"), + function("func2() f"), + ] + ) + + +@parametrize("completer", completers(aliasing=False, casing=True, filtr=True)) +@parametrize("text", texts) +def test_table_casing(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + cased_schemas + + [ + table("users"), + table("Orders"), + table('"select"'), + function("Func1()"), + function("func2()"), + ] + ) + + +@parametrize("completer", completers(aliasing=False, casing=True)) +def test_alias_search_without_aliases2(completer): + text = "SELECT * FROM blog.et" + result = get_result(completer, text) + assert result[0] == table("EntryTags", -2) + + +@parametrize("completer", completers(aliasing=False, casing=True)) +def test_alias_search_without_aliases1(completer): + text = "SELECT * FROM blog.e" + result = get_result(completer, text) + assert result[0] == table("Entries", -1) + + +@parametrize("completer", completers(aliasing=True, casing=True)) +def test_alias_search_with_aliases2(completer): + text = "SELECT * FROM blog.et" + result = get_result(completer, text) + assert result[0] == table("EntryTags ET", -2) + + +@parametrize("completer", completers(aliasing=True, casing=True)) +def test_alias_search_with_aliases1(completer): + text = "SELECT * FROM blog.e" + result = get_result(completer, text) + assert result[0] == table("Entries E", -1) + + +@parametrize("completer", completers(aliasing=True, casing=True)) +def test_join_alias_search_with_aliases1(completer): + text = "SELECT * FROM blog.Entries E JOIN blog.e" + result = get_result(completer, text) + assert result[:2] == [ + table("Entries E2", -1), + join("EntAccLog EAL ON EAL.EntryID = E.EntryID", -1), + ] + + +@parametrize("completer", completers(aliasing=False, casing=True)) +def test_join_alias_search_without_aliases1(completer): + text = "SELECT * FROM blog.Entries JOIN blog.e" + result = get_result(completer, text) + assert result[:2] == [ + table("Entries", -1), + join("EntAccLog ON EntAccLog.EntryID = Entries.EntryID", -1), + ] + + +@parametrize("completer", completers(aliasing=True, casing=True)) +def test_join_alias_search_with_aliases2(completer): + text = "SELECT * FROM blog.Entries E JOIN blog.et" + result = get_result(completer, text) + assert result[0] == join("EntryTags ET ON ET.EntryID = E.EntryID", -2) + + +@parametrize("completer", completers(aliasing=False, casing=True)) +def test_join_alias_search_without_aliases2(completer): + text = "SELECT * FROM blog.Entries JOIN blog.et" + result = get_result(completer, text) + assert result[0] == join("EntryTags ON EntryTags.EntryID = Entries.EntryID", -2) + + +@parametrize("completer", completers()) +def test_function_alias_search_without_aliases(completer): + text = "SELECT blog.ees" + result = get_result(completer, text) + first = result[0] + assert first.start_position == -3 + assert first.text == "extract_entry_symbols()" + assert first.display_text == "extract_entry_symbols(_entryid)" + + +@parametrize("completer", completers()) +def test_function_alias_search_with_aliases(completer): + text = "SELECT blog.ee" + result = get_result(completer, text) + first = result[0] + assert first.start_position == -2 + assert first.text == "enter_entry(_title := , _text := )" + assert first.display_text == "enter_entry(_title, _text)" + + +@parametrize("completer", completers(filtr=True, casing=True, qualify=no_qual)) +def test_column_alias_search(completer): + result = get_result(completer, "SELECT et FROM blog.Entries E", len("SELECT et")) + cols = ("EntryText", "EntryTitle", "EntryID") + assert result[:3] == [column(c, -2) for c in cols] + + +@parametrize("completer", completers(casing=True)) +def test_column_alias_search_qualified(completer): + result = get_result( + completer, "SELECT E.ei FROM blog.Entries E", len("SELECT E.ei") + ) + cols = ("EntryID", "EntryTitle") + assert result[:3] == [column(c, -2) for c in cols] + + +@parametrize("completer", completers(casing=False, filtr=False, aliasing=False)) +def test_schema_object_order(completer): + result = get_result(completer, "SELECT * FROM u") + assert result[:3] == [ + table(t, pos=-1) for t in ("users", 'custom."Users"', "custom.users") + ] + + +@parametrize("completer", completers(casing=False, filtr=False, aliasing=False)) +def test_all_schema_objects(completer): + text = "SELECT * FROM " + result = get_result(completer, text) + assert completions_to_set(result) >= completions_to_set( + [table(x) for x in ("orders", '"select"', "custom.shipments")] + + [function(x + "()") for x in ("func2",)] + ) + + +@parametrize("completer", completers(filtr=False, aliasing=False, casing=True)) +def test_all_schema_objects_with_casing(completer): + text = "SELECT * FROM " + result = get_result(completer, text) + assert completions_to_set(result) >= completions_to_set( + [table(x) for x in ("Orders", '"select"', "CUSTOM.shipments")] + + [function(x + "()") for x in ("func2",)] + ) + + +@parametrize("completer", completers(casing=False, filtr=False, aliasing=True)) +def test_all_schema_objects_with_aliases(completer): + text = "SELECT * FROM " + result = get_result(completer, text) + assert completions_to_set(result) >= completions_to_set( + [table(x) for x in ("orders o", '"select" s', "custom.shipments s")] + + [function(x) for x in ("func2() f",)] + ) + + +@parametrize("completer", completers(casing=False, filtr=False, aliasing=True)) +def test_set_schema(completer): + text = "SET SCHEMA " + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [schema("'blog'"), schema("'Custom'"), schema("'custom'"), schema("'public'")] + ) diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py new file mode 100644 index 0000000..db1fe0a --- /dev/null +++ b/tests/test_smart_completion_public_schema_only.py @@ -0,0 +1,1112 @@ +from metadata import ( + MetaData, + alias, + name_join, + fk_join, + join, + keyword, + schema, + table, + view, + function, + column, + wildcard_expansion, + get_result, + result_set, + qual, + no_qual, + parametrize, +) +from prompt_toolkit.completion import Completion +from utils import completions_to_set + + +metadata = { + "tables": { + "users": ["id", "parentid", "email", "first_name", "last_name"], + "Users": ["userid", "username"], + "orders": ["id", "ordered_date", "status", "email"], + "select": ["id", "insert", "ABC"], + }, + "views": {"user_emails": ["id", "email"], "functions": ["function"]}, + "functions": [ + ["custom_fun", [], [], [], "", False, False, False, False], + ["_custom_fun", [], [], [], "", False, False, False, False], + ["custom_func1", [], [], [], "", False, False, False, False], + ["custom_func2", [], [], [], "", False, False, False, False], + [ + "set_returning_func", + ["x", "y"], + ["integer", "integer"], + ["b", "b"], + "", + False, + False, + True, + False, + ], + ], + "datatypes": ["custom_type1", "custom_type2"], + "foreignkeys": [ + ("public", "users", "id", "public", "users", "parentid"), + ("public", "users", "id", "public", "Users", "userid"), + ], +} + +metadata = {k: {"public": v} for k, v in metadata.items()} + +testdata = MetaData(metadata) + +cased_users_col_names = ["ID", "PARENTID", "Email", "First_Name", "last_name"] +cased_users2_col_names = ["UserID", "UserName"] +cased_func_names = [ + "Custom_Fun", + "_custom_fun", + "Custom_Func1", + "custom_func2", + "set_returning_func", +] +cased_tbls = ["Users", "Orders"] +cased_views = ["User_Emails", "Functions"] +casing = ( + ["SELECT", "PUBLIC"] + + cased_func_names + + cased_tbls + + cased_views + + cased_users_col_names + + cased_users2_col_names +) +# Lists for use in assertions +cased_funcs = [ + function(f) + for f in ("Custom_Fun()", "_custom_fun()", "Custom_Func1()", "custom_func2()") +] + [function("set_returning_func(x := , y := )", display="set_returning_func(x, y)")] +cased_tbls = [table(t) for t in (cased_tbls + ['"Users"', '"select"'])] +cased_rels = [view(t) for t in cased_views] + cased_funcs + cased_tbls +cased_users_cols = [column(c) for c in cased_users_col_names] +aliased_rels = ( + [table(t) for t in ("users u", '"Users" U', "orders o", '"select" s')] + + [view("user_emails ue"), view("functions f")] + + [ + function(f) + for f in ( + "_custom_fun() cf", + "custom_fun() cf", + "custom_func1() cf", + "custom_func2() cf", + ) + ] + + [ + function( + "set_returning_func(x := , y := ) srf", + display="set_returning_func(x, y) srf", + ) + ] +) +cased_aliased_rels = ( + [table(t) for t in ("Users U", '"Users" U', "Orders O", '"select" s')] + + [view("User_Emails UE"), view("Functions F")] + + [ + function(f) + for f in ( + "_custom_fun() cf", + "Custom_Fun() CF", + "Custom_Func1() CF", + "custom_func2() cf", + ) + ] + + [ + function( + "set_returning_func(x := , y := ) srf", + display="set_returning_func(x, y) srf", + ) + ] +) +completers = testdata.get_completers(casing) + + +# Just to make sure that this doesn't crash +@parametrize("completer", completers()) +def test_function_column_name(completer): + for l in range( + len("SELECT * FROM Functions WHERE function:"), + len("SELECT * FROM Functions WHERE function:text") + 1, + ): + assert [] == get_result( + completer, "SELECT * FROM Functions WHERE function:text"[:l] + ) + + +@parametrize("action", ["ALTER", "DROP", "CREATE", "CREATE OR REPLACE"]) +@parametrize("completer", completers()) +def test_drop_alter_function(completer, action): + assert get_result(completer, action + " FUNCTION set_ret") == [ + function("set_returning_func(x integer, y integer)", -len("set_ret")) + ] + + +@parametrize("completer", completers()) +def test_empty_string_completion(completer): + result = get_result(completer, "") + assert completions_to_set( + testdata.keywords() + testdata.specials() + ) == completions_to_set(result) + + +@parametrize("completer", completers()) +def test_select_keyword_completion(completer): + result = get_result(completer, "SEL") + assert completions_to_set(result) == completions_to_set([keyword("SELECT", -3)]) + + +@parametrize("completer", completers()) +def test_builtin_function_name_completion(completer): + result = get_result(completer, "SELECT MA") + assert completions_to_set(result) == completions_to_set( + [ + function("MAKE_DATE", -2), + function("MAKE_INTERVAL", -2), + function("MAKE_TIME", -2), + function("MAKE_TIMESTAMP", -2), + function("MAKE_TIMESTAMPTZ", -2), + function("MASKLEN", -2), + function("MAX", -2), + keyword("MAXEXTENTS", -2), + keyword("MATERIALIZED VIEW", -2), + ] + ) + + +@parametrize("completer", completers()) +def test_builtin_function_matches_only_at_start(completer): + text = "SELECT IN" + + result = [c.text for c in get_result(completer, text)] + + assert "MIN" not in result + + +@parametrize("completer", completers(casing=False, aliasing=False)) +def test_user_function_name_completion(completer): + result = get_result(completer, "SELECT cu") + assert completions_to_set(result) == completions_to_set( + [ + function("custom_fun()", -2), + function("_custom_fun()", -2), + function("custom_func1()", -2), + function("custom_func2()", -2), + function("CURRENT_DATE", -2), + function("CURRENT_TIMESTAMP", -2), + function("CUME_DIST", -2), + function("CURRENT_TIME", -2), + keyword("CURRENT", -2), + ] + ) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +def test_user_function_name_completion_matches_anywhere(completer): + result = get_result(completer, "SELECT om") + assert completions_to_set(result) == completions_to_set( + [ + function("custom_fun()", -2), + function("_custom_fun()", -2), + function("custom_func1()", -2), + function("custom_func2()", -2), + ] + ) + + +@parametrize("completer", completers(casing=True)) +def test_list_functions_for_special(completer): + result = get_result(completer, r"\df ") + assert completions_to_set(result) == completions_to_set( + [schema("PUBLIC")] + [function(f) for f in cased_func_names] + ) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_suggested_column_names_from_visible_table(completer): + result = get_result(completer, "SELECT from users", len("SELECT ")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("users") + ) + + +@parametrize("completer", completers(casing=True, qualify=no_qual)) +def test_suggested_cased_column_names(completer): + result = get_result(completer, "SELECT from users", len("SELECT ")) + assert completions_to_set(result) == completions_to_set( + cased_funcs + + cased_users_cols + + testdata.builtin_functions() + + testdata.keywords() + ) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +@parametrize("text", ["SELECT from users", "INSERT INTO Orders SELECT from users"]) +def test_suggested_auto_qualified_column_names(text, completer): + position = text.index(" ") + 1 + cols = [column(c.lower()) for c in cased_users_col_names] + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + cols + testdata.functions_and_keywords() + ) + + +@parametrize("completer", completers(casing=False, qualify=qual)) +@parametrize( + "text", + [ + 'SELECT from users U NATURAL JOIN "Users"', + 'INSERT INTO Orders SELECT from users U NATURAL JOIN "Users"', + ], +) +def test_suggested_auto_qualified_column_names_two_tables(text, completer): + position = text.index(" ") + 1 + cols = [column("U." + c.lower()) for c in cased_users_col_names] + cols += [column('"Users".' + c.lower()) for c in cased_users2_col_names] + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + cols + testdata.functions_and_keywords() + ) + + +@parametrize("completer", completers(casing=True, qualify=["always"])) +@parametrize("text", ["UPDATE users SET ", "INSERT INTO users("]) +def test_no_column_qualification(text, completer): + cols = [column(c) for c in cased_users_col_names] + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set(cols) + + +@parametrize("completer", completers(casing=True, qualify=["always"])) +def test_suggested_cased_always_qualified_column_names(completer): + text = "SELECT from users" + position = len("SELECT ") + cols = [column("users." + c) for c in cased_users_col_names] + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + cased_funcs + cols + testdata.builtin_functions() + testdata.keywords() + ) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_suggested_column_names_in_function(completer): + result = get_result(completer, "SELECT MAX( from users", len("SELECT MAX(")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("users") + ) + + +@parametrize("completer", completers(casing=False)) +def test_suggested_column_names_with_table_dot(completer): + result = get_result(completer, "SELECT users. from users", len("SELECT users.")) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) + + +@parametrize("completer", completers(casing=False)) +def test_suggested_column_names_with_alias(completer): + result = get_result(completer, "SELECT u. from users u", len("SELECT u.")) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_suggested_multiple_column_names(completer): + result = get_result(completer, "SELECT id, from users u", len("SELECT id, ")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("users") + ) + + +@parametrize("completer", completers(casing=False)) +def test_suggested_multiple_column_names_with_alias(completer): + result = get_result( + completer, "SELECT u.id, u. from users u", len("SELECT u.id, u.") + ) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) + + +@parametrize("completer", completers(casing=True)) +def test_suggested_cased_column_names_with_alias(completer): + result = get_result( + completer, "SELECT u.id, u. from users u", len("SELECT u.id, u.") + ) + assert completions_to_set(result) == completions_to_set(cased_users_cols) + + +@parametrize("completer", completers(casing=False)) +def test_suggested_multiple_column_names_with_dot(completer): + result = get_result( + completer, + "SELECT users.id, users. from users u", + len("SELECT users.id, users."), + ) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) + + +@parametrize("completer", completers(casing=False)) +def test_suggest_columns_after_three_way_join(completer): + text = """SELECT * FROM users u1 + INNER JOIN users u2 ON u1.id = u2.id + INNER JOIN users u3 ON u2.id = u3.""" + result = get_result(completer, text) + assert column("id") in result + + +join_condition_texts = [ + 'INSERT INTO orders SELECT * FROM users U JOIN "Users" U2 ON ', + """INSERT INTO public.orders(orderid) + SELECT * FROM users U JOIN "Users" U2 ON """, + 'SELECT * FROM users U JOIN "Users" U2 ON ', + 'SELECT * FROM users U INNER join "Users" U2 ON ', + 'SELECT * FROM USERS U right JOIN "Users" U2 ON ', + 'SELECT * FROM users U LEFT JOIN "Users" U2 ON ', + 'SELECT * FROM Users U FULL JOIN "Users" U2 ON ', + 'SELECT * FROM users U right outer join "Users" U2 ON ', + 'SELECT * FROM Users U LEFT OUTER JOIN "Users" U2 ON ', + 'SELECT * FROM users U FULL OUTER JOIN "Users" U2 ON ', + """SELECT * + FROM users U + FULL OUTER JOIN "Users" U2 ON + """, +] + + +@parametrize("completer", completers(casing=False)) +@parametrize("text", join_condition_texts) +def test_suggested_join_conditions(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [alias("U"), alias("U2"), fk_join("U2.userid = U.id")] + ) + + +@parametrize("completer", completers(casing=True)) +@parametrize("text", join_condition_texts) +def test_cased_join_conditions(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [alias("U"), alias("U2"), fk_join("U2.UserID = U.ID")] + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + """SELECT * + FROM users + CROSS JOIN "Users" + NATURAL JOIN users u + JOIN "Users" u2 ON + """ + ], +) +def test_suggested_join_conditions_with_same_table_twice(completer, text): + result = get_result(completer, text) + assert result == [ + fk_join("u2.userid = u.id"), + fk_join("u2.userid = users.id"), + name_join('u2.userid = "Users".userid'), + name_join('u2.username = "Users".username'), + alias("u"), + alias("u2"), + alias("users"), + alias('"Users"'), + ] + + +@parametrize("completer", completers()) +@parametrize("text", ["SELECT * FROM users JOIN users u2 on foo."]) +def test_suggested_join_conditions_with_invalid_qualifier(completer, text): + result = get_result(completer, text) + assert result == [] + + +@parametrize("completer", completers(casing=False)) +@parametrize( + ("text", "ref"), + [ + ("SELECT * FROM users JOIN NonTable on ", "NonTable"), + ("SELECT * FROM users JOIN nontable nt on ", "nt"), + ], +) +def test_suggested_join_conditions_with_invalid_table(completer, text, ref): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [alias("users"), alias(ref)] + ) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize( + "text", + [ + 'SELECT * FROM "Users" u JOIN u', + 'SELECT * FROM "Users" u JOIN uid', + 'SELECT * FROM "Users" u JOIN userid', + 'SELECT * FROM "Users" u JOIN id', + ], +) +def test_suggested_joins_fuzzy(completer, text): + result = get_result(completer, text) + last_word = text.split()[-1] + expected = join("users ON users.id = u.userid", -len(last_word)) + assert expected in result + + +join_texts = [ + "SELECT * FROM Users JOIN ", + """INSERT INTO "Users" + SELECT * + FROM Users + INNER JOIN """, + """INSERT INTO public."Users"(username) + SELECT * + FROM Users + INNER JOIN """, + """SELECT * + FROM Users + INNER JOIN """, +] + + +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize("text", join_texts) +def test_suggested_joins(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas_and_from_clause_items() + + [ + join('"Users" ON "Users".userid = Users.id'), + join("users users2 ON users2.id = Users.parentid"), + join("users users2 ON users2.parentid = Users.id"), + ] + ) + + +@parametrize("completer", completers(casing=True, aliasing=False)) +@parametrize("text", join_texts) +def test_cased_joins(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [schema("PUBLIC")] + + cased_rels + + [ + join('"Users" ON "Users".UserID = Users.ID'), + join("Users Users2 ON Users2.ID = Users.PARENTID"), + join("Users Users2 ON Users2.PARENTID = Users.ID"), + ] + ) + + +@parametrize("completer", completers(casing=False, aliasing=True)) +@parametrize("text", join_texts) +def test_aliased_joins(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas() + + aliased_rels + + [ + join('"Users" U ON U.userid = Users.id'), + join("users u ON u.id = Users.parentid"), + join("users u ON u.parentid = Users.id"), + ] + ) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize( + "text", + [ + 'SELECT * FROM public."Users" JOIN ', + 'SELECT * FROM public."Users" RIGHT OUTER JOIN ', + """SELECT * + FROM public."Users" + LEFT JOIN """, + ], +) +def test_suggested_joins_quoted_schema_qualified_table(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas_and_from_clause_items() + + [join('public.users ON users.id = "Users".userid')] + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT u.name, o.id FROM users u JOIN orders o ON ", + "SELECT u.name, o.id FROM users u JOIN orders o ON JOIN orders o2 ON", + ], +) +def test_suggested_aliases_after_on(completer, text): + position = len("SELECT u.name, o.id FROM users u JOIN orders o ON ") + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + [ + alias("u"), + name_join("o.id = u.id"), + name_join("o.email = u.email"), + alias("o"), + ] + ) + + +@parametrize("completer", completers()) +@parametrize( + "text", + [ + "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ", + "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = JOIN orders o2 ON", + ], +) +def test_suggested_aliases_after_on_right_side(completer, text): + position = len("SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ") + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set([alias("u"), alias("o")]) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT users.name, orders.id FROM users JOIN orders ON ", + "SELECT users.name, orders.id FROM users JOIN orders ON JOIN orders orders2 ON", + ], +) +def test_suggested_tables_after_on(completer, text): + position = len("SELECT users.name, orders.id FROM users JOIN orders ON ") + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + [ + name_join("orders.id = users.id"), + name_join("orders.email = users.email"), + alias("users"), + alias("orders"), + ] + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = JOIN orders orders2 ON", + "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ", + ], +) +def test_suggested_tables_after_on_right_side(completer, text): + position = len( + "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = " + ) + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + [alias("users"), alias("orders")] + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT * FROM users INNER JOIN orders USING (", + "SELECT * FROM users INNER JOIN orders USING(", + ], +) +def test_join_using_suggests_common_columns(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [column("id"), column("email")] + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT * FROM users u1 JOIN users u2 USING (email) JOIN user_emails ue USING()", + "SELECT * FROM users u1 JOIN users u2 USING(email) JOIN user_emails ue USING ()", + "SELECT * FROM users u1 JOIN user_emails ue USING () JOIN users u2 ue USING(first_name, last_name)", + "SELECT * FROM users u1 JOIN user_emails ue USING() JOIN users u2 ue USING (first_name, last_name)", + ], +) +def test_join_using_suggests_from_last_table(completer, text): + position = text.index("()") + 1 + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + [column("id"), column("email")] + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT * FROM users INNER JOIN orders USING (id,", + "SELECT * FROM users INNER JOIN orders USING(id,", + ], +) +def test_join_using_suggests_columns_after_first_column(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [column("id"), column("email")] + ) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize( + "text", + [ + "SELECT * FROM ", + "SELECT * FROM users CROSS JOIN ", + "SELECT * FROM users natural join ", + ], +) +def test_table_names_after_from(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas_and_from_clause_items() + ) + assert [c.text for c in result] == [ + "public", + "orders", + '"select"', + "users", + '"Users"', + "functions", + "user_emails", + "_custom_fun()", + "custom_fun()", + "custom_func1()", + "custom_func2()", + "set_returning_func(x := , y := )", + ] + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_auto_escaped_col_names(completer): + result = get_result(completer, 'SELECT from "select"', len("SELECT ")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("select") + ) + + +@parametrize("completer", completers(aliasing=False)) +def test_allow_leading_double_quote_in_last_word(completer): + result = get_result(completer, 'SELECT * from "sele') + + expected = table('"select"', -5) + + assert expected in result + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT 1::", + "CREATE TABLE foo (bar ", + "CREATE FUNCTION foo (bar INT, baz ", + "ALTER TABLE foo ALTER COLUMN bar TYPE ", + ], +) +def test_suggest_datatype(text, completer): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas() + testdata.types() + testdata.builtin_datatypes() + ) + + +@parametrize("completer", completers(casing=False)) +def test_suggest_columns_from_escaped_table_alias(completer): + result = get_result(completer, 'select * from "select" s where s.') + assert completions_to_set(result) == completions_to_set(testdata.columns("select")) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_suggest_columns_from_set_returning_function(completer): + result = get_result(completer, "select from set_returning_func()", len("select ")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("set_returning_func", typ="functions") + ) + + +@parametrize("completer", completers(casing=False)) +def test_suggest_columns_from_aliased_set_returning_function(completer): + result = get_result( + completer, "select f. from set_returning_func() f", len("select f.") + ) + assert completions_to_set(result) == completions_to_set( + testdata.columns("set_returning_func", typ="functions") + ) + + +@parametrize("completer", completers(casing=False)) +def test_join_functions_using_suggests_common_columns(completer): + text = """SELECT * FROM set_returning_func() f1 + INNER JOIN set_returning_func() f2 USING (""" + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.columns("set_returning_func", typ="functions") + ) + + +@parametrize("completer", completers(casing=False)) +def test_join_functions_on_suggests_columns_and_join_conditions(completer): + text = """SELECT * FROM set_returning_func() f1 + INNER JOIN set_returning_func() f2 ON f1.""" + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [name_join("y = f2.y"), name_join("x = f2.x")] + + testdata.columns("set_returning_func", typ="functions") + ) + + +@parametrize("completer", completers()) +def test_learn_keywords(completer): + history = "CREATE VIEW v AS SELECT 1" + completer.extend_query_history(history) + + # Now that we've used `VIEW` once, it should be suggested ahead of other + # keywords starting with v. + text = "create v" + completions = get_result(completer, text) + assert completions[0].text == "VIEW" + + +@parametrize("completer", completers(casing=False, aliasing=False)) +def test_learn_table_names(completer): + history = "SELECT * FROM users; SELECT * FROM orders; SELECT * FROM users" + completer.extend_query_history(history) + + text = "SELECT * FROM " + completions = get_result(completer, text) + + # `users` should be higher priority than `orders` (used more often) + users = table("users") + orders = table("orders") + + assert completions.index(users) < completions.index(orders) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_columns_before_keywords(completer): + text = "SELECT * FROM orders WHERE s" + completions = get_result(completer, text) + + col = column("status", -1) + kw = keyword("SELECT", -1) + + assert completions.index(col) < completions.index(kw) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +@parametrize( + "text", + [ + "SELECT * FROM users", + "INSERT INTO users SELECT * FROM users u", + """INSERT INTO users(id, parentid, email, first_name, last_name) + SELECT * + FROM users u""", + ], +) +def test_wildcard_column_expansion(completer, text): + position = text.find("*") + 1 + + completions = get_result(completer, text, position) + + col_list = "id, parentid, email, first_name, last_name" + expected = [wildcard_expansion(col_list)] + + assert expected == completions + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT u.* FROM users u", + "INSERT INTO public.users SELECT u.* FROM users u", + """INSERT INTO users(id, parentid, email, first_name, last_name) + SELECT u.* + FROM users u""", + ], +) +def test_wildcard_column_expansion_with_alias(completer, text): + position = text.find("*") + 1 + + completions = get_result(completer, text, position) + + col_list = "id, u.parentid, u.email, u.first_name, u.last_name" + expected = [wildcard_expansion(col_list)] + + assert expected == completions + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text,expected", + [ + ( + "SELECT users.* FROM users", + "id, users.parentid, users.email, users.first_name, users.last_name", + ), + ( + "SELECT Users.* FROM Users", + "id, Users.parentid, Users.email, Users.first_name, Users.last_name", + ), + ], +) +def test_wildcard_column_expansion_with_table_qualifier(completer, text, expected): + position = len("SELECT users.*") + + completions = get_result(completer, text, position) + + expected = [wildcard_expansion(expected)] + + assert expected == completions + + +@parametrize("completer", completers(casing=False, qualify=qual)) +def test_wildcard_column_expansion_with_two_tables(completer): + text = 'SELECT * FROM "select" JOIN users u ON true' + position = len("SELECT *") + + completions = get_result(completer, text, position) + + cols = ( + '"select".id, "select".insert, "select"."ABC", ' + "u.id, u.parentid, u.email, u.first_name, u.last_name" + ) + expected = [wildcard_expansion(cols)] + assert completions == expected + + +@parametrize("completer", completers(casing=False)) +def test_wildcard_column_expansion_with_two_tables_and_parent(completer): + text = 'SELECT "select".* FROM "select" JOIN users u ON true' + position = len('SELECT "select".*') + + completions = get_result(completer, text, position) + + col_list = 'id, "select".insert, "select"."ABC"' + expected = [wildcard_expansion(col_list)] + + assert expected == completions + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + ["SELECT U. FROM Users U", "SELECT U. FROM USERS U", "SELECT U. FROM users U"], +) +def test_suggest_columns_from_unquoted_table(completer, text): + position = len("SELECT U.") + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) + + +@parametrize("completer", completers(casing=False)) +def test_suggest_columns_from_quoted_table(completer): + result = get_result(completer, 'SELECT U. FROM "Users" U', len("SELECT U.")) + assert completions_to_set(result) == completions_to_set(testdata.columns("Users")) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize("text", ["SELECT * FROM ", "SELECT * FROM Orders o CROSS JOIN "]) +def test_schema_or_visible_table_completion(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas_and_from_clause_items() + ) + + +@parametrize("completer", completers(casing=False, aliasing=True)) +@parametrize("text", ["SELECT * FROM "]) +def test_table_aliases(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas() + aliased_rels + ) + + +@parametrize("completer", completers(casing=False, aliasing=True)) +@parametrize("text", ["SELECT * FROM Orders o CROSS JOIN "]) +def test_duplicate_table_aliases(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas() + + [ + table("orders o2"), + table("users u"), + table('"Users" U'), + table('"select" s'), + view("user_emails ue"), + view("functions f"), + function("_custom_fun() cf"), + function("custom_fun() cf"), + function("custom_func1() cf"), + function("custom_func2() cf"), + function( + "set_returning_func(x := , y := ) srf", + display="set_returning_func(x, y) srf", + ), + ] + ) + + +@parametrize("completer", completers(casing=True, aliasing=True)) +@parametrize("text", ["SELECT * FROM Orders o CROSS JOIN "]) +def test_duplicate_aliases_with_casing(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [ + schema("PUBLIC"), + table("Orders O2"), + table("Users U"), + table('"Users" U'), + table('"select" s'), + view("User_Emails UE"), + view("Functions F"), + function("_custom_fun() cf"), + function("Custom_Fun() CF"), + function("Custom_Func1() CF"), + function("custom_func2() cf"), + function( + "set_returning_func(x := , y := ) srf", + display="set_returning_func(x, y) srf", + ), + ] + ) + + +@parametrize("completer", completers(casing=True, aliasing=True)) +@parametrize("text", ["SELECT * FROM "]) +def test_aliases_with_casing(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [schema("PUBLIC")] + cased_aliased_rels + ) + + +@parametrize("completer", completers(casing=True, aliasing=False)) +@parametrize("text", ["SELECT * FROM "]) +def test_table_casing(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [schema("PUBLIC")] + cased_rels + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "INSERT INTO users ()", + "INSERT INTO users()", + "INSERT INTO users () SELECT * FROM orders;", + "INSERT INTO users() SELECT * FROM users u cross join orders o", + ], +) +def test_insert(completer, text): + position = text.find("(") + 1 + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +def test_suggest_cte_names(completer): + text = """ + WITH cte1 AS (SELECT a, b, c FROM foo), + cte2 AS (SELECT d, e, f FROM bar) + SELECT * FROM + """ + result = get_result(completer, text) + expected = completions_to_set( + [ + Completion("cte1", 0, display_meta="table"), + Completion("cte2", 0, display_meta="table"), + ] + ) + assert expected <= completions_to_set(result) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_suggest_columns_from_cte(completer): + result = get_result( + completer, + "WITH cte AS (SELECT foo, bar FROM baz) SELECT FROM cte", + len("WITH cte AS (SELECT foo, bar FROM baz) SELECT "), + ) + expected = [ + Completion("foo", 0, display_meta="column"), + Completion("bar", 0, display_meta="column"), + ] + testdata.functions_and_keywords() + + assert completions_to_set(expected) == completions_to_set(result) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +@parametrize( + "text", + [ + "WITH cte AS (SELECT foo FROM bar) SELECT * FROM cte WHERE cte.", + "WITH cte AS (SELECT foo FROM bar) SELECT * FROM cte c WHERE c.", + ], +) +def test_cte_qualified_columns(completer, text): + result = get_result(completer, text) + expected = [Completion("foo", 0, display_meta="column")] + assert completions_to_set(expected) == completions_to_set(result) + + +@parametrize( + "keyword_casing,expected,texts", + [ + ("upper", "SELECT", ("", "s", "S", "Sel")), + ("lower", "select", ("", "s", "S", "Sel")), + ("auto", "SELECT", ("", "S", "SEL", "seL")), + ("auto", "select", ("s", "sel", "SEl")), + ], +) +def test_keyword_casing_upper(keyword_casing, expected, texts): + for text in texts: + completer = testdata.get_completer({"keyword_casing": keyword_casing}) + completions = get_result(completer, text) + assert expected in [cpl.text for cpl in completions] + + +@parametrize("completer", completers()) +def test_keyword_after_alter(completer): + text = "ALTER TABLE users ALTER " + expected = Completion("COLUMN", start_position=0, display_meta="keyword") + completions = get_result(completer, text) + assert expected in completions + + +@parametrize("completer", completers()) +def test_set_schema(completer): + text = "SET SCHEMA " + result = get_result(completer, text) + expected = completions_to_set([schema("'public'")]) + assert completions_to_set(result) == expected + + +@parametrize("completer", completers()) +def test_special_name_completion(completer): + result = get_result(completer, "\\t") + assert completions_to_set(result) == completions_to_set( + [ + Completion( + text="\\timing", + start_position=-2, + display_meta="Toggle timing of commands.", + ) + ] + ) diff --git a/tests/test_sqlcompletion.py b/tests/test_sqlcompletion.py new file mode 100644 index 0000000..1034bbe --- /dev/null +++ b/tests/test_sqlcompletion.py @@ -0,0 +1,964 @@ +from pgcli.packages.sqlcompletion import ( + suggest_type, + Special, + Database, + Schema, + Table, + Column, + View, + Keyword, + FromClauseItem, + Function, + Datatype, + Alias, + JoinCondition, + Join, +) +from pgcli.packages.parseutils.tables import TableReference +import pytest + + +def cols_etc( + table, schema=None, alias=None, is_function=False, parent=None, last_keyword=None +): + """Returns the expected select-clause suggestions for a single-table + select.""" + return { + Column( + table_refs=(TableReference(schema, table, alias, is_function),), + qualifiable=True, + ), + Function(schema=parent), + Keyword(last_keyword), + } + + +def test_select_suggests_cols_with_visible_table_scope(): + suggestions = suggest_type("SELECT FROM tabl", "SELECT ") + assert set(suggestions) == cols_etc("tabl", last_keyword="SELECT") + + +def test_select_suggests_cols_with_qualified_table_scope(): + suggestions = suggest_type("SELECT FROM sch.tabl", "SELECT ") + assert set(suggestions) == cols_etc("tabl", "sch", last_keyword="SELECT") + + +def test_cte_does_not_crash(): + sql = "WITH CTE AS (SELECT F.* FROM Foo F WHERE F.Bar > 23) SELECT C.* FROM CTE C WHERE C.FooID BETWEEN 123 AND 234;" + for i in range(len(sql)): + suggestions = suggest_type(sql[: i + 1], sql[: i + 1]) + + +@pytest.mark.parametrize("expression", ['SELECT * FROM "tabl" WHERE ']) +def test_where_suggests_columns_functions_quoted_table(expression): + expected = cols_etc("tabl", alias='"tabl"', last_keyword="WHERE") + suggestions = suggest_type(expression, expression) + assert expected == set(suggestions) + + +@pytest.mark.parametrize( + "expression", + [ + "INSERT INTO OtherTabl(ID, Name) SELECT * FROM tabl WHERE ", + "INSERT INTO OtherTabl SELECT * FROM tabl WHERE ", + "SELECT * FROM tabl WHERE ", + "SELECT * FROM tabl WHERE (", + "SELECT * FROM tabl WHERE foo = ", + "SELECT * FROM tabl WHERE bar OR ", + "SELECT * FROM tabl WHERE foo = 1 AND ", + "SELECT * FROM tabl WHERE (bar > 10 AND ", + "SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (", + "SELECT * FROM tabl WHERE 10 < ", + "SELECT * FROM tabl WHERE foo BETWEEN ", + "SELECT * FROM tabl WHERE foo BETWEEN foo AND ", + ], +) +def test_where_suggests_columns_functions(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE") + + +@pytest.mark.parametrize( + "expression", + ["SELECT * FROM tabl WHERE foo IN (", "SELECT * FROM tabl WHERE foo IN (bar, "], +) +def test_where_in_suggests_columns(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE") + + +@pytest.mark.parametrize("expression", ["SELECT 1 AS ", "SELECT 1 FROM tabl AS "]) +def test_after_as(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == set() + + +def test_where_equals_any_suggests_columns_or_keywords(): + text = "SELECT * FROM tabl WHERE foo = ANY(" + suggestions = suggest_type(text, text) + assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE") + + +def test_lparen_suggests_cols_and_funcs(): + suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(") + assert set(suggestion) == { + Column(table_refs=((None, "tbl", None, False),), qualifiable=True), + Function(schema=None), + Keyword("("), + } + + +def test_select_suggests_cols_and_funcs(): + suggestions = suggest_type("SELECT ", "SELECT ") + assert set(suggestions) == { + Column(table_refs=(), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + } + + +@pytest.mark.parametrize( + "expression", ["INSERT INTO ", "COPY ", "UPDATE ", "DESCRIBE "] +) +def test_suggests_tables_views_and_schemas(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == {Table(schema=None), View(schema=None), Schema()} + + +@pytest.mark.parametrize("expression", ["SELECT * FROM "]) +def test_suggest_tables_views_schemas_and_functions(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM foo JOIN bar on bar.barid = foo.barid JOIN ", + "SELECT * FROM foo JOIN bar USING (barid) JOIN ", + ], +) +def test_suggest_after_join_with_two_tables(expression): + suggestions = suggest_type(expression, expression) + tables = tuple([(None, "foo", None, False), (None, "bar", None, False)]) + assert set(suggestions) == { + FromClauseItem(schema=None, table_refs=tables), + Join(tables, None), + Schema(), + } + + +@pytest.mark.parametrize( + "expression", ["SELECT * FROM foo JOIN ", "SELECT * FROM foo JOIN bar"] +) +def test_suggest_after_join_with_one_table(expression): + suggestions = suggest_type(expression, expression) + tables = ((None, "foo", None, False),) + assert set(suggestions) == { + FromClauseItem(schema=None, table_refs=tables), + Join(((None, "foo", None, False),), None), + Schema(), + } + + +@pytest.mark.parametrize( + "expression", ["INSERT INTO sch.", "COPY sch.", "DESCRIBE sch."] +) +def test_suggest_qualified_tables_and_views(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == {Table(schema="sch"), View(schema="sch")} + + +@pytest.mark.parametrize("expression", ["UPDATE sch."]) +def test_suggest_qualified_aliasable_tables_and_views(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == {Table(schema="sch"), View(schema="sch")} + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM sch.", + 'SELECT * FROM sch."', + 'SELECT * FROM sch."foo', + 'SELECT * FROM "sch".', + 'SELECT * FROM "sch"."', + ], +) +def test_suggest_qualified_tables_views_and_functions(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == {FromClauseItem(schema="sch")} + + +@pytest.mark.parametrize("expression", ["SELECT * FROM foo JOIN sch."]) +def test_suggest_qualified_tables_views_functions_and_joins(expression): + suggestions = suggest_type(expression, expression) + tbls = tuple([(None, "foo", None, False)]) + assert set(suggestions) == { + FromClauseItem(schema="sch", table_refs=tbls), + Join(tbls, "sch"), + } + + +def test_truncate_suggests_tables_and_schemas(): + suggestions = suggest_type("TRUNCATE ", "TRUNCATE ") + assert set(suggestions) == {Table(schema=None), Schema()} + + +def test_truncate_suggests_qualified_tables(): + suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.") + assert set(suggestions) == {Table(schema="sch")} + + +@pytest.mark.parametrize( + "text", ["SELECT DISTINCT ", "INSERT INTO foo SELECT DISTINCT "] +) +def test_distinct_suggests_cols(text): + suggestions = suggest_type(text, text) + assert set(suggestions) == { + Column(table_refs=(), local_tables=(), qualifiable=True), + Function(schema=None), + Keyword("DISTINCT"), + } + + +@pytest.mark.parametrize( + "text, text_before, last_keyword", + [ + ("SELECT DISTINCT FROM tbl x JOIN tbl1 y", "SELECT DISTINCT", "SELECT"), + ( + "SELECT * FROM tbl x JOIN tbl1 y ORDER BY ", + "SELECT * FROM tbl x JOIN tbl1 y ORDER BY ", + "ORDER BY", + ), + ], +) +def test_distinct_and_order_by_suggestions_with_aliases( + text, text_before, last_keyword +): + suggestions = suggest_type(text, text_before) + assert set(suggestions) == { + Column( + table_refs=( + TableReference(None, "tbl", "x", False), + TableReference(None, "tbl1", "y", False), + ), + local_tables=(), + qualifiable=True, + ), + Function(schema=None), + Keyword(last_keyword), + } + + +@pytest.mark.parametrize( + "text, text_before", + [ + ("SELECT DISTINCT x. FROM tbl x JOIN tbl1 y", "SELECT DISTINCT x."), + ( + "SELECT * FROM tbl x JOIN tbl1 y ORDER BY x.", + "SELECT * FROM tbl x JOIN tbl1 y ORDER BY x.", + ), + ], +) +def test_distinct_and_order_by_suggestions_with_alias_given(text, text_before): + suggestions = suggest_type(text, text_before) + assert set(suggestions) == { + Column( + table_refs=(TableReference(None, "tbl", "x", False),), + local_tables=(), + qualifiable=False, + ), + Table(schema="x"), + View(schema="x"), + Function(schema="x"), + } + + +def test_function_arguments_with_alias_given(): + suggestions = suggest_type("SELECT avg(x. FROM tbl x, tbl2 y", "SELECT avg(x.") + + assert set(suggestions) == { + Column( + table_refs=(TableReference(None, "tbl", "x", False),), + local_tables=(), + qualifiable=False, + ), + Table(schema="x"), + View(schema="x"), + Function(schema="x"), + } + + +def test_col_comma_suggests_cols(): + suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,") + assert set(suggestions) == { + Column(table_refs=((None, "tbl", None, False),), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + } + + +def test_table_comma_suggests_tables_and_schemas(): + suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ") + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} + + +def test_into_suggests_tables_and_schemas(): + suggestion = suggest_type("INSERT INTO ", "INSERT INTO ") + assert set(suggestion) == {Table(schema=None), View(schema=None), Schema()} + + +@pytest.mark.parametrize( + "text", ["INSERT INTO abc (", "INSERT INTO abc () SELECT * FROM hij;"] +) +def test_insert_into_lparen_suggests_cols(text): + suggestions = suggest_type(text, "INSERT INTO abc (") + assert suggestions == ( + Column(table_refs=((None, "abc", None, False),), context="insert"), + ) + + +def test_insert_into_lparen_partial_text_suggests_cols(): + suggestions = suggest_type("INSERT INTO abc (i", "INSERT INTO abc (i") + assert suggestions == ( + Column(table_refs=((None, "abc", None, False),), context="insert"), + ) + + +def test_insert_into_lparen_comma_suggests_cols(): + suggestions = suggest_type("INSERT INTO abc (id,", "INSERT INTO abc (id,") + assert suggestions == ( + Column(table_refs=((None, "abc", None, False),), context="insert"), + ) + + +def test_partially_typed_col_name_suggests_col_names(): + suggestions = suggest_type( + "SELECT * FROM tabl WHERE col_n", "SELECT * FROM tabl WHERE col_n" + ) + assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE") + + +def test_dot_suggests_cols_of_a_table_or_schema_qualified_table(): + suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.") + assert set(suggestions) == { + Column(table_refs=((None, "tabl", None, False),)), + Table(schema="tabl"), + View(schema="tabl"), + Function(schema="tabl"), + } + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT t1. FROM tabl1 t1", + "SELECT t1. FROM tabl1 t1, tabl2 t2", + 'SELECT t1. FROM "tabl1" t1', + 'SELECT t1. FROM "tabl1" t1, "tabl2" t2', + ], +) +def test_dot_suggests_cols_of_an_alias(sql): + suggestions = suggest_type(sql, "SELECT t1.") + assert set(suggestions) == { + Table(schema="t1"), + View(schema="t1"), + Column(table_refs=((None, "tabl1", "t1", False),)), + Function(schema="t1"), + } + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM tabl1 t1 WHERE t1.", + "SELECT * FROM tabl1 t1, tabl2 t2 WHERE t1.", + 'SELECT * FROM "tabl1" t1 WHERE t1.', + 'SELECT * FROM "tabl1" t1, tabl2 t2 WHERE t1.', + ], +) +def test_dot_suggests_cols_of_an_alias_where(sql): + suggestions = suggest_type(sql, sql) + assert set(suggestions) == { + Table(schema="t1"), + View(schema="t1"), + Column(table_refs=((None, "tabl1", "t1", False),)), + Function(schema="t1"), + } + + +def test_dot_col_comma_suggests_cols_or_schema_qualified_table(): + suggestions = suggest_type( + "SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2." + ) + assert set(suggestions) == { + Column(table_refs=((None, "tabl2", "t2", False),)), + Table(schema="t2"), + View(schema="t2"), + Function(schema="t2"), + } + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM (", + "SELECT * FROM foo WHERE EXISTS (", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (", + ], +) +def test_sub_select_suggests_keyword(expression): + suggestion = suggest_type(expression, expression) + assert suggestion == (Keyword(),) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM (S", + "SELECT * FROM foo WHERE EXISTS (S", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (S", + ], +) +def test_sub_select_partial_text_suggests_keyword(expression): + suggestion = suggest_type(expression, expression) + assert suggestion == (Keyword(),) + + +def test_outer_table_reference_in_exists_subquery_suggests_columns(): + q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f." + suggestions = suggest_type(q, q) + assert set(suggestions) == { + Column(table_refs=((None, "foo", "f", False),)), + Table(schema="f"), + View(schema="f"), + Function(schema="f"), + } + + +@pytest.mark.parametrize("expression", ["SELECT * FROM (SELECT * FROM "]) +def test_sub_select_table_name_completion(expression): + suggestion = suggest_type(expression, expression) + assert set(suggestion) == {FromClauseItem(schema=None), Schema()} + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM foo WHERE EXISTS (SELECT * FROM ", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ", + ], +) +def test_sub_select_table_name_completion_with_outer_table(expression): + suggestion = suggest_type(expression, expression) + tbls = tuple([(None, "foo", None, False)]) + assert set(suggestion) == {FromClauseItem(schema=None, table_refs=tbls), Schema()} + + +def test_sub_select_col_name_completion(): + suggestions = suggest_type( + "SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT " + ) + assert set(suggestions) == { + Column(table_refs=((None, "abc", None, False),), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + } + + +@pytest.mark.xfail +def test_sub_select_multiple_col_name_completion(): + suggestions = suggest_type( + "SELECT * FROM (SELECT a, FROM abc", "SELECT * FROM (SELECT a, " + ) + assert set(suggestions) == cols_etc("abc") + + +def test_sub_select_dot_col_name_completion(): + suggestions = suggest_type( + "SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t." + ) + assert set(suggestions) == { + Column(table_refs=((None, "tabl", "t", False),)), + Table(schema="t"), + View(schema="t"), + Function(schema="t"), + } + + +@pytest.mark.parametrize("join_type", ("", "INNER", "LEFT", "RIGHT OUTER")) +@pytest.mark.parametrize("tbl_alias", ("", "foo")) +def test_join_suggests_tables_and_schemas(tbl_alias, join_type): + text = f"SELECT * FROM abc {tbl_alias} {join_type} JOIN " + suggestion = suggest_type(text, text) + tbls = tuple([(None, "abc", tbl_alias or None, False)]) + assert set(suggestion) == { + FromClauseItem(schema=None, table_refs=tbls), + Schema(), + Join(tbls, None), + } + + +def test_left_join_with_comma(): + text = "select * from foo f left join bar b," + suggestions = suggest_type(text, text) + # tbls should also include (None, 'bar', 'b', False) + # but there's a bug with commas + tbls = tuple([(None, "foo", "f", False)]) + assert set(suggestions) == {FromClauseItem(schema=None, table_refs=tbls), Schema()} + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM abc a JOIN def d ON a.", + "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.", + ], +) +def test_join_alias_dot_suggests_cols1(sql): + suggestions = suggest_type(sql, sql) + tables = ((None, "abc", "a", False), (None, "def", "d", False)) + assert set(suggestions) == { + Column(table_refs=((None, "abc", "a", False),)), + Table(schema="a"), + View(schema="a"), + Function(schema="a"), + JoinCondition(table_refs=tables, parent=(None, "abc", "a", False)), + } + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM abc a JOIN def d ON a.id = d.", + "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.", + ], +) +def test_join_alias_dot_suggests_cols2(sql): + suggestion = suggest_type(sql, sql) + assert set(suggestion) == { + Column(table_refs=((None, "def", "d", False),)), + Table(schema="d"), + View(schema="d"), + Function(schema="d"), + } + + +@pytest.mark.parametrize( + "sql", + [ + "select a.x, b.y from abc a join bcd b on ", + """select a.x, b.y +from abc a +join bcd b on +""", + """select a.x, b.y +from abc a +join bcd b +on """, + "select a.x, b.y from abc a join bcd b on a.id = b.id OR ", + ], +) +def test_on_suggests_aliases_and_join_conditions(sql): + suggestions = suggest_type(sql, sql) + tables = ((None, "abc", "a", False), (None, "bcd", "b", False)) + assert set(suggestions) == { + JoinCondition(table_refs=tables, parent=None), + Alias(aliases=("a", "b")), + } + + +@pytest.mark.parametrize( + "sql", + [ + "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ", + "select abc.x, bcd.y from abc join bcd on ", + ], +) +def test_on_suggests_tables_and_join_conditions(sql): + suggestions = suggest_type(sql, sql) + tables = ((None, "abc", None, False), (None, "bcd", None, False)) + assert set(suggestions) == { + JoinCondition(table_refs=tables, parent=None), + Alias(aliases=("abc", "bcd")), + } + + +@pytest.mark.parametrize( + "sql", + [ + "select a.x, b.y from abc a join bcd b on a.id = ", + "select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ", + ], +) +def test_on_suggests_aliases_right_side(sql): + suggestions = suggest_type(sql, sql) + assert suggestions == (Alias(aliases=("a", "b")),) + + +@pytest.mark.parametrize( + "sql", + [ + "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ", + "select abc.x, bcd.y from abc join bcd on ", + ], +) +def test_on_suggests_tables_and_join_conditions_right_side(sql): + suggestions = suggest_type(sql, sql) + tables = ((None, "abc", None, False), (None, "bcd", None, False)) + assert set(suggestions) == { + JoinCondition(table_refs=tables, parent=None), + Alias(aliases=("abc", "bcd")), + } + + +@pytest.mark.parametrize( + "text", + ( + "select * from abc inner join def using (", + "select * from abc inner join def using (col1, ", + "insert into hij select * from abc inner join def using (", + """insert into hij(x, y, z) + select * from abc inner join def using (col1, """, + """insert into hij (a,b,c) + select * from abc inner join def using (col1, """, + ), +) +def test_join_using_suggests_common_columns(text): + tables = ((None, "abc", None, False), (None, "def", None, False)) + assert set(suggest_type(text, text)) == { + Column(table_refs=tables, require_last_table=True) + } + + +def test_suggest_columns_after_multiple_joins(): + sql = """select * from t1 + inner join t2 ON + t1.id = t2.t1_id + inner join t3 ON + t2.id = t3.""" + suggestions = suggest_type(sql, sql) + assert Column(table_refs=((None, "t3", None, False),)) in set(suggestions) + + +def test_2_statements_2nd_current(): + suggestions = suggest_type( + "select * from a; select * from ", "select * from a; select * from " + ) + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} + + suggestions = suggest_type( + "select * from a; select from b", "select * from a; select " + ) + assert set(suggestions) == { + Column(table_refs=((None, "b", None, False),), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + } + + # Should work even if first statement is invalid + suggestions = suggest_type( + "select * from; select * from ", "select * from; select * from " + ) + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} + + +def test_2_statements_1st_current(): + suggestions = suggest_type("select * from ; select * from b", "select * from ") + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} + + suggestions = suggest_type("select from a; select * from b", "select ") + assert set(suggestions) == cols_etc("a", last_keyword="SELECT") + + +def test_3_statements_2nd_current(): + suggestions = suggest_type( + "select * from a; select * from ; select * from c", + "select * from a; select * from ", + ) + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} + + suggestions = suggest_type( + "select * from a; select from b; select * from c", "select * from a; select " + ) + assert set(suggestions) == cols_etc("b", last_keyword="SELECT") + + +@pytest.mark.parametrize( + "text", + [ + """ +CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $$ +SELECT FROM foo; +SELECT 2 FROM bar; +$$ language sql; + """, + """create function func2(int, varchar) +RETURNS text +language sql AS +$func$ +SELECT 2 FROM bar; +SELECT FROM foo; +$func$ + """, + """ +CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $func$ +SELECT 3 FROM foo; +SELECT 2 FROM bar; +$$ language sql; +create function func2(int, varchar) +RETURNS text +language sql AS +$func$ +SELECT 2 FROM bar; +SELECT FROM foo; +$func$ + """, + """ +SELECT * FROM baz; +CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $func$ +SELECT FROM foo; +SELECT 2 FROM bar; +$$ language sql; +create function func2(int, varchar) +RETURNS text +language sql AS +$func$ +SELECT 3 FROM bar; +SELECT FROM foo; +$func$ +SELECT * FROM qux; + """, + ], +) +def test_statements_in_function_body(text): + suggestions = suggest_type(text, text[: text.find(" ") + 1]) + assert set(suggestions) == { + Column(table_refs=((None, "foo", None, False),), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + } + + +functions = [ + """ +CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $$ +SELECT 1 FROM foo; +SELECT 2 FROM bar; +$$ language sql; + """, + """ +create function func2(int, varchar) +RETURNS text +language sql AS +' +SELECT 2 FROM bar; +SELECT 1 FROM foo; +'; + """, +] + + +@pytest.mark.parametrize("text", functions) +def test_statements_with_cursor_after_function_body(text): + suggestions = suggest_type(text, text[: text.find("; ") + 1]) + assert set(suggestions) == {Keyword(), Special()} + + +@pytest.mark.parametrize("text", functions) +def test_statements_with_cursor_before_function_body(text): + suggestions = suggest_type(text, "") + assert set(suggestions) == {Keyword(), Special()} + + +def test_create_db_with_template(): + suggestions = suggest_type( + "create database foo with template ", "create database foo with template " + ) + + assert set(suggestions) == {Database()} + + +@pytest.mark.parametrize("initial_text", ("", " ", "\t \t", "\n")) +def test_specials_included_for_initial_completion(initial_text): + suggestions = suggest_type(initial_text, initial_text) + + assert set(suggestions) == {Keyword(), Special()} + + +def test_drop_schema_qualified_table_suggests_only_tables(): + text = "DROP TABLE schema_name.table_name" + suggestions = suggest_type(text, text) + assert suggestions == (Table(schema="schema_name"),) + + +@pytest.mark.parametrize("text", (",", " ,", "sel ,")) +def test_handle_pre_completion_comma_gracefully(text): + suggestions = suggest_type(text, text) + + assert iter(suggestions) + + +def test_drop_schema_suggests_schemas(): + sql = "DROP SCHEMA " + assert suggest_type(sql, sql) == (Schema(),) + + +@pytest.mark.parametrize("text", ["SELECT x::", "SELECT x::y", "SELECT (x + y)::"]) +def test_cast_operator_suggests_types(text): + assert set(suggest_type(text, text)) == { + Datatype(schema=None), + Table(schema=None), + Schema(), + } + + +@pytest.mark.parametrize( + "text", ["SELECT foo::bar.", "SELECT foo::bar.baz", "SELECT (x + y)::bar."] +) +def test_cast_operator_suggests_schema_qualified_types(text): + assert set(suggest_type(text, text)) == { + Datatype(schema="bar"), + Table(schema="bar"), + } + + +def test_alter_column_type_suggests_types(): + q = "ALTER TABLE foo ALTER COLUMN bar TYPE " + assert set(suggest_type(q, q)) == { + Datatype(schema=None), + Table(schema=None), + Schema(), + } + + +@pytest.mark.parametrize( + "text", + [ + "CREATE TABLE foo (bar ", + "CREATE TABLE foo (bar DOU", + "CREATE TABLE foo (bar INT, baz ", + "CREATE TABLE foo (bar INT, baz TEXT, qux ", + "CREATE FUNCTION foo (bar ", + "CREATE FUNCTION foo (bar INT, baz ", + "SELECT * FROM foo() AS bar (baz ", + "SELECT * FROM foo() AS bar (baz INT, qux ", + # make sure this doesn't trigger special completion + "CREATE TABLE foo (dt d", + ], +) +def test_identifier_suggests_types_in_parentheses(text): + assert set(suggest_type(text, text)) == { + Datatype(schema=None), + Table(schema=None), + Schema(), + } + + +@pytest.mark.parametrize( + "text", + [ + "SELECT foo ", + "SELECT foo FROM bar ", + "SELECT foo AS bar ", + "SELECT foo bar ", + "SELECT * FROM foo AS bar ", + "SELECT * FROM foo bar ", + "SELECT foo FROM (SELECT bar ", + ], +) +def test_alias_suggests_keywords(text): + suggestions = suggest_type(text, text) + assert suggestions == (Keyword(),) + + +def test_invalid_sql(): + # issue 317 + text = "selt *" + suggestions = suggest_type(text, text) + assert suggestions == (Keyword(),) + + +@pytest.mark.parametrize( + "text", + ["SELECT * FROM foo where created > now() - ", "select * from foo where bar "], +) +def test_suggest_where_keyword(text): + # https://github.com/dbcli/mycli/issues/135 + suggestions = suggest_type(text, text) + assert set(suggestions) == cols_etc("foo", last_keyword="WHERE") + + +@pytest.mark.parametrize( + "text, before, expected", + [ + ( + "\\ns abc SELECT ", + "SELECT ", + [ + Column(table_refs=(), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + ], + ), + ("\\ns abc SELECT foo ", "SELECT foo ", (Keyword(),)), + ( + "\\ns abc SELECT t1. FROM tabl1 t1", + "SELECT t1.", + [ + Table(schema="t1"), + View(schema="t1"), + Column(table_refs=((None, "tabl1", "t1", False),)), + Function(schema="t1"), + ], + ), + ], +) +def test_named_query_completion(text, before, expected): + suggestions = suggest_type(text, before) + assert set(expected) == set(suggestions) + + +def test_select_suggests_fields_from_function(): + suggestions = suggest_type("SELECT FROM func()", "SELECT ") + assert set(suggestions) == cols_etc("func", is_function=True, last_keyword="SELECT") + + +@pytest.mark.parametrize("sql", ["("]) +def test_leading_parenthesis(sql): + # No assertion for now; just make sure it doesn't crash + suggest_type(sql, sql) + + +@pytest.mark.parametrize("sql", ['select * from "', 'select * from "foo']) +def test_ignore_leading_double_quotes(sql): + suggestions = suggest_type(sql, sql) + assert FromClauseItem(schema=None) in set(suggestions) + + +@pytest.mark.parametrize( + "sql", + [ + "ALTER TABLE foo ALTER COLUMN ", + "ALTER TABLE foo ALTER COLUMN bar", + "ALTER TABLE foo DROP COLUMN ", + "ALTER TABLE foo DROP COLUMN bar", + ], +) +def test_column_keyword_suggests_columns(sql): + suggestions = suggest_type(sql, sql) + assert set(suggestions) == {Column(table_refs=((None, "foo", None, False),))} + + +def test_handle_unrecognized_kw_generously(): + sql = "SELECT * FROM sessions WHERE session = 1 AND " + suggestions = suggest_type(sql, sql) + expected = Column(table_refs=((None, "sessions", None, False),), qualifiable=True) + + assert expected in set(suggestions) + + +@pytest.mark.parametrize("sql", ["ALTER ", "ALTER TABLE foo ALTER "]) +def test_keyword_after_alter(sql): + assert Keyword("ALTER") in set(suggest_type(sql, sql)) diff --git a/tests/test_ssh_tunnel.py b/tests/test_ssh_tunnel.py new file mode 100644 index 0000000..ae865f4 --- /dev/null +++ b/tests/test_ssh_tunnel.py @@ -0,0 +1,188 @@ +import os +from unittest.mock import patch, MagicMock, ANY + +import pytest +from configobj import ConfigObj +from click.testing import CliRunner +from sshtunnel import SSHTunnelForwarder + +from pgcli.main import cli, PGCli +from pgcli.pgexecute import PGExecute + + +@pytest.fixture +def mock_ssh_tunnel_forwarder() -> MagicMock: + mock_ssh_tunnel_forwarder = MagicMock( + SSHTunnelForwarder, local_bind_ports=[1111], autospec=True + ) + with patch( + "pgcli.main.sshtunnel.SSHTunnelForwarder", + return_value=mock_ssh_tunnel_forwarder, + ) as mock: + yield mock + + +@pytest.fixture +def mock_pgexecute() -> MagicMock: + with patch.object(PGExecute, "__init__", return_value=None) as mock_pgexecute: + yield mock_pgexecute + + +def test_ssh_tunnel( + mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock +) -> None: + # Test with just a host + tunnel_url = "some.host" + db_params = { + "database": "dbname", + "host": "db.host", + "user": "db_user", + "passwd": "db_passwd", + } + expected_tunnel_params = { + "local_bind_address": ("127.0.0.1",), + "remote_bind_address": (db_params["host"], 5432), + "ssh_address_or_host": (tunnel_url, 22), + "logger": ANY, + } + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect(**db_params) + + mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) + mock_ssh_tunnel_forwarder.return_value.start.assert_called_once() + mock_pgexecute.assert_called_once() + + call_args, call_kwargs = mock_pgexecute.call_args + assert call_args == ( + db_params["database"], + db_params["user"], + db_params["passwd"], + "127.0.0.1", + pgcli.ssh_tunnel.local_bind_ports[0], + "", + ) + mock_ssh_tunnel_forwarder.reset_mock() + mock_pgexecute.reset_mock() + + # Test with a full url and with a specific db port + tunnel_user = "tunnel_user" + tunnel_passwd = "tunnel_pass" + tunnel_host = "some.other.host" + tunnel_port = 1022 + tunnel_url = f"ssh://{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}" + db_params["port"] = 1234 + + expected_tunnel_params["remote_bind_address"] = ( + db_params["host"], + db_params["port"], + ) + expected_tunnel_params["ssh_address_or_host"] = (tunnel_host, tunnel_port) + expected_tunnel_params["ssh_username"] = tunnel_user + expected_tunnel_params["ssh_password"] = tunnel_passwd + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect(**db_params) + + mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) + mock_ssh_tunnel_forwarder.return_value.start.assert_called_once() + mock_pgexecute.assert_called_once() + + call_args, call_kwargs = mock_pgexecute.call_args + assert call_args == ( + db_params["database"], + db_params["user"], + db_params["passwd"], + "127.0.0.1", + pgcli.ssh_tunnel.local_bind_ports[0], + "", + ) + mock_ssh_tunnel_forwarder.reset_mock() + mock_pgexecute.reset_mock() + + # Test with DSN + dsn = ( + f"user={db_params['user']} password={db_params['passwd']} " + f"host={db_params['host']} port={db_params['port']}" + ) + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect(dsn=dsn) + + expected_dsn = ( + f"user={db_params['user']} password={db_params['passwd']} " + f"host=127.0.0.1 port={pgcli.ssh_tunnel.local_bind_ports[0]}" + ) + + mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) + mock_pgexecute.assert_called_once() + + call_args, call_kwargs = mock_pgexecute.call_args + assert expected_dsn in call_args + + +def test_cli_with_tunnel() -> None: + runner = CliRunner() + tunnel_url = "mytunnel" + with patch.object( + PGCli, "__init__", autospec=True, return_value=None + ) as mock_pgcli: + runner.invoke(cli, ["--ssh-tunnel", tunnel_url]) + mock_pgcli.assert_called_once() + call_args, call_kwargs = mock_pgcli.call_args + assert call_kwargs["ssh_tunnel_url"] == tunnel_url + + +def test_config( + tmpdir: os.PathLike, mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock +) -> None: + pgclirc = str(tmpdir.join("rcfile")) + + tunnel_user = "tunnel_user" + tunnel_passwd = "tunnel_pass" + tunnel_host = "tunnel.host" + tunnel_port = 1022 + tunnel_url = f"{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}" + + tunnel2_url = "tunnel2.host" + + config = ConfigObj() + config.filename = pgclirc + config["ssh tunnels"] = {} + config["ssh tunnels"][r"\.com$"] = tunnel_url + config["ssh tunnels"][r"^hello-"] = tunnel2_url + config.write() + + # Unmatched host + pgcli = PGCli(pgclirc_file=pgclirc) + pgcli.connect(host="unmatched.host") + mock_ssh_tunnel_forwarder.assert_not_called() + + # Host matching first tunnel + pgcli = PGCli(pgclirc_file=pgclirc) + pgcli.connect(host="matched.host.com") + mock_ssh_tunnel_forwarder.assert_called_once() + call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args + assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port) + assert call_kwargs["ssh_username"] == tunnel_user + assert call_kwargs["ssh_password"] == tunnel_passwd + mock_ssh_tunnel_forwarder.reset_mock() + + # Host matching second tunnel + pgcli = PGCli(pgclirc_file=pgclirc) + pgcli.connect(host="hello-i-am-matched") + mock_ssh_tunnel_forwarder.assert_called_once() + + call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args + assert call_kwargs["ssh_address_or_host"] == (tunnel2_url, 22) + mock_ssh_tunnel_forwarder.reset_mock() + + # Host matching both tunnels (will use the first one matched) + pgcli = PGCli(pgclirc_file=pgclirc) + pgcli.connect(host="hello-i-am-matched.com") + mock_ssh_tunnel_forwarder.assert_called_once() + + call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args + assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port) + assert call_kwargs["ssh_username"] == tunnel_user + assert call_kwargs["ssh_password"] == tunnel_passwd diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..67d769f --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,92 @@ +import pytest +import psycopg +from pgcli.main import format_output, OutputSettings +from os import getenv + +POSTGRES_USER = getenv("PGUSER", "postgres") +POSTGRES_HOST = getenv("PGHOST", "localhost") +POSTGRES_PORT = getenv("PGPORT", 5432) +POSTGRES_PASSWORD = getenv("PGPASSWORD", "postgres") + + +def db_connection(dbname=None): + conn = psycopg.connect( + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + port=POSTGRES_PORT, + dbname=dbname, + ) + conn.autocommit = True + return conn + + +try: + conn = db_connection() + CAN_CONNECT_TO_DB = True + SERVER_VERSION = conn.info.parameter_status("server_version") + JSON_AVAILABLE = True + JSONB_AVAILABLE = True +except Exception as x: + CAN_CONNECT_TO_DB = JSON_AVAILABLE = JSONB_AVAILABLE = False + SERVER_VERSION = 0 + + +dbtest = pytest.mark.skipif( + not CAN_CONNECT_TO_DB, + reason="Need a postgres instance at localhost accessible by user 'postgres'", +) + + +requires_json = pytest.mark.skipif( + not JSON_AVAILABLE, reason="Postgres server unavailable or json type not defined" +) + + +requires_jsonb = pytest.mark.skipif( + not JSONB_AVAILABLE, reason="Postgres server unavailable or jsonb type not defined" +) + + +def create_db(dbname): + with db_connection().cursor() as cur: + try: + cur.execute("""CREATE DATABASE _test_db""") + except: + pass + + +def drop_tables(conn): + with conn.cursor() as cur: + cur.execute( + """ + DROP SCHEMA public CASCADE; + CREATE SCHEMA public; + DROP SCHEMA IF EXISTS schema1 CASCADE; + DROP SCHEMA IF EXISTS schema2 CASCADE""" + ) + + +def run( + executor, sql, join=False, expanded=False, pgspecial=None, exception_formatter=None +): + "Return string output for the sql to be run" + + results = executor.run(sql, pgspecial, exception_formatter) + formatted = [] + settings = OutputSettings( + table_format="psql", dcmlfmt="d", floatfmt="g", expanded=expanded + ) + for title, rows, headers, status, sql, success, is_special in results: + formatted.extend(format_output(title, rows, headers, status, settings)) + if join: + formatted = "\n".join(formatted) + + return formatted + + +def completions_to_set(completions): + return { + (completion.display_text, completion.display_meta_text) + for completion in completions + } @@ -0,0 +1,14 @@ +[tox] +envlist = py38, py39, py310, py311, py312 +[testenv] +deps = pytest>=2.7.0,<=3.0.7 + mock>=1.0.1 + behave>=1.2.4 + pexpect==3.3 + sshtunnel>=0.4.0 +commands = py.test + behave tests/features +passenv = PGHOST + PGPORT + PGUSER + PGPASSWORD |