diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-11-01 04:38:03 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-11-01 04:38:03 +0000 |
commit | fa11d0da51045077b543d42a1ab661c4a20b5127 (patch) | |
tree | aac9e87c59cb8bc7e3cd429e9200c3ca017cb591 /pgcli/main.py | |
parent | Adding upstream version 3.5.0. (diff) | |
download | pgcli-81cc31ae0825b1c86f44d8b1f45abf62ff57100b.tar.xz pgcli-81cc31ae0825b1c86f44d8b1f45abf62ff57100b.zip |
Adding upstream version 4.0.1.upstream/4.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'pgcli/main.py')
-rw-r--r-- | pgcli/main.py | 162 |
1 files changed, 130 insertions, 32 deletions
diff --git a/pgcli/main.py b/pgcli/main.py index 0fa264f..f95c800 100644 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -63,15 +63,14 @@ from .config import ( ) from .key_bindings import pgcli_bindings from .packages.formatter.sqlformatter import register_new_formatter -from .packages.prompt_utils import confirm_destructive_query +from .packages.prompt_utils import confirm, confirm_destructive_query +from .packages.parseutils import is_destructive +from .packages.parseutils import parse_destructive_warning from .__init__ import __version__ click.disable_unicode_literals_warning = True -try: - from urlparse import urlparse, unquote, parse_qs -except ImportError: - from urllib.parse import urlparse, unquote, parse_qs +from urllib.parse import urlparse from getpass import getuser @@ -201,6 +200,9 @@ class PGCli: self.multiline_mode = c["main"].get("multi_line_mode", "psql") self.vi_mode = c["main"].as_bool("vi") self.auto_expand = auto_vertical_output or c["main"].as_bool("auto_expand") + self.auto_retry_closed_connection = c["main"].as_bool( + "auto_retry_closed_connection" + ) self.expanded_output = c["main"].as_bool("expand") self.pgspecial.timing_enabled = c["main"].as_bool("timing") if row_limit is not None: @@ -224,11 +226,16 @@ class PGCli: self.syntax_style = c["main"]["syntax_style"] self.cli_style = c["colors"] self.wider_completion_menu = c["main"].as_bool("wider_completion_menu") - self.destructive_warning = warn or c["main"]["destructive_warning"] - # also handle boolean format of destructive warning - self.destructive_warning = {"true": "all", "false": "off"}.get( - self.destructive_warning.lower(), self.destructive_warning + self.destructive_warning = parse_destructive_warning( + warn or c["main"].as_list("destructive_warning") + ) + self.destructive_warning_restarts_connection = c["main"].as_bool( + "destructive_warning_restarts_connection" + ) + self.destructive_statements_require_transaction = c["main"].as_bool( + "destructive_statements_require_transaction" ) + self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty") self.null_string = c["main"].get("null_string", "<null>") self.prompt_format = ( @@ -258,6 +265,9 @@ class PGCli: # Initialize completer smart_completion = c["main"].as_bool("smart_completion") keyword_casing = c["main"]["keyword_casing"] + single_connection = single_connection or c["main"].as_bool( + "always_use_single_connection" + ) self.settings = { "casing_file": get_casing_file(c), "generate_casing_file": c["main"].as_bool("generate_casing_file"), @@ -269,6 +279,7 @@ class PGCli: "single_connection": single_connection, "less_chatty": less_chatty, "keyword_casing": keyword_casing, + "alias_map_file": c["main"]["alias_map_file"] or None, } completer = PGCompleter( @@ -292,7 +303,6 @@ class PGCli: raise PgCliQuitError def register_special_commands(self): - self.pgspecial.register( self.change_db, "\\c", @@ -354,6 +364,23 @@ class PGCli: "Change the table format used to output results", ) + self.pgspecial.register( + self.echo, + "\\echo", + "\\echo [string]", + "Echo a string to stdout", + ) + + self.pgspecial.register( + self.echo, + "\\qecho", + "\\qecho [string]", + "Echo a string to the query output channel.", + ) + + def echo(self, pattern, **_): + return [(None, None, None, pattern)] + def change_table_format(self, pattern, **_): try: if pattern not in TabularOutputFormatter().supported_formats: @@ -423,12 +450,20 @@ class PGCli: except OSError as e: return [(None, None, None, str(e), "", False, True)] - if ( - self.destructive_warning != "off" - and confirm_destructive_query(query, self.destructive_warning) is False - ): - message = "Wise choice. Command execution stopped." - return [(None, None, None, message)] + if self.destructive_warning: + if ( + self.destructive_statements_require_transaction + and not self.pgexecute.valid_transaction() + and is_destructive(query, self.destructive_warning) + ): + message = "Destructive statements must be run within a transaction. Command execution stopped." + return [(None, None, None, message)] + destroy = confirm_destructive_query( + query, self.destructive_warning, self.dsn_alias + ) + if destroy is False: + message = "Wise choice. Command execution stopped." + return [(None, None, None, message)] on_error_resume = self.on_error == "RESUME" return self.pgexecute.run( @@ -456,7 +491,6 @@ class PGCli: return [(None, None, None, message, "", True, True)] def initialize_logging(self): - log_file = self.config["main"]["log_file"] if log_file == "default": log_file = config_location() + "log" @@ -687,34 +721,52 @@ class PGCli: editor_command = special.editor_command(text) return text - def execute_command(self, text): + def execute_command(self, text, handle_closed_connection=True): logger = self.logger query = MetaQuery(query=text, successful=False) try: - if self.destructive_warning != "off": - destroy = confirm = confirm_destructive_query( - text, self.destructive_warning + if self.destructive_warning: + if ( + self.destructive_statements_require_transaction + and not self.pgexecute.valid_transaction() + and is_destructive(text, self.destructive_warning) + ): + click.secho( + "Destructive statements must be run within a transaction." + ) + raise KeyboardInterrupt + destroy = confirm_destructive_query( + text, self.destructive_warning, self.dsn_alias ) if destroy is False: click.secho("Wise choice!") raise KeyboardInterrupt elif destroy: click.secho("Your call!") + output, query = self._evaluate_command(text) except KeyboardInterrupt: - # Restart connection to the database - self.pgexecute.connect() - logger.debug("cancelled query, sql: %r", text) - click.secho("cancelled query", err=True, fg="red") + if self.destructive_warning_restarts_connection: + # Restart connection to the database + self.pgexecute.connect() + logger.debug("cancelled query and restarted connection, sql: %r", text) + click.secho( + "cancelled query and restarted connection", err=True, fg="red" + ) + else: + logger.debug("cancelled query, sql: %r", text) + click.secho("cancelled query", err=True, fg="red") except NotImplementedError: click.secho("Not Yet Implemented.", fg="yellow") except OperationalError as e: logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) - self._handle_server_closed_connection(text) - except (PgCliQuitError, EOFError) as e: + click.secho(str(e), err=True, fg="red") + if handle_closed_connection: + self._handle_server_closed_connection(text) + except (PgCliQuitError, EOFError): raise except Exception as e: logger.error("sql: %r, error: %r", text, e) @@ -722,7 +774,9 @@ class PGCli: click.secho(str(e), err=True, fg="red") else: try: - if self.output_file and not text.startswith(("\\o ", "\\? ")): + if self.output_file and not text.startswith( + ("\\o ", "\\? ", "\\echo ") + ): try: with open(self.output_file, "a", encoding="utf-8") as f: click.echo(text, file=f) @@ -766,6 +820,34 @@ class PGCli: logger.debug("Search path: %r", self.completer.search_path) return query + def _check_ongoing_transaction_and_allow_quitting(self): + """Return whether we can really quit, possibly by asking the + user to confirm so if there is an ongoing transaction. + """ + if not self.pgexecute.valid_transaction(): + return True + while 1: + try: + choice = click.prompt( + "A transaction is ongoing. Choose `c` to COMMIT, `r` to ROLLBACK, `a` to abort exit.", + default="a", + ) + except click.Abort: + # Print newline if user aborts with `^C`, otherwise + # pgcli's prompt will be printed on the same line + # (just after the confirmation prompt). + click.echo(None, err=False) + choice = "a" + choice = choice.lower() + if choice == "a": + return False # do not quit + if choice == "c": + query = self.execute_command("commit") + return query.successful # quit only if query is successful + if choice == "r": + query = self.execute_command("rollback") + return query.successful # quit only if query is successful + def run_cli(self): logger = self.logger @@ -788,6 +870,10 @@ class PGCli: text = self.prompt_app.prompt() except KeyboardInterrupt: continue + except EOFError: + if not self._check_ongoing_transaction_and_allow_quitting(): + continue + raise try: text = self.handle_editor_command(text) @@ -797,7 +883,12 @@ class PGCli: click.secho(str(e), err=True, fg="red") continue - self.handle_watch_command(text) + try: + self.handle_watch_command(text) + except PgCliQuitError: + if not self._check_ongoing_transaction_and_allow_quitting(): + continue + raise self.now = dt.datetime.today() @@ -1036,10 +1127,17 @@ class PGCli: click.secho("Reconnecting...", fg="green") self.pgexecute.connect() click.secho("Reconnected!", fg="green") - self.execute_command(text) except OperationalError as e: click.secho("Reconnect Failed", fg="red") click.secho(str(e), err=True, fg="red") + else: + retry = self.auto_retry_closed_connection or confirm( + "Run the query from before reconnecting?" + ) + if retry: + click.secho("Running query...", fg="green") + # Don't get stuck in a retry loop + self.execute_command(text, handle_closed_connection=False) def refresh_completions(self, history=None, persist_priorities="all"): """Refresh outdated completions @@ -1266,7 +1364,6 @@ class PGCli: @click.option( "--warn", default=None, - type=click.Choice(["all", "moderate", "off"]), help="Warn before running a destructive query.", ) @click.option( @@ -1575,7 +1672,8 @@ def format_output(title, cur, headers, status, settings, explain_mode=False): first_line = next(formatted) formatted = itertools.chain([first_line], formatted) if ( - not expanded + not explain_mode + and not expanded and max_width and len(strip_ansi(first_line)) > max_width and headers |