summaryrefslogtreecommitdiffstats
path: root/pgcli/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'pgcli/main.py')
-rw-r--r--pgcli/main.py1630
1 files changed, 1630 insertions, 0 deletions
diff --git a/pgcli/main.py b/pgcli/main.py
new file mode 100644
index 0000000..0fa264f
--- /dev/null
+++ b/pgcli/main.py
@@ -0,0 +1,1630 @@
+from configobj import ConfigObj, ParseError
+from pgspecial.namedqueries import NamedQueries
+from .config import skip_initial_comment
+
+import atexit
+import os
+import re
+import sys
+import traceback
+import logging
+import threading
+import shutil
+import functools
+import pendulum
+import datetime as dt
+import itertools
+import platform
+from time import time, sleep
+from typing import Optional
+
+from cli_helpers.tabular_output import TabularOutputFormatter
+from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers
+from cli_helpers.utils import strip_ansi
+from .explain_output_formatter import ExplainOutputFormatter
+import click
+
+try:
+ import setproctitle
+except ImportError:
+ setproctitle = None
+from prompt_toolkit.completion import DynamicCompleter, ThreadedCompleter
+from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
+from prompt_toolkit.shortcuts import PromptSession, CompleteStyle
+from prompt_toolkit.document import Document
+from prompt_toolkit.filters import HasFocus, IsDone
+from prompt_toolkit.formatted_text import ANSI
+from prompt_toolkit.lexers import PygmentsLexer
+from prompt_toolkit.layout.processors import (
+ ConditionalProcessor,
+ HighlightMatchingBracketProcessor,
+ TabsProcessor,
+)
+from prompt_toolkit.history import FileHistory
+from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
+from pygments.lexers.sql import PostgresLexer
+
+from pgspecial.main import PGSpecial, NO_QUERY, PAGER_OFF, PAGER_LONG_OUTPUT
+import pgspecial as special
+
+from . import auth
+from .pgcompleter import PGCompleter
+from .pgtoolbar import create_toolbar_tokens_func
+from .pgstyle import style_factory, style_factory_output
+from .pgexecute import PGExecute
+from .completion_refresher import CompletionRefresher
+from .config import (
+ get_casing_file,
+ load_config,
+ config_location,
+ ensure_dir_exists,
+ get_config,
+ get_config_filename,
+)
+from .key_bindings import pgcli_bindings
+from .packages.formatter.sqlformatter import register_new_formatter
+from .packages.prompt_utils import confirm_destructive_query
+from .__init__ import __version__
+
+click.disable_unicode_literals_warning = True
+
+try:
+ from urlparse import urlparse, unquote, parse_qs
+except ImportError:
+ from urllib.parse import urlparse, unquote, parse_qs
+
+from getpass import getuser
+
+from psycopg import OperationalError, InterfaceError
+from psycopg.conninfo import make_conninfo, conninfo_to_dict
+
+from collections import namedtuple
+
+try:
+ import sshtunnel
+
+ SSH_TUNNEL_SUPPORT = True
+except ImportError:
+ SSH_TUNNEL_SUPPORT = False
+
+
+# Ref: https://stackoverflow.com/questions/30425105/filter-special-chars-such-as-color-codes-from-shell-output
+COLOR_CODE_REGEX = re.compile(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))")
+DEFAULT_MAX_FIELD_WIDTH = 500
+
+# Query tuples are used for maintaining history
+MetaQuery = namedtuple(
+ "Query",
+ [
+ "query", # The entire text of the command
+ "successful", # True If all subqueries were successful
+ "total_time", # Time elapsed executing the query and formatting results
+ "execution_time", # Time elapsed executing the query
+ "meta_changed", # True if any subquery executed create/alter/drop
+ "db_changed", # True if any subquery changed the database
+ "path_changed", # True if any subquery changed the search path
+ "mutated", # True if any subquery executed insert/update/delete
+ "is_special", # True if the query is a special command
+ ],
+)
+MetaQuery.__new__.__defaults__ = ("", False, 0, 0, False, False, False, False)
+
+OutputSettings = namedtuple(
+ "OutputSettings",
+ "table_format dcmlfmt floatfmt missingval expanded max_width case_function style_output max_field_width",
+)
+OutputSettings.__new__.__defaults__ = (
+ None,
+ None,
+ None,
+ "<null>",
+ False,
+ None,
+ lambda x: x,
+ None,
+ DEFAULT_MAX_FIELD_WIDTH,
+)
+
+
+class PgCliQuitError(Exception):
+ pass
+
+
+class PGCli:
+ default_prompt = "\\u@\\h:\\d> "
+ max_len_prompt = 30
+
+ def set_default_pager(self, config):
+ configured_pager = config["main"].get("pager")
+ os_environ_pager = os.environ.get("PAGER")
+
+ if configured_pager:
+ self.logger.info(
+ 'Default pager found in config file: "%s"', configured_pager
+ )
+ os.environ["PAGER"] = configured_pager
+ elif os_environ_pager:
+ self.logger.info(
+ 'Default pager found in PAGER environment variable: "%s"',
+ os_environ_pager,
+ )
+ os.environ["PAGER"] = os_environ_pager
+ else:
+ self.logger.info(
+ "No default pager found in environment. Using os default pager"
+ )
+
+ # Set default set of less recommended options, if they are not already set.
+ # They are ignored if pager is different than less.
+ if not os.environ.get("LESS"):
+ os.environ["LESS"] = "-SRXF"
+
+ def __init__(
+ self,
+ force_passwd_prompt=False,
+ never_passwd_prompt=False,
+ pgexecute=None,
+ pgclirc_file=None,
+ row_limit=None,
+ single_connection=False,
+ less_chatty=None,
+ prompt=None,
+ prompt_dsn=None,
+ auto_vertical_output=False,
+ warn=None,
+ ssh_tunnel_url: Optional[str] = None,
+ ):
+ self.force_passwd_prompt = force_passwd_prompt
+ self.never_passwd_prompt = never_passwd_prompt
+ self.pgexecute = pgexecute
+ self.dsn_alias = None
+ self.watch_command = None
+
+ # Load config.
+ c = self.config = get_config(pgclirc_file)
+
+ # at this point, config should be written to pgclirc_file if it did not exist. Read it.
+ self.config_writer = load_config(get_config_filename(pgclirc_file))
+
+ # make sure to use self.config_writer, not self.config
+ NamedQueries.instance = NamedQueries.from_config(self.config_writer)
+
+ self.logger = logging.getLogger(__name__)
+ self.initialize_logging()
+
+ self.set_default_pager(c)
+ self.output_file = None
+ self.pgspecial = PGSpecial()
+
+ self.explain_mode = False
+ self.multi_line = c["main"].as_bool("multi_line")
+ self.multiline_mode = c["main"].get("multi_line_mode", "psql")
+ self.vi_mode = c["main"].as_bool("vi")
+ self.auto_expand = auto_vertical_output or c["main"].as_bool("auto_expand")
+ self.expanded_output = c["main"].as_bool("expand")
+ self.pgspecial.timing_enabled = c["main"].as_bool("timing")
+ if row_limit is not None:
+ self.row_limit = row_limit
+ else:
+ self.row_limit = c["main"].as_int("row_limit")
+
+ # if not specified, set to DEFAULT_MAX_FIELD_WIDTH
+ # if specified but empty, set to None to disable truncation
+ # ellipsis will take at least 3 symbols, so this can't be less than 3 if specified and > 0
+ max_field_width = c["main"].get("max_field_width", DEFAULT_MAX_FIELD_WIDTH)
+ if max_field_width and max_field_width.lower() != "none":
+ max_field_width = max(3, abs(int(max_field_width)))
+ else:
+ max_field_width = None
+ self.max_field_width = max_field_width
+
+ self.min_num_menu_lines = c["main"].as_int("min_num_menu_lines")
+ self.multiline_continuation_char = c["main"]["multiline_continuation_char"]
+ self.table_format = c["main"]["table_format"]
+ self.syntax_style = c["main"]["syntax_style"]
+ self.cli_style = c["colors"]
+ self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
+ self.destructive_warning = warn or c["main"]["destructive_warning"]
+ # also handle boolean format of destructive warning
+ self.destructive_warning = {"true": "all", "false": "off"}.get(
+ self.destructive_warning.lower(), self.destructive_warning
+ )
+ self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty")
+ self.null_string = c["main"].get("null_string", "<null>")
+ self.prompt_format = (
+ prompt
+ if prompt is not None
+ else c["main"].get("prompt", self.default_prompt)
+ )
+ self.prompt_dsn_format = prompt_dsn
+ self.on_error = c["main"]["on_error"].upper()
+ self.decimal_format = c["data_formats"]["decimal"]
+ self.float_format = c["data_formats"]["float"]
+ auth.keyring_initialize(c["main"].as_bool("keyring"), logger=self.logger)
+ self.show_bottom_toolbar = c["main"].as_bool("show_bottom_toolbar")
+
+ self.pgspecial.pset_pager(
+ self.config["main"].as_bool("enable_pager") and "on" or "off"
+ )
+
+ self.style_output = style_factory_output(self.syntax_style, c["colors"])
+
+ self.now = dt.datetime.today()
+
+ self.completion_refresher = CompletionRefresher()
+
+ self.query_history = []
+
+ # Initialize completer
+ smart_completion = c["main"].as_bool("smart_completion")
+ keyword_casing = c["main"]["keyword_casing"]
+ self.settings = {
+ "casing_file": get_casing_file(c),
+ "generate_casing_file": c["main"].as_bool("generate_casing_file"),
+ "generate_aliases": c["main"].as_bool("generate_aliases"),
+ "asterisk_column_order": c["main"]["asterisk_column_order"],
+ "qualify_columns": c["main"]["qualify_columns"],
+ "case_column_headers": c["main"].as_bool("case_column_headers"),
+ "search_path_filter": c["main"].as_bool("search_path_filter"),
+ "single_connection": single_connection,
+ "less_chatty": less_chatty,
+ "keyword_casing": keyword_casing,
+ }
+
+ completer = PGCompleter(
+ smart_completion, pgspecial=self.pgspecial, settings=self.settings
+ )
+ self.completer = completer
+ self._completer_lock = threading.Lock()
+ self.register_special_commands()
+
+ self.prompt_app = None
+
+ self.ssh_tunnel_config = c.get("ssh tunnels")
+ self.ssh_tunnel_url = ssh_tunnel_url
+ self.ssh_tunnel = None
+
+ # formatter setup
+ self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"])
+ register_new_formatter(self.formatter)
+
+ def quit(self):
+ raise PgCliQuitError
+
+ def register_special_commands(self):
+
+ self.pgspecial.register(
+ self.change_db,
+ "\\c",
+ "\\c[onnect] database_name",
+ "Change to a new database.",
+ aliases=("use", "\\connect", "USE"),
+ )
+
+ refresh_callback = lambda: self.refresh_completions(persist_priorities="all")
+
+ self.pgspecial.register(
+ self.quit,
+ "\\q",
+ "\\q",
+ "Quit pgcli.",
+ arg_type=NO_QUERY,
+ case_sensitive=True,
+ aliases=(":q",),
+ )
+ self.pgspecial.register(
+ self.quit,
+ "quit",
+ "quit",
+ "Quit pgcli.",
+ arg_type=NO_QUERY,
+ case_sensitive=False,
+ aliases=("exit",),
+ )
+ self.pgspecial.register(
+ refresh_callback,
+ "\\#",
+ "\\#",
+ "Refresh auto-completions.",
+ arg_type=NO_QUERY,
+ )
+ self.pgspecial.register(
+ refresh_callback,
+ "\\refresh",
+ "\\refresh",
+ "Refresh auto-completions.",
+ arg_type=NO_QUERY,
+ )
+ self.pgspecial.register(
+ self.execute_from_file, "\\i", "\\i filename", "Execute commands from file."
+ )
+ self.pgspecial.register(
+ self.write_to_file,
+ "\\o",
+ "\\o [filename]",
+ "Send all query results to file.",
+ )
+ self.pgspecial.register(
+ self.info_connection, "\\conninfo", "\\conninfo", "Get connection details"
+ )
+ self.pgspecial.register(
+ self.change_table_format,
+ "\\T",
+ "\\T [format]",
+ "Change the table format used to output results",
+ )
+
+ def change_table_format(self, pattern, **_):
+ try:
+ if pattern not in TabularOutputFormatter().supported_formats:
+ raise ValueError()
+ self.table_format = pattern
+ yield (None, None, None, f"Changed table format to {pattern}")
+ except ValueError:
+ msg = f"Table format {pattern} not recognized. Allowed formats:"
+ for table_type in TabularOutputFormatter().supported_formats:
+ msg += f"\n\t{table_type}"
+ msg += "\nCurrently set to: %s" % self.table_format
+ yield (None, None, None, msg)
+
+ def info_connection(self, **_):
+ if self.pgexecute.host.startswith("/"):
+ host = 'socket "%s"' % self.pgexecute.host
+ else:
+ host = 'host "%s"' % self.pgexecute.host
+
+ yield (
+ None,
+ None,
+ None,
+ 'You are connected to database "%s" as user '
+ '"%s" on %s at port "%s".'
+ % (self.pgexecute.dbname, self.pgexecute.user, host, self.pgexecute.port),
+ )
+
+ def change_db(self, pattern, **_):
+ if pattern:
+ # Get all the parameters in pattern, handling double quotes if any.
+ infos = re.findall(r'"[^"]*"|[^"\'\s]+', pattern)
+ # Now removing quotes.
+ list(map(lambda s: s.strip('"'), infos))
+
+ infos.extend([None] * (4 - len(infos)))
+ db, user, host, port = infos
+ try:
+ self.pgexecute.connect(
+ database=db,
+ user=user,
+ host=host,
+ port=port,
+ **self.pgexecute.extra_args,
+ )
+ except OperationalError as e:
+ click.secho(str(e), err=True, fg="red")
+ click.echo("Previous connection kept")
+ else:
+ self.pgexecute.connect()
+
+ yield (
+ None,
+ None,
+ None,
+ 'You are now connected to database "%s" as '
+ 'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user),
+ )
+
+ def execute_from_file(self, pattern, **_):
+ if not pattern:
+ message = "\\i: missing required argument"
+ return [(None, None, None, message, "", False, True)]
+ try:
+ with open(os.path.expanduser(pattern), encoding="utf-8") as f:
+ query = f.read()
+ except OSError as e:
+ return [(None, None, None, str(e), "", False, True)]
+
+ if (
+ self.destructive_warning != "off"
+ and confirm_destructive_query(query, self.destructive_warning) is False
+ ):
+ message = "Wise choice. Command execution stopped."
+ return [(None, None, None, message)]
+
+ on_error_resume = self.on_error == "RESUME"
+ return self.pgexecute.run(
+ query,
+ self.pgspecial,
+ on_error_resume=on_error_resume,
+ explain_mode=self.explain_mode,
+ )
+
+ def write_to_file(self, pattern, **_):
+ if not pattern:
+ self.output_file = None
+ message = "File output disabled"
+ return [(None, None, None, message, "", True, True)]
+ filename = os.path.abspath(os.path.expanduser(pattern))
+ if not os.path.isfile(filename):
+ try:
+ open(filename, "w").close()
+ except OSError as e:
+ self.output_file = None
+ message = str(e) + "\nFile output disabled"
+ return [(None, None, None, message, "", False, True)]
+ self.output_file = filename
+ message = 'Writing to file "%s"' % self.output_file
+ return [(None, None, None, message, "", True, True)]
+
+ def initialize_logging(self):
+
+ log_file = self.config["main"]["log_file"]
+ if log_file == "default":
+ log_file = config_location() + "log"
+ ensure_dir_exists(log_file)
+ log_level = self.config["main"]["log_level"]
+
+ # Disable logging if value is NONE by switching to a no-op handler.
+ # Set log level to a high value so it doesn't even waste cycles getting called.
+ if log_level.upper() == "NONE":
+ handler = logging.NullHandler()
+ else:
+ handler = logging.FileHandler(os.path.expanduser(log_file))
+
+ level_map = {
+ "CRITICAL": logging.CRITICAL,
+ "ERROR": logging.ERROR,
+ "WARNING": logging.WARNING,
+ "INFO": logging.INFO,
+ "DEBUG": logging.DEBUG,
+ "NONE": logging.CRITICAL,
+ }
+
+ log_level = level_map[log_level.upper()]
+
+ formatter = logging.Formatter(
+ "%(asctime)s (%(process)d/%(threadName)s) "
+ "%(name)s %(levelname)s - %(message)s"
+ )
+
+ handler.setFormatter(formatter)
+
+ root_logger = logging.getLogger("pgcli")
+ root_logger.addHandler(handler)
+ root_logger.setLevel(log_level)
+
+ root_logger.debug("Initializing pgcli logging.")
+ root_logger.debug("Log file %r.", log_file)
+
+ pgspecial_logger = logging.getLogger("pgspecial")
+ pgspecial_logger.addHandler(handler)
+ pgspecial_logger.setLevel(log_level)
+
+ def connect_dsn(self, dsn, **kwargs):
+ self.connect(dsn=dsn, **kwargs)
+
+ def connect_service(self, service, user):
+ service_config, file = parse_service_info(service)
+ if service_config is None:
+ click.secho(
+ f"service '{service}' was not found in {file}", err=True, fg="red"
+ )
+ exit(1)
+ self.connect(
+ database=service_config.get("dbname"),
+ host=service_config.get("host"),
+ user=user or service_config.get("user"),
+ port=service_config.get("port"),
+ passwd=service_config.get("password"),
+ )
+
+ def connect_uri(self, uri):
+ kwargs = conninfo_to_dict(uri)
+ remap = {"dbname": "database", "password": "passwd"}
+ kwargs = {remap.get(k, k): v for k, v in kwargs.items()}
+ self.connect(**kwargs)
+
+ def connect(
+ self, database="", host="", user="", port="", passwd="", dsn="", **kwargs
+ ):
+ # Connect to the database.
+
+ if not user:
+ user = getuser()
+
+ if not database:
+ database = user
+
+ kwargs.setdefault("application_name", "pgcli")
+
+ # If password prompt is not forced but no password is provided, try
+ # getting it from environment variable.
+ if not self.force_passwd_prompt and not passwd:
+ passwd = os.environ.get("PGPASSWORD", "")
+
+ # Prompt for a password immediately if requested via the -W flag. This
+ # avoids wasting time trying to connect to the database and catching a
+ # no-password exception.
+ # If we successfully parsed a password from a URI, there's no need to
+ # prompt for it, even with the -W flag
+ if self.force_passwd_prompt and not passwd:
+ passwd = click.prompt(
+ "Password for %s" % user, hide_input=True, show_default=False, type=str
+ )
+
+ key = f"{user}@{host}"
+
+ if not passwd and auth.keyring:
+ passwd = auth.keyring_get_password(key)
+
+ def should_ask_for_password(exc):
+ # Prompt for a password after 1st attempt to connect
+ # fails. Don't prompt if the -w flag is supplied
+ if self.never_passwd_prompt:
+ return False
+ error_msg = exc.args[0]
+ if "no password supplied" in error_msg:
+ return True
+ if "password authentication failed" in error_msg:
+ return True
+ return False
+
+ if dsn:
+ parsed_dsn = conninfo_to_dict(dsn)
+ if "host" in parsed_dsn:
+ host = parsed_dsn["host"]
+ if "port" in parsed_dsn:
+ port = parsed_dsn["port"]
+
+ if self.ssh_tunnel_config and not self.ssh_tunnel_url:
+ for db_host_regex, tunnel_url in self.ssh_tunnel_config.items():
+ if re.search(db_host_regex, host):
+ self.ssh_tunnel_url = tunnel_url
+ break
+
+ if self.ssh_tunnel_url:
+ # We add the protocol as urlparse doesn't find it by itself
+ if "://" not in self.ssh_tunnel_url:
+ self.ssh_tunnel_url = f"ssh://{self.ssh_tunnel_url}"
+
+ tunnel_info = urlparse(self.ssh_tunnel_url)
+ params = {
+ "local_bind_address": ("127.0.0.1",),
+ "remote_bind_address": (host, int(port or 5432)),
+ "ssh_address_or_host": (tunnel_info.hostname, tunnel_info.port or 22),
+ "logger": self.logger,
+ }
+ if tunnel_info.username:
+ params["ssh_username"] = tunnel_info.username
+ if tunnel_info.password:
+ params["ssh_password"] = tunnel_info.password
+
+ # Hack: sshtunnel adds a console handler to the logger, so we revert handlers.
+ # We can remove this when https://github.com/pahaz/sshtunnel/pull/250 is merged.
+ logger_handlers = self.logger.handlers.copy()
+ try:
+ self.ssh_tunnel = sshtunnel.SSHTunnelForwarder(**params)
+ self.ssh_tunnel.start()
+ except Exception as e:
+ self.logger.handlers = logger_handlers
+ self.logger.error("traceback: %r", traceback.format_exc())
+ click.secho(str(e), err=True, fg="red")
+ exit(1)
+ self.logger.handlers = logger_handlers
+
+ atexit.register(self.ssh_tunnel.stop)
+ host = "127.0.0.1"
+ port = self.ssh_tunnel.local_bind_ports[0]
+
+ if dsn:
+ dsn = make_conninfo(dsn, host=host, port=port)
+
+ # Attempt to connect to the database.
+ # Note that passwd may be empty on the first attempt. If connection
+ # fails because of a missing or incorrect password, but we're allowed to
+ # prompt for a password (no -w flag), prompt for a passwd and try again.
+ try:
+ try:
+ pgexecute = PGExecute(database, user, passwd, host, port, dsn, **kwargs)
+ except (OperationalError, InterfaceError) as e:
+ if should_ask_for_password(e):
+ passwd = click.prompt(
+ "Password for %s" % user,
+ hide_input=True,
+ show_default=False,
+ type=str,
+ )
+ pgexecute = PGExecute(
+ database, user, passwd, host, port, dsn, **kwargs
+ )
+ else:
+ raise e
+ if passwd and auth.keyring:
+ auth.keyring_set_password(key, passwd)
+
+ except Exception as e: # Connecting to a database could fail.
+ self.logger.debug("Database connection failed: %r.", e)
+ self.logger.error("traceback: %r", traceback.format_exc())
+ click.secho(str(e), err=True, fg="red")
+ exit(1)
+
+ self.pgexecute = pgexecute
+
+ def handle_editor_command(self, text):
+ r"""
+ Editor command is any query that is prefixed or suffixed
+ by a '\e'. The reason for a while loop is because a user
+ might edit a query multiple times.
+ For eg:
+ "select * from \e"<enter> to edit it in vim, then come
+ back to the prompt with the edited query "select * from
+ blah where q = 'abc'\e" to edit it again.
+ :param text: Document
+ :return: Document
+ """
+ editor_command = special.editor_command(text)
+ while editor_command:
+ if editor_command == "\\e":
+ filename = special.get_filename(text)
+ query = special.get_editor_query(text) or self.get_last_query()
+ else: # \ev or \ef
+ filename = None
+ spec = text.split()[1]
+ if editor_command == "\\ev":
+ query = self.pgexecute.view_definition(spec)
+ elif editor_command == "\\ef":
+ query = self.pgexecute.function_definition(spec)
+ sql, message = special.open_external_editor(filename, sql=query)
+ if message:
+ # Something went wrong. Raise an exception and bail.
+ raise RuntimeError(message)
+ while True:
+ try:
+ text = self.prompt_app.prompt(default=sql)
+ break
+ except KeyboardInterrupt:
+ sql = ""
+
+ editor_command = special.editor_command(text)
+ return text
+
+ def execute_command(self, text):
+ logger = self.logger
+
+ query = MetaQuery(query=text, successful=False)
+
+ try:
+ if self.destructive_warning != "off":
+ destroy = confirm = confirm_destructive_query(
+ text, self.destructive_warning
+ )
+ if destroy is False:
+ click.secho("Wise choice!")
+ raise KeyboardInterrupt
+ elif destroy:
+ click.secho("Your call!")
+ output, query = self._evaluate_command(text)
+ except KeyboardInterrupt:
+ # Restart connection to the database
+ self.pgexecute.connect()
+ logger.debug("cancelled query, sql: %r", text)
+ click.secho("cancelled query", err=True, fg="red")
+ except NotImplementedError:
+ click.secho("Not Yet Implemented.", fg="yellow")
+ except OperationalError as e:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ self._handle_server_closed_connection(text)
+ except (PgCliQuitError, EOFError) as e:
+ raise
+ except Exception as e:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ click.secho(str(e), err=True, fg="red")
+ else:
+ try:
+ if self.output_file and not text.startswith(("\\o ", "\\? ")):
+ try:
+ with open(self.output_file, "a", encoding="utf-8") as f:
+ click.echo(text, file=f)
+ click.echo("\n".join(output), file=f)
+ click.echo("", file=f) # extra newline
+ except OSError as e:
+ click.secho(str(e), err=True, fg="red")
+ else:
+ if output:
+ self.echo_via_pager("\n".join(output))
+ except KeyboardInterrupt:
+ pass
+
+ if self.pgspecial.timing_enabled:
+ # Only add humanized time display if > 1 second
+ if query.total_time > 1:
+ print(
+ "Time: %0.03fs (%s), executed in: %0.03fs (%s)"
+ % (
+ query.total_time,
+ pendulum.Duration(seconds=query.total_time).in_words(),
+ query.execution_time,
+ pendulum.Duration(seconds=query.execution_time).in_words(),
+ )
+ )
+ else:
+ print("Time: %0.03fs" % query.total_time)
+
+ # Check if we need to update completions, in order of most
+ # to least drastic changes
+ if query.db_changed:
+ with self._completer_lock:
+ self.completer.reset_completions()
+ self.refresh_completions(persist_priorities="keywords")
+ elif query.meta_changed:
+ self.refresh_completions(persist_priorities="all")
+ elif query.path_changed:
+ logger.debug("Refreshing search path")
+ with self._completer_lock:
+ self.completer.set_search_path(self.pgexecute.search_path())
+ logger.debug("Search path: %r", self.completer.search_path)
+ return query
+
+ def run_cli(self):
+ logger = self.logger
+
+ history_file = self.config["main"]["history_file"]
+ if history_file == "default":
+ history_file = config_location() + "history"
+ history = FileHistory(os.path.expanduser(history_file))
+ self.refresh_completions(history=history, persist_priorities="none")
+
+ self.prompt_app = self._build_cli(history)
+
+ if not self.less_chatty:
+ print("Server: PostgreSQL", self.pgexecute.server_version)
+ print("Version:", __version__)
+ print("Home: http://pgcli.com")
+
+ try:
+ while True:
+ try:
+ text = self.prompt_app.prompt()
+ except KeyboardInterrupt:
+ continue
+
+ try:
+ text = self.handle_editor_command(text)
+ except RuntimeError as e:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ click.secho(str(e), err=True, fg="red")
+ continue
+
+ self.handle_watch_command(text)
+
+ self.now = dt.datetime.today()
+
+ # Allow PGCompleter to learn user's preferred keywords, etc.
+ with self._completer_lock:
+ self.completer.extend_query_history(text)
+
+ except (PgCliQuitError, EOFError):
+ if not self.less_chatty:
+ print("Goodbye!")
+
+ def handle_watch_command(self, text):
+ # Initialize default metaquery in case execution fails
+ self.watch_command, timing = special.get_watch_command(text)
+
+ # If we run \watch without a command, apply it to the last query run.
+ if self.watch_command is not None and not self.watch_command.strip():
+ try:
+ self.watch_command = self.query_history[-1].query
+ except IndexError:
+ click.secho(
+ "\\watch cannot be used with an empty query", err=True, fg="red"
+ )
+ self.watch_command = None
+
+ # If there's a command to \watch, run it in a loop.
+ if self.watch_command:
+ while self.watch_command:
+ try:
+ query = self.execute_command(self.watch_command)
+ click.echo(f"Waiting for {timing} seconds before repeating")
+ sleep(timing)
+ except KeyboardInterrupt:
+ self.watch_command = None
+
+ # Otherwise, execute it as a regular command.
+ else:
+ query = self.execute_command(text)
+
+ self.query_history.append(query)
+
+ def _build_cli(self, history):
+ key_bindings = pgcli_bindings(self)
+
+ def get_message():
+ if self.dsn_alias and self.prompt_dsn_format is not None:
+ prompt_format = self.prompt_dsn_format
+ else:
+ prompt_format = self.prompt_format
+
+ prompt = self.get_prompt(prompt_format)
+
+ if (
+ prompt_format == self.default_prompt
+ and len(prompt) > self.max_len_prompt
+ ):
+ prompt = self.get_prompt("\\d> ")
+
+ prompt = prompt.replace("\\x1b", "\x1b")
+ return ANSI(prompt)
+
+ def get_continuation(width, line_number, is_soft_wrap):
+ continuation = self.multiline_continuation_char * (width - 1) + " "
+ return [("class:continuation", continuation)]
+
+ get_toolbar_tokens = create_toolbar_tokens_func(self)
+
+ if self.wider_completion_menu:
+ complete_style = CompleteStyle.MULTI_COLUMN
+ else:
+ complete_style = CompleteStyle.COLUMN
+
+ with self._completer_lock:
+ prompt_app = PromptSession(
+ lexer=PygmentsLexer(PostgresLexer),
+ reserve_space_for_menu=self.min_num_menu_lines,
+ message=get_message,
+ prompt_continuation=get_continuation,
+ bottom_toolbar=get_toolbar_tokens if self.show_bottom_toolbar else None,
+ complete_style=complete_style,
+ input_processors=[
+ # Highlight matching brackets while editing.
+ ConditionalProcessor(
+ processor=HighlightMatchingBracketProcessor(chars="[](){}"),
+ filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(),
+ ),
+ # Render \t as 4 spaces instead of "^I"
+ TabsProcessor(char1=" ", char2=" "),
+ ],
+ auto_suggest=AutoSuggestFromHistory(),
+ tempfile_suffix=".sql",
+ # N.b. pgcli's multi-line mode controls submit-on-Enter (which
+ # overrides the default behaviour of prompt_toolkit) and is
+ # distinct from prompt_toolkit's multiline mode here, which
+ # controls layout/display of the prompt/buffer
+ multiline=True,
+ history=history,
+ completer=ThreadedCompleter(DynamicCompleter(lambda: self.completer)),
+ complete_while_typing=True,
+ style=style_factory(self.syntax_style, self.cli_style),
+ include_default_pygments_style=False,
+ key_bindings=key_bindings,
+ enable_open_in_editor=True,
+ enable_system_prompt=True,
+ enable_suspend=True,
+ editing_mode=EditingMode.VI if self.vi_mode else EditingMode.EMACS,
+ search_ignore_case=True,
+ )
+
+ return prompt_app
+
+ def _should_limit_output(self, sql, cur):
+ """returns True if the output should be truncated, False otherwise."""
+ if self.explain_mode:
+ return False
+ if not is_select(sql):
+ return False
+
+ return (
+ not self._has_limit(sql)
+ and self.row_limit != 0
+ and cur
+ and cur.rowcount > self.row_limit
+ )
+
+ def _has_limit(self, sql):
+ if not sql:
+ return False
+ return "limit " in sql.lower()
+
+ def _limit_output(self, cur):
+ limit = min(self.row_limit, cur.rowcount)
+ new_cur = itertools.islice(cur, limit)
+ new_status = "SELECT " + str(limit)
+ click.secho("The result was limited to %s rows" % limit, fg="red")
+
+ return new_cur, new_status
+
+ def _evaluate_command(self, text):
+ """Used to run a command entered by the user during CLI operation
+ (Puts the E in REPL)
+
+ returns (results, MetaQuery)
+ """
+ logger = self.logger
+ logger.debug("sql: %r", text)
+
+ # set query to formatter in order to parse table name
+ self.formatter.query = text
+ all_success = True
+ meta_changed = False # CREATE, ALTER, DROP, etc
+ mutated = False # INSERT, DELETE, etc
+ db_changed = False
+ path_changed = False
+ output = []
+ total = 0
+ execution = 0
+
+ # Run the query.
+ start = time()
+ on_error_resume = self.on_error == "RESUME"
+ res = self.pgexecute.run(
+ text,
+ self.pgspecial,
+ exception_formatter,
+ on_error_resume,
+ explain_mode=self.explain_mode,
+ )
+
+ is_special = None
+
+ for title, cur, headers, status, sql, success, is_special in res:
+ logger.debug("headers: %r", headers)
+ logger.debug("rows: %r", cur)
+ logger.debug("status: %r", status)
+
+ if self._should_limit_output(sql, cur):
+ cur, status = self._limit_output(cur)
+
+ if self.pgspecial.auto_expand or self.auto_expand:
+ max_width = self.prompt_app.output.get_size().columns
+ else:
+ max_width = None
+
+ expanded = self.pgspecial.expanded_output or self.expanded_output
+ settings = OutputSettings(
+ table_format=self.table_format,
+ dcmlfmt=self.decimal_format,
+ floatfmt=self.float_format,
+ missingval=self.null_string,
+ expanded=expanded,
+ max_width=max_width,
+ case_function=(
+ self.completer.case
+ if self.settings["case_column_headers"]
+ else lambda x: x
+ ),
+ style_output=self.style_output,
+ max_field_width=self.max_field_width,
+ )
+ execution = time() - start
+ formatted = format_output(
+ title, cur, headers, status, settings, self.explain_mode
+ )
+
+ output.extend(formatted)
+ total = time() - start
+
+ # Keep track of whether any of the queries are mutating or changing
+ # the database
+ if success:
+ mutated = mutated or is_mutating(status)
+ db_changed = db_changed or has_change_db_cmd(sql)
+ meta_changed = meta_changed or has_meta_cmd(sql)
+ path_changed = path_changed or has_change_path_cmd(sql)
+ else:
+ all_success = False
+
+ meta_query = MetaQuery(
+ text,
+ all_success,
+ total,
+ execution,
+ meta_changed,
+ db_changed,
+ path_changed,
+ mutated,
+ is_special,
+ )
+
+ return output, meta_query
+
+ def _handle_server_closed_connection(self, text):
+ """Used during CLI execution."""
+ try:
+ click.secho("Reconnecting...", fg="green")
+ self.pgexecute.connect()
+ click.secho("Reconnected!", fg="green")
+ self.execute_command(text)
+ except OperationalError as e:
+ click.secho("Reconnect Failed", fg="red")
+ click.secho(str(e), err=True, fg="red")
+
+ def refresh_completions(self, history=None, persist_priorities="all"):
+ """Refresh outdated completions
+
+ :param history: A prompt_toolkit.history.FileHistory object. Used to
+ load keyword and identifier preferences
+
+ :param persist_priorities: 'all' or 'keywords'
+ """
+
+ callback = functools.partial(
+ self._on_completions_refreshed, persist_priorities=persist_priorities
+ )
+ return self.completion_refresher.refresh(
+ self.pgexecute,
+ self.pgspecial,
+ callback,
+ history=history,
+ settings=self.settings,
+ )
+
+ def _on_completions_refreshed(self, new_completer, persist_priorities):
+ self._swap_completer_objects(new_completer, persist_priorities)
+
+ if self.prompt_app:
+ # After refreshing, redraw the CLI to clear the statusbar
+ # "Refreshing completions..." indicator
+ self.prompt_app.app.invalidate()
+
+ def _swap_completer_objects(self, new_completer, persist_priorities):
+ """Swap the completer object with the newly created completer.
+
+ persist_priorities is a string specifying how the old completer's
+ learned prioritizer should be transferred to the new completer.
+
+ 'none' - The new prioritizer is left in a new/clean state
+
+ 'all' - The new prioritizer is updated to exactly reflect
+ the old one
+
+ 'keywords' - The new prioritizer is updated with old keyword
+ priorities, but not any other.
+
+ """
+ with self._completer_lock:
+ old_completer = self.completer
+ self.completer = new_completer
+
+ if persist_priorities == "all":
+ # Just swap over the entire prioritizer
+ new_completer.prioritizer = old_completer.prioritizer
+ elif persist_priorities == "keywords":
+ # Swap over the entire prioritizer, but clear name priorities,
+ # leaving learned keyword priorities alone
+ new_completer.prioritizer = old_completer.prioritizer
+ new_completer.prioritizer.clear_names()
+ elif persist_priorities == "none":
+ # Leave the new prioritizer as is
+ pass
+ self.completer = new_completer
+
+ def get_completions(self, text, cursor_positition):
+ with self._completer_lock:
+ return self.completer.get_completions(
+ Document(text=text, cursor_position=cursor_positition), None
+ )
+
+ def get_prompt(self, string):
+ # should be before replacing \\d
+ string = string.replace("\\dsn_alias", self.dsn_alias or "")
+ string = string.replace("\\t", self.now.strftime("%x %X"))
+ string = string.replace("\\u", self.pgexecute.user or "(none)")
+ string = string.replace("\\H", self.pgexecute.host or "(none)")
+ string = string.replace("\\h", self.pgexecute.short_host or "(none)")
+ string = string.replace("\\d", self.pgexecute.dbname or "(none)")
+ string = string.replace(
+ "\\p",
+ str(self.pgexecute.port) if self.pgexecute.port is not None else "5432",
+ )
+ string = string.replace("\\i", str(self.pgexecute.pid) or "(none)")
+ string = string.replace("\\#", "#" if self.pgexecute.superuser else ">")
+ string = string.replace("\\n", "\n")
+ return string
+
+ def get_last_query(self):
+ """Get the last query executed or None."""
+ return self.query_history[-1][0] if self.query_history else None
+
+ def is_too_wide(self, line):
+ """Will this line be too wide to fit into terminal?"""
+ if not self.prompt_app:
+ return False
+ return (
+ len(COLOR_CODE_REGEX.sub("", line))
+ > self.prompt_app.output.get_size().columns
+ )
+
+ def is_too_tall(self, lines):
+ """Are there too many lines to fit into terminal?"""
+ if not self.prompt_app:
+ return False
+ return len(lines) >= (self.prompt_app.output.get_size().rows - 4)
+
+ def echo_via_pager(self, text, color=None):
+ if self.pgspecial.pager_config == PAGER_OFF or self.watch_command:
+ click.echo(text, color=color)
+ elif (
+ self.pgspecial.pager_config == PAGER_LONG_OUTPUT
+ and self.table_format != "csv"
+ ):
+ lines = text.split("\n")
+
+ # The last 4 lines are reserved for the pgcli menu and padding
+ if self.is_too_tall(lines) or any(self.is_too_wide(l) for l in lines):
+ click.echo_via_pager(text, color=color)
+ else:
+ click.echo(text, color=color)
+ else:
+ click.echo_via_pager(text, color)
+
+
+@click.command()
+# Default host is '' so psycopg can default to either localhost or unix socket
+@click.option(
+ "-h",
+ "--host",
+ default="",
+ envvar="PGHOST",
+ help="Host address of the postgres database.",
+)
+@click.option(
+ "-p",
+ "--port",
+ default=5432,
+ help="Port number at which the " "postgres instance is listening.",
+ envvar="PGPORT",
+ type=click.INT,
+)
+@click.option(
+ "-U",
+ "--username",
+ "username_opt",
+ help="Username to connect to the postgres database.",
+)
+@click.option(
+ "-u", "--user", "username_opt", help="Username to connect to the postgres database."
+)
+@click.option(
+ "-W",
+ "--password",
+ "prompt_passwd",
+ is_flag=True,
+ default=False,
+ help="Force password prompt.",
+)
+@click.option(
+ "-w",
+ "--no-password",
+ "never_prompt",
+ is_flag=True,
+ default=False,
+ help="Never prompt for password.",
+)
+@click.option(
+ "--single-connection",
+ "single_connection",
+ is_flag=True,
+ default=False,
+ help="Do not use a separate connection for completions.",
+)
+@click.option("-v", "--version", is_flag=True, help="Version of pgcli.")
+@click.option("-d", "--dbname", "dbname_opt", help="database name to connect to.")
+@click.option(
+ "--pgclirc",
+ default=config_location() + "config",
+ envvar="PGCLIRC",
+ help="Location of pgclirc file.",
+ type=click.Path(dir_okay=False),
+)
+@click.option(
+ "-D",
+ "--dsn",
+ default="",
+ envvar="DSN",
+ help="Use DSN configured into the [alias_dsn] section of pgclirc file.",
+)
+@click.option(
+ "--list-dsn",
+ "list_dsn",
+ is_flag=True,
+ help="list of DSN configured into the [alias_dsn] section of pgclirc file.",
+)
+@click.option(
+ "--row-limit",
+ default=None,
+ envvar="PGROWLIMIT",
+ type=click.INT,
+ help="Set threshold for row limit prompt. Use 0 to disable prompt.",
+)
+@click.option(
+ "--less-chatty",
+ "less_chatty",
+ is_flag=True,
+ default=False,
+ help="Skip intro on startup and goodbye on exit.",
+)
+@click.option("--prompt", help='Prompt format (Default: "\\u@\\h:\\d> ").')
+@click.option(
+ "--prompt-dsn",
+ help='Prompt format for connections using DSN aliases (Default: "\\u@\\h:\\d> ").',
+)
+@click.option(
+ "-l",
+ "--list",
+ "list_databases",
+ is_flag=True,
+ help="list available databases, then exit.",
+)
+@click.option(
+ "--auto-vertical-output",
+ is_flag=True,
+ help="Automatically switch to vertical output mode if the result is wider than the terminal width.",
+)
+@click.option(
+ "--warn",
+ default=None,
+ type=click.Choice(["all", "moderate", "off"]),
+ help="Warn before running a destructive query.",
+)
+@click.option(
+ "--ssh-tunnel",
+ default=None,
+ help="Open an SSH tunnel to the given address and connect to the database from it.",
+)
+@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1)
+@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1)
+def cli(
+ dbname,
+ username_opt,
+ host,
+ port,
+ prompt_passwd,
+ never_prompt,
+ single_connection,
+ dbname_opt,
+ username,
+ version,
+ pgclirc,
+ dsn,
+ row_limit,
+ less_chatty,
+ prompt,
+ prompt_dsn,
+ list_databases,
+ auto_vertical_output,
+ list_dsn,
+ warn,
+ ssh_tunnel: str,
+):
+ if version:
+ print("Version:", __version__)
+ sys.exit(0)
+
+ config_dir = os.path.dirname(config_location())
+ if not os.path.exists(config_dir):
+ os.makedirs(config_dir)
+
+ # Migrate the config file from old location.
+ config_full_path = config_location() + "config"
+ if os.path.exists(os.path.expanduser("~/.pgclirc")):
+ if not os.path.exists(config_full_path):
+ shutil.move(os.path.expanduser("~/.pgclirc"), config_full_path)
+ print("Config file (~/.pgclirc) moved to new location", config_full_path)
+ else:
+ print("Config file is now located at", config_full_path)
+ print(
+ "Please move the existing config file ~/.pgclirc to",
+ config_full_path,
+ )
+ if list_dsn:
+ try:
+ cfg = load_config(pgclirc, config_full_path)
+ for alias in cfg["alias_dsn"]:
+ click.secho(alias + " : " + cfg["alias_dsn"][alias])
+ sys.exit(0)
+ except Exception as err:
+ click.secho(
+ "Invalid DSNs found in the config file. "
+ 'Please check the "[alias_dsn]" section in pgclirc.',
+ err=True,
+ fg="red",
+ )
+ exit(1)
+
+ if ssh_tunnel and not SSH_TUNNEL_SUPPORT:
+ click.secho(
+ 'Cannot open SSH tunnel, "sshtunnel" package was not found. '
+ "Please install pgcli with `pip install pgcli[sshtunnel]` if you want SSH tunnel support.",
+ err=True,
+ fg="red",
+ )
+ exit(1)
+
+ pgcli = PGCli(
+ prompt_passwd,
+ never_prompt,
+ pgclirc_file=pgclirc,
+ row_limit=row_limit,
+ single_connection=single_connection,
+ less_chatty=less_chatty,
+ prompt=prompt,
+ prompt_dsn=prompt_dsn,
+ auto_vertical_output=auto_vertical_output,
+ warn=warn,
+ ssh_tunnel_url=ssh_tunnel,
+ )
+
+ # Choose which ever one has a valid value.
+ if dbname_opt and dbname:
+ # work as psql: when database is given as option and argument use the argument as user
+ username = dbname
+ database = dbname_opt or dbname or ""
+ user = username_opt or username
+ service = None
+ if database.startswith("service="):
+ service = database[8:]
+ elif os.getenv("PGSERVICE") is not None:
+ service = os.getenv("PGSERVICE")
+ # because option --list or -l are not supposed to have a db name
+ if list_databases:
+ database = "postgres"
+
+ if dsn != "":
+ try:
+ cfg = load_config(pgclirc, config_full_path)
+ dsn_config = cfg["alias_dsn"][dsn]
+ except KeyError:
+ click.secho(
+ f"Could not find a DSN with alias {dsn}. "
+ 'Please check the "[alias_dsn]" section in pgclirc.',
+ err=True,
+ fg="red",
+ )
+ exit(1)
+ except Exception:
+ click.secho(
+ "Invalid DSNs found in the config file. "
+ 'Please check the "[alias_dsn]" section in pgclirc.',
+ err=True,
+ fg="red",
+ )
+ exit(1)
+ pgcli.connect_uri(dsn_config)
+ pgcli.dsn_alias = dsn
+ elif "://" in database:
+ pgcli.connect_uri(database)
+ elif "=" in database and service is None:
+ pgcli.connect_dsn(database, user=user)
+ elif service is not None:
+ pgcli.connect_service(service, user)
+ else:
+ pgcli.connect(database, host, user, port)
+
+ if list_databases:
+ cur, headers, status = pgcli.pgexecute.full_databases()
+
+ title = "List of databases"
+ settings = OutputSettings(table_format="ascii", missingval="<null>")
+ formatted = format_output(title, cur, headers, status, settings)
+ pgcli.echo_via_pager("\n".join(formatted))
+
+ sys.exit(0)
+
+ pgcli.logger.debug(
+ "Launch Params: \n" "\tdatabase: %r" "\tuser: %r" "\thost: %r" "\tport: %r",
+ database,
+ user,
+ host,
+ port,
+ )
+
+ if setproctitle:
+ obfuscate_process_password()
+
+ pgcli.run_cli()
+
+
+def obfuscate_process_password():
+ process_title = setproctitle.getproctitle()
+ if "://" in process_title:
+ process_title = re.sub(r":(.*):(.*)@", r":\1:xxxx@", process_title)
+ elif "=" in process_title:
+ process_title = re.sub(
+ r"password=(.+?)((\s[a-zA-Z]+=)|$)", r"password=xxxx\2", process_title
+ )
+
+ setproctitle.setproctitle(process_title)
+
+
+def has_meta_cmd(query):
+ """Determines if the completion needs a refresh by checking if the sql
+ statement is an alter, create, drop, commit or rollback."""
+ try:
+ first_token = query.split()[0]
+ if first_token.lower() in ("alter", "create", "drop", "commit", "rollback"):
+ return True
+ except Exception:
+ return False
+
+ return False
+
+
+def has_change_db_cmd(query):
+ """Determines if the statement is a database switch such as 'use' or '\\c'"""
+ try:
+ first_token = query.split()[0]
+ if first_token.lower() in ("use", "\\c", "\\connect"):
+ return True
+ except Exception:
+ return False
+
+ return False
+
+
+def has_change_path_cmd(sql):
+ """Determines if the search_path should be refreshed by checking if the
+ sql has 'set search_path'."""
+ return "set search_path" in sql.lower()
+
+
+def is_mutating(status):
+ """Determines if the statement is mutating based on the status."""
+ if not status:
+ return False
+
+ mutating = {"insert", "update", "delete"}
+ return status.split(None, 1)[0].lower() in mutating
+
+
+def is_select(status):
+ """Returns true if the first word in status is 'select'."""
+ if not status:
+ return False
+ return status.split(None, 1)[0].lower() == "select"
+
+
+def exception_formatter(e):
+ return click.style(str(e), fg="red")
+
+
+def format_output(title, cur, headers, status, settings, explain_mode=False):
+ output = []
+ expanded = settings.expanded or settings.table_format == "vertical"
+ table_format = "vertical" if settings.expanded else settings.table_format
+ max_width = settings.max_width
+ case_function = settings.case_function
+ if explain_mode:
+ formatter = ExplainOutputFormatter(max_width or 100)
+ else:
+ formatter = TabularOutputFormatter(format_name=table_format)
+
+ def format_array(val):
+ if val is None:
+ return settings.missingval
+ if not isinstance(val, list):
+ return val
+ return "{" + ",".join(str(format_array(e)) for e in val) + "}"
+
+ def format_arrays(data, headers, **_):
+ data = list(data)
+ for row in data:
+ row[:] = [
+ format_array(val) if isinstance(val, list) else val for val in row
+ ]
+
+ return data, headers
+
+ def format_status(cur, status):
+ # redshift does not return rowcount as part of status.
+ # See https://github.com/dbcli/pgcli/issues/1320
+ if cur and hasattr(cur, "rowcount") and cur.rowcount is not None:
+ if status and not status.endswith(str(cur.rowcount)):
+ status += " %s" % cur.rowcount
+ return status
+
+ output_kwargs = {
+ "sep_title": "RECORD {n}",
+ "sep_character": "-",
+ "sep_length": (1, 25),
+ "missing_value": settings.missingval,
+ "integer_format": settings.dcmlfmt,
+ "float_format": settings.floatfmt,
+ "preprocessors": (format_numbers, format_arrays),
+ "disable_numparse": True,
+ "preserve_whitespace": True,
+ "style": settings.style_output,
+ "max_field_width": settings.max_field_width,
+ }
+ if not settings.floatfmt:
+ output_kwargs["preprocessors"] = (align_decimals,)
+
+ if table_format == "csv":
+ # The default CSV dialect is "excel" which is not handling newline values correctly
+ # Nevertheless, we want to keep on using "excel" on Windows since it uses '\r\n'
+ # as the line terminator
+ # https://github.com/dbcli/pgcli/issues/1102
+ dialect = "excel" if platform.system() == "Windows" else "unix"
+ output_kwargs["dialect"] = dialect
+
+ if title: # Only print the title if it's not None.
+ output.append(title)
+
+ if cur:
+ headers = [case_function(x) for x in headers]
+ if max_width is not None:
+ cur = list(cur)
+ column_types = None
+ if hasattr(cur, "description"):
+ column_types = []
+ for d in cur.description:
+ col_type = cur.adapters.types.get(d.type_code)
+ type_name = col_type.name if col_type else None
+ if type_name in ("numeric", "float4", "float8"):
+ column_types.append(float)
+ if type_name in ("int2", "int4", "int8"):
+ column_types.append(int)
+ else:
+ column_types.append(str)
+
+ formatted = formatter.format_output(cur, headers, **output_kwargs)
+ if isinstance(formatted, str):
+ formatted = iter(formatted.splitlines())
+ first_line = next(formatted)
+ formatted = itertools.chain([first_line], formatted)
+ if (
+ not expanded
+ and max_width
+ and len(strip_ansi(first_line)) > max_width
+ and headers
+ ):
+ formatted = formatter.format_output(
+ cur,
+ headers,
+ format_name="vertical",
+ column_types=column_types,
+ **output_kwargs,
+ )
+ if isinstance(formatted, str):
+ formatted = iter(formatted.splitlines())
+
+ output = itertools.chain(output, formatted)
+
+ # Only print the status if it's not None
+ if status:
+ output = itertools.chain(output, [format_status(cur, status)])
+
+ return output
+
+
+def parse_service_info(service):
+ service = service or os.getenv("PGSERVICE")
+ service_file = os.getenv("PGSERVICEFILE")
+ if not service_file:
+ # try ~/.pg_service.conf (if that exists)
+ if platform.system() == "Windows":
+ service_file = os.getenv("PGSYSCONFDIR") + "\\pg_service.conf"
+ elif os.getenv("PGSYSCONFDIR"):
+ service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf")
+ else:
+ service_file = os.path.expanduser("~/.pg_service.conf")
+ if not service or not os.path.exists(service_file):
+ # nothing to do
+ return None, service_file
+ with open(service_file, newline="") as f:
+ skipped_lines = skip_initial_comment(f)
+ try:
+ service_file_config = ConfigObj(f)
+ except ParseError as err:
+ err.line_number += skipped_lines
+ raise err
+ if service not in service_file_config:
+ return None, service_file
+ service_conf = service_file_config.get(service)
+ return service_conf, service_file
+
+
+if __name__ == "__main__":
+ cli()