diff options
Diffstat (limited to 'pgcli/main.py')
-rw-r--r-- | pgcli/main.py | 1516 |
1 files changed, 1516 insertions, 0 deletions
diff --git a/pgcli/main.py b/pgcli/main.py new file mode 100644 index 0000000..b146898 --- /dev/null +++ b/pgcli/main.py @@ -0,0 +1,1516 @@ +import platform +import warnings +from os.path import expanduser + +from configobj import ConfigObj +from pgspecial.namedqueries import NamedQueries + +warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2") + +import os +import re +import sys +import traceback +import logging +import threading +import shutil +import functools +import pendulum +import datetime as dt +import itertools +import platform +from time import time, sleep +from codecs import open + +keyring = None # keyring will be loaded later + +from cli_helpers.tabular_output import TabularOutputFormatter +from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers +import click + +try: + import setproctitle +except ImportError: + setproctitle = None +from prompt_toolkit.completion import DynamicCompleter, ThreadedCompleter +from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode +from prompt_toolkit.shortcuts import PromptSession, CompleteStyle +from prompt_toolkit.document import Document +from prompt_toolkit.filters import HasFocus, IsDone +from prompt_toolkit.formatted_text import ANSI +from prompt_toolkit.lexers import PygmentsLexer +from prompt_toolkit.layout.processors import ( + ConditionalProcessor, + HighlightMatchingBracketProcessor, + TabsProcessor, +) +from prompt_toolkit.history import FileHistory +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from pygments.lexers.sql import PostgresLexer + +from pgspecial.main import PGSpecial, NO_QUERY, PAGER_OFF, PAGER_LONG_OUTPUT +import pgspecial as special + +from .pgcompleter import PGCompleter +from .pgtoolbar import create_toolbar_tokens_func +from .pgstyle import style_factory, style_factory_output +from .pgexecute import PGExecute +from .completion_refresher import CompletionRefresher +from .config import ( + get_casing_file, + load_config, + config_location, + ensure_dir_exists, + get_config, +) +from .key_bindings import pgcli_bindings +from .packages.prompt_utils import confirm_destructive_query +from .__init__ import __version__ + +click.disable_unicode_literals_warning = True + +try: + from urlparse import urlparse, unquote, parse_qs +except ImportError: + from urllib.parse import urlparse, unquote, parse_qs + +from getpass import getuser +from psycopg2 import OperationalError, InterfaceError +import psycopg2 + +from collections import namedtuple + +from textwrap import dedent + +# Ref: https://stackoverflow.com/questions/30425105/filter-special-chars-such-as-color-codes-from-shell-output +COLOR_CODE_REGEX = re.compile(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))") + +# Query tuples are used for maintaining history +MetaQuery = namedtuple( + "Query", + [ + "query", # The entire text of the command + "successful", # True If all subqueries were successful + "total_time", # Time elapsed executing the query and formatting results + "execution_time", # Time elapsed executing the query + "meta_changed", # True if any subquery executed create/alter/drop + "db_changed", # True if any subquery changed the database + "path_changed", # True if any subquery changed the search path + "mutated", # True if any subquery executed insert/update/delete + "is_special", # True if the query is a special command + ], +) +MetaQuery.__new__.__defaults__ = ("", False, 0, 0, False, False, False, False) + +OutputSettings = namedtuple( + "OutputSettings", + "table_format dcmlfmt floatfmt missingval expanded max_width case_function style_output", +) +OutputSettings.__new__.__defaults__ = ( + None, + None, + None, + "<null>", + False, + None, + lambda x: x, + None, +) + + +class PgCliQuitError(Exception): + pass + + +class PGCli(object): + default_prompt = "\\u@\\h:\\d> " + max_len_prompt = 30 + + def set_default_pager(self, config): + configured_pager = config["main"].get("pager") + os_environ_pager = os.environ.get("PAGER") + + if configured_pager: + self.logger.info( + 'Default pager found in config file: "%s"', configured_pager + ) + os.environ["PAGER"] = configured_pager + elif os_environ_pager: + self.logger.info( + 'Default pager found in PAGER environment variable: "%s"', + os_environ_pager, + ) + os.environ["PAGER"] = os_environ_pager + else: + self.logger.info( + "No default pager found in environment. Using os default pager" + ) + + # Set default set of less recommended options, if they are not already set. + # They are ignored if pager is different than less. + if not os.environ.get("LESS"): + os.environ["LESS"] = "-SRXF" + + def __init__( + self, + force_passwd_prompt=False, + never_passwd_prompt=False, + pgexecute=None, + pgclirc_file=None, + row_limit=None, + single_connection=False, + less_chatty=None, + prompt=None, + prompt_dsn=None, + auto_vertical_output=False, + warn=None, + ): + + self.force_passwd_prompt = force_passwd_prompt + self.never_passwd_prompt = never_passwd_prompt + self.pgexecute = pgexecute + self.dsn_alias = None + self.watch_command = None + + # Load config. + c = self.config = get_config(pgclirc_file) + + NamedQueries.instance = NamedQueries.from_config(self.config) + + self.logger = logging.getLogger(__name__) + self.initialize_logging() + + self.set_default_pager(c) + self.output_file = None + self.pgspecial = PGSpecial() + + self.multi_line = c["main"].as_bool("multi_line") + self.multiline_mode = c["main"].get("multi_line_mode", "psql") + self.vi_mode = c["main"].as_bool("vi") + self.auto_expand = auto_vertical_output or c["main"].as_bool("auto_expand") + self.expanded_output = c["main"].as_bool("expand") + self.pgspecial.timing_enabled = c["main"].as_bool("timing") + if row_limit is not None: + self.row_limit = row_limit + else: + self.row_limit = c["main"].as_int("row_limit") + + self.min_num_menu_lines = c["main"].as_int("min_num_menu_lines") + self.multiline_continuation_char = c["main"]["multiline_continuation_char"] + self.table_format = c["main"]["table_format"] + self.syntax_style = c["main"]["syntax_style"] + self.cli_style = c["colors"] + self.wider_completion_menu = c["main"].as_bool("wider_completion_menu") + c_dest_warning = c["main"].as_bool("destructive_warning") + self.destructive_warning = c_dest_warning if warn is None else warn + self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty") + self.null_string = c["main"].get("null_string", "<null>") + self.prompt_format = ( + prompt + if prompt is not None + else c["main"].get("prompt", self.default_prompt) + ) + self.prompt_dsn_format = prompt_dsn + self.on_error = c["main"]["on_error"].upper() + self.decimal_format = c["data_formats"]["decimal"] + self.float_format = c["data_formats"]["float"] + self.initialize_keyring() + self.show_bottom_toolbar = c["main"].as_bool("show_bottom_toolbar") + + self.pgspecial.pset_pager( + self.config["main"].as_bool("enable_pager") and "on" or "off" + ) + + self.style_output = style_factory_output(self.syntax_style, c["colors"]) + + self.now = dt.datetime.today() + + self.completion_refresher = CompletionRefresher() + + self.query_history = [] + + # Initialize completer + smart_completion = c["main"].as_bool("smart_completion") + keyword_casing = c["main"]["keyword_casing"] + self.settings = { + "casing_file": get_casing_file(c), + "generate_casing_file": c["main"].as_bool("generate_casing_file"), + "generate_aliases": c["main"].as_bool("generate_aliases"), + "asterisk_column_order": c["main"]["asterisk_column_order"], + "qualify_columns": c["main"]["qualify_columns"], + "case_column_headers": c["main"].as_bool("case_column_headers"), + "search_path_filter": c["main"].as_bool("search_path_filter"), + "single_connection": single_connection, + "less_chatty": less_chatty, + "keyword_casing": keyword_casing, + } + + completer = PGCompleter( + smart_completion, pgspecial=self.pgspecial, settings=self.settings + ) + self.completer = completer + self._completer_lock = threading.Lock() + self.register_special_commands() + + self.prompt_app = None + + def quit(self): + raise PgCliQuitError + + def register_special_commands(self): + + self.pgspecial.register( + self.change_db, + "\\c", + "\\c[onnect] database_name", + "Change to a new database.", + aliases=("use", "\\connect", "USE"), + ) + + refresh_callback = lambda: self.refresh_completions(persist_priorities="all") + + self.pgspecial.register( + self.quit, + "\\q", + "\\q", + "Quit pgcli.", + arg_type=NO_QUERY, + case_sensitive=True, + aliases=(":q",), + ) + self.pgspecial.register( + self.quit, + "quit", + "quit", + "Quit pgcli.", + arg_type=NO_QUERY, + case_sensitive=False, + aliases=("exit",), + ) + self.pgspecial.register( + refresh_callback, + "\\#", + "\\#", + "Refresh auto-completions.", + arg_type=NO_QUERY, + ) + self.pgspecial.register( + refresh_callback, + "\\refresh", + "\\refresh", + "Refresh auto-completions.", + arg_type=NO_QUERY, + ) + self.pgspecial.register( + self.execute_from_file, "\\i", "\\i filename", "Execute commands from file." + ) + self.pgspecial.register( + self.write_to_file, + "\\o", + "\\o [filename]", + "Send all query results to file.", + ) + self.pgspecial.register( + self.info_connection, "\\conninfo", "\\conninfo", "Get connection details" + ) + self.pgspecial.register( + self.change_table_format, + "\\T", + "\\T [format]", + "Change the table format used to output results", + ) + + def change_table_format(self, pattern, **_): + try: + if pattern not in TabularOutputFormatter().supported_formats: + raise ValueError() + self.table_format = pattern + yield (None, None, None, "Changed table format to {}".format(pattern)) + except ValueError: + msg = "Table format {} not recognized. Allowed formats:".format(pattern) + for table_type in TabularOutputFormatter().supported_formats: + msg += "\n\t{}".format(table_type) + msg += "\nCurrently set to: %s" % self.table_format + yield (None, None, None, msg) + + def info_connection(self, **_): + if self.pgexecute.host.startswith("/"): + host = 'socket "%s"' % self.pgexecute.host + else: + host = 'host "%s"' % self.pgexecute.host + + yield ( + None, + None, + None, + 'You are connected to database "%s" as user ' + '"%s" on %s at port "%s".' + % (self.pgexecute.dbname, self.pgexecute.user, host, self.pgexecute.port), + ) + + def change_db(self, pattern, **_): + if pattern: + # Get all the parameters in pattern, handling double quotes if any. + infos = re.findall(r'"[^"]*"|[^"\'\s]+', pattern) + # Now removing quotes. + list(map(lambda s: s.strip('"'), infos)) + + infos.extend([None] * (4 - len(infos))) + db, user, host, port = infos + try: + self.pgexecute.connect( + database=db, + user=user, + host=host, + port=port, + **self.pgexecute.extra_args, + ) + except OperationalError as e: + click.secho(str(e), err=True, fg="red") + click.echo("Previous connection kept") + else: + self.pgexecute.connect() + + yield ( + None, + None, + None, + 'You are now connected to database "%s" as ' + 'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user), + ) + + def execute_from_file(self, pattern, **_): + if not pattern: + message = "\\i: missing required argument" + return [(None, None, None, message, "", False, True)] + try: + with open(os.path.expanduser(pattern), encoding="utf-8") as f: + query = f.read() + except IOError as e: + return [(None, None, None, str(e), "", False, True)] + + if self.destructive_warning and confirm_destructive_query(query) is False: + message = "Wise choice. Command execution stopped." + return [(None, None, None, message)] + + on_error_resume = self.on_error == "RESUME" + return self.pgexecute.run( + query, self.pgspecial, on_error_resume=on_error_resume + ) + + def write_to_file(self, pattern, **_): + if not pattern: + self.output_file = None + message = "File output disabled" + return [(None, None, None, message, "", True, True)] + filename = os.path.abspath(os.path.expanduser(pattern)) + if not os.path.isfile(filename): + try: + open(filename, "w").close() + except IOError as e: + self.output_file = None + message = str(e) + "\nFile output disabled" + return [(None, None, None, message, "", False, True)] + self.output_file = filename + message = 'Writing to file "%s"' % self.output_file + return [(None, None, None, message, "", True, True)] + + def initialize_logging(self): + + log_file = self.config["main"]["log_file"] + if log_file == "default": + log_file = config_location() + "log" + ensure_dir_exists(log_file) + log_level = self.config["main"]["log_level"] + + # Disable logging if value is NONE by switching to a no-op handler. + # Set log level to a high value so it doesn't even waste cycles getting called. + if log_level.upper() == "NONE": + handler = logging.NullHandler() + else: + handler = logging.FileHandler(os.path.expanduser(log_file)) + + level_map = { + "CRITICAL": logging.CRITICAL, + "ERROR": logging.ERROR, + "WARNING": logging.WARNING, + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, + "NONE": logging.CRITICAL, + } + + log_level = level_map[log_level.upper()] + + formatter = logging.Formatter( + "%(asctime)s (%(process)d/%(threadName)s) " + "%(name)s %(levelname)s - %(message)s" + ) + + handler.setFormatter(formatter) + + root_logger = logging.getLogger("pgcli") + root_logger.addHandler(handler) + root_logger.setLevel(log_level) + + root_logger.debug("Initializing pgcli logging.") + root_logger.debug("Log file %r.", log_file) + + pgspecial_logger = logging.getLogger("pgspecial") + pgspecial_logger.addHandler(handler) + pgspecial_logger.setLevel(log_level) + + def initialize_keyring(self): + global keyring + + keyring_enabled = self.config["main"].as_bool("keyring") + if keyring_enabled: + # Try best to load keyring (issue #1041). + import importlib + + try: + keyring = importlib.import_module("keyring") + except Exception as e: # ImportError for Python 2, ModuleNotFoundError for Python 3 + self.logger.warning("import keyring failed: %r.", e) + + def connect_dsn(self, dsn, **kwargs): + self.connect(dsn=dsn, **kwargs) + + def connect_service(self, service, user): + service_config, file = parse_service_info(service) + if service_config is None: + click.secho( + "service '%s' was not found in %s" % (service, file), err=True, fg="red" + ) + exit(1) + self.connect( + database=service_config.get("dbname"), + host=service_config.get("host"), + user=user or service_config.get("user"), + port=service_config.get("port"), + passwd=service_config.get("password"), + ) + + def connect_uri(self, uri): + kwargs = psycopg2.extensions.parse_dsn(uri) + remap = {"dbname": "database", "password": "passwd"} + kwargs = {remap.get(k, k): v for k, v in kwargs.items()} + self.connect(**kwargs) + + def connect( + self, database="", host="", user="", port="", passwd="", dsn="", **kwargs + ): + # Connect to the database. + + if not user: + user = getuser() + + if not database: + database = user + + kwargs.setdefault("application_name", "pgcli") + + # If password prompt is not forced but no password is provided, try + # getting it from environment variable. + if not self.force_passwd_prompt and not passwd: + passwd = os.environ.get("PGPASSWORD", "") + + # Find password from store + key = "%s@%s" % (user, host) + keyring_error_message = dedent( + """\ + {} + {} + To remove this message do one of the following: + - prepare keyring as described at: https://keyring.readthedocs.io/en/stable/ + - uninstall keyring: pip uninstall keyring + - disable keyring in our configuration: add keyring = False to [main]""" + ) + if not passwd and keyring: + + try: + passwd = keyring.get_password("pgcli", key) + except (RuntimeError, keyring.errors.InitError) as e: + click.secho( + keyring_error_message.format( + "Load your password from keyring returned:", str(e) + ), + err=True, + fg="red", + ) + + # Prompt for a password immediately if requested via the -W flag. This + # avoids wasting time trying to connect to the database and catching a + # no-password exception. + # If we successfully parsed a password from a URI, there's no need to + # prompt for it, even with the -W flag + if self.force_passwd_prompt and not passwd: + passwd = click.prompt( + "Password for %s" % user, hide_input=True, show_default=False, type=str + ) + + def should_ask_for_password(exc): + # Prompt for a password after 1st attempt to connect + # fails. Don't prompt if the -w flag is supplied + if self.never_passwd_prompt: + return False + error_msg = exc.args[0] + if "no password supplied" in error_msg: + return True + if "password authentication failed" in error_msg: + return True + return False + + # Attempt to connect to the database. + # Note that passwd may be empty on the first attempt. If connection + # fails because of a missing or incorrect password, but we're allowed to + # prompt for a password (no -w flag), prompt for a passwd and try again. + try: + try: + pgexecute = PGExecute(database, user, passwd, host, port, dsn, **kwargs) + except (OperationalError, InterfaceError) as e: + if should_ask_for_password(e): + passwd = click.prompt( + "Password for %s" % user, + hide_input=True, + show_default=False, + type=str, + ) + pgexecute = PGExecute( + database, user, passwd, host, port, dsn, **kwargs + ) + else: + raise e + if passwd and keyring: + try: + keyring.set_password("pgcli", key, passwd) + except (RuntimeError, keyring.errors.KeyringError) as e: + click.secho( + keyring_error_message.format( + "Set password in keyring returned:", str(e) + ), + err=True, + fg="red", + ) + + except Exception as e: # Connecting to a database could fail. + self.logger.debug("Database connection failed: %r.", e) + self.logger.error("traceback: %r", traceback.format_exc()) + click.secho(str(e), err=True, fg="red") + exit(1) + + self.pgexecute = pgexecute + + def handle_editor_command(self, text): + r""" + Editor command is any query that is prefixed or suffixed + by a '\e'. The reason for a while loop is because a user + might edit a query multiple times. + For eg: + "select * from \e"<enter> to edit it in vim, then come + back to the prompt with the edited query "select * from + blah where q = 'abc'\e" to edit it again. + :param text: Document + :return: Document + """ + editor_command = special.editor_command(text) + while editor_command: + if editor_command == "\\e": + filename = special.get_filename(text) + query = special.get_editor_query(text) or self.get_last_query() + else: # \ev or \ef + filename = None + spec = text.split()[1] + if editor_command == "\\ev": + query = self.pgexecute.view_definition(spec) + elif editor_command == "\\ef": + query = self.pgexecute.function_definition(spec) + sql, message = special.open_external_editor(filename, sql=query) + if message: + # Something went wrong. Raise an exception and bail. + raise RuntimeError(message) + while True: + try: + text = self.prompt_app.prompt(default=sql) + break + except KeyboardInterrupt: + sql = "" + + editor_command = special.editor_command(text) + return text + + def execute_command(self, text): + logger = self.logger + + query = MetaQuery(query=text, successful=False) + + try: + if self.destructive_warning: + destroy = confirm = confirm_destructive_query(text) + if destroy is False: + click.secho("Wise choice!") + raise KeyboardInterrupt + elif destroy: + click.secho("Your call!") + output, query = self._evaluate_command(text) + except KeyboardInterrupt: + # Restart connection to the database + self.pgexecute.connect() + logger.debug("cancelled query, sql: %r", text) + click.secho("cancelled query", err=True, fg="red") + except NotImplementedError: + click.secho("Not Yet Implemented.", fg="yellow") + except OperationalError as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + self._handle_server_closed_connection(text) + except (PgCliQuitError, EOFError) as e: + raise + except Exception as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + click.secho(str(e), err=True, fg="red") + else: + try: + if self.output_file and not text.startswith(("\\o ", "\\? ")): + try: + with open(self.output_file, "a", encoding="utf-8") as f: + click.echo(text, file=f) + click.echo("\n".join(output), file=f) + click.echo("", file=f) # extra newline + except IOError as e: + click.secho(str(e), err=True, fg="red") + else: + if output: + self.echo_via_pager("\n".join(output)) + except KeyboardInterrupt: + pass + + if self.pgspecial.timing_enabled: + # Only add humanized time display if > 1 second + if query.total_time > 1: + print( + "Time: %0.03fs (%s), executed in: %0.03fs (%s)" + % ( + query.total_time, + pendulum.Duration(seconds=query.total_time).in_words(), + query.execution_time, + pendulum.Duration(seconds=query.execution_time).in_words(), + ) + ) + else: + print("Time: %0.03fs" % query.total_time) + + # Check if we need to update completions, in order of most + # to least drastic changes + if query.db_changed: + with self._completer_lock: + self.completer.reset_completions() + self.refresh_completions(persist_priorities="keywords") + elif query.meta_changed: + self.refresh_completions(persist_priorities="all") + elif query.path_changed: + logger.debug("Refreshing search path") + with self._completer_lock: + self.completer.set_search_path(self.pgexecute.search_path()) + logger.debug("Search path: %r", self.completer.search_path) + return query + + def run_cli(self): + logger = self.logger + + history_file = self.config["main"]["history_file"] + if history_file == "default": + history_file = config_location() + "history" + history = FileHistory(os.path.expanduser(history_file)) + self.refresh_completions(history=history, persist_priorities="none") + + self.prompt_app = self._build_cli(history) + + if not self.less_chatty: + print("Server: PostgreSQL", self.pgexecute.server_version) + print("Version:", __version__) + print("Chat: https://gitter.im/dbcli/pgcli") + print("Home: http://pgcli.com") + + try: + while True: + try: + text = self.prompt_app.prompt() + except KeyboardInterrupt: + continue + + try: + text = self.handle_editor_command(text) + except RuntimeError as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + click.secho(str(e), err=True, fg="red") + continue + + # Initialize default metaquery in case execution fails + self.watch_command, timing = special.get_watch_command(text) + if self.watch_command: + while self.watch_command: + try: + query = self.execute_command(self.watch_command) + click.echo( + "Waiting for {0} seconds before repeating".format( + timing + ) + ) + sleep(timing) + except KeyboardInterrupt: + self.watch_command = None + else: + query = self.execute_command(text) + + self.now = dt.datetime.today() + + # Allow PGCompleter to learn user's preferred keywords, etc. + with self._completer_lock: + self.completer.extend_query_history(text) + + self.query_history.append(query) + + except (PgCliQuitError, EOFError): + if not self.less_chatty: + print("Goodbye!") + + def _build_cli(self, history): + key_bindings = pgcli_bindings(self) + + def get_message(): + if self.dsn_alias and self.prompt_dsn_format is not None: + prompt_format = self.prompt_dsn_format + else: + prompt_format = self.prompt_format + + prompt = self.get_prompt(prompt_format) + + if ( + prompt_format == self.default_prompt + and len(prompt) > self.max_len_prompt + ): + prompt = self.get_prompt("\\d> ") + + prompt = prompt.replace("\\x1b", "\x1b") + return ANSI(prompt) + + def get_continuation(width, line_number, is_soft_wrap): + continuation = self.multiline_continuation_char * (width - 1) + " " + return [("class:continuation", continuation)] + + get_toolbar_tokens = create_toolbar_tokens_func(self) + + if self.wider_completion_menu: + complete_style = CompleteStyle.MULTI_COLUMN + else: + complete_style = CompleteStyle.COLUMN + + with self._completer_lock: + prompt_app = PromptSession( + lexer=PygmentsLexer(PostgresLexer), + reserve_space_for_menu=self.min_num_menu_lines, + message=get_message, + prompt_continuation=get_continuation, + bottom_toolbar=get_toolbar_tokens if self.show_bottom_toolbar else None, + complete_style=complete_style, + input_processors=[ + # Highlight matching brackets while editing. + ConditionalProcessor( + processor=HighlightMatchingBracketProcessor(chars="[](){}"), + filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(), + ), + # Render \t as 4 spaces instead of "^I" + TabsProcessor(char1=" ", char2=" "), + ], + auto_suggest=AutoSuggestFromHistory(), + tempfile_suffix=".sql", + # N.b. pgcli's multi-line mode controls submit-on-Enter (which + # overrides the default behaviour of prompt_toolkit) and is + # distinct from prompt_toolkit's multiline mode here, which + # controls layout/display of the prompt/buffer + multiline=True, + history=history, + completer=ThreadedCompleter(DynamicCompleter(lambda: self.completer)), + complete_while_typing=True, + style=style_factory(self.syntax_style, self.cli_style), + include_default_pygments_style=False, + key_bindings=key_bindings, + enable_open_in_editor=True, + enable_system_prompt=True, + enable_suspend=True, + editing_mode=EditingMode.VI if self.vi_mode else EditingMode.EMACS, + search_ignore_case=True, + ) + + return prompt_app + + def _should_limit_output(self, sql, cur): + """returns True if the output should be truncated, False otherwise.""" + if not is_select(sql): + return False + + return ( + not self._has_limit(sql) + and self.row_limit != 0 + and cur + and cur.rowcount > self.row_limit + ) + + def _has_limit(self, sql): + if not sql: + return False + return "limit " in sql.lower() + + def _limit_output(self, cur): + limit = min(self.row_limit, cur.rowcount) + new_cur = itertools.islice(cur, limit) + new_status = "SELECT " + str(limit) + click.secho("The result was limited to %s rows" % limit, fg="red") + + return new_cur, new_status + + def _evaluate_command(self, text): + """Used to run a command entered by the user during CLI operation + (Puts the E in REPL) + + returns (results, MetaQuery) + """ + logger = self.logger + logger.debug("sql: %r", text) + + all_success = True + meta_changed = False # CREATE, ALTER, DROP, etc + mutated = False # INSERT, DELETE, etc + db_changed = False + path_changed = False + output = [] + total = 0 + execution = 0 + + # Run the query. + start = time() + on_error_resume = self.on_error == "RESUME" + res = self.pgexecute.run( + text, self.pgspecial, exception_formatter, on_error_resume + ) + + is_special = None + + for title, cur, headers, status, sql, success, is_special in res: + logger.debug("headers: %r", headers) + logger.debug("rows: %r", cur) + logger.debug("status: %r", status) + + if self._should_limit_output(sql, cur): + cur, status = self._limit_output(cur) + + if self.pgspecial.auto_expand or self.auto_expand: + max_width = self.prompt_app.output.get_size().columns + else: + max_width = None + + expanded = self.pgspecial.expanded_output or self.expanded_output + settings = OutputSettings( + table_format=self.table_format, + dcmlfmt=self.decimal_format, + floatfmt=self.float_format, + missingval=self.null_string, + expanded=expanded, + max_width=max_width, + case_function=( + self.completer.case + if self.settings["case_column_headers"] + else lambda x: x + ), + style_output=self.style_output, + ) + execution = time() - start + formatted = format_output(title, cur, headers, status, settings) + + output.extend(formatted) + total = time() - start + + # Keep track of whether any of the queries are mutating or changing + # the database + if success: + mutated = mutated or is_mutating(status) + db_changed = db_changed or has_change_db_cmd(sql) + meta_changed = meta_changed or has_meta_cmd(sql) + path_changed = path_changed or has_change_path_cmd(sql) + else: + all_success = False + + meta_query = MetaQuery( + text, + all_success, + total, + execution, + meta_changed, + db_changed, + path_changed, + mutated, + is_special, + ) + + return output, meta_query + + def _handle_server_closed_connection(self, text): + """Used during CLI execution.""" + try: + click.secho("Reconnecting...", fg="green") + self.pgexecute.connect() + click.secho("Reconnected!", fg="green") + self.execute_command(text) + except OperationalError as e: + click.secho("Reconnect Failed", fg="red") + click.secho(str(e), err=True, fg="red") + + def refresh_completions(self, history=None, persist_priorities="all"): + """Refresh outdated completions + + :param history: A prompt_toolkit.history.FileHistory object. Used to + load keyword and identifier preferences + + :param persist_priorities: 'all' or 'keywords' + """ + + callback = functools.partial( + self._on_completions_refreshed, persist_priorities=persist_priorities + ) + self.completion_refresher.refresh( + self.pgexecute, + self.pgspecial, + callback, + history=history, + settings=self.settings, + ) + return [ + (None, None, None, "Auto-completion refresh started in the background.") + ] + + def _on_completions_refreshed(self, new_completer, persist_priorities): + self._swap_completer_objects(new_completer, persist_priorities) + + if self.prompt_app: + # After refreshing, redraw the CLI to clear the statusbar + # "Refreshing completions..." indicator + self.prompt_app.app.invalidate() + + def _swap_completer_objects(self, new_completer, persist_priorities): + """Swap the completer object with the newly created completer. + + persist_priorities is a string specifying how the old completer's + learned prioritizer should be transferred to the new completer. + + 'none' - The new prioritizer is left in a new/clean state + + 'all' - The new prioritizer is updated to exactly reflect + the old one + + 'keywords' - The new prioritizer is updated with old keyword + priorities, but not any other. + + """ + with self._completer_lock: + old_completer = self.completer + self.completer = new_completer + + if persist_priorities == "all": + # Just swap over the entire prioritizer + new_completer.prioritizer = old_completer.prioritizer + elif persist_priorities == "keywords": + # Swap over the entire prioritizer, but clear name priorities, + # leaving learned keyword priorities alone + new_completer.prioritizer = old_completer.prioritizer + new_completer.prioritizer.clear_names() + elif persist_priorities == "none": + # Leave the new prioritizer as is + pass + self.completer = new_completer + + def get_completions(self, text, cursor_positition): + with self._completer_lock: + return self.completer.get_completions( + Document(text=text, cursor_position=cursor_positition), None + ) + + def get_prompt(self, string): + # should be before replacing \\d + string = string.replace("\\dsn_alias", self.dsn_alias or "") + string = string.replace("\\t", self.now.strftime("%x %X")) + string = string.replace("\\u", self.pgexecute.user or "(none)") + string = string.replace("\\H", self.pgexecute.host or "(none)") + string = string.replace("\\h", self.pgexecute.short_host or "(none)") + string = string.replace("\\d", self.pgexecute.dbname or "(none)") + string = string.replace( + "\\p", + str(self.pgexecute.port) if self.pgexecute.port is not None else "5432", + ) + string = string.replace("\\i", str(self.pgexecute.pid) or "(none)") + string = string.replace("\\#", "#" if (self.pgexecute.superuser) else ">") + string = string.replace("\\n", "\n") + return string + + def get_last_query(self): + """Get the last query executed or None.""" + return self.query_history[-1][0] if self.query_history else None + + def is_too_wide(self, line): + """Will this line be too wide to fit into terminal?""" + if not self.prompt_app: + return False + return ( + len(COLOR_CODE_REGEX.sub("", line)) + > self.prompt_app.output.get_size().columns + ) + + def is_too_tall(self, lines): + """Are there too many lines to fit into terminal?""" + if not self.prompt_app: + return False + return len(lines) >= (self.prompt_app.output.get_size().rows - 4) + + def echo_via_pager(self, text, color=None): + if self.pgspecial.pager_config == PAGER_OFF or self.watch_command: + click.echo(text, color=color) + elif "pspg" in os.environ.get("PAGER", "") and self.table_format == "csv": + click.echo_via_pager(text, color) + elif self.pgspecial.pager_config == PAGER_LONG_OUTPUT: + lines = text.split("\n") + + # The last 4 lines are reserved for the pgcli menu and padding + if self.is_too_tall(lines) or any(self.is_too_wide(l) for l in lines): + click.echo_via_pager(text, color=color) + else: + click.echo(text, color=color) + else: + click.echo_via_pager(text, color) + + +@click.command() +# Default host is '' so psycopg2 can default to either localhost or unix socket +@click.option( + "-h", + "--host", + default="", + envvar="PGHOST", + help="Host address of the postgres database.", +) +@click.option( + "-p", + "--port", + default=5432, + help="Port number at which the " "postgres instance is listening.", + envvar="PGPORT", + type=click.INT, +) +@click.option( + "-U", + "--username", + "username_opt", + help="Username to connect to the postgres database.", +) +@click.option( + "-u", "--user", "username_opt", help="Username to connect to the postgres database." +) +@click.option( + "-W", + "--password", + "prompt_passwd", + is_flag=True, + default=False, + help="Force password prompt.", +) +@click.option( + "-w", + "--no-password", + "never_prompt", + is_flag=True, + default=False, + help="Never prompt for password.", +) +@click.option( + "--single-connection", + "single_connection", + is_flag=True, + default=False, + help="Do not use a separate connection for completions.", +) +@click.option("-v", "--version", is_flag=True, help="Version of pgcli.") +@click.option("-d", "--dbname", "dbname_opt", help="database name to connect to.") +@click.option( + "--pgclirc", + default=config_location() + "config", + envvar="PGCLIRC", + help="Location of pgclirc file.", + type=click.Path(dir_okay=False), +) +@click.option( + "-D", + "--dsn", + default="", + envvar="DSN", + help="Use DSN configured into the [alias_dsn] section of pgclirc file.", +) +@click.option( + "--list-dsn", + "list_dsn", + is_flag=True, + help="list of DSN configured into the [alias_dsn] section of pgclirc file.", +) +@click.option( + "--row-limit", + default=None, + envvar="PGROWLIMIT", + type=click.INT, + help="Set threshold for row limit prompt. Use 0 to disable prompt.", +) +@click.option( + "--less-chatty", + "less_chatty", + is_flag=True, + default=False, + help="Skip intro on startup and goodbye on exit.", +) +@click.option("--prompt", help='Prompt format (Default: "\\u@\\h:\\d> ").') +@click.option( + "--prompt-dsn", + help='Prompt format for connections using DSN aliases (Default: "\\u@\\h:\\d> ").', +) +@click.option( + "-l", + "--list", + "list_databases", + is_flag=True, + help="list " "available databases, then exit.", +) +@click.option( + "--auto-vertical-output", + is_flag=True, + help="Automatically switch to vertical output mode if the result is wider than the terminal width.", +) +@click.option( + "--warn/--no-warn", default=None, help="Warn before running a destructive query." +) +@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1) +@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1) +def cli( + dbname, + username_opt, + host, + port, + prompt_passwd, + never_prompt, + single_connection, + dbname_opt, + username, + version, + pgclirc, + dsn, + row_limit, + less_chatty, + prompt, + prompt_dsn, + list_databases, + auto_vertical_output, + list_dsn, + warn, +): + if version: + print("Version:", __version__) + sys.exit(0) + + config_dir = os.path.dirname(config_location()) + if not os.path.exists(config_dir): + os.makedirs(config_dir) + + # Migrate the config file from old location. + config_full_path = config_location() + "config" + if os.path.exists(os.path.expanduser("~/.pgclirc")): + if not os.path.exists(config_full_path): + shutil.move(os.path.expanduser("~/.pgclirc"), config_full_path) + print("Config file (~/.pgclirc) moved to new location", config_full_path) + else: + print("Config file is now located at", config_full_path) + print( + "Please move the existing config file ~/.pgclirc to", + config_full_path, + ) + if list_dsn: + try: + cfg = load_config(pgclirc, config_full_path) + for alias in cfg["alias_dsn"]: + click.secho(alias + " : " + cfg["alias_dsn"][alias]) + sys.exit(0) + except Exception as err: + click.secho( + "Invalid DSNs found in the config file. " + 'Please check the "[alias_dsn]" section in pgclirc.', + err=True, + fg="red", + ) + exit(1) + + pgcli = PGCli( + prompt_passwd, + never_prompt, + pgclirc_file=pgclirc, + row_limit=row_limit, + single_connection=single_connection, + less_chatty=less_chatty, + prompt=prompt, + prompt_dsn=prompt_dsn, + auto_vertical_output=auto_vertical_output, + warn=warn, + ) + + # Choose which ever one has a valid value. + if dbname_opt and dbname: + # work as psql: when database is given as option and argument use the argument as user + username = dbname + database = dbname_opt or dbname or "" + user = username_opt or username + service = None + if database.startswith("service="): + service = database[8:] + elif os.getenv("PGSERVICE") is not None: + service = os.getenv("PGSERVICE") + # because option --list or -l are not supposed to have a db name + if list_databases: + database = "postgres" + + if dsn != "": + try: + cfg = load_config(pgclirc, config_full_path) + dsn_config = cfg["alias_dsn"][dsn] + except KeyError: + click.secho( + f"Could not find a DSN with alias {dsn}. " + 'Please check the "[alias_dsn]" section in pgclirc.', + err=True, + fg="red", + ) + exit(1) + except Exception: + click.secho( + "Invalid DSNs found in the config file. " + 'Please check the "[alias_dsn]" section in pgclirc.', + err=True, + fg="red", + ) + exit(1) + pgcli.connect_uri(dsn_config) + pgcli.dsn_alias = dsn + elif "://" in database: + pgcli.connect_uri(database) + elif "=" in database and service is None: + pgcli.connect_dsn(database, user=user) + elif service is not None: + pgcli.connect_service(service, user) + else: + pgcli.connect(database, host, user, port) + + if list_databases: + cur, headers, status = pgcli.pgexecute.full_databases() + + title = "List of databases" + settings = OutputSettings(table_format="ascii", missingval="<null>") + formatted = format_output(title, cur, headers, status, settings) + pgcli.echo_via_pager("\n".join(formatted)) + + sys.exit(0) + + pgcli.logger.debug( + "Launch Params: \n" "\tdatabase: %r" "\tuser: %r" "\thost: %r" "\tport: %r", + database, + user, + host, + port, + ) + + if setproctitle: + obfuscate_process_password() + + pgcli.run_cli() + + +def obfuscate_process_password(): + process_title = setproctitle.getproctitle() + if "://" in process_title: + process_title = re.sub(r":(.*):(.*)@", r":\1:xxxx@", process_title) + elif "=" in process_title: + process_title = re.sub( + r"password=(.+?)((\s[a-zA-Z]+=)|$)", r"password=xxxx\2", process_title + ) + + setproctitle.setproctitle(process_title) + + +def has_meta_cmd(query): + """Determines if the completion needs a refresh by checking if the sql + statement is an alter, create, drop, commit or rollback.""" + try: + first_token = query.split()[0] + if first_token.lower() in ("alter", "create", "drop", "commit", "rollback"): + return True + except Exception: + return False + + return False + + +def has_change_db_cmd(query): + """Determines if the statement is a database switch such as 'use' or '\\c'""" + try: + first_token = query.split()[0] + if first_token.lower() in ("use", "\\c", "\\connect"): + return True + except Exception: + return False + + return False + + +def has_change_path_cmd(sql): + """Determines if the search_path should be refreshed by checking if the + sql has 'set search_path'.""" + return "set search_path" in sql.lower() + + +def is_mutating(status): + """Determines if the statement is mutating based on the status.""" + if not status: + return False + + mutating = set(["insert", "update", "delete"]) + return status.split(None, 1)[0].lower() in mutating + + +def is_select(status): + """Returns true if the first word in status is 'select'.""" + if not status: + return False + return status.split(None, 1)[0].lower() == "select" + + +def exception_formatter(e): + return click.style(str(e), fg="red") + + +def format_output(title, cur, headers, status, settings): + output = [] + expanded = settings.expanded or settings.table_format == "vertical" + table_format = "vertical" if settings.expanded else settings.table_format + max_width = settings.max_width + case_function = settings.case_function + formatter = TabularOutputFormatter(format_name=table_format) + + def format_array(val): + if val is None: + return settings.missingval + if not isinstance(val, list): + return val + return "{" + ",".join(str(format_array(e)) for e in val) + "}" + + def format_arrays(data, headers, **_): + data = list(data) + for row in data: + row[:] = [ + format_array(val) if isinstance(val, list) else val for val in row + ] + + return data, headers + + output_kwargs = { + "sep_title": "RECORD {n}", + "sep_character": "-", + "sep_length": (1, 25), + "missing_value": settings.missingval, + "integer_format": settings.dcmlfmt, + "float_format": settings.floatfmt, + "preprocessors": (format_numbers, format_arrays), + "disable_numparse": True, + "preserve_whitespace": True, + "style": settings.style_output, + } + if not settings.floatfmt: + output_kwargs["preprocessors"] = (align_decimals,) + + if table_format == "csv": + # The default CSV dialect is "excel" which is not handling newline values correctly + # Nevertheless, we want to keep on using "excel" on Windows since it uses '\r\n' + # as the line terminator + # https://github.com/dbcli/pgcli/issues/1102 + dialect = "excel" if platform.system() == "Windows" else "unix" + output_kwargs["dialect"] = dialect + + if title: # Only print the title if it's not None. + output.append(title) + + if cur: + headers = [case_function(x) for x in headers] + if max_width is not None: + cur = list(cur) + column_types = None + if hasattr(cur, "description"): + column_types = [] + for d in cur.description: + if ( + d[1] in psycopg2.extensions.DECIMAL.values + or d[1] in psycopg2.extensions.FLOAT.values + ): + column_types.append(float) + if ( + d[1] == psycopg2.extensions.INTEGER.values + or d[1] in psycopg2.extensions.LONGINTEGER.values + ): + column_types.append(int) + else: + column_types.append(str) + + formatted = formatter.format_output(cur, headers, **output_kwargs) + if isinstance(formatted, str): + formatted = iter(formatted.splitlines()) + first_line = next(formatted) + formatted = itertools.chain([first_line], formatted) + if not expanded and max_width and len(first_line) > max_width and headers: + formatted = formatter.format_output( + cur, headers, format_name="vertical", column_types=None, **output_kwargs + ) + if isinstance(formatted, str): + formatted = iter(formatted.splitlines()) + + output = itertools.chain(output, formatted) + + # Only print the status if it's not None and we are not producing CSV + if status and table_format != "csv": + output = itertools.chain(output, [status]) + + return output + + +def parse_service_info(service): + service = service or os.getenv("PGSERVICE") + service_file = os.getenv("PGSERVICEFILE") + if not service_file: + # try ~/.pg_service.conf (if that exists) + if platform.system() == "Windows": + service_file = os.getenv("PGSYSCONFDIR") + "\\pg_service.conf" + elif os.getenv("PGSYSCONFDIR"): + service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf") + else: + service_file = expanduser("~/.pg_service.conf") + if not service: + # nothing to do + return None, service_file + service_file_config = ConfigObj(service_file) + if service not in service_file_config: + return None, service_file + service_conf = service_file_config.get(service) + return service_conf, service_file + + +if __name__ == "__main__": + cli() |