diff options
Diffstat (limited to 'pgcli/main.py')
-rw-r--r-- | pgcli/main.py | 93 |
1 files changed, 57 insertions, 36 deletions
diff --git a/pgcli/main.py b/pgcli/main.py index b146898..5135f6f 100644 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -2,8 +2,9 @@ import platform import warnings from os.path import expanduser -from configobj import ConfigObj +from configobj import ConfigObj, ParseError from pgspecial.namedqueries import NamedQueries +from .config import skip_initial_comment warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2") @@ -20,12 +21,12 @@ import datetime as dt import itertools import platform from time import time, sleep -from codecs import open keyring = None # keyring will be loaded later from cli_helpers.tabular_output import TabularOutputFormatter from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers +from cli_helpers.utils import strip_ansi import click try: @@ -62,6 +63,7 @@ from .config import ( config_location, ensure_dir_exists, get_config, + get_config_filename, ) from .key_bindings import pgcli_bindings from .packages.prompt_utils import confirm_destructive_query @@ -122,7 +124,7 @@ class PgCliQuitError(Exception): pass -class PGCli(object): +class PGCli: default_prompt = "\\u@\\h:\\d> " max_len_prompt = 30 @@ -175,7 +177,11 @@ class PGCli(object): # Load config. c = self.config = get_config(pgclirc_file) - NamedQueries.instance = NamedQueries.from_config(self.config) + # at this point, config should be written to pgclirc_file if it did not exist. Read it. + self.config_writer = load_config(get_config_filename(pgclirc_file)) + + # make sure to use self.config_writer, not self.config + NamedQueries.instance = NamedQueries.from_config(self.config_writer) self.logger = logging.getLogger(__name__) self.initialize_logging() @@ -201,8 +207,11 @@ class PGCli(object): self.syntax_style = c["main"]["syntax_style"] self.cli_style = c["colors"] self.wider_completion_menu = c["main"].as_bool("wider_completion_menu") - c_dest_warning = c["main"].as_bool("destructive_warning") - self.destructive_warning = c_dest_warning if warn is None else warn + self.destructive_warning = warn or c["main"]["destructive_warning"] + # also handle boolean format of destructive warning + self.destructive_warning = {"true": "all", "false": "off"}.get( + self.destructive_warning.lower(), self.destructive_warning + ) self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty") self.null_string = c["main"].get("null_string", "<null>") self.prompt_format = ( @@ -325,11 +334,11 @@ class PGCli(object): if pattern not in TabularOutputFormatter().supported_formats: raise ValueError() self.table_format = pattern - yield (None, None, None, "Changed table format to {}".format(pattern)) + yield (None, None, None, f"Changed table format to {pattern}") except ValueError: - msg = "Table format {} not recognized. Allowed formats:".format(pattern) + msg = f"Table format {pattern} not recognized. Allowed formats:" for table_type in TabularOutputFormatter().supported_formats: - msg += "\n\t{}".format(table_type) + msg += f"\n\t{table_type}" msg += "\nCurrently set to: %s" % self.table_format yield (None, None, None, msg) @@ -386,10 +395,13 @@ class PGCli(object): try: with open(os.path.expanduser(pattern), encoding="utf-8") as f: query = f.read() - except IOError as e: + except OSError as e: return [(None, None, None, str(e), "", False, True)] - if self.destructive_warning and confirm_destructive_query(query) is False: + if ( + self.destructive_warning != "off" + and confirm_destructive_query(query, self.destructive_warning) is False + ): message = "Wise choice. Command execution stopped." return [(None, None, None, message)] @@ -407,7 +419,7 @@ class PGCli(object): if not os.path.isfile(filename): try: open(filename, "w").close() - except IOError as e: + except OSError as e: self.output_file = None message = str(e) + "\nFile output disabled" return [(None, None, None, message, "", False, True)] @@ -479,7 +491,7 @@ class PGCli(object): service_config, file = parse_service_info(service) if service_config is None: click.secho( - "service '%s' was not found in %s" % (service, file), err=True, fg="red" + f"service '{service}' was not found in {file}", err=True, fg="red" ) exit(1) self.connect( @@ -515,7 +527,7 @@ class PGCli(object): passwd = os.environ.get("PGPASSWORD", "") # Find password from store - key = "%s@%s" % (user, host) + key = f"{user}@{host}" keyring_error_message = dedent( """\ {} @@ -644,8 +656,10 @@ class PGCli(object): query = MetaQuery(query=text, successful=False) try: - if self.destructive_warning: - destroy = confirm = confirm_destructive_query(text) + if self.destructive_warning != "off": + destroy = confirm = confirm_destructive_query( + text, self.destructive_warning + ) if destroy is False: click.secho("Wise choice!") raise KeyboardInterrupt @@ -677,7 +691,7 @@ class PGCli(object): click.echo(text, file=f) click.echo("\n".join(output), file=f) click.echo("", file=f) # extra newline - except IOError as e: + except OSError as e: click.secho(str(e), err=True, fg="red") else: if output: @@ -729,7 +743,6 @@ class PGCli(object): if not self.less_chatty: print("Server: PostgreSQL", self.pgexecute.server_version) print("Version:", __version__) - print("Chat: https://gitter.im/dbcli/pgcli") print("Home: http://pgcli.com") try: @@ -753,11 +766,7 @@ class PGCli(object): while self.watch_command: try: query = self.execute_command(self.watch_command) - click.echo( - "Waiting for {0} seconds before repeating".format( - timing - ) - ) + click.echo(f"Waiting for {timing} seconds before repeating") sleep(timing) except KeyboardInterrupt: self.watch_command = None @@ -979,16 +988,13 @@ class PGCli(object): callback = functools.partial( self._on_completions_refreshed, persist_priorities=persist_priorities ) - self.completion_refresher.refresh( + return self.completion_refresher.refresh( self.pgexecute, self.pgspecial, callback, history=history, settings=self.settings, ) - return [ - (None, None, None, "Auto-completion refresh started in the background.") - ] def _on_completions_refreshed(self, new_completer, persist_priorities): self._swap_completer_objects(new_completer, persist_priorities) @@ -1049,7 +1055,7 @@ class PGCli(object): str(self.pgexecute.port) if self.pgexecute.port is not None else "5432", ) string = string.replace("\\i", str(self.pgexecute.pid) or "(none)") - string = string.replace("\\#", "#" if (self.pgexecute.superuser) else ">") + string = string.replace("\\#", "#" if self.pgexecute.superuser else ">") string = string.replace("\\n", "\n") return string @@ -1075,9 +1081,10 @@ class PGCli(object): def echo_via_pager(self, text, color=None): if self.pgspecial.pager_config == PAGER_OFF or self.watch_command: click.echo(text, color=color) - elif "pspg" in os.environ.get("PAGER", "") and self.table_format == "csv": - click.echo_via_pager(text, color) - elif self.pgspecial.pager_config == PAGER_LONG_OUTPUT: + elif ( + self.pgspecial.pager_config == PAGER_LONG_OUTPUT + and self.table_format != "csv" + ): lines = text.split("\n") # The last 4 lines are reserved for the pgcli menu and padding @@ -1192,7 +1199,10 @@ class PGCli(object): help="Automatically switch to vertical output mode if the result is wider than the terminal width.", ) @click.option( - "--warn/--no-warn", default=None, help="Warn before running a destructive query." + "--warn", + default=None, + type=click.Choice(["all", "moderate", "off"]), + help="Warn before running a destructive query.", ) @click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1) @click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1) @@ -1384,7 +1394,7 @@ def is_mutating(status): if not status: return False - mutating = set(["insert", "update", "delete"]) + mutating = {"insert", "update", "delete"} return status.split(None, 1)[0].lower() in mutating @@ -1475,7 +1485,12 @@ def format_output(title, cur, headers, status, settings): formatted = iter(formatted.splitlines()) first_line = next(formatted) formatted = itertools.chain([first_line], formatted) - if not expanded and max_width and len(first_line) > max_width and headers: + if ( + not expanded + and max_width + and len(strip_ansi(first_line)) > max_width + and headers + ): formatted = formatter.format_output( cur, headers, format_name="vertical", column_types=None, **output_kwargs ) @@ -1502,10 +1517,16 @@ def parse_service_info(service): service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf") else: service_file = expanduser("~/.pg_service.conf") - if not service: + if not service or not os.path.exists(service_file): # nothing to do return None, service_file - service_file_config = ConfigObj(service_file) + with open(service_file, newline="") as f: + skipped_lines = skip_initial_comment(f) + try: + service_file_config = ConfigObj(f) + except ParseError as err: + err.line_number += skipped_lines + raise err if service not in service_file_config: return None, service_file service_conf = service_file_config.get(service) |