summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/ci.yml66
-rw-r--r--.pre-commit-config.yaml2
-rw-r--r--.travis.yml51
-rw-r--r--AUTHORS4
-rw-r--r--DEVELOP.rst2
-rw-r--r--README.rst62
-rw-r--r--Vagrantfile81
-rw-r--r--changelog.rst34
-rw-r--r--pgcli/__init__.py2
-rw-r--r--pgcli/completion_refresher.py9
-rw-r--r--pgcli/config.py45
-rw-r--r--pgcli/magic.py8
-rw-r--r--pgcli/main.py93
-rw-r--r--pgcli/packages/parseutils/__init__.py32
-rw-r--r--pgcli/packages/parseutils/meta.py2
-rw-r--r--pgcli/packages/parseutils/tables.py3
-rw-r--r--pgcli/packages/pgliterals/pgliterals.json1
-rw-r--r--pgcli/packages/prioritization.py4
-rw-r--r--pgcli/packages/prompt_utils.py4
-rw-r--r--pgcli/packages/sqlcompletion.py2
-rw-r--r--pgcli/pgclirc19
-rw-r--r--pgcli/pgcompleter.py40
-rw-r--r--pgcli/pgexecute.py178
-rw-r--r--pgcli/pgtoolbar.py22
-rw-r--r--requirements-dev.txt1
-rw-r--r--tests/conftest.py2
-rw-r--r--tests/features/db_utils.py4
-rw-r--r--tests/features/environment.py13
-rw-r--r--tests/features/fixture_utils.py2
-rw-r--r--tests/features/steps/basic_commands.py13
-rw-r--r--tests/features/steps/crud_database.py8
-rw-r--r--tests/features/steps/expanded.py2
-rw-r--r--tests/features/steps/iocommands.py6
-rw-r--r--tests/features/steps/specials.py7
-rw-r--r--tests/features/steps/wrappers.py10
-rw-r--r--[-rwxr-xr-x]tests/features/wrappager.py0
-rw-r--r--tests/metadata.py6
-rw-r--r--tests/parseutils/test_parseutils.py59
-rw-r--r--tests/test_completion_refresher.py20
-rw-r--r--tests/test_config.py19
-rw-r--r--tests/test_main.py9
-rw-r--r--tests/test_naive_completion.py2
-rw-r--r--tests/test_pgexecute.py140
-rw-r--r--tests/test_pgspecial.py24
-rw-r--r--tests/test_prompt_utils.py2
-rw-r--r--tests/test_rowlimit.py2
-rw-r--r--tests/test_smart_completion_multiple_schemata.py32
-rw-r--r--tests/test_smart_completion_public_schema_only.py6
-rw-r--r--tests/test_sqlcompletion.py425
-rw-r--r--tests/utils.py8
50 files changed, 953 insertions, 635 deletions
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 0000000..ce54d6f
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,66 @@
+name: pgcli
+
+on:
+ pull_request:
+ paths-ignore:
+ - '**.rst'
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+
+ strategy:
+ matrix:
+ python-version: [3.6, 3.7, 3.8, 3.9]
+
+ services:
+ postgres:
+ image: postgres:9.6
+ env:
+ POSTGRES_USER: postgres
+ POSTGRES_PASSWORD: postgres
+ ports:
+ - 5432:5432
+ options: >-
+ --health-cmd pg_isready
+ --health-interval 10s
+ --health-timeout 5s
+ --health-retries 5
+
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Install requirements
+ run: |
+ pip install -U pip setuptools
+ pip install --no-cache-dir .
+ pip install -r requirements-dev.txt
+ pip install keyrings.alt>=3.1
+
+ - name: Run unit tests
+ run: coverage run --source pgcli -m py.test
+
+ - name: Run integration tests
+ env:
+ PGUSER: postgres
+ PGPASSWORD: postgres
+
+ run: behave tests/features --no-capture
+
+ - name: Check changelog for ReST compliance
+ run: rst2html.py --halt=warning changelog.rst >/dev/null
+
+ - name: Run Black
+ run: pip install black && black --check .
+ if: matrix.python-version == '3.6'
+
+ - name: Coverage
+ run: |
+ coverage combine
+ coverage report
+ codecov
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index b970ac5..9e27ab8 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black
- rev: stable
+ rev: 21.5b0
hooks:
- id: black
language_version: python3.7
diff --git a/.travis.yml b/.travis.yml
deleted file mode 100644
index 8d50fbd..0000000
--- a/.travis.yml
+++ /dev/null
@@ -1,51 +0,0 @@
-dist: xenial
-
-sudo: required
-
-language: python
-
-python:
- - "3.6"
- - "3.7"
- - "3.8"
- - "3.9-dev"
-
-before_install:
- - which python
- - which pip
- - pip install -U setuptools
-
-install:
- - pip install --no-cache-dir .
- - pip install -r requirements-dev.txt
- - pip install keyrings.alt>=3.1
-
-script:
- - set -e
- - coverage run --source pgcli -m py.test
- - cd tests
- - behave --no-capture
- - cd ..
- # check for changelog ReST compliance
- - rst2html.py --halt=warning changelog.rst >/dev/null
- # check for black code compliance, 3.6 only
- - if [[ "$TRAVIS_PYTHON_VERSION" == "3.6" ]]; then pip install black && black --check . ; else echo "Skipping black for $TRAVIS_PYTHON_VERSION"; fi
- - set +e
-
-after_success:
- - coverage combine
- - codecov
-
-notifications:
- webhooks:
- urls:
- - YOUR_WEBHOOK_URL
- on_success: change # options: [always|never|change] default: always
- on_failure: always # options: [always|never|change] default: always
- on_start: false # default: false
-
-services:
- - postgresql
-
-addons:
- postgresql: "9.6"
diff --git a/AUTHORS b/AUTHORS
index baaf758..bcfba6a 100644
--- a/AUTHORS
+++ b/AUTHORS
@@ -114,6 +114,10 @@ Contributors:
* Tom Caruso (tomplex)
* Jan Brun Rasmussen (janbrunrasmussen)
* Kevin Marsh (kevinmarsh)
+ * Eero Ruohola (ruohola)
+ * Miroslav Šedivý (eumiro)
+ * Eric R Young (ERYoung11)
+ * Paweł Sacawa (psacawa)
Creator:
--------
diff --git a/DEVELOP.rst b/DEVELOP.rst
index 18adf9c..e262823 100644
--- a/DEVELOP.rst
+++ b/DEVELOP.rst
@@ -170,7 +170,7 @@ Troubleshooting the integration tests
- Make sure postgres instance on localhost is running
- Check your ``pg_hba.conf`` file to verify local connections are enabled
- Check `this issue <https://github.com/dbcli/pgcli/issues/945>`_ for relevant information.
-- Contact us on `gitter <https://gitter.im/dbcli/pgcli/>`_ or `file an issue <https://github.com/dbcli/pgcli/issues/new>`_.
+- `File an issue <https://github.com/dbcli/pgcli/issues/new>`_.
Coding Style
------------
diff --git a/README.rst b/README.rst
index d593427..95137f7 100644
--- a/README.rst
+++ b/README.rst
@@ -1,7 +1,7 @@
A REPL for Postgres
-------------------
-|Build Status| |CodeCov| |PyPI| |Landscape| |Gitter|
+|Build Status| |CodeCov| |PyPI| |Landscape|
This is a postgres client that does auto-completion and syntax highlighting.
@@ -62,32 +62,32 @@ For more details:
Usage: pgcli [OPTIONS] [DBNAME] [USERNAME]
Options:
- -h, --host TEXT Host address of the postgres database.
- -p, --port INTEGER Port number at which the postgres instance is
- listening.
- -U, --username TEXT Username to connect to the postgres database.
- -u, --user TEXT Username to connect to the postgres database.
- -W, --password Force password prompt.
- -w, --no-password Never prompt for password.
- --single-connection Do not use a separate connection for completions.
- -v, --version Version of pgcli.
- -d, --dbname TEXT database name to connect to.
- --pgclirc PATH Location of pgclirc file.
- -D, --dsn TEXT Use DSN configured into the [alias_dsn] section of
- pgclirc file.
- --list-dsn list of DSN configured into the [alias_dsn] section
- of pgclirc file.
- --row-limit INTEGER Set threshold for row limit prompt. Use 0 to disable
- prompt.
- --less-chatty Skip intro on startup and goodbye on exit.
- --prompt TEXT Prompt format (Default: "\u@\h:\d> ").
- --prompt-dsn TEXT Prompt format for connections using DSN aliases
- (Default: "\u@\h:\d> ").
- -l, --list list available databases, then exit.
- --auto-vertical-output Automatically switch to vertical output mode if the
- result is wider than the terminal width.
- --warn / --no-warn Warn before running a destructive query.
- --help Show this message and exit.
+ -h, --host TEXT Host address of the postgres database.
+ -p, --port INTEGER Port number at which the postgres instance is
+ listening.
+ -U, --username TEXT Username to connect to the postgres database.
+ -u, --user TEXT Username to connect to the postgres database.
+ -W, --password Force password prompt.
+ -w, --no-password Never prompt for password.
+ --single-connection Do not use a separate connection for completions.
+ -v, --version Version of pgcli.
+ -d, --dbname TEXT database name to connect to.
+ --pgclirc FILE Location of pgclirc file.
+ -D, --dsn TEXT Use DSN configured into the [alias_dsn] section
+ of pgclirc file.
+ --list-dsn list of DSN configured into the [alias_dsn]
+ section of pgclirc file.
+ --row-limit INTEGER Set threshold for row limit prompt. Use 0 to
+ disable prompt.
+ --less-chatty Skip intro on startup and goodbye on exit.
+ --prompt TEXT Prompt format (Default: "\u@\h:\d> ").
+ --prompt-dsn TEXT Prompt format for connections using DSN aliases
+ (Default: "\u@\h:\d> ").
+ -l, --list list available databases, then exit.
+ --auto-vertical-output Automatically switch to vertical output mode if
+ the result is wider than the terminal width.
+ --warn [all|moderate|off] Warn before running a destructive query.
+ --help Show this message and exit.
``pgcli`` also supports many of the same `environment variables`_ as ``psql`` for login options (e.g. ``PGHOST``, ``PGPORT``, ``PGUSER``, ``PGPASSWORD``, ``PGDATABASE``).
@@ -352,8 +352,8 @@ interface to Postgres database.
Thanks to all the beta testers and contributors for your time and patience. :)
-.. |Build Status| image:: https://api.travis-ci.org/dbcli/pgcli.svg?branch=master
- :target: https://travis-ci.org/dbcli/pgcli
+.. |Build Status| image:: https://github.com/dbcli/pgcli/workflows/pgcli/badge.svg
+ :target: https://github.com/dbcli/pgcli/actions?query=workflow%3Apgcli
.. |CodeCov| image:: https://codecov.io/gh/dbcli/pgcli/branch/master/graph/badge.svg
:target: https://codecov.io/gh/dbcli/pgcli
@@ -366,7 +366,3 @@ Thanks to all the beta testers and contributors for your time and patience. :)
.. |PyPI| image:: https://img.shields.io/pypi/v/pgcli.svg
:target: https://pypi.python.org/pypi/pgcli/
:alt: Latest Version
-
-.. |Gitter| image:: https://badges.gitter.im/Join%20Chat.svg
- :target: https://gitter.im/dbcli/pgcli?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge
- :alt: Gitter Chat
diff --git a/Vagrantfile b/Vagrantfile
index 0313520..297e70a 100644
--- a/Vagrantfile
+++ b/Vagrantfile
@@ -1,5 +1,7 @@
# -*- mode: ruby -*-
# vi: set ft=ruby :
+#
+#
Vagrant.configure(2) do |config|
@@ -9,20 +11,23 @@ Vagrant.configure(2) do |config|
pgcli_description = "Postgres CLI with autocompletion and syntax highlighting"
config.vm.define "debian" do |debian|
- debian.vm.box = "chef/debian-7.8"
+ debian.vm.box = "bento/debian-10.8"
debian.vm.provision "shell", inline: <<-SHELL
- echo "-> Building DEB on `lsb_release -s`"
+ echo "-> Building DEB on `lsb_release -d`"
sudo apt-get update
sudo apt-get install -y libpq-dev python-dev python-setuptools rubygems
- sudo easy_install pip
- sudo pip install virtualenv virtualenv-tools
+ sudo apt install -y python3-pip
+ sudo pip3 install --no-cache-dir virtualenv virtualenv-tools3
+ sudo apt-get install -y ruby-dev
+ sudo apt-get install -y git
+ sudo apt-get install -y rpm librpmbuild8
+
sudo gem install fpm
+
echo "-> Cleaning up old workspace"
- rm -rf build
+ sudo rm -rf build
mkdir -p build/usr/share
virtualenv build/usr/share/pgcli
- build/usr/share/pgcli/bin/pip install -U pip distribute
- build/usr/share/pgcli/bin/pip uninstall -y distribute
build/usr/share/pgcli/bin/pip install /pgcli
echo "-> Cleaning Virtualenv"
@@ -45,24 +50,59 @@ Vagrant.configure(2) do |config|
--url https://github.com/dbcli/pgcli \
--description "#{pgcli_description}" \
--license 'BSD'
+
SHELL
end
+
+# This is considerably more messy than the debian section. I had to go off-standard to update
+# some packages to get this to work.
+
config.vm.define "centos" do |centos|
- centos.vm.box = "chef/centos-7.0"
+
+ centos.vm.box = "bento/centos-7.9"
+ centos.vm.box_version = "202012.21.0"
centos.vm.provision "shell", inline: <<-SHELL
#!/bin/bash
- echo "-> Building RPM on `lsb_release -s`"
- sudo yum install -y rpm-build gcc ruby-devel postgresql-devel python-devel rubygems
- sudo easy_install pip
- sudo pip install virtualenv virtualenv-tools
- sudo gem install fpm
+ echo "-> Building RPM on `hostnamectl | grep "Operating System"`"
+ export PATH=/usr/local/rvm/gems/ruby-2.6.3/bin:/usr/local/rvm/gems/ruby-2.6.3@global/bin:/usr/local/rvm/rubies/ruby-2.6.3/bin:/usr/local/sbin:/usr/local/bin:/sbin:/bin:/usr/sbin:/usr/bin:/usr/local/rvm/bin:/root/bin
+ echo "PATH -> " $PATH
+
+#####
+### get base updates
+
+ sudo yum install -y rpm-build gcc postgresql-devel python-devel python3-pip git python3-devel
+
+######
+### install FPM, which we need to install to get an up-to-date version of ruby, which we need for git
+
+ echo "-> Get FPM installed"
+ # import the necessary GPG keys
+ gpg --keyserver hkp://pool.sks-keyservers.net --recv-keys 409B6B1796C275462A1703113804BB82D39DC0E3 7D2BAF1CF37B13E2069D6956105BD0E739499BDB
+ sudo gpg --keyserver hkp://pool.sks-keyservers.net --recv-keys 409B6B1796C275462A1703113804BB82D39DC0E3 7D2BAF1CF37B13E2069D6956105BD0E739499BDB
+ # install RVM
+ sudo curl -sSL https://get.rvm.io | sudo bash -s stable
+ sudo usermod -aG rvm vagrant
+ sudo usermod -aG rvm root
+ sudo /usr/local/rvm/bin/rvm alias create default 2.6.3
+ source /etc/profile.d/rvm.sh
+
+ # install a newer version of ruby. centos7 only comes with ruby2.0.0, which isn't good enough for git.
+ sudo yum install -y ruby-devel
+ sudo /usr/local/rvm/bin/rvm install 2.6.3
+
+ #
+ # yes,this gives an error about generating doc but we don't need the doc.
+
+ /usr/local/rvm/gems/ruby-2.6.3/wrappers/gem install fpm
+
+######
+
+ sudo pip3 install virtualenv virtualenv-tools3
echo "-> Cleaning up old workspace"
rm -rf build
mkdir -p build/usr/share
virtualenv build/usr/share/pgcli
- build/usr/share/pgcli/bin/pip install -U pip distribute
- build/usr/share/pgcli/bin/pip uninstall -y distribute
build/usr/share/pgcli/bin/pip install /pgcli
echo "-> Cleaning Virtualenv"
@@ -74,9 +114,9 @@ Vagrant.configure(2) do |config|
find build -iname '*.pyc' -delete
find build -iname '*.pyo' -delete
+ cd /home/vagrant
echo "-> Creating PgCLI RPM"
- echo $PATH
- sudo /usr/local/bin/fpm -t rpm -s dir -C build -n pgcli -v #{pgcli_version} \
+ /usr/local/rvm/gems/ruby-2.6.3/gems/fpm-1.12.0/bin/fpm -t rpm -s dir -C build -n pgcli -v #{pgcli_version} \
-a all \
-d postgresql-devel \
-d python-devel \
@@ -86,8 +126,13 @@ Vagrant.configure(2) do |config|
--url https://github.com/dbcli/pgcli \
--description "#{pgcli_description}" \
--license 'BSD'
- SHELL
+
+
+ SHELL
+
+
end
+
end
diff --git a/changelog.rst b/changelog.rst
index ec50635..d732088 100644
--- a/changelog.rst
+++ b/changelog.rst
@@ -1,3 +1,37 @@
+TBD
+=====
+
+Features:
+---------
+
+Bug fixes:
+----------
+
+3.2.0
+=====
+
+Release date: 2021/08/23
+
+Features:
+---------
+
+* Consider `update` queries destructive and issue a warning. Change
+ `destructive_warning` setting to `all|moderate|off`, vs `true|false`. (#1239)
+* Skip initial comment in .pg_session even if it doesn't start with '#'
+* Include functions from schemas in search_path. (`Amjith Ramanujam`_)
+
+Bug fixes:
+----------
+
+* Fix issue where `syntax_style` config value would not have any effect. (#1212)
+* Fix crash because of not found `InputMode.REPLACE_SINGLE` with prompt-toolkit < 3.0.6
+* Fix comments being lost in config when saving a named query. (#1240)
+* Fix IPython magic for ipython-sql >= 0.4.0
+* Fix pager not being used when output format is set to csv. (#1238)
+* Add function literals random, generate_series, generate_subscripts
+* Fix ANSI escape codes in first line make the cli choose expanded output incorrectly
+* Fix pgcli crashing with virtual `pgbouncer` database. (#1093)
+
3.1.0
=====
diff --git a/pgcli/__init__.py b/pgcli/__init__.py
index f5f41e5..1173108 100644
--- a/pgcli/__init__.py
+++ b/pgcli/__init__.py
@@ -1 +1 @@
-__version__ = "3.1.0"
+__version__ = "3.2.0"
diff --git a/pgcli/completion_refresher.py b/pgcli/completion_refresher.py
index cf0879f..1039d51 100644
--- a/pgcli/completion_refresher.py
+++ b/pgcli/completion_refresher.py
@@ -3,10 +3,9 @@ import os
from collections import OrderedDict
from .pgcompleter import PGCompleter
-from .pgexecute import PGExecute
-class CompletionRefresher(object):
+class CompletionRefresher:
refreshers = OrderedDict()
@@ -27,6 +26,10 @@ class CompletionRefresher(object):
has completed the refresh. The newly created completion
object will be passed in as an argument to each callback.
"""
+ if executor.is_virtual_database():
+ # do nothing
+ return [(None, None, None, "Auto-completion refresh can't be started.")]
+
if self.is_refreshing():
self._restart_refresh.set()
return [(None, None, None, "Auto-completion refresh restarted.")]
@@ -141,7 +144,7 @@ def refresh_casing(completer, executor):
with open(casing_file, "w") as f:
f.write(casing_prefs)
if os.path.isfile(casing_file):
- with open(casing_file, "r") as f:
+ with open(casing_file) as f:
completer.extend_casing([line.strip() for line in f])
diff --git a/pgcli/config.py b/pgcli/config.py
index 0fc42dd..22f08dc 100644
--- a/pgcli/config.py
+++ b/pgcli/config.py
@@ -3,6 +3,8 @@ import shutil
import os
import platform
from os.path import expanduser, exists, dirname
+import re
+from typing import TextIO
from configobj import ConfigObj
@@ -16,11 +18,15 @@ def config_location():
def load_config(usr_cfg, def_cfg=None):
- cfg = ConfigObj()
- cfg.merge(ConfigObj(def_cfg, interpolation=False))
- cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8"))
+ # avoid config merges when possible. For writing, we need an umerged config instance.
+ # see https://github.com/dbcli/pgcli/issues/1240 and https://github.com/DiffSK/configobj/issues/171
+ if def_cfg:
+ cfg = ConfigObj()
+ cfg.merge(ConfigObj(def_cfg, interpolation=False))
+ cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8"))
+ else:
+ cfg = ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8")
cfg.filename = expanduser(usr_cfg)
-
return cfg
@@ -44,12 +50,16 @@ def upgrade_config(config, def_config):
cfg.write()
+def get_config_filename(pgclirc_file=None):
+ return pgclirc_file or "%sconfig" % config_location()
+
+
def get_config(pgclirc_file=None):
from pgcli import __file__ as package_root
package_root = os.path.dirname(package_root)
- pgclirc_file = pgclirc_file or "%sconfig" % config_location()
+ pgclirc_file = get_config_filename(pgclirc_file)
default_config = os.path.join(package_root, "pgclirc")
write_default_config(default_config, pgclirc_file)
@@ -62,3 +72,28 @@ def get_casing_file(config):
if casing_file == "default":
casing_file = config_location() + "casing"
return casing_file
+
+
+def skip_initial_comment(f_stream: TextIO) -> int:
+ """
+ Initial comment in ~/.pg_service.conf is not always marked with '#'
+ which crashes the parser. This function takes a file object and
+ "rewinds" it to the beginning of the first section,
+ from where on it can be parsed safely
+
+ :return: number of skipped lines
+ """
+ section_regex = r"\s*\["
+ pos = f_stream.tell()
+ lines_skipped = 0
+ while True:
+ line = f_stream.readline()
+ if line == "":
+ break
+ if re.match(section_regex, line) is not None:
+ f_stream.seek(pos)
+ break
+ else:
+ pos += len(line)
+ lines_skipped += 1
+ return lines_skipped
diff --git a/pgcli/magic.py b/pgcli/magic.py
index f58f415..6e58f28 100644
--- a/pgcli/magic.py
+++ b/pgcli/magic.py
@@ -25,7 +25,11 @@ def pgcli_line_magic(line):
if hasattr(sql.connection.Connection, "get"):
conn = sql.connection.Connection.get(parsed["connection"])
else:
- conn = sql.connection.Connection.set(parsed["connection"])
+ try:
+ conn = sql.connection.Connection.set(parsed["connection"])
+ # a new positional argument was added to Connection.set in version 0.4.0 of ipython-sql
+ except TypeError:
+ conn = sql.connection.Connection.set(parsed["connection"], False)
try:
# A corresponding pgcli object already exists
@@ -43,7 +47,7 @@ def pgcli_line_magic(line):
conn._pgcli = pgcli
# For convenience, print the connection alias
- print("Connected: {}".format(conn.name))
+ print(f"Connected: {conn.name}")
try:
pgcli.run_cli()
diff --git a/pgcli/main.py b/pgcli/main.py
index b146898..5135f6f 100644
--- a/pgcli/main.py
+++ b/pgcli/main.py
@@ -2,8 +2,9 @@ import platform
import warnings
from os.path import expanduser
-from configobj import ConfigObj
+from configobj import ConfigObj, ParseError
from pgspecial.namedqueries import NamedQueries
+from .config import skip_initial_comment
warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2")
@@ -20,12 +21,12 @@ import datetime as dt
import itertools
import platform
from time import time, sleep
-from codecs import open
keyring = None # keyring will be loaded later
from cli_helpers.tabular_output import TabularOutputFormatter
from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers
+from cli_helpers.utils import strip_ansi
import click
try:
@@ -62,6 +63,7 @@ from .config import (
config_location,
ensure_dir_exists,
get_config,
+ get_config_filename,
)
from .key_bindings import pgcli_bindings
from .packages.prompt_utils import confirm_destructive_query
@@ -122,7 +124,7 @@ class PgCliQuitError(Exception):
pass
-class PGCli(object):
+class PGCli:
default_prompt = "\\u@\\h:\\d> "
max_len_prompt = 30
@@ -175,7 +177,11 @@ class PGCli(object):
# Load config.
c = self.config = get_config(pgclirc_file)
- NamedQueries.instance = NamedQueries.from_config(self.config)
+ # at this point, config should be written to pgclirc_file if it did not exist. Read it.
+ self.config_writer = load_config(get_config_filename(pgclirc_file))
+
+ # make sure to use self.config_writer, not self.config
+ NamedQueries.instance = NamedQueries.from_config(self.config_writer)
self.logger = logging.getLogger(__name__)
self.initialize_logging()
@@ -201,8 +207,11 @@ class PGCli(object):
self.syntax_style = c["main"]["syntax_style"]
self.cli_style = c["colors"]
self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
- c_dest_warning = c["main"].as_bool("destructive_warning")
- self.destructive_warning = c_dest_warning if warn is None else warn
+ self.destructive_warning = warn or c["main"]["destructive_warning"]
+ # also handle boolean format of destructive warning
+ self.destructive_warning = {"true": "all", "false": "off"}.get(
+ self.destructive_warning.lower(), self.destructive_warning
+ )
self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty")
self.null_string = c["main"].get("null_string", "<null>")
self.prompt_format = (
@@ -325,11 +334,11 @@ class PGCli(object):
if pattern not in TabularOutputFormatter().supported_formats:
raise ValueError()
self.table_format = pattern
- yield (None, None, None, "Changed table format to {}".format(pattern))
+ yield (None, None, None, f"Changed table format to {pattern}")
except ValueError:
- msg = "Table format {} not recognized. Allowed formats:".format(pattern)
+ msg = f"Table format {pattern} not recognized. Allowed formats:"
for table_type in TabularOutputFormatter().supported_formats:
- msg += "\n\t{}".format(table_type)
+ msg += f"\n\t{table_type}"
msg += "\nCurrently set to: %s" % self.table_format
yield (None, None, None, msg)
@@ -386,10 +395,13 @@ class PGCli(object):
try:
with open(os.path.expanduser(pattern), encoding="utf-8") as f:
query = f.read()
- except IOError as e:
+ except OSError as e:
return [(None, None, None, str(e), "", False, True)]
- if self.destructive_warning and confirm_destructive_query(query) is False:
+ if (
+ self.destructive_warning != "off"
+ and confirm_destructive_query(query, self.destructive_warning) is False
+ ):
message = "Wise choice. Command execution stopped."
return [(None, None, None, message)]
@@ -407,7 +419,7 @@ class PGCli(object):
if not os.path.isfile(filename):
try:
open(filename, "w").close()
- except IOError as e:
+ except OSError as e:
self.output_file = None
message = str(e) + "\nFile output disabled"
return [(None, None, None, message, "", False, True)]
@@ -479,7 +491,7 @@ class PGCli(object):
service_config, file = parse_service_info(service)
if service_config is None:
click.secho(
- "service '%s' was not found in %s" % (service, file), err=True, fg="red"
+ f"service '{service}' was not found in {file}", err=True, fg="red"
)
exit(1)
self.connect(
@@ -515,7 +527,7 @@ class PGCli(object):
passwd = os.environ.get("PGPASSWORD", "")
# Find password from store
- key = "%s@%s" % (user, host)
+ key = f"{user}@{host}"
keyring_error_message = dedent(
"""\
{}
@@ -644,8 +656,10 @@ class PGCli(object):
query = MetaQuery(query=text, successful=False)
try:
- if self.destructive_warning:
- destroy = confirm = confirm_destructive_query(text)
+ if self.destructive_warning != "off":
+ destroy = confirm = confirm_destructive_query(
+ text, self.destructive_warning
+ )
if destroy is False:
click.secho("Wise choice!")
raise KeyboardInterrupt
@@ -677,7 +691,7 @@ class PGCli(object):
click.echo(text, file=f)
click.echo("\n".join(output), file=f)
click.echo("", file=f) # extra newline
- except IOError as e:
+ except OSError as e:
click.secho(str(e), err=True, fg="red")
else:
if output:
@@ -729,7 +743,6 @@ class PGCli(object):
if not self.less_chatty:
print("Server: PostgreSQL", self.pgexecute.server_version)
print("Version:", __version__)
- print("Chat: https://gitter.im/dbcli/pgcli")
print("Home: http://pgcli.com")
try:
@@ -753,11 +766,7 @@ class PGCli(object):
while self.watch_command:
try:
query = self.execute_command(self.watch_command)
- click.echo(
- "Waiting for {0} seconds before repeating".format(
- timing
- )
- )
+ click.echo(f"Waiting for {timing} seconds before repeating")
sleep(timing)
except KeyboardInterrupt:
self.watch_command = None
@@ -979,16 +988,13 @@ class PGCli(object):
callback = functools.partial(
self._on_completions_refreshed, persist_priorities=persist_priorities
)
- self.completion_refresher.refresh(
+ return self.completion_refresher.refresh(
self.pgexecute,
self.pgspecial,
callback,
history=history,
settings=self.settings,
)
- return [
- (None, None, None, "Auto-completion refresh started in the background.")
- ]
def _on_completions_refreshed(self, new_completer, persist_priorities):
self._swap_completer_objects(new_completer, persist_priorities)
@@ -1049,7 +1055,7 @@ class PGCli(object):
str(self.pgexecute.port) if self.pgexecute.port is not None else "5432",
)
string = string.replace("\\i", str(self.pgexecute.pid) or "(none)")
- string = string.replace("\\#", "#" if (self.pgexecute.superuser) else ">")
+ string = string.replace("\\#", "#" if self.pgexecute.superuser else ">")
string = string.replace("\\n", "\n")
return string
@@ -1075,9 +1081,10 @@ class PGCli(object):
def echo_via_pager(self, text, color=None):
if self.pgspecial.pager_config == PAGER_OFF or self.watch_command:
click.echo(text, color=color)
- elif "pspg" in os.environ.get("PAGER", "") and self.table_format == "csv":
- click.echo_via_pager(text, color)
- elif self.pgspecial.pager_config == PAGER_LONG_OUTPUT:
+ elif (
+ self.pgspecial.pager_config == PAGER_LONG_OUTPUT
+ and self.table_format != "csv"
+ ):
lines = text.split("\n")
# The last 4 lines are reserved for the pgcli menu and padding
@@ -1192,7 +1199,10 @@ class PGCli(object):
help="Automatically switch to vertical output mode if the result is wider than the terminal width.",
)
@click.option(
- "--warn/--no-warn", default=None, help="Warn before running a destructive query."
+ "--warn",
+ default=None,
+ type=click.Choice(["all", "moderate", "off"]),
+ help="Warn before running a destructive query.",
)
@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1)
@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1)
@@ -1384,7 +1394,7 @@ def is_mutating(status):
if not status:
return False
- mutating = set(["insert", "update", "delete"])
+ mutating = {"insert", "update", "delete"}
return status.split(None, 1)[0].lower() in mutating
@@ -1475,7 +1485,12 @@ def format_output(title, cur, headers, status, settings):
formatted = iter(formatted.splitlines())
first_line = next(formatted)
formatted = itertools.chain([first_line], formatted)
- if not expanded and max_width and len(first_line) > max_width and headers:
+ if (
+ not expanded
+ and max_width
+ and len(strip_ansi(first_line)) > max_width
+ and headers
+ ):
formatted = formatter.format_output(
cur, headers, format_name="vertical", column_types=None, **output_kwargs
)
@@ -1502,10 +1517,16 @@ def parse_service_info(service):
service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf")
else:
service_file = expanduser("~/.pg_service.conf")
- if not service:
+ if not service or not os.path.exists(service_file):
# nothing to do
return None, service_file
- service_file_config = ConfigObj(service_file)
+ with open(service_file, newline="") as f:
+ skipped_lines = skip_initial_comment(f)
+ try:
+ service_file_config = ConfigObj(f)
+ except ParseError as err:
+ err.line_number += skipped_lines
+ raise err
if service not in service_file_config:
return None, service_file
service_conf = service_file_config.get(service)
diff --git a/pgcli/packages/parseutils/__init__.py b/pgcli/packages/parseutils/__init__.py
index a11e7bf..1acc008 100644
--- a/pgcli/packages/parseutils/__init__.py
+++ b/pgcli/packages/parseutils/__init__.py
@@ -1,22 +1,34 @@
import sqlparse
-def query_starts_with(query, prefixes):
+def query_starts_with(formatted_sql, prefixes):
"""Check if the query starts with any item from *prefixes*."""
prefixes = [prefix.lower() for prefix in prefixes]
- formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
-def queries_start_with(queries, prefixes):
- """Check if any queries start with any item from *prefixes*."""
- for query in sqlparse.split(queries):
- if query and query_starts_with(query, prefixes) is True:
- return True
- return False
+def query_is_unconditional_update(formatted_sql):
+ """Check if the query starts with UPDATE and contains no WHERE."""
+ tokens = formatted_sql.split()
+ return bool(tokens) and tokens[0] == "update" and "where" not in tokens
+
+def query_is_simple_update(formatted_sql):
+ """Check if the query starts with UPDATE."""
+ tokens = formatted_sql.split()
+ return bool(tokens) and tokens[0] == "update"
-def is_destructive(queries):
+
+def is_destructive(queries, warning_level="all"):
"""Returns if any of the queries in *queries* is destructive."""
keywords = ("drop", "shutdown", "delete", "truncate", "alter")
- return queries_start_with(queries, keywords)
+ for query in sqlparse.split(queries):
+ if query:
+ formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
+ if query_starts_with(formatted_sql, keywords):
+ return True
+ if query_is_unconditional_update(formatted_sql):
+ return True
+ if warning_level == "all" and query_is_simple_update(formatted_sql):
+ return True
+ return False
diff --git a/pgcli/packages/parseutils/meta.py b/pgcli/packages/parseutils/meta.py
index 108c01a..333cab5 100644
--- a/pgcli/packages/parseutils/meta.py
+++ b/pgcli/packages/parseutils/meta.py
@@ -50,7 +50,7 @@ def parse_defaults(defaults_string):
yield current
-class FunctionMetadata(object):
+class FunctionMetadata:
def __init__(
self,
schema_name,
diff --git a/pgcli/packages/parseutils/tables.py b/pgcli/packages/parseutils/tables.py
index 0ec3e69..aaa676c 100644
--- a/pgcli/packages/parseutils/tables.py
+++ b/pgcli/packages/parseutils/tables.py
@@ -42,8 +42,7 @@ def extract_from_part(parsed, stop_at_punctuation=True):
for item in parsed.tokens:
if tbl_prefix_seen:
if is_subselect(item):
- for x in extract_from_part(item, stop_at_punctuation):
- yield x
+ yield from extract_from_part(item, stop_at_punctuation)
elif stop_at_punctuation and item.ttype is Punctuation:
return
# An incomplete nested select won't be recognized correctly as a
diff --git a/pgcli/packages/pgliterals/pgliterals.json b/pgcli/packages/pgliterals/pgliterals.json
index c7b74b5..df00817 100644
--- a/pgcli/packages/pgliterals/pgliterals.json
+++ b/pgcli/packages/pgliterals/pgliterals.json
@@ -392,6 +392,7 @@
"QUOTE_NULLABLE",
"RADIANS",
"RADIUS",
+ "RANDOM",
"RANK",
"REGEXP_MATCH",
"REGEXP_MATCHES",
diff --git a/pgcli/packages/prioritization.py b/pgcli/packages/prioritization.py
index e92dcbb..f5a9cb5 100644
--- a/pgcli/packages/prioritization.py
+++ b/pgcli/packages/prioritization.py
@@ -16,10 +16,10 @@ def _compile_regex(keyword):
keywords = get_literals("keywords")
-keyword_regexs = dict((kw, _compile_regex(kw)) for kw in keywords)
+keyword_regexs = {kw: _compile_regex(kw) for kw in keywords}
-class PrevalenceCounter(object):
+class PrevalenceCounter:
def __init__(self):
self.keyword_counts = defaultdict(int)
self.name_counts = defaultdict(int)
diff --git a/pgcli/packages/prompt_utils.py b/pgcli/packages/prompt_utils.py
index 3c58490..e8589de 100644
--- a/pgcli/packages/prompt_utils.py
+++ b/pgcli/packages/prompt_utils.py
@@ -3,7 +3,7 @@ import click
from .parseutils import is_destructive
-def confirm_destructive_query(queries):
+def confirm_destructive_query(queries, warning_level):
"""Check if the query is destructive and prompts the user to confirm.
Returns:
@@ -15,7 +15,7 @@ def confirm_destructive_query(queries):
prompt_text = (
"You're about to run a destructive command.\n" "Do you want to proceed? (y/n)"
)
- if is_destructive(queries) and sys.stdin.isatty():
+ if is_destructive(queries, warning_level) and sys.stdin.isatty():
return prompt(prompt_text, type=bool)
diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py
index 6ef8859..6305301 100644
--- a/pgcli/packages/sqlcompletion.py
+++ b/pgcli/packages/sqlcompletion.py
@@ -47,7 +47,7 @@ Alias = namedtuple("Alias", ["aliases"])
Path = namedtuple("Path", [])
-class SqlStatement(object):
+class SqlStatement:
def __init__(self, full_text, text_before_cursor):
self.identifier = None
self.word_before_cursor = word_before_cursor = last_word(
diff --git a/pgcli/pgclirc b/pgcli/pgclirc
index e97afda..15c10f5 100644
--- a/pgcli/pgclirc
+++ b/pgcli/pgclirc
@@ -23,9 +23,13 @@ multi_line = False
multi_line_mode = psql
# Destructive warning mode will alert you before executing a sql statement
-# that may cause harm to the database such as "drop table", "drop database"
-# or "shutdown".
-destructive_warning = True
+# that may cause harm to the database such as "drop table", "drop database",
+# "shutdown", "delete", or "update".
+# Possible values:
+# "all" - warn on data definition statements, server actions such as SHUTDOWN, DELETE or UPDATE
+# "moderate" - skip warning on UPDATE statements, except for unconditional updates
+# "off" - skip all warnings
+destructive_warning = all
# Enables expand mode, which is similar to `\x` in psql.
expand = False
@@ -170,9 +174,12 @@ arg-toolbar = 'noinherit bold'
arg-toolbar.text = 'nobold'
bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold'
bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold'
-literal.string = '#ba2121'
-literal.number = '#666666'
-keyword = 'bold #008000'
+# These three values can be used to further refine the syntax highlighting.
+# They are commented out by default, since they have priority over the theme set
+# with the `syntax_style` setting and overriding its behavior can be confusing.
+# literal.string = '#ba2121'
+# literal.number = '#666666'
+# keyword = 'bold #008000'
# style classes for colored table output
output.header = "#00ff5f bold"
diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py
index 9c95a01..227e25c 100644
--- a/pgcli/pgcompleter.py
+++ b/pgcli/pgcompleter.py
@@ -83,7 +83,7 @@ class PGCompleter(Completer):
reserved_words = set(get_literals("reserved"))
def __init__(self, smart_completion=True, pgspecial=None, settings=None):
- super(PGCompleter, self).__init__()
+ super().__init__()
self.smart_completion = smart_completion
self.pgspecial = pgspecial
self.prioritizer = PrevalenceCounter()
@@ -140,7 +140,7 @@ class PGCompleter(Completer):
return "'{}'".format(self.unescape_name(name))
def unescape_name(self, name):
- """ Unquote a string."""
+ """Unquote a string."""
if name and name[0] == '"' and name[-1] == '"':
name = name[1:-1]
@@ -177,7 +177,7 @@ class PGCompleter(Completer):
:return:
"""
# casing should be a dict {lowercasename:PreferredCasingName}
- self.casing = dict((word.lower(), word) for word in words)
+ self.casing = {word.lower(): word for word in words}
def extend_relations(self, data, kind):
"""extend metadata for tables or views.
@@ -279,8 +279,8 @@ class PGCompleter(Completer):
fk = ForeignKey(
parentschema, parenttable, parcol, childschema, childtable, childcol
)
- childcolmeta.foreignkeys.append((fk))
- parcolmeta.foreignkeys.append((fk))
+ childcolmeta.foreignkeys.append(fk)
+ parcolmeta.foreignkeys.append(fk)
def extend_datatypes(self, type_data):
@@ -424,7 +424,7 @@ class PGCompleter(Completer):
# the same priority as unquoted names.
lexical_priority = (
tuple(
- 0 if c in (" _") else -ord(c)
+ 0 if c in " _" else -ord(c)
for c in self.unescape_name(item.lower())
)
+ (1,)
@@ -517,9 +517,9 @@ class PGCompleter(Completer):
# require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should
# suggest only columns that appear in the last table and one more
ltbl = tables[-1].ref
- other_tbl_cols = set(
+ other_tbl_cols = {
c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs
- )
+ }
scoped_cols = {
t: [col for col in cols if col.name in other_tbl_cols]
for t, cols in scoped_cols.items()
@@ -574,7 +574,7 @@ class PGCompleter(Completer):
tbls - TableReference iterable of tables already in query
"""
tbl = self.case(tbl)
- tbls = set(normalize_ref(t.ref) for t in tbls)
+ tbls = {normalize_ref(t.ref) for t in tbls}
if self.generate_aliases:
tbl = generate_alias(self.unescape_name(tbl))
if normalize_ref(tbl) not in tbls:
@@ -589,10 +589,10 @@ class PGCompleter(Completer):
tbls = suggestion.table_refs
cols = self.populate_scoped_cols(tbls)
# Set up some data structures for efficient access
- qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls)
- ref_prio = dict((normalize_ref(t.ref), n) for n, t in enumerate(tbls))
- refs = set(normalize_ref(t.ref) for t in tbls)
- other_tbls = set((t.schema, t.name) for t in list(cols)[:-1])
+ qualified = {normalize_ref(t.ref): t.schema for t in tbls}
+ ref_prio = {normalize_ref(t.ref): n for n, t in enumerate(tbls)}
+ refs = {normalize_ref(t.ref) for t in tbls}
+ other_tbls = {(t.schema, t.name) for t in list(cols)[:-1]}
joins = []
# Iterate over FKs in existing tables to find potential joins
fks = (
@@ -667,7 +667,7 @@ class PGCompleter(Completer):
return d
# Tables that are closer to the cursor get higher prio
- ref_prio = dict((tbl.ref, num) for num, tbl in enumerate(suggestion.table_refs))
+ ref_prio = {tbl.ref: num for num, tbl in enumerate(suggestion.table_refs)}
# Map (schema, table, col) to tables
coldict = list_dict(
((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref
@@ -703,7 +703,11 @@ class PGCompleter(Completer):
not f.is_aggregate
and not f.is_window
and not f.is_extension
- and (f.is_public or f.schema_name == suggestion.schema)
+ and (
+ f.is_public
+ or f.schema_name in self.search_path
+ or f.schema_name == suggestion.schema
+ )
)
else:
@@ -721,9 +725,7 @@ class PGCompleter(Completer):
# Function overloading means we way have multiple functions of the same
# name at this point, so keep unique names only
all_functions = self.populate_functions(suggestion.schema, filt)
- funcs = set(
- self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions
- )
+ funcs = {self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions}
matches = self.find_matches(word_before_cursor, funcs, meta="function")
@@ -953,7 +955,7 @@ class PGCompleter(Completer):
:return: {TableReference:{colname:ColumnMetaData}}
"""
- ctes = dict((normalize_ref(t.name), t.columns) for t in local_tbls)
+ ctes = {normalize_ref(t.name): t.columns for t in local_tbls}
columns = OrderedDict()
meta = self.dbmetadata
diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py
index d34bf26..a013b55 100644
--- a/pgcli/pgexecute.py
+++ b/pgcli/pgexecute.py
@@ -1,13 +1,15 @@
-import traceback
import logging
+import select
+import traceback
+
+import pgspecial as special
import psycopg2
-import psycopg2.extras
import psycopg2.errorcodes
import psycopg2.extensions as ext
+import psycopg2.extras
import sqlparse
-import pgspecial as special
-import select
from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn
+
from .packages.parseutils.meta import FunctionMetadata, ForeignKey
_logger = logging.getLogger(__name__)
@@ -27,6 +29,7 @@ ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING))
# TODO: Get default timeout from pgclirc?
_WAIT_SELECT_TIMEOUT = 1
+_wait_callback_is_set = False
def _wait_select(conn):
@@ -34,31 +37,41 @@ def _wait_select(conn):
copy-pasted from psycopg2.extras.wait_select
the default implementation doesn't define a timeout in the select calls
"""
- while 1:
- try:
- state = conn.poll()
- if state == POLL_OK:
- break
- elif state == POLL_READ:
- select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT)
- elif state == POLL_WRITE:
- select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT)
- else:
- raise conn.OperationalError("bad state from poll: %s" % state)
- except KeyboardInterrupt:
- conn.cancel()
- # the loop will be broken by a server error
- continue
- except select.error as e:
- errno = e.args[0]
- if errno != 4:
- raise
+ try:
+ while 1:
+ try:
+ state = conn.poll()
+ if state == POLL_OK:
+ break
+ elif state == POLL_READ:
+ select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT)
+ elif state == POLL_WRITE:
+ select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT)
+ else:
+ raise conn.OperationalError("bad state from poll: %s" % state)
+ except KeyboardInterrupt:
+ conn.cancel()
+ # the loop will be broken by a server error
+ continue
+ except OSError as e:
+ errno = e.args[0]
+ if errno != 4:
+ raise
+ except psycopg2.OperationalError:
+ pass
-# When running a query, make pressing CTRL+C raise a KeyboardInterrupt
-# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
-# See also https://github.com/psycopg/psycopg2/issues/468
-ext.set_wait_callback(_wait_select)
+def _set_wait_callback(is_virtual_database):
+ global _wait_callback_is_set
+ if _wait_callback_is_set:
+ return
+ _wait_callback_is_set = True
+ if is_virtual_database:
+ return
+ # When running a query, make pressing CTRL+C raise a KeyboardInterrupt
+ # See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
+ # See also https://github.com/psycopg/psycopg2/issues/468
+ ext.set_wait_callback(_wait_select)
def register_date_typecasters(connection):
@@ -72,6 +85,8 @@ def register_date_typecasters(connection):
cursor = connection.cursor()
cursor.execute("SELECT NULL::date")
+ if cursor.description is None:
+ return
date_oid = cursor.description[0][1]
cursor.execute("SELECT NULL::timestamp")
timestamp_oid = cursor.description[0][1]
@@ -103,7 +118,7 @@ def register_json_typecasters(conn, loads_fn):
try:
psycopg2.extras.register_json(conn, loads=loads_fn, name=name)
available.add(name)
- except psycopg2.ProgrammingError:
+ except (psycopg2.ProgrammingError, psycopg2.errors.ProtocolViolation):
pass
return available
@@ -127,7 +142,39 @@ def register_hstore_typecaster(conn):
pass
-class PGExecute(object):
+class ProtocolSafeCursor(psycopg2.extensions.cursor):
+ def __init__(self, *args, **kwargs):
+ self.protocol_error = False
+ self.protocol_message = ""
+ super().__init__(*args, **kwargs)
+
+ def __iter__(self):
+ if self.protocol_error:
+ raise StopIteration
+ return super().__iter__()
+
+ def fetchall(self):
+ if self.protocol_error:
+ return [(self.protocol_message,)]
+ return super().fetchall()
+
+ def fetchone(self):
+ if self.protocol_error:
+ return (self.protocol_message,)
+ return super().fetchone()
+
+ def execute(self, sql, args=None):
+ try:
+ psycopg2.extensions.cursor.execute(self, sql, args)
+ self.protocol_error = False
+ self.protocol_message = ""
+ except psycopg2.errors.ProtocolViolation as ex:
+ self.protocol_error = True
+ self.protocol_message = ex.pgerror
+ _logger.debug("%s: %s" % (ex.__class__.__name__, ex))
+
+
+class PGExecute:
# The boolean argument to the current_schemas function indicates whether
# implicit schemas, e.g. pg_catalog
@@ -190,8 +237,6 @@ class PGExecute(object):
SELECT pg_catalog.pg_get_functiondef(f.f_oid)
FROM f"""
- version_query = "SELECT version();"
-
def __init__(
self,
database=None,
@@ -203,6 +248,7 @@ class PGExecute(object):
**kwargs,
):
self._conn_params = {}
+ self._is_virtual_database = None
self.conn = None
self.dbname = None
self.user = None
@@ -214,6 +260,11 @@ class PGExecute(object):
self.connect(database, user, password, host, port, dsn, **kwargs)
self.reset_expanded = None
+ def is_virtual_database(self):
+ if self._is_virtual_database is None:
+ self._is_virtual_database = self.is_protocol_error()
+ return self._is_virtual_database
+
def copy(self):
"""Returns a clone of the current executor."""
return self.__class__(**self._conn_params)
@@ -250,9 +301,9 @@ class PGExecute(object):
)
conn_params.update({k: v for k, v in new_params.items() if v})
+ conn_params["cursor_factory"] = ProtocolSafeCursor
conn = psycopg2.connect(**conn_params)
- cursor = conn.cursor()
conn.set_client_encoding("utf8")
self._conn_params = conn_params
@@ -293,16 +344,22 @@ class PGExecute(object):
self.extra_args = kwargs
if not self.host:
- self.host = self.get_socket_directory()
+ self.host = (
+ "pgbouncer"
+ if self.is_virtual_database()
+ else self.get_socket_directory()
+ )
- pid = self._select_one(cursor, "select pg_backend_pid()")[0]
- self.pid = pid
+ self.pid = conn.get_backend_pid()
self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1")
- self.server_version = conn.get_parameter_status("server_version")
+ self.server_version = conn.get_parameter_status("server_version") or ""
+
+ _set_wait_callback(self.is_virtual_database())
- register_date_typecasters(conn)
- register_json_typecasters(self.conn, self._json_typecaster)
- register_hstore_typecaster(self.conn)
+ if not self.is_virtual_database():
+ register_date_typecasters(conn)
+ register_json_typecasters(self.conn, self._json_typecaster)
+ register_hstore_typecaster(self.conn)
@property
def short_host(self):
@@ -395,7 +452,13 @@ class PGExecute(object):
# See https://github.com/dbcli/pgcli/issues/1014.
cur = None
try:
- for result in pgspecial.execute(cur, sql):
+ response = pgspecial.execute(cur, sql)
+ if cur and cur.protocol_error:
+ yield None, None, None, cur.protocol_message, statement, False, False
+ # this would close connection. We should reconnect.
+ self.connect()
+ continue
+ for result in response:
# e.g. execute_from_file already appends these
if len(result) < 7:
yield result + (sql, True, True)
@@ -453,6 +516,9 @@ class PGExecute(object):
if cur.description:
headers = [x[0] for x in cur.description]
return title, cur, headers, cur.statusmessage
+ elif cur.protocol_error:
+ _logger.debug("Protocol error, unsupported command.")
+ return title, None, None, cur.protocol_message
else:
_logger.debug("No rows in result.")
return title, None, None, cur.statusmessage
@@ -485,7 +551,7 @@ class PGExecute(object):
try:
cur.execute(sql, (spec,))
except psycopg2.ProgrammingError:
- raise RuntimeError("View {} does not exist.".format(spec))
+ raise RuntimeError(f"View {spec} does not exist.")
result = cur.fetchone()
view_type = "MATERIALIZED" if result[2] == "m" else ""
return template.format(*result + (view_type,))
@@ -501,7 +567,7 @@ class PGExecute(object):
result = cur.fetchone()
return result[0]
except psycopg2.ProgrammingError:
- raise RuntimeError("Function {} does not exist.".format(spec))
+ raise RuntimeError(f"Function {spec} does not exist.")
def schemata(self):
"""Returns a list of schema names in the database"""
@@ -527,21 +593,18 @@ class PGExecute(object):
sql = cur.mogrify(self.tables_query, [kinds])
_logger.debug("Tables Query. sql: %r", sql)
cur.execute(sql)
- for row in cur:
- yield row
+ yield from cur
def tables(self):
"""Yields (schema_name, table_name) tuples"""
- for row in self._relations(kinds=["r", "p", "f"]):
- yield row
+ yield from self._relations(kinds=["r", "p", "f"])
def views(self):
"""Yields (schema_name, view_name) tuples.
Includes both views and and materialized views
"""
- for row in self._relations(kinds=["v", "m"]):
- yield row
+ yield from self._relations(kinds=["v", "m"])
def _columns(self, kinds=("r", "p", "f", "v", "m")):
"""Get column metadata for tables and views
@@ -599,16 +662,13 @@ class PGExecute(object):
sql = cur.mogrify(columns_query, [kinds])
_logger.debug("Columns Query. sql: %r", sql)
cur.execute(sql)
- for row in cur:
- yield row
+ yield from cur
def table_columns(self):
- for row in self._columns(kinds=["r", "p", "f"]):
- yield row
+ yield from self._columns(kinds=["r", "p", "f"])
def view_columns(self):
- for row in self._columns(kinds=["v", "m"]):
- yield row
+ yield from self._columns(kinds=["v", "m"])
def databases(self):
with self.conn.cursor() as cur:
@@ -623,6 +683,13 @@ class PGExecute(object):
headers = [x[0] for x in cur.description]
return cur.fetchall(), headers, cur.statusmessage
+ def is_protocol_error(self):
+ query = "SELECT 1"
+ with self.conn.cursor() as cur:
+ _logger.debug("Simple Query. sql: %r", query)
+ cur.execute(query)
+ return bool(cur.protocol_error)
+
def get_socket_directory(self):
with self.conn.cursor() as cur:
_logger.debug(
@@ -804,8 +871,7 @@ class PGExecute(object):
"""
_logger.debug("Datatypes Query. sql: %r", query)
cur.execute(query)
- for row in cur:
- yield row
+ yield from cur
def casing(self):
"""Yields the most common casing for names used in db functions"""
diff --git a/pgcli/pgtoolbar.py b/pgcli/pgtoolbar.py
index f4289a1..41f903d 100644
--- a/pgcli/pgtoolbar.py
+++ b/pgcli/pgtoolbar.py
@@ -1,15 +1,23 @@
+from pkg_resources import packaging
+
+import prompt_toolkit
from prompt_toolkit.key_binding.vi_state import InputMode
from prompt_toolkit.application import get_app
+parse_version = packaging.version.parse
+
+vi_modes = {
+ InputMode.INSERT: "I",
+ InputMode.NAVIGATION: "N",
+ InputMode.REPLACE: "R",
+ InputMode.INSERT_MULTIPLE: "M",
+}
+if parse_version(prompt_toolkit.__version__) >= parse_version("3.0.6"):
+ vi_modes[InputMode.REPLACE_SINGLE] = "R"
+
def _get_vi_mode():
- return {
- InputMode.INSERT: "I",
- InputMode.NAVIGATION: "N",
- InputMode.REPLACE: "R",
- InputMode.REPLACE_SINGLE: "R",
- InputMode.INSERT_MULTIPLE: "M",
- }[get_app().vi_state.input_mode]
+ return vi_modes[get_app().vi_state.input_mode]
def create_toolbar_tokens_func(pgcli):
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 80e8f43..84fa6bf 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -1,5 +1,4 @@
pytest>=2.7.0
-mock>=1.0.1
tox>=1.9.2
behave>=1.2.4
pexpect==3.3
diff --git a/tests/conftest.py b/tests/conftest.py
index 2a715b1..33cddf2 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -12,7 +12,7 @@ from utils import (
import pgcli.pgexecute
-@pytest.yield_fixture(scope="function")
+@pytest.fixture(scope="function")
def connection():
create_db("_test_db")
connection = db_connection("_test_db")
diff --git a/tests/features/db_utils.py b/tests/features/db_utils.py
index f57bc3b..6898394 100644
--- a/tests/features/db_utils.py
+++ b/tests/features/db_utils.py
@@ -44,7 +44,7 @@ def create_cn(hostname, password, username, dbname, port):
host=hostname, user=username, database=dbname, password=password, port=port
)
- print("Created connection: {0}.".format(cn.dsn))
+ print(f"Created connection: {cn.dsn}.")
return cn
@@ -75,4 +75,4 @@ def close_cn(cn=None):
"""
if cn:
cn.close()
- print("Closed connection: {0}.".format(cn.dsn))
+ print(f"Closed connection: {cn.dsn}.")
diff --git a/tests/features/environment.py b/tests/features/environment.py
index 049c2f2..215c85c 100644
--- a/tests/features/environment.py
+++ b/tests/features/environment.py
@@ -38,7 +38,7 @@ def before_all(context):
vi = "_".join([str(x) for x in sys.version_info[:3]])
db_name = context.config.userdata.get("pg_test_db", "pgcli_behave_tests")
- db_name_full = "{0}_{1}".format(db_name, vi)
+ db_name_full = f"{db_name}_{vi}"
# Store get params from config.
context.conf = {
@@ -63,7 +63,7 @@ def before_all(context):
"import coverage",
"coverage.process_startup()",
"import pgcli.main",
- "pgcli.main.cli()",
+ "pgcli.main.cli(auto_envvar_prefix='BEHAVE')",
]
),
)
@@ -102,6 +102,7 @@ def before_all(context):
else:
if "PGPASSWORD" in os.environ:
del os.environ["PGPASSWORD"]
+ os.environ["BEHAVE_WARN"] = "moderate"
context.cn = dbutils.create_db(
context.conf["host"],
@@ -122,12 +123,12 @@ def before_all(context):
def show_env_changes(env_old, env_new):
"""Print out all test-specific env values."""
print("--- os.environ changed values: ---")
- all_keys = set(list(env_old.keys()) + list(env_new.keys()))
+ all_keys = env_old.keys() | env_new.keys()
for k in sorted(all_keys):
old_value = env_old.get(k, "")
new_value = env_new.get(k, "")
if new_value and old_value != new_value:
- print('{}="{}"'.format(k, new_value))
+ print(f'{k}="{new_value}"')
print("-" * 20)
@@ -173,13 +174,13 @@ def after_scenario(context, scenario):
# Quit nicely.
if not context.atprompt:
dbname = context.currentdb
- context.cli.expect_exact("{0}> ".format(dbname), timeout=15)
+ context.cli.expect_exact(f"{dbname}> ", timeout=15)
context.cli.sendcontrol("c")
context.cli.sendcontrol("d")
try:
context.cli.expect_exact(pexpect.EOF, timeout=15)
except pexpect.TIMEOUT:
- print("--- after_scenario {}: kill cli".format(scenario.name))
+ print(f"--- after_scenario {scenario.name}: kill cli")
context.cli.kill(signal.SIGKILL)
if hasattr(context, "tmpfile_sql_help") and context.tmpfile_sql_help:
context.tmpfile_sql_help.close()
diff --git a/tests/features/fixture_utils.py b/tests/features/fixture_utils.py
index 16f123a..70b603d 100644
--- a/tests/features/fixture_utils.py
+++ b/tests/features/fixture_utils.py
@@ -18,7 +18,7 @@ def read_fixture_files():
"""Read all files inside fixture_data directory."""
current_dir = os.path.dirname(__file__)
fixture_dir = os.path.join(current_dir, "fixture_data/")
- print("reading fixture data: {}".format(fixture_dir))
+ print(f"reading fixture data: {fixture_dir}")
fixture_dict = {}
for filename in os.listdir(fixture_dir):
if filename not in [".", ".."]:
diff --git a/tests/features/steps/basic_commands.py b/tests/features/steps/basic_commands.py
index 1069662..07e9ec1 100644
--- a/tests/features/steps/basic_commands.py
+++ b/tests/features/steps/basic_commands.py
@@ -65,19 +65,20 @@ def step_ctrl_d(context):
Send Ctrl + D to hopefully exit.
"""
# turn off pager before exiting
- context.cli.sendline("\pset pager off")
+ context.cli.sendcontrol("c")
+ context.cli.sendline(r"\pset pager off")
wrappers.wait_prompt(context)
context.cli.sendcontrol("d")
context.cli.expect(pexpect.EOF, timeout=15)
context.exit_sent = True
-@when('we send "\?" command')
+@when(r'we send "\?" command')
def step_send_help(context):
- """
+ r"""
Send \? to see help.
"""
- context.cli.sendline("\?")
+ context.cli.sendline(r"\?")
@when("we send partial select command")
@@ -96,9 +97,9 @@ def step_see_error_message(context):
@when("we send source command")
def step_send_source_command(context):
context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_")
- context.tmpfile_sql_help.write(b"\?")
+ context.tmpfile_sql_help.write(br"\?")
context.tmpfile_sql_help.flush()
- context.cli.sendline("\i {0}".format(context.tmpfile_sql_help.name))
+ context.cli.sendline(fr"\i {context.tmpfile_sql_help.name}")
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
diff --git a/tests/features/steps/crud_database.py b/tests/features/steps/crud_database.py
index 3fd8b7a..3f5d0e7 100644
--- a/tests/features/steps/crud_database.py
+++ b/tests/features/steps/crud_database.py
@@ -14,7 +14,7 @@ def step_db_create(context):
"""
Send create database.
"""
- context.cli.sendline("create database {0};".format(context.conf["dbname_tmp"]))
+ context.cli.sendline("create database {};".format(context.conf["dbname_tmp"]))
context.response = {"database_name": context.conf["dbname_tmp"]}
@@ -24,7 +24,7 @@ def step_db_drop(context):
"""
Send drop database.
"""
- context.cli.sendline("drop database {0};".format(context.conf["dbname_tmp"]))
+ context.cli.sendline("drop database {};".format(context.conf["dbname_tmp"]))
@when("we connect to test database")
@@ -33,7 +33,7 @@ def step_db_connect_test(context):
Send connect to database.
"""
db_name = context.conf["dbname"]
- context.cli.sendline("\\connect {0}".format(db_name))
+ context.cli.sendline(f"\\connect {db_name}")
@when("we connect to dbserver")
@@ -59,7 +59,7 @@ def step_see_prompt(context):
Wait to see the prompt.
"""
db_name = getattr(context, "currentdb", context.conf["dbname"])
- wrappers.expect_exact(context, "{0}> ".format(db_name), timeout=5)
+ wrappers.expect_exact(context, f"{db_name}> ", timeout=5)
context.atprompt = True
diff --git a/tests/features/steps/expanded.py b/tests/features/steps/expanded.py
index f34fcf0..265ea39 100644
--- a/tests/features/steps/expanded.py
+++ b/tests/features/steps/expanded.py
@@ -31,7 +31,7 @@ def step_prepare_data(context):
@when("we set expanded {mode}")
def step_set_expanded(context, mode):
"""Set expanded to mode."""
- context.cli.sendline("\\" + "x {}".format(mode))
+ context.cli.sendline("\\" + f"x {mode}")
wrappers.expect_exact(context, "Expanded display is", timeout=2)
wrappers.wait_prompt(context)
diff --git a/tests/features/steps/iocommands.py b/tests/features/steps/iocommands.py
index 613aeb2..a614490 100644
--- a/tests/features/steps/iocommands.py
+++ b/tests/features/steps/iocommands.py
@@ -13,7 +13,7 @@ def step_edit_file(context):
)
if os.path.exists(context.editor_file_name):
os.remove(context.editor_file_name)
- context.cli.sendline("\e {0}".format(os.path.basename(context.editor_file_name)))
+ context.cli.sendline(r"\e {}".format(os.path.basename(context.editor_file_name)))
wrappers.expect_exact(
context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2
)
@@ -53,7 +53,7 @@ def step_tee_ouptut(context):
)
if os.path.exists(context.tee_file_name):
os.remove(context.tee_file_name)
- context.cli.sendline("\o {0}".format(os.path.basename(context.tee_file_name)))
+ context.cli.sendline(r"\o {}".format(os.path.basename(context.tee_file_name)))
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
wrappers.expect_exact(context, "Writing to file", timeout=5)
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
@@ -67,7 +67,7 @@ def step_query_select_123456(context):
@when("we stop teeing output")
def step_notee_output(context):
- context.cli.sendline("\o")
+ context.cli.sendline(r"\o")
wrappers.expect_exact(context, "Time", timeout=5)
diff --git a/tests/features/steps/specials.py b/tests/features/steps/specials.py
index 813292c..a85f371 100644
--- a/tests/features/steps/specials.py
+++ b/tests/features/steps/specials.py
@@ -22,5 +22,10 @@ def step_see_refresh_started(context):
Wait to see refresh output.
"""
wrappers.expect_pager(
- context, "Auto-completion refresh started in the background.\r\n", timeout=2
+ context,
+ [
+ "Auto-completion refresh started in the background.\r\n",
+ "Auto-completion refresh restarted.\r\n",
+ ],
+ timeout=2,
)
diff --git a/tests/features/steps/wrappers.py b/tests/features/steps/wrappers.py
index e0f5a20..0ca8366 100644
--- a/tests/features/steps/wrappers.py
+++ b/tests/features/steps/wrappers.py
@@ -39,9 +39,15 @@ def expect_exact(context, expected, timeout):
def expect_pager(context, expected, timeout):
+ formatted = expected if isinstance(expected, list) else [expected]
+ formatted = [
+ f"{context.conf['pager_boundary']}\r\n{t}{context.conf['pager_boundary']}\r\n"
+ for t in formatted
+ ]
+
expect_exact(
context,
- "{0}\r\n{1}{0}\r\n".format(context.conf["pager_boundary"], expected),
+ formatted,
timeout=timeout,
)
@@ -57,7 +63,7 @@ def run_cli(context, run_args=None, prompt_check=True, currentdb=None):
context.cli.logfile = context.logfile
context.exit_sent = False
context.currentdb = currentdb or context.conf["dbname"]
- context.cli.sendline("\pset pager always")
+ context.cli.sendline(r"\pset pager always")
if prompt_check:
wait_prompt(context)
diff --git a/tests/features/wrappager.py b/tests/features/wrappager.py
index 51d4909..51d4909 100755..100644
--- a/tests/features/wrappager.py
+++ b/tests/features/wrappager.py
diff --git a/tests/metadata.py b/tests/metadata.py
index 2f89ea2..4ebcccd 100644
--- a/tests/metadata.py
+++ b/tests/metadata.py
@@ -3,7 +3,7 @@ from itertools import product
from pgcli.packages.parseutils.meta import FunctionMetadata, ForeignKey
from prompt_toolkit.completion import Completion
from prompt_toolkit.document import Document
-from mock import Mock
+from unittest.mock import Mock
import pytest
parametrize = pytest.mark.parametrize
@@ -59,7 +59,7 @@ def wildcard_expansion(cols, pos=-1):
return Completion(cols, start_position=pos, display_meta="columns", display="*")
-class MetaData(object):
+class MetaData:
def __init__(self, metadata):
self.metadata = metadata
@@ -128,7 +128,7 @@ class MetaData(object):
]
def schemas(self, pos=0):
- schemas = set(sch for schs in self.metadata.values() for sch in schs)
+ schemas = {sch for schs in self.metadata.values() for sch in schs}
return [schema(escape(s), pos=pos) for s in schemas]
def functions_and_keywords(self, parent="public", pos=0):
diff --git a/tests/parseutils/test_parseutils.py b/tests/parseutils/test_parseutils.py
index 50bc889..5a375d7 100644
--- a/tests/parseutils/test_parseutils.py
+++ b/tests/parseutils/test_parseutils.py
@@ -1,4 +1,5 @@
import pytest
+from pgcli.packages.parseutils import is_destructive
from pgcli.packages.parseutils.tables import extract_tables
from pgcli.packages.parseutils.utils import find_prev_keyword, is_open_quote
@@ -34,12 +35,12 @@ def test_simple_select_single_table_double_quoted():
def test_simple_select_multiple_tables():
tables = extract_tables("select * from abc, def")
- assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)])
+ assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)}
def test_simple_select_multiple_tables_double_quoted():
tables = extract_tables('select * from "Abc", "Def"')
- assert set(tables) == set([(None, "Abc", None, False), (None, "Def", None, False)])
+ assert set(tables) == {(None, "Abc", None, False), (None, "Def", None, False)}
def test_simple_select_single_table_deouble_quoted_aliased():
@@ -49,14 +50,12 @@ def test_simple_select_single_table_deouble_quoted_aliased():
def test_simple_select_multiple_tables_deouble_quoted_aliased():
tables = extract_tables('select * from "Abc" a, "Def" d')
- assert set(tables) == set([(None, "Abc", "a", False), (None, "Def", "d", False)])
+ assert set(tables) == {(None, "Abc", "a", False), (None, "Def", "d", False)}
def test_simple_select_multiple_tables_schema_qualified():
tables = extract_tables("select * from abc.def, ghi.jkl")
- assert set(tables) == set(
- [("abc", "def", None, False), ("ghi", "jkl", None, False)]
- )
+ assert set(tables) == {("abc", "def", None, False), ("ghi", "jkl", None, False)}
def test_simple_select_with_cols_single_table():
@@ -71,14 +70,12 @@ def test_simple_select_with_cols_single_table_schema_qualified():
def test_simple_select_with_cols_multiple_tables():
tables = extract_tables("select a,b from abc, def")
- assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)])
+ assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)}
def test_simple_select_with_cols_multiple_qualified_tables():
tables = extract_tables("select a,b from abc.def, def.ghi")
- assert set(tables) == set(
- [("abc", "def", None, False), ("def", "ghi", None, False)]
- )
+ assert set(tables) == {("abc", "def", None, False), ("def", "ghi", None, False)}
def test_select_with_hanging_comma_single_table():
@@ -88,14 +85,12 @@ def test_select_with_hanging_comma_single_table():
def test_select_with_hanging_comma_multiple_tables():
tables = extract_tables("select a, from abc, def")
- assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)])
+ assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)}
def test_select_with_hanging_period_multiple_tables():
tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2")
- assert set(tables) == set(
- [(None, "tabl1", "t1", False), (None, "tabl2", "t2", False)]
- )
+ assert set(tables) == {(None, "tabl1", "t1", False), (None, "tabl2", "t2", False)}
def test_simple_insert_single_table():
@@ -126,14 +121,14 @@ def test_simple_update_table_with_schema():
@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"])
def test_join_table(join_type):
- sql = "SELECT * FROM abc a {0} JOIN def d ON a.id = d.num".format(join_type)
+ sql = f"SELECT * FROM abc a {join_type} JOIN def d ON a.id = d.num"
tables = extract_tables(sql)
- assert set(tables) == set([(None, "abc", "a", False), (None, "def", "d", False)])
+ assert set(tables) == {(None, "abc", "a", False), (None, "def", "d", False)}
def test_join_table_schema_qualified():
tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num")
- assert set(tables) == set([("abc", "def", "x", False), ("ghi", "jkl", "y", False)])
+ assert set(tables) == {("abc", "def", "x", False), ("ghi", "jkl", "y", False)}
def test_incomplete_join_clause():
@@ -177,25 +172,25 @@ def test_extract_no_tables(text):
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
def test_simple_function_as_table(arg_list):
- tables = extract_tables("SELECT * FROM foo({0})".format(arg_list))
+ tables = extract_tables(f"SELECT * FROM foo({arg_list})")
assert tables == ((None, "foo", None, True),)
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
def test_simple_schema_qualified_function_as_table(arg_list):
- tables = extract_tables("SELECT * FROM foo.bar({0})".format(arg_list))
+ tables = extract_tables(f"SELECT * FROM foo.bar({arg_list})")
assert tables == (("foo", "bar", None, True),)
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
def test_simple_aliased_function_as_table(arg_list):
- tables = extract_tables("SELECT * FROM foo({0}) bar".format(arg_list))
+ tables = extract_tables(f"SELECT * FROM foo({arg_list}) bar")
assert tables == ((None, "foo", "bar", True),)
def test_simple_table_and_function():
tables = extract_tables("SELECT * FROM foo JOIN bar()")
- assert set(tables) == set([(None, "foo", None, False), (None, "bar", None, True)])
+ assert set(tables) == {(None, "foo", None, False), (None, "bar", None, True)}
def test_complex_table_and_function():
@@ -203,9 +198,7 @@ def test_complex_table_and_function():
"""SELECT * FROM foo.bar baz
JOIN bar.qux(x, y, z) quux"""
)
- assert set(tables) == set(
- [("foo", "bar", "baz", False), ("bar", "qux", "quux", True)]
- )
+ assert set(tables) == {("foo", "bar", "baz", False), ("bar", "qux", "quux", True)}
def test_find_prev_keyword_using():
@@ -267,3 +260,21 @@ def test_is_open_quote__closed(sql):
)
def test_is_open_quote__open(sql):
assert is_open_quote(sql)
+
+
+@pytest.mark.parametrize(
+ ("sql", "warning_level", "expected"),
+ [
+ ("update abc set x = 1", "all", True),
+ ("update abc set x = 1 where y = 2", "all", True),
+ ("update abc set x = 1", "moderate", True),
+ ("update abc set x = 1 where y = 2", "moderate", False),
+ ("select x, y, z from abc", "all", False),
+ ("drop abc", "all", True),
+ ("alter abc", "all", True),
+ ("delete abc", "all", True),
+ ("truncate abc", "all", True),
+ ],
+)
+def test_is_destructive(sql, warning_level, expected):
+ assert is_destructive(sql, warning_level=warning_level) == expected
diff --git a/tests/test_completion_refresher.py b/tests/test_completion_refresher.py
index 6a916a8..a5529d6 100644
--- a/tests/test_completion_refresher.py
+++ b/tests/test_completion_refresher.py
@@ -1,6 +1,6 @@
import time
import pytest
-from mock import Mock, patch
+from unittest.mock import Mock, patch
@pytest.fixture
@@ -37,7 +37,7 @@ def test_refresh_called_once(refresher):
:return:
"""
callbacks = Mock()
- pgexecute = Mock()
+ pgexecute = Mock(**{"is_virtual_database.return_value": False})
special = Mock()
with patch.object(refresher, "_bg_refresh") as bg_refresh:
@@ -57,7 +57,7 @@ def test_refresh_called_twice(refresher):
"""
callbacks = Mock()
- pgexecute = Mock()
+ pgexecute = Mock(**{"is_virtual_database.return_value": False})
special = Mock()
def dummy_bg_refresh(*args):
@@ -84,14 +84,12 @@ def test_refresh_with_callbacks(refresher):
:param refresher:
"""
callbacks = [Mock()]
- pgexecute_class = Mock()
- pgexecute = Mock()
+ pgexecute = Mock(**{"is_virtual_database.return_value": False})
pgexecute.extra_args = {}
special = Mock()
- with patch("pgcli.completion_refresher.PGExecute", pgexecute_class):
- # Set refreshers to 0: we're not testing refresh logic here
- refresher.refreshers = {}
- refresher.refresh(pgexecute, special, callbacks)
- time.sleep(1) # Wait for the thread to work.
- assert callbacks[0].call_count == 1
+ # Set refreshers to 0: we're not testing refresh logic here
+ refresher.refreshers = {}
+ refresher.refresh(pgexecute, special, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert callbacks[0].call_count == 1
diff --git a/tests/test_config.py b/tests/test_config.py
index 1c023e0..08fe74e 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -1,9 +1,10 @@
+import io
import os
import stat
import pytest
-from pgcli.config import ensure_dir_exists
+from pgcli.config import ensure_dir_exists, skip_initial_comment
def test_ensure_file_parent(tmpdir):
@@ -20,11 +21,23 @@ def test_ensure_existing_dir(tmpdir):
def test_ensure_other_create_error(tmpdir):
- subdir = tmpdir.join("subdir")
+ subdir = tmpdir.join('subdir"')
rcfile = subdir.join("rcfile")
- # trigger an oserror that isn't "directory already exists"
+ # trigger an oserror that isn't "directory already exists"
os.chmod(str(tmpdir), stat.S_IREAD)
with pytest.raises(OSError):
ensure_dir_exists(str(rcfile))
+
+
+@pytest.mark.parametrize(
+ "text, skipped_lines",
+ (
+ ("abc\n", 1),
+ ("#[section]\ndef\n[section]", 2),
+ ("[section]", 0),
+ ),
+)
+def test_skip_initial_comment(text, skipped_lines):
+ assert skip_initial_comment(io.StringIO(text)) == skipped_lines
diff --git a/tests/test_main.py b/tests/test_main.py
index 9b85a34..c48accb 100644
--- a/tests/test_main.py
+++ b/tests/test_main.py
@@ -1,6 +1,6 @@
import os
import platform
-import mock
+from unittest import mock
import pytest
@@ -288,7 +288,12 @@ def test_pg_service_file(tmpdir):
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
with open(tmpdir.join(".pg_service.conf").strpath, "w") as service_conf:
service_conf.write(
- """[myservice]
+ """File begins with a comment
+ that is not a comment
+ # or maybe a comment after all
+ because psql is crazy
+
+ [myservice]
host=a_host
user=a_user
port=5433
diff --git a/tests/test_naive_completion.py b/tests/test_naive_completion.py
index a6c80a7..5b93661 100644
--- a/tests/test_naive_completion.py
+++ b/tests/test_naive_completion.py
@@ -13,7 +13,7 @@ def completer():
@pytest.fixture
def complete_event():
- from mock import Mock
+ from unittest.mock import Mock
return Mock()
diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py
index 9273be9..109674c 100644
--- a/tests/test_pgexecute.py
+++ b/tests/test_pgexecute.py
@@ -2,7 +2,7 @@ from textwrap import dedent
import psycopg2
import pytest
-from mock import patch, MagicMock
+from unittest.mock import patch, MagicMock
from pgspecial.main import PGSpecial, NO_QUERY
from utils import run, dbtest, requires_json, requires_jsonb
@@ -89,7 +89,7 @@ def test_expanded_slash_G(executor, pgspecial):
# Tests whether we reset the expanded output after a \G.
run(executor, """create table test(a boolean)""")
run(executor, """insert into test values(True)""")
- results = run(executor, """select * from test \G""", pgspecial=pgspecial)
+ results = run(executor, r"""select * from test \G""", pgspecial=pgspecial)
assert pgspecial.expanded_output == False
@@ -105,31 +105,35 @@ def test_schemata_table_views_and_columns_query(executor):
# schemata
# don't enforce all members of the schemas since they may include postgres
# temporary schemas
- assert set(executor.schemata()) >= set(
- ["public", "pg_catalog", "information_schema", "schema1", "schema2"]
- )
+ assert set(executor.schemata()) >= {
+ "public",
+ "pg_catalog",
+ "information_schema",
+ "schema1",
+ "schema2",
+ }
assert executor.search_path() == ["pg_catalog", "public"]
# tables
- assert set(executor.tables()) >= set(
- [("public", "a"), ("public", "b"), ("schema1", "c")]
- )
-
- assert set(executor.table_columns()) >= set(
- [
- ("public", "a", "x", "text", False, None),
- ("public", "a", "y", "text", False, None),
- ("public", "b", "z", "text", False, None),
- ("schema1", "c", "w", "text", True, "'meow'::text"),
- ]
- )
+ assert set(executor.tables()) >= {
+ ("public", "a"),
+ ("public", "b"),
+ ("schema1", "c"),
+ }
+
+ assert set(executor.table_columns()) >= {
+ ("public", "a", "x", "text", False, None),
+ ("public", "a", "y", "text", False, None),
+ ("public", "b", "z", "text", False, None),
+ ("schema1", "c", "w", "text", True, "'meow'::text"),
+ }
# views
- assert set(executor.views()) >= set([("public", "d")])
+ assert set(executor.views()) >= {("public", "d")}
- assert set(executor.view_columns()) >= set(
- [("public", "d", "e", "integer", False, None)]
- )
+ assert set(executor.view_columns()) >= {
+ ("public", "d", "e", "integer", False, None)
+ }
@dbtest
@@ -142,9 +146,9 @@ def test_foreign_key_query(executor):
"create table schema2.child(childid int PRIMARY KEY, motherid int REFERENCES schema1.parent)",
)
- assert set(executor.foreignkeys()) >= set(
- [("schema1", "parent", "parentid", "schema2", "child", "motherid")]
- )
+ assert set(executor.foreignkeys()) >= {
+ ("schema1", "parent", "parentid", "schema2", "child", "motherid")
+ }
@dbtest
@@ -175,30 +179,28 @@ def test_functions_query(executor):
)
funcs = set(executor.functions())
- assert funcs >= set(
- [
- function_meta_data(func_name="func1", return_type="integer"),
- function_meta_data(
- func_name="func3",
- arg_names=["x", "y"],
- arg_types=["integer", "integer"],
- arg_modes=["t", "t"],
- return_type="record",
- is_set_returning=True,
- ),
- function_meta_data(
- schema_name="public",
- func_name="func4",
- arg_names=("x",),
- arg_types=("integer",),
- return_type="integer",
- is_set_returning=True,
- ),
- function_meta_data(
- schema_name="schema1", func_name="func2", return_type="integer"
- ),
- ]
- )
+ assert funcs >= {
+ function_meta_data(func_name="func1", return_type="integer"),
+ function_meta_data(
+ func_name="func3",
+ arg_names=["x", "y"],
+ arg_types=["integer", "integer"],
+ arg_modes=["t", "t"],
+ return_type="record",
+ is_set_returning=True,
+ ),
+ function_meta_data(
+ schema_name="public",
+ func_name="func4",
+ arg_names=("x",),
+ arg_types=("integer",),
+ return_type="integer",
+ is_set_returning=True,
+ ),
+ function_meta_data(
+ schema_name="schema1", func_name="func2", return_type="integer"
+ ),
+ }
@dbtest
@@ -257,8 +259,8 @@ def test_not_is_special(executor, pgspecial):
@dbtest
def test_execute_from_file_no_arg(executor, pgspecial):
- """\i without a filename returns an error."""
- result = list(executor.run("\i", pgspecial=pgspecial))
+ r"""\i without a filename returns an error."""
+ result = list(executor.run(r"\i", pgspecial=pgspecial))
status, sql, success, is_special = result[0][3:]
assert "missing required argument" in status
assert success == False
@@ -268,12 +270,12 @@ def test_execute_from_file_no_arg(executor, pgspecial):
@dbtest
@patch("pgcli.main.os")
def test_execute_from_file_io_error(os, executor, pgspecial):
- """\i with an io_error returns an error."""
- # Inject an IOError.
- os.path.expanduser.side_effect = IOError("test")
+ r"""\i with an os_error returns an error."""
+ # Inject an OSError.
+ os.path.expanduser.side_effect = OSError("test")
# Check the result.
- result = list(executor.run("\i test", pgspecial=pgspecial))
+ result = list(executor.run(r"\i test", pgspecial=pgspecial))
status, sql, success, is_special = result[0][3:]
assert status == "test"
assert success == False
@@ -290,7 +292,7 @@ def test_multiple_queries_same_line(executor):
@dbtest
def test_multiple_queries_with_special_command_same_line(executor, pgspecial):
- result = run(executor, "select 'foo'; \d", pgspecial=pgspecial)
+ result = run(executor, r"select 'foo'; \d", pgspecial=pgspecial)
assert len(result) == 11 # 2 * (output+status) * 3 lines
assert "foo" in result[3]
# This is a lame check. :(
@@ -408,7 +410,7 @@ def test_date_time_types(executor):
@pytest.mark.parametrize("value", ["10000000", "10000000.0", "10000000000000"])
def test_large_numbers_render_directly(executor, value):
run(executor, "create table numbertest(a numeric)")
- run(executor, "insert into numbertest (a) values ({0})".format(value))
+ run(executor, f"insert into numbertest (a) values ({value})")
assert value in run(executor, "select * from numbertest", join=True)
@@ -511,13 +513,28 @@ def test_short_host(executor):
assert executor.short_host == "localhost1"
-class BrokenConnection(object):
+class BrokenConnection:
"""Mock a connection that failed."""
def cursor(self):
raise psycopg2.InterfaceError("I'm broken!")
+class VirtualCursor:
+ """Mock a cursor to virtual database like pgbouncer."""
+
+ def __init__(self):
+ self.protocol_error = False
+ self.protocol_message = ""
+ self.description = None
+ self.status = None
+ self.statusmessage = "Error"
+
+ def execute(self, *args, **kwargs):
+ self.protocol_error = True
+ self.protocol_message = "Command not supported"
+
+
@dbtest
def test_exit_without_active_connection(executor):
quit_handler = MagicMock()
@@ -540,3 +557,12 @@ def test_exit_without_active_connection(executor):
# an exception should be raised when running a query without active connection
with pytest.raises(psycopg2.InterfaceError):
run(executor, "select 1", pgspecial=pgspecial)
+
+
+@dbtest
+def test_virtual_database(executor):
+ virtual_connection = MagicMock()
+ virtual_connection.cursor.return_value = VirtualCursor()
+ with patch.object(executor, "conn", virtual_connection):
+ result = run(executor, "select 1")
+ assert "Command not supported" in result
diff --git a/tests/test_pgspecial.py b/tests/test_pgspecial.py
index eaeaf12..cd99e32 100644
--- a/tests/test_pgspecial.py
+++ b/tests/test_pgspecial.py
@@ -13,12 +13,12 @@ from pgcli.packages.sqlcompletion import (
def test_slash_suggests_special():
suggestions = suggest_type("\\", "\\")
- assert set(suggestions) == set([Special()])
+ assert set(suggestions) == {Special()}
def test_slash_d_suggests_special():
suggestions = suggest_type("\\d", "\\d")
- assert set(suggestions) == set([Special()])
+ assert set(suggestions) == {Special()}
def test_dn_suggests_schemata():
@@ -30,24 +30,24 @@ def test_dn_suggests_schemata():
def test_d_suggests_tables_views_and_schemas():
- suggestions = suggest_type("\d ", "\d ")
- assert set(suggestions) == set([Schema(), Table(schema=None), View(schema=None)])
+ suggestions = suggest_type(r"\d ", r"\d ")
+ assert set(suggestions) == {Schema(), Table(schema=None), View(schema=None)}
- suggestions = suggest_type("\d xxx", "\d xxx")
- assert set(suggestions) == set([Schema(), Table(schema=None), View(schema=None)])
+ suggestions = suggest_type(r"\d xxx", r"\d xxx")
+ assert set(suggestions) == {Schema(), Table(schema=None), View(schema=None)}
def test_d_dot_suggests_schema_qualified_tables_or_views():
- suggestions = suggest_type("\d myschema.", "\d myschema.")
- assert set(suggestions) == set([Table(schema="myschema"), View(schema="myschema")])
+ suggestions = suggest_type(r"\d myschema.", r"\d myschema.")
+ assert set(suggestions) == {Table(schema="myschema"), View(schema="myschema")}
- suggestions = suggest_type("\d myschema.xxx", "\d myschema.xxx")
- assert set(suggestions) == set([Table(schema="myschema"), View(schema="myschema")])
+ suggestions = suggest_type(r"\d myschema.xxx", r"\d myschema.xxx")
+ assert set(suggestions) == {Table(schema="myschema"), View(schema="myschema")}
def test_df_suggests_schema_or_function():
suggestions = suggest_type("\\df xxx", "\\df xxx")
- assert set(suggestions) == set([Function(schema=None, usage="special"), Schema()])
+ assert set(suggestions) == {Function(schema=None, usage="special"), Schema()}
suggestions = suggest_type("\\df myschema.xxx", "\\df myschema.xxx")
assert suggestions == (Function(schema="myschema", usage="special"),)
@@ -63,7 +63,7 @@ def test_leading_whitespace_ok():
def test_dT_suggests_schema_or_datatypes():
text = "\\dT "
suggestions = suggest_type(text, text)
- assert set(suggestions) == set([Schema(), Datatype(schema=None)])
+ assert set(suggestions) == {Schema(), Datatype(schema=None)}
def test_schema_qualified_dT_suggests_datatypes():
diff --git a/tests/test_prompt_utils.py b/tests/test_prompt_utils.py
index c1f8a16..a8a3a1e 100644
--- a/tests/test_prompt_utils.py
+++ b/tests/test_prompt_utils.py
@@ -7,4 +7,4 @@ def test_confirm_destructive_query_notty():
stdin = click.get_text_stream("stdin")
if not stdin.isatty():
sql = "drop database foo;"
- assert confirm_destructive_query(sql) is None
+ assert confirm_destructive_query(sql, "all") is None
diff --git a/tests/test_rowlimit.py b/tests/test_rowlimit.py
index e76ea04..947fc80 100644
--- a/tests/test_rowlimit.py
+++ b/tests/test_rowlimit.py
@@ -1,5 +1,5 @@
import pytest
-from mock import Mock
+from unittest.mock import Mock
from pgcli.main import PGCli
diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py
index 805b727..5c9c9af 100644
--- a/tests/test_smart_completion_multiple_schemata.py
+++ b/tests/test_smart_completion_multiple_schemata.py
@@ -193,7 +193,7 @@ def test_suggested_joins(completer, query, tbl):
result = get_result(completer, query.format(tbl))
assert completions_to_set(result) == completions_to_set(
testdata.schemas_and_from_clause_items()
- + [join("custom.shipments ON shipments.user_id = {0}.id".format(tbl))]
+ + [join(f"custom.shipments ON shipments.user_id = {tbl}.id")]
)
@@ -350,6 +350,36 @@ def test_schema_qualified_function_name(completer):
)
+@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
+def test_schema_qualified_function_name_after_from(completer):
+ text = "SELECT * FROM custom.set_r"
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [
+ function("set_returning_func()", -len("func")),
+ ]
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
+def test_unqualified_function_name_not_returned(completer):
+ text = "SELECT * FROM set_r"
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set([])
+
+
+@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
+def test_unqualified_function_name_in_search_path(completer):
+ completer.search_path = ["public", "custom"]
+ text = "SELECT * FROM set_r"
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [
+ function("set_returning_func()", -len("func")),
+ ]
+ )
+
+
@parametrize("completer", completers(filtr=True, casing=False))
@parametrize(
"text",
diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py
index b935709..db1fe0a 100644
--- a/tests/test_smart_completion_public_schema_only.py
+++ b/tests/test_smart_completion_public_schema_only.py
@@ -53,7 +53,7 @@ metadata = {
],
}
-metadata = dict((k, {"public": v}) for k, v in metadata.items())
+metadata = {k: {"public": v} for k, v in metadata.items()}
testdata = MetaData(metadata)
@@ -296,7 +296,7 @@ def test_suggested_cased_always_qualified_column_names(completer):
def test_suggested_column_names_in_function(completer):
result = get_result(completer, "SELECT MAX( from users", len("SELECT MAX("))
assert completions_to_set(result) == completions_to_set(
- (testdata.columns_functions_and_keywords("users"))
+ testdata.columns_functions_and_keywords("users")
)
@@ -316,7 +316,7 @@ def test_suggested_column_names_with_alias(completer):
def test_suggested_multiple_column_names(completer):
result = get_result(completer, "SELECT id, from users u", len("SELECT id, "))
assert completions_to_set(result) == completions_to_set(
- (testdata.columns_functions_and_keywords("users"))
+ testdata.columns_functions_and_keywords("users")
)
diff --git a/tests/test_sqlcompletion.py b/tests/test_sqlcompletion.py
index 3cbad0a..744fadb 100644
--- a/tests/test_sqlcompletion.py
+++ b/tests/test_sqlcompletion.py
@@ -23,16 +23,14 @@ def cols_etc(
):
"""Returns the expected select-clause suggestions for a single-table
select."""
- return set(
- [
- Column(
- table_refs=(TableReference(schema, table, alias, is_function),),
- qualifiable=True,
- ),
- Function(schema=parent),
- Keyword(last_keyword),
- ]
- )
+ return {
+ Column(
+ table_refs=(TableReference(schema, table, alias, is_function),),
+ qualifiable=True,
+ ),
+ Function(schema=parent),
+ Keyword(last_keyword),
+ }
def test_select_suggests_cols_with_visible_table_scope():
@@ -103,24 +101,20 @@ def test_where_equals_any_suggests_columns_or_keywords():
def test_lparen_suggests_cols_and_funcs():
suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(")
- assert set(suggestion) == set(
- [
- Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
- Function(schema=None),
- Keyword("("),
- ]
- )
+ assert set(suggestion) == {
+ Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
+ Function(schema=None),
+ Keyword("("),
+ }
def test_select_suggests_cols_and_funcs():
suggestions = suggest_type("SELECT ", "SELECT ")
- assert set(suggestions) == set(
- [
- Column(table_refs=(), qualifiable=True),
- Function(schema=None),
- Keyword("SELECT"),
- ]
- )
+ assert set(suggestions) == {
+ Column(table_refs=(), qualifiable=True),
+ Function(schema=None),
+ Keyword("SELECT"),
+ }
@pytest.mark.parametrize(
@@ -128,13 +122,13 @@ def test_select_suggests_cols_and_funcs():
)
def test_suggests_tables_views_and_schemas(expression):
suggestions = suggest_type(expression, expression)
- assert set(suggestions) == set([Table(schema=None), View(schema=None), Schema()])
+ assert set(suggestions) == {Table(schema=None), View(schema=None), Schema()}
@pytest.mark.parametrize("expression", ["SELECT * FROM "])
def test_suggest_tables_views_schemas_and_functions(expression):
suggestions = suggest_type(expression, expression)
- assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
+ assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
@pytest.mark.parametrize(
@@ -147,9 +141,11 @@ def test_suggest_tables_views_schemas_and_functions(expression):
def test_suggest_after_join_with_two_tables(expression):
suggestions = suggest_type(expression, expression)
tables = tuple([(None, "foo", None, False), (None, "bar", None, False)])
- assert set(suggestions) == set(
- [FromClauseItem(schema=None, table_refs=tables), Join(tables, None), Schema()]
- )
+ assert set(suggestions) == {
+ FromClauseItem(schema=None, table_refs=tables),
+ Join(tables, None),
+ Schema(),
+ }
@pytest.mark.parametrize(
@@ -158,13 +154,11 @@ def test_suggest_after_join_with_two_tables(expression):
def test_suggest_after_join_with_one_table(expression):
suggestions = suggest_type(expression, expression)
tables = ((None, "foo", None, False),)
- assert set(suggestions) == set(
- [
- FromClauseItem(schema=None, table_refs=tables),
- Join(((None, "foo", None, False),), None),
- Schema(),
- ]
- )
+ assert set(suggestions) == {
+ FromClauseItem(schema=None, table_refs=tables),
+ Join(((None, "foo", None, False),), None),
+ Schema(),
+ }
@pytest.mark.parametrize(
@@ -172,13 +166,13 @@ def test_suggest_after_join_with_one_table(expression):
)
def test_suggest_qualified_tables_and_views(expression):
suggestions = suggest_type(expression, expression)
- assert set(suggestions) == set([Table(schema="sch"), View(schema="sch")])
+ assert set(suggestions) == {Table(schema="sch"), View(schema="sch")}
@pytest.mark.parametrize("expression", ["UPDATE sch."])
def test_suggest_qualified_aliasable_tables_and_views(expression):
suggestions = suggest_type(expression, expression)
- assert set(suggestions) == set([Table(schema="sch"), View(schema="sch")])
+ assert set(suggestions) == {Table(schema="sch"), View(schema="sch")}
@pytest.mark.parametrize(
@@ -193,26 +187,27 @@ def test_suggest_qualified_aliasable_tables_and_views(expression):
)
def test_suggest_qualified_tables_views_and_functions(expression):
suggestions = suggest_type(expression, expression)
- assert set(suggestions) == set([FromClauseItem(schema="sch")])
+ assert set(suggestions) == {FromClauseItem(schema="sch")}
@pytest.mark.parametrize("expression", ["SELECT * FROM foo JOIN sch."])
def test_suggest_qualified_tables_views_functions_and_joins(expression):
suggestions = suggest_type(expression, expression)
tbls = tuple([(None, "foo", None, False)])
- assert set(suggestions) == set(
- [FromClauseItem(schema="sch", table_refs=tbls), Join(tbls, "sch")]
- )
+ assert set(suggestions) == {
+ FromClauseItem(schema="sch", table_refs=tbls),
+ Join(tbls, "sch"),
+ }
def test_truncate_suggests_tables_and_schemas():
suggestions = suggest_type("TRUNCATE ", "TRUNCATE ")
- assert set(suggestions) == set([Table(schema=None), Schema()])
+ assert set(suggestions) == {Table(schema=None), Schema()}
def test_truncate_suggests_qualified_tables():
suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.")
- assert set(suggestions) == set([Table(schema="sch")])
+ assert set(suggestions) == {Table(schema="sch")}
@pytest.mark.parametrize(
@@ -220,13 +215,11 @@ def test_truncate_suggests_qualified_tables():
)
def test_distinct_suggests_cols(text):
suggestions = suggest_type(text, text)
- assert set(suggestions) == set(
- [
- Column(table_refs=(), local_tables=(), qualifiable=True),
- Function(schema=None),
- Keyword("DISTINCT"),
- ]
- )
+ assert set(suggestions) == {
+ Column(table_refs=(), local_tables=(), qualifiable=True),
+ Function(schema=None),
+ Keyword("DISTINCT"),
+ }
@pytest.mark.parametrize(
@@ -244,20 +237,18 @@ def test_distinct_and_order_by_suggestions_with_aliases(
text, text_before, last_keyword
):
suggestions = suggest_type(text, text_before)
- assert set(suggestions) == set(
- [
- Column(
- table_refs=(
- TableReference(None, "tbl", "x", False),
- TableReference(None, "tbl1", "y", False),
- ),
- local_tables=(),
- qualifiable=True,
+ assert set(suggestions) == {
+ Column(
+ table_refs=(
+ TableReference(None, "tbl", "x", False),
+ TableReference(None, "tbl1", "y", False),
),
- Function(schema=None),
- Keyword(last_keyword),
- ]
- )
+ local_tables=(),
+ qualifiable=True,
+ ),
+ Function(schema=None),
+ Keyword(last_keyword),
+ }
@pytest.mark.parametrize(
@@ -272,56 +263,50 @@ def test_distinct_and_order_by_suggestions_with_aliases(
)
def test_distinct_and_order_by_suggestions_with_alias_given(text, text_before):
suggestions = suggest_type(text, text_before)
- assert set(suggestions) == set(
- [
- Column(
- table_refs=(TableReference(None, "tbl", "x", False),),
- local_tables=(),
- qualifiable=False,
- ),
- Table(schema="x"),
- View(schema="x"),
- Function(schema="x"),
- ]
- )
+ assert set(suggestions) == {
+ Column(
+ table_refs=(TableReference(None, "tbl", "x", False),),
+ local_tables=(),
+ qualifiable=False,
+ ),
+ Table(schema="x"),
+ View(schema="x"),
+ Function(schema="x"),
+ }
def test_function_arguments_with_alias_given():
suggestions = suggest_type("SELECT avg(x. FROM tbl x, tbl2 y", "SELECT avg(x.")
- assert set(suggestions) == set(
- [
- Column(
- table_refs=(TableReference(None, "tbl", "x", False),),
- local_tables=(),
- qualifiable=False,
- ),
- Table(schema="x"),
- View(schema="x"),
- Function(schema="x"),
- ]
- )
+ assert set(suggestions) == {
+ Column(
+ table_refs=(TableReference(None, "tbl", "x", False),),
+ local_tables=(),
+ qualifiable=False,
+ ),
+ Table(schema="x"),
+ View(schema="x"),
+ Function(schema="x"),
+ }
def test_col_comma_suggests_cols():
suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,")
- assert set(suggestions) == set(
- [
- Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
- Function(schema=None),
- Keyword("SELECT"),
- ]
- )
+ assert set(suggestions) == {
+ Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
+ Function(schema=None),
+ Keyword("SELECT"),
+ }
def test_table_comma_suggests_tables_and_schemas():
suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ")
- assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
+ assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
def test_into_suggests_tables_and_schemas():
suggestion = suggest_type("INSERT INTO ", "INSERT INTO ")
- assert set(suggestion) == set([Table(schema=None), View(schema=None), Schema()])
+ assert set(suggestion) == {Table(schema=None), View(schema=None), Schema()}
@pytest.mark.parametrize(
@@ -357,14 +342,12 @@ def test_partially_typed_col_name_suggests_col_names():
def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.")
- assert set(suggestions) == set(
- [
- Column(table_refs=((None, "tabl", None, False),)),
- Table(schema="tabl"),
- View(schema="tabl"),
- Function(schema="tabl"),
- ]
- )
+ assert set(suggestions) == {
+ Column(table_refs=((None, "tabl", None, False),)),
+ Table(schema="tabl"),
+ View(schema="tabl"),
+ Function(schema="tabl"),
+ }
@pytest.mark.parametrize(
@@ -378,14 +361,12 @@ def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
)
def test_dot_suggests_cols_of_an_alias(sql):
suggestions = suggest_type(sql, "SELECT t1.")
- assert set(suggestions) == set(
- [
- Table(schema="t1"),
- View(schema="t1"),
- Column(table_refs=((None, "tabl1", "t1", False),)),
- Function(schema="t1"),
- ]
- )
+ assert set(suggestions) == {
+ Table(schema="t1"),
+ View(schema="t1"),
+ Column(table_refs=((None, "tabl1", "t1", False),)),
+ Function(schema="t1"),
+ }
@pytest.mark.parametrize(
@@ -399,28 +380,24 @@ def test_dot_suggests_cols_of_an_alias(sql):
)
def test_dot_suggests_cols_of_an_alias_where(sql):
suggestions = suggest_type(sql, sql)
- assert set(suggestions) == set(
- [
- Table(schema="t1"),
- View(schema="t1"),
- Column(table_refs=((None, "tabl1", "t1", False),)),
- Function(schema="t1"),
- ]
- )
+ assert set(suggestions) == {
+ Table(schema="t1"),
+ View(schema="t1"),
+ Column(table_refs=((None, "tabl1", "t1", False),)),
+ Function(schema="t1"),
+ }
def test_dot_col_comma_suggests_cols_or_schema_qualified_table():
suggestions = suggest_type(
"SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2."
)
- assert set(suggestions) == set(
- [
- Column(table_refs=((None, "tabl2", "t2", False),)),
- Table(schema="t2"),
- View(schema="t2"),
- Function(schema="t2"),
- ]
- )
+ assert set(suggestions) == {
+ Column(table_refs=((None, "tabl2", "t2", False),)),
+ Table(schema="t2"),
+ View(schema="t2"),
+ Function(schema="t2"),
+ }
@pytest.mark.parametrize(
@@ -452,20 +429,18 @@ def test_sub_select_partial_text_suggests_keyword(expression):
def test_outer_table_reference_in_exists_subquery_suggests_columns():
q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f."
suggestions = suggest_type(q, q)
- assert set(suggestions) == set(
- [
- Column(table_refs=((None, "foo", "f", False),)),
- Table(schema="f"),
- View(schema="f"),
- Function(schema="f"),
- ]
- )
+ assert set(suggestions) == {
+ Column(table_refs=((None, "foo", "f", False),)),
+ Table(schema="f"),
+ View(schema="f"),
+ Function(schema="f"),
+ }
@pytest.mark.parametrize("expression", ["SELECT * FROM (SELECT * FROM "])
def test_sub_select_table_name_completion(expression):
suggestion = suggest_type(expression, expression)
- assert set(suggestion) == set([FromClauseItem(schema=None), Schema()])
+ assert set(suggestion) == {FromClauseItem(schema=None), Schema()}
@pytest.mark.parametrize(
@@ -478,22 +453,18 @@ def test_sub_select_table_name_completion(expression):
def test_sub_select_table_name_completion_with_outer_table(expression):
suggestion = suggest_type(expression, expression)
tbls = tuple([(None, "foo", None, False)])
- assert set(suggestion) == set(
- [FromClauseItem(schema=None, table_refs=tbls), Schema()]
- )
+ assert set(suggestion) == {FromClauseItem(schema=None, table_refs=tbls), Schema()}
def test_sub_select_col_name_completion():
suggestions = suggest_type(
"SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT "
)
- assert set(suggestions) == set(
- [
- Column(table_refs=((None, "abc", None, False),), qualifiable=True),
- Function(schema=None),
- Keyword("SELECT"),
- ]
- )
+ assert set(suggestions) == {
+ Column(table_refs=((None, "abc", None, False),), qualifiable=True),
+ Function(schema=None),
+ Keyword("SELECT"),
+ }
@pytest.mark.xfail
@@ -508,25 +479,25 @@ def test_sub_select_dot_col_name_completion():
suggestions = suggest_type(
"SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t."
)
- assert set(suggestions) == set(
- [
- Column(table_refs=((None, "tabl", "t", False),)),
- Table(schema="t"),
- View(schema="t"),
- Function(schema="t"),
- ]
- )
+ assert set(suggestions) == {
+ Column(table_refs=((None, "tabl", "t", False),)),
+ Table(schema="t"),
+ View(schema="t"),
+ Function(schema="t"),
+ }
@pytest.mark.parametrize("join_type", ("", "INNER", "LEFT", "RIGHT OUTER"))
@pytest.mark.parametrize("tbl_alias", ("", "foo"))
def test_join_suggests_tables_and_schemas(tbl_alias, join_type):
- text = "SELECT * FROM abc {0} {1} JOIN ".format(tbl_alias, join_type)
+ text = f"SELECT * FROM abc {tbl_alias} {join_type} JOIN "
suggestion = suggest_type(text, text)
tbls = tuple([(None, "abc", tbl_alias or None, False)])
- assert set(suggestion) == set(
- [FromClauseItem(schema=None, table_refs=tbls), Schema(), Join(tbls, None)]
- )
+ assert set(suggestion) == {
+ FromClauseItem(schema=None, table_refs=tbls),
+ Schema(),
+ Join(tbls, None),
+ }
def test_left_join_with_comma():
@@ -535,9 +506,7 @@ def test_left_join_with_comma():
# tbls should also include (None, 'bar', 'b', False)
# but there's a bug with commas
tbls = tuple([(None, "foo", "f", False)])
- assert set(suggestions) == set(
- [FromClauseItem(schema=None, table_refs=tbls), Schema()]
- )
+ assert set(suggestions) == {FromClauseItem(schema=None, table_refs=tbls), Schema()}
@pytest.mark.parametrize(
@@ -550,15 +519,13 @@ def test_left_join_with_comma():
def test_join_alias_dot_suggests_cols1(sql):
suggestions = suggest_type(sql, sql)
tables = ((None, "abc", "a", False), (None, "def", "d", False))
- assert set(suggestions) == set(
- [
- Column(table_refs=((None, "abc", "a", False),)),
- Table(schema="a"),
- View(schema="a"),
- Function(schema="a"),
- JoinCondition(table_refs=tables, parent=(None, "abc", "a", False)),
- ]
- )
+ assert set(suggestions) == {
+ Column(table_refs=((None, "abc", "a", False),)),
+ Table(schema="a"),
+ View(schema="a"),
+ Function(schema="a"),
+ JoinCondition(table_refs=tables, parent=(None, "abc", "a", False)),
+ }
@pytest.mark.parametrize(
@@ -570,14 +537,12 @@ def test_join_alias_dot_suggests_cols1(sql):
)
def test_join_alias_dot_suggests_cols2(sql):
suggestion = suggest_type(sql, sql)
- assert set(suggestion) == set(
- [
- Column(table_refs=((None, "def", "d", False),)),
- Table(schema="d"),
- View(schema="d"),
- Function(schema="d"),
- ]
- )
+ assert set(suggestion) == {
+ Column(table_refs=((None, "def", "d", False),)),
+ Table(schema="d"),
+ View(schema="d"),
+ Function(schema="d"),
+ }
@pytest.mark.parametrize(
@@ -598,9 +563,10 @@ on """,
def test_on_suggests_aliases_and_join_conditions(sql):
suggestions = suggest_type(sql, sql)
tables = ((None, "abc", "a", False), (None, "bcd", "b", False))
- assert set(suggestions) == set(
- (JoinCondition(table_refs=tables, parent=None), Alias(aliases=("a", "b")))
- )
+ assert set(suggestions) == {
+ JoinCondition(table_refs=tables, parent=None),
+ Alias(aliases=("a", "b")),
+ }
@pytest.mark.parametrize(
@@ -613,9 +579,10 @@ def test_on_suggests_aliases_and_join_conditions(sql):
def test_on_suggests_tables_and_join_conditions(sql):
suggestions = suggest_type(sql, sql)
tables = ((None, "abc", None, False), (None, "bcd", None, False))
- assert set(suggestions) == set(
- (JoinCondition(table_refs=tables, parent=None), Alias(aliases=("abc", "bcd")))
- )
+ assert set(suggestions) == {
+ JoinCondition(table_refs=tables, parent=None),
+ Alias(aliases=("abc", "bcd")),
+ }
@pytest.mark.parametrize(
@@ -640,9 +607,10 @@ def test_on_suggests_aliases_right_side(sql):
def test_on_suggests_tables_and_join_conditions_right_side(sql):
suggestions = suggest_type(sql, sql)
tables = ((None, "abc", None, False), (None, "bcd", None, False))
- assert set(suggestions) == set(
- (JoinCondition(table_refs=tables, parent=None), Alias(aliases=("abc", "bcd")))
- )
+ assert set(suggestions) == {
+ JoinCondition(table_refs=tables, parent=None),
+ Alias(aliases=("abc", "bcd")),
+ }
@pytest.mark.parametrize(
@@ -659,9 +627,9 @@ def test_on_suggests_tables_and_join_conditions_right_side(sql):
)
def test_join_using_suggests_common_columns(text):
tables = ((None, "abc", None, False), (None, "def", None, False))
- assert set(suggest_type(text, text)) == set(
- [Column(table_refs=tables, require_last_table=True)]
- )
+ assert set(suggest_type(text, text)) == {
+ Column(table_refs=tables, require_last_table=True)
+ }
def test_suggest_columns_after_multiple_joins():
@@ -678,29 +646,27 @@ def test_2_statements_2nd_current():
suggestions = suggest_type(
"select * from a; select * from ", "select * from a; select * from "
)
- assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
+ assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
suggestions = suggest_type(
"select * from a; select from b", "select * from a; select "
)
- assert set(suggestions) == set(
- [
- Column(table_refs=((None, "b", None, False),), qualifiable=True),
- Function(schema=None),
- Keyword("SELECT"),
- ]
- )
+ assert set(suggestions) == {
+ Column(table_refs=((None, "b", None, False),), qualifiable=True),
+ Function(schema=None),
+ Keyword("SELECT"),
+ }
# Should work even if first statement is invalid
suggestions = suggest_type(
"select * from; select * from ", "select * from; select * from "
)
- assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
+ assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
def test_2_statements_1st_current():
suggestions = suggest_type("select * from ; select * from b", "select * from ")
- assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
+ assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
suggestions = suggest_type("select from a; select * from b", "select ")
assert set(suggestions) == cols_etc("a", last_keyword="SELECT")
@@ -711,7 +677,7 @@ def test_3_statements_2nd_current():
"select * from a; select * from ; select * from c",
"select * from a; select * from ",
)
- assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
+ assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
suggestions = suggest_type(
"select * from a; select from b; select * from c", "select * from a; select "
@@ -768,13 +734,11 @@ SELECT * FROM qux;
)
def test_statements_in_function_body(text):
suggestions = suggest_type(text, text[: text.find(" ") + 1])
- assert set(suggestions) == set(
- [
- Column(table_refs=((None, "foo", None, False),), qualifiable=True),
- Function(schema=None),
- Keyword("SELECT"),
- ]
- )
+ assert set(suggestions) == {
+ Column(table_refs=((None, "foo", None, False),), qualifiable=True),
+ Function(schema=None),
+ Keyword("SELECT"),
+ }
functions = [
@@ -799,13 +763,13 @@ SELECT 1 FROM foo;
@pytest.mark.parametrize("text", functions)
def test_statements_with_cursor_after_function_body(text):
suggestions = suggest_type(text, text[: text.find("; ") + 1])
- assert set(suggestions) == set([Keyword(), Special()])
+ assert set(suggestions) == {Keyword(), Special()}
@pytest.mark.parametrize("text", functions)
def test_statements_with_cursor_before_function_body(text):
suggestions = suggest_type(text, "")
- assert set(suggestions) == set([Keyword(), Special()])
+ assert set(suggestions) == {Keyword(), Special()}
def test_create_db_with_template():
@@ -813,14 +777,14 @@ def test_create_db_with_template():
"create database foo with template ", "create database foo with template "
)
- assert set(suggestions) == set((Database(),))
+ assert set(suggestions) == {Database()}
@pytest.mark.parametrize("initial_text", ("", " ", "\t \t", "\n"))
def test_specials_included_for_initial_completion(initial_text):
suggestions = suggest_type(initial_text, initial_text)
- assert set(suggestions) == set([Keyword(), Special()])
+ assert set(suggestions) == {Keyword(), Special()}
def test_drop_schema_qualified_table_suggests_only_tables():
@@ -843,25 +807,30 @@ def test_drop_schema_suggests_schemas():
@pytest.mark.parametrize("text", ["SELECT x::", "SELECT x::y", "SELECT (x + y)::"])
def test_cast_operator_suggests_types(text):
- assert set(suggest_type(text, text)) == set(
- [Datatype(schema=None), Table(schema=None), Schema()]
- )
+ assert set(suggest_type(text, text)) == {
+ Datatype(schema=None),
+ Table(schema=None),
+ Schema(),
+ }
@pytest.mark.parametrize(
"text", ["SELECT foo::bar.", "SELECT foo::bar.baz", "SELECT (x + y)::bar."]
)
def test_cast_operator_suggests_schema_qualified_types(text):
- assert set(suggest_type(text, text)) == set(
- [Datatype(schema="bar"), Table(schema="bar")]
- )
+ assert set(suggest_type(text, text)) == {
+ Datatype(schema="bar"),
+ Table(schema="bar"),
+ }
def test_alter_column_type_suggests_types():
q = "ALTER TABLE foo ALTER COLUMN bar TYPE "
- assert set(suggest_type(q, q)) == set(
- [Datatype(schema=None), Table(schema=None), Schema()]
- )
+ assert set(suggest_type(q, q)) == {
+ Datatype(schema=None),
+ Table(schema=None),
+ Schema(),
+ }
@pytest.mark.parametrize(
@@ -880,9 +849,11 @@ def test_alter_column_type_suggests_types():
],
)
def test_identifier_suggests_types_in_parentheses(text):
- assert set(suggest_type(text, text)) == set(
- [Datatype(schema=None), Table(schema=None), Schema()]
- )
+ assert set(suggest_type(text, text)) == {
+ Datatype(schema=None),
+ Table(schema=None),
+ Schema(),
+ }
@pytest.mark.parametrize(
@@ -977,7 +948,7 @@ def test_ignore_leading_double_quotes(sql):
)
def test_column_keyword_suggests_columns(sql):
suggestions = suggest_type(sql, sql)
- assert set(suggestions) == set([Column(table_refs=((None, "foo", None, False),))])
+ assert set(suggestions) == {Column(table_refs=((None, "foo", None, False),))}
def test_handle_unrecognized_kw_generously():
diff --git a/tests/utils.py b/tests/utils.py
index 2427c30..460ea46 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -8,7 +8,7 @@ from os import getenv
POSTGRES_USER = getenv("PGUSER", "postgres")
POSTGRES_HOST = getenv("PGHOST", "localhost")
POSTGRES_PORT = getenv("PGPORT", 5432)
-POSTGRES_PASSWORD = getenv("PGPASSWORD", "")
+POSTGRES_PASSWORD = getenv("PGPASSWORD", "postgres")
def db_connection(dbname=None):
@@ -73,7 +73,7 @@ def drop_tables(conn):
def run(
executor, sql, join=False, expanded=False, pgspecial=None, exception_formatter=None
):
- " Return string output for the sql to be run "
+ "Return string output for the sql to be run"
results = executor.run(sql, pgspecial, exception_formatter)
formatted = []
@@ -89,7 +89,7 @@ def run(
def completions_to_set(completions):
- return set(
+ return {
(completion.display_text, completion.display_meta_text)
for completion in completions
- )
+ }