From 6884720fae8a2622b14e93d9e35ca5fcc2283b40 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 8 Feb 2021 11:31:05 +0100 Subject: Adding upstream version 3.1.0. Signed-off-by: Daniel Baumann --- pgcli/__init__.py | 1 + pgcli/__main__.py | 9 + pgcli/completion_refresher.py | 150 +++ pgcli/config.py | 64 ++ pgcli/key_bindings.py | 127 +++ pgcli/magic.py | 67 ++ pgcli/main.py | 1516 +++++++++++++++++++++++++++++ pgcli/packages/__init__.py | 0 pgcli/packages/parseutils/__init__.py | 22 + pgcli/packages/parseutils/ctes.py | 141 +++ pgcli/packages/parseutils/meta.py | 170 ++++ pgcli/packages/parseutils/tables.py | 170 ++++ pgcli/packages/parseutils/utils.py | 140 +++ pgcli/packages/pgliterals/__init__.py | 0 pgcli/packages/pgliterals/main.py | 15 + pgcli/packages/pgliterals/pgliterals.json | 629 ++++++++++++ pgcli/packages/prioritization.py | 51 + pgcli/packages/prompt_utils.py | 35 + pgcli/packages/sqlcompletion.py | 608 ++++++++++++ pgcli/pgbuffer.py | 50 + pgcli/pgclirc | 195 ++++ pgcli/pgcompleter.py | 1046 ++++++++++++++++++++ pgcli/pgexecute.py | 857 ++++++++++++++++ pgcli/pgstyle.py | 116 +++ pgcli/pgtoolbar.py | 62 ++ 25 files changed, 6241 insertions(+) create mode 100644 pgcli/__init__.py create mode 100644 pgcli/__main__.py create mode 100644 pgcli/completion_refresher.py create mode 100644 pgcli/config.py create mode 100644 pgcli/key_bindings.py create mode 100644 pgcli/magic.py create mode 100644 pgcli/main.py create mode 100644 pgcli/packages/__init__.py create mode 100644 pgcli/packages/parseutils/__init__.py create mode 100644 pgcli/packages/parseutils/ctes.py create mode 100644 pgcli/packages/parseutils/meta.py create mode 100644 pgcli/packages/parseutils/tables.py create mode 100644 pgcli/packages/parseutils/utils.py create mode 100644 pgcli/packages/pgliterals/__init__.py create mode 100644 pgcli/packages/pgliterals/main.py create mode 100644 pgcli/packages/pgliterals/pgliterals.json create mode 100644 pgcli/packages/prioritization.py create mode 100644 pgcli/packages/prompt_utils.py create mode 100644 pgcli/packages/sqlcompletion.py create mode 100644 pgcli/pgbuffer.py create mode 100644 pgcli/pgclirc create mode 100644 pgcli/pgcompleter.py create mode 100644 pgcli/pgexecute.py create mode 100644 pgcli/pgstyle.py create mode 100644 pgcli/pgtoolbar.py (limited to 'pgcli') diff --git a/pgcli/__init__.py b/pgcli/__init__.py new file mode 100644 index 0000000..f5f41e5 --- /dev/null +++ b/pgcli/__init__.py @@ -0,0 +1 @@ +__version__ = "3.1.0" diff --git a/pgcli/__main__.py b/pgcli/__main__.py new file mode 100644 index 0000000..ddf1662 --- /dev/null +++ b/pgcli/__main__.py @@ -0,0 +1,9 @@ +""" +pgcli package main entry point +""" + +from .main import cli + + +if __name__ == "__main__": + cli() diff --git a/pgcli/completion_refresher.py b/pgcli/completion_refresher.py new file mode 100644 index 0000000..cf0879f --- /dev/null +++ b/pgcli/completion_refresher.py @@ -0,0 +1,150 @@ +import threading +import os +from collections import OrderedDict + +from .pgcompleter import PGCompleter +from .pgexecute import PGExecute + + +class CompletionRefresher(object): + + refreshers = OrderedDict() + + def __init__(self): + self._completer_thread = None + self._restart_refresh = threading.Event() + + def refresh(self, executor, special, callbacks, history=None, settings=None): + """ + Creates a PGCompleter object and populates it with the relevant + completion suggestions in a background thread. + + executor - PGExecute object, used to extract the credentials to connect + to the database. + special - PGSpecial object used for creating a new completion object. + settings - dict of settings for completer object + callbacks - A function or a list of functions to call after the thread + has completed the refresh. The newly created completion + object will be passed in as an argument to each callback. + """ + if self.is_refreshing(): + self._restart_refresh.set() + return [(None, None, None, "Auto-completion refresh restarted.")] + else: + self._completer_thread = threading.Thread( + target=self._bg_refresh, + args=(executor, special, callbacks, history, settings), + name="completion_refresh", + ) + self._completer_thread.setDaemon(True) + self._completer_thread.start() + return [ + (None, None, None, "Auto-completion refresh started in the background.") + ] + + def is_refreshing(self): + return self._completer_thread and self._completer_thread.is_alive() + + def _bg_refresh(self, pgexecute, special, callbacks, history=None, settings=None): + settings = settings or {} + completer = PGCompleter( + smart_completion=True, pgspecial=special, settings=settings + ) + + if settings.get("single_connection"): + executor = pgexecute + else: + # Create a new pgexecute method to populate the completions. + executor = pgexecute.copy() + # If callbacks is a single function then push it into a list. + if callable(callbacks): + callbacks = [callbacks] + + while 1: + for refresher in self.refreshers.values(): + refresher(completer, executor) + if self._restart_refresh.is_set(): + self._restart_refresh.clear() + break + else: + # Break out of while loop if the for loop finishes natually + # without hitting the break statement. + break + + # Start over the refresh from the beginning if the for loop hit the + # break statement. + continue + + # Load history into pgcompleter so it can learn user preferences + n_recent = 100 + if history: + for recent in history.get_strings()[-n_recent:]: + completer.extend_query_history(recent, is_init=True) + + for callback in callbacks: + callback(completer) + + if not settings.get("single_connection") and executor.conn: + # close connection established with pgexecute.copy() + executor.conn.close() + + +def refresher(name, refreshers=CompletionRefresher.refreshers): + """Decorator to populate the dictionary of refreshers with the current + function. + """ + + def wrapper(wrapped): + refreshers[name] = wrapped + return wrapped + + return wrapper + + +@refresher("schemata") +def refresh_schemata(completer, executor): + completer.set_search_path(executor.search_path()) + completer.extend_schemata(executor.schemata()) + + +@refresher("tables") +def refresh_tables(completer, executor): + completer.extend_relations(executor.tables(), kind="tables") + completer.extend_columns(executor.table_columns(), kind="tables") + completer.extend_foreignkeys(executor.foreignkeys()) + + +@refresher("views") +def refresh_views(completer, executor): + completer.extend_relations(executor.views(), kind="views") + completer.extend_columns(executor.view_columns(), kind="views") + + +@refresher("types") +def refresh_types(completer, executor): + completer.extend_datatypes(executor.datatypes()) + + +@refresher("databases") +def refresh_databases(completer, executor): + completer.extend_database_names(executor.databases()) + + +@refresher("casing") +def refresh_casing(completer, executor): + casing_file = completer.casing_file + if not casing_file: + return + generate_casing_file = completer.generate_casing_file + if generate_casing_file and not os.path.isfile(casing_file): + casing_prefs = "\n".join(executor.casing()) + with open(casing_file, "w") as f: + f.write(casing_prefs) + if os.path.isfile(casing_file): + with open(casing_file, "r") as f: + completer.extend_casing([line.strip() for line in f]) + + +@refresher("functions") +def refresh_functions(completer, executor): + completer.extend_functions(executor.functions()) diff --git a/pgcli/config.py b/pgcli/config.py new file mode 100644 index 0000000..0fc42dd --- /dev/null +++ b/pgcli/config.py @@ -0,0 +1,64 @@ +import errno +import shutil +import os +import platform +from os.path import expanduser, exists, dirname +from configobj import ConfigObj + + +def config_location(): + if "XDG_CONFIG_HOME" in os.environ: + return "%s/pgcli/" % expanduser(os.environ["XDG_CONFIG_HOME"]) + elif platform.system() == "Windows": + return os.getenv("USERPROFILE") + "\\AppData\\Local\\dbcli\\pgcli\\" + else: + return expanduser("~/.config/pgcli/") + + +def load_config(usr_cfg, def_cfg=None): + cfg = ConfigObj() + cfg.merge(ConfigObj(def_cfg, interpolation=False)) + cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8")) + cfg.filename = expanduser(usr_cfg) + + return cfg + + +def ensure_dir_exists(path): + parent_dir = expanduser(dirname(path)) + os.makedirs(parent_dir, exist_ok=True) + + +def write_default_config(source, destination, overwrite=False): + destination = expanduser(destination) + if not overwrite and exists(destination): + return + + ensure_dir_exists(destination) + + shutil.copyfile(source, destination) + + +def upgrade_config(config, def_config): + cfg = load_config(config, def_config) + cfg.write() + + +def get_config(pgclirc_file=None): + from pgcli import __file__ as package_root + + package_root = os.path.dirname(package_root) + + pgclirc_file = pgclirc_file or "%sconfig" % config_location() + + default_config = os.path.join(package_root, "pgclirc") + write_default_config(default_config, pgclirc_file) + + return load_config(pgclirc_file, default_config) + + +def get_casing_file(config): + casing_file = config["main"]["casing_file"] + if casing_file == "default": + casing_file = config_location() + "casing" + return casing_file diff --git a/pgcli/key_bindings.py b/pgcli/key_bindings.py new file mode 100644 index 0000000..23174b6 --- /dev/null +++ b/pgcli/key_bindings.py @@ -0,0 +1,127 @@ +import logging +from prompt_toolkit.enums import EditingMode +from prompt_toolkit.key_binding import KeyBindings +from prompt_toolkit.filters import ( + completion_is_selected, + is_searching, + has_completions, + has_selection, + vi_mode, +) + +from .pgbuffer import buffer_should_be_handled + +_logger = logging.getLogger(__name__) + + +def pgcli_bindings(pgcli): + """Custom key bindings for pgcli.""" + kb = KeyBindings() + + tab_insert_text = " " * 4 + + @kb.add("f2") + def _(event): + """Enable/Disable SmartCompletion Mode.""" + _logger.debug("Detected F2 key.") + pgcli.completer.smart_completion = not pgcli.completer.smart_completion + + @kb.add("f3") + def _(event): + """Enable/Disable Multiline Mode.""" + _logger.debug("Detected F3 key.") + pgcli.multi_line = not pgcli.multi_line + + @kb.add("f4") + def _(event): + """Toggle between Vi and Emacs mode.""" + _logger.debug("Detected F4 key.") + pgcli.vi_mode = not pgcli.vi_mode + event.app.editing_mode = EditingMode.VI if pgcli.vi_mode else EditingMode.EMACS + + @kb.add("tab") + def _(event): + """Force autocompletion at cursor on non-empty lines.""" + + _logger.debug("Detected key.") + + buff = event.app.current_buffer + doc = buff.document + + if doc.on_first_line or doc.current_line.strip(): + if buff.complete_state: + buff.complete_next() + else: + buff.start_completion(select_first=True) + else: + buff.insert_text(tab_insert_text, fire_event=False) + + @kb.add("escape", filter=has_completions) + def _(event): + """Force closing of autocompletion.""" + _logger.debug("Detected key.") + + event.current_buffer.complete_state = None + event.app.current_buffer.complete_state = None + + @kb.add("c-space") + def _(event): + """ + Initialize autocompletion at cursor. + + If the autocompletion menu is not showing, display it with the + appropriate completions for the context. + + If the menu is showing, select the next completion. + """ + _logger.debug("Detected key.") + + b = event.app.current_buffer + if b.complete_state: + b.complete_next() + else: + b.start_completion(select_first=False) + + @kb.add("enter", filter=completion_is_selected) + def _(event): + """Makes the enter key work as the tab key only when showing the menu. + + In other words, don't execute query when enter is pressed in + the completion dropdown menu, instead close the dropdown menu + (accept current selection). + + """ + _logger.debug("Detected enter key during completion selection.") + + event.current_buffer.complete_state = None + event.app.current_buffer.complete_state = None + + # When using multi_line input mode the buffer is not handled on Enter (a new line is + # inserted instead), so we force the handling if we're not in a completion or + # history search, and one of several conditions are True + @kb.add( + "enter", + filter=~(completion_is_selected | is_searching) + & buffer_should_be_handled(pgcli), + ) + def _(event): + _logger.debug("Detected enter key.") + event.current_buffer.validate_and_handle() + + @kb.add("escape", "enter", filter=~vi_mode) + def _(event): + """Introduces a line break regardless of multi-line mode or not.""" + _logger.debug("Detected alt-enter key.") + event.app.current_buffer.insert_text("\n") + + @kb.add("c-p", filter=~has_selection) + def _(event): + """Move up in history.""" + event.current_buffer.history_backward(count=event.arg) + + @kb.add("c-n", filter=~has_selection) + def _(event): + """Move down in history.""" + event.current_buffer.history_forward(count=event.arg) + + return kb diff --git a/pgcli/magic.py b/pgcli/magic.py new file mode 100644 index 0000000..f58f415 --- /dev/null +++ b/pgcli/magic.py @@ -0,0 +1,67 @@ +from .main import PGCli +import sql.parse +import sql.connection +import logging + +_logger = logging.getLogger(__name__) + + +def load_ipython_extension(ipython): + """This is called via the ipython command '%load_ext pgcli.magic'""" + + # first, load the sql magic if it isn't already loaded + if not ipython.find_line_magic("sql"): + ipython.run_line_magic("load_ext", "sql") + + # register our own magic + ipython.register_magic_function(pgcli_line_magic, "line", "pgcli") + + +def pgcli_line_magic(line): + _logger.debug("pgcli magic called: %r", line) + parsed = sql.parse.parse(line, {}) + # "get" was renamed to "set" in ipython-sql: + # https://github.com/catherinedevlin/ipython-sql/commit/f4283c65aaf68f961e84019e8b939e4a3c501d43 + if hasattr(sql.connection.Connection, "get"): + conn = sql.connection.Connection.get(parsed["connection"]) + else: + conn = sql.connection.Connection.set(parsed["connection"]) + + try: + # A corresponding pgcli object already exists + pgcli = conn._pgcli + _logger.debug("Reusing existing pgcli") + except AttributeError: + # I can't figure out how to get the underylying psycopg2 connection + # from the sqlalchemy connection, so just grab the url and make a + # new connection + pgcli = PGCli() + u = conn.session.engine.url + _logger.debug("New pgcli: %r", str(u)) + + pgcli.connect(u.database, u.host, u.username, u.port, u.password) + conn._pgcli = pgcli + + # For convenience, print the connection alias + print("Connected: {}".format(conn.name)) + + try: + pgcli.run_cli() + except SystemExit: + pass + + if not pgcli.query_history: + return + + q = pgcli.query_history[-1] + + if not q.successful: + _logger.debug("Unsuccessful query - ignoring") + return + + if q.meta_changed or q.db_changed or q.path_changed: + _logger.debug("Dangerous query detected -- ignoring") + return + + ipython = get_ipython() + return ipython.run_cell_magic("sql", line, q.query) diff --git a/pgcli/main.py b/pgcli/main.py new file mode 100644 index 0000000..b146898 --- /dev/null +++ b/pgcli/main.py @@ -0,0 +1,1516 @@ +import platform +import warnings +from os.path import expanduser + +from configobj import ConfigObj +from pgspecial.namedqueries import NamedQueries + +warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2") + +import os +import re +import sys +import traceback +import logging +import threading +import shutil +import functools +import pendulum +import datetime as dt +import itertools +import platform +from time import time, sleep +from codecs import open + +keyring = None # keyring will be loaded later + +from cli_helpers.tabular_output import TabularOutputFormatter +from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers +import click + +try: + import setproctitle +except ImportError: + setproctitle = None +from prompt_toolkit.completion import DynamicCompleter, ThreadedCompleter +from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode +from prompt_toolkit.shortcuts import PromptSession, CompleteStyle +from prompt_toolkit.document import Document +from prompt_toolkit.filters import HasFocus, IsDone +from prompt_toolkit.formatted_text import ANSI +from prompt_toolkit.lexers import PygmentsLexer +from prompt_toolkit.layout.processors import ( + ConditionalProcessor, + HighlightMatchingBracketProcessor, + TabsProcessor, +) +from prompt_toolkit.history import FileHistory +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from pygments.lexers.sql import PostgresLexer + +from pgspecial.main import PGSpecial, NO_QUERY, PAGER_OFF, PAGER_LONG_OUTPUT +import pgspecial as special + +from .pgcompleter import PGCompleter +from .pgtoolbar import create_toolbar_tokens_func +from .pgstyle import style_factory, style_factory_output +from .pgexecute import PGExecute +from .completion_refresher import CompletionRefresher +from .config import ( + get_casing_file, + load_config, + config_location, + ensure_dir_exists, + get_config, +) +from .key_bindings import pgcli_bindings +from .packages.prompt_utils import confirm_destructive_query +from .__init__ import __version__ + +click.disable_unicode_literals_warning = True + +try: + from urlparse import urlparse, unquote, parse_qs +except ImportError: + from urllib.parse import urlparse, unquote, parse_qs + +from getpass import getuser +from psycopg2 import OperationalError, InterfaceError +import psycopg2 + +from collections import namedtuple + +from textwrap import dedent + +# Ref: https://stackoverflow.com/questions/30425105/filter-special-chars-such-as-color-codes-from-shell-output +COLOR_CODE_REGEX = re.compile(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))") + +# Query tuples are used for maintaining history +MetaQuery = namedtuple( + "Query", + [ + "query", # The entire text of the command + "successful", # True If all subqueries were successful + "total_time", # Time elapsed executing the query and formatting results + "execution_time", # Time elapsed executing the query + "meta_changed", # True if any subquery executed create/alter/drop + "db_changed", # True if any subquery changed the database + "path_changed", # True if any subquery changed the search path + "mutated", # True if any subquery executed insert/update/delete + "is_special", # True if the query is a special command + ], +) +MetaQuery.__new__.__defaults__ = ("", False, 0, 0, False, False, False, False) + +OutputSettings = namedtuple( + "OutputSettings", + "table_format dcmlfmt floatfmt missingval expanded max_width case_function style_output", +) +OutputSettings.__new__.__defaults__ = ( + None, + None, + None, + "", + False, + None, + lambda x: x, + None, +) + + +class PgCliQuitError(Exception): + pass + + +class PGCli(object): + default_prompt = "\\u@\\h:\\d> " + max_len_prompt = 30 + + def set_default_pager(self, config): + configured_pager = config["main"].get("pager") + os_environ_pager = os.environ.get("PAGER") + + if configured_pager: + self.logger.info( + 'Default pager found in config file: "%s"', configured_pager + ) + os.environ["PAGER"] = configured_pager + elif os_environ_pager: + self.logger.info( + 'Default pager found in PAGER environment variable: "%s"', + os_environ_pager, + ) + os.environ["PAGER"] = os_environ_pager + else: + self.logger.info( + "No default pager found in environment. Using os default pager" + ) + + # Set default set of less recommended options, if they are not already set. + # They are ignored if pager is different than less. + if not os.environ.get("LESS"): + os.environ["LESS"] = "-SRXF" + + def __init__( + self, + force_passwd_prompt=False, + never_passwd_prompt=False, + pgexecute=None, + pgclirc_file=None, + row_limit=None, + single_connection=False, + less_chatty=None, + prompt=None, + prompt_dsn=None, + auto_vertical_output=False, + warn=None, + ): + + self.force_passwd_prompt = force_passwd_prompt + self.never_passwd_prompt = never_passwd_prompt + self.pgexecute = pgexecute + self.dsn_alias = None + self.watch_command = None + + # Load config. + c = self.config = get_config(pgclirc_file) + + NamedQueries.instance = NamedQueries.from_config(self.config) + + self.logger = logging.getLogger(__name__) + self.initialize_logging() + + self.set_default_pager(c) + self.output_file = None + self.pgspecial = PGSpecial() + + self.multi_line = c["main"].as_bool("multi_line") + self.multiline_mode = c["main"].get("multi_line_mode", "psql") + self.vi_mode = c["main"].as_bool("vi") + self.auto_expand = auto_vertical_output or c["main"].as_bool("auto_expand") + self.expanded_output = c["main"].as_bool("expand") + self.pgspecial.timing_enabled = c["main"].as_bool("timing") + if row_limit is not None: + self.row_limit = row_limit + else: + self.row_limit = c["main"].as_int("row_limit") + + self.min_num_menu_lines = c["main"].as_int("min_num_menu_lines") + self.multiline_continuation_char = c["main"]["multiline_continuation_char"] + self.table_format = c["main"]["table_format"] + self.syntax_style = c["main"]["syntax_style"] + self.cli_style = c["colors"] + self.wider_completion_menu = c["main"].as_bool("wider_completion_menu") + c_dest_warning = c["main"].as_bool("destructive_warning") + self.destructive_warning = c_dest_warning if warn is None else warn + self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty") + self.null_string = c["main"].get("null_string", "") + self.prompt_format = ( + prompt + if prompt is not None + else c["main"].get("prompt", self.default_prompt) + ) + self.prompt_dsn_format = prompt_dsn + self.on_error = c["main"]["on_error"].upper() + self.decimal_format = c["data_formats"]["decimal"] + self.float_format = c["data_formats"]["float"] + self.initialize_keyring() + self.show_bottom_toolbar = c["main"].as_bool("show_bottom_toolbar") + + self.pgspecial.pset_pager( + self.config["main"].as_bool("enable_pager") and "on" or "off" + ) + + self.style_output = style_factory_output(self.syntax_style, c["colors"]) + + self.now = dt.datetime.today() + + self.completion_refresher = CompletionRefresher() + + self.query_history = [] + + # Initialize completer + smart_completion = c["main"].as_bool("smart_completion") + keyword_casing = c["main"]["keyword_casing"] + self.settings = { + "casing_file": get_casing_file(c), + "generate_casing_file": c["main"].as_bool("generate_casing_file"), + "generate_aliases": c["main"].as_bool("generate_aliases"), + "asterisk_column_order": c["main"]["asterisk_column_order"], + "qualify_columns": c["main"]["qualify_columns"], + "case_column_headers": c["main"].as_bool("case_column_headers"), + "search_path_filter": c["main"].as_bool("search_path_filter"), + "single_connection": single_connection, + "less_chatty": less_chatty, + "keyword_casing": keyword_casing, + } + + completer = PGCompleter( + smart_completion, pgspecial=self.pgspecial, settings=self.settings + ) + self.completer = completer + self._completer_lock = threading.Lock() + self.register_special_commands() + + self.prompt_app = None + + def quit(self): + raise PgCliQuitError + + def register_special_commands(self): + + self.pgspecial.register( + self.change_db, + "\\c", + "\\c[onnect] database_name", + "Change to a new database.", + aliases=("use", "\\connect", "USE"), + ) + + refresh_callback = lambda: self.refresh_completions(persist_priorities="all") + + self.pgspecial.register( + self.quit, + "\\q", + "\\q", + "Quit pgcli.", + arg_type=NO_QUERY, + case_sensitive=True, + aliases=(":q",), + ) + self.pgspecial.register( + self.quit, + "quit", + "quit", + "Quit pgcli.", + arg_type=NO_QUERY, + case_sensitive=False, + aliases=("exit",), + ) + self.pgspecial.register( + refresh_callback, + "\\#", + "\\#", + "Refresh auto-completions.", + arg_type=NO_QUERY, + ) + self.pgspecial.register( + refresh_callback, + "\\refresh", + "\\refresh", + "Refresh auto-completions.", + arg_type=NO_QUERY, + ) + self.pgspecial.register( + self.execute_from_file, "\\i", "\\i filename", "Execute commands from file." + ) + self.pgspecial.register( + self.write_to_file, + "\\o", + "\\o [filename]", + "Send all query results to file.", + ) + self.pgspecial.register( + self.info_connection, "\\conninfo", "\\conninfo", "Get connection details" + ) + self.pgspecial.register( + self.change_table_format, + "\\T", + "\\T [format]", + "Change the table format used to output results", + ) + + def change_table_format(self, pattern, **_): + try: + if pattern not in TabularOutputFormatter().supported_formats: + raise ValueError() + self.table_format = pattern + yield (None, None, None, "Changed table format to {}".format(pattern)) + except ValueError: + msg = "Table format {} not recognized. Allowed formats:".format(pattern) + for table_type in TabularOutputFormatter().supported_formats: + msg += "\n\t{}".format(table_type) + msg += "\nCurrently set to: %s" % self.table_format + yield (None, None, None, msg) + + def info_connection(self, **_): + if self.pgexecute.host.startswith("/"): + host = 'socket "%s"' % self.pgexecute.host + else: + host = 'host "%s"' % self.pgexecute.host + + yield ( + None, + None, + None, + 'You are connected to database "%s" as user ' + '"%s" on %s at port "%s".' + % (self.pgexecute.dbname, self.pgexecute.user, host, self.pgexecute.port), + ) + + def change_db(self, pattern, **_): + if pattern: + # Get all the parameters in pattern, handling double quotes if any. + infos = re.findall(r'"[^"]*"|[^"\'\s]+', pattern) + # Now removing quotes. + list(map(lambda s: s.strip('"'), infos)) + + infos.extend([None] * (4 - len(infos))) + db, user, host, port = infos + try: + self.pgexecute.connect( + database=db, + user=user, + host=host, + port=port, + **self.pgexecute.extra_args, + ) + except OperationalError as e: + click.secho(str(e), err=True, fg="red") + click.echo("Previous connection kept") + else: + self.pgexecute.connect() + + yield ( + None, + None, + None, + 'You are now connected to database "%s" as ' + 'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user), + ) + + def execute_from_file(self, pattern, **_): + if not pattern: + message = "\\i: missing required argument" + return [(None, None, None, message, "", False, True)] + try: + with open(os.path.expanduser(pattern), encoding="utf-8") as f: + query = f.read() + except IOError as e: + return [(None, None, None, str(e), "", False, True)] + + if self.destructive_warning and confirm_destructive_query(query) is False: + message = "Wise choice. Command execution stopped." + return [(None, None, None, message)] + + on_error_resume = self.on_error == "RESUME" + return self.pgexecute.run( + query, self.pgspecial, on_error_resume=on_error_resume + ) + + def write_to_file(self, pattern, **_): + if not pattern: + self.output_file = None + message = "File output disabled" + return [(None, None, None, message, "", True, True)] + filename = os.path.abspath(os.path.expanduser(pattern)) + if not os.path.isfile(filename): + try: + open(filename, "w").close() + except IOError as e: + self.output_file = None + message = str(e) + "\nFile output disabled" + return [(None, None, None, message, "", False, True)] + self.output_file = filename + message = 'Writing to file "%s"' % self.output_file + return [(None, None, None, message, "", True, True)] + + def initialize_logging(self): + + log_file = self.config["main"]["log_file"] + if log_file == "default": + log_file = config_location() + "log" + ensure_dir_exists(log_file) + log_level = self.config["main"]["log_level"] + + # Disable logging if value is NONE by switching to a no-op handler. + # Set log level to a high value so it doesn't even waste cycles getting called. + if log_level.upper() == "NONE": + handler = logging.NullHandler() + else: + handler = logging.FileHandler(os.path.expanduser(log_file)) + + level_map = { + "CRITICAL": logging.CRITICAL, + "ERROR": logging.ERROR, + "WARNING": logging.WARNING, + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, + "NONE": logging.CRITICAL, + } + + log_level = level_map[log_level.upper()] + + formatter = logging.Formatter( + "%(asctime)s (%(process)d/%(threadName)s) " + "%(name)s %(levelname)s - %(message)s" + ) + + handler.setFormatter(formatter) + + root_logger = logging.getLogger("pgcli") + root_logger.addHandler(handler) + root_logger.setLevel(log_level) + + root_logger.debug("Initializing pgcli logging.") + root_logger.debug("Log file %r.", log_file) + + pgspecial_logger = logging.getLogger("pgspecial") + pgspecial_logger.addHandler(handler) + pgspecial_logger.setLevel(log_level) + + def initialize_keyring(self): + global keyring + + keyring_enabled = self.config["main"].as_bool("keyring") + if keyring_enabled: + # Try best to load keyring (issue #1041). + import importlib + + try: + keyring = importlib.import_module("keyring") + except Exception as e: # ImportError for Python 2, ModuleNotFoundError for Python 3 + self.logger.warning("import keyring failed: %r.", e) + + def connect_dsn(self, dsn, **kwargs): + self.connect(dsn=dsn, **kwargs) + + def connect_service(self, service, user): + service_config, file = parse_service_info(service) + if service_config is None: + click.secho( + "service '%s' was not found in %s" % (service, file), err=True, fg="red" + ) + exit(1) + self.connect( + database=service_config.get("dbname"), + host=service_config.get("host"), + user=user or service_config.get("user"), + port=service_config.get("port"), + passwd=service_config.get("password"), + ) + + def connect_uri(self, uri): + kwargs = psycopg2.extensions.parse_dsn(uri) + remap = {"dbname": "database", "password": "passwd"} + kwargs = {remap.get(k, k): v for k, v in kwargs.items()} + self.connect(**kwargs) + + def connect( + self, database="", host="", user="", port="", passwd="", dsn="", **kwargs + ): + # Connect to the database. + + if not user: + user = getuser() + + if not database: + database = user + + kwargs.setdefault("application_name", "pgcli") + + # If password prompt is not forced but no password is provided, try + # getting it from environment variable. + if not self.force_passwd_prompt and not passwd: + passwd = os.environ.get("PGPASSWORD", "") + + # Find password from store + key = "%s@%s" % (user, host) + keyring_error_message = dedent( + """\ + {} + {} + To remove this message do one of the following: + - prepare keyring as described at: https://keyring.readthedocs.io/en/stable/ + - uninstall keyring: pip uninstall keyring + - disable keyring in our configuration: add keyring = False to [main]""" + ) + if not passwd and keyring: + + try: + passwd = keyring.get_password("pgcli", key) + except (RuntimeError, keyring.errors.InitError) as e: + click.secho( + keyring_error_message.format( + "Load your password from keyring returned:", str(e) + ), + err=True, + fg="red", + ) + + # Prompt for a password immediately if requested via the -W flag. This + # avoids wasting time trying to connect to the database and catching a + # no-password exception. + # If we successfully parsed a password from a URI, there's no need to + # prompt for it, even with the -W flag + if self.force_passwd_prompt and not passwd: + passwd = click.prompt( + "Password for %s" % user, hide_input=True, show_default=False, type=str + ) + + def should_ask_for_password(exc): + # Prompt for a password after 1st attempt to connect + # fails. Don't prompt if the -w flag is supplied + if self.never_passwd_prompt: + return False + error_msg = exc.args[0] + if "no password supplied" in error_msg: + return True + if "password authentication failed" in error_msg: + return True + return False + + # Attempt to connect to the database. + # Note that passwd may be empty on the first attempt. If connection + # fails because of a missing or incorrect password, but we're allowed to + # prompt for a password (no -w flag), prompt for a passwd and try again. + try: + try: + pgexecute = PGExecute(database, user, passwd, host, port, dsn, **kwargs) + except (OperationalError, InterfaceError) as e: + if should_ask_for_password(e): + passwd = click.prompt( + "Password for %s" % user, + hide_input=True, + show_default=False, + type=str, + ) + pgexecute = PGExecute( + database, user, passwd, host, port, dsn, **kwargs + ) + else: + raise e + if passwd and keyring: + try: + keyring.set_password("pgcli", key, passwd) + except (RuntimeError, keyring.errors.KeyringError) as e: + click.secho( + keyring_error_message.format( + "Set password in keyring returned:", str(e) + ), + err=True, + fg="red", + ) + + except Exception as e: # Connecting to a database could fail. + self.logger.debug("Database connection failed: %r.", e) + self.logger.error("traceback: %r", traceback.format_exc()) + click.secho(str(e), err=True, fg="red") + exit(1) + + self.pgexecute = pgexecute + + def handle_editor_command(self, text): + r""" + Editor command is any query that is prefixed or suffixed + by a '\e'. The reason for a while loop is because a user + might edit a query multiple times. + For eg: + "select * from \e" to edit it in vim, then come + back to the prompt with the edited query "select * from + blah where q = 'abc'\e" to edit it again. + :param text: Document + :return: Document + """ + editor_command = special.editor_command(text) + while editor_command: + if editor_command == "\\e": + filename = special.get_filename(text) + query = special.get_editor_query(text) or self.get_last_query() + else: # \ev or \ef + filename = None + spec = text.split()[1] + if editor_command == "\\ev": + query = self.pgexecute.view_definition(spec) + elif editor_command == "\\ef": + query = self.pgexecute.function_definition(spec) + sql, message = special.open_external_editor(filename, sql=query) + if message: + # Something went wrong. Raise an exception and bail. + raise RuntimeError(message) + while True: + try: + text = self.prompt_app.prompt(default=sql) + break + except KeyboardInterrupt: + sql = "" + + editor_command = special.editor_command(text) + return text + + def execute_command(self, text): + logger = self.logger + + query = MetaQuery(query=text, successful=False) + + try: + if self.destructive_warning: + destroy = confirm = confirm_destructive_query(text) + if destroy is False: + click.secho("Wise choice!") + raise KeyboardInterrupt + elif destroy: + click.secho("Your call!") + output, query = self._evaluate_command(text) + except KeyboardInterrupt: + # Restart connection to the database + self.pgexecute.connect() + logger.debug("cancelled query, sql: %r", text) + click.secho("cancelled query", err=True, fg="red") + except NotImplementedError: + click.secho("Not Yet Implemented.", fg="yellow") + except OperationalError as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + self._handle_server_closed_connection(text) + except (PgCliQuitError, EOFError) as e: + raise + except Exception as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + click.secho(str(e), err=True, fg="red") + else: + try: + if self.output_file and not text.startswith(("\\o ", "\\? ")): + try: + with open(self.output_file, "a", encoding="utf-8") as f: + click.echo(text, file=f) + click.echo("\n".join(output), file=f) + click.echo("", file=f) # extra newline + except IOError as e: + click.secho(str(e), err=True, fg="red") + else: + if output: + self.echo_via_pager("\n".join(output)) + except KeyboardInterrupt: + pass + + if self.pgspecial.timing_enabled: + # Only add humanized time display if > 1 second + if query.total_time > 1: + print( + "Time: %0.03fs (%s), executed in: %0.03fs (%s)" + % ( + query.total_time, + pendulum.Duration(seconds=query.total_time).in_words(), + query.execution_time, + pendulum.Duration(seconds=query.execution_time).in_words(), + ) + ) + else: + print("Time: %0.03fs" % query.total_time) + + # Check if we need to update completions, in order of most + # to least drastic changes + if query.db_changed: + with self._completer_lock: + self.completer.reset_completions() + self.refresh_completions(persist_priorities="keywords") + elif query.meta_changed: + self.refresh_completions(persist_priorities="all") + elif query.path_changed: + logger.debug("Refreshing search path") + with self._completer_lock: + self.completer.set_search_path(self.pgexecute.search_path()) + logger.debug("Search path: %r", self.completer.search_path) + return query + + def run_cli(self): + logger = self.logger + + history_file = self.config["main"]["history_file"] + if history_file == "default": + history_file = config_location() + "history" + history = FileHistory(os.path.expanduser(history_file)) + self.refresh_completions(history=history, persist_priorities="none") + + self.prompt_app = self._build_cli(history) + + if not self.less_chatty: + print("Server: PostgreSQL", self.pgexecute.server_version) + print("Version:", __version__) + print("Chat: https://gitter.im/dbcli/pgcli") + print("Home: http://pgcli.com") + + try: + while True: + try: + text = self.prompt_app.prompt() + except KeyboardInterrupt: + continue + + try: + text = self.handle_editor_command(text) + except RuntimeError as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + click.secho(str(e), err=True, fg="red") + continue + + # Initialize default metaquery in case execution fails + self.watch_command, timing = special.get_watch_command(text) + if self.watch_command: + while self.watch_command: + try: + query = self.execute_command(self.watch_command) + click.echo( + "Waiting for {0} seconds before repeating".format( + timing + ) + ) + sleep(timing) + except KeyboardInterrupt: + self.watch_command = None + else: + query = self.execute_command(text) + + self.now = dt.datetime.today() + + # Allow PGCompleter to learn user's preferred keywords, etc. + with self._completer_lock: + self.completer.extend_query_history(text) + + self.query_history.append(query) + + except (PgCliQuitError, EOFError): + if not self.less_chatty: + print("Goodbye!") + + def _build_cli(self, history): + key_bindings = pgcli_bindings(self) + + def get_message(): + if self.dsn_alias and self.prompt_dsn_format is not None: + prompt_format = self.prompt_dsn_format + else: + prompt_format = self.prompt_format + + prompt = self.get_prompt(prompt_format) + + if ( + prompt_format == self.default_prompt + and len(prompt) > self.max_len_prompt + ): + prompt = self.get_prompt("\\d> ") + + prompt = prompt.replace("\\x1b", "\x1b") + return ANSI(prompt) + + def get_continuation(width, line_number, is_soft_wrap): + continuation = self.multiline_continuation_char * (width - 1) + " " + return [("class:continuation", continuation)] + + get_toolbar_tokens = create_toolbar_tokens_func(self) + + if self.wider_completion_menu: + complete_style = CompleteStyle.MULTI_COLUMN + else: + complete_style = CompleteStyle.COLUMN + + with self._completer_lock: + prompt_app = PromptSession( + lexer=PygmentsLexer(PostgresLexer), + reserve_space_for_menu=self.min_num_menu_lines, + message=get_message, + prompt_continuation=get_continuation, + bottom_toolbar=get_toolbar_tokens if self.show_bottom_toolbar else None, + complete_style=complete_style, + input_processors=[ + # Highlight matching brackets while editing. + ConditionalProcessor( + processor=HighlightMatchingBracketProcessor(chars="[](){}"), + filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(), + ), + # Render \t as 4 spaces instead of "^I" + TabsProcessor(char1=" ", char2=" "), + ], + auto_suggest=AutoSuggestFromHistory(), + tempfile_suffix=".sql", + # N.b. pgcli's multi-line mode controls submit-on-Enter (which + # overrides the default behaviour of prompt_toolkit) and is + # distinct from prompt_toolkit's multiline mode here, which + # controls layout/display of the prompt/buffer + multiline=True, + history=history, + completer=ThreadedCompleter(DynamicCompleter(lambda: self.completer)), + complete_while_typing=True, + style=style_factory(self.syntax_style, self.cli_style), + include_default_pygments_style=False, + key_bindings=key_bindings, + enable_open_in_editor=True, + enable_system_prompt=True, + enable_suspend=True, + editing_mode=EditingMode.VI if self.vi_mode else EditingMode.EMACS, + search_ignore_case=True, + ) + + return prompt_app + + def _should_limit_output(self, sql, cur): + """returns True if the output should be truncated, False otherwise.""" + if not is_select(sql): + return False + + return ( + not self._has_limit(sql) + and self.row_limit != 0 + and cur + and cur.rowcount > self.row_limit + ) + + def _has_limit(self, sql): + if not sql: + return False + return "limit " in sql.lower() + + def _limit_output(self, cur): + limit = min(self.row_limit, cur.rowcount) + new_cur = itertools.islice(cur, limit) + new_status = "SELECT " + str(limit) + click.secho("The result was limited to %s rows" % limit, fg="red") + + return new_cur, new_status + + def _evaluate_command(self, text): + """Used to run a command entered by the user during CLI operation + (Puts the E in REPL) + + returns (results, MetaQuery) + """ + logger = self.logger + logger.debug("sql: %r", text) + + all_success = True + meta_changed = False # CREATE, ALTER, DROP, etc + mutated = False # INSERT, DELETE, etc + db_changed = False + path_changed = False + output = [] + total = 0 + execution = 0 + + # Run the query. + start = time() + on_error_resume = self.on_error == "RESUME" + res = self.pgexecute.run( + text, self.pgspecial, exception_formatter, on_error_resume + ) + + is_special = None + + for title, cur, headers, status, sql, success, is_special in res: + logger.debug("headers: %r", headers) + logger.debug("rows: %r", cur) + logger.debug("status: %r", status) + + if self._should_limit_output(sql, cur): + cur, status = self._limit_output(cur) + + if self.pgspecial.auto_expand or self.auto_expand: + max_width = self.prompt_app.output.get_size().columns + else: + max_width = None + + expanded = self.pgspecial.expanded_output or self.expanded_output + settings = OutputSettings( + table_format=self.table_format, + dcmlfmt=self.decimal_format, + floatfmt=self.float_format, + missingval=self.null_string, + expanded=expanded, + max_width=max_width, + case_function=( + self.completer.case + if self.settings["case_column_headers"] + else lambda x: x + ), + style_output=self.style_output, + ) + execution = time() - start + formatted = format_output(title, cur, headers, status, settings) + + output.extend(formatted) + total = time() - start + + # Keep track of whether any of the queries are mutating or changing + # the database + if success: + mutated = mutated or is_mutating(status) + db_changed = db_changed or has_change_db_cmd(sql) + meta_changed = meta_changed or has_meta_cmd(sql) + path_changed = path_changed or has_change_path_cmd(sql) + else: + all_success = False + + meta_query = MetaQuery( + text, + all_success, + total, + execution, + meta_changed, + db_changed, + path_changed, + mutated, + is_special, + ) + + return output, meta_query + + def _handle_server_closed_connection(self, text): + """Used during CLI execution.""" + try: + click.secho("Reconnecting...", fg="green") + self.pgexecute.connect() + click.secho("Reconnected!", fg="green") + self.execute_command(text) + except OperationalError as e: + click.secho("Reconnect Failed", fg="red") + click.secho(str(e), err=True, fg="red") + + def refresh_completions(self, history=None, persist_priorities="all"): + """Refresh outdated completions + + :param history: A prompt_toolkit.history.FileHistory object. Used to + load keyword and identifier preferences + + :param persist_priorities: 'all' or 'keywords' + """ + + callback = functools.partial( + self._on_completions_refreshed, persist_priorities=persist_priorities + ) + self.completion_refresher.refresh( + self.pgexecute, + self.pgspecial, + callback, + history=history, + settings=self.settings, + ) + return [ + (None, None, None, "Auto-completion refresh started in the background.") + ] + + def _on_completions_refreshed(self, new_completer, persist_priorities): + self._swap_completer_objects(new_completer, persist_priorities) + + if self.prompt_app: + # After refreshing, redraw the CLI to clear the statusbar + # "Refreshing completions..." indicator + self.prompt_app.app.invalidate() + + def _swap_completer_objects(self, new_completer, persist_priorities): + """Swap the completer object with the newly created completer. + + persist_priorities is a string specifying how the old completer's + learned prioritizer should be transferred to the new completer. + + 'none' - The new prioritizer is left in a new/clean state + + 'all' - The new prioritizer is updated to exactly reflect + the old one + + 'keywords' - The new prioritizer is updated with old keyword + priorities, but not any other. + + """ + with self._completer_lock: + old_completer = self.completer + self.completer = new_completer + + if persist_priorities == "all": + # Just swap over the entire prioritizer + new_completer.prioritizer = old_completer.prioritizer + elif persist_priorities == "keywords": + # Swap over the entire prioritizer, but clear name priorities, + # leaving learned keyword priorities alone + new_completer.prioritizer = old_completer.prioritizer + new_completer.prioritizer.clear_names() + elif persist_priorities == "none": + # Leave the new prioritizer as is + pass + self.completer = new_completer + + def get_completions(self, text, cursor_positition): + with self._completer_lock: + return self.completer.get_completions( + Document(text=text, cursor_position=cursor_positition), None + ) + + def get_prompt(self, string): + # should be before replacing \\d + string = string.replace("\\dsn_alias", self.dsn_alias or "") + string = string.replace("\\t", self.now.strftime("%x %X")) + string = string.replace("\\u", self.pgexecute.user or "(none)") + string = string.replace("\\H", self.pgexecute.host or "(none)") + string = string.replace("\\h", self.pgexecute.short_host or "(none)") + string = string.replace("\\d", self.pgexecute.dbname or "(none)") + string = string.replace( + "\\p", + str(self.pgexecute.port) if self.pgexecute.port is not None else "5432", + ) + string = string.replace("\\i", str(self.pgexecute.pid) or "(none)") + string = string.replace("\\#", "#" if (self.pgexecute.superuser) else ">") + string = string.replace("\\n", "\n") + return string + + def get_last_query(self): + """Get the last query executed or None.""" + return self.query_history[-1][0] if self.query_history else None + + def is_too_wide(self, line): + """Will this line be too wide to fit into terminal?""" + if not self.prompt_app: + return False + return ( + len(COLOR_CODE_REGEX.sub("", line)) + > self.prompt_app.output.get_size().columns + ) + + def is_too_tall(self, lines): + """Are there too many lines to fit into terminal?""" + if not self.prompt_app: + return False + return len(lines) >= (self.prompt_app.output.get_size().rows - 4) + + def echo_via_pager(self, text, color=None): + if self.pgspecial.pager_config == PAGER_OFF or self.watch_command: + click.echo(text, color=color) + elif "pspg" in os.environ.get("PAGER", "") and self.table_format == "csv": + click.echo_via_pager(text, color) + elif self.pgspecial.pager_config == PAGER_LONG_OUTPUT: + lines = text.split("\n") + + # The last 4 lines are reserved for the pgcli menu and padding + if self.is_too_tall(lines) or any(self.is_too_wide(l) for l in lines): + click.echo_via_pager(text, color=color) + else: + click.echo(text, color=color) + else: + click.echo_via_pager(text, color) + + +@click.command() +# Default host is '' so psycopg2 can default to either localhost or unix socket +@click.option( + "-h", + "--host", + default="", + envvar="PGHOST", + help="Host address of the postgres database.", +) +@click.option( + "-p", + "--port", + default=5432, + help="Port number at which the " "postgres instance is listening.", + envvar="PGPORT", + type=click.INT, +) +@click.option( + "-U", + "--username", + "username_opt", + help="Username to connect to the postgres database.", +) +@click.option( + "-u", "--user", "username_opt", help="Username to connect to the postgres database." +) +@click.option( + "-W", + "--password", + "prompt_passwd", + is_flag=True, + default=False, + help="Force password prompt.", +) +@click.option( + "-w", + "--no-password", + "never_prompt", + is_flag=True, + default=False, + help="Never prompt for password.", +) +@click.option( + "--single-connection", + "single_connection", + is_flag=True, + default=False, + help="Do not use a separate connection for completions.", +) +@click.option("-v", "--version", is_flag=True, help="Version of pgcli.") +@click.option("-d", "--dbname", "dbname_opt", help="database name to connect to.") +@click.option( + "--pgclirc", + default=config_location() + "config", + envvar="PGCLIRC", + help="Location of pgclirc file.", + type=click.Path(dir_okay=False), +) +@click.option( + "-D", + "--dsn", + default="", + envvar="DSN", + help="Use DSN configured into the [alias_dsn] section of pgclirc file.", +) +@click.option( + "--list-dsn", + "list_dsn", + is_flag=True, + help="list of DSN configured into the [alias_dsn] section of pgclirc file.", +) +@click.option( + "--row-limit", + default=None, + envvar="PGROWLIMIT", + type=click.INT, + help="Set threshold for row limit prompt. Use 0 to disable prompt.", +) +@click.option( + "--less-chatty", + "less_chatty", + is_flag=True, + default=False, + help="Skip intro on startup and goodbye on exit.", +) +@click.option("--prompt", help='Prompt format (Default: "\\u@\\h:\\d> ").') +@click.option( + "--prompt-dsn", + help='Prompt format for connections using DSN aliases (Default: "\\u@\\h:\\d> ").', +) +@click.option( + "-l", + "--list", + "list_databases", + is_flag=True, + help="list " "available databases, then exit.", +) +@click.option( + "--auto-vertical-output", + is_flag=True, + help="Automatically switch to vertical output mode if the result is wider than the terminal width.", +) +@click.option( + "--warn/--no-warn", default=None, help="Warn before running a destructive query." +) +@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1) +@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1) +def cli( + dbname, + username_opt, + host, + port, + prompt_passwd, + never_prompt, + single_connection, + dbname_opt, + username, + version, + pgclirc, + dsn, + row_limit, + less_chatty, + prompt, + prompt_dsn, + list_databases, + auto_vertical_output, + list_dsn, + warn, +): + if version: + print("Version:", __version__) + sys.exit(0) + + config_dir = os.path.dirname(config_location()) + if not os.path.exists(config_dir): + os.makedirs(config_dir) + + # Migrate the config file from old location. + config_full_path = config_location() + "config" + if os.path.exists(os.path.expanduser("~/.pgclirc")): + if not os.path.exists(config_full_path): + shutil.move(os.path.expanduser("~/.pgclirc"), config_full_path) + print("Config file (~/.pgclirc) moved to new location", config_full_path) + else: + print("Config file is now located at", config_full_path) + print( + "Please move the existing config file ~/.pgclirc to", + config_full_path, + ) + if list_dsn: + try: + cfg = load_config(pgclirc, config_full_path) + for alias in cfg["alias_dsn"]: + click.secho(alias + " : " + cfg["alias_dsn"][alias]) + sys.exit(0) + except Exception as err: + click.secho( + "Invalid DSNs found in the config file. " + 'Please check the "[alias_dsn]" section in pgclirc.', + err=True, + fg="red", + ) + exit(1) + + pgcli = PGCli( + prompt_passwd, + never_prompt, + pgclirc_file=pgclirc, + row_limit=row_limit, + single_connection=single_connection, + less_chatty=less_chatty, + prompt=prompt, + prompt_dsn=prompt_dsn, + auto_vertical_output=auto_vertical_output, + warn=warn, + ) + + # Choose which ever one has a valid value. + if dbname_opt and dbname: + # work as psql: when database is given as option and argument use the argument as user + username = dbname + database = dbname_opt or dbname or "" + user = username_opt or username + service = None + if database.startswith("service="): + service = database[8:] + elif os.getenv("PGSERVICE") is not None: + service = os.getenv("PGSERVICE") + # because option --list or -l are not supposed to have a db name + if list_databases: + database = "postgres" + + if dsn != "": + try: + cfg = load_config(pgclirc, config_full_path) + dsn_config = cfg["alias_dsn"][dsn] + except KeyError: + click.secho( + f"Could not find a DSN with alias {dsn}. " + 'Please check the "[alias_dsn]" section in pgclirc.', + err=True, + fg="red", + ) + exit(1) + except Exception: + click.secho( + "Invalid DSNs found in the config file. " + 'Please check the "[alias_dsn]" section in pgclirc.', + err=True, + fg="red", + ) + exit(1) + pgcli.connect_uri(dsn_config) + pgcli.dsn_alias = dsn + elif "://" in database: + pgcli.connect_uri(database) + elif "=" in database and service is None: + pgcli.connect_dsn(database, user=user) + elif service is not None: + pgcli.connect_service(service, user) + else: + pgcli.connect(database, host, user, port) + + if list_databases: + cur, headers, status = pgcli.pgexecute.full_databases() + + title = "List of databases" + settings = OutputSettings(table_format="ascii", missingval="") + formatted = format_output(title, cur, headers, status, settings) + pgcli.echo_via_pager("\n".join(formatted)) + + sys.exit(0) + + pgcli.logger.debug( + "Launch Params: \n" "\tdatabase: %r" "\tuser: %r" "\thost: %r" "\tport: %r", + database, + user, + host, + port, + ) + + if setproctitle: + obfuscate_process_password() + + pgcli.run_cli() + + +def obfuscate_process_password(): + process_title = setproctitle.getproctitle() + if "://" in process_title: + process_title = re.sub(r":(.*):(.*)@", r":\1:xxxx@", process_title) + elif "=" in process_title: + process_title = re.sub( + r"password=(.+?)((\s[a-zA-Z]+=)|$)", r"password=xxxx\2", process_title + ) + + setproctitle.setproctitle(process_title) + + +def has_meta_cmd(query): + """Determines if the completion needs a refresh by checking if the sql + statement is an alter, create, drop, commit or rollback.""" + try: + first_token = query.split()[0] + if first_token.lower() in ("alter", "create", "drop", "commit", "rollback"): + return True + except Exception: + return False + + return False + + +def has_change_db_cmd(query): + """Determines if the statement is a database switch such as 'use' or '\\c'""" + try: + first_token = query.split()[0] + if first_token.lower() in ("use", "\\c", "\\connect"): + return True + except Exception: + return False + + return False + + +def has_change_path_cmd(sql): + """Determines if the search_path should be refreshed by checking if the + sql has 'set search_path'.""" + return "set search_path" in sql.lower() + + +def is_mutating(status): + """Determines if the statement is mutating based on the status.""" + if not status: + return False + + mutating = set(["insert", "update", "delete"]) + return status.split(None, 1)[0].lower() in mutating + + +def is_select(status): + """Returns true if the first word in status is 'select'.""" + if not status: + return False + return status.split(None, 1)[0].lower() == "select" + + +def exception_formatter(e): + return click.style(str(e), fg="red") + + +def format_output(title, cur, headers, status, settings): + output = [] + expanded = settings.expanded or settings.table_format == "vertical" + table_format = "vertical" if settings.expanded else settings.table_format + max_width = settings.max_width + case_function = settings.case_function + formatter = TabularOutputFormatter(format_name=table_format) + + def format_array(val): + if val is None: + return settings.missingval + if not isinstance(val, list): + return val + return "{" + ",".join(str(format_array(e)) for e in val) + "}" + + def format_arrays(data, headers, **_): + data = list(data) + for row in data: + row[:] = [ + format_array(val) if isinstance(val, list) else val for val in row + ] + + return data, headers + + output_kwargs = { + "sep_title": "RECORD {n}", + "sep_character": "-", + "sep_length": (1, 25), + "missing_value": settings.missingval, + "integer_format": settings.dcmlfmt, + "float_format": settings.floatfmt, + "preprocessors": (format_numbers, format_arrays), + "disable_numparse": True, + "preserve_whitespace": True, + "style": settings.style_output, + } + if not settings.floatfmt: + output_kwargs["preprocessors"] = (align_decimals,) + + if table_format == "csv": + # The default CSV dialect is "excel" which is not handling newline values correctly + # Nevertheless, we want to keep on using "excel" on Windows since it uses '\r\n' + # as the line terminator + # https://github.com/dbcli/pgcli/issues/1102 + dialect = "excel" if platform.system() == "Windows" else "unix" + output_kwargs["dialect"] = dialect + + if title: # Only print the title if it's not None. + output.append(title) + + if cur: + headers = [case_function(x) for x in headers] + if max_width is not None: + cur = list(cur) + column_types = None + if hasattr(cur, "description"): + column_types = [] + for d in cur.description: + if ( + d[1] in psycopg2.extensions.DECIMAL.values + or d[1] in psycopg2.extensions.FLOAT.values + ): + column_types.append(float) + if ( + d[1] == psycopg2.extensions.INTEGER.values + or d[1] in psycopg2.extensions.LONGINTEGER.values + ): + column_types.append(int) + else: + column_types.append(str) + + formatted = formatter.format_output(cur, headers, **output_kwargs) + if isinstance(formatted, str): + formatted = iter(formatted.splitlines()) + first_line = next(formatted) + formatted = itertools.chain([first_line], formatted) + if not expanded and max_width and len(first_line) > max_width and headers: + formatted = formatter.format_output( + cur, headers, format_name="vertical", column_types=None, **output_kwargs + ) + if isinstance(formatted, str): + formatted = iter(formatted.splitlines()) + + output = itertools.chain(output, formatted) + + # Only print the status if it's not None and we are not producing CSV + if status and table_format != "csv": + output = itertools.chain(output, [status]) + + return output + + +def parse_service_info(service): + service = service or os.getenv("PGSERVICE") + service_file = os.getenv("PGSERVICEFILE") + if not service_file: + # try ~/.pg_service.conf (if that exists) + if platform.system() == "Windows": + service_file = os.getenv("PGSYSCONFDIR") + "\\pg_service.conf" + elif os.getenv("PGSYSCONFDIR"): + service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf") + else: + service_file = expanduser("~/.pg_service.conf") + if not service: + # nothing to do + return None, service_file + service_file_config = ConfigObj(service_file) + if service not in service_file_config: + return None, service_file + service_conf = service_file_config.get(service) + return service_conf, service_file + + +if __name__ == "__main__": + cli() diff --git a/pgcli/packages/__init__.py b/pgcli/packages/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pgcli/packages/parseutils/__init__.py b/pgcli/packages/parseutils/__init__.py new file mode 100644 index 0000000..a11e7bf --- /dev/null +++ b/pgcli/packages/parseutils/__init__.py @@ -0,0 +1,22 @@ +import sqlparse + + +def query_starts_with(query, prefixes): + """Check if the query starts with any item from *prefixes*.""" + prefixes = [prefix.lower() for prefix in prefixes] + formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip() + return bool(formatted_sql) and formatted_sql.split()[0] in prefixes + + +def queries_start_with(queries, prefixes): + """Check if any queries start with any item from *prefixes*.""" + for query in sqlparse.split(queries): + if query and query_starts_with(query, prefixes) is True: + return True + return False + + +def is_destructive(queries): + """Returns if any of the queries in *queries* is destructive.""" + keywords = ("drop", "shutdown", "delete", "truncate", "alter") + return queries_start_with(queries, keywords) diff --git a/pgcli/packages/parseutils/ctes.py b/pgcli/packages/parseutils/ctes.py new file mode 100644 index 0000000..e1f9088 --- /dev/null +++ b/pgcli/packages/parseutils/ctes.py @@ -0,0 +1,141 @@ +from sqlparse import parse +from sqlparse.tokens import Keyword, CTE, DML +from sqlparse.sql import Identifier, IdentifierList, Parenthesis +from collections import namedtuple +from .meta import TableMetadata, ColumnMetadata + + +# TableExpression is a namedtuple representing a CTE, used internally +# name: cte alias assigned in the query +# columns: list of column names +# start: index into the original string of the left parens starting the CTE +# stop: index into the original string of the right parens ending the CTE +TableExpression = namedtuple("TableExpression", "name columns start stop") + + +def isolate_query_ctes(full_text, text_before_cursor): + """Simplify a query by converting CTEs into table metadata objects""" + + if not full_text or not full_text.strip(): + return full_text, text_before_cursor, tuple() + + ctes, remainder = extract_ctes(full_text) + if not ctes: + return full_text, text_before_cursor, () + + current_position = len(text_before_cursor) + meta = [] + + for cte in ctes: + if cte.start < current_position < cte.stop: + # Currently editing a cte - treat its body as the current full_text + text_before_cursor = full_text[cte.start : current_position] + full_text = full_text[cte.start : cte.stop] + return full_text, text_before_cursor, meta + + # Append this cte to the list of available table metadata + cols = (ColumnMetadata(name, None, ()) for name in cte.columns) + meta.append(TableMetadata(cte.name, cols)) + + # Editing past the last cte (ie the main body of the query) + full_text = full_text[ctes[-1].stop :] + text_before_cursor = text_before_cursor[ctes[-1].stop : current_position] + + return full_text, text_before_cursor, tuple(meta) + + +def extract_ctes(sql): + """Extract constant table expresseions from a query + + Returns tuple (ctes, remainder_sql) + + ctes is a list of TableExpression namedtuples + remainder_sql is the text from the original query after the CTEs have + been stripped. + """ + + p = parse(sql)[0] + + # Make sure the first meaningful token is "WITH" which is necessary to + # define CTEs + idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True) + if not (tok and tok.ttype == CTE): + return [], sql + + # Get the next (meaningful) token, which should be the first CTE + idx, tok = p.token_next(idx) + if not tok: + return ([], "") + start_pos = token_start_pos(p.tokens, idx) + ctes = [] + + if isinstance(tok, IdentifierList): + # Multiple ctes + for t in tok.get_identifiers(): + cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t)) + cte = get_cte_from_token(t, start_pos + cte_start_offset) + if not cte: + continue + ctes.append(cte) + elif isinstance(tok, Identifier): + # A single CTE + cte = get_cte_from_token(tok, start_pos) + if cte: + ctes.append(cte) + + idx = p.token_index(tok) + 1 + + # Collapse everything after the ctes into a remainder query + remainder = "".join(str(tok) for tok in p.tokens[idx:]) + + return ctes, remainder + + +def get_cte_from_token(tok, pos0): + cte_name = tok.get_real_name() + if not cte_name: + return None + + # Find the start position of the opening parens enclosing the cte body + idx, parens = tok.token_next_by(Parenthesis) + if not parens: + return None + + start_pos = pos0 + token_start_pos(tok.tokens, idx) + cte_len = len(str(parens)) # includes parens + stop_pos = start_pos + cte_len + + column_names = extract_column_names(parens) + + return TableExpression(cte_name, column_names, start_pos, stop_pos) + + +def extract_column_names(parsed): + # Find the first DML token to check if it's a SELECT or INSERT/UPDATE/DELETE + idx, tok = parsed.token_next_by(t=DML) + tok_val = tok and tok.value.lower() + + if tok_val in ("insert", "update", "delete"): + # Jump ahead to the RETURNING clause where the list of column names is + idx, tok = parsed.token_next_by(idx, (Keyword, "returning")) + elif not tok_val == "select": + # Must be invalid CTE + return () + + # The next token should be either a column name, or a list of column names + idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True) + return tuple(t.get_name() for t in _identifiers(tok)) + + +def token_start_pos(tokens, idx): + return sum(len(str(t)) for t in tokens[:idx]) + + +def _identifiers(tok): + if isinstance(tok, IdentifierList): + for t in tok.get_identifiers(): + # NB: IdentifierList.get_identifiers() can return non-identifiers! + if isinstance(t, Identifier): + yield t + elif isinstance(tok, Identifier): + yield tok diff --git a/pgcli/packages/parseutils/meta.py b/pgcli/packages/parseutils/meta.py new file mode 100644 index 0000000..108c01a --- /dev/null +++ b/pgcli/packages/parseutils/meta.py @@ -0,0 +1,170 @@ +from collections import namedtuple + +_ColumnMetadata = namedtuple( + "ColumnMetadata", ["name", "datatype", "foreignkeys", "default", "has_default"] +) + + +def ColumnMetadata(name, datatype, foreignkeys=None, default=None, has_default=False): + return _ColumnMetadata(name, datatype, foreignkeys or [], default, has_default) + + +ForeignKey = namedtuple( + "ForeignKey", + [ + "parentschema", + "parenttable", + "parentcolumn", + "childschema", + "childtable", + "childcolumn", + ], +) +TableMetadata = namedtuple("TableMetadata", "name columns") + + +def parse_defaults(defaults_string): + """Yields default values for a function, given the string provided by + pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)""" + if not defaults_string: + return + current = "" + in_quote = None + for char in defaults_string: + if current == "" and char == " ": + # Skip space after comma separating default expressions + continue + if char == '"' or char == "'": + if in_quote and char == in_quote: + # End quote + in_quote = None + elif not in_quote: + # Begin quote + in_quote = char + elif char == "," and not in_quote: + # End of expression + yield current + current = "" + continue + current += char + yield current + + +class FunctionMetadata(object): + def __init__( + self, + schema_name, + func_name, + arg_names, + arg_types, + arg_modes, + return_type, + is_aggregate, + is_window, + is_set_returning, + is_extension, + arg_defaults, + ): + """Class for describing a postgresql function""" + + self.schema_name = schema_name + self.func_name = func_name + + self.arg_modes = tuple(arg_modes) if arg_modes else None + self.arg_names = tuple(arg_names) if arg_names else None + + # Be flexible in not requiring arg_types -- use None as a placeholder + # for each arg. (Used for compatibility with old versions of postgresql + # where such info is hard to get. + if arg_types: + self.arg_types = tuple(arg_types) + elif arg_modes: + self.arg_types = tuple([None] * len(arg_modes)) + elif arg_names: + self.arg_types = tuple([None] * len(arg_names)) + else: + self.arg_types = None + + self.arg_defaults = tuple(parse_defaults(arg_defaults)) + + self.return_type = return_type.strip() + self.is_aggregate = is_aggregate + self.is_window = is_window + self.is_set_returning = is_set_returning + self.is_extension = bool(is_extension) + self.is_public = self.schema_name and self.schema_name == "public" + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not self.__eq__(other) + + def _signature(self): + return ( + self.schema_name, + self.func_name, + self.arg_names, + self.arg_types, + self.arg_modes, + self.return_type, + self.is_aggregate, + self.is_window, + self.is_set_returning, + self.is_extension, + self.arg_defaults, + ) + + def __hash__(self): + return hash(self._signature()) + + def __repr__(self): + return ( + "%s(schema_name=%r, func_name=%r, arg_names=%r, " + "arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, " + "is_window=%r, is_set_returning=%r, is_extension=%r, arg_defaults=%r)" + ) % ((self.__class__.__name__,) + self._signature()) + + def has_variadic(self): + return self.arg_modes and any(arg_mode == "v" for arg_mode in self.arg_modes) + + def args(self): + """Returns a list of input-parameter ColumnMetadata namedtuples.""" + if not self.arg_names: + return [] + modes = self.arg_modes or ["i"] * len(self.arg_names) + args = [ + (name, typ) + for name, typ, mode in zip(self.arg_names, self.arg_types, modes) + if mode in ("i", "b", "v") # IN, INOUT, VARIADIC + ] + + def arg(name, typ, num): + num_args = len(args) + num_defaults = len(self.arg_defaults) + has_default = num + num_defaults >= num_args + default = ( + self.arg_defaults[num - num_args + num_defaults] + if has_default + else None + ) + return ColumnMetadata(name, typ, [], default, has_default) + + return [arg(name, typ, num) for num, (name, typ) in enumerate(args)] + + def fields(self): + """Returns a list of output-field ColumnMetadata namedtuples""" + + if self.return_type.lower() == "void": + return [] + elif not self.arg_modes: + # For functions without output parameters, the function name + # is used as the name of the output column. + # E.g. 'SELECT unnest FROM unnest(...);' + return [ColumnMetadata(self.func_name, self.return_type, [])] + + return [ + ColumnMetadata(name, typ, []) + for name, typ, mode in zip(self.arg_names, self.arg_types, self.arg_modes) + if mode in ("o", "b", "t") + ] # OUT, INOUT, TABLE diff --git a/pgcli/packages/parseutils/tables.py b/pgcli/packages/parseutils/tables.py new file mode 100644 index 0000000..0ec3e69 --- /dev/null +++ b/pgcli/packages/parseutils/tables.py @@ -0,0 +1,170 @@ +import sqlparse +from collections import namedtuple +from sqlparse.sql import IdentifierList, Identifier, Function +from sqlparse.tokens import Keyword, DML, Punctuation + +TableReference = namedtuple( + "TableReference", ["schema", "name", "alias", "is_function"] +) +TableReference.ref = property( + lambda self: self.alias + or ( + self.name + if self.name.islower() or self.name[0] == '"' + else '"' + self.name + '"' + ) +) + + +# This code is borrowed from sqlparse example script. +# +def is_subselect(parsed): + if not parsed.is_group: + return False + for item in parsed.tokens: + if item.ttype is DML and item.value.upper() in ( + "SELECT", + "INSERT", + "UPDATE", + "CREATE", + "DELETE", + ): + return True + return False + + +def _identifier_is_function(identifier): + return any(isinstance(t, Function) for t in identifier.tokens) + + +def extract_from_part(parsed, stop_at_punctuation=True): + tbl_prefix_seen = False + for item in parsed.tokens: + if tbl_prefix_seen: + if is_subselect(item): + for x in extract_from_part(item, stop_at_punctuation): + yield x + elif stop_at_punctuation and item.ttype is Punctuation: + return + # An incomplete nested select won't be recognized correctly as a + # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes + # the second FROM to trigger this elif condition resulting in a + # `return`. So we need to ignore the keyword if the keyword + # FROM. + # Also 'SELECT * FROM abc JOIN def' will trigger this elif + # condition. So we need to ignore the keyword JOIN and its variants + # INNER JOIN, FULL OUTER JOIN, etc. + elif ( + item.ttype is Keyword + and (not item.value.upper() == "FROM") + and (not item.value.upper().endswith("JOIN")) + ): + tbl_prefix_seen = False + else: + yield item + elif item.ttype is Keyword or item.ttype is Keyword.DML: + item_val = item.value.upper() + if ( + item_val + in ( + "COPY", + "FROM", + "INTO", + "UPDATE", + "TABLE", + ) + or item_val.endswith("JOIN") + ): + tbl_prefix_seen = True + # 'SELECT a, FROM abc' will detect FROM as part of the column list. + # So this check here is necessary. + elif isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + if identifier.ttype is Keyword and identifier.value.upper() == "FROM": + tbl_prefix_seen = True + break + + +def extract_table_identifiers(token_stream, allow_functions=True): + """yields tuples of TableReference namedtuples""" + + # We need to do some massaging of the names because postgres is case- + # insensitive and '"Foo"' is not the same table as 'Foo' (while 'foo' is) + def parse_identifier(item): + name = item.get_real_name() + schema_name = item.get_parent_name() + alias = item.get_alias() + if not name: + schema_name = None + name = item.get_name() + alias = alias or name + schema_quoted = schema_name and item.value[0] == '"' + if schema_name and not schema_quoted: + schema_name = schema_name.lower() + quote_count = item.value.count('"') + name_quoted = quote_count > 2 or (quote_count and not schema_quoted) + alias_quoted = alias and item.value[-1] == '"' + if alias_quoted or name_quoted and not alias and name.islower(): + alias = '"' + (alias or name) + '"' + if name and not name_quoted and not name.islower(): + if not alias: + alias = name + name = name.lower() + return schema_name, name, alias + + try: + for item in token_stream: + if isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + # Sometimes Keywords (such as FROM ) are classified as + # identifiers which don't have the get_real_name() method. + try: + schema_name = identifier.get_parent_name() + real_name = identifier.get_real_name() + is_function = allow_functions and _identifier_is_function( + identifier + ) + except AttributeError: + continue + if real_name: + yield TableReference( + schema_name, real_name, identifier.get_alias(), is_function + ) + elif isinstance(item, Identifier): + schema_name, real_name, alias = parse_identifier(item) + is_function = allow_functions and _identifier_is_function(item) + + yield TableReference(schema_name, real_name, alias, is_function) + elif isinstance(item, Function): + schema_name, real_name, alias = parse_identifier(item) + yield TableReference(None, real_name, alias, allow_functions) + except StopIteration: + return + + +# extract_tables is inspired from examples in the sqlparse lib. +def extract_tables(sql): + """Extract the table names from an SQL statment. + + Returns a list of TableReference namedtuples + + """ + parsed = sqlparse.parse(sql) + if not parsed: + return () + + # INSERT statements must stop looking for tables at the sign of first + # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2) + # abc is the table name, but if we don't stop at the first lparen, then + # we'll identify abc, col1 and col2 as table names. + insert_stmt = parsed[0].token_first().value.lower() == "insert" + stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt) + + # Kludge: sqlparse mistakenly identifies insert statements as + # function calls due to the parenthesized column list, e.g. interprets + # "insert into foo (bar, baz)" as a function call to foo with arguments + # (bar, baz). So don't allow any identifiers in insert statements + # to have is_function=True + identifiers = extract_table_identifiers(stream, allow_functions=not insert_stmt) + # In the case 'sche.', we get an empty TableReference; remove that + return tuple(i for i in identifiers if i.name) diff --git a/pgcli/packages/parseutils/utils.py b/pgcli/packages/parseutils/utils.py new file mode 100644 index 0000000..034c96e --- /dev/null +++ b/pgcli/packages/parseutils/utils.py @@ -0,0 +1,140 @@ +import re +import sqlparse +from sqlparse.sql import Identifier +from sqlparse.tokens import Token, Error + +cleanup_regex = { + # This matches only alphanumerics and underscores. + "alphanum_underscore": re.compile(r"(\w+)$"), + # This matches everything except spaces, parens, colon, and comma + "many_punctuations": re.compile(r"([^():,\s]+)$"), + # This matches everything except spaces, parens, colon, comma, and period + "most_punctuations": re.compile(r"([^\.():,\s]+)$"), + # This matches everything except a space. + "all_punctuations": re.compile(r"([^\s]+)$"), +} + + +def last_word(text, include="alphanum_underscore"): + r""" + Find the last word in a sentence. + + >>> last_word('abc') + 'abc' + >>> last_word(' abc') + 'abc' + >>> last_word('') + '' + >>> last_word(' ') + '' + >>> last_word('abc ') + '' + >>> last_word('abc def') + 'def' + >>> last_word('abc def ') + '' + >>> last_word('abc def;') + '' + >>> last_word('bac $def') + 'def' + >>> last_word('bac $def', include='most_punctuations') + '$def' + >>> last_word('bac \def', include='most_punctuations') + '\\\\def' + >>> last_word('bac \def;', include='most_punctuations') + '\\\\def;' + >>> last_word('bac::def', include='most_punctuations') + 'def' + >>> last_word('"foo*bar', include='most_punctuations') + '"foo*bar' + """ + + if not text: # Empty string + return "" + + if text[-1].isspace(): + return "" + else: + regex = cleanup_regex[include] + matches = regex.search(text) + if matches: + return matches.group(0) + else: + return "" + + +def find_prev_keyword(sql, n_skip=0): + """Find the last sql keyword in an SQL statement + + Returns the value of the last keyword, and the text of the query with + everything after the last keyword stripped + """ + if not sql.strip(): + return None, "" + + parsed = sqlparse.parse(sql)[0] + flattened = list(parsed.flatten()) + flattened = flattened[: len(flattened) - n_skip] + + logical_operators = ("AND", "OR", "NOT", "BETWEEN") + + for t in reversed(flattened): + if t.value == "(" or ( + t.is_keyword and (t.value.upper() not in logical_operators) + ): + # Find the location of token t in the original parsed statement + # We can't use parsed.token_index(t) because t may be a child token + # inside a TokenList, in which case token_index throws an error + # Minimal example: + # p = sqlparse.parse('select * from foo where bar') + # t = list(p.flatten())[-3] # The "Where" token + # p.token_index(t) # Throws ValueError: not in list + idx = flattened.index(t) + + # Combine the string values of all tokens in the original list + # up to and including the target keyword token t, to produce a + # query string with everything after the keyword token removed + text = "".join(tok.value for tok in flattened[: idx + 1]) + return t, text + + return None, "" + + +# Postgresql dollar quote signs look like `$$` or `$tag$` +dollar_quote_regex = re.compile(r"^\$[^$]*\$$") + + +def is_open_quote(sql): + """Returns true if the query contains an unclosed quote""" + + # parsed can contain one or more semi-colon separated commands + parsed = sqlparse.parse(sql) + return any(_parsed_is_open_quote(p) for p in parsed) + + +def _parsed_is_open_quote(parsed): + # Look for unmatched single quotes, or unmatched dollar sign quotes + return any(tok.match(Token.Error, ("'", "$")) for tok in parsed.flatten()) + + +def parse_partial_identifier(word): + """Attempt to parse a (partially typed) word as an identifier + + word may include a schema qualification, like `schema_name.partial_name` + or `schema_name.` There may also be unclosed quotation marks, like + `"schema`, or `schema."partial_name` + + :param word: string representing a (partially complete) identifier + :return: sqlparse.sql.Identifier, or None + """ + + p = sqlparse.parse(word)[0] + n_tok = len(p.tokens) + if n_tok == 1 and isinstance(p.tokens[0], Identifier): + return p.tokens[0] + elif p.token_next_by(m=(Error, '"'))[1]: + # An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar' + # Close the double quote, then reparse + return parse_partial_identifier(word + '"') + else: + return None diff --git a/pgcli/packages/pgliterals/__init__.py b/pgcli/packages/pgliterals/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pgcli/packages/pgliterals/main.py b/pgcli/packages/pgliterals/main.py new file mode 100644 index 0000000..5c39296 --- /dev/null +++ b/pgcli/packages/pgliterals/main.py @@ -0,0 +1,15 @@ +import os +import json + +root = os.path.dirname(__file__) +literal_file = os.path.join(root, "pgliterals.json") + +with open(literal_file) as f: + literals = json.load(f) + + +def get_literals(literal_type, type_=tuple): + # Where `literal_type` is one of 'keywords', 'functions', 'datatypes', + # returns a tuple of literal values of that type. + + return type_(literals[literal_type]) diff --git a/pgcli/packages/pgliterals/pgliterals.json b/pgcli/packages/pgliterals/pgliterals.json new file mode 100644 index 0000000..c7b74b5 --- /dev/null +++ b/pgcli/packages/pgliterals/pgliterals.json @@ -0,0 +1,629 @@ +{ + "keywords": { + "ACCESS": [], + "ADD": [], + "ALL": [], + "ALTER": [ + "AGGREGATE", + "COLLATION", + "COLUMN", + "CONVERSION", + "DATABASE", + "DEFAULT", + "DOMAIN", + "EVENT TRIGGER", + "EXTENSION", + "FOREIGN", + "FUNCTION", + "GROUP", + "INDEX", + "LANGUAGE", + "LARGE OBJECT", + "MATERIALIZED VIEW", + "OPERATOR", + "POLICY", + "ROLE", + "RULE", + "SCHEMA", + "SEQUENCE", + "SERVER", + "SYSTEM", + "TABLE", + "TABLESPACE", + "TEXT SEARCH", + "TRIGGER", + "TYPE", + "USER", + "VIEW" + ], + "AND": [], + "ANY": [], + "AS": [], + "ASC": [], + "AUDIT": [], + "BEGIN": [], + "BETWEEN": [], + "BY": [], + "CASE": [], + "CHAR": [], + "CHECK": [], + "CLUSTER": [], + "COLUMN": [], + "COMMENT": [], + "COMMIT": [], + "COMPRESS": [], + "CONCURRENTLY": [], + "CONNECT": [], + "COPY": [], + "CREATE": [ + "ACCESS METHOD", + "AGGREGATE", + "CAST", + "COLLATION", + "CONVERSION", + "DATABASE", + "DOMAIN", + "EVENT TRIGGER", + "EXTENSION", + "FOREIGN DATA WRAPPER", + "FOREIGN EXTENSION", + "FUNCTION", + "GLOBAL", + "GROUP", + "IF NOT EXISTS", + "INDEX", + "LANGUAGE", + "LOCAL", + "MATERIALIZED VIEW", + "OPERATOR", + "OR REPLACE", + "POLICY", + "ROLE", + "RULE", + "SCHEMA", + "SEQUENCE", + "SERVER", + "TABLE", + "TABLESPACE", + "TEMPORARY", + "TEXT SEARCH", + "TRIGGER", + "TYPE", + "UNIQUE", + "UNLOGGED", + "USER", + "USER MAPPING", + "VIEW" + ], + "CURRENT": [], + "DATABASE": [], + "DATE": [], + "DECIMAL": [], + "DEFAULT": [], + "DELETE FROM": [], + "DELIMITER": [], + "DESC": [], + "DESCRIBE": [], + "DISTINCT": [], + "DROP": [ + "ACCESS METHOD", + "AGGREGATE", + "CAST", + "COLLATION", + "COLUMN", + "CONVERSION", + "DATABASE", + "DOMAIN", + "EVENT TRIGGER", + "EXTENSION", + "FOREIGN DATA WRAPPER", + "FOREIGN TABLE", + "FUNCTION", + "GROUP", + "INDEX", + "LANGUAGE", + "MATERIALIZED VIEW", + "OPERATOR", + "OWNED", + "POLICY", + "ROLE", + "RULE", + "SCHEMA", + "SEQUENCE", + "SERVER", + "TABLE", + "TABLESPACE", + "TEXT SEARCH", + "TRANSFORM", + "TRIGGER", + "TYPE", + "USER", + "USER MAPPING", + "VIEW" + ], + "EXPLAIN": [], + "ELSE": [], + "ENCODING": [], + "ESCAPE": [], + "EXCLUSIVE": [], + "EXISTS": [], + "EXTENSION": [], + "FILE": [], + "FLOAT": [], + "FOR": [], + "FORMAT": [], + "FORCE_QUOTE": [], + "FORCE_NOT_NULL": [], + "FREEZE": [], + "FROM": [], + "FULL": [], + "FUNCTION": [], + "GRANT": [], + "GROUP BY": [], + "HAVING": [], + "HEADER": [], + "IDENTIFIED": [], + "IMMEDIATE": [], + "IN": [], + "INCREMENT": [], + "INDEX": [], + "INITIAL": [], + "INSERT INTO": [], + "INTEGER": [], + "INTERSECT": [], + "INTERVAL": [], + "INTO": [], + "IS": [], + "JOIN": [], + "LANGUAGE": [], + "LEFT": [], + "LEVEL": [], + "LIKE": [], + "LIMIT": [], + "LOCK": [], + "LONG": [], + "MATERIALIZED VIEW": [], + "MAXEXTENTS": [], + "MINUS": [], + "MLSLABEL": [], + "MODE": [], + "MODIFY": [], + "NOT": [], + "NOAUDIT": [], + "NOTICE": [], + "NOCOMPRESS": [], + "NOWAIT": [], + "NULL": [], + "NUMBER": [], + "OIDS": [], + "OF": [], + "OFFLINE": [], + "ON": [], + "ONLINE": [], + "OPTION": [], + "OR": [], + "ORDER BY": [], + "OUTER": [], + "OWNER": [], + "PCTFREE": [], + "PRIMARY": [], + "PRIOR": [], + "PRIVILEGES": [], + "QUOTE": [], + "RAISE": [], + "RENAME": [], + "REPLACE": [], + "RESET": ["ALL"], + "RAW": [], + "REFRESH MATERIALIZED VIEW": [], + "RESOURCE": [], + "RETURNS": [], + "REVOKE": [], + "RIGHT": [], + "ROLLBACK": [], + "ROW": [], + "ROWID": [], + "ROWNUM": [], + "ROWS": [], + "SELECT": [], + "SESSION": [], + "SET": [], + "SHARE": [], + "SHOW": [], + "SIZE": [], + "SMALLINT": [], + "START": [], + "SUCCESSFUL": [], + "SYNONYM": [], + "SYSDATE": [], + "TABLE": [], + "TEMPLATE": [], + "THEN": [], + "TO": [], + "TRIGGER": [], + "TRUNCATE": [], + "UID": [], + "UNION": [], + "UNIQUE": [], + "UPDATE": [], + "USE": [], + "USER": [], + "USING": [], + "VALIDATE": [], + "VALUES": [], + "VARCHAR": [], + "VARCHAR2": [], + "VIEW": [], + "WHEN": [], + "WHENEVER": [], + "WHERE": [], + "WITH": [] + }, + "functions": [ + "ABBREV", + "ABS", + "AGE", + "AREA", + "ARRAY_AGG", + "ARRAY_APPEND", + "ARRAY_CAT", + "ARRAY_DIMS", + "ARRAY_FILL", + "ARRAY_LENGTH", + "ARRAY_LOWER", + "ARRAY_NDIMS", + "ARRAY_POSITION", + "ARRAY_POSITIONS", + "ARRAY_PREPEND", + "ARRAY_REMOVE", + "ARRAY_REPLACE", + "ARRAY_TO_STRING", + "ARRAY_UPPER", + "ASCII", + "AVG", + "BIT_AND", + "BIT_LENGTH", + "BIT_OR", + "BOOL_AND", + "BOOL_OR", + "BOUND_BOX", + "BOX", + "BROADCAST", + "BTRIM", + "CARDINALITY", + "CBRT", + "CEIL", + "CEILING", + "CENTER", + "CHAR_LENGTH", + "CHR", + "CIRCLE", + "CLOCK_TIMESTAMP", + "CONCAT", + "CONCAT_WS", + "CONVERT", + "CONVERT_FROM", + "CONVERT_TO", + "COUNT", + "CUME_DIST", + "CURRENT_DATE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "DATE_PART", + "DATE_TRUNC", + "DECODE", + "DEGREES", + "DENSE_RANK", + "DIAMETER", + "DIV", + "ENCODE", + "ENUM_FIRST", + "ENUM_LAST", + "ENUM_RANGE", + "EVERY", + "EXP", + "EXTRACT", + "FAMILY", + "FIRST_VALUE", + "FLOOR", + "FORMAT", + "GET_BIT", + "GET_BYTE", + "HEIGHT", + "HOST", + "HOSTMASK", + "INET_MERGE", + "INET_SAME_FAMILY", + "INITCAP", + "ISCLOSED", + "ISFINITE", + "ISOPEN", + "JUSTIFY_DAYS", + "JUSTIFY_HOURS", + "JUSTIFY_INTERVAL", + "LAG", + "LAST_VALUE", + "LEAD", + "LEFT", + "LENGTH", + "LINE", + "LN", + "LOCALTIME", + "LOCALTIMESTAMP", + "LOG", + "LOG10", + "LOWER", + "LPAD", + "LSEG", + "LTRIM", + "MAKE_DATE", + "MAKE_INTERVAL", + "MAKE_TIME", + "MAKE_TIMESTAMP", + "MAKE_TIMESTAMPTZ", + "MASKLEN", + "MAX", + "MD5", + "MIN", + "MOD", + "NETMASK", + "NETWORK", + "NOW", + "NPOINTS", + "NTH_VALUE", + "NTILE", + "NUM_NONNULLS", + "NUM_NULLS", + "OCTET_LENGTH", + "OVERLAY", + "PARSE_IDENT", + "PATH", + "PCLOSE", + "PERCENT_RANK", + "PG_CLIENT_ENCODING", + "PI", + "POINT", + "POLYGON", + "POPEN", + "POSITION", + "POWER", + "QUOTE_IDENT", + "QUOTE_LITERAL", + "QUOTE_NULLABLE", + "RADIANS", + "RADIUS", + "RANK", + "REGEXP_MATCH", + "REGEXP_MATCHES", + "REGEXP_REPLACE", + "REGEXP_SPLIT_TO_ARRAY", + "REGEXP_SPLIT_TO_TABLE", + "REPEAT", + "REPLACE", + "REVERSE", + "RIGHT", + "ROUND", + "ROW_NUMBER", + "RPAD", + "RTRIM", + "SCALE", + "SET_BIT", + "SET_BYTE", + "SET_MASKLEN", + "SHA224", + "SHA256", + "SHA384", + "SHA512", + "SIGN", + "SPLIT_PART", + "SQRT", + "STARTS_WITH", + "STATEMENT_TIMESTAMP", + "STRING_TO_ARRAY", + "STRPOS", + "SUBSTR", + "SUBSTRING", + "SUM", + "TEXT", + "TIMEOFDAY", + "TO_ASCII", + "TO_CHAR", + "TO_DATE", + "TO_HEX", + "TO_NUMBER", + "TO_TIMESTAMP", + "TRANSACTION_TIMESTAMP", + "TRANSLATE", + "TRIM", + "TRUNC", + "UNNEST", + "UPPER", + "WIDTH", + "WIDTH_BUCKET", + "XMLAGG" + ], + "datatypes": [ + "ANY", + "ANYARRAY", + "ANYELEMENT", + "ANYENUM", + "ANYNONARRAY", + "ANYRANGE", + "BIGINT", + "BIGSERIAL", + "BIT", + "BIT VARYING", + "BOOL", + "BOOLEAN", + "BOX", + "BYTEA", + "CHAR", + "CHARACTER", + "CHARACTER VARYING", + "CIDR", + "CIRCLE", + "CSTRING", + "DATE", + "DECIMAL", + "DOUBLE PRECISION", + "EVENT_TRIGGER", + "FDW_HANDLER", + "FLOAT4", + "FLOAT8", + "INET", + "INT", + "INT2", + "INT4", + "INT8", + "INTEGER", + "INTERNAL", + "INTERVAL", + "JSON", + "JSONB", + "LANGUAGE_HANDLER", + "LINE", + "LSEG", + "MACADDR", + "MACADDR8", + "MONEY", + "NUMERIC", + "OID", + "OPAQUE", + "PATH", + "PG_LSN", + "POINT", + "POLYGON", + "REAL", + "RECORD", + "REGCLASS", + "REGCONFIG", + "REGDICTIONARY", + "REGNAMESPACE", + "REGOPER", + "REGOPERATOR", + "REGPROC", + "REGPROCEDURE", + "REGROLE", + "REGTYPE", + "SERIAL", + "SERIAL2", + "SERIAL4", + "SERIAL8", + "SMALLINT", + "SMALLSERIAL", + "TEXT", + "TIME", + "TIMESTAMP", + "TRIGGER", + "TSQUERY", + "TSVECTOR", + "TXID_SNAPSHOT", + "UUID", + "VARBIT", + "VARCHAR", + "VOID", + "XML" + ], + "reserved": [ + "ALL", + "ANALYSE", + "ANALYZE", + "AND", + "ANY", + "ARRAY", + "AS", + "ASC", + "ASYMMETRIC", + "BOTH", + "CASE", + "CAST", + "CHECK", + "COLLATE", + "COLUMN", + "CONSTRAINT", + "CREATE", + "CURRENT_CATALOG", + "CURRENT_DATE", + "CURRENT_ROLE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "CURRENT_USER", + "DEFAULT", + "DEFERRABLE", + "DESC", + "DISTINCT", + "DO", + "ELSE", + "END", + "EXCEPT", + "FALSE", + "FETCH", + "FOR", + "FOREIGN", + "FROM", + "GRANT", + "GROUP", + "HAVING", + "IN", + "INITIALLY", + "INTERSECT", + "INTO", + "LATERAL", + "LEADING", + "LIMIT", + "LOCALTIME", + "LOCALTIMESTAMP", + "NOT", + "NULL", + "OFFSET", + "ON", + "ONLY", + "OR", + "ORDER", + "PLACING", + "PRIMARY", + "REFERENCES", + "RETURNING", + "SELECT", + "SESSION_USER", + "SOME", + "SYMMETRIC", + "TABLE", + "THEN", + "TO", + "TRAILING", + "TRUE", + "UNION", + "UNIQUE", + "USER", + "USING", + "VARIADIC", + "WHEN", + "WHERE", + "WINDOW", + "WITH", + "AUTHORIZATION", + "BINARY", + "COLLATION", + "CONCURRENTLY", + "CROSS", + "CURRENT_SCHEMA", + "FREEZE", + "FULL", + "ILIKE", + "INNER", + "IS", + "ISNULL", + "JOIN", + "LEFT", + "LIKE", + "NATURAL", + "NOTNULL", + "OUTER", + "OVERLAPS", + "RIGHT", + "SIMILAR", + "TABLESAMPLE", + "VERBOSE" + ] +} diff --git a/pgcli/packages/prioritization.py b/pgcli/packages/prioritization.py new file mode 100644 index 0000000..e92dcbb --- /dev/null +++ b/pgcli/packages/prioritization.py @@ -0,0 +1,51 @@ +import re +import sqlparse +from sqlparse.tokens import Name +from collections import defaultdict +from .pgliterals.main import get_literals + + +white_space_regex = re.compile("\\s+", re.MULTILINE) + + +def _compile_regex(keyword): + # Surround the keyword with word boundaries and replace interior whitespace + # with whitespace wildcards + pattern = "\\b" + white_space_regex.sub(r"\\s+", keyword) + "\\b" + return re.compile(pattern, re.MULTILINE | re.IGNORECASE) + + +keywords = get_literals("keywords") +keyword_regexs = dict((kw, _compile_regex(kw)) for kw in keywords) + + +class PrevalenceCounter(object): + def __init__(self): + self.keyword_counts = defaultdict(int) + self.name_counts = defaultdict(int) + + def update(self, text): + self.update_keywords(text) + self.update_names(text) + + def update_names(self, text): + for parsed in sqlparse.parse(text): + for token in parsed.flatten(): + if token.ttype in Name: + self.name_counts[token.value] += 1 + + def clear_names(self): + self.name_counts = defaultdict(int) + + def update_keywords(self, text): + # Count keywords. Can't rely for sqlparse for this, because it's + # database agnostic + for keyword, regex in keyword_regexs.items(): + for _ in regex.finditer(text): + self.keyword_counts[keyword] += 1 + + def keyword_count(self, keyword): + return self.keyword_counts[keyword] + + def name_count(self, name): + return self.name_counts[name] diff --git a/pgcli/packages/prompt_utils.py b/pgcli/packages/prompt_utils.py new file mode 100644 index 0000000..3c58490 --- /dev/null +++ b/pgcli/packages/prompt_utils.py @@ -0,0 +1,35 @@ +import sys +import click +from .parseutils import is_destructive + + +def confirm_destructive_query(queries): + """Check if the query is destructive and prompts the user to confirm. + + Returns: + * None if the query is non-destructive or we can't prompt the user. + * True if the query is destructive and the user wants to proceed. + * False if the query is destructive and the user doesn't want to proceed. + + """ + prompt_text = ( + "You're about to run a destructive command.\n" "Do you want to proceed? (y/n)" + ) + if is_destructive(queries) and sys.stdin.isatty(): + return prompt(prompt_text, type=bool) + + +def confirm(*args, **kwargs): + """Prompt for confirmation (yes/no) and handle any abort exceptions.""" + try: + return click.confirm(*args, **kwargs) + except click.Abort: + return False + + +def prompt(*args, **kwargs): + """Prompt the user for input and handle any abort exceptions.""" + try: + return click.prompt(*args, **kwargs) + except click.Abort: + return False diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py new file mode 100644 index 0000000..6ef8859 --- /dev/null +++ b/pgcli/packages/sqlcompletion.py @@ -0,0 +1,608 @@ +import sys +import re +import sqlparse +from collections import namedtuple +from sqlparse.sql import Comparison, Identifier, Where +from .parseutils.utils import last_word, find_prev_keyword, parse_partial_identifier +from .parseutils.tables import extract_tables +from .parseutils.ctes import isolate_query_ctes +from pgspecial.main import parse_special_command + + +Special = namedtuple("Special", []) +Database = namedtuple("Database", []) +Schema = namedtuple("Schema", ["quoted"]) +Schema.__new__.__defaults__ = (False,) +# FromClauseItem is a table/view/function used in the FROM clause +# `table_refs` contains the list of tables/... already in the statement, +# used to ensure that the alias we suggest is unique +FromClauseItem = namedtuple("FromClauseItem", "schema table_refs local_tables") +Table = namedtuple("Table", ["schema", "table_refs", "local_tables"]) +TableFormat = namedtuple("TableFormat", []) +View = namedtuple("View", ["schema", "table_refs"]) +# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid' +JoinCondition = namedtuple("JoinCondition", ["table_refs", "parent"]) +# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid' +Join = namedtuple("Join", ["table_refs", "schema"]) + +Function = namedtuple("Function", ["schema", "table_refs", "usage"]) +# For convenience, don't require the `usage` argument in Function constructor +Function.__new__.__defaults__ = (None, tuple(), None) +Table.__new__.__defaults__ = (None, tuple(), tuple()) +View.__new__.__defaults__ = (None, tuple()) +FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple()) + +Column = namedtuple( + "Column", + ["table_refs", "require_last_table", "local_tables", "qualifiable", "context"], +) +Column.__new__.__defaults__ = (None, None, tuple(), False, None) + +Keyword = namedtuple("Keyword", ["last_token"]) +Keyword.__new__.__defaults__ = (None,) +NamedQuery = namedtuple("NamedQuery", []) +Datatype = namedtuple("Datatype", ["schema"]) +Alias = namedtuple("Alias", ["aliases"]) + +Path = namedtuple("Path", []) + + +class SqlStatement(object): + def __init__(self, full_text, text_before_cursor): + self.identifier = None + self.word_before_cursor = word_before_cursor = last_word( + text_before_cursor, include="many_punctuations" + ) + full_text = _strip_named_query(full_text) + text_before_cursor = _strip_named_query(text_before_cursor) + + full_text, text_before_cursor, self.local_tables = isolate_query_ctes( + full_text, text_before_cursor + ) + + self.text_before_cursor_including_last_word = text_before_cursor + + # If we've partially typed a word then word_before_cursor won't be an + # empty string. In that case we want to remove the partially typed + # string before sending it to the sqlparser. Otherwise the last token + # will always be the partially typed string which renders the smart + # completion useless because it will always return the list of + # keywords as completion. + if self.word_before_cursor: + if word_before_cursor[-1] == "(" or word_before_cursor[0] == "\\": + parsed = sqlparse.parse(text_before_cursor) + else: + text_before_cursor = text_before_cursor[: -len(word_before_cursor)] + parsed = sqlparse.parse(text_before_cursor) + self.identifier = parse_partial_identifier(word_before_cursor) + else: + parsed = sqlparse.parse(text_before_cursor) + + full_text, text_before_cursor, parsed = _split_multiple_statements( + full_text, text_before_cursor, parsed + ) + + self.full_text = full_text + self.text_before_cursor = text_before_cursor + self.parsed = parsed + + self.last_token = parsed and parsed.token_prev(len(parsed.tokens))[1] or "" + + def is_insert(self): + return self.parsed.token_first().value.lower() == "insert" + + def get_tables(self, scope="full"): + """Gets the tables available in the statement. + param `scope:` possible values: 'full', 'insert', 'before' + If 'insert', only the first table is returned. + If 'before', only tables before the cursor are returned. + If not 'insert' and the stmt is an insert, the first table is skipped. + """ + tables = extract_tables( + self.full_text if scope == "full" else self.text_before_cursor + ) + if scope == "insert": + tables = tables[:1] + elif self.is_insert(): + tables = tables[1:] + return tables + + def get_previous_token(self, token): + return self.parsed.token_prev(self.parsed.token_index(token))[1] + + def get_identifier_schema(self): + schema = (self.identifier and self.identifier.get_parent_name()) or None + # If schema name is unquoted, lower-case it + if schema and self.identifier.value[0] != '"': + schema = schema.lower() + + return schema + + def reduce_to_prev_keyword(self, n_skip=0): + prev_keyword, self.text_before_cursor = find_prev_keyword( + self.text_before_cursor, n_skip=n_skip + ) + return prev_keyword + + +def suggest_type(full_text, text_before_cursor): + """Takes the full_text that is typed so far and also the text before the + cursor to suggest completion type and scope. + + Returns a tuple with a type of entity ('table', 'column' etc) and a scope. + A scope for a column category will be a list of tables. + """ + + if full_text.startswith("\\i "): + return (Path(),) + + # This is a temporary hack; the exception handling + # here should be removed once sqlparse has been fixed + try: + stmt = SqlStatement(full_text, text_before_cursor) + except (TypeError, AttributeError): + return [] + + # Check for special commands and handle those separately + if stmt.parsed: + # Be careful here because trivial whitespace is parsed as a + # statement, but the statement won't have a first token + tok1 = stmt.parsed.token_first() + if tok1 and tok1.value.startswith("\\"): + text = stmt.text_before_cursor + stmt.word_before_cursor + return suggest_special(text) + + return suggest_based_on_last_token(stmt.last_token, stmt) + + +named_query_regex = re.compile(r"^\s*\\ns\s+[A-z0-9\-_]+\s+") + + +def _strip_named_query(txt): + """ + This will strip "save named query" command in the beginning of the line: + '\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' + ' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' + """ + + if named_query_regex.match(txt): + txt = named_query_regex.sub("", txt) + return txt + + +function_body_pattern = re.compile(r"(\$.*?\$)([\s\S]*?)\1", re.M) + + +def _find_function_body(text): + split = function_body_pattern.search(text) + return (split.start(2), split.end(2)) if split else (None, None) + + +def _statement_from_function(full_text, text_before_cursor, statement): + current_pos = len(text_before_cursor) + body_start, body_end = _find_function_body(full_text) + if body_start is None: + return full_text, text_before_cursor, statement + if not body_start <= current_pos < body_end: + return full_text, text_before_cursor, statement + full_text = full_text[body_start:body_end] + text_before_cursor = text_before_cursor[body_start:] + parsed = sqlparse.parse(text_before_cursor) + return _split_multiple_statements(full_text, text_before_cursor, parsed) + + +def _split_multiple_statements(full_text, text_before_cursor, parsed): + if len(parsed) > 1: + # Multiple statements being edited -- isolate the current one by + # cumulatively summing statement lengths to find the one that bounds + # the current position + current_pos = len(text_before_cursor) + stmt_start, stmt_end = 0, 0 + + for statement in parsed: + stmt_len = len(str(statement)) + stmt_start, stmt_end = stmt_end, stmt_end + stmt_len + + if stmt_end >= current_pos: + text_before_cursor = full_text[stmt_start:current_pos] + full_text = full_text[stmt_start:] + break + + elif parsed: + # A single statement + statement = parsed[0] + else: + # The empty string + return full_text, text_before_cursor, None + + token2 = None + if statement.get_type() in ("CREATE", "CREATE OR REPLACE"): + token1 = statement.token_first() + if token1: + token1_idx = statement.token_index(token1) + token2 = statement.token_next(token1_idx)[1] + if token2 and token2.value.upper() == "FUNCTION": + full_text, text_before_cursor, statement = _statement_from_function( + full_text, text_before_cursor, statement + ) + return full_text, text_before_cursor, statement + + +SPECIALS_SUGGESTION = { + "dT": Datatype, + "df": Function, + "dt": Table, + "dv": View, + "sf": Function, +} + + +def suggest_special(text): + text = text.lstrip() + cmd, _, arg = parse_special_command(text) + + if cmd == text: + # Trying to complete the special command itself + return (Special(),) + + if cmd in ("\\c", "\\connect"): + return (Database(),) + + if cmd == "\\T": + return (TableFormat(),) + + if cmd == "\\dn": + return (Schema(),) + + if arg: + # Try to distinguish "\d name" from "\d schema.name" + # Note that this will fail to obtain a schema name if wildcards are + # used, e.g. "\d schema???.name" + parsed = sqlparse.parse(arg)[0].tokens[0] + try: + schema = parsed.get_parent_name() + except AttributeError: + schema = None + else: + schema = None + + if cmd[1:] == "d": + # \d can describe tables or views + if schema: + return (Table(schema=schema), View(schema=schema)) + else: + return (Schema(), Table(schema=None), View(schema=None)) + elif cmd[1:] in SPECIALS_SUGGESTION: + rel_type = SPECIALS_SUGGESTION[cmd[1:]] + if schema: + if rel_type == Function: + return (Function(schema=schema, usage="special"),) + return (rel_type(schema=schema),) + else: + if rel_type == Function: + return (Schema(), Function(schema=None, usage="special")) + return (Schema(), rel_type(schema=None)) + + if cmd in ["\\n", "\\ns", "\\nd"]: + return (NamedQuery(),) + + return (Keyword(), Special()) + + +def suggest_based_on_last_token(token, stmt): + + if isinstance(token, str): + token_v = token.lower() + elif isinstance(token, Comparison): + # If 'token' is a Comparison type such as + # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling + # token.value on the comparison type will only return the lhs of the + # comparison. In this case a.id. So we need to do token.tokens to get + # both sides of the comparison and pick the last token out of that + # list. + token_v = token.tokens[-1].value.lower() + elif isinstance(token, Where): + # sqlparse groups all tokens from the where clause into a single token + # list. This means that token.value may be something like + # 'where foo > 5 and '. We need to look "inside" token.tokens to handle + # suggestions in complicated where clauses correctly + prev_keyword = stmt.reduce_to_prev_keyword() + return suggest_based_on_last_token(prev_keyword, stmt) + elif isinstance(token, Identifier): + # If the previous token is an identifier, we can suggest datatypes if + # we're in a parenthesized column/field list, e.g.: + # CREATE TABLE foo (Identifier + # CREATE FUNCTION foo (Identifier + # If we're not in a parenthesized list, the most likely scenario is the + # user is about to specify an alias, e.g.: + # SELECT Identifier + # SELECT foo FROM Identifier + prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor) + if prev_keyword and prev_keyword.value == "(": + # Suggest datatypes + return suggest_based_on_last_token("type", stmt) + else: + return (Keyword(),) + else: + token_v = token.value.lower() + + if not token: + return (Keyword(), Special()) + elif token_v.endswith("("): + p = sqlparse.parse(stmt.text_before_cursor)[0] + + if p.tokens and isinstance(p.tokens[-1], Where): + # Four possibilities: + # 1 - Parenthesized clause like "WHERE foo AND (" + # Suggest columns/functions + # 2 - Function call like "WHERE foo(" + # Suggest columns/functions + # 3 - Subquery expression like "WHERE EXISTS (" + # Suggest keywords, in order to do a subquery + # 4 - Subquery OR array comparison like "WHERE foo = ANY(" + # Suggest columns/functions AND keywords. (If we wanted to be + # really fancy, we could suggest only array-typed columns) + + column_suggestions = suggest_based_on_last_token("where", stmt) + + # Check for a subquery expression (cases 3 & 4) + where = p.tokens[-1] + prev_tok = where.token_prev(len(where.tokens) - 1)[1] + + if isinstance(prev_tok, Comparison): + # e.g. "SELECT foo FROM bar WHERE foo = ANY(" + prev_tok = prev_tok.tokens[-1] + + prev_tok = prev_tok.value.lower() + if prev_tok == "exists": + return (Keyword(),) + else: + return column_suggestions + + # Get the token before the parens + prev_tok = p.token_prev(len(p.tokens) - 1)[1] + + if ( + prev_tok + and prev_tok.value + and prev_tok.value.lower().split(" ")[-1] == "using" + ): + # tbl1 INNER JOIN tbl2 USING (col1, col2) + tables = stmt.get_tables("before") + + # suggest columns that are present in more than one table + return ( + Column( + table_refs=tables, + require_last_table=True, + local_tables=stmt.local_tables, + ), + ) + + elif p.token_first().value.lower() == "select": + # If the lparen is preceeded by a space chances are we're about to + # do a sub-select. + if last_word(stmt.text_before_cursor, "all_punctuations").startswith("("): + return (Keyword(),) + prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1] + if prev_prev_tok and prev_prev_tok.normalized == "INTO": + return (Column(table_refs=stmt.get_tables("insert"), context="insert"),) + # We're probably in a function argument list + return _suggest_expression(token_v, stmt) + elif token_v == "set": + return (Column(table_refs=stmt.get_tables(), local_tables=stmt.local_tables),) + elif token_v in ("select", "where", "having", "order by", "distinct"): + return _suggest_expression(token_v, stmt) + elif token_v == "as": + # Don't suggest anything for aliases + return () + elif (token_v.endswith("join") and token.is_keyword) or ( + token_v in ("copy", "from", "update", "into", "describe", "truncate") + ): + + schema = stmt.get_identifier_schema() + tables = extract_tables(stmt.text_before_cursor) + is_join = token_v.endswith("join") and token.is_keyword + + # Suggest tables from either the currently-selected schema or the + # public schema if no schema has been specified + suggest = [] + + if not schema: + # Suggest schemas + suggest.insert(0, Schema()) + + if token_v == "from" or is_join: + suggest.append( + FromClauseItem( + schema=schema, table_refs=tables, local_tables=stmt.local_tables + ) + ) + elif token_v == "truncate": + suggest.append(Table(schema)) + else: + suggest.extend((Table(schema), View(schema))) + + if is_join and _allow_join(stmt.parsed): + tables = stmt.get_tables("before") + suggest.append(Join(table_refs=tables, schema=schema)) + + return tuple(suggest) + + elif token_v == "function": + schema = stmt.get_identifier_schema() + + # stmt.get_previous_token will fail for e.g. `SELECT 1 FROM functions WHERE function:` + try: + prev = stmt.get_previous_token(token).value.lower() + if prev in ("drop", "alter", "create", "create or replace"): + + # Suggest functions from either the currently-selected schema or the + # public schema if no schema has been specified + suggest = [] + + if not schema: + # Suggest schemas + suggest.insert(0, Schema()) + + suggest.append(Function(schema=schema, usage="signature")) + return tuple(suggest) + + except ValueError: + pass + return tuple() + + elif token_v in ("table", "view"): + # E.g. 'ALTER TABLE ' + rel_type = {"table": Table, "view": View, "function": Function}[token_v] + schema = stmt.get_identifier_schema() + if schema: + return (rel_type(schema=schema),) + else: + return (Schema(), rel_type(schema=schema)) + + elif token_v == "column": + # E.g. 'ALTER TABLE foo ALTER COLUMN bar + return (Column(table_refs=stmt.get_tables()),) + + elif token_v == "on": + tables = stmt.get_tables("before") + parent = (stmt.identifier and stmt.identifier.get_parent_name()) or None + if parent: + # "ON parent." + # parent can be either a schema name or table alias + filteredtables = tuple(t for t in tables if identifies(parent, t)) + sugs = [ + Column(table_refs=filteredtables, local_tables=stmt.local_tables), + Table(schema=parent), + View(schema=parent), + Function(schema=parent), + ] + if filteredtables and _allow_join_condition(stmt.parsed): + sugs.append(JoinCondition(table_refs=tables, parent=filteredtables[-1])) + return tuple(sugs) + else: + # ON + # Use table alias if there is one, otherwise the table name + aliases = tuple(t.ref for t in tables) + if _allow_join_condition(stmt.parsed): + return ( + Alias(aliases=aliases), + JoinCondition(table_refs=tables, parent=None), + ) + else: + return (Alias(aliases=aliases),) + + elif token_v in ("c", "use", "database", "template"): + # "\c ", "DROP DATABASE ", + # "CREATE DATABASE WITH TEMPLATE " + return (Database(),) + elif token_v == "schema": + # DROP SCHEMA schema_name, SET SCHEMA schema name + prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2) + quoted = prev_keyword and prev_keyword.value.lower() == "set" + return (Schema(quoted),) + elif token_v.endswith(",") or token_v in ("=", "and", "or"): + prev_keyword = stmt.reduce_to_prev_keyword() + if prev_keyword: + return suggest_based_on_last_token(prev_keyword, stmt) + else: + return () + elif token_v in ("type", "::"): + # ALTER TABLE foo SET DATA TYPE bar + # SELECT foo::bar + # Note that tables are a form of composite type in postgresql, so + # they're suggested here as well + schema = stmt.get_identifier_schema() + suggestions = [Datatype(schema=schema), Table(schema=schema)] + if not schema: + suggestions.append(Schema()) + return tuple(suggestions) + elif token_v in {"alter", "create", "drop"}: + return (Keyword(token_v.upper()),) + elif token.is_keyword: + # token is a keyword we haven't implemented any special handling for + # go backwards in the query until we find one we do recognize + prev_keyword = stmt.reduce_to_prev_keyword(n_skip=1) + if prev_keyword: + return suggest_based_on_last_token(prev_keyword, stmt) + else: + return (Keyword(token_v.upper()),) + else: + return (Keyword(),) + + +def _suggest_expression(token_v, stmt): + """ + Return suggestions for an expression, taking account of any partially-typed + identifier's parent, which may be a table alias or schema name. + """ + parent = stmt.identifier.get_parent_name() if stmt.identifier else [] + tables = stmt.get_tables() + + if parent: + tables = tuple(t for t in tables if identifies(parent, t)) + return ( + Column(table_refs=tables, local_tables=stmt.local_tables), + Table(schema=parent), + View(schema=parent), + Function(schema=parent), + ) + + return ( + Column(table_refs=tables, local_tables=stmt.local_tables, qualifiable=True), + Function(schema=None), + Keyword(token_v.upper()), + ) + + +def identifies(id, ref): + """Returns true if string `id` matches TableReference `ref`""" + + return ( + id == ref.alias + or id == ref.name + or (ref.schema and (id == ref.schema + "." + ref.name)) + ) + + +def _allow_join_condition(statement): + """ + Tests if a join condition should be suggested + + We need this to avoid bad suggestions when entering e.g. + select * from tbl1 a join tbl2 b on a.id = + So check that the preceding token is a ON, AND, or OR keyword, instead of + e.g. an equals sign. + + :param statement: an sqlparse.sql.Statement + :return: boolean + """ + + if not statement or not statement.tokens: + return False + + last_tok = statement.token_prev(len(statement.tokens))[1] + return last_tok.value.lower() in ("on", "and", "or") + + +def _allow_join(statement): + """ + Tests if a join should be suggested + + We need this to avoid bad suggestions when entering e.g. + select * from tbl1 a join tbl2 b + So check that the preceding token is a JOIN keyword + + :param statement: an sqlparse.sql.Statement + :return: boolean + """ + + if not statement or not statement.tokens: + return False + + last_tok = statement.token_prev(len(statement.tokens))[1] + return last_tok.value.lower().endswith("join") and last_tok.value.lower() not in ( + "cross join", + "natural join", + ) diff --git a/pgcli/pgbuffer.py b/pgcli/pgbuffer.py new file mode 100644 index 0000000..706ed25 --- /dev/null +++ b/pgcli/pgbuffer.py @@ -0,0 +1,50 @@ +import logging + +from prompt_toolkit.enums import DEFAULT_BUFFER +from prompt_toolkit.filters import Condition +from prompt_toolkit.application import get_app +from .packages.parseutils.utils import is_open_quote + +_logger = logging.getLogger(__name__) + + +def _is_complete(sql): + # A complete command is an sql statement that ends with a semicolon, unless + # there's an open quote surrounding it, as is common when writing a + # CREATE FUNCTION command + return sql.endswith(";") and not is_open_quote(sql) + + +""" +Returns True if the buffer contents should be handled (i.e. the query/command +executed) immediately. This is necessary as we use prompt_toolkit in multiline +mode, which by default will insert new lines on Enter. +""" + + +def buffer_should_be_handled(pgcli): + @Condition + def cond(): + if not pgcli.multi_line: + _logger.debug("Not in multi-line mode. Handle the buffer.") + return True + + if pgcli.multiline_mode == "safe": + _logger.debug("Multi-line mode is set to 'safe'. Do NOT handle the buffer.") + return False + + doc = get_app().layout.get_buffer_by_name(DEFAULT_BUFFER).document + text = doc.text.strip() + + return ( + text.startswith("\\") # Special Command + or text.endswith(r"\e") # Special Command + or text.endswith(r"\G") # Ended with \e which should launch the editor + or _is_complete(text) # A complete SQL command + or (text == "exit") # Exit doesn't need semi-colon + or (text == "quit") # Quit doesn't need semi-colon + or (text == ":q") # To all the vim fans out there + or (text == "") # Just a plain enter without any text + ) + + return cond diff --git a/pgcli/pgclirc b/pgcli/pgclirc new file mode 100644 index 0000000..e97afda --- /dev/null +++ b/pgcli/pgclirc @@ -0,0 +1,195 @@ +# vi: ft=dosini +[main] + +# Enables context sensitive auto-completion. If this is disabled, all +# possible completions will be listed. +smart_completion = True + +# Display the completions in several columns. (More completions will be +# visible.) +wider_completion_menu = False + +# Multi-line mode allows breaking up the sql statements into multiple lines. If +# this is set to True, then the end of the statements must have a semi-colon. +# If this is set to False then sql statements can't be split into multiple +# lines. End of line (return) is considered as the end of the statement. +multi_line = False + +# If multi_line_mode is set to "psql", in multi-line mode, [Enter] will execute +# the current input if the input ends in a semicolon. +# If multi_line_mode is set to "safe", in multi-line mode, [Enter] will always +# insert a newline, and [Esc] [Enter] or [Alt]-[Enter] must be used to execute +# a command. +multi_line_mode = psql + +# Destructive warning mode will alert you before executing a sql statement +# that may cause harm to the database such as "drop table", "drop database" +# or "shutdown". +destructive_warning = True + +# Enables expand mode, which is similar to `\x` in psql. +expand = False + +# Enables auto expand mode, which is similar to `\x auto` in psql. +auto_expand = False + +# If set to True, table suggestions will include a table alias +generate_aliases = False + +# log_file location. +# In Unix/Linux: ~/.config/pgcli/log +# In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\log +# %USERPROFILE% is typically C:\Users\{username} +log_file = default + +# keyword casing preference. Possible values: "lower", "upper", "auto" +keyword_casing = auto + +# casing_file location. +# In Unix/Linux: ~/.config/pgcli/casing +# In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\casing +# %USERPROFILE% is typically C:\Users\{username} +casing_file = default + +# If generate_casing_file is set to True and there is no file in the above +# location, one will be generated based on usage in SQL/PLPGSQL functions. +generate_casing_file = False + +# Casing of column headers based on the casing_file described above +case_column_headers = True + +# history_file location. +# In Unix/Linux: ~/.config/pgcli/history +# In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\history +# %USERPROFILE% is typically C:\Users\{username} +history_file = default + +# Default log level. Possible values: "CRITICAL", "ERROR", "WARNING", "INFO" +# and "DEBUG". "NONE" disables logging. +log_level = INFO + +# Order of columns when expanding * to column list +# Possible values: "table_order" and "alphabetic" +asterisk_column_order = table_order + +# Whether to qualify with table alias/name when suggesting columns +# Possible values: "always", "never" and "if_more_than_one_table" +qualify_columns = if_more_than_one_table + +# When no schema is entered, only suggest objects in search_path +search_path_filter = False + +# Default pager. +# By default 'PAGER' environment variable is used +# pager = less -SRXF + +# Timing of sql statements and table rendering. +timing = True + +# Show/hide the informational toolbar with function keymap at the footer. +show_bottom_toolbar = True + +# Table format. Possible values: psql, plain, simple, grid, fancy_grid, pipe, +# ascii, double, github, orgtbl, rst, mediawiki, html, latex, latex_booktabs, +# textile, moinmoin, jira, vertical, tsv, csv. +# Recommended: psql, fancy_grid and grid. +table_format = psql + +# Syntax Style. Possible values: manni, igor, xcode, vim, autumn, vs, rrt, +# native, perldoc, borland, tango, emacs, friendly, monokai, paraiso-dark, +# colorful, murphy, bw, pastie, paraiso-light, trac, default, fruity +syntax_style = default + +# Keybindings: +# When Vi mode is enabled you can use modal editing features offered by Vi in the REPL. +# When Vi mode is disabled emacs keybindings such as Ctrl-A for home and Ctrl-E +# for end are available in the REPL. +vi = False + +# Error handling +# When one of multiple SQL statements causes an error, choose to either +# continue executing the remaining statements, or stopping +# Possible values "STOP" or "RESUME" +on_error = STOP + +# Set threshold for row limit. Use 0 to disable limiting. +row_limit = 1000 + +# Skip intro on startup and goodbye on exit +less_chatty = False + +# Postgres prompt +# \t - Current date and time +# \u - Username +# \h - Short hostname of the server (up to first '.') +# \H - Hostname of the server +# \d - Database name +# \p - Database port +# \i - Postgres PID +# \# - "@" sign if logged in as superuser, '>' in other case +# \n - Newline +# \dsn_alias - name of dsn alias if -D option is used (empty otherwise) +# \x1b[...m - insert ANSI escape sequence +# eg: prompt = '\x1b[35m\u@\x1b[32m\h:\x1b[36m\d>' +prompt = '\u@\h:\d> ' + +# Number of lines to reserve for the suggestion menu +min_num_menu_lines = 4 + +# Character used to left pad multi-line queries to match the prompt size. +multiline_continuation_char = '' + +# The string used in place of a null value. +null_string = '' + +# manage pager on startup +enable_pager = True + +# Use keyring to automatically save and load password in a secure manner +keyring = True + +# Custom colors for the completion menu, toolbar, etc. +[colors] +completion-menu.completion.current = 'bg:#ffffff #000000' +completion-menu.completion = 'bg:#008888 #ffffff' +completion-menu.meta.completion.current = 'bg:#44aaaa #000000' +completion-menu.meta.completion = 'bg:#448888 #ffffff' +completion-menu.multi-column-meta = 'bg:#aaffff #000000' +scrollbar.arrow = 'bg:#003333' +scrollbar = 'bg:#00aaaa' +selected = '#ffffff bg:#6666aa' +search = '#ffffff bg:#4444aa' +search.current = '#ffffff bg:#44aa44' +bottom-toolbar = 'bg:#222222 #aaaaaa' +bottom-toolbar.off = 'bg:#222222 #888888' +bottom-toolbar.on = 'bg:#222222 #ffffff' +search-toolbar = 'noinherit bold' +search-toolbar.text = 'nobold' +system-toolbar = 'noinherit bold' +arg-toolbar = 'noinherit bold' +arg-toolbar.text = 'nobold' +bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold' +bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold' +literal.string = '#ba2121' +literal.number = '#666666' +keyword = 'bold #008000' + +# style classes for colored table output +output.header = "#00ff5f bold" +output.odd-row = "" +output.even-row = "" +output.null = "#808080" + +# Named queries are queries you can execute by name. +[named queries] + +# DSN to call by -D option +[alias_dsn] +# example_dsn = postgresql://[user[:password]@][netloc][:port][/dbname] + +# Format for number representation +# for decimal "d" - 12345678, ",d" - 12,345,678 +# for float "g" - 123456.78, ",g" - 123,456.78 +[data_formats] +decimal = "" +float = "" diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py new file mode 100644 index 0000000..9c95a01 --- /dev/null +++ b/pgcli/pgcompleter.py @@ -0,0 +1,1046 @@ +import logging +import re +from itertools import count, repeat, chain +import operator +from collections import namedtuple, defaultdict, OrderedDict +from cli_helpers.tabular_output import TabularOutputFormatter +from pgspecial.namedqueries import NamedQueries +from prompt_toolkit.completion import Completer, Completion, PathCompleter +from prompt_toolkit.document import Document +from .packages.sqlcompletion import ( + FromClauseItem, + suggest_type, + Special, + Database, + Schema, + Table, + TableFormat, + Function, + Column, + View, + Keyword, + NamedQuery, + Datatype, + Alias, + Path, + JoinCondition, + Join, +) +from .packages.parseutils.meta import ColumnMetadata, ForeignKey +from .packages.parseutils.utils import last_word +from .packages.parseutils.tables import TableReference +from .packages.pgliterals.main import get_literals +from .packages.prioritization import PrevalenceCounter +from .config import load_config, config_location + +_logger = logging.getLogger(__name__) + +Match = namedtuple("Match", ["completion", "priority"]) + +_SchemaObject = namedtuple("SchemaObject", "name schema meta") + + +def SchemaObject(name, schema=None, meta=None): + return _SchemaObject(name, schema, meta) + + +_Candidate = namedtuple("Candidate", "completion prio meta synonyms prio2 display") + + +def Candidate( + completion, prio=None, meta=None, synonyms=None, prio2=None, display=None +): + return _Candidate( + completion, prio, meta, synonyms or [completion], prio2, display or completion + ) + + +# Used to strip trailing '::some_type' from default-value expressions +arg_default_type_strip_regex = re.compile(r"::[\w\.]+(\[\])?$") + +normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"' + + +def generate_alias(tbl): + """Generate a table alias, consisting of all upper-case letters in + the table name, or, if there are no upper-case letters, the first letter + + all letters preceded by _ + param tbl - unescaped name of the table to alias + """ + return "".join( + [l for l in tbl if l.isupper()] + or [l for l, prev in zip(tbl, "_" + tbl) if prev == "_" and l != "_"] + ) + + +class PGCompleter(Completer): + # keywords_tree: A dict mapping keywords to well known following keywords. + # e.g. 'CREATE': ['TABLE', 'USER', ...], + keywords_tree = get_literals("keywords", type_=dict) + keywords = tuple(set(chain(keywords_tree.keys(), *keywords_tree.values()))) + functions = get_literals("functions") + datatypes = get_literals("datatypes") + reserved_words = set(get_literals("reserved")) + + def __init__(self, smart_completion=True, pgspecial=None, settings=None): + super(PGCompleter, self).__init__() + self.smart_completion = smart_completion + self.pgspecial = pgspecial + self.prioritizer = PrevalenceCounter() + settings = settings or {} + self.signature_arg_style = settings.get( + "signature_arg_style", "{arg_name} {arg_type}" + ) + self.call_arg_style = settings.get( + "call_arg_style", "{arg_name: <{max_arg_len}} := {arg_default}" + ) + self.call_arg_display_style = settings.get( + "call_arg_display_style", "{arg_name}" + ) + self.call_arg_oneliner_max = settings.get("call_arg_oneliner_max", 2) + self.search_path_filter = settings.get("search_path_filter") + self.generate_aliases = settings.get("generate_aliases") + self.casing_file = settings.get("casing_file") + self.insert_col_skip_patterns = [ + re.compile(pattern) + for pattern in settings.get( + "insert_col_skip_patterns", [r"^now\(\)$", r"^nextval\("] + ) + ] + self.generate_casing_file = settings.get("generate_casing_file") + self.qualify_columns = settings.get("qualify_columns", "if_more_than_one_table") + self.asterisk_column_order = settings.get( + "asterisk_column_order", "table_order" + ) + + keyword_casing = settings.get("keyword_casing", "upper").lower() + if keyword_casing not in ("upper", "lower", "auto"): + keyword_casing = "upper" + self.keyword_casing = keyword_casing + self.name_pattern = re.compile(r"^[_a-z][_a-z0-9\$]*$") + + self.databases = [] + self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}} + self.search_path = [] + self.casing = {} + + self.all_completions = set(self.keywords + self.functions) + + def escape_name(self, name): + if name and ( + (not self.name_pattern.match(name)) + or (name.upper() in self.reserved_words) + or (name.upper() in self.functions) + ): + name = '"%s"' % name + + return name + + def escape_schema(self, name): + return "'{}'".format(self.unescape_name(name)) + + def unescape_name(self, name): + """ Unquote a string.""" + if name and name[0] == '"' and name[-1] == '"': + name = name[1:-1] + + return name + + def escaped_names(self, names): + return [self.escape_name(name) for name in names] + + def extend_database_names(self, databases): + self.databases.extend(databases) + + def extend_keywords(self, additional_keywords): + self.keywords.extend(additional_keywords) + self.all_completions.update(additional_keywords) + + def extend_schemata(self, schemata): + + # schemata is a list of schema names + schemata = self.escaped_names(schemata) + metadata = self.dbmetadata["tables"] + for schema in schemata: + metadata[schema] = {} + + # dbmetadata.values() are the 'tables' and 'functions' dicts + for metadata in self.dbmetadata.values(): + for schema in schemata: + metadata[schema] = {} + + self.all_completions.update(schemata) + + def extend_casing(self, words): + """extend casing data + + :return: + """ + # casing should be a dict {lowercasename:PreferredCasingName} + self.casing = dict((word.lower(), word) for word in words) + + def extend_relations(self, data, kind): + """extend metadata for tables or views. + + :param data: list of (schema_name, rel_name) tuples + :param kind: either 'tables' or 'views' + + :return: + + """ + + data = [self.escaped_names(d) for d in data] + + # dbmetadata['tables']['schema_name']['table_name'] should be an + # OrderedDict {column_name:ColumnMetaData}. + metadata = self.dbmetadata[kind] + for schema, relname in data: + try: + metadata[schema][relname] = OrderedDict() + except KeyError: + _logger.error( + "%r %r listed in unrecognized schema %r", kind, relname, schema + ) + self.all_completions.add(relname) + + def extend_columns(self, column_data, kind): + """extend column metadata. + + :param column_data: list of (schema_name, rel_name, column_name, + column_type, has_default, default) tuples + :param kind: either 'tables' or 'views' + + :return: + + """ + metadata = self.dbmetadata[kind] + for schema, relname, colname, datatype, has_default, default in column_data: + (schema, relname, colname) = self.escaped_names([schema, relname, colname]) + column = ColumnMetadata( + name=colname, + datatype=datatype, + has_default=has_default, + default=default, + ) + metadata[schema][relname][colname] = column + self.all_completions.add(colname) + + def extend_functions(self, func_data): + + # func_data is a list of function metadata namedtuples + + # dbmetadata['schema_name']['functions']['function_name'] should return + # the function metadata namedtuple for the corresponding function + metadata = self.dbmetadata["functions"] + + for f in func_data: + schema, func = self.escaped_names([f.schema_name, f.func_name]) + + if func in metadata[schema]: + metadata[schema][func].append(f) + else: + metadata[schema][func] = [f] + + self.all_completions.add(func) + + self._refresh_arg_list_cache() + + def _refresh_arg_list_cache(self): + # We keep a cache of {function_usage:{function_metadata: function_arg_list_string}} + # This is used when suggesting functions, to avoid the latency that would result + # if we'd recalculate the arg lists each time we suggest functions (in large DBs) + self._arg_list_cache = { + usage: { + meta: self._arg_list(meta, usage) + for sch, funcs in self.dbmetadata["functions"].items() + for func, metas in funcs.items() + for meta in metas + } + for usage in ("call", "call_display", "signature") + } + + def extend_foreignkeys(self, fk_data): + + # fk_data is a list of ForeignKey namedtuples, with fields + # parentschema, childschema, parenttable, childtable, + # parentcolumns, childcolumns + + # These are added as a list of ForeignKey namedtuples to the + # ColumnMetadata namedtuple for both the child and parent + meta = self.dbmetadata["tables"] + + for fk in fk_data: + e = self.escaped_names + parentschema, childschema = e([fk.parentschema, fk.childschema]) + parenttable, childtable = e([fk.parenttable, fk.childtable]) + childcol, parcol = e([fk.childcolumn, fk.parentcolumn]) + childcolmeta = meta[childschema][childtable][childcol] + parcolmeta = meta[parentschema][parenttable][parcol] + fk = ForeignKey( + parentschema, parenttable, parcol, childschema, childtable, childcol + ) + childcolmeta.foreignkeys.append((fk)) + parcolmeta.foreignkeys.append((fk)) + + def extend_datatypes(self, type_data): + + # dbmetadata['datatypes'][schema_name][type_name] should store type + # metadata, such as composite type field names. Currently, we're not + # storing any metadata beyond typename, so just store None + meta = self.dbmetadata["datatypes"] + + for t in type_data: + schema, type_name = self.escaped_names(t) + meta[schema][type_name] = None + self.all_completions.add(type_name) + + def extend_query_history(self, text, is_init=False): + if is_init: + # During completer initialization, only load keyword preferences, + # not names + self.prioritizer.update_keywords(text) + else: + self.prioritizer.update(text) + + def set_search_path(self, search_path): + self.search_path = self.escaped_names(search_path) + + def reset_completions(self): + self.databases = [] + self.special_commands = [] + self.search_path = [] + self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}} + self.all_completions = set(self.keywords + self.functions) + + def find_matches(self, text, collection, mode="fuzzy", meta=None): + """Find completion matches for the given text. + + Given the user's input text and a collection of available + completions, find completions matching the last word of the + text. + + `collection` can be either a list of strings or a list of Candidate + namedtuples. + `mode` can be either 'fuzzy', or 'strict' + 'fuzzy': fuzzy matching, ties broken by name prevalance + `keyword`: start only matching, ties broken by keyword prevalance + + yields prompt_toolkit Completion instances for any matches found + in the collection of available completions. + + """ + if not collection: + return [] + prio_order = [ + "keyword", + "function", + "view", + "table", + "datatype", + "database", + "schema", + "column", + "table alias", + "join", + "name join", + "fk join", + "table format", + ] + type_priority = prio_order.index(meta) if meta in prio_order else -1 + text = last_word(text, include="most_punctuations").lower() + text_len = len(text) + + if text and text[0] == '"': + # text starts with double quote; user is manually escaping a name + # Match on everything that follows the double-quote. Note that + # text_len is calculated before removing the quote, so the + # Completion.position value is correct + text = text[1:] + + if mode == "fuzzy": + fuzzy = True + priority_func = self.prioritizer.name_count + else: + fuzzy = False + priority_func = self.prioritizer.keyword_count + + # Construct a `_match` function for either fuzzy or non-fuzzy matching + # The match function returns a 2-tuple used for sorting the matches, + # or None if the item doesn't match + # Note: higher priority values mean more important, so use negative + # signs to flip the direction of the tuple + if fuzzy: + regex = ".*?".join(map(re.escape, text)) + pat = re.compile("(%s)" % regex) + + def _match(item): + if item.lower()[: len(text) + 1] in (text, text + " "): + # Exact match of first word in suggestion + # This is to get exact alias matches to the top + # E.g. for input `e`, 'Entries E' should be on top + # (before e.g. `EndUsers EU`) + return float("Infinity"), -1 + r = pat.search(self.unescape_name(item.lower())) + if r: + return -len(r.group()), -r.start() + + else: + match_end_limit = len(text) + + def _match(item): + match_point = item.lower().find(text, 0, match_end_limit) + if match_point >= 0: + # Use negative infinity to force keywords to sort after all + # fuzzy matches + return -float("Infinity"), -match_point + + matches = [] + for cand in collection: + if isinstance(cand, _Candidate): + item, prio, display_meta, synonyms, prio2, display = cand + if display_meta is None: + display_meta = meta + syn_matches = (_match(x) for x in synonyms) + # Nones need to be removed to avoid max() crashing in Python 3 + syn_matches = [m for m in syn_matches if m] + sort_key = max(syn_matches) if syn_matches else None + else: + item, display_meta, prio, prio2, display = cand, meta, 0, 0, cand + sort_key = _match(cand) + + if sort_key: + if display_meta and len(display_meta) > 50: + # Truncate meta-text to 50 characters, if necessary + display_meta = display_meta[:47] + "..." + + # Lexical order of items in the collection, used for + # tiebreaking items with the same match group length and start + # position. Since we use *higher* priority to mean "more + # important," we use -ord(c) to prioritize "aa" > "ab" and end + # with 1 to prioritize shorter strings (ie "user" > "users"). + # We first do a case-insensitive sort and then a + # case-sensitive one as a tie breaker. + # We also use the unescape_name to make sure quoted names have + # the same priority as unquoted names. + lexical_priority = ( + tuple( + 0 if c in (" _") else -ord(c) + for c in self.unescape_name(item.lower()) + ) + + (1,) + + tuple(c for c in item) + ) + + item = self.case(item) + display = self.case(display) + priority = ( + sort_key, + type_priority, + prio, + priority_func(item), + prio2, + lexical_priority, + ) + matches.append( + Match( + completion=Completion( + text=item, + start_position=-text_len, + display_meta=display_meta, + display=display, + ), + priority=priority, + ) + ) + return matches + + def case(self, word): + return self.casing.get(word, word) + + def get_completions(self, document, complete_event, smart_completion=None): + word_before_cursor = document.get_word_before_cursor(WORD=True) + + if smart_completion is None: + smart_completion = self.smart_completion + + # If smart_completion is off then match any word that starts with + # 'word_before_cursor'. + if not smart_completion: + matches = self.find_matches( + word_before_cursor, self.all_completions, mode="strict" + ) + completions = [m.completion for m in matches] + return sorted(completions, key=operator.attrgetter("text")) + + matches = [] + suggestions = suggest_type(document.text, document.text_before_cursor) + + for suggestion in suggestions: + suggestion_type = type(suggestion) + _logger.debug("Suggestion type: %r", suggestion_type) + + # Map suggestion type to method + # e.g. 'table' -> self.get_table_matches + matcher = self.suggestion_matchers[suggestion_type] + matches.extend(matcher(self, suggestion, word_before_cursor)) + + # Sort matches so highest priorities are first + matches = sorted(matches, key=operator.attrgetter("priority"), reverse=True) + + return [m.completion for m in matches] + + def get_column_matches(self, suggestion, word_before_cursor): + tables = suggestion.table_refs + do_qualify = suggestion.qualifiable and { + "always": True, + "never": False, + "if_more_than_one_table": len(tables) > 1, + }[self.qualify_columns] + qualify = lambda col, tbl: ( + (tbl + "." + self.case(col)) if do_qualify else self.case(col) + ) + _logger.debug("Completion column scope: %r", tables) + scoped_cols = self.populate_scoped_cols(tables, suggestion.local_tables) + + def make_cand(name, ref): + synonyms = (name, generate_alias(self.case(name))) + return Candidate(qualify(name, ref), 0, "column", synonyms) + + def flat_cols(): + return [ + make_cand(c.name, t.ref) + for t, cols in scoped_cols.items() + for c in cols + ] + + if suggestion.require_last_table: + # require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should + # suggest only columns that appear in the last table and one more + ltbl = tables[-1].ref + other_tbl_cols = set( + c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs + ) + scoped_cols = { + t: [col for col in cols if col.name in other_tbl_cols] + for t, cols in scoped_cols.items() + if t.ref == ltbl + } + lastword = last_word(word_before_cursor, include="most_punctuations") + if lastword == "*": + if suggestion.context == "insert": + + def filter(col): + if not col.has_default: + return True + return not any( + p.match(col.default) for p in self.insert_col_skip_patterns + ) + + scoped_cols = { + t: [col for col in cols if filter(col)] + for t, cols in scoped_cols.items() + } + if self.asterisk_column_order == "alphabetic": + for cols in scoped_cols.values(): + cols.sort(key=operator.attrgetter("name")) + if ( + lastword != word_before_cursor + and len(tables) == 1 + and word_before_cursor[-len(lastword) - 1] == "." + ): + # User typed x.*; replicate "x." for all columns except the + # first, which gets the original (as we only replace the "*"") + sep = ", " + word_before_cursor[:-1] + collist = sep.join(self.case(c.completion) for c in flat_cols()) + else: + collist = ", ".join( + qualify(c.name, t.ref) for t, cs in scoped_cols.items() for c in cs + ) + + return [ + Match( + completion=Completion( + collist, -1, display_meta="columns", display="*" + ), + priority=(1, 1, 1), + ) + ] + + return self.find_matches(word_before_cursor, flat_cols(), meta="column") + + def alias(self, tbl, tbls): + """Generate a unique table alias + tbl - name of the table to alias, quoted if it needs to be + tbls - TableReference iterable of tables already in query + """ + tbl = self.case(tbl) + tbls = set(normalize_ref(t.ref) for t in tbls) + if self.generate_aliases: + tbl = generate_alias(self.unescape_name(tbl)) + if normalize_ref(tbl) not in tbls: + return tbl + elif tbl[0] == '"': + aliases = ('"' + tbl[1:-1] + str(i) + '"' for i in count(2)) + else: + aliases = (tbl + str(i) for i in count(2)) + return next(a for a in aliases if normalize_ref(a) not in tbls) + + def get_join_matches(self, suggestion, word_before_cursor): + tbls = suggestion.table_refs + cols = self.populate_scoped_cols(tbls) + # Set up some data structures for efficient access + qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls) + ref_prio = dict((normalize_ref(t.ref), n) for n, t in enumerate(tbls)) + refs = set(normalize_ref(t.ref) for t in tbls) + other_tbls = set((t.schema, t.name) for t in list(cols)[:-1]) + joins = [] + # Iterate over FKs in existing tables to find potential joins + fks = ( + (fk, rtbl, rcol) + for rtbl, rcols in cols.items() + for rcol in rcols + for fk in rcol.foreignkeys + ) + col = namedtuple("col", "schema tbl col") + for fk, rtbl, rcol in fks: + right = col(rtbl.schema, rtbl.name, rcol.name) + child = col(fk.childschema, fk.childtable, fk.childcolumn) + parent = col(fk.parentschema, fk.parenttable, fk.parentcolumn) + left = child if parent == right else parent + if suggestion.schema and left.schema != suggestion.schema: + continue + c = self.case + if self.generate_aliases or normalize_ref(left.tbl) in refs: + lref = self.alias(left.tbl, suggestion.table_refs) + join = "{0} {4} ON {4}.{1} = {2}.{3}".format( + c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref + ) + else: + join = "{0} ON {0}.{1} = {2}.{3}".format( + c(left.tbl), c(left.col), rtbl.ref, c(right.col) + ) + alias = generate_alias(self.case(left.tbl)) + synonyms = [ + join, + "{0} ON {0}.{1} = {2}.{3}".format( + alias, c(left.col), rtbl.ref, c(right.col) + ), + ] + # Schema-qualify if (1) new table in same schema as old, and old + # is schema-qualified, or (2) new in other schema, except public + if not suggestion.schema and ( + qualified[normalize_ref(rtbl.ref)] + and left.schema == right.schema + or left.schema not in (right.schema, "public") + ): + join = left.schema + "." + join + prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + ( + 0 if (left.schema, left.tbl) in other_tbls else 1 + ) + joins.append(Candidate(join, prio, "join", synonyms=synonyms)) + + return self.find_matches(word_before_cursor, joins, meta="join") + + def get_join_condition_matches(self, suggestion, word_before_cursor): + col = namedtuple("col", "schema tbl col") + tbls = self.populate_scoped_cols(suggestion.table_refs).items + cols = [(t, c) for t, cs in tbls() for c in cs] + try: + lref = (suggestion.parent or suggestion.table_refs[-1]).ref + ltbl, lcols = [(t, cs) for (t, cs) in tbls() if t.ref == lref][-1] + except IndexError: # The user typed an incorrect table qualifier + return [] + conds, found_conds = [], set() + + def add_cond(lcol, rcol, rref, prio, meta): + prefix = "" if suggestion.parent else ltbl.ref + "." + case = self.case + cond = prefix + case(lcol) + " = " + rref + "." + case(rcol) + if cond not in found_conds: + found_conds.add(cond) + conds.append(Candidate(cond, prio + ref_prio[rref], meta)) + + def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]} + d = defaultdict(list) + for pair in pairs: + d[pair[0]].append(pair[1]) + return d + + # Tables that are closer to the cursor get higher prio + ref_prio = dict((tbl.ref, num) for num, tbl in enumerate(suggestion.table_refs)) + # Map (schema, table, col) to tables + coldict = list_dict( + ((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref + ) + # For each fk from the left table, generate a join condition if + # the other table is also in the scope + fks = ((fk, lcol.name) for lcol in lcols for fk in lcol.foreignkeys) + for fk, lcol in fks: + left = col(ltbl.schema, ltbl.name, lcol) + child = col(fk.childschema, fk.childtable, fk.childcolumn) + par = col(fk.parentschema, fk.parenttable, fk.parentcolumn) + left, right = (child, par) if left == child else (par, child) + for rtbl in coldict[right]: + add_cond(left.col, right.col, rtbl.ref, 2000, "fk join") + # For name matching, use a {(colname, coltype): TableReference} dict + coltyp = namedtuple("coltyp", "name datatype") + col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols) + # Find all name-match join conditions + for c in (coltyp(c.name, c.datatype) for c in lcols): + for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref): + prio = 1000 if c.datatype in ("integer", "bigint", "smallint") else 0 + add_cond(c.name, c.name, rtbl.ref, prio, "name join") + + return self.find_matches(word_before_cursor, conds, meta="join") + + def get_function_matches(self, suggestion, word_before_cursor, alias=False): + + if suggestion.usage == "from": + # Only suggest functions allowed in FROM clause + + def filt(f): + return ( + not f.is_aggregate + and not f.is_window + and not f.is_extension + and (f.is_public or f.schema_name == suggestion.schema) + ) + + else: + alias = False + + def filt(f): + return not f.is_extension and ( + f.is_public or f.schema_name == suggestion.schema + ) + + arg_mode = {"signature": "signature", "special": None}.get( + suggestion.usage, "call" + ) + + # Function overloading means we way have multiple functions of the same + # name at this point, so keep unique names only + all_functions = self.populate_functions(suggestion.schema, filt) + funcs = set( + self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions + ) + + matches = self.find_matches(word_before_cursor, funcs, meta="function") + + if not suggestion.schema and not suggestion.usage: + # also suggest hardcoded functions using startswith matching + predefined_funcs = self.find_matches( + word_before_cursor, self.functions, mode="strict", meta="function" + ) + matches.extend(predefined_funcs) + + return matches + + def get_schema_matches(self, suggestion, word_before_cursor): + schema_names = self.dbmetadata["tables"].keys() + + # Unless we're sure the user really wants them, hide schema names + # starting with pg_, which are mostly temporary schemas + if not word_before_cursor.startswith("pg_"): + schema_names = [s for s in schema_names if not s.startswith("pg_")] + + if suggestion.quoted: + schema_names = [self.escape_schema(s) for s in schema_names] + + return self.find_matches(word_before_cursor, schema_names, meta="schema") + + def get_from_clause_item_matches(self, suggestion, word_before_cursor): + alias = self.generate_aliases + s = suggestion + t_sug = Table(s.schema, s.table_refs, s.local_tables) + v_sug = View(s.schema, s.table_refs) + f_sug = Function(s.schema, s.table_refs, usage="from") + return ( + self.get_table_matches(t_sug, word_before_cursor, alias) + + self.get_view_matches(v_sug, word_before_cursor, alias) + + self.get_function_matches(f_sug, word_before_cursor, alias) + ) + + def _arg_list(self, func, usage): + """Returns a an arg list string, e.g. `(_foo:=23)` for a func. + + :param func is a FunctionMetadata object + :param usage is 'call', 'call_display' or 'signature' + + """ + template = { + "call": self.call_arg_style, + "call_display": self.call_arg_display_style, + "signature": self.signature_arg_style, + }[usage] + args = func.args() + if not template: + return "()" + elif usage == "call" and len(args) < 2: + return "()" + elif usage == "call" and func.has_variadic(): + return "()" + multiline = usage == "call" and len(args) > self.call_arg_oneliner_max + max_arg_len = max(len(a.name) for a in args) if multiline else 0 + args = ( + self._format_arg(template, arg, arg_num + 1, max_arg_len) + for arg_num, arg in enumerate(args) + ) + if multiline: + return "(" + ",".join("\n " + a for a in args if a) + "\n)" + else: + return "(" + ", ".join(a for a in args if a) + ")" + + def _format_arg(self, template, arg, arg_num, max_arg_len): + if not template: + return None + if arg.has_default: + arg_default = "NULL" if arg.default is None else arg.default + # Remove trailing ::(schema.)type + arg_default = arg_default_type_strip_regex.sub("", arg_default) + else: + arg_default = "" + return template.format( + max_arg_len=max_arg_len, + arg_name=self.case(arg.name), + arg_num=arg_num, + arg_type=arg.datatype, + arg_default=arg_default, + ) + + def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None): + """Returns a Candidate namedtuple. + + :param tbl is a SchemaObject + :param arg_mode determines what type of arg list to suffix for functions. + Possible values: call, signature + + """ + cased_tbl = self.case(tbl.name) + if do_alias: + alias = self.alias(cased_tbl, suggestion.table_refs) + synonyms = (cased_tbl, generate_alias(cased_tbl)) + maybe_alias = (" " + alias) if do_alias else "" + maybe_schema = (self.case(tbl.schema) + ".") if tbl.schema else "" + suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else "" + if arg_mode == "call": + display_suffix = self._arg_list_cache["call_display"][tbl.meta] + elif arg_mode == "signature": + display_suffix = self._arg_list_cache["signature"][tbl.meta] + else: + display_suffix = "" + item = maybe_schema + cased_tbl + suffix + maybe_alias + display = maybe_schema + cased_tbl + display_suffix + maybe_alias + prio2 = 0 if tbl.schema else 1 + return Candidate(item, synonyms=synonyms, prio2=prio2, display=display) + + def get_table_matches(self, suggestion, word_before_cursor, alias=False): + tables = self.populate_schema_objects(suggestion.schema, "tables") + tables.extend(SchemaObject(tbl.name) for tbl in suggestion.local_tables) + + # Unless we're sure the user really wants them, don't suggest the + # pg_catalog tables that are implicitly on the search path + if not suggestion.schema and (not word_before_cursor.startswith("pg_")): + tables = [t for t in tables if not t.name.startswith("pg_")] + tables = [self._make_cand(t, alias, suggestion) for t in tables] + return self.find_matches(word_before_cursor, tables, meta="table") + + def get_table_formats(self, _, word_before_cursor): + formats = TabularOutputFormatter().supported_formats + return self.find_matches(word_before_cursor, formats, meta="table format") + + def get_view_matches(self, suggestion, word_before_cursor, alias=False): + views = self.populate_schema_objects(suggestion.schema, "views") + + if not suggestion.schema and (not word_before_cursor.startswith("pg_")): + views = [v for v in views if not v.name.startswith("pg_")] + views = [self._make_cand(v, alias, suggestion) for v in views] + return self.find_matches(word_before_cursor, views, meta="view") + + def get_alias_matches(self, suggestion, word_before_cursor): + aliases = suggestion.aliases + return self.find_matches(word_before_cursor, aliases, meta="table alias") + + def get_database_matches(self, _, word_before_cursor): + return self.find_matches(word_before_cursor, self.databases, meta="database") + + def get_keyword_matches(self, suggestion, word_before_cursor): + keywords = self.keywords_tree.keys() + # Get well known following keywords for the last token. If any, narrow + # candidates to this list. + next_keywords = self.keywords_tree.get(suggestion.last_token, []) + if next_keywords: + keywords = next_keywords + + casing = self.keyword_casing + if casing == "auto": + if word_before_cursor and word_before_cursor[-1].islower(): + casing = "lower" + else: + casing = "upper" + + if casing == "upper": + keywords = [k.upper() for k in keywords] + else: + keywords = [k.lower() for k in keywords] + + return self.find_matches( + word_before_cursor, keywords, mode="strict", meta="keyword" + ) + + def get_path_matches(self, _, word_before_cursor): + completer = PathCompleter(expanduser=True) + document = Document( + text=word_before_cursor, cursor_position=len(word_before_cursor) + ) + for c in completer.get_completions(document, None): + yield Match(completion=c, priority=(0,)) + + def get_special_matches(self, _, word_before_cursor): + if not self.pgspecial: + return [] + + commands = self.pgspecial.commands + cmds = commands.keys() + cmds = [Candidate(cmd, 0, commands[cmd].description) for cmd in cmds] + return self.find_matches(word_before_cursor, cmds, mode="strict") + + def get_datatype_matches(self, suggestion, word_before_cursor): + # suggest custom datatypes + types = self.populate_schema_objects(suggestion.schema, "datatypes") + types = [self._make_cand(t, False, suggestion) for t in types] + matches = self.find_matches(word_before_cursor, types, meta="datatype") + + if not suggestion.schema: + # Also suggest hardcoded types + matches.extend( + self.find_matches( + word_before_cursor, self.datatypes, mode="strict", meta="datatype" + ) + ) + + return matches + + def get_namedquery_matches(self, _, word_before_cursor): + return self.find_matches( + word_before_cursor, NamedQueries.instance.list(), meta="named query" + ) + + suggestion_matchers = { + FromClauseItem: get_from_clause_item_matches, + JoinCondition: get_join_condition_matches, + Join: get_join_matches, + Column: get_column_matches, + Function: get_function_matches, + Schema: get_schema_matches, + Table: get_table_matches, + TableFormat: get_table_formats, + View: get_view_matches, + Alias: get_alias_matches, + Database: get_database_matches, + Keyword: get_keyword_matches, + Special: get_special_matches, + Datatype: get_datatype_matches, + NamedQuery: get_namedquery_matches, + Path: get_path_matches, + } + + def populate_scoped_cols(self, scoped_tbls, local_tbls=()): + """Find all columns in a set of scoped_tables. + + :param scoped_tbls: list of TableReference namedtuples + :param local_tbls: tuple(TableMetadata) + :return: {TableReference:{colname:ColumnMetaData}} + + """ + ctes = dict((normalize_ref(t.name), t.columns) for t in local_tbls) + columns = OrderedDict() + meta = self.dbmetadata + + def addcols(schema, rel, alias, reltype, cols): + tbl = TableReference(schema, rel, alias, reltype == "functions") + if tbl not in columns: + columns[tbl] = [] + columns[tbl].extend(cols) + + for tbl in scoped_tbls: + # Local tables should shadow database tables + if tbl.schema is None and normalize_ref(tbl.name) in ctes: + cols = ctes[normalize_ref(tbl.name)] + addcols(None, tbl.name, "CTE", tbl.alias, cols) + continue + schemas = [tbl.schema] if tbl.schema else self.search_path + for schema in schemas: + relname = self.escape_name(tbl.name) + schema = self.escape_name(schema) + if tbl.is_function: + # Return column names from a set-returning function + # Get an array of FunctionMetadata objects + functions = meta["functions"].get(schema, {}).get(relname) + for func in functions or []: + # func is a FunctionMetadata object + cols = func.fields() + addcols(schema, relname, tbl.alias, "functions", cols) + else: + for reltype in ("tables", "views"): + cols = meta[reltype].get(schema, {}).get(relname) + if cols: + cols = cols.values() + addcols(schema, relname, tbl.alias, reltype, cols) + break + + return columns + + def _get_schemas(self, obj_typ, schema): + """Returns a list of schemas from which to suggest objects. + + :param schema is the schema qualification input by the user (if any) + + """ + metadata = self.dbmetadata[obj_typ] + if schema: + schema = self.escape_name(schema) + return [schema] if schema in metadata else [] + return self.search_path if self.search_path_filter else metadata.keys() + + def _maybe_schema(self, schema, parent): + return None if parent or schema in self.search_path else schema + + def populate_schema_objects(self, schema, obj_type): + """Returns a list of SchemaObjects representing tables or views. + + :param schema is the schema qualification input by the user (if any) + + """ + + return [ + SchemaObject( + name=obj, schema=(self._maybe_schema(schema=sch, parent=schema)) + ) + for sch in self._get_schemas(obj_type, schema) + for obj in self.dbmetadata[obj_type][sch].keys() + ] + + def populate_functions(self, schema, filter_func): + """Returns a list of function SchemaObjects. + + :param filter_func is a function that accepts a FunctionMetadata + namedtuple and returns a boolean indicating whether that + function should be kept or discarded + + """ + + # Because of multiple dispatch, we can have multiple functions + # with the same name, which is why `for meta in metas` is necessary + # in the comprehensions below + return [ + SchemaObject( + name=func, + schema=(self._maybe_schema(schema=sch, parent=schema)), + meta=meta, + ) + for sch in self._get_schemas("functions", schema) + for (func, metas) in self.dbmetadata["functions"][sch].items() + for meta in metas + if filter_func(meta) + ] diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py new file mode 100644 index 0000000..d34bf26 --- /dev/null +++ b/pgcli/pgexecute.py @@ -0,0 +1,857 @@ +import traceback +import logging +import psycopg2 +import psycopg2.extras +import psycopg2.errorcodes +import psycopg2.extensions as ext +import sqlparse +import pgspecial as special +import select +from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn +from .packages.parseutils.meta import FunctionMetadata, ForeignKey + +_logger = logging.getLogger(__name__) + +# Cast all database input to unicode automatically. +# See http://initd.org/psycopg/docs/usage.html#unicode-handling for more info. +ext.register_type(ext.UNICODE) +ext.register_type(ext.UNICODEARRAY) +ext.register_type(ext.new_type((705,), "UNKNOWN", ext.UNICODE)) +# See https://github.com/dbcli/pgcli/issues/426 for more details. +# This registers a unicode type caster for datatype 'RECORD'. +ext.register_type(ext.new_type((2249,), "RECORD", ext.UNICODE)) + +# Cast bytea fields to text. By default, this will render as hex strings with +# Postgres 9+ and as escaped binary in earlier versions. +ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING)) + +# TODO: Get default timeout from pgclirc? +_WAIT_SELECT_TIMEOUT = 1 + + +def _wait_select(conn): + """ + copy-pasted from psycopg2.extras.wait_select + the default implementation doesn't define a timeout in the select calls + """ + while 1: + try: + state = conn.poll() + if state == POLL_OK: + break + elif state == POLL_READ: + select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT) + elif state == POLL_WRITE: + select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT) + else: + raise conn.OperationalError("bad state from poll: %s" % state) + except KeyboardInterrupt: + conn.cancel() + # the loop will be broken by a server error + continue + except select.error as e: + errno = e.args[0] + if errno != 4: + raise + + +# When running a query, make pressing CTRL+C raise a KeyboardInterrupt +# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/ +# See also https://github.com/psycopg/psycopg2/issues/468 +ext.set_wait_callback(_wait_select) + + +def register_date_typecasters(connection): + """ + Casts date and timestamp values to string, resolves issues with out of + range dates (e.g. BC) which psycopg2 can't handle + """ + + def cast_date(value, cursor): + return value + + cursor = connection.cursor() + cursor.execute("SELECT NULL::date") + date_oid = cursor.description[0][1] + cursor.execute("SELECT NULL::timestamp") + timestamp_oid = cursor.description[0][1] + cursor.execute("SELECT NULL::timestamp with time zone") + timestamptz_oid = cursor.description[0][1] + oids = (date_oid, timestamp_oid, timestamptz_oid) + new_type = psycopg2.extensions.new_type(oids, "DATE", cast_date) + psycopg2.extensions.register_type(new_type) + + +def register_json_typecasters(conn, loads_fn): + """Set the function for converting JSON data for a connection. + + Use the supplied function to decode JSON data returned from the database + via the given connection. The function should accept a single argument of + the data as a string encoded in the database's character encoding. + psycopg2's default handler for JSON data is json.loads. + http://initd.org/psycopg/docs/extras.html#json-adaptation + + This function attempts to register the typecaster for both JSON and JSONB + types. + + Returns a set that is a subset of {'json', 'jsonb'} indicating which types + (if any) were successfully registered. + """ + available = set() + + for name in ["json", "jsonb"]: + try: + psycopg2.extras.register_json(conn, loads=loads_fn, name=name) + available.add(name) + except psycopg2.ProgrammingError: + pass + + return available + + +def register_hstore_typecaster(conn): + """ + Instead of using register_hstore() which converts hstore into a python + dict, we query the 'oid' of hstore which will be different for each + database and register a type caster that converts it to unicode. + http://initd.org/psycopg/docs/extras.html#psycopg2.extras.register_hstore + """ + with conn.cursor() as cur: + try: + cur.execute( + "select t.oid FROM pg_type t WHERE t.typname = 'hstore' and t.typisdefined" + ) + oid = cur.fetchone()[0] + ext.register_type(ext.new_type((oid,), "HSTORE", ext.UNICODE)) + except Exception: + pass + + +class PGExecute(object): + + # The boolean argument to the current_schemas function indicates whether + # implicit schemas, e.g. pg_catalog + search_path_query = """ + SELECT * FROM unnest(current_schemas(true))""" + + schemata_query = """ + SELECT nspname + FROM pg_catalog.pg_namespace + ORDER BY 1 """ + + tables_query = """ + SELECT n.nspname schema_name, + c.relname table_name + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n + ON n.oid = c.relnamespace + WHERE c.relkind = ANY(%s) + ORDER BY 1,2;""" + + databases_query = """ + SELECT d.datname + FROM pg_catalog.pg_database d + ORDER BY 1""" + + full_databases_query = """ + SELECT d.datname as "Name", + pg_catalog.pg_get_userbyid(d.datdba) as "Owner", + pg_catalog.pg_encoding_to_char(d.encoding) as "Encoding", + d.datcollate as "Collate", + d.datctype as "Ctype", + pg_catalog.array_to_string(d.datacl, E'\n') AS "Access privileges" + FROM pg_catalog.pg_database d + ORDER BY 1""" + + socket_directory_query = """ + SELECT setting + FROM pg_settings + WHERE name = 'unix_socket_directories' + """ + + view_definition_query = """ + WITH v AS (SELECT %s::pg_catalog.regclass::pg_catalog.oid AS v_oid) + SELECT nspname, relname, relkind, + pg_catalog.pg_get_viewdef(c.oid, true), + array_remove(array_remove(c.reloptions,'check_option=local'), + 'check_option=cascaded') AS reloptions, + CASE + WHEN 'check_option=local' = ANY (c.reloptions) THEN 'LOCAL'::text + WHEN 'check_option=cascaded' = ANY (c.reloptions) THEN 'CASCADED'::text + ELSE NULL + END AS checkoption + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n ON (c.relnamespace = n.oid) + JOIN v ON (c.oid = v.v_oid)""" + + function_definition_query = """ + WITH f AS + (SELECT %s::pg_catalog.regproc::pg_catalog.oid AS f_oid) + SELECT pg_catalog.pg_get_functiondef(f.f_oid) + FROM f""" + + version_query = "SELECT version();" + + def __init__( + self, + database=None, + user=None, + password=None, + host=None, + port=None, + dsn=None, + **kwargs, + ): + self._conn_params = {} + self.conn = None + self.dbname = None + self.user = None + self.password = None + self.host = None + self.port = None + self.server_version = None + self.extra_args = None + self.connect(database, user, password, host, port, dsn, **kwargs) + self.reset_expanded = None + + def copy(self): + """Returns a clone of the current executor.""" + return self.__class__(**self._conn_params) + + def connect( + self, + database=None, + user=None, + password=None, + host=None, + port=None, + dsn=None, + **kwargs, + ): + + conn_params = self._conn_params.copy() + + new_params = { + "database": database, + "user": user, + "password": password, + "host": host, + "port": port, + "dsn": dsn, + } + new_params.update(kwargs) + + if new_params["dsn"]: + new_params = {"dsn": new_params["dsn"], "password": new_params["password"]} + + if new_params["password"]: + new_params["dsn"] = make_dsn( + new_params["dsn"], password=new_params.pop("password") + ) + + conn_params.update({k: v for k, v in new_params.items() if v}) + + conn = psycopg2.connect(**conn_params) + cursor = conn.cursor() + conn.set_client_encoding("utf8") + + self._conn_params = conn_params + if self.conn: + self.conn.close() + self.conn = conn + self.conn.autocommit = True + + # When we connect using a DSN, we don't really know what db, + # user, etc. we connected to. Let's read it. + # Note: moved this after setting autocommit because of #664. + libpq_version = psycopg2.__libpq_version__ + dsn_parameters = {} + if libpq_version >= 93000: + # use actual connection info from psycopg2.extensions.Connection.info + # as libpq_version > 9.3 is available and required dependency + dsn_parameters = conn.info.dsn_parameters + else: + try: + dsn_parameters = conn.get_dsn_parameters() + except Exception as x: + # https://github.com/dbcli/pgcli/issues/1110 + # PQconninfo not available in libpq < 9.3 + _logger.info("Exception in get_dsn_parameters: %r", x) + + if dsn_parameters: + self.dbname = dsn_parameters.get("dbname") + self.user = dsn_parameters.get("user") + self.host = dsn_parameters.get("host") + self.port = dsn_parameters.get("port") + else: + self.dbname = conn_params.get("database") + self.user = conn_params.get("user") + self.host = conn_params.get("host") + self.port = conn_params.get("port") + + self.password = password + self.extra_args = kwargs + + if not self.host: + self.host = self.get_socket_directory() + + pid = self._select_one(cursor, "select pg_backend_pid()")[0] + self.pid = pid + self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1") + self.server_version = conn.get_parameter_status("server_version") + + register_date_typecasters(conn) + register_json_typecasters(self.conn, self._json_typecaster) + register_hstore_typecaster(self.conn) + + @property + def short_host(self): + if "," in self.host: + host, _, _ = self.host.partition(",") + else: + host = self.host + short_host, _, _ = host.partition(".") + return short_host + + def _select_one(self, cur, sql): + """ + Helper method to run a select and retrieve a single field value + :param cur: cursor + :param sql: string + :return: string + """ + cur.execute(sql) + return cur.fetchone() + + def _json_typecaster(self, json_data): + """Interpret incoming JSON data as a string. + + The raw data is decoded using the connection's encoding, which defaults + to the database's encoding. + + See http://initd.org/psycopg/docs/connection.html#connection.encoding + """ + + return json_data + + def failed_transaction(self): + status = self.conn.get_transaction_status() + return status == ext.TRANSACTION_STATUS_INERROR + + def valid_transaction(self): + status = self.conn.get_transaction_status() + return ( + status == ext.TRANSACTION_STATUS_ACTIVE + or status == ext.TRANSACTION_STATUS_INTRANS + ) + + def run( + self, statement, pgspecial=None, exception_formatter=None, on_error_resume=False + ): + """Execute the sql in the database and return the results. + + :param statement: A string containing one or more sql statements + :param pgspecial: PGSpecial object + :param exception_formatter: A callable that accepts an Exception and + returns a formatted (title, rows, headers, status) tuple that can + act as a query result. If an exception_formatter is not supplied, + psycopg2 exceptions are always raised. + :param on_error_resume: Bool. If true, queries following an exception + (assuming exception_formatter has been supplied) continue to + execute. + + :return: Generator yielding tuples containing + (title, rows, headers, status, query, success, is_special) + """ + + # Remove spaces and EOL + statement = statement.strip() + if not statement: # Empty string + yield (None, None, None, None, statement, False, False) + + # Split the sql into separate queries and run each one. + for sql in sqlparse.split(statement): + # Remove spaces, eol and semi-colons. + sql = sql.rstrip(";") + sql = sqlparse.format(sql, strip_comments=True).strip() + if not sql: + continue + try: + if pgspecial: + # \G is treated specially since we have to set the expanded output. + if sql.endswith("\\G"): + if not pgspecial.expanded_output: + pgspecial.expanded_output = True + self.reset_expanded = True + sql = sql[:-2].strip() + + # First try to run each query as special + _logger.debug("Trying a pgspecial command. sql: %r", sql) + try: + cur = self.conn.cursor() + except psycopg2.InterfaceError: + # edge case when connection is already closed, but we + # don't need cursor for special_cmd.arg_type == NO_QUERY. + # See https://github.com/dbcli/pgcli/issues/1014. + cur = None + try: + for result in pgspecial.execute(cur, sql): + # e.g. execute_from_file already appends these + if len(result) < 7: + yield result + (sql, True, True) + else: + yield result + continue + except special.CommandNotFound: + pass + + # Not a special command, so execute as normal sql + yield self.execute_normal_sql(sql) + (sql, True, False) + except psycopg2.DatabaseError as e: + _logger.error("sql: %r, error: %r", sql, e) + _logger.error("traceback: %r", traceback.format_exc()) + + if self._must_raise(e) or not exception_formatter: + raise + + yield None, None, None, exception_formatter(e), sql, False, False + + if not on_error_resume: + break + finally: + if self.reset_expanded: + pgspecial.expanded_output = False + self.reset_expanded = None + + def _must_raise(self, e): + """Return true if e is an error that should not be caught in ``run``. + + An uncaught error will prompt the user to reconnect; as long as we + detect that the connection is stil open, we catch the error, as + reconnecting won't solve that problem. + + :param e: DatabaseError. An exception raised while executing a query. + + :return: Bool. True if ``run`` must raise this exception. + + """ + return self.conn.closed != 0 + + def execute_normal_sql(self, split_sql): + """Returns tuple (title, rows, headers, status)""" + _logger.debug("Regular sql statement. sql: %r", split_sql) + cur = self.conn.cursor() + cur.execute(split_sql) + + # conn.notices persist between queies, we use pop to clear out the list + title = "" + while len(self.conn.notices) > 0: + title = self.conn.notices.pop() + title + + # cur.description will be None for operations that do not return + # rows. + if cur.description: + headers = [x[0] for x in cur.description] + return title, cur, headers, cur.statusmessage + else: + _logger.debug("No rows in result.") + return title, None, None, cur.statusmessage + + def search_path(self): + """Returns the current search path as a list of schema names""" + + try: + with self.conn.cursor() as cur: + _logger.debug("Search path query. sql: %r", self.search_path_query) + cur.execute(self.search_path_query) + return [x[0] for x in cur.fetchall()] + except psycopg2.ProgrammingError: + fallback = "SELECT * FROM current_schemas(true)" + with self.conn.cursor() as cur: + _logger.debug("Search path query. sql: %r", fallback) + cur.execute(fallback) + return cur.fetchone()[0] + + def view_definition(self, spec): + """Returns the SQL defining views described by `spec`""" + + template = "CREATE OR REPLACE {6} VIEW {0}.{1} AS \n{3}" + # 2: relkind, v or m (materialized) + # 4: reloptions, null + # 5: checkoption: local or cascaded + with self.conn.cursor() as cur: + sql = self.view_definition_query + _logger.debug("View Definition Query. sql: %r\nspec: %r", sql, spec) + try: + cur.execute(sql, (spec,)) + except psycopg2.ProgrammingError: + raise RuntimeError("View {} does not exist.".format(spec)) + result = cur.fetchone() + view_type = "MATERIALIZED" if result[2] == "m" else "" + return template.format(*result + (view_type,)) + + def function_definition(self, spec): + """Returns the SQL defining functions described by `spec`""" + + with self.conn.cursor() as cur: + sql = self.function_definition_query + _logger.debug("Function Definition Query. sql: %r\nspec: %r", sql, spec) + try: + cur.execute(sql, (spec,)) + result = cur.fetchone() + return result[0] + except psycopg2.ProgrammingError: + raise RuntimeError("Function {} does not exist.".format(spec)) + + def schemata(self): + """Returns a list of schema names in the database""" + + with self.conn.cursor() as cur: + _logger.debug("Schemata Query. sql: %r", self.schemata_query) + cur.execute(self.schemata_query) + return [x[0] for x in cur.fetchall()] + + def _relations(self, kinds=("r", "p", "f", "v", "m")): + """Get table or view name metadata + + :param kinds: list of postgres relkind filters: + 'r' - table + 'p' - partitioned table + 'f' - foreign table + 'v' - view + 'm' - materialized view + :return: (schema_name, rel_name) tuples + """ + + with self.conn.cursor() as cur: + sql = cur.mogrify(self.tables_query, [kinds]) + _logger.debug("Tables Query. sql: %r", sql) + cur.execute(sql) + for row in cur: + yield row + + def tables(self): + """Yields (schema_name, table_name) tuples""" + for row in self._relations(kinds=["r", "p", "f"]): + yield row + + def views(self): + """Yields (schema_name, view_name) tuples. + + Includes both views and and materialized views + """ + for row in self._relations(kinds=["v", "m"]): + yield row + + def _columns(self, kinds=("r", "p", "f", "v", "m")): + """Get column metadata for tables and views + + :param kinds: kinds: list of postgres relkind filters: + 'r' - table + 'p' - partitioned table + 'f' - foreign table + 'v' - view + 'm' - materialized view + :return: list of (schema_name, relation_name, column_name, column_type) tuples + """ + + if self.conn.server_version >= 80400: + columns_query = """ + SELECT nsp.nspname schema_name, + cls.relname table_name, + att.attname column_name, + att.atttypid::regtype::text type_name, + att.atthasdef AS has_default, + pg_catalog.pg_get_expr(def.adbin, def.adrelid, true) as default + FROM pg_catalog.pg_attribute att + INNER JOIN pg_catalog.pg_class cls + ON att.attrelid = cls.oid + INNER JOIN pg_catalog.pg_namespace nsp + ON cls.relnamespace = nsp.oid + LEFT OUTER JOIN pg_attrdef def + ON def.adrelid = att.attrelid + AND def.adnum = att.attnum + WHERE cls.relkind = ANY(%s) + AND NOT att.attisdropped + AND att.attnum > 0 + ORDER BY 1, 2, att.attnum""" + else: + columns_query = """ + SELECT nsp.nspname schema_name, + cls.relname table_name, + att.attname column_name, + typ.typname type_name, + NULL AS has_default, + NULL AS default + FROM pg_catalog.pg_attribute att + INNER JOIN pg_catalog.pg_class cls + ON att.attrelid = cls.oid + INNER JOIN pg_catalog.pg_namespace nsp + ON cls.relnamespace = nsp.oid + INNER JOIN pg_catalog.pg_type typ + ON typ.oid = att.atttypid + WHERE cls.relkind = ANY(%s) + AND NOT att.attisdropped + AND att.attnum > 0 + ORDER BY 1, 2, att.attnum""" + + with self.conn.cursor() as cur: + sql = cur.mogrify(columns_query, [kinds]) + _logger.debug("Columns Query. sql: %r", sql) + cur.execute(sql) + for row in cur: + yield row + + def table_columns(self): + for row in self._columns(kinds=["r", "p", "f"]): + yield row + + def view_columns(self): + for row in self._columns(kinds=["v", "m"]): + yield row + + def databases(self): + with self.conn.cursor() as cur: + _logger.debug("Databases Query. sql: %r", self.databases_query) + cur.execute(self.databases_query) + return [x[0] for x in cur.fetchall()] + + def full_databases(self): + with self.conn.cursor() as cur: + _logger.debug("Databases Query. sql: %r", self.full_databases_query) + cur.execute(self.full_databases_query) + headers = [x[0] for x in cur.description] + return cur.fetchall(), headers, cur.statusmessage + + def get_socket_directory(self): + with self.conn.cursor() as cur: + _logger.debug( + "Socket directory Query. sql: %r", self.socket_directory_query + ) + cur.execute(self.socket_directory_query) + result = cur.fetchone() + return result[0] if result else "" + + def foreignkeys(self): + """Yields ForeignKey named tuples""" + + if self.conn.server_version < 90000: + return + + with self.conn.cursor() as cur: + query = """ + SELECT s_p.nspname AS parentschema, + t_p.relname AS parenttable, + unnest(( + select + array_agg(attname ORDER BY i) + from + (select unnest(confkey) as attnum, generate_subscripts(confkey, 1) as i) x + JOIN pg_catalog.pg_attribute c USING(attnum) + WHERE c.attrelid = fk.confrelid + )) AS parentcolumn, + s_c.nspname AS childschema, + t_c.relname AS childtable, + unnest(( + select + array_agg(attname ORDER BY i) + from + (select unnest(conkey) as attnum, generate_subscripts(conkey, 1) as i) x + JOIN pg_catalog.pg_attribute c USING(attnum) + WHERE c.attrelid = fk.conrelid + )) AS childcolumn + FROM pg_catalog.pg_constraint fk + JOIN pg_catalog.pg_class t_p ON t_p.oid = fk.confrelid + JOIN pg_catalog.pg_namespace s_p ON s_p.oid = t_p.relnamespace + JOIN pg_catalog.pg_class t_c ON t_c.oid = fk.conrelid + JOIN pg_catalog.pg_namespace s_c ON s_c.oid = t_c.relnamespace + WHERE fk.contype = 'f'; + """ + _logger.debug("Functions Query. sql: %r", query) + cur.execute(query) + for row in cur: + yield ForeignKey(*row) + + def functions(self): + """Yields FunctionMetadata named tuples""" + + if self.conn.server_version >= 110000: + query = """ + SELECT n.nspname schema_name, + p.proname func_name, + p.proargnames, + COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[], + p.proargmodes, + prorettype::regtype::text return_type, + p.prokind = 'a' is_aggregate, + p.prokind = 'w' is_window, + p.proretset is_set_returning, + d.deptype = 'e' is_extension, + pg_get_expr(proargdefaults, 0) AS arg_defaults + FROM pg_catalog.pg_proc p + INNER JOIN pg_catalog.pg_namespace n + ON n.oid = p.pronamespace + LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' + WHERE p.prorettype::regtype != 'trigger'::regtype + ORDER BY 1, 2 + """ + elif self.conn.server_version > 90000: + query = """ + SELECT n.nspname schema_name, + p.proname func_name, + p.proargnames, + COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[], + p.proargmodes, + prorettype::regtype::text return_type, + p.proisagg is_aggregate, + p.proiswindow is_window, + p.proretset is_set_returning, + d.deptype = 'e' is_extension, + pg_get_expr(proargdefaults, 0) AS arg_defaults + FROM pg_catalog.pg_proc p + INNER JOIN pg_catalog.pg_namespace n + ON n.oid = p.pronamespace + LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' + WHERE p.prorettype::regtype != 'trigger'::regtype + ORDER BY 1, 2 + """ + elif self.conn.server_version >= 80400: + query = """ + SELECT n.nspname schema_name, + p.proname func_name, + p.proargnames, + COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[], + p.proargmodes, + prorettype::regtype::text, + p.proisagg is_aggregate, + false is_window, + p.proretset is_set_returning, + d.deptype = 'e' is_extension, + NULL AS arg_defaults + FROM pg_catalog.pg_proc p + INNER JOIN pg_catalog.pg_namespace n + ON n.oid = p.pronamespace + LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' + WHERE p.prorettype::regtype != 'trigger'::regtype + ORDER BY 1, 2 + """ + else: + query = """ + SELECT n.nspname schema_name, + p.proname func_name, + p.proargnames, + NULL arg_types, + NULL arg_modes, + '' ret_type, + p.proisagg is_aggregate, + false is_window, + p.proretset is_set_returning, + d.deptype = 'e' is_extension, + NULL AS arg_defaults + FROM pg_catalog.pg_proc p + INNER JOIN pg_catalog.pg_namespace n + ON n.oid = p.pronamespace + LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e' + WHERE p.prorettype::regtype != 'trigger'::regtype + ORDER BY 1, 2 + """ + + with self.conn.cursor() as cur: + _logger.debug("Functions Query. sql: %r", query) + cur.execute(query) + for row in cur: + yield FunctionMetadata(*row) + + def datatypes(self): + """Yields tuples of (schema_name, type_name)""" + + with self.conn.cursor() as cur: + if self.conn.server_version > 90000: + query = """ + SELECT n.nspname schema_name, + t.typname type_name + FROM pg_catalog.pg_type t + INNER JOIN pg_catalog.pg_namespace n + ON n.oid = t.typnamespace + WHERE ( t.typrelid = 0 -- non-composite types + OR ( -- composite type, but not a table + SELECT c.relkind = 'c' + FROM pg_catalog.pg_class c + WHERE c.oid = t.typrelid + ) + ) + AND NOT EXISTS( -- ignore array types + SELECT 1 + FROM pg_catalog.pg_type el + WHERE el.oid = t.typelem AND el.typarray = t.oid + ) + AND n.nspname <> 'pg_catalog' + AND n.nspname <> 'information_schema' + ORDER BY 1, 2; + """ + else: + query = """ + SELECT n.nspname schema_name, + pg_catalog.format_type(t.oid, NULL) type_name + FROM pg_catalog.pg_type t + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace + WHERE (t.typrelid = 0 OR (SELECT c.relkind = 'c' FROM pg_catalog.pg_class c WHERE c.oid = t.typrelid)) + AND t.typname !~ '^_' + AND n.nspname <> 'pg_catalog' + AND n.nspname <> 'information_schema' + AND pg_catalog.pg_type_is_visible(t.oid) + ORDER BY 1, 2; + """ + _logger.debug("Datatypes Query. sql: %r", query) + cur.execute(query) + for row in cur: + yield row + + def casing(self): + """Yields the most common casing for names used in db functions""" + with self.conn.cursor() as cur: + query = r""" + WITH Words AS ( + SELECT regexp_split_to_table(prosrc, '\W+') AS Word, COUNT(1) + FROM pg_catalog.pg_proc P + JOIN pg_catalog.pg_namespace N ON N.oid = P.pronamespace + JOIN pg_catalog.pg_language L ON L.oid = P.prolang + WHERE L.lanname IN ('sql', 'plpgsql') + AND N.nspname NOT IN ('pg_catalog', 'information_schema') + GROUP BY Word + ), + OrderWords AS ( + SELECT Word, + ROW_NUMBER() OVER(PARTITION BY LOWER(Word) ORDER BY Count DESC) + FROM Words + WHERE Word ~* '.*[a-z].*' + ), + Names AS ( + --Column names + SELECT attname AS Name + FROM pg_catalog.pg_attribute + UNION -- Table/view names + SELECT relname + FROM pg_catalog.pg_class + UNION -- Function names + SELECT proname + FROM pg_catalog.pg_proc + UNION -- Type names + SELECT typname + FROM pg_catalog.pg_type + UNION -- Schema names + SELECT nspname + FROM pg_catalog.pg_namespace + UNION -- Parameter names + SELECT unnest(proargnames) + FROM pg_proc + ) + SELECT Word + FROM OrderWords + WHERE LOWER(Word) IN (SELECT Name FROM Names) + AND Row_Number = 1; + """ + _logger.debug("Casing Query. sql: %r", query) + cur.execute(query) + for row in cur: + yield row[0] diff --git a/pgcli/pgstyle.py b/pgcli/pgstyle.py new file mode 100644 index 0000000..8229037 --- /dev/null +++ b/pgcli/pgstyle.py @@ -0,0 +1,116 @@ +import logging + +import pygments.styles +from pygments.token import string_to_tokentype, Token +from pygments.style import Style as PygmentsStyle +from pygments.util import ClassNotFound +from prompt_toolkit.styles.pygments import style_from_pygments_cls +from prompt_toolkit.styles import merge_styles, Style + +logger = logging.getLogger(__name__) + +# map Pygments tokens (ptk 1.0) to class names (ptk 2.0). +TOKEN_TO_PROMPT_STYLE = { + Token.Menu.Completions.Completion.Current: "completion-menu.completion.current", + Token.Menu.Completions.Completion: "completion-menu.completion", + Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current", + Token.Menu.Completions.Meta: "completion-menu.meta.completion", + Token.Menu.Completions.MultiColumnMeta: "completion-menu.multi-column-meta", + Token.Menu.Completions.ProgressButton: "scrollbar.arrow", # best guess + Token.Menu.Completions.ProgressBar: "scrollbar", # best guess + Token.SelectedText: "selected", + Token.SearchMatch: "search", + Token.SearchMatch.Current: "search.current", + Token.Toolbar: "bottom-toolbar", + Token.Toolbar.Off: "bottom-toolbar.off", + Token.Toolbar.On: "bottom-toolbar.on", + Token.Toolbar.Search: "search-toolbar", + Token.Toolbar.Search.Text: "search-toolbar.text", + Token.Toolbar.System: "system-toolbar", + Token.Toolbar.Arg: "arg-toolbar", + Token.Toolbar.Arg.Text: "arg-toolbar.text", + Token.Toolbar.Transaction.Valid: "bottom-toolbar.transaction.valid", + Token.Toolbar.Transaction.Failed: "bottom-toolbar.transaction.failed", + Token.Output.Header: "output.header", + Token.Output.OddRow: "output.odd-row", + Token.Output.EvenRow: "output.even-row", + Token.Output.Null: "output.null", + Token.Literal.String: "literal.string", + Token.Literal.Number: "literal.number", + Token.Keyword: "keyword", + Token.Prompt: "prompt", + Token.Continuation: "continuation", +} + +# reverse dict for cli_helpers, because they still expect Pygments tokens. +PROMPT_STYLE_TO_TOKEN = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()} + + +def parse_pygments_style(token_name, style_object, style_dict): + """Parse token type and style string. + + :param token_name: str name of Pygments token. Example: "Token.String" + :param style_object: pygments.style.Style instance to use as base + :param style_dict: dict of token names and their styles, customized to this cli + + """ + token_type = string_to_tokentype(token_name) + try: + other_token_type = string_to_tokentype(style_dict[token_name]) + return token_type, style_object.styles[other_token_type] + except AttributeError: + return token_type, style_dict[token_name] + + +def style_factory(name, cli_style): + try: + style = pygments.styles.get_style_by_name(name) + except ClassNotFound: + style = pygments.styles.get_style_by_name("native") + + prompt_styles = [] + # prompt-toolkit used pygments tokens for styling before, switched to style + # names in 2.0. Convert old token types to new style names, for backwards compatibility. + for token in cli_style: + if token.startswith("Token."): + # treat as pygments token (1.0) + token_type, style_value = parse_pygments_style(token, style, cli_style) + if token_type in TOKEN_TO_PROMPT_STYLE: + prompt_style = TOKEN_TO_PROMPT_STYLE[token_type] + prompt_styles.append((prompt_style, style_value)) + else: + # we don't want to support tokens anymore + logger.error("Unhandled style / class name: %s", token) + else: + # treat as prompt style name (2.0). See default style names here: + # https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py + prompt_styles.append((token, cli_style[token])) + + override_style = Style([("bottom-toolbar", "noreverse")]) + return merge_styles( + [style_from_pygments_cls(style), override_style, Style(prompt_styles)] + ) + + +def style_factory_output(name, cli_style): + try: + style = pygments.styles.get_style_by_name(name).styles + except ClassNotFound: + style = pygments.styles.get_style_by_name("native").styles + + for token in cli_style: + if token.startswith("Token."): + token_type, style_value = parse_pygments_style(token, style, cli_style) + style.update({token_type: style_value}) + elif token in PROMPT_STYLE_TO_TOKEN: + token_type = PROMPT_STYLE_TO_TOKEN[token] + style.update({token_type: cli_style[token]}) + else: + # TODO: cli helpers will have to switch to ptk.Style + logger.error("Unhandled style / class name: %s", token) + + class OutputStyle(PygmentsStyle): + default_style = "" + styles = style + + return OutputStyle diff --git a/pgcli/pgtoolbar.py b/pgcli/pgtoolbar.py new file mode 100644 index 0000000..f4289a1 --- /dev/null +++ b/pgcli/pgtoolbar.py @@ -0,0 +1,62 @@ +from prompt_toolkit.key_binding.vi_state import InputMode +from prompt_toolkit.application import get_app + + +def _get_vi_mode(): + return { + InputMode.INSERT: "I", + InputMode.NAVIGATION: "N", + InputMode.REPLACE: "R", + InputMode.REPLACE_SINGLE: "R", + InputMode.INSERT_MULTIPLE: "M", + }[get_app().vi_state.input_mode] + + +def create_toolbar_tokens_func(pgcli): + """Return a function that generates the toolbar tokens.""" + + def get_toolbar_tokens(): + result = [] + result.append(("class:bottom-toolbar", " ")) + + if pgcli.completer.smart_completion: + result.append(("class:bottom-toolbar.on", "[F2] Smart Completion: ON ")) + else: + result.append(("class:bottom-toolbar.off", "[F2] Smart Completion: OFF ")) + + if pgcli.multi_line: + result.append(("class:bottom-toolbar.on", "[F3] Multiline: ON ")) + else: + result.append(("class:bottom-toolbar.off", "[F3] Multiline: OFF ")) + + if pgcli.multi_line: + if pgcli.multiline_mode == "safe": + result.append(("class:bottom-toolbar", " ([Esc] [Enter] to execute]) ")) + else: + result.append( + ("class:bottom-toolbar", " (Semi-colon [;] will end the line) ") + ) + + if pgcli.vi_mode: + result.append( + ("class:bottom-toolbar", "[F4] Vi-mode (" + _get_vi_mode() + ")") + ) + else: + result.append(("class:bottom-toolbar", "[F4] Emacs-mode")) + + if pgcli.pgexecute.failed_transaction(): + result.append( + ("class:bottom-toolbar.transaction.failed", " Failed transaction") + ) + + if pgcli.pgexecute.valid_transaction(): + result.append( + ("class:bottom-toolbar.transaction.valid", " Transaction") + ) + + if pgcli.completion_refresher.is_refreshing(): + result.append(("class:bottom-toolbar", " Refreshing completions...")) + + return result + + return get_toolbar_tokens -- cgit v1.2.3