diff options
Diffstat (limited to 'pgcli/main.py')
-rw-r--r-- | pgcli/main.py | 224 |
1 files changed, 213 insertions, 11 deletions
diff --git a/pgcli/main.py b/pgcli/main.py index f95c800..056a940 100644 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -11,9 +11,9 @@ import logging import threading import shutil import functools -import pendulum import datetime as dt import itertools +import pathlib import platform from time import time, sleep from typing import Optional @@ -74,8 +74,9 @@ from urllib.parse import urlparse from getpass import getuser -from psycopg import OperationalError, InterfaceError +from psycopg import OperationalError, InterfaceError, Notify from psycopg.conninfo import make_conninfo, conninfo_to_dict +from psycopg.errors import Diagnostic from collections import namedtuple @@ -129,6 +130,15 @@ class PgCliQuitError(Exception): pass +def notify_callback(notify: Notify): + click.secho( + 'Notification received on channel "{}" (PID {}):\n{}'.format( + notify.channel, notify.pid, notify.payload + ), + fg="green", + ) + + class PGCli: default_prompt = "\\u@\\h:\\d> " max_len_prompt = 30 @@ -165,6 +175,7 @@ class PGCli: pgexecute=None, pgclirc_file=None, row_limit=None, + application_name="pgcli", single_connection=False, less_chatty=None, prompt=None, @@ -172,6 +183,7 @@ class PGCli: auto_vertical_output=False, warn=None, ssh_tunnel_url: Optional[str] = None, + log_file: Optional[str] = None, ): self.force_passwd_prompt = force_passwd_prompt self.never_passwd_prompt = never_passwd_prompt @@ -210,6 +222,8 @@ class PGCli: else: self.row_limit = c["main"].as_int("row_limit") + self.application_name = application_name + # if not specified, set to DEFAULT_MAX_FIELD_WIDTH # if specified but empty, set to None to disable truncation # ellipsis will take at least 3 symbols, so this can't be less than 3 if specified and > 0 @@ -237,6 +251,9 @@ class PGCli: ) self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty") + self.verbose_errors = "verbose_errors" in c["main"] and c["main"].as_bool( + "verbose_errors" + ) self.null_string = c["main"].get("null_string", "<null>") self.prompt_format = ( prompt @@ -295,6 +312,11 @@ class PGCli: self.ssh_tunnel_url = ssh_tunnel_url self.ssh_tunnel = None + if log_file: + with open(log_file, "a+"): + pass # ensure writeable + self.log_file = log_file + # formatter setup self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"]) register_new_formatter(self.formatter) @@ -355,6 +377,12 @@ class PGCli: "Send all query results to file.", ) self.pgspecial.register( + self.write_to_logfile, + "\\log-file", + "\\log-file [filename]", + "Log all query results to a logfile, in addition to the normal output destination.", + ) + self.pgspecial.register( self.info_connection, "\\conninfo", "\\conninfo", "Get connection details" ) self.pgspecial.register( @@ -378,6 +406,26 @@ class PGCli: "Echo a string to the query output channel.", ) + self.pgspecial.register( + self.toggle_verbose_errors, + "\\v", + "\\v [on|off]", + "Toggle verbose errors.", + ) + + def toggle_verbose_errors(self, pattern, **_): + flag = pattern.strip() + + if flag == "on": + self.verbose_errors = True + elif flag == "off": + self.verbose_errors = False + else: + self.verbose_errors = not self.verbose_errors + + message = "Verbose errors " + "on." if self.verbose_errors else "off." + return [(None, None, None, message)] + def echo(self, pattern, **_): return [(None, None, None, pattern)] @@ -473,6 +521,26 @@ class PGCli: explain_mode=self.explain_mode, ) + def write_to_logfile(self, pattern, **_): + if not pattern: + self.log_file = None + message = "Logfile capture disabled" + return [(None, None, None, message, "", True, True)] + + log_file = pathlib.Path(pattern).expanduser().absolute() + + try: + with open(log_file, "a+"): + pass # ensure writeable + except OSError as e: + self.log_file = None + message = str(e) + "\nLogfile capture disabled" + return [(None, None, None, message, "", False, True)] + + self.log_file = str(log_file) + message = 'Writing to file "%s"' % self.log_file + return [(None, None, None, message, "", True, True)] + def write_to_file(self, pattern, **_): if not pattern: self.output_file = None @@ -568,7 +636,7 @@ class PGCli: if not database: database = user - kwargs.setdefault("application_name", "pgcli") + kwargs.setdefault("application_name", self.application_name) # If password prompt is not forced but no password is provided, try # getting it from environment variable. @@ -658,7 +726,16 @@ class PGCli: # prompt for a password (no -w flag), prompt for a passwd and try again. try: try: - pgexecute = PGExecute(database, user, passwd, host, port, dsn, **kwargs) + pgexecute = PGExecute( + database, + user, + passwd, + host, + port, + dsn, + notify_callback, + **kwargs, + ) except (OperationalError, InterfaceError) as e: if should_ask_for_password(e): passwd = click.prompt( @@ -668,7 +745,14 @@ class PGCli: type=str, ) pgexecute = PGExecute( - database, user, passwd, host, port, dsn, **kwargs + database, + user, + passwd, + host, + port, + dsn, + notify_callback, + **kwargs, ) else: raise e @@ -775,7 +859,7 @@ class PGCli: else: try: if self.output_file and not text.startswith( - ("\\o ", "\\? ", "\\echo ") + ("\\o ", "\\log-file", "\\? ", "\\echo ") ): try: with open(self.output_file, "a", encoding="utf-8") as f: @@ -787,6 +871,23 @@ class PGCli: else: if output: self.echo_via_pager("\n".join(output)) + + # Log to file in addition to normal output + if ( + self.log_file + and not text.startswith(("\\o ", "\\log-file", "\\? ", "\\echo ")) + and not text.strip() == "" + ): + try: + with open(self.log_file, "a", encoding="utf-8") as f: + click.echo( + dt.datetime.now().isoformat(), file=f + ) # timestamp log + click.echo(text, file=f) + click.echo("\n".join(output), file=f) + click.echo("", file=f) # extra newline + except OSError as e: + click.secho(str(e), err=True, fg="red") except KeyboardInterrupt: pass @@ -797,9 +898,9 @@ class PGCli: "Time: %0.03fs (%s), executed in: %0.03fs (%s)" % ( query.total_time, - pendulum.Duration(seconds=query.total_time).in_words(), + duration_in_words(query.total_time), query.execution_time, - pendulum.Duration(seconds=query.execution_time).in_words(), + duration_in_words(query.execution_time), ) ) else: @@ -1053,7 +1154,7 @@ class PGCli: res = self.pgexecute.run( text, self.pgspecial, - exception_formatter, + lambda x: exception_formatter(x, self.verbose_errors), on_error_resume, explain_mode=self.explain_mode, ) @@ -1338,6 +1439,12 @@ class PGCli: help="Set threshold for row limit prompt. Use 0 to disable prompt.", ) @click.option( + "--application-name", + default="pgcli", + envvar="PGAPPNAME", + help="Application name for the connection.", +) +@click.option( "--less-chatty", "less_chatty", is_flag=True, @@ -1371,6 +1478,11 @@ class PGCli: default=None, help="Open an SSH tunnel to the given address and connect to the database from it.", ) +@click.option( + "--log-file", + default=None, + help="Write all queries & output into a file, in addition to the normal output destination.", +) @click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1) @click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1) def cli( @@ -1387,6 +1499,7 @@ def cli( pgclirc, dsn, row_limit, + application_name, less_chatty, prompt, prompt_dsn, @@ -1395,6 +1508,7 @@ def cli( list_dsn, warn, ssh_tunnel: str, + log_file: str, ): if version: print("Version:", __version__) @@ -1445,6 +1559,7 @@ def cli( never_prompt, pgclirc_file=pgclirc, row_limit=row_limit, + application_name=application_name, single_connection=single_connection, less_chatty=less_chatty, prompt=prompt, @@ -1452,6 +1567,7 @@ def cli( auto_vertical_output=auto_vertical_output, warn=warn, ssh_tunnel_url=ssh_tunnel, + log_file=log_file, ) # Choose which ever one has a valid value. @@ -1583,8 +1699,71 @@ def is_select(status): return status.split(None, 1)[0].lower() == "select" -def exception_formatter(e): - return click.style(str(e), fg="red") +def diagnostic_output(diagnostic: Diagnostic) -> str: + fields = [] + + if diagnostic.severity is not None: + fields.append("Severity: " + diagnostic.severity) + + if diagnostic.severity_nonlocalized is not None: + fields.append("Severity (non-localized): " + diagnostic.severity_nonlocalized) + + if diagnostic.sqlstate is not None: + fields.append("SQLSTATE code: " + diagnostic.sqlstate) + + if diagnostic.message_primary is not None: + fields.append("Message: " + diagnostic.message_primary) + + if diagnostic.message_detail is not None: + fields.append("Detail: " + diagnostic.message_detail) + + if diagnostic.message_hint is not None: + fields.append("Hint: " + diagnostic.message_hint) + + if diagnostic.statement_position is not None: + fields.append("Position: " + diagnostic.statement_position) + + if diagnostic.internal_position is not None: + fields.append("Internal position: " + diagnostic.internal_position) + + if diagnostic.internal_query is not None: + fields.append("Internal query: " + diagnostic.internal_query) + + if diagnostic.context is not None: + fields.append("Where: " + diagnostic.context) + + if diagnostic.schema_name is not None: + fields.append("Schema name: " + diagnostic.schema_name) + + if diagnostic.table_name is not None: + fields.append("Table name: " + diagnostic.table_name) + + if diagnostic.column_name is not None: + fields.append("Column name: " + diagnostic.column_name) + + if diagnostic.datatype_name is not None: + fields.append("Data type name: " + diagnostic.datatype_name) + + if diagnostic.constraint_name is not None: + fields.append("Constraint name: " + diagnostic.constraint_name) + + if diagnostic.source_file is not None: + fields.append("File: " + diagnostic.source_file) + + if diagnostic.source_line is not None: + fields.append("Line: " + diagnostic.source_line) + + if diagnostic.source_function is not None: + fields.append("Routine: " + diagnostic.source_function) + + return "\n".join(fields) + + +def exception_formatter(e, verbose_errors: bool = False): + s = str(e) + if verbose_errors: + s += "\n" + diagnostic_output(e.diag) + return click.style(s, fg="red") def format_output(title, cur, headers, status, settings, explain_mode=False): @@ -1724,5 +1903,28 @@ def parse_service_info(service): return service_conf, service_file +def duration_in_words(duration_in_seconds: float) -> str: + if not duration_in_seconds: + return "0 seconds" + components = [] + hours, remainder = divmod(duration_in_seconds, 3600) + if hours > 1: + components.append(f"{hours} hours") + elif hours == 1: + components.append("1 hour") + minutes, seconds = divmod(remainder, 60) + if minutes > 1: + components.append(f"{minutes} minutes") + elif minutes == 1: + components.append("1 minute") + if seconds >= 2: + components.append(f"{int(seconds)} seconds") + elif seconds >= 1: + components.append("1 second") + elif seconds: + components.append(f"{round(seconds, 3)} second") + return " ".join(components) + + if __name__ == "__main__": cli() |