From 76d27bc43d56d7ef3ca0090fb199777888adf7c3 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 6 Sep 2021 06:17:09 +0200 Subject: Adding upstream version 3.2.0. Signed-off-by: Daniel Baumann --- pgcli/__init__.py | 2 +- pgcli/completion_refresher.py | 9 +- pgcli/config.py | 45 +++++++- pgcli/magic.py | 8 +- pgcli/main.py | 93 ++++++++++------ pgcli/packages/parseutils/__init__.py | 32 ++++-- pgcli/packages/parseutils/meta.py | 2 +- pgcli/packages/parseutils/tables.py | 3 +- pgcli/packages/pgliterals/pgliterals.json | 1 + pgcli/packages/prioritization.py | 4 +- pgcli/packages/prompt_utils.py | 4 +- pgcli/packages/sqlcompletion.py | 2 +- pgcli/pgclirc | 19 +++- pgcli/pgcompleter.py | 40 +++---- pgcli/pgexecute.py | 178 ++++++++++++++++++++---------- pgcli/pgtoolbar.py | 22 ++-- 16 files changed, 311 insertions(+), 153 deletions(-) (limited to 'pgcli') diff --git a/pgcli/__init__.py b/pgcli/__init__.py index f5f41e5..1173108 100644 --- a/pgcli/__init__.py +++ b/pgcli/__init__.py @@ -1 +1 @@ -__version__ = "3.1.0" +__version__ = "3.2.0" diff --git a/pgcli/completion_refresher.py b/pgcli/completion_refresher.py index cf0879f..1039d51 100644 --- a/pgcli/completion_refresher.py +++ b/pgcli/completion_refresher.py @@ -3,10 +3,9 @@ import os from collections import OrderedDict from .pgcompleter import PGCompleter -from .pgexecute import PGExecute -class CompletionRefresher(object): +class CompletionRefresher: refreshers = OrderedDict() @@ -27,6 +26,10 @@ class CompletionRefresher(object): has completed the refresh. The newly created completion object will be passed in as an argument to each callback. """ + if executor.is_virtual_database(): + # do nothing + return [(None, None, None, "Auto-completion refresh can't be started.")] + if self.is_refreshing(): self._restart_refresh.set() return [(None, None, None, "Auto-completion refresh restarted.")] @@ -141,7 +144,7 @@ def refresh_casing(completer, executor): with open(casing_file, "w") as f: f.write(casing_prefs) if os.path.isfile(casing_file): - with open(casing_file, "r") as f: + with open(casing_file) as f: completer.extend_casing([line.strip() for line in f]) diff --git a/pgcli/config.py b/pgcli/config.py index 0fc42dd..22f08dc 100644 --- a/pgcli/config.py +++ b/pgcli/config.py @@ -3,6 +3,8 @@ import shutil import os import platform from os.path import expanduser, exists, dirname +import re +from typing import TextIO from configobj import ConfigObj @@ -16,11 +18,15 @@ def config_location(): def load_config(usr_cfg, def_cfg=None): - cfg = ConfigObj() - cfg.merge(ConfigObj(def_cfg, interpolation=False)) - cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8")) + # avoid config merges when possible. For writing, we need an umerged config instance. + # see https://github.com/dbcli/pgcli/issues/1240 and https://github.com/DiffSK/configobj/issues/171 + if def_cfg: + cfg = ConfigObj() + cfg.merge(ConfigObj(def_cfg, interpolation=False)) + cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8")) + else: + cfg = ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8") cfg.filename = expanduser(usr_cfg) - return cfg @@ -44,12 +50,16 @@ def upgrade_config(config, def_config): cfg.write() +def get_config_filename(pgclirc_file=None): + return pgclirc_file or "%sconfig" % config_location() + + def get_config(pgclirc_file=None): from pgcli import __file__ as package_root package_root = os.path.dirname(package_root) - pgclirc_file = pgclirc_file or "%sconfig" % config_location() + pgclirc_file = get_config_filename(pgclirc_file) default_config = os.path.join(package_root, "pgclirc") write_default_config(default_config, pgclirc_file) @@ -62,3 +72,28 @@ def get_casing_file(config): if casing_file == "default": casing_file = config_location() + "casing" return casing_file + + +def skip_initial_comment(f_stream: TextIO) -> int: + """ + Initial comment in ~/.pg_service.conf is not always marked with '#' + which crashes the parser. This function takes a file object and + "rewinds" it to the beginning of the first section, + from where on it can be parsed safely + + :return: number of skipped lines + """ + section_regex = r"\s*\[" + pos = f_stream.tell() + lines_skipped = 0 + while True: + line = f_stream.readline() + if line == "": + break + if re.match(section_regex, line) is not None: + f_stream.seek(pos) + break + else: + pos += len(line) + lines_skipped += 1 + return lines_skipped diff --git a/pgcli/magic.py b/pgcli/magic.py index f58f415..6e58f28 100644 --- a/pgcli/magic.py +++ b/pgcli/magic.py @@ -25,7 +25,11 @@ def pgcli_line_magic(line): if hasattr(sql.connection.Connection, "get"): conn = sql.connection.Connection.get(parsed["connection"]) else: - conn = sql.connection.Connection.set(parsed["connection"]) + try: + conn = sql.connection.Connection.set(parsed["connection"]) + # a new positional argument was added to Connection.set in version 0.4.0 of ipython-sql + except TypeError: + conn = sql.connection.Connection.set(parsed["connection"], False) try: # A corresponding pgcli object already exists @@ -43,7 +47,7 @@ def pgcli_line_magic(line): conn._pgcli = pgcli # For convenience, print the connection alias - print("Connected: {}".format(conn.name)) + print(f"Connected: {conn.name}") try: pgcli.run_cli() diff --git a/pgcli/main.py b/pgcli/main.py index b146898..5135f6f 100644 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -2,8 +2,9 @@ import platform import warnings from os.path import expanduser -from configobj import ConfigObj +from configobj import ConfigObj, ParseError from pgspecial.namedqueries import NamedQueries +from .config import skip_initial_comment warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2") @@ -20,12 +21,12 @@ import datetime as dt import itertools import platform from time import time, sleep -from codecs import open keyring = None # keyring will be loaded later from cli_helpers.tabular_output import TabularOutputFormatter from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers +from cli_helpers.utils import strip_ansi import click try: @@ -62,6 +63,7 @@ from .config import ( config_location, ensure_dir_exists, get_config, + get_config_filename, ) from .key_bindings import pgcli_bindings from .packages.prompt_utils import confirm_destructive_query @@ -122,7 +124,7 @@ class PgCliQuitError(Exception): pass -class PGCli(object): +class PGCli: default_prompt = "\\u@\\h:\\d> " max_len_prompt = 30 @@ -175,7 +177,11 @@ class PGCli(object): # Load config. c = self.config = get_config(pgclirc_file) - NamedQueries.instance = NamedQueries.from_config(self.config) + # at this point, config should be written to pgclirc_file if it did not exist. Read it. + self.config_writer = load_config(get_config_filename(pgclirc_file)) + + # make sure to use self.config_writer, not self.config + NamedQueries.instance = NamedQueries.from_config(self.config_writer) self.logger = logging.getLogger(__name__) self.initialize_logging() @@ -201,8 +207,11 @@ class PGCli(object): self.syntax_style = c["main"]["syntax_style"] self.cli_style = c["colors"] self.wider_completion_menu = c["main"].as_bool("wider_completion_menu") - c_dest_warning = c["main"].as_bool("destructive_warning") - self.destructive_warning = c_dest_warning if warn is None else warn + self.destructive_warning = warn or c["main"]["destructive_warning"] + # also handle boolean format of destructive warning + self.destructive_warning = {"true": "all", "false": "off"}.get( + self.destructive_warning.lower(), self.destructive_warning + ) self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty") self.null_string = c["main"].get("null_string", "") self.prompt_format = ( @@ -325,11 +334,11 @@ class PGCli(object): if pattern not in TabularOutputFormatter().supported_formats: raise ValueError() self.table_format = pattern - yield (None, None, None, "Changed table format to {}".format(pattern)) + yield (None, None, None, f"Changed table format to {pattern}") except ValueError: - msg = "Table format {} not recognized. Allowed formats:".format(pattern) + msg = f"Table format {pattern} not recognized. Allowed formats:" for table_type in TabularOutputFormatter().supported_formats: - msg += "\n\t{}".format(table_type) + msg += f"\n\t{table_type}" msg += "\nCurrently set to: %s" % self.table_format yield (None, None, None, msg) @@ -386,10 +395,13 @@ class PGCli(object): try: with open(os.path.expanduser(pattern), encoding="utf-8") as f: query = f.read() - except IOError as e: + except OSError as e: return [(None, None, None, str(e), "", False, True)] - if self.destructive_warning and confirm_destructive_query(query) is False: + if ( + self.destructive_warning != "off" + and confirm_destructive_query(query, self.destructive_warning) is False + ): message = "Wise choice. Command execution stopped." return [(None, None, None, message)] @@ -407,7 +419,7 @@ class PGCli(object): if not os.path.isfile(filename): try: open(filename, "w").close() - except IOError as e: + except OSError as e: self.output_file = None message = str(e) + "\nFile output disabled" return [(None, None, None, message, "", False, True)] @@ -479,7 +491,7 @@ class PGCli(object): service_config, file = parse_service_info(service) if service_config is None: click.secho( - "service '%s' was not found in %s" % (service, file), err=True, fg="red" + f"service '{service}' was not found in {file}", err=True, fg="red" ) exit(1) self.connect( @@ -515,7 +527,7 @@ class PGCli(object): passwd = os.environ.get("PGPASSWORD", "") # Find password from store - key = "%s@%s" % (user, host) + key = f"{user}@{host}" keyring_error_message = dedent( """\ {} @@ -644,8 +656,10 @@ class PGCli(object): query = MetaQuery(query=text, successful=False) try: - if self.destructive_warning: - destroy = confirm = confirm_destructive_query(text) + if self.destructive_warning != "off": + destroy = confirm = confirm_destructive_query( + text, self.destructive_warning + ) if destroy is False: click.secho("Wise choice!") raise KeyboardInterrupt @@ -677,7 +691,7 @@ class PGCli(object): click.echo(text, file=f) click.echo("\n".join(output), file=f) click.echo("", file=f) # extra newline - except IOError as e: + except OSError as e: click.secho(str(e), err=True, fg="red") else: if output: @@ -729,7 +743,6 @@ class PGCli(object): if not self.less_chatty: print("Server: PostgreSQL", self.pgexecute.server_version) print("Version:", __version__) - print("Chat: https://gitter.im/dbcli/pgcli") print("Home: http://pgcli.com") try: @@ -753,11 +766,7 @@ class PGCli(object): while self.watch_command: try: query = self.execute_command(self.watch_command) - click.echo( - "Waiting for {0} seconds before repeating".format( - timing - ) - ) + click.echo(f"Waiting for {timing} seconds before repeating") sleep(timing) except KeyboardInterrupt: self.watch_command = None @@ -979,16 +988,13 @@ class PGCli(object): callback = functools.partial( self._on_completions_refreshed, persist_priorities=persist_priorities ) - self.completion_refresher.refresh( + return self.completion_refresher.refresh( self.pgexecute, self.pgspecial, callback, history=history, settings=self.settings, ) - return [ - (None, None, None, "Auto-completion refresh started in the background.") - ] def _on_completions_refreshed(self, new_completer, persist_priorities): self._swap_completer_objects(new_completer, persist_priorities) @@ -1049,7 +1055,7 @@ class PGCli(object): str(self.pgexecute.port) if self.pgexecute.port is not None else "5432", ) string = string.replace("\\i", str(self.pgexecute.pid) or "(none)") - string = string.replace("\\#", "#" if (self.pgexecute.superuser) else ">") + string = string.replace("\\#", "#" if self.pgexecute.superuser else ">") string = string.replace("\\n", "\n") return string @@ -1075,9 +1081,10 @@ class PGCli(object): def echo_via_pager(self, text, color=None): if self.pgspecial.pager_config == PAGER_OFF or self.watch_command: click.echo(text, color=color) - elif "pspg" in os.environ.get("PAGER", "") and self.table_format == "csv": - click.echo_via_pager(text, color) - elif self.pgspecial.pager_config == PAGER_LONG_OUTPUT: + elif ( + self.pgspecial.pager_config == PAGER_LONG_OUTPUT + and self.table_format != "csv" + ): lines = text.split("\n") # The last 4 lines are reserved for the pgcli menu and padding @@ -1192,7 +1199,10 @@ class PGCli(object): help="Automatically switch to vertical output mode if the result is wider than the terminal width.", ) @click.option( - "--warn/--no-warn", default=None, help="Warn before running a destructive query." + "--warn", + default=None, + type=click.Choice(["all", "moderate", "off"]), + help="Warn before running a destructive query.", ) @click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1) @click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1) @@ -1384,7 +1394,7 @@ def is_mutating(status): if not status: return False - mutating = set(["insert", "update", "delete"]) + mutating = {"insert", "update", "delete"} return status.split(None, 1)[0].lower() in mutating @@ -1475,7 +1485,12 @@ def format_output(title, cur, headers, status, settings): formatted = iter(formatted.splitlines()) first_line = next(formatted) formatted = itertools.chain([first_line], formatted) - if not expanded and max_width and len(first_line) > max_width and headers: + if ( + not expanded + and max_width + and len(strip_ansi(first_line)) > max_width + and headers + ): formatted = formatter.format_output( cur, headers, format_name="vertical", column_types=None, **output_kwargs ) @@ -1502,10 +1517,16 @@ def parse_service_info(service): service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf") else: service_file = expanduser("~/.pg_service.conf") - if not service: + if not service or not os.path.exists(service_file): # nothing to do return None, service_file - service_file_config = ConfigObj(service_file) + with open(service_file, newline="") as f: + skipped_lines = skip_initial_comment(f) + try: + service_file_config = ConfigObj(f) + except ParseError as err: + err.line_number += skipped_lines + raise err if service not in service_file_config: return None, service_file service_conf = service_file_config.get(service) diff --git a/pgcli/packages/parseutils/__init__.py b/pgcli/packages/parseutils/__init__.py index a11e7bf..1acc008 100644 --- a/pgcli/packages/parseutils/__init__.py +++ b/pgcli/packages/parseutils/__init__.py @@ -1,22 +1,34 @@ import sqlparse -def query_starts_with(query, prefixes): +def query_starts_with(formatted_sql, prefixes): """Check if the query starts with any item from *prefixes*.""" prefixes = [prefix.lower() for prefix in prefixes] - formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip() return bool(formatted_sql) and formatted_sql.split()[0] in prefixes -def queries_start_with(queries, prefixes): - """Check if any queries start with any item from *prefixes*.""" - for query in sqlparse.split(queries): - if query and query_starts_with(query, prefixes) is True: - return True - return False +def query_is_unconditional_update(formatted_sql): + """Check if the query starts with UPDATE and contains no WHERE.""" + tokens = formatted_sql.split() + return bool(tokens) and tokens[0] == "update" and "where" not in tokens + +def query_is_simple_update(formatted_sql): + """Check if the query starts with UPDATE.""" + tokens = formatted_sql.split() + return bool(tokens) and tokens[0] == "update" -def is_destructive(queries): + +def is_destructive(queries, warning_level="all"): """Returns if any of the queries in *queries* is destructive.""" keywords = ("drop", "shutdown", "delete", "truncate", "alter") - return queries_start_with(queries, keywords) + for query in sqlparse.split(queries): + if query: + formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip() + if query_starts_with(formatted_sql, keywords): + return True + if query_is_unconditional_update(formatted_sql): + return True + if warning_level == "all" and query_is_simple_update(formatted_sql): + return True + return False diff --git a/pgcli/packages/parseutils/meta.py b/pgcli/packages/parseutils/meta.py index 108c01a..333cab5 100644 --- a/pgcli/packages/parseutils/meta.py +++ b/pgcli/packages/parseutils/meta.py @@ -50,7 +50,7 @@ def parse_defaults(defaults_string): yield current -class FunctionMetadata(object): +class FunctionMetadata: def __init__( self, schema_name, diff --git a/pgcli/packages/parseutils/tables.py b/pgcli/packages/parseutils/tables.py index 0ec3e69..aaa676c 100644 --- a/pgcli/packages/parseutils/tables.py +++ b/pgcli/packages/parseutils/tables.py @@ -42,8 +42,7 @@ def extract_from_part(parsed, stop_at_punctuation=True): for item in parsed.tokens: if tbl_prefix_seen: if is_subselect(item): - for x in extract_from_part(item, stop_at_punctuation): - yield x + yield from extract_from_part(item, stop_at_punctuation) elif stop_at_punctuation and item.ttype is Punctuation: return # An incomplete nested select won't be recognized correctly as a diff --git a/pgcli/packages/pgliterals/pgliterals.json b/pgcli/packages/pgliterals/pgliterals.json index c7b74b5..df00817 100644 --- a/pgcli/packages/pgliterals/pgliterals.json +++ b/pgcli/packages/pgliterals/pgliterals.json @@ -392,6 +392,7 @@ "QUOTE_NULLABLE", "RADIANS", "RADIUS", + "RANDOM", "RANK", "REGEXP_MATCH", "REGEXP_MATCHES", diff --git a/pgcli/packages/prioritization.py b/pgcli/packages/prioritization.py index e92dcbb..f5a9cb5 100644 --- a/pgcli/packages/prioritization.py +++ b/pgcli/packages/prioritization.py @@ -16,10 +16,10 @@ def _compile_regex(keyword): keywords = get_literals("keywords") -keyword_regexs = dict((kw, _compile_regex(kw)) for kw in keywords) +keyword_regexs = {kw: _compile_regex(kw) for kw in keywords} -class PrevalenceCounter(object): +class PrevalenceCounter: def __init__(self): self.keyword_counts = defaultdict(int) self.name_counts = defaultdict(int) diff --git a/pgcli/packages/prompt_utils.py b/pgcli/packages/prompt_utils.py index 3c58490..e8589de 100644 --- a/pgcli/packages/prompt_utils.py +++ b/pgcli/packages/prompt_utils.py @@ -3,7 +3,7 @@ import click from .parseutils import is_destructive -def confirm_destructive_query(queries): +def confirm_destructive_query(queries, warning_level): """Check if the query is destructive and prompts the user to confirm. Returns: @@ -15,7 +15,7 @@ def confirm_destructive_query(queries): prompt_text = ( "You're about to run a destructive command.\n" "Do you want to proceed? (y/n)" ) - if is_destructive(queries) and sys.stdin.isatty(): + if is_destructive(queries, warning_level) and sys.stdin.isatty(): return prompt(prompt_text, type=bool) diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py index 6ef8859..6305301 100644 --- a/pgcli/packages/sqlcompletion.py +++ b/pgcli/packages/sqlcompletion.py @@ -47,7 +47,7 @@ Alias = namedtuple("Alias", ["aliases"]) Path = namedtuple("Path", []) -class SqlStatement(object): +class SqlStatement: def __init__(self, full_text, text_before_cursor): self.identifier = None self.word_before_cursor = word_before_cursor = last_word( diff --git a/pgcli/pgclirc b/pgcli/pgclirc index e97afda..15c10f5 100644 --- a/pgcli/pgclirc +++ b/pgcli/pgclirc @@ -23,9 +23,13 @@ multi_line = False multi_line_mode = psql # Destructive warning mode will alert you before executing a sql statement -# that may cause harm to the database such as "drop table", "drop database" -# or "shutdown". -destructive_warning = True +# that may cause harm to the database such as "drop table", "drop database", +# "shutdown", "delete", or "update". +# Possible values: +# "all" - warn on data definition statements, server actions such as SHUTDOWN, DELETE or UPDATE +# "moderate" - skip warning on UPDATE statements, except for unconditional updates +# "off" - skip all warnings +destructive_warning = all # Enables expand mode, which is similar to `\x` in psql. expand = False @@ -170,9 +174,12 @@ arg-toolbar = 'noinherit bold' arg-toolbar.text = 'nobold' bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold' bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold' -literal.string = '#ba2121' -literal.number = '#666666' -keyword = 'bold #008000' +# These three values can be used to further refine the syntax highlighting. +# They are commented out by default, since they have priority over the theme set +# with the `syntax_style` setting and overriding its behavior can be confusing. +# literal.string = '#ba2121' +# literal.number = '#666666' +# keyword = 'bold #008000' # style classes for colored table output output.header = "#00ff5f bold" diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py index 9c95a01..227e25c 100644 --- a/pgcli/pgcompleter.py +++ b/pgcli/pgcompleter.py @@ -83,7 +83,7 @@ class PGCompleter(Completer): reserved_words = set(get_literals("reserved")) def __init__(self, smart_completion=True, pgspecial=None, settings=None): - super(PGCompleter, self).__init__() + super().__init__() self.smart_completion = smart_completion self.pgspecial = pgspecial self.prioritizer = PrevalenceCounter() @@ -140,7 +140,7 @@ class PGCompleter(Completer): return "'{}'".format(self.unescape_name(name)) def unescape_name(self, name): - """ Unquote a string.""" + """Unquote a string.""" if name and name[0] == '"' and name[-1] == '"': name = name[1:-1] @@ -177,7 +177,7 @@ class PGCompleter(Completer): :return: """ # casing should be a dict {lowercasename:PreferredCasingName} - self.casing = dict((word.lower(), word) for word in words) + self.casing = {word.lower(): word for word in words} def extend_relations(self, data, kind): """extend metadata for tables or views. @@ -279,8 +279,8 @@ class PGCompleter(Completer): fk = ForeignKey( parentschema, parenttable, parcol, childschema, childtable, childcol ) - childcolmeta.foreignkeys.append((fk)) - parcolmeta.foreignkeys.append((fk)) + childcolmeta.foreignkeys.append(fk) + parcolmeta.foreignkeys.append(fk) def extend_datatypes(self, type_data): @@ -424,7 +424,7 @@ class PGCompleter(Completer): # the same priority as unquoted names. lexical_priority = ( tuple( - 0 if c in (" _") else -ord(c) + 0 if c in " _" else -ord(c) for c in self.unescape_name(item.lower()) ) + (1,) @@ -517,9 +517,9 @@ class PGCompleter(Completer): # require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should # suggest only columns that appear in the last table and one more ltbl = tables[-1].ref - other_tbl_cols = set( + other_tbl_cols = { c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs - ) + } scoped_cols = { t: [col for col in cols if col.name in other_tbl_cols] for t, cols in scoped_cols.items() @@ -574,7 +574,7 @@ class PGCompleter(Completer): tbls - TableReference iterable of tables already in query """ tbl = self.case(tbl) - tbls = set(normalize_ref(t.ref) for t in tbls) + tbls = {normalize_ref(t.ref) for t in tbls} if self.generate_aliases: tbl = generate_alias(self.unescape_name(tbl)) if normalize_ref(tbl) not in tbls: @@ -589,10 +589,10 @@ class PGCompleter(Completer): tbls = suggestion.table_refs cols = self.populate_scoped_cols(tbls) # Set up some data structures for efficient access - qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls) - ref_prio = dict((normalize_ref(t.ref), n) for n, t in enumerate(tbls)) - refs = set(normalize_ref(t.ref) for t in tbls) - other_tbls = set((t.schema, t.name) for t in list(cols)[:-1]) + qualified = {normalize_ref(t.ref): t.schema for t in tbls} + ref_prio = {normalize_ref(t.ref): n for n, t in enumerate(tbls)} + refs = {normalize_ref(t.ref) for t in tbls} + other_tbls = {(t.schema, t.name) for t in list(cols)[:-1]} joins = [] # Iterate over FKs in existing tables to find potential joins fks = ( @@ -667,7 +667,7 @@ class PGCompleter(Completer): return d # Tables that are closer to the cursor get higher prio - ref_prio = dict((tbl.ref, num) for num, tbl in enumerate(suggestion.table_refs)) + ref_prio = {tbl.ref: num for num, tbl in enumerate(suggestion.table_refs)} # Map (schema, table, col) to tables coldict = list_dict( ((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref @@ -703,7 +703,11 @@ class PGCompleter(Completer): not f.is_aggregate and not f.is_window and not f.is_extension - and (f.is_public or f.schema_name == suggestion.schema) + and ( + f.is_public + or f.schema_name in self.search_path + or f.schema_name == suggestion.schema + ) ) else: @@ -721,9 +725,7 @@ class PGCompleter(Completer): # Function overloading means we way have multiple functions of the same # name at this point, so keep unique names only all_functions = self.populate_functions(suggestion.schema, filt) - funcs = set( - self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions - ) + funcs = {self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions} matches = self.find_matches(word_before_cursor, funcs, meta="function") @@ -953,7 +955,7 @@ class PGCompleter(Completer): :return: {TableReference:{colname:ColumnMetaData}} """ - ctes = dict((normalize_ref(t.name), t.columns) for t in local_tbls) + ctes = {normalize_ref(t.name): t.columns for t in local_tbls} columns = OrderedDict() meta = self.dbmetadata diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index d34bf26..a013b55 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -1,13 +1,15 @@ -import traceback import logging +import select +import traceback + +import pgspecial as special import psycopg2 -import psycopg2.extras import psycopg2.errorcodes import psycopg2.extensions as ext +import psycopg2.extras import sqlparse -import pgspecial as special -import select from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn + from .packages.parseutils.meta import FunctionMetadata, ForeignKey _logger = logging.getLogger(__name__) @@ -27,6 +29,7 @@ ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING)) # TODO: Get default timeout from pgclirc? _WAIT_SELECT_TIMEOUT = 1 +_wait_callback_is_set = False def _wait_select(conn): @@ -34,31 +37,41 @@ def _wait_select(conn): copy-pasted from psycopg2.extras.wait_select the default implementation doesn't define a timeout in the select calls """ - while 1: - try: - state = conn.poll() - if state == POLL_OK: - break - elif state == POLL_READ: - select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT) - elif state == POLL_WRITE: - select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT) - else: - raise conn.OperationalError("bad state from poll: %s" % state) - except KeyboardInterrupt: - conn.cancel() - # the loop will be broken by a server error - continue - except select.error as e: - errno = e.args[0] - if errno != 4: - raise + try: + while 1: + try: + state = conn.poll() + if state == POLL_OK: + break + elif state == POLL_READ: + select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT) + elif state == POLL_WRITE: + select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT) + else: + raise conn.OperationalError("bad state from poll: %s" % state) + except KeyboardInterrupt: + conn.cancel() + # the loop will be broken by a server error + continue + except OSError as e: + errno = e.args[0] + if errno != 4: + raise + except psycopg2.OperationalError: + pass -# When running a query, make pressing CTRL+C raise a KeyboardInterrupt -# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/ -# See also https://github.com/psycopg/psycopg2/issues/468 -ext.set_wait_callback(_wait_select) +def _set_wait_callback(is_virtual_database): + global _wait_callback_is_set + if _wait_callback_is_set: + return + _wait_callback_is_set = True + if is_virtual_database: + return + # When running a query, make pressing CTRL+C raise a KeyboardInterrupt + # See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/ + # See also https://github.com/psycopg/psycopg2/issues/468 + ext.set_wait_callback(_wait_select) def register_date_typecasters(connection): @@ -72,6 +85,8 @@ def register_date_typecasters(connection): cursor = connection.cursor() cursor.execute("SELECT NULL::date") + if cursor.description is None: + return date_oid = cursor.description[0][1] cursor.execute("SELECT NULL::timestamp") timestamp_oid = cursor.description[0][1] @@ -103,7 +118,7 @@ def register_json_typecasters(conn, loads_fn): try: psycopg2.extras.register_json(conn, loads=loads_fn, name=name) available.add(name) - except psycopg2.ProgrammingError: + except (psycopg2.ProgrammingError, psycopg2.errors.ProtocolViolation): pass return available @@ -127,7 +142,39 @@ def register_hstore_typecaster(conn): pass -class PGExecute(object): +class ProtocolSafeCursor(psycopg2.extensions.cursor): + def __init__(self, *args, **kwargs): + self.protocol_error = False + self.protocol_message = "" + super().__init__(*args, **kwargs) + + def __iter__(self): + if self.protocol_error: + raise StopIteration + return super().__iter__() + + def fetchall(self): + if self.protocol_error: + return [(self.protocol_message,)] + return super().fetchall() + + def fetchone(self): + if self.protocol_error: + return (self.protocol_message,) + return super().fetchone() + + def execute(self, sql, args=None): + try: + psycopg2.extensions.cursor.execute(self, sql, args) + self.protocol_error = False + self.protocol_message = "" + except psycopg2.errors.ProtocolViolation as ex: + self.protocol_error = True + self.protocol_message = ex.pgerror + _logger.debug("%s: %s" % (ex.__class__.__name__, ex)) + + +class PGExecute: # The boolean argument to the current_schemas function indicates whether # implicit schemas, e.g. pg_catalog @@ -190,8 +237,6 @@ class PGExecute(object): SELECT pg_catalog.pg_get_functiondef(f.f_oid) FROM f""" - version_query = "SELECT version();" - def __init__( self, database=None, @@ -203,6 +248,7 @@ class PGExecute(object): **kwargs, ): self._conn_params = {} + self._is_virtual_database = None self.conn = None self.dbname = None self.user = None @@ -214,6 +260,11 @@ class PGExecute(object): self.connect(database, user, password, host, port, dsn, **kwargs) self.reset_expanded = None + def is_virtual_database(self): + if self._is_virtual_database is None: + self._is_virtual_database = self.is_protocol_error() + return self._is_virtual_database + def copy(self): """Returns a clone of the current executor.""" return self.__class__(**self._conn_params) @@ -250,9 +301,9 @@ class PGExecute(object): ) conn_params.update({k: v for k, v in new_params.items() if v}) + conn_params["cursor_factory"] = ProtocolSafeCursor conn = psycopg2.connect(**conn_params) - cursor = conn.cursor() conn.set_client_encoding("utf8") self._conn_params = conn_params @@ -293,16 +344,22 @@ class PGExecute(object): self.extra_args = kwargs if not self.host: - self.host = self.get_socket_directory() + self.host = ( + "pgbouncer" + if self.is_virtual_database() + else self.get_socket_directory() + ) - pid = self._select_one(cursor, "select pg_backend_pid()")[0] - self.pid = pid + self.pid = conn.get_backend_pid() self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1") - self.server_version = conn.get_parameter_status("server_version") + self.server_version = conn.get_parameter_status("server_version") or "" + + _set_wait_callback(self.is_virtual_database()) - register_date_typecasters(conn) - register_json_typecasters(self.conn, self._json_typecaster) - register_hstore_typecaster(self.conn) + if not self.is_virtual_database(): + register_date_typecasters(conn) + register_json_typecasters(self.conn, self._json_typecaster) + register_hstore_typecaster(self.conn) @property def short_host(self): @@ -395,7 +452,13 @@ class PGExecute(object): # See https://github.com/dbcli/pgcli/issues/1014. cur = None try: - for result in pgspecial.execute(cur, sql): + response = pgspecial.execute(cur, sql) + if cur and cur.protocol_error: + yield None, None, None, cur.protocol_message, statement, False, False + # this would close connection. We should reconnect. + self.connect() + continue + for result in response: # e.g. execute_from_file already appends these if len(result) < 7: yield result + (sql, True, True) @@ -453,6 +516,9 @@ class PGExecute(object): if cur.description: headers = [x[0] for x in cur.description] return title, cur, headers, cur.statusmessage + elif cur.protocol_error: + _logger.debug("Protocol error, unsupported command.") + return title, None, None, cur.protocol_message else: _logger.debug("No rows in result.") return title, None, None, cur.statusmessage @@ -485,7 +551,7 @@ class PGExecute(object): try: cur.execute(sql, (spec,)) except psycopg2.ProgrammingError: - raise RuntimeError("View {} does not exist.".format(spec)) + raise RuntimeError(f"View {spec} does not exist.") result = cur.fetchone() view_type = "MATERIALIZED" if result[2] == "m" else "" return template.format(*result + (view_type,)) @@ -501,7 +567,7 @@ class PGExecute(object): result = cur.fetchone() return result[0] except psycopg2.ProgrammingError: - raise RuntimeError("Function {} does not exist.".format(spec)) + raise RuntimeError(f"Function {spec} does not exist.") def schemata(self): """Returns a list of schema names in the database""" @@ -527,21 +593,18 @@ class PGExecute(object): sql = cur.mogrify(self.tables_query, [kinds]) _logger.debug("Tables Query. sql: %r", sql) cur.execute(sql) - for row in cur: - yield row + yield from cur def tables(self): """Yields (schema_name, table_name) tuples""" - for row in self._relations(kinds=["r", "p", "f"]): - yield row + yield from self._relations(kinds=["r", "p", "f"]) def views(self): """Yields (schema_name, view_name) tuples. Includes both views and and materialized views """ - for row in self._relations(kinds=["v", "m"]): - yield row + yield from self._relations(kinds=["v", "m"]) def _columns(self, kinds=("r", "p", "f", "v", "m")): """Get column metadata for tables and views @@ -599,16 +662,13 @@ class PGExecute(object): sql = cur.mogrify(columns_query, [kinds]) _logger.debug("Columns Query. sql: %r", sql) cur.execute(sql) - for row in cur: - yield row + yield from cur def table_columns(self): - for row in self._columns(kinds=["r", "p", "f"]): - yield row + yield from self._columns(kinds=["r", "p", "f"]) def view_columns(self): - for row in self._columns(kinds=["v", "m"]): - yield row + yield from self._columns(kinds=["v", "m"]) def databases(self): with self.conn.cursor() as cur: @@ -623,6 +683,13 @@ class PGExecute(object): headers = [x[0] for x in cur.description] return cur.fetchall(), headers, cur.statusmessage + def is_protocol_error(self): + query = "SELECT 1" + with self.conn.cursor() as cur: + _logger.debug("Simple Query. sql: %r", query) + cur.execute(query) + return bool(cur.protocol_error) + def get_socket_directory(self): with self.conn.cursor() as cur: _logger.debug( @@ -804,8 +871,7 @@ class PGExecute(object): """ _logger.debug("Datatypes Query. sql: %r", query) cur.execute(query) - for row in cur: - yield row + yield from cur def casing(self): """Yields the most common casing for names used in db functions""" diff --git a/pgcli/pgtoolbar.py b/pgcli/pgtoolbar.py index f4289a1..41f903d 100644 --- a/pgcli/pgtoolbar.py +++ b/pgcli/pgtoolbar.py @@ -1,15 +1,23 @@ +from pkg_resources import packaging + +import prompt_toolkit from prompt_toolkit.key_binding.vi_state import InputMode from prompt_toolkit.application import get_app +parse_version = packaging.version.parse + +vi_modes = { + InputMode.INSERT: "I", + InputMode.NAVIGATION: "N", + InputMode.REPLACE: "R", + InputMode.INSERT_MULTIPLE: "M", +} +if parse_version(prompt_toolkit.__version__) >= parse_version("3.0.6"): + vi_modes[InputMode.REPLACE_SINGLE] = "R" + def _get_vi_mode(): - return { - InputMode.INSERT: "I", - InputMode.NAVIGATION: "N", - InputMode.REPLACE: "R", - InputMode.REPLACE_SINGLE: "R", - InputMode.INSERT_MULTIPLE: "M", - }[get_app().vi_state.input_mode] + return vi_modes[get_app().vi_state.input_mode] def create_toolbar_tokens_func(pgcli): -- cgit v1.2.3