diff options
50 files changed, 953 insertions, 635 deletions
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..ce54d6f --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,66 @@ +name: pgcli + +on: + pull_request: + paths-ignore: + - '**.rst' + +jobs: + build: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: [3.6, 3.7, 3.8, 3.9] + + services: + postgres: + image: postgres:9.6 + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install requirements + run: | + pip install -U pip setuptools + pip install --no-cache-dir . + pip install -r requirements-dev.txt + pip install keyrings.alt>=3.1 + + - name: Run unit tests + run: coverage run --source pgcli -m py.test + + - name: Run integration tests + env: + PGUSER: postgres + PGPASSWORD: postgres + + run: behave tests/features --no-capture + + - name: Check changelog for ReST compliance + run: rst2html.py --halt=warning changelog.rst >/dev/null + + - name: Run Black + run: pip install black && black --check . + if: matrix.python-version == '3.6' + + - name: Coverage + run: | + coverage combine + coverage report + codecov diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b970ac5..9e27ab8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/psf/black - rev: stable + rev: 21.5b0 hooks: - id: black language_version: python3.7 diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 8d50fbd..0000000 --- a/.travis.yml +++ /dev/null @@ -1,51 +0,0 @@ -dist: xenial - -sudo: required - -language: python - -python: - - "3.6" - - "3.7" - - "3.8" - - "3.9-dev" - -before_install: - - which python - - which pip - - pip install -U setuptools - -install: - - pip install --no-cache-dir . - - pip install -r requirements-dev.txt - - pip install keyrings.alt>=3.1 - -script: - - set -e - - coverage run --source pgcli -m py.test - - cd tests - - behave --no-capture - - cd .. - # check for changelog ReST compliance - - rst2html.py --halt=warning changelog.rst >/dev/null - # check for black code compliance, 3.6 only - - if [[ "$TRAVIS_PYTHON_VERSION" == "3.6" ]]; then pip install black && black --check . ; else echo "Skipping black for $TRAVIS_PYTHON_VERSION"; fi - - set +e - -after_success: - - coverage combine - - codecov - -notifications: - webhooks: - urls: - - YOUR_WEBHOOK_URL - on_success: change # options: [always|never|change] default: always - on_failure: always # options: [always|never|change] default: always - on_start: false # default: false - -services: - - postgresql - -addons: - postgresql: "9.6" @@ -114,6 +114,10 @@ Contributors: * Tom Caruso (tomplex) * Jan Brun Rasmussen (janbrunrasmussen) * Kevin Marsh (kevinmarsh) + * Eero Ruohola (ruohola) + * Miroslav Šedivý (eumiro) + * Eric R Young (ERYoung11) + * Paweł Sacawa (psacawa) Creator: -------- diff --git a/DEVELOP.rst b/DEVELOP.rst index 18adf9c..e262823 100644 --- a/DEVELOP.rst +++ b/DEVELOP.rst @@ -170,7 +170,7 @@ Troubleshooting the integration tests - Make sure postgres instance on localhost is running - Check your ``pg_hba.conf`` file to verify local connections are enabled - Check `this issue <https://github.com/dbcli/pgcli/issues/945>`_ for relevant information. -- Contact us on `gitter <https://gitter.im/dbcli/pgcli/>`_ or `file an issue <https://github.com/dbcli/pgcli/issues/new>`_. +- `File an issue <https://github.com/dbcli/pgcli/issues/new>`_. Coding Style ------------ @@ -1,7 +1,7 @@ A REPL for Postgres ------------------- -|Build Status| |CodeCov| |PyPI| |Landscape| |Gitter| +|Build Status| |CodeCov| |PyPI| |Landscape| This is a postgres client that does auto-completion and syntax highlighting. @@ -62,32 +62,32 @@ For more details: Usage: pgcli [OPTIONS] [DBNAME] [USERNAME] Options: - -h, --host TEXT Host address of the postgres database. - -p, --port INTEGER Port number at which the postgres instance is - listening. - -U, --username TEXT Username to connect to the postgres database. - -u, --user TEXT Username to connect to the postgres database. - -W, --password Force password prompt. - -w, --no-password Never prompt for password. - --single-connection Do not use a separate connection for completions. - -v, --version Version of pgcli. - -d, --dbname TEXT database name to connect to. - --pgclirc PATH Location of pgclirc file. - -D, --dsn TEXT Use DSN configured into the [alias_dsn] section of - pgclirc file. - --list-dsn list of DSN configured into the [alias_dsn] section - of pgclirc file. - --row-limit INTEGER Set threshold for row limit prompt. Use 0 to disable - prompt. - --less-chatty Skip intro on startup and goodbye on exit. - --prompt TEXT Prompt format (Default: "\u@\h:\d> "). - --prompt-dsn TEXT Prompt format for connections using DSN aliases - (Default: "\u@\h:\d> "). - -l, --list list available databases, then exit. - --auto-vertical-output Automatically switch to vertical output mode if the - result is wider than the terminal width. - --warn / --no-warn Warn before running a destructive query. - --help Show this message and exit. + -h, --host TEXT Host address of the postgres database. + -p, --port INTEGER Port number at which the postgres instance is + listening. + -U, --username TEXT Username to connect to the postgres database. + -u, --user TEXT Username to connect to the postgres database. + -W, --password Force password prompt. + -w, --no-password Never prompt for password. + --single-connection Do not use a separate connection for completions. + -v, --version Version of pgcli. + -d, --dbname TEXT database name to connect to. + --pgclirc FILE Location of pgclirc file. + -D, --dsn TEXT Use DSN configured into the [alias_dsn] section + of pgclirc file. + --list-dsn list of DSN configured into the [alias_dsn] + section of pgclirc file. + --row-limit INTEGER Set threshold for row limit prompt. Use 0 to + disable prompt. + --less-chatty Skip intro on startup and goodbye on exit. + --prompt TEXT Prompt format (Default: "\u@\h:\d> "). + --prompt-dsn TEXT Prompt format for connections using DSN aliases + (Default: "\u@\h:\d> "). + -l, --list list available databases, then exit. + --auto-vertical-output Automatically switch to vertical output mode if + the result is wider than the terminal width. + --warn [all|moderate|off] Warn before running a destructive query. + --help Show this message and exit. ``pgcli`` also supports many of the same `environment variables`_ as ``psql`` for login options (e.g. ``PGHOST``, ``PGPORT``, ``PGUSER``, ``PGPASSWORD``, ``PGDATABASE``). @@ -352,8 +352,8 @@ interface to Postgres database. Thanks to all the beta testers and contributors for your time and patience. :) -.. |Build Status| image:: https://api.travis-ci.org/dbcli/pgcli.svg?branch=master - :target: https://travis-ci.org/dbcli/pgcli +.. |Build Status| image:: https://github.com/dbcli/pgcli/workflows/pgcli/badge.svg + :target: https://github.com/dbcli/pgcli/actions?query=workflow%3Apgcli .. |CodeCov| image:: https://codecov.io/gh/dbcli/pgcli/branch/master/graph/badge.svg :target: https://codecov.io/gh/dbcli/pgcli @@ -366,7 +366,3 @@ Thanks to all the beta testers and contributors for your time and patience. :) .. |PyPI| image:: https://img.shields.io/pypi/v/pgcli.svg :target: https://pypi.python.org/pypi/pgcli/ :alt: Latest Version - -.. |Gitter| image:: https://badges.gitter.im/Join%20Chat.svg - :target: https://gitter.im/dbcli/pgcli?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge - :alt: Gitter Chat diff --git a/Vagrantfile b/Vagrantfile index 0313520..297e70a 100644 --- a/Vagrantfile +++ b/Vagrantfile @@ -1,5 +1,7 @@ # -*- mode: ruby -*- # vi: set ft=ruby : +# +# Vagrant.configure(2) do |config| @@ -9,20 +11,23 @@ Vagrant.configure(2) do |config| pgcli_description = "Postgres CLI with autocompletion and syntax highlighting" config.vm.define "debian" do |debian| - debian.vm.box = "chef/debian-7.8" + debian.vm.box = "bento/debian-10.8" debian.vm.provision "shell", inline: <<-SHELL - echo "-> Building DEB on `lsb_release -s`" + echo "-> Building DEB on `lsb_release -d`" sudo apt-get update sudo apt-get install -y libpq-dev python-dev python-setuptools rubygems - sudo easy_install pip - sudo pip install virtualenv virtualenv-tools + sudo apt install -y python3-pip + sudo pip3 install --no-cache-dir virtualenv virtualenv-tools3 + sudo apt-get install -y ruby-dev + sudo apt-get install -y git + sudo apt-get install -y rpm librpmbuild8 + sudo gem install fpm + echo "-> Cleaning up old workspace" - rm -rf build + sudo rm -rf build mkdir -p build/usr/share virtualenv build/usr/share/pgcli - build/usr/share/pgcli/bin/pip install -U pip distribute - build/usr/share/pgcli/bin/pip uninstall -y distribute build/usr/share/pgcli/bin/pip install /pgcli echo "-> Cleaning Virtualenv" @@ -45,24 +50,59 @@ Vagrant.configure(2) do |config| --url https://github.com/dbcli/pgcli \ --description "#{pgcli_description}" \ --license 'BSD' + SHELL end + +# This is considerably more messy than the debian section. I had to go off-standard to update +# some packages to get this to work. + config.vm.define "centos" do |centos| - centos.vm.box = "chef/centos-7.0" + + centos.vm.box = "bento/centos-7.9" + centos.vm.box_version = "202012.21.0" centos.vm.provision "shell", inline: <<-SHELL #!/bin/bash - echo "-> Building RPM on `lsb_release -s`" - sudo yum install -y rpm-build gcc ruby-devel postgresql-devel python-devel rubygems - sudo easy_install pip - sudo pip install virtualenv virtualenv-tools - sudo gem install fpm + echo "-> Building RPM on `hostnamectl | grep "Operating System"`" + export PATH=/usr/local/rvm/gems/ruby-2.6.3/bin:/usr/local/rvm/gems/ruby-2.6.3@global/bin:/usr/local/rvm/rubies/ruby-2.6.3/bin:/usr/local/sbin:/usr/local/bin:/sbin:/bin:/usr/sbin:/usr/bin:/usr/local/rvm/bin:/root/bin + echo "PATH -> " $PATH + +##### +### get base updates + + sudo yum install -y rpm-build gcc postgresql-devel python-devel python3-pip git python3-devel + +###### +### install FPM, which we need to install to get an up-to-date version of ruby, which we need for git + + echo "-> Get FPM installed" + # import the necessary GPG keys + gpg --keyserver hkp://pool.sks-keyservers.net --recv-keys 409B6B1796C275462A1703113804BB82D39DC0E3 7D2BAF1CF37B13E2069D6956105BD0E739499BDB + sudo gpg --keyserver hkp://pool.sks-keyservers.net --recv-keys 409B6B1796C275462A1703113804BB82D39DC0E3 7D2BAF1CF37B13E2069D6956105BD0E739499BDB + # install RVM + sudo curl -sSL https://get.rvm.io | sudo bash -s stable + sudo usermod -aG rvm vagrant + sudo usermod -aG rvm root + sudo /usr/local/rvm/bin/rvm alias create default 2.6.3 + source /etc/profile.d/rvm.sh + + # install a newer version of ruby. centos7 only comes with ruby2.0.0, which isn't good enough for git. + sudo yum install -y ruby-devel + sudo /usr/local/rvm/bin/rvm install 2.6.3 + + # + # yes,this gives an error about generating doc but we don't need the doc. + + /usr/local/rvm/gems/ruby-2.6.3/wrappers/gem install fpm + +###### + + sudo pip3 install virtualenv virtualenv-tools3 echo "-> Cleaning up old workspace" rm -rf build mkdir -p build/usr/share virtualenv build/usr/share/pgcli - build/usr/share/pgcli/bin/pip install -U pip distribute - build/usr/share/pgcli/bin/pip uninstall -y distribute build/usr/share/pgcli/bin/pip install /pgcli echo "-> Cleaning Virtualenv" @@ -74,9 +114,9 @@ Vagrant.configure(2) do |config| find build -iname '*.pyc' -delete find build -iname '*.pyo' -delete + cd /home/vagrant echo "-> Creating PgCLI RPM" - echo $PATH - sudo /usr/local/bin/fpm -t rpm -s dir -C build -n pgcli -v #{pgcli_version} \ + /usr/local/rvm/gems/ruby-2.6.3/gems/fpm-1.12.0/bin/fpm -t rpm -s dir -C build -n pgcli -v #{pgcli_version} \ -a all \ -d postgresql-devel \ -d python-devel \ @@ -86,8 +126,13 @@ Vagrant.configure(2) do |config| --url https://github.com/dbcli/pgcli \ --description "#{pgcli_description}" \ --license 'BSD' - SHELL + + + SHELL + + end + end diff --git a/changelog.rst b/changelog.rst index ec50635..d732088 100644 --- a/changelog.rst +++ b/changelog.rst @@ -1,3 +1,37 @@ +TBD +===== + +Features: +--------- + +Bug fixes: +---------- + +3.2.0 +===== + +Release date: 2021/08/23 + +Features: +--------- + +* Consider `update` queries destructive and issue a warning. Change + `destructive_warning` setting to `all|moderate|off`, vs `true|false`. (#1239) +* Skip initial comment in .pg_session even if it doesn't start with '#' +* Include functions from schemas in search_path. (`Amjith Ramanujam`_) + +Bug fixes: +---------- + +* Fix issue where `syntax_style` config value would not have any effect. (#1212) +* Fix crash because of not found `InputMode.REPLACE_SINGLE` with prompt-toolkit < 3.0.6 +* Fix comments being lost in config when saving a named query. (#1240) +* Fix IPython magic for ipython-sql >= 0.4.0 +* Fix pager not being used when output format is set to csv. (#1238) +* Add function literals random, generate_series, generate_subscripts +* Fix ANSI escape codes in first line make the cli choose expanded output incorrectly +* Fix pgcli crashing with virtual `pgbouncer` database. (#1093) + 3.1.0 ===== diff --git a/pgcli/__init__.py b/pgcli/__init__.py index f5f41e5..1173108 100644 --- a/pgcli/__init__.py +++ b/pgcli/__init__.py @@ -1 +1 @@ -__version__ = "3.1.0" +__version__ = "3.2.0" diff --git a/pgcli/completion_refresher.py b/pgcli/completion_refresher.py index cf0879f..1039d51 100644 --- a/pgcli/completion_refresher.py +++ b/pgcli/completion_refresher.py @@ -3,10 +3,9 @@ import os from collections import OrderedDict from .pgcompleter import PGCompleter -from .pgexecute import PGExecute -class CompletionRefresher(object): +class CompletionRefresher: refreshers = OrderedDict() @@ -27,6 +26,10 @@ class CompletionRefresher(object): has completed the refresh. The newly created completion object will be passed in as an argument to each callback. """ + if executor.is_virtual_database(): + # do nothing + return [(None, None, None, "Auto-completion refresh can't be started.")] + if self.is_refreshing(): self._restart_refresh.set() return [(None, None, None, "Auto-completion refresh restarted.")] @@ -141,7 +144,7 @@ def refresh_casing(completer, executor): with open(casing_file, "w") as f: f.write(casing_prefs) if os.path.isfile(casing_file): - with open(casing_file, "r") as f: + with open(casing_file) as f: completer.extend_casing([line.strip() for line in f]) diff --git a/pgcli/config.py b/pgcli/config.py index 0fc42dd..22f08dc 100644 --- a/pgcli/config.py +++ b/pgcli/config.py @@ -3,6 +3,8 @@ import shutil import os import platform from os.path import expanduser, exists, dirname +import re +from typing import TextIO from configobj import ConfigObj @@ -16,11 +18,15 @@ def config_location(): def load_config(usr_cfg, def_cfg=None): - cfg = ConfigObj() - cfg.merge(ConfigObj(def_cfg, interpolation=False)) - cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8")) + # avoid config merges when possible. For writing, we need an umerged config instance. + # see https://github.com/dbcli/pgcli/issues/1240 and https://github.com/DiffSK/configobj/issues/171 + if def_cfg: + cfg = ConfigObj() + cfg.merge(ConfigObj(def_cfg, interpolation=False)) + cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8")) + else: + cfg = ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8") cfg.filename = expanduser(usr_cfg) - return cfg @@ -44,12 +50,16 @@ def upgrade_config(config, def_config): cfg.write() +def get_config_filename(pgclirc_file=None): + return pgclirc_file or "%sconfig" % config_location() + + def get_config(pgclirc_file=None): from pgcli import __file__ as package_root package_root = os.path.dirname(package_root) - pgclirc_file = pgclirc_file or "%sconfig" % config_location() + pgclirc_file = get_config_filename(pgclirc_file) default_config = os.path.join(package_root, "pgclirc") write_default_config(default_config, pgclirc_file) @@ -62,3 +72,28 @@ def get_casing_file(config): if casing_file == "default": casing_file = config_location() + "casing" return casing_file + + +def skip_initial_comment(f_stream: TextIO) -> int: + """ + Initial comment in ~/.pg_service.conf is not always marked with '#' + which crashes the parser. This function takes a file object and + "rewinds" it to the beginning of the first section, + from where on it can be parsed safely + + :return: number of skipped lines + """ + section_regex = r"\s*\[" + pos = f_stream.tell() + lines_skipped = 0 + while True: + line = f_stream.readline() + if line == "": + break + if re.match(section_regex, line) is not None: + f_stream.seek(pos) + break + else: + pos += len(line) + lines_skipped += 1 + return lines_skipped diff --git a/pgcli/magic.py b/pgcli/magic.py index f58f415..6e58f28 100644 --- a/pgcli/magic.py +++ b/pgcli/magic.py @@ -25,7 +25,11 @@ def pgcli_line_magic(line): if hasattr(sql.connection.Connection, "get"): conn = sql.connection.Connection.get(parsed["connection"]) else: - conn = sql.connection.Connection.set(parsed["connection"]) + try: + conn = sql.connection.Connection.set(parsed["connection"]) + # a new positional argument was added to Connection.set in version 0.4.0 of ipython-sql + except TypeError: + conn = sql.connection.Connection.set(parsed["connection"], False) try: # A corresponding pgcli object already exists @@ -43,7 +47,7 @@ def pgcli_line_magic(line): conn._pgcli = pgcli # For convenience, print the connection alias - print("Connected: {}".format(conn.name)) + print(f"Connected: {conn.name}") try: pgcli.run_cli() diff --git a/pgcli/main.py b/pgcli/main.py index b146898..5135f6f 100644 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -2,8 +2,9 @@ import platform import warnings from os.path import expanduser -from configobj import ConfigObj +from configobj import ConfigObj, ParseError from pgspecial.namedqueries import NamedQueries +from .config import skip_initial_comment warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2") @@ -20,12 +21,12 @@ import datetime as dt import itertools import platform from time import time, sleep -from codecs import open keyring = None # keyring will be loaded later from cli_helpers.tabular_output import TabularOutputFormatter from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers +from cli_helpers.utils import strip_ansi import click try: @@ -62,6 +63,7 @@ from .config import ( config_location, ensure_dir_exists, get_config, + get_config_filename, ) from .key_bindings import pgcli_bindings from .packages.prompt_utils import confirm_destructive_query @@ -122,7 +124,7 @@ class PgCliQuitError(Exception): pass -class PGCli(object): +class PGCli: default_prompt = "\\u@\\h:\\d> " max_len_prompt = 30 @@ -175,7 +177,11 @@ class PGCli(object): # Load config. c = self.config = get_config(pgclirc_file) - NamedQueries.instance = NamedQueries.from_config(self.config) + # at this point, config should be written to pgclirc_file if it did not exist. Read it. + self.config_writer = load_config(get_config_filename(pgclirc_file)) + + # make sure to use self.config_writer, not self.config + NamedQueries.instance = NamedQueries.from_config(self.config_writer) self.logger = logging.getLogger(__name__) self.initialize_logging() @@ -201,8 +207,11 @@ class PGCli(object): self.syntax_style = c["main"]["syntax_style"] self.cli_style = c["colors"] self.wider_completion_menu = c["main"].as_bool("wider_completion_menu") - c_dest_warning = c["main"].as_bool("destructive_warning") - self.destructive_warning = c_dest_warning if warn is None else warn + self.destructive_warning = warn or c["main"]["destructive_warning"] + # also handle boolean format of destructive warning + self.destructive_warning = {"true": "all", "false": "off"}.get( + self.destructive_warning.lower(), self.destructive_warning + ) self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty") self.null_string = c["main"].get("null_string", "<null>") self.prompt_format = ( @@ -325,11 +334,11 @@ class PGCli(object): if pattern not in TabularOutputFormatter().supported_formats: raise ValueError() self.table_format = pattern - yield (None, None, None, "Changed table format to {}".format(pattern)) + yield (None, None, None, f"Changed table format to {pattern}") except ValueError: - msg = "Table format {} not recognized. Allowed formats:".format(pattern) + msg = f"Table format {pattern} not recognized. Allowed formats:" for table_type in TabularOutputFormatter().supported_formats: - msg += "\n\t{}".format(table_type) + msg += f"\n\t{table_type}" msg += "\nCurrently set to: %s" % self.table_format yield (None, None, None, msg) @@ -386,10 +395,13 @@ class PGCli(object): try: with open(os.path.expanduser(pattern), encoding="utf-8") as f: query = f.read() - except IOError as e: + except OSError as e: return [(None, None, None, str(e), "", False, True)] - if self.destructive_warning and confirm_destructive_query(query) is False: + if ( + self.destructive_warning != "off" + and confirm_destructive_query(query, self.destructive_warning) is False + ): message = "Wise choice. Command execution stopped." return [(None, None, None, message)] @@ -407,7 +419,7 @@ class PGCli(object): if not os.path.isfile(filename): try: open(filename, "w").close() - except IOError as e: + except OSError as e: self.output_file = None message = str(e) + "\nFile output disabled" return [(None, None, None, message, "", False, True)] @@ -479,7 +491,7 @@ class PGCli(object): service_config, file = parse_service_info(service) if service_config is None: click.secho( - "service '%s' was not found in %s" % (service, file), err=True, fg="red" + f"service '{service}' was not found in {file}", err=True, fg="red" ) exit(1) self.connect( @@ -515,7 +527,7 @@ class PGCli(object): passwd = os.environ.get("PGPASSWORD", "") # Find password from store - key = "%s@%s" % (user, host) + key = f"{user}@{host}" keyring_error_message = dedent( """\ {} @@ -644,8 +656,10 @@ class PGCli(object): query = MetaQuery(query=text, successful=False) try: - if self.destructive_warning: - destroy = confirm = confirm_destructive_query(text) + if self.destructive_warning != "off": + destroy = confirm = confirm_destructive_query( + text, self.destructive_warning + ) if destroy is False: click.secho("Wise choice!") raise KeyboardInterrupt @@ -677,7 +691,7 @@ class PGCli(object): click.echo(text, file=f) click.echo("\n".join(output), file=f) click.echo("", file=f) # extra newline - except IOError as e: + except OSError as e: click.secho(str(e), err=True, fg="red") else: if output: @@ -729,7 +743,6 @@ class PGCli(object): if not self.less_chatty: print("Server: PostgreSQL", self.pgexecute.server_version) print("Version:", __version__) - print("Chat: https://gitter.im/dbcli/pgcli") print("Home: http://pgcli.com") try: @@ -753,11 +766,7 @@ class PGCli(object): while self.watch_command: try: query = self.execute_command(self.watch_command) - click.echo( - "Waiting for {0} seconds before repeating".format( - timing - ) - ) + click.echo(f"Waiting for {timing} seconds before repeating") sleep(timing) except KeyboardInterrupt: self.watch_command = None @@ -979,16 +988,13 @@ class PGCli(object): callback = functools.partial( self._on_completions_refreshed, persist_priorities=persist_priorities ) - self.completion_refresher.refresh( + return self.completion_refresher.refresh( self.pgexecute, self.pgspecial, callback, history=history, settings=self.settings, ) - return [ - (None, None, None, "Auto-completion refresh started in the background.") - ] def _on_completions_refreshed(self, new_completer, persist_priorities): self._swap_completer_objects(new_completer, persist_priorities) @@ -1049,7 +1055,7 @@ class PGCli(object): str(self.pgexecute.port) if self.pgexecute.port is not None else "5432", ) string = string.replace("\\i", str(self.pgexecute.pid) or "(none)") - string = string.replace("\\#", "#" if (self.pgexecute.superuser) else ">") + string = string.replace("\\#", "#" if self.pgexecute.superuser else ">") string = string.replace("\\n", "\n") return string @@ -1075,9 +1081,10 @@ class PGCli(object): def echo_via_pager(self, text, color=None): if self.pgspecial.pager_config == PAGER_OFF or self.watch_command: click.echo(text, color=color) - elif "pspg" in os.environ.get("PAGER", "") and self.table_format == "csv": - click.echo_via_pager(text, color) - elif self.pgspecial.pager_config == PAGER_LONG_OUTPUT: + elif ( + self.pgspecial.pager_config == PAGER_LONG_OUTPUT + and self.table_format != "csv" + ): lines = text.split("\n") # The last 4 lines are reserved for the pgcli menu and padding @@ -1192,7 +1199,10 @@ class PGCli(object): help="Automatically switch to vertical output mode if the result is wider than the terminal width.", ) @click.option( - "--warn/--no-warn", default=None, help="Warn before running a destructive query." + "--warn", + default=None, + type=click.Choice(["all", "moderate", "off"]), + help="Warn before running a destructive query.", ) @click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1) @click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1) @@ -1384,7 +1394,7 @@ def is_mutating(status): if not status: return False - mutating = set(["insert", "update", "delete"]) + mutating = {"insert", "update", "delete"} return status.split(None, 1)[0].lower() in mutating @@ -1475,7 +1485,12 @@ def format_output(title, cur, headers, status, settings): formatted = iter(formatted.splitlines()) first_line = next(formatted) formatted = itertools.chain([first_line], formatted) - if not expanded and max_width and len(first_line) > max_width and headers: + if ( + not expanded + and max_width + and len(strip_ansi(first_line)) > max_width + and headers + ): formatted = formatter.format_output( cur, headers, format_name="vertical", column_types=None, **output_kwargs ) @@ -1502,10 +1517,16 @@ def parse_service_info(service): service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf") else: service_file = expanduser("~/.pg_service.conf") - if not service: + if not service or not os.path.exists(service_file): # nothing to do return None, service_file - service_file_config = ConfigObj(service_file) + with open(service_file, newline="") as f: + skipped_lines = skip_initial_comment(f) + try: + service_file_config = ConfigObj(f) + except ParseError as err: + err.line_number += skipped_lines + raise err if service not in service_file_config: return None, service_file service_conf = service_file_config.get(service) diff --git a/pgcli/packages/parseutils/__init__.py b/pgcli/packages/parseutils/__init__.py index a11e7bf..1acc008 100644 --- a/pgcli/packages/parseutils/__init__.py +++ b/pgcli/packages/parseutils/__init__.py @@ -1,22 +1,34 @@ import sqlparse -def query_starts_with(query, prefixes): +def query_starts_with(formatted_sql, prefixes): """Check if the query starts with any item from *prefixes*.""" prefixes = [prefix.lower() for prefix in prefixes] - formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip() return bool(formatted_sql) and formatted_sql.split()[0] in prefixes -def queries_start_with(queries, prefixes): - """Check if any queries start with any item from *prefixes*.""" - for query in sqlparse.split(queries): - if query and query_starts_with(query, prefixes) is True: - return True - return False +def query_is_unconditional_update(formatted_sql): + """Check if the query starts with UPDATE and contains no WHERE.""" + tokens = formatted_sql.split() + return bool(tokens) and tokens[0] == "update" and "where" not in tokens + +def query_is_simple_update(formatted_sql): + """Check if the query starts with UPDATE.""" + tokens = formatted_sql.split() + return bool(tokens) and tokens[0] == "update" -def is_destructive(queries): + +def is_destructive(queries, warning_level="all"): """Returns if any of the queries in *queries* is destructive.""" keywords = ("drop", "shutdown", "delete", "truncate", "alter") - return queries_start_with(queries, keywords) + for query in sqlparse.split(queries): + if query: + formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip() + if query_starts_with(formatted_sql, keywords): + return True + if query_is_unconditional_update(formatted_sql): + return True + if warning_level == "all" and query_is_simple_update(formatted_sql): + return True + return False diff --git a/pgcli/packages/parseutils/meta.py b/pgcli/packages/parseutils/meta.py index 108c01a..333cab5 100644 --- a/pgcli/packages/parseutils/meta.py +++ b/pgcli/packages/parseutils/meta.py @@ -50,7 +50,7 @@ def parse_defaults(defaults_string): yield current -class FunctionMetadata(object): +class FunctionMetadata: def __init__( self, schema_name, diff --git a/pgcli/packages/parseutils/tables.py b/pgcli/packages/parseutils/tables.py index 0ec3e69..aaa676c 100644 --- a/pgcli/packages/parseutils/tables.py +++ b/pgcli/packages/parseutils/tables.py @@ -42,8 +42,7 @@ def extract_from_part(parsed, stop_at_punctuation=True): for item in parsed.tokens: if tbl_prefix_seen: if is_subselect(item): - for x in extract_from_part(item, stop_at_punctuation): - yield x + yield from extract_from_part(item, stop_at_punctuation) elif stop_at_punctuation and item.ttype is Punctuation: return # An incomplete nested select won't be recognized correctly as a diff --git a/pgcli/packages/pgliterals/pgliterals.json b/pgcli/packages/pgliterals/pgliterals.json index c7b74b5..df00817 100644 --- a/pgcli/packages/pgliterals/pgliterals.json +++ b/pgcli/packages/pgliterals/pgliterals.json @@ -392,6 +392,7 @@ "QUOTE_NULLABLE", "RADIANS", "RADIUS", + "RANDOM", "RANK", "REGEXP_MATCH", "REGEXP_MATCHES", diff --git a/pgcli/packages/prioritization.py b/pgcli/packages/prioritization.py index e92dcbb..f5a9cb5 100644 --- a/pgcli/packages/prioritization.py +++ b/pgcli/packages/prioritization.py @@ -16,10 +16,10 @@ def _compile_regex(keyword): keywords = get_literals("keywords") -keyword_regexs = dict((kw, _compile_regex(kw)) for kw in keywords) +keyword_regexs = {kw: _compile_regex(kw) for kw in keywords} -class PrevalenceCounter(object): +class PrevalenceCounter: def __init__(self): self.keyword_counts = defaultdict(int) self.name_counts = defaultdict(int) diff --git a/pgcli/packages/prompt_utils.py b/pgcli/packages/prompt_utils.py index 3c58490..e8589de 100644 --- a/pgcli/packages/prompt_utils.py +++ b/pgcli/packages/prompt_utils.py @@ -3,7 +3,7 @@ import click from .parseutils import is_destructive -def confirm_destructive_query(queries): +def confirm_destructive_query(queries, warning_level): """Check if the query is destructive and prompts the user to confirm. Returns: @@ -15,7 +15,7 @@ def confirm_destructive_query(queries): prompt_text = ( "You're about to run a destructive command.\n" "Do you want to proceed? (y/n)" ) - if is_destructive(queries) and sys.stdin.isatty(): + if is_destructive(queries, warning_level) and sys.stdin.isatty(): return prompt(prompt_text, type=bool) diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py index 6ef8859..6305301 100644 --- a/pgcli/packages/sqlcompletion.py +++ b/pgcli/packages/sqlcompletion.py @@ -47,7 +47,7 @@ Alias = namedtuple("Alias", ["aliases"]) Path = namedtuple("Path", []) -class SqlStatement(object): +class SqlStatement: def __init__(self, full_text, text_before_cursor): self.identifier = None self.word_before_cursor = word_before_cursor = last_word( diff --git a/pgcli/pgclirc b/pgcli/pgclirc index e97afda..15c10f5 100644 --- a/pgcli/pgclirc +++ b/pgcli/pgclirc @@ -23,9 +23,13 @@ multi_line = False multi_line_mode = psql # Destructive warning mode will alert you before executing a sql statement -# that may cause harm to the database such as "drop table", "drop database" -# or "shutdown". -destructive_warning = True +# that may cause harm to the database such as "drop table", "drop database", +# "shutdown", "delete", or "update". +# Possible values: +# "all" - warn on data definition statements, server actions such as SHUTDOWN, DELETE or UPDATE +# "moderate" - skip warning on UPDATE statements, except for unconditional updates +# "off" - skip all warnings +destructive_warning = all # Enables expand mode, which is similar to `\x` in psql. expand = False @@ -170,9 +174,12 @@ arg-toolbar = 'noinherit bold' arg-toolbar.text = 'nobold' bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold' bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold' -literal.string = '#ba2121' -literal.number = '#666666' -keyword = 'bold #008000' +# These three values can be used to further refine the syntax highlighting. +# They are commented out by default, since they have priority over the theme set +# with the `syntax_style` setting and overriding its behavior can be confusing. +# literal.string = '#ba2121' +# literal.number = '#666666' +# keyword = 'bold #008000' # style classes for colored table output output.header = "#00ff5f bold" diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py index 9c95a01..227e25c 100644 --- a/pgcli/pgcompleter.py +++ b/pgcli/pgcompleter.py @@ -83,7 +83,7 @@ class PGCompleter(Completer): reserved_words = set(get_literals("reserved")) def __init__(self, smart_completion=True, pgspecial=None, settings=None): - super(PGCompleter, self).__init__() + super().__init__() self.smart_completion = smart_completion self.pgspecial = pgspecial self.prioritizer = PrevalenceCounter() @@ -140,7 +140,7 @@ class PGCompleter(Completer): return "'{}'".format(self.unescape_name(name)) def unescape_name(self, name): - """ Unquote a string.""" + """Unquote a string.""" if name and name[0] == '"' and name[-1] == '"': name = name[1:-1] @@ -177,7 +177,7 @@ class PGCompleter(Completer): :return: """ # casing should be a dict {lowercasename:PreferredCasingName} - self.casing = dict((word.lower(), word) for word in words) + self.casing = {word.lower(): word for word in words} def extend_relations(self, data, kind): """extend metadata for tables or views. @@ -279,8 +279,8 @@ class PGCompleter(Completer): fk = ForeignKey( parentschema, parenttable, parcol, childschema, childtable, childcol ) - childcolmeta.foreignkeys.append((fk)) - parcolmeta.foreignkeys.append((fk)) + childcolmeta.foreignkeys.append(fk) + parcolmeta.foreignkeys.append(fk) def extend_datatypes(self, type_data): @@ -424,7 +424,7 @@ class PGCompleter(Completer): # the same priority as unquoted names. lexical_priority = ( tuple( - 0 if c in (" _") else -ord(c) + 0 if c in " _" else -ord(c) for c in self.unescape_name(item.lower()) ) + (1,) @@ -517,9 +517,9 @@ class PGCompleter(Completer): # require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should # suggest only columns that appear in the last table and one more ltbl = tables[-1].ref - other_tbl_cols = set( + other_tbl_cols = { c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs - ) + } scoped_cols = { t: [col for col in cols if col.name in other_tbl_cols] for t, cols in scoped_cols.items() @@ -574,7 +574,7 @@ class PGCompleter(Completer): tbls - TableReference iterable of tables already in query """ tbl = self.case(tbl) - tbls = set(normalize_ref(t.ref) for t in tbls) + tbls = {normalize_ref(t.ref) for t in tbls} if self.generate_aliases: tbl = generate_alias(self.unescape_name(tbl)) if normalize_ref(tbl) not in tbls: @@ -589,10 +589,10 @@ class PGCompleter(Completer): tbls = suggestion.table_refs cols = self.populate_scoped_cols(tbls) # Set up some data structures for efficient access - qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls) - ref_prio = dict((normalize_ref(t.ref), n) for n, t in enumerate(tbls)) - refs = set(normalize_ref(t.ref) for t in tbls) - other_tbls = set((t.schema, t.name) for t in list(cols)[:-1]) + qualified = {normalize_ref(t.ref): t.schema for t in tbls} + ref_prio = {normalize_ref(t.ref): n for n, t in enumerate(tbls)} + refs = {normalize_ref(t.ref) for t in tbls} + other_tbls = {(t.schema, t.name) for t in list(cols)[:-1]} joins = [] # Iterate over FKs in existing tables to find potential joins fks = ( @@ -667,7 +667,7 @@ class PGCompleter(Completer): return d # Tables that are closer to the cursor get higher prio - ref_prio = dict((tbl.ref, num) for num, tbl in enumerate(suggestion.table_refs)) + ref_prio = {tbl.ref: num for num, tbl in enumerate(suggestion.table_refs)} # Map (schema, table, col) to tables coldict = list_dict( ((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref @@ -703,7 +703,11 @@ class PGCompleter(Completer): not f.is_aggregate and not f.is_window and not f.is_extension - and (f.is_public or f.schema_name == suggestion.schema) + and ( + f.is_public + or f.schema_name in self.search_path + or f.schema_name == suggestion.schema + ) ) else: @@ -721,9 +725,7 @@ class PGCompleter(Completer): # Function overloading means we way have multiple functions of the same # name at this point, so keep unique names only all_functions = self.populate_functions(suggestion.schema, filt) - funcs = set( - self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions - ) + funcs = {self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions} matches = self.find_matches(word_before_cursor, funcs, meta="function") @@ -953,7 +955,7 @@ class PGCompleter(Completer): :return: {TableReference:{colname:ColumnMetaData}} """ - ctes = dict((normalize_ref(t.name), t.columns) for t in local_tbls) + ctes = {normalize_ref(t.name): t.columns for t in local_tbls} columns = OrderedDict() meta = self.dbmetadata diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index d34bf26..a013b55 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -1,13 +1,15 @@ -import traceback import logging +import select +import traceback + +import pgspecial as special import psycopg2 -import psycopg2.extras import psycopg2.errorcodes import psycopg2.extensions as ext +import psycopg2.extras import sqlparse -import pgspecial as special -import select from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn + from .packages.parseutils.meta import FunctionMetadata, ForeignKey _logger = logging.getLogger(__name__) @@ -27,6 +29,7 @@ ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING)) # TODO: Get default timeout from pgclirc? _WAIT_SELECT_TIMEOUT = 1 +_wait_callback_is_set = False def _wait_select(conn): @@ -34,31 +37,41 @@ def _wait_select(conn): copy-pasted from psycopg2.extras.wait_select the default implementation doesn't define a timeout in the select calls """ - while 1: - try: - state = conn.poll() - if state == POLL_OK: - break - elif state == POLL_READ: - select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT) - elif state == POLL_WRITE: - select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT) - else: - raise conn.OperationalError("bad state from poll: %s" % state) - except KeyboardInterrupt: - conn.cancel() - # the loop will be broken by a server error - continue - except select.error as e: - errno = e.args[0] - if errno != 4: - raise + try: + while 1: + try: + state = conn.poll() + if state == POLL_OK: + break + elif state == POLL_READ: + select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT) + elif state == POLL_WRITE: + select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT) + else: + raise conn.OperationalError("bad state from poll: %s" % state) + except KeyboardInterrupt: + conn.cancel() + # the loop will be broken by a server error + continue + except OSError as e: + errno = e.args[0] + if errno != 4: + raise + except psycopg2.OperationalError: + pass -# When running a query, make pressing CTRL+C raise a KeyboardInterrupt -# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/ -# See also https://github.com/psycopg/psycopg2/issues/468 -ext.set_wait_callback(_wait_select) +def _set_wait_callback(is_virtual_database): + global _wait_callback_is_set + if _wait_callback_is_set: + return + _wait_callback_is_set = True + if is_virtual_database: + return + # When running a query, make pressing CTRL+C raise a KeyboardInterrupt + # See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/ + # See also https://github.com/psycopg/psycopg2/issues/468 + ext.set_wait_callback(_wait_select) def register_date_typecasters(connection): @@ -72,6 +85,8 @@ def register_date_typecasters(connection): cursor = connection.cursor() cursor.execute("SELECT NULL::date") + if cursor.description is None: + return date_oid = cursor.description[0][1] cursor.execute("SELECT NULL::timestamp") timestamp_oid = cursor.description[0][1] @@ -103,7 +118,7 @@ def register_json_typecasters(conn, loads_fn): try: psycopg2.extras.register_json(conn, loads=loads_fn, name=name) available.add(name) - except psycopg2.ProgrammingError: + except (psycopg2.ProgrammingError, psycopg2.errors.ProtocolViolation): pass return available @@ -127,7 +142,39 @@ def register_hstore_typecaster(conn): pass -class PGExecute(object): +class ProtocolSafeCursor(psycopg2.extensions.cursor): + def __init__(self, *args, **kwargs): + self.protocol_error = False + self.protocol_message = "" + super().__init__(*args, **kwargs) + + def __iter__(self): + if self.protocol_error: + raise StopIteration + return super().__iter__() + + def fetchall(self): + if self.protocol_error: + return [(self.protocol_message,)] + return super().fetchall() + + def fetchone(self): + if self.protocol_error: + return (self.protocol_message,) + return super().fetchone() + + def execute(self, sql, args=None): + try: + psycopg2.extensions.cursor.execute(self, sql, args) + self.protocol_error = False + self.protocol_message = "" + except psycopg2.errors.ProtocolViolation as ex: + self.protocol_error = True + self.protocol_message = ex.pgerror + _logger.debug("%s: %s" % (ex.__class__.__name__, ex)) + + +class PGExecute: # The boolean argument to the current_schemas function indicates whether # implicit schemas, e.g. pg_catalog @@ -190,8 +237,6 @@ class PGExecute(object): SELECT pg_catalog.pg_get_functiondef(f.f_oid) FROM f""" - version_query = "SELECT version();" - def __init__( self, database=None, @@ -203,6 +248,7 @@ class PGExecute(object): **kwargs, ): self._conn_params = {} + self._is_virtual_database = None self.conn = None self.dbname = None self.user = None @@ -214,6 +260,11 @@ class PGExecute(object): self.connect(database, user, password, host, port, dsn, **kwargs) self.reset_expanded = None + def is_virtual_database(self): + if self._is_virtual_database is None: + self._is_virtual_database = self.is_protocol_error() + return self._is_virtual_database + def copy(self): """Returns a clone of the current executor.""" return self.__class__(**self._conn_params) @@ -250,9 +301,9 @@ class PGExecute(object): ) conn_params.update({k: v for k, v in new_params.items() if v}) + conn_params["cursor_factory"] = ProtocolSafeCursor conn = psycopg2.connect(**conn_params) - cursor = conn.cursor() conn.set_client_encoding("utf8") self._conn_params = conn_params @@ -293,16 +344,22 @@ class PGExecute(object): self.extra_args = kwargs if not self.host: - self.host = self.get_socket_directory() + self.host = ( + "pgbouncer" + if self.is_virtual_database() + else self.get_socket_directory() + ) - pid = self._select_one(cursor, "select pg_backend_pid()")[0] - self.pid = pid + self.pid = conn.get_backend_pid() self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1") - self.server_version = conn.get_parameter_status("server_version") + self.server_version = conn.get_parameter_status("server_version") or "" + + _set_wait_callback(self.is_virtual_database()) - register_date_typecasters(conn) - register_json_typecasters(self.conn, self._json_typecaster) - register_hstore_typecaster(self.conn) + if not self.is_virtual_database(): + register_date_typecasters(conn) + register_json_typecasters(self.conn, self._json_typecaster) + register_hstore_typecaster(self.conn) @property def short_host(self): @@ -395,7 +452,13 @@ class PGExecute(object): # See https://github.com/dbcli/pgcli/issues/1014. cur = None try: - for result in pgspecial.execute(cur, sql): + response = pgspecial.execute(cur, sql) + if cur and cur.protocol_error: + yield None, None, None, cur.protocol_message, statement, False, False + # this would close connection. We should reconnect. + self.connect() + continue + for result in response: # e.g. execute_from_file already appends these if len(result) < 7: yield result + (sql, True, True) @@ -453,6 +516,9 @@ class PGExecute(object): if cur.description: headers = [x[0] for x in cur.description] return title, cur, headers, cur.statusmessage + elif cur.protocol_error: + _logger.debug("Protocol error, unsupported command.") + return title, None, None, cur.protocol_message else: _logger.debug("No rows in result.") return title, None, None, cur.statusmessage @@ -485,7 +551,7 @@ class PGExecute(object): try: cur.execute(sql, (spec,)) except psycopg2.ProgrammingError: - raise RuntimeError("View {} does not exist.".format(spec)) + raise RuntimeError(f"View {spec} does not exist.") result = cur.fetchone() view_type = "MATERIALIZED" if result[2] == "m" else "" return template.format(*result + (view_type,)) @@ -501,7 +567,7 @@ class PGExecute(object): result = cur.fetchone() return result[0] except psycopg2.ProgrammingError: - raise RuntimeError("Function {} does not exist.".format(spec)) + raise RuntimeError(f"Function {spec} does not exist.") def schemata(self): """Returns a list of schema names in the database""" @@ -527,21 +593,18 @@ class PGExecute(object): sql = cur.mogrify(self.tables_query, [kinds]) _logger.debug("Tables Query. sql: %r", sql) cur.execute(sql) - for row in cur: - yield row + yield from cur def tables(self): """Yields (schema_name, table_name) tuples""" - for row in self._relations(kinds=["r", "p", "f"]): - yield row + yield from self._relations(kinds=["r", "p", "f"]) def views(self): """Yields (schema_name, view_name) tuples. Includes both views and and materialized views """ - for row in self._relations(kinds=["v", "m"]): - yield row + yield from self._relations(kinds=["v", "m"]) def _columns(self, kinds=("r", "p", "f", "v", "m")): """Get column metadata for tables and views @@ -599,16 +662,13 @@ class PGExecute(object): sql = cur.mogrify(columns_query, [kinds]) _logger.debug("Columns Query. sql: %r", sql) cur.execute(sql) - for row in cur: - yield row + yield from cur def table_columns(self): - for row in self._columns(kinds=["r", "p", "f"]): - yield row + yield from self._columns(kinds=["r", "p", "f"]) def view_columns(self): - for row in self._columns(kinds=["v", "m"]): - yield row + yield from self._columns(kinds=["v", "m"]) def databases(self): with self.conn.cursor() as cur: @@ -623,6 +683,13 @@ class PGExecute(object): headers = [x[0] for x in cur.description] return cur.fetchall(), headers, cur.statusmessage + def is_protocol_error(self): + query = "SELECT 1" + with self.conn.cursor() as cur: + _logger.debug("Simple Query. sql: %r", query) + cur.execute(query) + return bool(cur.protocol_error) + def get_socket_directory(self): with self.conn.cursor() as cur: _logger.debug( @@ -804,8 +871,7 @@ class PGExecute(object): """ _logger.debug("Datatypes Query. sql: %r", query) cur.execute(query) - for row in cur: - yield row + yield from cur def casing(self): """Yields the most common casing for names used in db functions""" diff --git a/pgcli/pgtoolbar.py b/pgcli/pgtoolbar.py index f4289a1..41f903d 100644 --- a/pgcli/pgtoolbar.py +++ b/pgcli/pgtoolbar.py @@ -1,15 +1,23 @@ +from pkg_resources import packaging + +import prompt_toolkit from prompt_toolkit.key_binding.vi_state import InputMode from prompt_toolkit.application import get_app +parse_version = packaging.version.parse + +vi_modes = { + InputMode.INSERT: "I", + InputMode.NAVIGATION: "N", + InputMode.REPLACE: "R", + InputMode.INSERT_MULTIPLE: "M", +} +if parse_version(prompt_toolkit.__version__) >= parse_version("3.0.6"): + vi_modes[InputMode.REPLACE_SINGLE] = "R" + def _get_vi_mode(): - return { - InputMode.INSERT: "I", - InputMode.NAVIGATION: "N", - InputMode.REPLACE: "R", - InputMode.REPLACE_SINGLE: "R", - InputMode.INSERT_MULTIPLE: "M", - }[get_app().vi_state.input_mode] + return vi_modes[get_app().vi_state.input_mode] def create_toolbar_tokens_func(pgcli): diff --git a/requirements-dev.txt b/requirements-dev.txt index 80e8f43..84fa6bf 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,4 @@ pytest>=2.7.0 -mock>=1.0.1 tox>=1.9.2 behave>=1.2.4 pexpect==3.3 diff --git a/tests/conftest.py b/tests/conftest.py index 2a715b1..33cddf2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,7 @@ from utils import ( import pgcli.pgexecute -@pytest.yield_fixture(scope="function") +@pytest.fixture(scope="function") def connection(): create_db("_test_db") connection = db_connection("_test_db") diff --git a/tests/features/db_utils.py b/tests/features/db_utils.py index f57bc3b..6898394 100644 --- a/tests/features/db_utils.py +++ b/tests/features/db_utils.py @@ -44,7 +44,7 @@ def create_cn(hostname, password, username, dbname, port): host=hostname, user=username, database=dbname, password=password, port=port ) - print("Created connection: {0}.".format(cn.dsn)) + print(f"Created connection: {cn.dsn}.") return cn @@ -75,4 +75,4 @@ def close_cn(cn=None): """ if cn: cn.close() - print("Closed connection: {0}.".format(cn.dsn)) + print(f"Closed connection: {cn.dsn}.") diff --git a/tests/features/environment.py b/tests/features/environment.py index 049c2f2..215c85c 100644 --- a/tests/features/environment.py +++ b/tests/features/environment.py @@ -38,7 +38,7 @@ def before_all(context): vi = "_".join([str(x) for x in sys.version_info[:3]]) db_name = context.config.userdata.get("pg_test_db", "pgcli_behave_tests") - db_name_full = "{0}_{1}".format(db_name, vi) + db_name_full = f"{db_name}_{vi}" # Store get params from config. context.conf = { @@ -63,7 +63,7 @@ def before_all(context): "import coverage", "coverage.process_startup()", "import pgcli.main", - "pgcli.main.cli()", + "pgcli.main.cli(auto_envvar_prefix='BEHAVE')", ] ), ) @@ -102,6 +102,7 @@ def before_all(context): else: if "PGPASSWORD" in os.environ: del os.environ["PGPASSWORD"] + os.environ["BEHAVE_WARN"] = "moderate" context.cn = dbutils.create_db( context.conf["host"], @@ -122,12 +123,12 @@ def before_all(context): def show_env_changes(env_old, env_new): """Print out all test-specific env values.""" print("--- os.environ changed values: ---") - all_keys = set(list(env_old.keys()) + list(env_new.keys())) + all_keys = env_old.keys() | env_new.keys() for k in sorted(all_keys): old_value = env_old.get(k, "") new_value = env_new.get(k, "") if new_value and old_value != new_value: - print('{}="{}"'.format(k, new_value)) + print(f'{k}="{new_value}"') print("-" * 20) @@ -173,13 +174,13 @@ def after_scenario(context, scenario): # Quit nicely. if not context.atprompt: dbname = context.currentdb - context.cli.expect_exact("{0}> ".format(dbname), timeout=15) + context.cli.expect_exact(f"{dbname}> ", timeout=15) context.cli.sendcontrol("c") context.cli.sendcontrol("d") try: context.cli.expect_exact(pexpect.EOF, timeout=15) except pexpect.TIMEOUT: - print("--- after_scenario {}: kill cli".format(scenario.name)) + print(f"--- after_scenario {scenario.name}: kill cli") context.cli.kill(signal.SIGKILL) if hasattr(context, "tmpfile_sql_help") and context.tmpfile_sql_help: context.tmpfile_sql_help.close() diff --git a/tests/features/fixture_utils.py b/tests/features/fixture_utils.py index 16f123a..70b603d 100644 --- a/tests/features/fixture_utils.py +++ b/tests/features/fixture_utils.py @@ -18,7 +18,7 @@ def read_fixture_files(): """Read all files inside fixture_data directory.""" current_dir = os.path.dirname(__file__) fixture_dir = os.path.join(current_dir, "fixture_data/") - print("reading fixture data: {}".format(fixture_dir)) + print(f"reading fixture data: {fixture_dir}") fixture_dict = {} for filename in os.listdir(fixture_dir): if filename not in [".", ".."]: diff --git a/tests/features/steps/basic_commands.py b/tests/features/steps/basic_commands.py index 1069662..07e9ec1 100644 --- a/tests/features/steps/basic_commands.py +++ b/tests/features/steps/basic_commands.py @@ -65,19 +65,20 @@ def step_ctrl_d(context): Send Ctrl + D to hopefully exit. """ # turn off pager before exiting - context.cli.sendline("\pset pager off") + context.cli.sendcontrol("c") + context.cli.sendline(r"\pset pager off") wrappers.wait_prompt(context) context.cli.sendcontrol("d") context.cli.expect(pexpect.EOF, timeout=15) context.exit_sent = True -@when('we send "\?" command') +@when(r'we send "\?" command') def step_send_help(context): - """ + r""" Send \? to see help. """ - context.cli.sendline("\?") + context.cli.sendline(r"\?") @when("we send partial select command") @@ -96,9 +97,9 @@ def step_see_error_message(context): @when("we send source command") def step_send_source_command(context): context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_") - context.tmpfile_sql_help.write(b"\?") + context.tmpfile_sql_help.write(br"\?") context.tmpfile_sql_help.flush() - context.cli.sendline("\i {0}".format(context.tmpfile_sql_help.name)) + context.cli.sendline(fr"\i {context.tmpfile_sql_help.name}") wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) diff --git a/tests/features/steps/crud_database.py b/tests/features/steps/crud_database.py index 3fd8b7a..3f5d0e7 100644 --- a/tests/features/steps/crud_database.py +++ b/tests/features/steps/crud_database.py @@ -14,7 +14,7 @@ def step_db_create(context): """ Send create database. """ - context.cli.sendline("create database {0};".format(context.conf["dbname_tmp"])) + context.cli.sendline("create database {};".format(context.conf["dbname_tmp"])) context.response = {"database_name": context.conf["dbname_tmp"]} @@ -24,7 +24,7 @@ def step_db_drop(context): """ Send drop database. """ - context.cli.sendline("drop database {0};".format(context.conf["dbname_tmp"])) + context.cli.sendline("drop database {};".format(context.conf["dbname_tmp"])) @when("we connect to test database") @@ -33,7 +33,7 @@ def step_db_connect_test(context): Send connect to database. """ db_name = context.conf["dbname"] - context.cli.sendline("\\connect {0}".format(db_name)) + context.cli.sendline(f"\\connect {db_name}") @when("we connect to dbserver") @@ -59,7 +59,7 @@ def step_see_prompt(context): Wait to see the prompt. """ db_name = getattr(context, "currentdb", context.conf["dbname"]) - wrappers.expect_exact(context, "{0}> ".format(db_name), timeout=5) + wrappers.expect_exact(context, f"{db_name}> ", timeout=5) context.atprompt = True diff --git a/tests/features/steps/expanded.py b/tests/features/steps/expanded.py index f34fcf0..265ea39 100644 --- a/tests/features/steps/expanded.py +++ b/tests/features/steps/expanded.py @@ -31,7 +31,7 @@ def step_prepare_data(context): @when("we set expanded {mode}") def step_set_expanded(context, mode): """Set expanded to mode.""" - context.cli.sendline("\\" + "x {}".format(mode)) + context.cli.sendline("\\" + f"x {mode}") wrappers.expect_exact(context, "Expanded display is", timeout=2) wrappers.wait_prompt(context) diff --git a/tests/features/steps/iocommands.py b/tests/features/steps/iocommands.py index 613aeb2..a614490 100644 --- a/tests/features/steps/iocommands.py +++ b/tests/features/steps/iocommands.py @@ -13,7 +13,7 @@ def step_edit_file(context): ) if os.path.exists(context.editor_file_name): os.remove(context.editor_file_name) - context.cli.sendline("\e {0}".format(os.path.basename(context.editor_file_name))) + context.cli.sendline(r"\e {}".format(os.path.basename(context.editor_file_name))) wrappers.expect_exact( context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2 ) @@ -53,7 +53,7 @@ def step_tee_ouptut(context): ) if os.path.exists(context.tee_file_name): os.remove(context.tee_file_name) - context.cli.sendline("\o {0}".format(os.path.basename(context.tee_file_name))) + context.cli.sendline(r"\o {}".format(os.path.basename(context.tee_file_name))) wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) wrappers.expect_exact(context, "Writing to file", timeout=5) wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) @@ -67,7 +67,7 @@ def step_query_select_123456(context): @when("we stop teeing output") def step_notee_output(context): - context.cli.sendline("\o") + context.cli.sendline(r"\o") wrappers.expect_exact(context, "Time", timeout=5) diff --git a/tests/features/steps/specials.py b/tests/features/steps/specials.py index 813292c..a85f371 100644 --- a/tests/features/steps/specials.py +++ b/tests/features/steps/specials.py @@ -22,5 +22,10 @@ def step_see_refresh_started(context): Wait to see refresh output. """ wrappers.expect_pager( - context, "Auto-completion refresh started in the background.\r\n", timeout=2 + context, + [ + "Auto-completion refresh started in the background.\r\n", + "Auto-completion refresh restarted.\r\n", + ], + timeout=2, ) diff --git a/tests/features/steps/wrappers.py b/tests/features/steps/wrappers.py index e0f5a20..0ca8366 100644 --- a/tests/features/steps/wrappers.py +++ b/tests/features/steps/wrappers.py @@ -39,9 +39,15 @@ def expect_exact(context, expected, timeout): def expect_pager(context, expected, timeout): + formatted = expected if isinstance(expected, list) else [expected] + formatted = [ + f"{context.conf['pager_boundary']}\r\n{t}{context.conf['pager_boundary']}\r\n" + for t in formatted + ] + expect_exact( context, - "{0}\r\n{1}{0}\r\n".format(context.conf["pager_boundary"], expected), + formatted, timeout=timeout, ) @@ -57,7 +63,7 @@ def run_cli(context, run_args=None, prompt_check=True, currentdb=None): context.cli.logfile = context.logfile context.exit_sent = False context.currentdb = currentdb or context.conf["dbname"] - context.cli.sendline("\pset pager always") + context.cli.sendline(r"\pset pager always") if prompt_check: wait_prompt(context) diff --git a/tests/features/wrappager.py b/tests/features/wrappager.py index 51d4909..51d4909 100755..100644 --- a/tests/features/wrappager.py +++ b/tests/features/wrappager.py diff --git a/tests/metadata.py b/tests/metadata.py index 2f89ea2..4ebcccd 100644 --- a/tests/metadata.py +++ b/tests/metadata.py @@ -3,7 +3,7 @@ from itertools import product from pgcli.packages.parseutils.meta import FunctionMetadata, ForeignKey from prompt_toolkit.completion import Completion from prompt_toolkit.document import Document -from mock import Mock +from unittest.mock import Mock import pytest parametrize = pytest.mark.parametrize @@ -59,7 +59,7 @@ def wildcard_expansion(cols, pos=-1): return Completion(cols, start_position=pos, display_meta="columns", display="*") -class MetaData(object): +class MetaData: def __init__(self, metadata): self.metadata = metadata @@ -128,7 +128,7 @@ class MetaData(object): ] def schemas(self, pos=0): - schemas = set(sch for schs in self.metadata.values() for sch in schs) + schemas = {sch for schs in self.metadata.values() for sch in schs} return [schema(escape(s), pos=pos) for s in schemas] def functions_and_keywords(self, parent="public", pos=0): diff --git a/tests/parseutils/test_parseutils.py b/tests/parseutils/test_parseutils.py index 50bc889..5a375d7 100644 --- a/tests/parseutils/test_parseutils.py +++ b/tests/parseutils/test_parseutils.py @@ -1,4 +1,5 @@ import pytest +from pgcli.packages.parseutils import is_destructive from pgcli.packages.parseutils.tables import extract_tables from pgcli.packages.parseutils.utils import find_prev_keyword, is_open_quote @@ -34,12 +35,12 @@ def test_simple_select_single_table_double_quoted(): def test_simple_select_multiple_tables(): tables = extract_tables("select * from abc, def") - assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)]) + assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)} def test_simple_select_multiple_tables_double_quoted(): tables = extract_tables('select * from "Abc", "Def"') - assert set(tables) == set([(None, "Abc", None, False), (None, "Def", None, False)]) + assert set(tables) == {(None, "Abc", None, False), (None, "Def", None, False)} def test_simple_select_single_table_deouble_quoted_aliased(): @@ -49,14 +50,12 @@ def test_simple_select_single_table_deouble_quoted_aliased(): def test_simple_select_multiple_tables_deouble_quoted_aliased(): tables = extract_tables('select * from "Abc" a, "Def" d') - assert set(tables) == set([(None, "Abc", "a", False), (None, "Def", "d", False)]) + assert set(tables) == {(None, "Abc", "a", False), (None, "Def", "d", False)} def test_simple_select_multiple_tables_schema_qualified(): tables = extract_tables("select * from abc.def, ghi.jkl") - assert set(tables) == set( - [("abc", "def", None, False), ("ghi", "jkl", None, False)] - ) + assert set(tables) == {("abc", "def", None, False), ("ghi", "jkl", None, False)} def test_simple_select_with_cols_single_table(): @@ -71,14 +70,12 @@ def test_simple_select_with_cols_single_table_schema_qualified(): def test_simple_select_with_cols_multiple_tables(): tables = extract_tables("select a,b from abc, def") - assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)]) + assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)} def test_simple_select_with_cols_multiple_qualified_tables(): tables = extract_tables("select a,b from abc.def, def.ghi") - assert set(tables) == set( - [("abc", "def", None, False), ("def", "ghi", None, False)] - ) + assert set(tables) == {("abc", "def", None, False), ("def", "ghi", None, False)} def test_select_with_hanging_comma_single_table(): @@ -88,14 +85,12 @@ def test_select_with_hanging_comma_single_table(): def test_select_with_hanging_comma_multiple_tables(): tables = extract_tables("select a, from abc, def") - assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)]) + assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)} def test_select_with_hanging_period_multiple_tables(): tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2") - assert set(tables) == set( - [(None, "tabl1", "t1", False), (None, "tabl2", "t2", False)] - ) + assert set(tables) == {(None, "tabl1", "t1", False), (None, "tabl2", "t2", False)} def test_simple_insert_single_table(): @@ -126,14 +121,14 @@ def test_simple_update_table_with_schema(): @pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"]) def test_join_table(join_type): - sql = "SELECT * FROM abc a {0} JOIN def d ON a.id = d.num".format(join_type) + sql = f"SELECT * FROM abc a {join_type} JOIN def d ON a.id = d.num" tables = extract_tables(sql) - assert set(tables) == set([(None, "abc", "a", False), (None, "def", "d", False)]) + assert set(tables) == {(None, "abc", "a", False), (None, "def", "d", False)} def test_join_table_schema_qualified(): tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num") - assert set(tables) == set([("abc", "def", "x", False), ("ghi", "jkl", "y", False)]) + assert set(tables) == {("abc", "def", "x", False), ("ghi", "jkl", "y", False)} def test_incomplete_join_clause(): @@ -177,25 +172,25 @@ def test_extract_no_tables(text): @pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"]) def test_simple_function_as_table(arg_list): - tables = extract_tables("SELECT * FROM foo({0})".format(arg_list)) + tables = extract_tables(f"SELECT * FROM foo({arg_list})") assert tables == ((None, "foo", None, True),) @pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"]) def test_simple_schema_qualified_function_as_table(arg_list): - tables = extract_tables("SELECT * FROM foo.bar({0})".format(arg_list)) + tables = extract_tables(f"SELECT * FROM foo.bar({arg_list})") assert tables == (("foo", "bar", None, True),) @pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"]) def test_simple_aliased_function_as_table(arg_list): - tables = extract_tables("SELECT * FROM foo({0}) bar".format(arg_list)) + tables = extract_tables(f"SELECT * FROM foo({arg_list}) bar") assert tables == ((None, "foo", "bar", True),) def test_simple_table_and_function(): tables = extract_tables("SELECT * FROM foo JOIN bar()") - assert set(tables) == set([(None, "foo", None, False), (None, "bar", None, True)]) + assert set(tables) == {(None, "foo", None, False), (None, "bar", None, True)} def test_complex_table_and_function(): @@ -203,9 +198,7 @@ def test_complex_table_and_function(): """SELECT * FROM foo.bar baz JOIN bar.qux(x, y, z) quux""" ) - assert set(tables) == set( - [("foo", "bar", "baz", False), ("bar", "qux", "quux", True)] - ) + assert set(tables) == {("foo", "bar", "baz", False), ("bar", "qux", "quux", True)} def test_find_prev_keyword_using(): @@ -267,3 +260,21 @@ def test_is_open_quote__closed(sql): ) def test_is_open_quote__open(sql): assert is_open_quote(sql) + + +@pytest.mark.parametrize( + ("sql", "warning_level", "expected"), + [ + ("update abc set x = 1", "all", True), + ("update abc set x = 1 where y = 2", "all", True), + ("update abc set x = 1", "moderate", True), + ("update abc set x = 1 where y = 2", "moderate", False), + ("select x, y, z from abc", "all", False), + ("drop abc", "all", True), + ("alter abc", "all", True), + ("delete abc", "all", True), + ("truncate abc", "all", True), + ], +) +def test_is_destructive(sql, warning_level, expected): + assert is_destructive(sql, warning_level=warning_level) == expected diff --git a/tests/test_completion_refresher.py b/tests/test_completion_refresher.py index 6a916a8..a5529d6 100644 --- a/tests/test_completion_refresher.py +++ b/tests/test_completion_refresher.py @@ -1,6 +1,6 @@ import time import pytest -from mock import Mock, patch +from unittest.mock import Mock, patch @pytest.fixture @@ -37,7 +37,7 @@ def test_refresh_called_once(refresher): :return: """ callbacks = Mock() - pgexecute = Mock() + pgexecute = Mock(**{"is_virtual_database.return_value": False}) special = Mock() with patch.object(refresher, "_bg_refresh") as bg_refresh: @@ -57,7 +57,7 @@ def test_refresh_called_twice(refresher): """ callbacks = Mock() - pgexecute = Mock() + pgexecute = Mock(**{"is_virtual_database.return_value": False}) special = Mock() def dummy_bg_refresh(*args): @@ -84,14 +84,12 @@ def test_refresh_with_callbacks(refresher): :param refresher: """ callbacks = [Mock()] - pgexecute_class = Mock() - pgexecute = Mock() + pgexecute = Mock(**{"is_virtual_database.return_value": False}) pgexecute.extra_args = {} special = Mock() - with patch("pgcli.completion_refresher.PGExecute", pgexecute_class): - # Set refreshers to 0: we're not testing refresh logic here - refresher.refreshers = {} - refresher.refresh(pgexecute, special, callbacks) - time.sleep(1) # Wait for the thread to work. - assert callbacks[0].call_count == 1 + # Set refreshers to 0: we're not testing refresh logic here + refresher.refreshers = {} + refresher.refresh(pgexecute, special, callbacks) + time.sleep(1) # Wait for the thread to work. + assert callbacks[0].call_count == 1 diff --git a/tests/test_config.py b/tests/test_config.py index 1c023e0..08fe74e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,9 +1,10 @@ +import io import os import stat import pytest -from pgcli.config import ensure_dir_exists +from pgcli.config import ensure_dir_exists, skip_initial_comment def test_ensure_file_parent(tmpdir): @@ -20,11 +21,23 @@ def test_ensure_existing_dir(tmpdir): def test_ensure_other_create_error(tmpdir): - subdir = tmpdir.join("subdir") + subdir = tmpdir.join('subdir"') rcfile = subdir.join("rcfile") - # trigger an oserror that isn't "directory already exists" + # trigger an oserror that isn't "directory already exists" os.chmod(str(tmpdir), stat.S_IREAD) with pytest.raises(OSError): ensure_dir_exists(str(rcfile)) + + +@pytest.mark.parametrize( + "text, skipped_lines", + ( + ("abc\n", 1), + ("#[section]\ndef\n[section]", 2), + ("[section]", 0), + ), +) +def test_skip_initial_comment(text, skipped_lines): + assert skip_initial_comment(io.StringIO(text)) == skipped_lines diff --git a/tests/test_main.py b/tests/test_main.py index 9b85a34..c48accb 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,6 +1,6 @@ import os import platform -import mock +from unittest import mock import pytest @@ -288,7 +288,12 @@ def test_pg_service_file(tmpdir): cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) with open(tmpdir.join(".pg_service.conf").strpath, "w") as service_conf: service_conf.write( - """[myservice] + """File begins with a comment + that is not a comment + # or maybe a comment after all + because psql is crazy + + [myservice] host=a_host user=a_user port=5433 diff --git a/tests/test_naive_completion.py b/tests/test_naive_completion.py index a6c80a7..5b93661 100644 --- a/tests/test_naive_completion.py +++ b/tests/test_naive_completion.py @@ -13,7 +13,7 @@ def completer(): @pytest.fixture def complete_event(): - from mock import Mock + from unittest.mock import Mock return Mock() diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py index 9273be9..109674c 100644 --- a/tests/test_pgexecute.py +++ b/tests/test_pgexecute.py @@ -2,7 +2,7 @@ from textwrap import dedent import psycopg2 import pytest -from mock import patch, MagicMock +from unittest.mock import patch, MagicMock from pgspecial.main import PGSpecial, NO_QUERY from utils import run, dbtest, requires_json, requires_jsonb @@ -89,7 +89,7 @@ def test_expanded_slash_G(executor, pgspecial): # Tests whether we reset the expanded output after a \G. run(executor, """create table test(a boolean)""") run(executor, """insert into test values(True)""") - results = run(executor, """select * from test \G""", pgspecial=pgspecial) + results = run(executor, r"""select * from test \G""", pgspecial=pgspecial) assert pgspecial.expanded_output == False @@ -105,31 +105,35 @@ def test_schemata_table_views_and_columns_query(executor): # schemata # don't enforce all members of the schemas since they may include postgres # temporary schemas - assert set(executor.schemata()) >= set( - ["public", "pg_catalog", "information_schema", "schema1", "schema2"] - ) + assert set(executor.schemata()) >= { + "public", + "pg_catalog", + "information_schema", + "schema1", + "schema2", + } assert executor.search_path() == ["pg_catalog", "public"] # tables - assert set(executor.tables()) >= set( - [("public", "a"), ("public", "b"), ("schema1", "c")] - ) - - assert set(executor.table_columns()) >= set( - [ - ("public", "a", "x", "text", False, None), - ("public", "a", "y", "text", False, None), - ("public", "b", "z", "text", False, None), - ("schema1", "c", "w", "text", True, "'meow'::text"), - ] - ) + assert set(executor.tables()) >= { + ("public", "a"), + ("public", "b"), + ("schema1", "c"), + } + + assert set(executor.table_columns()) >= { + ("public", "a", "x", "text", False, None), + ("public", "a", "y", "text", False, None), + ("public", "b", "z", "text", False, None), + ("schema1", "c", "w", "text", True, "'meow'::text"), + } # views - assert set(executor.views()) >= set([("public", "d")]) + assert set(executor.views()) >= {("public", "d")} - assert set(executor.view_columns()) >= set( - [("public", "d", "e", "integer", False, None)] - ) + assert set(executor.view_columns()) >= { + ("public", "d", "e", "integer", False, None) + } @dbtest @@ -142,9 +146,9 @@ def test_foreign_key_query(executor): "create table schema2.child(childid int PRIMARY KEY, motherid int REFERENCES schema1.parent)", ) - assert set(executor.foreignkeys()) >= set( - [("schema1", "parent", "parentid", "schema2", "child", "motherid")] - ) + assert set(executor.foreignkeys()) >= { + ("schema1", "parent", "parentid", "schema2", "child", "motherid") + } @dbtest @@ -175,30 +179,28 @@ def test_functions_query(executor): ) funcs = set(executor.functions()) - assert funcs >= set( - [ - function_meta_data(func_name="func1", return_type="integer"), - function_meta_data( - func_name="func3", - arg_names=["x", "y"], - arg_types=["integer", "integer"], - arg_modes=["t", "t"], - return_type="record", - is_set_returning=True, - ), - function_meta_data( - schema_name="public", - func_name="func4", - arg_names=("x",), - arg_types=("integer",), - return_type="integer", - is_set_returning=True, - ), - function_meta_data( - schema_name="schema1", func_name="func2", return_type="integer" - ), - ] - ) + assert funcs >= { + function_meta_data(func_name="func1", return_type="integer"), + function_meta_data( + func_name="func3", + arg_names=["x", "y"], + arg_types=["integer", "integer"], + arg_modes=["t", "t"], + return_type="record", + is_set_returning=True, + ), + function_meta_data( + schema_name="public", + func_name="func4", + arg_names=("x",), + arg_types=("integer",), + return_type="integer", + is_set_returning=True, + ), + function_meta_data( + schema_name="schema1", func_name="func2", return_type="integer" + ), + } @dbtest @@ -257,8 +259,8 @@ def test_not_is_special(executor, pgspecial): @dbtest def test_execute_from_file_no_arg(executor, pgspecial): - """\i without a filename returns an error.""" - result = list(executor.run("\i", pgspecial=pgspecial)) + r"""\i without a filename returns an error.""" + result = list(executor.run(r"\i", pgspecial=pgspecial)) status, sql, success, is_special = result[0][3:] assert "missing required argument" in status assert success == False @@ -268,12 +270,12 @@ def test_execute_from_file_no_arg(executor, pgspecial): @dbtest @patch("pgcli.main.os") def test_execute_from_file_io_error(os, executor, pgspecial): - """\i with an io_error returns an error.""" - # Inject an IOError. - os.path.expanduser.side_effect = IOError("test") + r"""\i with an os_error returns an error.""" + # Inject an OSError. + os.path.expanduser.side_effect = OSError("test") # Check the result. - result = list(executor.run("\i test", pgspecial=pgspecial)) + result = list(executor.run(r"\i test", pgspecial=pgspecial)) status, sql, success, is_special = result[0][3:] assert status == "test" assert success == False @@ -290,7 +292,7 @@ def test_multiple_queries_same_line(executor): @dbtest def test_multiple_queries_with_special_command_same_line(executor, pgspecial): - result = run(executor, "select 'foo'; \d", pgspecial=pgspecial) + result = run(executor, r"select 'foo'; \d", pgspecial=pgspecial) assert len(result) == 11 # 2 * (output+status) * 3 lines assert "foo" in result[3] # This is a lame check. :( @@ -408,7 +410,7 @@ def test_date_time_types(executor): @pytest.mark.parametrize("value", ["10000000", "10000000.0", "10000000000000"]) def test_large_numbers_render_directly(executor, value): run(executor, "create table numbertest(a numeric)") - run(executor, "insert into numbertest (a) values ({0})".format(value)) + run(executor, f"insert into numbertest (a) values ({value})") assert value in run(executor, "select * from numbertest", join=True) @@ -511,13 +513,28 @@ def test_short_host(executor): assert executor.short_host == "localhost1" -class BrokenConnection(object): +class BrokenConnection: """Mock a connection that failed.""" def cursor(self): raise psycopg2.InterfaceError("I'm broken!") +class VirtualCursor: + """Mock a cursor to virtual database like pgbouncer.""" + + def __init__(self): + self.protocol_error = False + self.protocol_message = "" + self.description = None + self.status = None + self.statusmessage = "Error" + + def execute(self, *args, **kwargs): + self.protocol_error = True + self.protocol_message = "Command not supported" + + @dbtest def test_exit_without_active_connection(executor): quit_handler = MagicMock() @@ -540,3 +557,12 @@ def test_exit_without_active_connection(executor): # an exception should be raised when running a query without active connection with pytest.raises(psycopg2.InterfaceError): run(executor, "select 1", pgspecial=pgspecial) + + +@dbtest +def test_virtual_database(executor): + virtual_connection = MagicMock() + virtual_connection.cursor.return_value = VirtualCursor() + with patch.object(executor, "conn", virtual_connection): + result = run(executor, "select 1") + assert "Command not supported" in result diff --git a/tests/test_pgspecial.py b/tests/test_pgspecial.py index eaeaf12..cd99e32 100644 --- a/tests/test_pgspecial.py +++ b/tests/test_pgspecial.py @@ -13,12 +13,12 @@ from pgcli.packages.sqlcompletion import ( def test_slash_suggests_special(): suggestions = suggest_type("\\", "\\") - assert set(suggestions) == set([Special()]) + assert set(suggestions) == {Special()} def test_slash_d_suggests_special(): suggestions = suggest_type("\\d", "\\d") - assert set(suggestions) == set([Special()]) + assert set(suggestions) == {Special()} def test_dn_suggests_schemata(): @@ -30,24 +30,24 @@ def test_dn_suggests_schemata(): def test_d_suggests_tables_views_and_schemas(): - suggestions = suggest_type("\d ", "\d ") - assert set(suggestions) == set([Schema(), Table(schema=None), View(schema=None)]) + suggestions = suggest_type(r"\d ", r"\d ") + assert set(suggestions) == {Schema(), Table(schema=None), View(schema=None)} - suggestions = suggest_type("\d xxx", "\d xxx") - assert set(suggestions) == set([Schema(), Table(schema=None), View(schema=None)]) + suggestions = suggest_type(r"\d xxx", r"\d xxx") + assert set(suggestions) == {Schema(), Table(schema=None), View(schema=None)} def test_d_dot_suggests_schema_qualified_tables_or_views(): - suggestions = suggest_type("\d myschema.", "\d myschema.") - assert set(suggestions) == set([Table(schema="myschema"), View(schema="myschema")]) + suggestions = suggest_type(r"\d myschema.", r"\d myschema.") + assert set(suggestions) == {Table(schema="myschema"), View(schema="myschema")} - suggestions = suggest_type("\d myschema.xxx", "\d myschema.xxx") - assert set(suggestions) == set([Table(schema="myschema"), View(schema="myschema")]) + suggestions = suggest_type(r"\d myschema.xxx", r"\d myschema.xxx") + assert set(suggestions) == {Table(schema="myschema"), View(schema="myschema")} def test_df_suggests_schema_or_function(): suggestions = suggest_type("\\df xxx", "\\df xxx") - assert set(suggestions) == set([Function(schema=None, usage="special"), Schema()]) + assert set(suggestions) == {Function(schema=None, usage="special"), Schema()} suggestions = suggest_type("\\df myschema.xxx", "\\df myschema.xxx") assert suggestions == (Function(schema="myschema", usage="special"),) @@ -63,7 +63,7 @@ def test_leading_whitespace_ok(): def test_dT_suggests_schema_or_datatypes(): text = "\\dT " suggestions = suggest_type(text, text) - assert set(suggestions) == set([Schema(), Datatype(schema=None)]) + assert set(suggestions) == {Schema(), Datatype(schema=None)} def test_schema_qualified_dT_suggests_datatypes(): diff --git a/tests/test_prompt_utils.py b/tests/test_prompt_utils.py index c1f8a16..a8a3a1e 100644 --- a/tests/test_prompt_utils.py +++ b/tests/test_prompt_utils.py @@ -7,4 +7,4 @@ def test_confirm_destructive_query_notty(): stdin = click.get_text_stream("stdin") if not stdin.isatty(): sql = "drop database foo;" - assert confirm_destructive_query(sql) is None + assert confirm_destructive_query(sql, "all") is None diff --git a/tests/test_rowlimit.py b/tests/test_rowlimit.py index e76ea04..947fc80 100644 --- a/tests/test_rowlimit.py +++ b/tests/test_rowlimit.py @@ -1,5 +1,5 @@ import pytest -from mock import Mock +from unittest.mock import Mock from pgcli.main import PGCli diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py index 805b727..5c9c9af 100644 --- a/tests/test_smart_completion_multiple_schemata.py +++ b/tests/test_smart_completion_multiple_schemata.py @@ -193,7 +193,7 @@ def test_suggested_joins(completer, query, tbl): result = get_result(completer, query.format(tbl)) assert completions_to_set(result) == completions_to_set( testdata.schemas_and_from_clause_items() - + [join("custom.shipments ON shipments.user_id = {0}.id".format(tbl))] + + [join(f"custom.shipments ON shipments.user_id = {tbl}.id")] ) @@ -350,6 +350,36 @@ def test_schema_qualified_function_name(completer): ) +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +def test_schema_qualified_function_name_after_from(completer): + text = "SELECT * FROM custom.set_r" + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [ + function("set_returning_func()", -len("func")), + ] + ) + + +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +def test_unqualified_function_name_not_returned(completer): + text = "SELECT * FROM set_r" + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set([]) + + +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +def test_unqualified_function_name_in_search_path(completer): + completer.search_path = ["public", "custom"] + text = "SELECT * FROM set_r" + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [ + function("set_returning_func()", -len("func")), + ] + ) + + @parametrize("completer", completers(filtr=True, casing=False)) @parametrize( "text", diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py index b935709..db1fe0a 100644 --- a/tests/test_smart_completion_public_schema_only.py +++ b/tests/test_smart_completion_public_schema_only.py @@ -53,7 +53,7 @@ metadata = { ], } -metadata = dict((k, {"public": v}) for k, v in metadata.items()) +metadata = {k: {"public": v} for k, v in metadata.items()} testdata = MetaData(metadata) @@ -296,7 +296,7 @@ def test_suggested_cased_always_qualified_column_names(completer): def test_suggested_column_names_in_function(completer): result = get_result(completer, "SELECT MAX( from users", len("SELECT MAX(")) assert completions_to_set(result) == completions_to_set( - (testdata.columns_functions_and_keywords("users")) + testdata.columns_functions_and_keywords("users") ) @@ -316,7 +316,7 @@ def test_suggested_column_names_with_alias(completer): def test_suggested_multiple_column_names(completer): result = get_result(completer, "SELECT id, from users u", len("SELECT id, ")) assert completions_to_set(result) == completions_to_set( - (testdata.columns_functions_and_keywords("users")) + testdata.columns_functions_and_keywords("users") ) diff --git a/tests/test_sqlcompletion.py b/tests/test_sqlcompletion.py index 3cbad0a..744fadb 100644 --- a/tests/test_sqlcompletion.py +++ b/tests/test_sqlcompletion.py @@ -23,16 +23,14 @@ def cols_etc( ): """Returns the expected select-clause suggestions for a single-table select.""" - return set( - [ - Column( - table_refs=(TableReference(schema, table, alias, is_function),), - qualifiable=True, - ), - Function(schema=parent), - Keyword(last_keyword), - ] - ) + return { + Column( + table_refs=(TableReference(schema, table, alias, is_function),), + qualifiable=True, + ), + Function(schema=parent), + Keyword(last_keyword), + } def test_select_suggests_cols_with_visible_table_scope(): @@ -103,24 +101,20 @@ def test_where_equals_any_suggests_columns_or_keywords(): def test_lparen_suggests_cols_and_funcs(): suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(") - assert set(suggestion) == set( - [ - Column(table_refs=((None, "tbl", None, False),), qualifiable=True), - Function(schema=None), - Keyword("("), - ] - ) + assert set(suggestion) == { + Column(table_refs=((None, "tbl", None, False),), qualifiable=True), + Function(schema=None), + Keyword("("), + } def test_select_suggests_cols_and_funcs(): suggestions = suggest_type("SELECT ", "SELECT ") - assert set(suggestions) == set( - [ - Column(table_refs=(), qualifiable=True), - Function(schema=None), - Keyword("SELECT"), - ] - ) + assert set(suggestions) == { + Column(table_refs=(), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + } @pytest.mark.parametrize( @@ -128,13 +122,13 @@ def test_select_suggests_cols_and_funcs(): ) def test_suggests_tables_views_and_schemas(expression): suggestions = suggest_type(expression, expression) - assert set(suggestions) == set([Table(schema=None), View(schema=None), Schema()]) + assert set(suggestions) == {Table(schema=None), View(schema=None), Schema()} @pytest.mark.parametrize("expression", ["SELECT * FROM "]) def test_suggest_tables_views_schemas_and_functions(expression): suggestions = suggest_type(expression, expression) - assert set(suggestions) == set([FromClauseItem(schema=None), Schema()]) + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} @pytest.mark.parametrize( @@ -147,9 +141,11 @@ def test_suggest_tables_views_schemas_and_functions(expression): def test_suggest_after_join_with_two_tables(expression): suggestions = suggest_type(expression, expression) tables = tuple([(None, "foo", None, False), (None, "bar", None, False)]) - assert set(suggestions) == set( - [FromClauseItem(schema=None, table_refs=tables), Join(tables, None), Schema()] - ) + assert set(suggestions) == { + FromClauseItem(schema=None, table_refs=tables), + Join(tables, None), + Schema(), + } @pytest.mark.parametrize( @@ -158,13 +154,11 @@ def test_suggest_after_join_with_two_tables(expression): def test_suggest_after_join_with_one_table(expression): suggestions = suggest_type(expression, expression) tables = ((None, "foo", None, False),) - assert set(suggestions) == set( - [ - FromClauseItem(schema=None, table_refs=tables), - Join(((None, "foo", None, False),), None), - Schema(), - ] - ) + assert set(suggestions) == { + FromClauseItem(schema=None, table_refs=tables), + Join(((None, "foo", None, False),), None), + Schema(), + } @pytest.mark.parametrize( @@ -172,13 +166,13 @@ def test_suggest_after_join_with_one_table(expression): ) def test_suggest_qualified_tables_and_views(expression): suggestions = suggest_type(expression, expression) - assert set(suggestions) == set([Table(schema="sch"), View(schema="sch")]) + assert set(suggestions) == {Table(schema="sch"), View(schema="sch")} @pytest.mark.parametrize("expression", ["UPDATE sch."]) def test_suggest_qualified_aliasable_tables_and_views(expression): suggestions = suggest_type(expression, expression) - assert set(suggestions) == set([Table(schema="sch"), View(schema="sch")]) + assert set(suggestions) == {Table(schema="sch"), View(schema="sch")} @pytest.mark.parametrize( @@ -193,26 +187,27 @@ def test_suggest_qualified_aliasable_tables_and_views(expression): ) def test_suggest_qualified_tables_views_and_functions(expression): suggestions = suggest_type(expression, expression) - assert set(suggestions) == set([FromClauseItem(schema="sch")]) + assert set(suggestions) == {FromClauseItem(schema="sch")} @pytest.mark.parametrize("expression", ["SELECT * FROM foo JOIN sch."]) def test_suggest_qualified_tables_views_functions_and_joins(expression): suggestions = suggest_type(expression, expression) tbls = tuple([(None, "foo", None, False)]) - assert set(suggestions) == set( - [FromClauseItem(schema="sch", table_refs=tbls), Join(tbls, "sch")] - ) + assert set(suggestions) == { + FromClauseItem(schema="sch", table_refs=tbls), + Join(tbls, "sch"), + } def test_truncate_suggests_tables_and_schemas(): suggestions = suggest_type("TRUNCATE ", "TRUNCATE ") - assert set(suggestions) == set([Table(schema=None), Schema()]) + assert set(suggestions) == {Table(schema=None), Schema()} def test_truncate_suggests_qualified_tables(): suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.") - assert set(suggestions) == set([Table(schema="sch")]) + assert set(suggestions) == {Table(schema="sch")} @pytest.mark.parametrize( @@ -220,13 +215,11 @@ def test_truncate_suggests_qualified_tables(): ) def test_distinct_suggests_cols(text): suggestions = suggest_type(text, text) - assert set(suggestions) == set( - [ - Column(table_refs=(), local_tables=(), qualifiable=True), - Function(schema=None), - Keyword("DISTINCT"), - ] - ) + assert set(suggestions) == { + Column(table_refs=(), local_tables=(), qualifiable=True), + Function(schema=None), + Keyword("DISTINCT"), + } @pytest.mark.parametrize( @@ -244,20 +237,18 @@ def test_distinct_and_order_by_suggestions_with_aliases( text, text_before, last_keyword ): suggestions = suggest_type(text, text_before) - assert set(suggestions) == set( - [ - Column( - table_refs=( - TableReference(None, "tbl", "x", False), - TableReference(None, "tbl1", "y", False), - ), - local_tables=(), - qualifiable=True, + assert set(suggestions) == { + Column( + table_refs=( + TableReference(None, "tbl", "x", False), + TableReference(None, "tbl1", "y", False), ), - Function(schema=None), - Keyword(last_keyword), - ] - ) + local_tables=(), + qualifiable=True, + ), + Function(schema=None), + Keyword(last_keyword), + } @pytest.mark.parametrize( @@ -272,56 +263,50 @@ def test_distinct_and_order_by_suggestions_with_aliases( ) def test_distinct_and_order_by_suggestions_with_alias_given(text, text_before): suggestions = suggest_type(text, text_before) - assert set(suggestions) == set( - [ - Column( - table_refs=(TableReference(None, "tbl", "x", False),), - local_tables=(), - qualifiable=False, - ), - Table(schema="x"), - View(schema="x"), - Function(schema="x"), - ] - ) + assert set(suggestions) == { + Column( + table_refs=(TableReference(None, "tbl", "x", False),), + local_tables=(), + qualifiable=False, + ), + Table(schema="x"), + View(schema="x"), + Function(schema="x"), + } def test_function_arguments_with_alias_given(): suggestions = suggest_type("SELECT avg(x. FROM tbl x, tbl2 y", "SELECT avg(x.") - assert set(suggestions) == set( - [ - Column( - table_refs=(TableReference(None, "tbl", "x", False),), - local_tables=(), - qualifiable=False, - ), - Table(schema="x"), - View(schema="x"), - Function(schema="x"), - ] - ) + assert set(suggestions) == { + Column( + table_refs=(TableReference(None, "tbl", "x", False),), + local_tables=(), + qualifiable=False, + ), + Table(schema="x"), + View(schema="x"), + Function(schema="x"), + } def test_col_comma_suggests_cols(): suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,") - assert set(suggestions) == set( - [ - Column(table_refs=((None, "tbl", None, False),), qualifiable=True), - Function(schema=None), - Keyword("SELECT"), - ] - ) + assert set(suggestions) == { + Column(table_refs=((None, "tbl", None, False),), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + } def test_table_comma_suggests_tables_and_schemas(): suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ") - assert set(suggestions) == set([FromClauseItem(schema=None), Schema()]) + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} def test_into_suggests_tables_and_schemas(): suggestion = suggest_type("INSERT INTO ", "INSERT INTO ") - assert set(suggestion) == set([Table(schema=None), View(schema=None), Schema()]) + assert set(suggestion) == {Table(schema=None), View(schema=None), Schema()} @pytest.mark.parametrize( @@ -357,14 +342,12 @@ def test_partially_typed_col_name_suggests_col_names(): def test_dot_suggests_cols_of_a_table_or_schema_qualified_table(): suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.") - assert set(suggestions) == set( - [ - Column(table_refs=((None, "tabl", None, False),)), - Table(schema="tabl"), - View(schema="tabl"), - Function(schema="tabl"), - ] - ) + assert set(suggestions) == { + Column(table_refs=((None, "tabl", None, False),)), + Table(schema="tabl"), + View(schema="tabl"), + Function(schema="tabl"), + } @pytest.mark.parametrize( @@ -378,14 +361,12 @@ def test_dot_suggests_cols_of_a_table_or_schema_qualified_table(): ) def test_dot_suggests_cols_of_an_alias(sql): suggestions = suggest_type(sql, "SELECT t1.") - assert set(suggestions) == set( - [ - Table(schema="t1"), - View(schema="t1"), - Column(table_refs=((None, "tabl1", "t1", False),)), - Function(schema="t1"), - ] - ) + assert set(suggestions) == { + Table(schema="t1"), + View(schema="t1"), + Column(table_refs=((None, "tabl1", "t1", False),)), + Function(schema="t1"), + } @pytest.mark.parametrize( @@ -399,28 +380,24 @@ def test_dot_suggests_cols_of_an_alias(sql): ) def test_dot_suggests_cols_of_an_alias_where(sql): suggestions = suggest_type(sql, sql) - assert set(suggestions) == set( - [ - Table(schema="t1"), - View(schema="t1"), - Column(table_refs=((None, "tabl1", "t1", False),)), - Function(schema="t1"), - ] - ) + assert set(suggestions) == { + Table(schema="t1"), + View(schema="t1"), + Column(table_refs=((None, "tabl1", "t1", False),)), + Function(schema="t1"), + } def test_dot_col_comma_suggests_cols_or_schema_qualified_table(): suggestions = suggest_type( "SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2." ) - assert set(suggestions) == set( - [ - Column(table_refs=((None, "tabl2", "t2", False),)), - Table(schema="t2"), - View(schema="t2"), - Function(schema="t2"), - ] - ) + assert set(suggestions) == { + Column(table_refs=((None, "tabl2", "t2", False),)), + Table(schema="t2"), + View(schema="t2"), + Function(schema="t2"), + } @pytest.mark.parametrize( @@ -452,20 +429,18 @@ def test_sub_select_partial_text_suggests_keyword(expression): def test_outer_table_reference_in_exists_subquery_suggests_columns(): q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f." suggestions = suggest_type(q, q) - assert set(suggestions) == set( - [ - Column(table_refs=((None, "foo", "f", False),)), - Table(schema="f"), - View(schema="f"), - Function(schema="f"), - ] - ) + assert set(suggestions) == { + Column(table_refs=((None, "foo", "f", False),)), + Table(schema="f"), + View(schema="f"), + Function(schema="f"), + } @pytest.mark.parametrize("expression", ["SELECT * FROM (SELECT * FROM "]) def test_sub_select_table_name_completion(expression): suggestion = suggest_type(expression, expression) - assert set(suggestion) == set([FromClauseItem(schema=None), Schema()]) + assert set(suggestion) == {FromClauseItem(schema=None), Schema()} @pytest.mark.parametrize( @@ -478,22 +453,18 @@ def test_sub_select_table_name_completion(expression): def test_sub_select_table_name_completion_with_outer_table(expression): suggestion = suggest_type(expression, expression) tbls = tuple([(None, "foo", None, False)]) - assert set(suggestion) == set( - [FromClauseItem(schema=None, table_refs=tbls), Schema()] - ) + assert set(suggestion) == {FromClauseItem(schema=None, table_refs=tbls), Schema()} def test_sub_select_col_name_completion(): suggestions = suggest_type( "SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT " ) - assert set(suggestions) == set( - [ - Column(table_refs=((None, "abc", None, False),), qualifiable=True), - Function(schema=None), - Keyword("SELECT"), - ] - ) + assert set(suggestions) == { + Column(table_refs=((None, "abc", None, False),), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + } @pytest.mark.xfail @@ -508,25 +479,25 @@ def test_sub_select_dot_col_name_completion(): suggestions = suggest_type( "SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t." ) - assert set(suggestions) == set( - [ - Column(table_refs=((None, "tabl", "t", False),)), - Table(schema="t"), - View(schema="t"), - Function(schema="t"), - ] - ) + assert set(suggestions) == { + Column(table_refs=((None, "tabl", "t", False),)), + Table(schema="t"), + View(schema="t"), + Function(schema="t"), + } @pytest.mark.parametrize("join_type", ("", "INNER", "LEFT", "RIGHT OUTER")) @pytest.mark.parametrize("tbl_alias", ("", "foo")) def test_join_suggests_tables_and_schemas(tbl_alias, join_type): - text = "SELECT * FROM abc {0} {1} JOIN ".format(tbl_alias, join_type) + text = f"SELECT * FROM abc {tbl_alias} {join_type} JOIN " suggestion = suggest_type(text, text) tbls = tuple([(None, "abc", tbl_alias or None, False)]) - assert set(suggestion) == set( - [FromClauseItem(schema=None, table_refs=tbls), Schema(), Join(tbls, None)] - ) + assert set(suggestion) == { + FromClauseItem(schema=None, table_refs=tbls), + Schema(), + Join(tbls, None), + } def test_left_join_with_comma(): @@ -535,9 +506,7 @@ def test_left_join_with_comma(): # tbls should also include (None, 'bar', 'b', False) # but there's a bug with commas tbls = tuple([(None, "foo", "f", False)]) - assert set(suggestions) == set( - [FromClauseItem(schema=None, table_refs=tbls), Schema()] - ) + assert set(suggestions) == {FromClauseItem(schema=None, table_refs=tbls), Schema()} @pytest.mark.parametrize( @@ -550,15 +519,13 @@ def test_left_join_with_comma(): def test_join_alias_dot_suggests_cols1(sql): suggestions = suggest_type(sql, sql) tables = ((None, "abc", "a", False), (None, "def", "d", False)) - assert set(suggestions) == set( - [ - Column(table_refs=((None, "abc", "a", False),)), - Table(schema="a"), - View(schema="a"), - Function(schema="a"), - JoinCondition(table_refs=tables, parent=(None, "abc", "a", False)), - ] - ) + assert set(suggestions) == { + Column(table_refs=((None, "abc", "a", False),)), + Table(schema="a"), + View(schema="a"), + Function(schema="a"), + JoinCondition(table_refs=tables, parent=(None, "abc", "a", False)), + } @pytest.mark.parametrize( @@ -570,14 +537,12 @@ def test_join_alias_dot_suggests_cols1(sql): ) def test_join_alias_dot_suggests_cols2(sql): suggestion = suggest_type(sql, sql) - assert set(suggestion) == set( - [ - Column(table_refs=((None, "def", "d", False),)), - Table(schema="d"), - View(schema="d"), - Function(schema="d"), - ] - ) + assert set(suggestion) == { + Column(table_refs=((None, "def", "d", False),)), + Table(schema="d"), + View(schema="d"), + Function(schema="d"), + } @pytest.mark.parametrize( @@ -598,9 +563,10 @@ on """, def test_on_suggests_aliases_and_join_conditions(sql): suggestions = suggest_type(sql, sql) tables = ((None, "abc", "a", False), (None, "bcd", "b", False)) - assert set(suggestions) == set( - (JoinCondition(table_refs=tables, parent=None), Alias(aliases=("a", "b"))) - ) + assert set(suggestions) == { + JoinCondition(table_refs=tables, parent=None), + Alias(aliases=("a", "b")), + } @pytest.mark.parametrize( @@ -613,9 +579,10 @@ def test_on_suggests_aliases_and_join_conditions(sql): def test_on_suggests_tables_and_join_conditions(sql): suggestions = suggest_type(sql, sql) tables = ((None, "abc", None, False), (None, "bcd", None, False)) - assert set(suggestions) == set( - (JoinCondition(table_refs=tables, parent=None), Alias(aliases=("abc", "bcd"))) - ) + assert set(suggestions) == { + JoinCondition(table_refs=tables, parent=None), + Alias(aliases=("abc", "bcd")), + } @pytest.mark.parametrize( @@ -640,9 +607,10 @@ def test_on_suggests_aliases_right_side(sql): def test_on_suggests_tables_and_join_conditions_right_side(sql): suggestions = suggest_type(sql, sql) tables = ((None, "abc", None, False), (None, "bcd", None, False)) - assert set(suggestions) == set( - (JoinCondition(table_refs=tables, parent=None), Alias(aliases=("abc", "bcd"))) - ) + assert set(suggestions) == { + JoinCondition(table_refs=tables, parent=None), + Alias(aliases=("abc", "bcd")), + } @pytest.mark.parametrize( @@ -659,9 +627,9 @@ def test_on_suggests_tables_and_join_conditions_right_side(sql): ) def test_join_using_suggests_common_columns(text): tables = ((None, "abc", None, False), (None, "def", None, False)) - assert set(suggest_type(text, text)) == set( - [Column(table_refs=tables, require_last_table=True)] - ) + assert set(suggest_type(text, text)) == { + Column(table_refs=tables, require_last_table=True) + } def test_suggest_columns_after_multiple_joins(): @@ -678,29 +646,27 @@ def test_2_statements_2nd_current(): suggestions = suggest_type( "select * from a; select * from ", "select * from a; select * from " ) - assert set(suggestions) == set([FromClauseItem(schema=None), Schema()]) + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} suggestions = suggest_type( "select * from a; select from b", "select * from a; select " ) - assert set(suggestions) == set( - [ - Column(table_refs=((None, "b", None, False),), qualifiable=True), - Function(schema=None), - Keyword("SELECT"), - ] - ) + assert set(suggestions) == { + Column(table_refs=((None, "b", None, False),), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + } # Should work even if first statement is invalid suggestions = suggest_type( "select * from; select * from ", "select * from; select * from " ) - assert set(suggestions) == set([FromClauseItem(schema=None), Schema()]) + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} def test_2_statements_1st_current(): suggestions = suggest_type("select * from ; select * from b", "select * from ") - assert set(suggestions) == set([FromClauseItem(schema=None), Schema()]) + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} suggestions = suggest_type("select from a; select * from b", "select ") assert set(suggestions) == cols_etc("a", last_keyword="SELECT") @@ -711,7 +677,7 @@ def test_3_statements_2nd_current(): "select * from a; select * from ; select * from c", "select * from a; select * from ", ) - assert set(suggestions) == set([FromClauseItem(schema=None), Schema()]) + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} suggestions = suggest_type( "select * from a; select from b; select * from c", "select * from a; select " @@ -768,13 +734,11 @@ SELECT * FROM qux; ) def test_statements_in_function_body(text): suggestions = suggest_type(text, text[: text.find(" ") + 1]) - assert set(suggestions) == set( - [ - Column(table_refs=((None, "foo", None, False),), qualifiable=True), - Function(schema=None), - Keyword("SELECT"), - ] - ) + assert set(suggestions) == { + Column(table_refs=((None, "foo", None, False),), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + } functions = [ @@ -799,13 +763,13 @@ SELECT 1 FROM foo; @pytest.mark.parametrize("text", functions) def test_statements_with_cursor_after_function_body(text): suggestions = suggest_type(text, text[: text.find("; ") + 1]) - assert set(suggestions) == set([Keyword(), Special()]) + assert set(suggestions) == {Keyword(), Special()} @pytest.mark.parametrize("text", functions) def test_statements_with_cursor_before_function_body(text): suggestions = suggest_type(text, "") - assert set(suggestions) == set([Keyword(), Special()]) + assert set(suggestions) == {Keyword(), Special()} def test_create_db_with_template(): @@ -813,14 +777,14 @@ def test_create_db_with_template(): "create database foo with template ", "create database foo with template " ) - assert set(suggestions) == set((Database(),)) + assert set(suggestions) == {Database()} @pytest.mark.parametrize("initial_text", ("", " ", "\t \t", "\n")) def test_specials_included_for_initial_completion(initial_text): suggestions = suggest_type(initial_text, initial_text) - assert set(suggestions) == set([Keyword(), Special()]) + assert set(suggestions) == {Keyword(), Special()} def test_drop_schema_qualified_table_suggests_only_tables(): @@ -843,25 +807,30 @@ def test_drop_schema_suggests_schemas(): @pytest.mark.parametrize("text", ["SELECT x::", "SELECT x::y", "SELECT (x + y)::"]) def test_cast_operator_suggests_types(text): - assert set(suggest_type(text, text)) == set( - [Datatype(schema=None), Table(schema=None), Schema()] - ) + assert set(suggest_type(text, text)) == { + Datatype(schema=None), + Table(schema=None), + Schema(), + } @pytest.mark.parametrize( "text", ["SELECT foo::bar.", "SELECT foo::bar.baz", "SELECT (x + y)::bar."] ) def test_cast_operator_suggests_schema_qualified_types(text): - assert set(suggest_type(text, text)) == set( - [Datatype(schema="bar"), Table(schema="bar")] - ) + assert set(suggest_type(text, text)) == { + Datatype(schema="bar"), + Table(schema="bar"), + } def test_alter_column_type_suggests_types(): q = "ALTER TABLE foo ALTER COLUMN bar TYPE " - assert set(suggest_type(q, q)) == set( - [Datatype(schema=None), Table(schema=None), Schema()] - ) + assert set(suggest_type(q, q)) == { + Datatype(schema=None), + Table(schema=None), + Schema(), + } @pytest.mark.parametrize( @@ -880,9 +849,11 @@ def test_alter_column_type_suggests_types(): ], ) def test_identifier_suggests_types_in_parentheses(text): - assert set(suggest_type(text, text)) == set( - [Datatype(schema=None), Table(schema=None), Schema()] - ) + assert set(suggest_type(text, text)) == { + Datatype(schema=None), + Table(schema=None), + Schema(), + } @pytest.mark.parametrize( @@ -977,7 +948,7 @@ def test_ignore_leading_double_quotes(sql): ) def test_column_keyword_suggests_columns(sql): suggestions = suggest_type(sql, sql) - assert set(suggestions) == set([Column(table_refs=((None, "foo", None, False),))]) + assert set(suggestions) == {Column(table_refs=((None, "foo", None, False),))} def test_handle_unrecognized_kw_generously(): diff --git a/tests/utils.py b/tests/utils.py index 2427c30..460ea46 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -8,7 +8,7 @@ from os import getenv POSTGRES_USER = getenv("PGUSER", "postgres") POSTGRES_HOST = getenv("PGHOST", "localhost") POSTGRES_PORT = getenv("PGPORT", 5432) -POSTGRES_PASSWORD = getenv("PGPASSWORD", "") +POSTGRES_PASSWORD = getenv("PGPASSWORD", "postgres") def db_connection(dbname=None): @@ -73,7 +73,7 @@ def drop_tables(conn): def run( executor, sql, join=False, expanded=False, pgspecial=None, exception_formatter=None ): - " Return string output for the sql to be run " + "Return string output for the sql to be run" results = executor.run(sql, pgspecial, exception_formatter) formatted = [] @@ -89,7 +89,7 @@ def run( def completions_to_set(completions): - return set( + return { (completion.display_text, completion.display_meta_text) for completion in completions - ) + } |