summaryrefslogtreecommitdiffstats
path: root/pgcli
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2021-02-08 10:31:05 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2021-02-08 10:31:05 +0000
commit6884720fae8a2622b14e93d9e35ca5fcc2283b40 (patch)
treedf6f736bb623cdd7932bbe2256101a6ac4ef7f35 /pgcli
parentInitial commit. (diff)
downloadpgcli-6884720fae8a2622b14e93d9e35ca5fcc2283b40.tar.xz
pgcli-6884720fae8a2622b14e93d9e35ca5fcc2283b40.zip
Adding upstream version 3.1.0.upstream/3.1.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'pgcli')
-rw-r--r--pgcli/__init__.py1
-rw-r--r--pgcli/__main__.py9
-rw-r--r--pgcli/completion_refresher.py150
-rw-r--r--pgcli/config.py64
-rw-r--r--pgcli/key_bindings.py127
-rw-r--r--pgcli/magic.py67
-rw-r--r--pgcli/main.py1516
-rw-r--r--pgcli/packages/__init__.py0
-rw-r--r--pgcli/packages/parseutils/__init__.py22
-rw-r--r--pgcli/packages/parseutils/ctes.py141
-rw-r--r--pgcli/packages/parseutils/meta.py170
-rw-r--r--pgcli/packages/parseutils/tables.py170
-rw-r--r--pgcli/packages/parseutils/utils.py140
-rw-r--r--pgcli/packages/pgliterals/__init__.py0
-rw-r--r--pgcli/packages/pgliterals/main.py15
-rw-r--r--pgcli/packages/pgliterals/pgliterals.json629
-rw-r--r--pgcli/packages/prioritization.py51
-rw-r--r--pgcli/packages/prompt_utils.py35
-rw-r--r--pgcli/packages/sqlcompletion.py608
-rw-r--r--pgcli/pgbuffer.py50
-rw-r--r--pgcli/pgclirc195
-rw-r--r--pgcli/pgcompleter.py1046
-rw-r--r--pgcli/pgexecute.py857
-rw-r--r--pgcli/pgstyle.py116
-rw-r--r--pgcli/pgtoolbar.py62
25 files changed, 6241 insertions, 0 deletions
diff --git a/pgcli/__init__.py b/pgcli/__init__.py
new file mode 100644
index 0000000..f5f41e5
--- /dev/null
+++ b/pgcli/__init__.py
@@ -0,0 +1 @@
+__version__ = "3.1.0"
diff --git a/pgcli/__main__.py b/pgcli/__main__.py
new file mode 100644
index 0000000..ddf1662
--- /dev/null
+++ b/pgcli/__main__.py
@@ -0,0 +1,9 @@
+"""
+pgcli package main entry point
+"""
+
+from .main import cli
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/pgcli/completion_refresher.py b/pgcli/completion_refresher.py
new file mode 100644
index 0000000..cf0879f
--- /dev/null
+++ b/pgcli/completion_refresher.py
@@ -0,0 +1,150 @@
+import threading
+import os
+from collections import OrderedDict
+
+from .pgcompleter import PGCompleter
+from .pgexecute import PGExecute
+
+
+class CompletionRefresher(object):
+
+ refreshers = OrderedDict()
+
+ def __init__(self):
+ self._completer_thread = None
+ self._restart_refresh = threading.Event()
+
+ def refresh(self, executor, special, callbacks, history=None, settings=None):
+ """
+ Creates a PGCompleter object and populates it with the relevant
+ completion suggestions in a background thread.
+
+ executor - PGExecute object, used to extract the credentials to connect
+ to the database.
+ special - PGSpecial object used for creating a new completion object.
+ settings - dict of settings for completer object
+ callbacks - A function or a list of functions to call after the thread
+ has completed the refresh. The newly created completion
+ object will be passed in as an argument to each callback.
+ """
+ if self.is_refreshing():
+ self._restart_refresh.set()
+ return [(None, None, None, "Auto-completion refresh restarted.")]
+ else:
+ self._completer_thread = threading.Thread(
+ target=self._bg_refresh,
+ args=(executor, special, callbacks, history, settings),
+ name="completion_refresh",
+ )
+ self._completer_thread.setDaemon(True)
+ self._completer_thread.start()
+ return [
+ (None, None, None, "Auto-completion refresh started in the background.")
+ ]
+
+ def is_refreshing(self):
+ return self._completer_thread and self._completer_thread.is_alive()
+
+ def _bg_refresh(self, pgexecute, special, callbacks, history=None, settings=None):
+ settings = settings or {}
+ completer = PGCompleter(
+ smart_completion=True, pgspecial=special, settings=settings
+ )
+
+ if settings.get("single_connection"):
+ executor = pgexecute
+ else:
+ # Create a new pgexecute method to populate the completions.
+ executor = pgexecute.copy()
+ # If callbacks is a single function then push it into a list.
+ if callable(callbacks):
+ callbacks = [callbacks]
+
+ while 1:
+ for refresher in self.refreshers.values():
+ refresher(completer, executor)
+ if self._restart_refresh.is_set():
+ self._restart_refresh.clear()
+ break
+ else:
+ # Break out of while loop if the for loop finishes natually
+ # without hitting the break statement.
+ break
+
+ # Start over the refresh from the beginning if the for loop hit the
+ # break statement.
+ continue
+
+ # Load history into pgcompleter so it can learn user preferences
+ n_recent = 100
+ if history:
+ for recent in history.get_strings()[-n_recent:]:
+ completer.extend_query_history(recent, is_init=True)
+
+ for callback in callbacks:
+ callback(completer)
+
+ if not settings.get("single_connection") and executor.conn:
+ # close connection established with pgexecute.copy()
+ executor.conn.close()
+
+
+def refresher(name, refreshers=CompletionRefresher.refreshers):
+ """Decorator to populate the dictionary of refreshers with the current
+ function.
+ """
+
+ def wrapper(wrapped):
+ refreshers[name] = wrapped
+ return wrapped
+
+ return wrapper
+
+
+@refresher("schemata")
+def refresh_schemata(completer, executor):
+ completer.set_search_path(executor.search_path())
+ completer.extend_schemata(executor.schemata())
+
+
+@refresher("tables")
+def refresh_tables(completer, executor):
+ completer.extend_relations(executor.tables(), kind="tables")
+ completer.extend_columns(executor.table_columns(), kind="tables")
+ completer.extend_foreignkeys(executor.foreignkeys())
+
+
+@refresher("views")
+def refresh_views(completer, executor):
+ completer.extend_relations(executor.views(), kind="views")
+ completer.extend_columns(executor.view_columns(), kind="views")
+
+
+@refresher("types")
+def refresh_types(completer, executor):
+ completer.extend_datatypes(executor.datatypes())
+
+
+@refresher("databases")
+def refresh_databases(completer, executor):
+ completer.extend_database_names(executor.databases())
+
+
+@refresher("casing")
+def refresh_casing(completer, executor):
+ casing_file = completer.casing_file
+ if not casing_file:
+ return
+ generate_casing_file = completer.generate_casing_file
+ if generate_casing_file and not os.path.isfile(casing_file):
+ casing_prefs = "\n".join(executor.casing())
+ with open(casing_file, "w") as f:
+ f.write(casing_prefs)
+ if os.path.isfile(casing_file):
+ with open(casing_file, "r") as f:
+ completer.extend_casing([line.strip() for line in f])
+
+
+@refresher("functions")
+def refresh_functions(completer, executor):
+ completer.extend_functions(executor.functions())
diff --git a/pgcli/config.py b/pgcli/config.py
new file mode 100644
index 0000000..0fc42dd
--- /dev/null
+++ b/pgcli/config.py
@@ -0,0 +1,64 @@
+import errno
+import shutil
+import os
+import platform
+from os.path import expanduser, exists, dirname
+from configobj import ConfigObj
+
+
+def config_location():
+ if "XDG_CONFIG_HOME" in os.environ:
+ return "%s/pgcli/" % expanduser(os.environ["XDG_CONFIG_HOME"])
+ elif platform.system() == "Windows":
+ return os.getenv("USERPROFILE") + "\\AppData\\Local\\dbcli\\pgcli\\"
+ else:
+ return expanduser("~/.config/pgcli/")
+
+
+def load_config(usr_cfg, def_cfg=None):
+ cfg = ConfigObj()
+ cfg.merge(ConfigObj(def_cfg, interpolation=False))
+ cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8"))
+ cfg.filename = expanduser(usr_cfg)
+
+ return cfg
+
+
+def ensure_dir_exists(path):
+ parent_dir = expanduser(dirname(path))
+ os.makedirs(parent_dir, exist_ok=True)
+
+
+def write_default_config(source, destination, overwrite=False):
+ destination = expanduser(destination)
+ if not overwrite and exists(destination):
+ return
+
+ ensure_dir_exists(destination)
+
+ shutil.copyfile(source, destination)
+
+
+def upgrade_config(config, def_config):
+ cfg = load_config(config, def_config)
+ cfg.write()
+
+
+def get_config(pgclirc_file=None):
+ from pgcli import __file__ as package_root
+
+ package_root = os.path.dirname(package_root)
+
+ pgclirc_file = pgclirc_file or "%sconfig" % config_location()
+
+ default_config = os.path.join(package_root, "pgclirc")
+ write_default_config(default_config, pgclirc_file)
+
+ return load_config(pgclirc_file, default_config)
+
+
+def get_casing_file(config):
+ casing_file = config["main"]["casing_file"]
+ if casing_file == "default":
+ casing_file = config_location() + "casing"
+ return casing_file
diff --git a/pgcli/key_bindings.py b/pgcli/key_bindings.py
new file mode 100644
index 0000000..23174b6
--- /dev/null
+++ b/pgcli/key_bindings.py
@@ -0,0 +1,127 @@
+import logging
+from prompt_toolkit.enums import EditingMode
+from prompt_toolkit.key_binding import KeyBindings
+from prompt_toolkit.filters import (
+ completion_is_selected,
+ is_searching,
+ has_completions,
+ has_selection,
+ vi_mode,
+)
+
+from .pgbuffer import buffer_should_be_handled
+
+_logger = logging.getLogger(__name__)
+
+
+def pgcli_bindings(pgcli):
+ """Custom key bindings for pgcli."""
+ kb = KeyBindings()
+
+ tab_insert_text = " " * 4
+
+ @kb.add("f2")
+ def _(event):
+ """Enable/Disable SmartCompletion Mode."""
+ _logger.debug("Detected F2 key.")
+ pgcli.completer.smart_completion = not pgcli.completer.smart_completion
+
+ @kb.add("f3")
+ def _(event):
+ """Enable/Disable Multiline Mode."""
+ _logger.debug("Detected F3 key.")
+ pgcli.multi_line = not pgcli.multi_line
+
+ @kb.add("f4")
+ def _(event):
+ """Toggle between Vi and Emacs mode."""
+ _logger.debug("Detected F4 key.")
+ pgcli.vi_mode = not pgcli.vi_mode
+ event.app.editing_mode = EditingMode.VI if pgcli.vi_mode else EditingMode.EMACS
+
+ @kb.add("tab")
+ def _(event):
+ """Force autocompletion at cursor on non-empty lines."""
+
+ _logger.debug("Detected <Tab> key.")
+
+ buff = event.app.current_buffer
+ doc = buff.document
+
+ if doc.on_first_line or doc.current_line.strip():
+ if buff.complete_state:
+ buff.complete_next()
+ else:
+ buff.start_completion(select_first=True)
+ else:
+ buff.insert_text(tab_insert_text, fire_event=False)
+
+ @kb.add("escape", filter=has_completions)
+ def _(event):
+ """Force closing of autocompletion."""
+ _logger.debug("Detected <Esc> key.")
+
+ event.current_buffer.complete_state = None
+ event.app.current_buffer.complete_state = None
+
+ @kb.add("c-space")
+ def _(event):
+ """
+ Initialize autocompletion at cursor.
+
+ If the autocompletion menu is not showing, display it with the
+ appropriate completions for the context.
+
+ If the menu is showing, select the next completion.
+ """
+ _logger.debug("Detected <C-Space> key.")
+
+ b = event.app.current_buffer
+ if b.complete_state:
+ b.complete_next()
+ else:
+ b.start_completion(select_first=False)
+
+ @kb.add("enter", filter=completion_is_selected)
+ def _(event):
+ """Makes the enter key work as the tab key only when showing the menu.
+
+ In other words, don't execute query when enter is pressed in
+ the completion dropdown menu, instead close the dropdown menu
+ (accept current selection).
+
+ """
+ _logger.debug("Detected enter key during completion selection.")
+
+ event.current_buffer.complete_state = None
+ event.app.current_buffer.complete_state = None
+
+ # When using multi_line input mode the buffer is not handled on Enter (a new line is
+ # inserted instead), so we force the handling if we're not in a completion or
+ # history search, and one of several conditions are True
+ @kb.add(
+ "enter",
+ filter=~(completion_is_selected | is_searching)
+ & buffer_should_be_handled(pgcli),
+ )
+ def _(event):
+ _logger.debug("Detected enter key.")
+ event.current_buffer.validate_and_handle()
+
+ @kb.add("escape", "enter", filter=~vi_mode)
+ def _(event):
+ """Introduces a line break regardless of multi-line mode or not."""
+ _logger.debug("Detected alt-enter key.")
+ event.app.current_buffer.insert_text("\n")
+
+ @kb.add("c-p", filter=~has_selection)
+ def _(event):
+ """Move up in history."""
+ event.current_buffer.history_backward(count=event.arg)
+
+ @kb.add("c-n", filter=~has_selection)
+ def _(event):
+ """Move down in history."""
+ event.current_buffer.history_forward(count=event.arg)
+
+ return kb
diff --git a/pgcli/magic.py b/pgcli/magic.py
new file mode 100644
index 0000000..f58f415
--- /dev/null
+++ b/pgcli/magic.py
@@ -0,0 +1,67 @@
+from .main import PGCli
+import sql.parse
+import sql.connection
+import logging
+
+_logger = logging.getLogger(__name__)
+
+
+def load_ipython_extension(ipython):
+ """This is called via the ipython command '%load_ext pgcli.magic'"""
+
+ # first, load the sql magic if it isn't already loaded
+ if not ipython.find_line_magic("sql"):
+ ipython.run_line_magic("load_ext", "sql")
+
+ # register our own magic
+ ipython.register_magic_function(pgcli_line_magic, "line", "pgcli")
+
+
+def pgcli_line_magic(line):
+ _logger.debug("pgcli magic called: %r", line)
+ parsed = sql.parse.parse(line, {})
+ # "get" was renamed to "set" in ipython-sql:
+ # https://github.com/catherinedevlin/ipython-sql/commit/f4283c65aaf68f961e84019e8b939e4a3c501d43
+ if hasattr(sql.connection.Connection, "get"):
+ conn = sql.connection.Connection.get(parsed["connection"])
+ else:
+ conn = sql.connection.Connection.set(parsed["connection"])
+
+ try:
+ # A corresponding pgcli object already exists
+ pgcli = conn._pgcli
+ _logger.debug("Reusing existing pgcli")
+ except AttributeError:
+ # I can't figure out how to get the underylying psycopg2 connection
+ # from the sqlalchemy connection, so just grab the url and make a
+ # new connection
+ pgcli = PGCli()
+ u = conn.session.engine.url
+ _logger.debug("New pgcli: %r", str(u))
+
+ pgcli.connect(u.database, u.host, u.username, u.port, u.password)
+ conn._pgcli = pgcli
+
+ # For convenience, print the connection alias
+ print("Connected: {}".format(conn.name))
+
+ try:
+ pgcli.run_cli()
+ except SystemExit:
+ pass
+
+ if not pgcli.query_history:
+ return
+
+ q = pgcli.query_history[-1]
+
+ if not q.successful:
+ _logger.debug("Unsuccessful query - ignoring")
+ return
+
+ if q.meta_changed or q.db_changed or q.path_changed:
+ _logger.debug("Dangerous query detected -- ignoring")
+ return
+
+ ipython = get_ipython()
+ return ipython.run_cell_magic("sql", line, q.query)
diff --git a/pgcli/main.py b/pgcli/main.py
new file mode 100644
index 0000000..b146898
--- /dev/null
+++ b/pgcli/main.py
@@ -0,0 +1,1516 @@
+import platform
+import warnings
+from os.path import expanduser
+
+from configobj import ConfigObj
+from pgspecial.namedqueries import NamedQueries
+
+warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2")
+
+import os
+import re
+import sys
+import traceback
+import logging
+import threading
+import shutil
+import functools
+import pendulum
+import datetime as dt
+import itertools
+import platform
+from time import time, sleep
+from codecs import open
+
+keyring = None # keyring will be loaded later
+
+from cli_helpers.tabular_output import TabularOutputFormatter
+from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers
+import click
+
+try:
+ import setproctitle
+except ImportError:
+ setproctitle = None
+from prompt_toolkit.completion import DynamicCompleter, ThreadedCompleter
+from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
+from prompt_toolkit.shortcuts import PromptSession, CompleteStyle
+from prompt_toolkit.document import Document
+from prompt_toolkit.filters import HasFocus, IsDone
+from prompt_toolkit.formatted_text import ANSI
+from prompt_toolkit.lexers import PygmentsLexer
+from prompt_toolkit.layout.processors import (
+ ConditionalProcessor,
+ HighlightMatchingBracketProcessor,
+ TabsProcessor,
+)
+from prompt_toolkit.history import FileHistory
+from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
+from pygments.lexers.sql import PostgresLexer
+
+from pgspecial.main import PGSpecial, NO_QUERY, PAGER_OFF, PAGER_LONG_OUTPUT
+import pgspecial as special
+
+from .pgcompleter import PGCompleter
+from .pgtoolbar import create_toolbar_tokens_func
+from .pgstyle import style_factory, style_factory_output
+from .pgexecute import PGExecute
+from .completion_refresher import CompletionRefresher
+from .config import (
+ get_casing_file,
+ load_config,
+ config_location,
+ ensure_dir_exists,
+ get_config,
+)
+from .key_bindings import pgcli_bindings
+from .packages.prompt_utils import confirm_destructive_query
+from .__init__ import __version__
+
+click.disable_unicode_literals_warning = True
+
+try:
+ from urlparse import urlparse, unquote, parse_qs
+except ImportError:
+ from urllib.parse import urlparse, unquote, parse_qs
+
+from getpass import getuser
+from psycopg2 import OperationalError, InterfaceError
+import psycopg2
+
+from collections import namedtuple
+
+from textwrap import dedent
+
+# Ref: https://stackoverflow.com/questions/30425105/filter-special-chars-such-as-color-codes-from-shell-output
+COLOR_CODE_REGEX = re.compile(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))")
+
+# Query tuples are used for maintaining history
+MetaQuery = namedtuple(
+ "Query",
+ [
+ "query", # The entire text of the command
+ "successful", # True If all subqueries were successful
+ "total_time", # Time elapsed executing the query and formatting results
+ "execution_time", # Time elapsed executing the query
+ "meta_changed", # True if any subquery executed create/alter/drop
+ "db_changed", # True if any subquery changed the database
+ "path_changed", # True if any subquery changed the search path
+ "mutated", # True if any subquery executed insert/update/delete
+ "is_special", # True if the query is a special command
+ ],
+)
+MetaQuery.__new__.__defaults__ = ("", False, 0, 0, False, False, False, False)
+
+OutputSettings = namedtuple(
+ "OutputSettings",
+ "table_format dcmlfmt floatfmt missingval expanded max_width case_function style_output",
+)
+OutputSettings.__new__.__defaults__ = (
+ None,
+ None,
+ None,
+ "<null>",
+ False,
+ None,
+ lambda x: x,
+ None,
+)
+
+
+class PgCliQuitError(Exception):
+ pass
+
+
+class PGCli(object):
+ default_prompt = "\\u@\\h:\\d> "
+ max_len_prompt = 30
+
+ def set_default_pager(self, config):
+ configured_pager = config["main"].get("pager")
+ os_environ_pager = os.environ.get("PAGER")
+
+ if configured_pager:
+ self.logger.info(
+ 'Default pager found in config file: "%s"', configured_pager
+ )
+ os.environ["PAGER"] = configured_pager
+ elif os_environ_pager:
+ self.logger.info(
+ 'Default pager found in PAGER environment variable: "%s"',
+ os_environ_pager,
+ )
+ os.environ["PAGER"] = os_environ_pager
+ else:
+ self.logger.info(
+ "No default pager found in environment. Using os default pager"
+ )
+
+ # Set default set of less recommended options, if they are not already set.
+ # They are ignored if pager is different than less.
+ if not os.environ.get("LESS"):
+ os.environ["LESS"] = "-SRXF"
+
+ def __init__(
+ self,
+ force_passwd_prompt=False,
+ never_passwd_prompt=False,
+ pgexecute=None,
+ pgclirc_file=None,
+ row_limit=None,
+ single_connection=False,
+ less_chatty=None,
+ prompt=None,
+ prompt_dsn=None,
+ auto_vertical_output=False,
+ warn=None,
+ ):
+
+ self.force_passwd_prompt = force_passwd_prompt
+ self.never_passwd_prompt = never_passwd_prompt
+ self.pgexecute = pgexecute
+ self.dsn_alias = None
+ self.watch_command = None
+
+ # Load config.
+ c = self.config = get_config(pgclirc_file)
+
+ NamedQueries.instance = NamedQueries.from_config(self.config)
+
+ self.logger = logging.getLogger(__name__)
+ self.initialize_logging()
+
+ self.set_default_pager(c)
+ self.output_file = None
+ self.pgspecial = PGSpecial()
+
+ self.multi_line = c["main"].as_bool("multi_line")
+ self.multiline_mode = c["main"].get("multi_line_mode", "psql")
+ self.vi_mode = c["main"].as_bool("vi")
+ self.auto_expand = auto_vertical_output or c["main"].as_bool("auto_expand")
+ self.expanded_output = c["main"].as_bool("expand")
+ self.pgspecial.timing_enabled = c["main"].as_bool("timing")
+ if row_limit is not None:
+ self.row_limit = row_limit
+ else:
+ self.row_limit = c["main"].as_int("row_limit")
+
+ self.min_num_menu_lines = c["main"].as_int("min_num_menu_lines")
+ self.multiline_continuation_char = c["main"]["multiline_continuation_char"]
+ self.table_format = c["main"]["table_format"]
+ self.syntax_style = c["main"]["syntax_style"]
+ self.cli_style = c["colors"]
+ self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
+ c_dest_warning = c["main"].as_bool("destructive_warning")
+ self.destructive_warning = c_dest_warning if warn is None else warn
+ self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty")
+ self.null_string = c["main"].get("null_string", "<null>")
+ self.prompt_format = (
+ prompt
+ if prompt is not None
+ else c["main"].get("prompt", self.default_prompt)
+ )
+ self.prompt_dsn_format = prompt_dsn
+ self.on_error = c["main"]["on_error"].upper()
+ self.decimal_format = c["data_formats"]["decimal"]
+ self.float_format = c["data_formats"]["float"]
+ self.initialize_keyring()
+ self.show_bottom_toolbar = c["main"].as_bool("show_bottom_toolbar")
+
+ self.pgspecial.pset_pager(
+ self.config["main"].as_bool("enable_pager") and "on" or "off"
+ )
+
+ self.style_output = style_factory_output(self.syntax_style, c["colors"])
+
+ self.now = dt.datetime.today()
+
+ self.completion_refresher = CompletionRefresher()
+
+ self.query_history = []
+
+ # Initialize completer
+ smart_completion = c["main"].as_bool("smart_completion")
+ keyword_casing = c["main"]["keyword_casing"]
+ self.settings = {
+ "casing_file": get_casing_file(c),
+ "generate_casing_file": c["main"].as_bool("generate_casing_file"),
+ "generate_aliases": c["main"].as_bool("generate_aliases"),
+ "asterisk_column_order": c["main"]["asterisk_column_order"],
+ "qualify_columns": c["main"]["qualify_columns"],
+ "case_column_headers": c["main"].as_bool("case_column_headers"),
+ "search_path_filter": c["main"].as_bool("search_path_filter"),
+ "single_connection": single_connection,
+ "less_chatty": less_chatty,
+ "keyword_casing": keyword_casing,
+ }
+
+ completer = PGCompleter(
+ smart_completion, pgspecial=self.pgspecial, settings=self.settings
+ )
+ self.completer = completer
+ self._completer_lock = threading.Lock()
+ self.register_special_commands()
+
+ self.prompt_app = None
+
+ def quit(self):
+ raise PgCliQuitError
+
+ def register_special_commands(self):
+
+ self.pgspecial.register(
+ self.change_db,
+ "\\c",
+ "\\c[onnect] database_name",
+ "Change to a new database.",
+ aliases=("use", "\\connect", "USE"),
+ )
+
+ refresh_callback = lambda: self.refresh_completions(persist_priorities="all")
+
+ self.pgspecial.register(
+ self.quit,
+ "\\q",
+ "\\q",
+ "Quit pgcli.",
+ arg_type=NO_QUERY,
+ case_sensitive=True,
+ aliases=(":q",),
+ )
+ self.pgspecial.register(
+ self.quit,
+ "quit",
+ "quit",
+ "Quit pgcli.",
+ arg_type=NO_QUERY,
+ case_sensitive=False,
+ aliases=("exit",),
+ )
+ self.pgspecial.register(
+ refresh_callback,
+ "\\#",
+ "\\#",
+ "Refresh auto-completions.",
+ arg_type=NO_QUERY,
+ )
+ self.pgspecial.register(
+ refresh_callback,
+ "\\refresh",
+ "\\refresh",
+ "Refresh auto-completions.",
+ arg_type=NO_QUERY,
+ )
+ self.pgspecial.register(
+ self.execute_from_file, "\\i", "\\i filename", "Execute commands from file."
+ )
+ self.pgspecial.register(
+ self.write_to_file,
+ "\\o",
+ "\\o [filename]",
+ "Send all query results to file.",
+ )
+ self.pgspecial.register(
+ self.info_connection, "\\conninfo", "\\conninfo", "Get connection details"
+ )
+ self.pgspecial.register(
+ self.change_table_format,
+ "\\T",
+ "\\T [format]",
+ "Change the table format used to output results",
+ )
+
+ def change_table_format(self, pattern, **_):
+ try:
+ if pattern not in TabularOutputFormatter().supported_formats:
+ raise ValueError()
+ self.table_format = pattern
+ yield (None, None, None, "Changed table format to {}".format(pattern))
+ except ValueError:
+ msg = "Table format {} not recognized. Allowed formats:".format(pattern)
+ for table_type in TabularOutputFormatter().supported_formats:
+ msg += "\n\t{}".format(table_type)
+ msg += "\nCurrently set to: %s" % self.table_format
+ yield (None, None, None, msg)
+
+ def info_connection(self, **_):
+ if self.pgexecute.host.startswith("/"):
+ host = 'socket "%s"' % self.pgexecute.host
+ else:
+ host = 'host "%s"' % self.pgexecute.host
+
+ yield (
+ None,
+ None,
+ None,
+ 'You are connected to database "%s" as user '
+ '"%s" on %s at port "%s".'
+ % (self.pgexecute.dbname, self.pgexecute.user, host, self.pgexecute.port),
+ )
+
+ def change_db(self, pattern, **_):
+ if pattern:
+ # Get all the parameters in pattern, handling double quotes if any.
+ infos = re.findall(r'"[^"]*"|[^"\'\s]+', pattern)
+ # Now removing quotes.
+ list(map(lambda s: s.strip('"'), infos))
+
+ infos.extend([None] * (4 - len(infos)))
+ db, user, host, port = infos
+ try:
+ self.pgexecute.connect(
+ database=db,
+ user=user,
+ host=host,
+ port=port,
+ **self.pgexecute.extra_args,
+ )
+ except OperationalError as e:
+ click.secho(str(e), err=True, fg="red")
+ click.echo("Previous connection kept")
+ else:
+ self.pgexecute.connect()
+
+ yield (
+ None,
+ None,
+ None,
+ 'You are now connected to database "%s" as '
+ 'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user),
+ )
+
+ def execute_from_file(self, pattern, **_):
+ if not pattern:
+ message = "\\i: missing required argument"
+ return [(None, None, None, message, "", False, True)]
+ try:
+ with open(os.path.expanduser(pattern), encoding="utf-8") as f:
+ query = f.read()
+ except IOError as e:
+ return [(None, None, None, str(e), "", False, True)]
+
+ if self.destructive_warning and confirm_destructive_query(query) is False:
+ message = "Wise choice. Command execution stopped."
+ return [(None, None, None, message)]
+
+ on_error_resume = self.on_error == "RESUME"
+ return self.pgexecute.run(
+ query, self.pgspecial, on_error_resume=on_error_resume
+ )
+
+ def write_to_file(self, pattern, **_):
+ if not pattern:
+ self.output_file = None
+ message = "File output disabled"
+ return [(None, None, None, message, "", True, True)]
+ filename = os.path.abspath(os.path.expanduser(pattern))
+ if not os.path.isfile(filename):
+ try:
+ open(filename, "w").close()
+ except IOError as e:
+ self.output_file = None
+ message = str(e) + "\nFile output disabled"
+ return [(None, None, None, message, "", False, True)]
+ self.output_file = filename
+ message = 'Writing to file "%s"' % self.output_file
+ return [(None, None, None, message, "", True, True)]
+
+ def initialize_logging(self):
+
+ log_file = self.config["main"]["log_file"]
+ if log_file == "default":
+ log_file = config_location() + "log"
+ ensure_dir_exists(log_file)
+ log_level = self.config["main"]["log_level"]
+
+ # Disable logging if value is NONE by switching to a no-op handler.
+ # Set log level to a high value so it doesn't even waste cycles getting called.
+ if log_level.upper() == "NONE":
+ handler = logging.NullHandler()
+ else:
+ handler = logging.FileHandler(os.path.expanduser(log_file))
+
+ level_map = {
+ "CRITICAL": logging.CRITICAL,
+ "ERROR": logging.ERROR,
+ "WARNING": logging.WARNING,
+ "INFO": logging.INFO,
+ "DEBUG": logging.DEBUG,
+ "NONE": logging.CRITICAL,
+ }
+
+ log_level = level_map[log_level.upper()]
+
+ formatter = logging.Formatter(
+ "%(asctime)s (%(process)d/%(threadName)s) "
+ "%(name)s %(levelname)s - %(message)s"
+ )
+
+ handler.setFormatter(formatter)
+
+ root_logger = logging.getLogger("pgcli")
+ root_logger.addHandler(handler)
+ root_logger.setLevel(log_level)
+
+ root_logger.debug("Initializing pgcli logging.")
+ root_logger.debug("Log file %r.", log_file)
+
+ pgspecial_logger = logging.getLogger("pgspecial")
+ pgspecial_logger.addHandler(handler)
+ pgspecial_logger.setLevel(log_level)
+
+ def initialize_keyring(self):
+ global keyring
+
+ keyring_enabled = self.config["main"].as_bool("keyring")
+ if keyring_enabled:
+ # Try best to load keyring (issue #1041).
+ import importlib
+
+ try:
+ keyring = importlib.import_module("keyring")
+ except Exception as e: # ImportError for Python 2, ModuleNotFoundError for Python 3
+ self.logger.warning("import keyring failed: %r.", e)
+
+ def connect_dsn(self, dsn, **kwargs):
+ self.connect(dsn=dsn, **kwargs)
+
+ def connect_service(self, service, user):
+ service_config, file = parse_service_info(service)
+ if service_config is None:
+ click.secho(
+ "service '%s' was not found in %s" % (service, file), err=True, fg="red"
+ )
+ exit(1)
+ self.connect(
+ database=service_config.get("dbname"),
+ host=service_config.get("host"),
+ user=user or service_config.get("user"),
+ port=service_config.get("port"),
+ passwd=service_config.get("password"),
+ )
+
+ def connect_uri(self, uri):
+ kwargs = psycopg2.extensions.parse_dsn(uri)
+ remap = {"dbname": "database", "password": "passwd"}
+ kwargs = {remap.get(k, k): v for k, v in kwargs.items()}
+ self.connect(**kwargs)
+
+ def connect(
+ self, database="", host="", user="", port="", passwd="", dsn="", **kwargs
+ ):
+ # Connect to the database.
+
+ if not user:
+ user = getuser()
+
+ if not database:
+ database = user
+
+ kwargs.setdefault("application_name", "pgcli")
+
+ # If password prompt is not forced but no password is provided, try
+ # getting it from environment variable.
+ if not self.force_passwd_prompt and not passwd:
+ passwd = os.environ.get("PGPASSWORD", "")
+
+ # Find password from store
+ key = "%s@%s" % (user, host)
+ keyring_error_message = dedent(
+ """\
+ {}
+ {}
+ To remove this message do one of the following:
+ - prepare keyring as described at: https://keyring.readthedocs.io/en/stable/
+ - uninstall keyring: pip uninstall keyring
+ - disable keyring in our configuration: add keyring = False to [main]"""
+ )
+ if not passwd and keyring:
+
+ try:
+ passwd = keyring.get_password("pgcli", key)
+ except (RuntimeError, keyring.errors.InitError) as e:
+ click.secho(
+ keyring_error_message.format(
+ "Load your password from keyring returned:", str(e)
+ ),
+ err=True,
+ fg="red",
+ )
+
+ # Prompt for a password immediately if requested via the -W flag. This
+ # avoids wasting time trying to connect to the database and catching a
+ # no-password exception.
+ # If we successfully parsed a password from a URI, there's no need to
+ # prompt for it, even with the -W flag
+ if self.force_passwd_prompt and not passwd:
+ passwd = click.prompt(
+ "Password for %s" % user, hide_input=True, show_default=False, type=str
+ )
+
+ def should_ask_for_password(exc):
+ # Prompt for a password after 1st attempt to connect
+ # fails. Don't prompt if the -w flag is supplied
+ if self.never_passwd_prompt:
+ return False
+ error_msg = exc.args[0]
+ if "no password supplied" in error_msg:
+ return True
+ if "password authentication failed" in error_msg:
+ return True
+ return False
+
+ # Attempt to connect to the database.
+ # Note that passwd may be empty on the first attempt. If connection
+ # fails because of a missing or incorrect password, but we're allowed to
+ # prompt for a password (no -w flag), prompt for a passwd and try again.
+ try:
+ try:
+ pgexecute = PGExecute(database, user, passwd, host, port, dsn, **kwargs)
+ except (OperationalError, InterfaceError) as e:
+ if should_ask_for_password(e):
+ passwd = click.prompt(
+ "Password for %s" % user,
+ hide_input=True,
+ show_default=False,
+ type=str,
+ )
+ pgexecute = PGExecute(
+ database, user, passwd, host, port, dsn, **kwargs
+ )
+ else:
+ raise e
+ if passwd and keyring:
+ try:
+ keyring.set_password("pgcli", key, passwd)
+ except (RuntimeError, keyring.errors.KeyringError) as e:
+ click.secho(
+ keyring_error_message.format(
+ "Set password in keyring returned:", str(e)
+ ),
+ err=True,
+ fg="red",
+ )
+
+ except Exception as e: # Connecting to a database could fail.
+ self.logger.debug("Database connection failed: %r.", e)
+ self.logger.error("traceback: %r", traceback.format_exc())
+ click.secho(str(e), err=True, fg="red")
+ exit(1)
+
+ self.pgexecute = pgexecute
+
+ def handle_editor_command(self, text):
+ r"""
+ Editor command is any query that is prefixed or suffixed
+ by a '\e'. The reason for a while loop is because a user
+ might edit a query multiple times.
+ For eg:
+ "select * from \e"<enter> to edit it in vim, then come
+ back to the prompt with the edited query "select * from
+ blah where q = 'abc'\e" to edit it again.
+ :param text: Document
+ :return: Document
+ """
+ editor_command = special.editor_command(text)
+ while editor_command:
+ if editor_command == "\\e":
+ filename = special.get_filename(text)
+ query = special.get_editor_query(text) or self.get_last_query()
+ else: # \ev or \ef
+ filename = None
+ spec = text.split()[1]
+ if editor_command == "\\ev":
+ query = self.pgexecute.view_definition(spec)
+ elif editor_command == "\\ef":
+ query = self.pgexecute.function_definition(spec)
+ sql, message = special.open_external_editor(filename, sql=query)
+ if message:
+ # Something went wrong. Raise an exception and bail.
+ raise RuntimeError(message)
+ while True:
+ try:
+ text = self.prompt_app.prompt(default=sql)
+ break
+ except KeyboardInterrupt:
+ sql = ""
+
+ editor_command = special.editor_command(text)
+ return text
+
+ def execute_command(self, text):
+ logger = self.logger
+
+ query = MetaQuery(query=text, successful=False)
+
+ try:
+ if self.destructive_warning:
+ destroy = confirm = confirm_destructive_query(text)
+ if destroy is False:
+ click.secho("Wise choice!")
+ raise KeyboardInterrupt
+ elif destroy:
+ click.secho("Your call!")
+ output, query = self._evaluate_command(text)
+ except KeyboardInterrupt:
+ # Restart connection to the database
+ self.pgexecute.connect()
+ logger.debug("cancelled query, sql: %r", text)
+ click.secho("cancelled query", err=True, fg="red")
+ except NotImplementedError:
+ click.secho("Not Yet Implemented.", fg="yellow")
+ except OperationalError as e:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ self._handle_server_closed_connection(text)
+ except (PgCliQuitError, EOFError) as e:
+ raise
+ except Exception as e:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ click.secho(str(e), err=True, fg="red")
+ else:
+ try:
+ if self.output_file and not text.startswith(("\\o ", "\\? ")):
+ try:
+ with open(self.output_file, "a", encoding="utf-8") as f:
+ click.echo(text, file=f)
+ click.echo("\n".join(output), file=f)
+ click.echo("", file=f) # extra newline
+ except IOError as e:
+ click.secho(str(e), err=True, fg="red")
+ else:
+ if output:
+ self.echo_via_pager("\n".join(output))
+ except KeyboardInterrupt:
+ pass
+
+ if self.pgspecial.timing_enabled:
+ # Only add humanized time display if > 1 second
+ if query.total_time > 1:
+ print(
+ "Time: %0.03fs (%s), executed in: %0.03fs (%s)"
+ % (
+ query.total_time,
+ pendulum.Duration(seconds=query.total_time).in_words(),
+ query.execution_time,
+ pendulum.Duration(seconds=query.execution_time).in_words(),
+ )
+ )
+ else:
+ print("Time: %0.03fs" % query.total_time)
+
+ # Check if we need to update completions, in order of most
+ # to least drastic changes
+ if query.db_changed:
+ with self._completer_lock:
+ self.completer.reset_completions()
+ self.refresh_completions(persist_priorities="keywords")
+ elif query.meta_changed:
+ self.refresh_completions(persist_priorities="all")
+ elif query.path_changed:
+ logger.debug("Refreshing search path")
+ with self._completer_lock:
+ self.completer.set_search_path(self.pgexecute.search_path())
+ logger.debug("Search path: %r", self.completer.search_path)
+ return query
+
+ def run_cli(self):
+ logger = self.logger
+
+ history_file = self.config["main"]["history_file"]
+ if history_file == "default":
+ history_file = config_location() + "history"
+ history = FileHistory(os.path.expanduser(history_file))
+ self.refresh_completions(history=history, persist_priorities="none")
+
+ self.prompt_app = self._build_cli(history)
+
+ if not self.less_chatty:
+ print("Server: PostgreSQL", self.pgexecute.server_version)
+ print("Version:", __version__)
+ print("Chat: https://gitter.im/dbcli/pgcli")
+ print("Home: http://pgcli.com")
+
+ try:
+ while True:
+ try:
+ text = self.prompt_app.prompt()
+ except KeyboardInterrupt:
+ continue
+
+ try:
+ text = self.handle_editor_command(text)
+ except RuntimeError as e:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ click.secho(str(e), err=True, fg="red")
+ continue
+
+ # Initialize default metaquery in case execution fails
+ self.watch_command, timing = special.get_watch_command(text)
+ if self.watch_command:
+ while self.watch_command:
+ try:
+ query = self.execute_command(self.watch_command)
+ click.echo(
+ "Waiting for {0} seconds before repeating".format(
+ timing
+ )
+ )
+ sleep(timing)
+ except KeyboardInterrupt:
+ self.watch_command = None
+ else:
+ query = self.execute_command(text)
+
+ self.now = dt.datetime.today()
+
+ # Allow PGCompleter to learn user's preferred keywords, etc.
+ with self._completer_lock:
+ self.completer.extend_query_history(text)
+
+ self.query_history.append(query)
+
+ except (PgCliQuitError, EOFError):
+ if not self.less_chatty:
+ print("Goodbye!")
+
+ def _build_cli(self, history):
+ key_bindings = pgcli_bindings(self)
+
+ def get_message():
+ if self.dsn_alias and self.prompt_dsn_format is not None:
+ prompt_format = self.prompt_dsn_format
+ else:
+ prompt_format = self.prompt_format
+
+ prompt = self.get_prompt(prompt_format)
+
+ if (
+ prompt_format == self.default_prompt
+ and len(prompt) > self.max_len_prompt
+ ):
+ prompt = self.get_prompt("\\d> ")
+
+ prompt = prompt.replace("\\x1b", "\x1b")
+ return ANSI(prompt)
+
+ def get_continuation(width, line_number, is_soft_wrap):
+ continuation = self.multiline_continuation_char * (width - 1) + " "
+ return [("class:continuation", continuation)]
+
+ get_toolbar_tokens = create_toolbar_tokens_func(self)
+
+ if self.wider_completion_menu:
+ complete_style = CompleteStyle.MULTI_COLUMN
+ else:
+ complete_style = CompleteStyle.COLUMN
+
+ with self._completer_lock:
+ prompt_app = PromptSession(
+ lexer=PygmentsLexer(PostgresLexer),
+ reserve_space_for_menu=self.min_num_menu_lines,
+ message=get_message,
+ prompt_continuation=get_continuation,
+ bottom_toolbar=get_toolbar_tokens if self.show_bottom_toolbar else None,
+ complete_style=complete_style,
+ input_processors=[
+ # Highlight matching brackets while editing.
+ ConditionalProcessor(
+ processor=HighlightMatchingBracketProcessor(chars="[](){}"),
+ filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(),
+ ),
+ # Render \t as 4 spaces instead of "^I"
+ TabsProcessor(char1=" ", char2=" "),
+ ],
+ auto_suggest=AutoSuggestFromHistory(),
+ tempfile_suffix=".sql",
+ # N.b. pgcli's multi-line mode controls submit-on-Enter (which
+ # overrides the default behaviour of prompt_toolkit) and is
+ # distinct from prompt_toolkit's multiline mode here, which
+ # controls layout/display of the prompt/buffer
+ multiline=True,
+ history=history,
+ completer=ThreadedCompleter(DynamicCompleter(lambda: self.completer)),
+ complete_while_typing=True,
+ style=style_factory(self.syntax_style, self.cli_style),
+ include_default_pygments_style=False,
+ key_bindings=key_bindings,
+ enable_open_in_editor=True,
+ enable_system_prompt=True,
+ enable_suspend=True,
+ editing_mode=EditingMode.VI if self.vi_mode else EditingMode.EMACS,
+ search_ignore_case=True,
+ )
+
+ return prompt_app
+
+ def _should_limit_output(self, sql, cur):
+ """returns True if the output should be truncated, False otherwise."""
+ if not is_select(sql):
+ return False
+
+ return (
+ not self._has_limit(sql)
+ and self.row_limit != 0
+ and cur
+ and cur.rowcount > self.row_limit
+ )
+
+ def _has_limit(self, sql):
+ if not sql:
+ return False
+ return "limit " in sql.lower()
+
+ def _limit_output(self, cur):
+ limit = min(self.row_limit, cur.rowcount)
+ new_cur = itertools.islice(cur, limit)
+ new_status = "SELECT " + str(limit)
+ click.secho("The result was limited to %s rows" % limit, fg="red")
+
+ return new_cur, new_status
+
+ def _evaluate_command(self, text):
+ """Used to run a command entered by the user during CLI operation
+ (Puts the E in REPL)
+
+ returns (results, MetaQuery)
+ """
+ logger = self.logger
+ logger.debug("sql: %r", text)
+
+ all_success = True
+ meta_changed = False # CREATE, ALTER, DROP, etc
+ mutated = False # INSERT, DELETE, etc
+ db_changed = False
+ path_changed = False
+ output = []
+ total = 0
+ execution = 0
+
+ # Run the query.
+ start = time()
+ on_error_resume = self.on_error == "RESUME"
+ res = self.pgexecute.run(
+ text, self.pgspecial, exception_formatter, on_error_resume
+ )
+
+ is_special = None
+
+ for title, cur, headers, status, sql, success, is_special in res:
+ logger.debug("headers: %r", headers)
+ logger.debug("rows: %r", cur)
+ logger.debug("status: %r", status)
+
+ if self._should_limit_output(sql, cur):
+ cur, status = self._limit_output(cur)
+
+ if self.pgspecial.auto_expand or self.auto_expand:
+ max_width = self.prompt_app.output.get_size().columns
+ else:
+ max_width = None
+
+ expanded = self.pgspecial.expanded_output or self.expanded_output
+ settings = OutputSettings(
+ table_format=self.table_format,
+ dcmlfmt=self.decimal_format,
+ floatfmt=self.float_format,
+ missingval=self.null_string,
+ expanded=expanded,
+ max_width=max_width,
+ case_function=(
+ self.completer.case
+ if self.settings["case_column_headers"]
+ else lambda x: x
+ ),
+ style_output=self.style_output,
+ )
+ execution = time() - start
+ formatted = format_output(title, cur, headers, status, settings)
+
+ output.extend(formatted)
+ total = time() - start
+
+ # Keep track of whether any of the queries are mutating or changing
+ # the database
+ if success:
+ mutated = mutated or is_mutating(status)
+ db_changed = db_changed or has_change_db_cmd(sql)
+ meta_changed = meta_changed or has_meta_cmd(sql)
+ path_changed = path_changed or has_change_path_cmd(sql)
+ else:
+ all_success = False
+
+ meta_query = MetaQuery(
+ text,
+ all_success,
+ total,
+ execution,
+ meta_changed,
+ db_changed,
+ path_changed,
+ mutated,
+ is_special,
+ )
+
+ return output, meta_query
+
+ def _handle_server_closed_connection(self, text):
+ """Used during CLI execution."""
+ try:
+ click.secho("Reconnecting...", fg="green")
+ self.pgexecute.connect()
+ click.secho("Reconnected!", fg="green")
+ self.execute_command(text)
+ except OperationalError as e:
+ click.secho("Reconnect Failed", fg="red")
+ click.secho(str(e), err=True, fg="red")
+
+ def refresh_completions(self, history=None, persist_priorities="all"):
+ """Refresh outdated completions
+
+ :param history: A prompt_toolkit.history.FileHistory object. Used to
+ load keyword and identifier preferences
+
+ :param persist_priorities: 'all' or 'keywords'
+ """
+
+ callback = functools.partial(
+ self._on_completions_refreshed, persist_priorities=persist_priorities
+ )
+ self.completion_refresher.refresh(
+ self.pgexecute,
+ self.pgspecial,
+ callback,
+ history=history,
+ settings=self.settings,
+ )
+ return [
+ (None, None, None, "Auto-completion refresh started in the background.")
+ ]
+
+ def _on_completions_refreshed(self, new_completer, persist_priorities):
+ self._swap_completer_objects(new_completer, persist_priorities)
+
+ if self.prompt_app:
+ # After refreshing, redraw the CLI to clear the statusbar
+ # "Refreshing completions..." indicator
+ self.prompt_app.app.invalidate()
+
+ def _swap_completer_objects(self, new_completer, persist_priorities):
+ """Swap the completer object with the newly created completer.
+
+ persist_priorities is a string specifying how the old completer's
+ learned prioritizer should be transferred to the new completer.
+
+ 'none' - The new prioritizer is left in a new/clean state
+
+ 'all' - The new prioritizer is updated to exactly reflect
+ the old one
+
+ 'keywords' - The new prioritizer is updated with old keyword
+ priorities, but not any other.
+
+ """
+ with self._completer_lock:
+ old_completer = self.completer
+ self.completer = new_completer
+
+ if persist_priorities == "all":
+ # Just swap over the entire prioritizer
+ new_completer.prioritizer = old_completer.prioritizer
+ elif persist_priorities == "keywords":
+ # Swap over the entire prioritizer, but clear name priorities,
+ # leaving learned keyword priorities alone
+ new_completer.prioritizer = old_completer.prioritizer
+ new_completer.prioritizer.clear_names()
+ elif persist_priorities == "none":
+ # Leave the new prioritizer as is
+ pass
+ self.completer = new_completer
+
+ def get_completions(self, text, cursor_positition):
+ with self._completer_lock:
+ return self.completer.get_completions(
+ Document(text=text, cursor_position=cursor_positition), None
+ )
+
+ def get_prompt(self, string):
+ # should be before replacing \\d
+ string = string.replace("\\dsn_alias", self.dsn_alias or "")
+ string = string.replace("\\t", self.now.strftime("%x %X"))
+ string = string.replace("\\u", self.pgexecute.user or "(none)")
+ string = string.replace("\\H", self.pgexecute.host or "(none)")
+ string = string.replace("\\h", self.pgexecute.short_host or "(none)")
+ string = string.replace("\\d", self.pgexecute.dbname or "(none)")
+ string = string.replace(
+ "\\p",
+ str(self.pgexecute.port) if self.pgexecute.port is not None else "5432",
+ )
+ string = string.replace("\\i", str(self.pgexecute.pid) or "(none)")
+ string = string.replace("\\#", "#" if (self.pgexecute.superuser) else ">")
+ string = string.replace("\\n", "\n")
+ return string
+
+ def get_last_query(self):
+ """Get the last query executed or None."""
+ return self.query_history[-1][0] if self.query_history else None
+
+ def is_too_wide(self, line):
+ """Will this line be too wide to fit into terminal?"""
+ if not self.prompt_app:
+ return False
+ return (
+ len(COLOR_CODE_REGEX.sub("", line))
+ > self.prompt_app.output.get_size().columns
+ )
+
+ def is_too_tall(self, lines):
+ """Are there too many lines to fit into terminal?"""
+ if not self.prompt_app:
+ return False
+ return len(lines) >= (self.prompt_app.output.get_size().rows - 4)
+
+ def echo_via_pager(self, text, color=None):
+ if self.pgspecial.pager_config == PAGER_OFF or self.watch_command:
+ click.echo(text, color=color)
+ elif "pspg" in os.environ.get("PAGER", "") and self.table_format == "csv":
+ click.echo_via_pager(text, color)
+ elif self.pgspecial.pager_config == PAGER_LONG_OUTPUT:
+ lines = text.split("\n")
+
+ # The last 4 lines are reserved for the pgcli menu and padding
+ if self.is_too_tall(lines) or any(self.is_too_wide(l) for l in lines):
+ click.echo_via_pager(text, color=color)
+ else:
+ click.echo(text, color=color)
+ else:
+ click.echo_via_pager(text, color)
+
+
+@click.command()
+# Default host is '' so psycopg2 can default to either localhost or unix socket
+@click.option(
+ "-h",
+ "--host",
+ default="",
+ envvar="PGHOST",
+ help="Host address of the postgres database.",
+)
+@click.option(
+ "-p",
+ "--port",
+ default=5432,
+ help="Port number at which the " "postgres instance is listening.",
+ envvar="PGPORT",
+ type=click.INT,
+)
+@click.option(
+ "-U",
+ "--username",
+ "username_opt",
+ help="Username to connect to the postgres database.",
+)
+@click.option(
+ "-u", "--user", "username_opt", help="Username to connect to the postgres database."
+)
+@click.option(
+ "-W",
+ "--password",
+ "prompt_passwd",
+ is_flag=True,
+ default=False,
+ help="Force password prompt.",
+)
+@click.option(
+ "-w",
+ "--no-password",
+ "never_prompt",
+ is_flag=True,
+ default=False,
+ help="Never prompt for password.",
+)
+@click.option(
+ "--single-connection",
+ "single_connection",
+ is_flag=True,
+ default=False,
+ help="Do not use a separate connection for completions.",
+)
+@click.option("-v", "--version", is_flag=True, help="Version of pgcli.")
+@click.option("-d", "--dbname", "dbname_opt", help="database name to connect to.")
+@click.option(
+ "--pgclirc",
+ default=config_location() + "config",
+ envvar="PGCLIRC",
+ help="Location of pgclirc file.",
+ type=click.Path(dir_okay=False),
+)
+@click.option(
+ "-D",
+ "--dsn",
+ default="",
+ envvar="DSN",
+ help="Use DSN configured into the [alias_dsn] section of pgclirc file.",
+)
+@click.option(
+ "--list-dsn",
+ "list_dsn",
+ is_flag=True,
+ help="list of DSN configured into the [alias_dsn] section of pgclirc file.",
+)
+@click.option(
+ "--row-limit",
+ default=None,
+ envvar="PGROWLIMIT",
+ type=click.INT,
+ help="Set threshold for row limit prompt. Use 0 to disable prompt.",
+)
+@click.option(
+ "--less-chatty",
+ "less_chatty",
+ is_flag=True,
+ default=False,
+ help="Skip intro on startup and goodbye on exit.",
+)
+@click.option("--prompt", help='Prompt format (Default: "\\u@\\h:\\d> ").')
+@click.option(
+ "--prompt-dsn",
+ help='Prompt format for connections using DSN aliases (Default: "\\u@\\h:\\d> ").',
+)
+@click.option(
+ "-l",
+ "--list",
+ "list_databases",
+ is_flag=True,
+ help="list " "available databases, then exit.",
+)
+@click.option(
+ "--auto-vertical-output",
+ is_flag=True,
+ help="Automatically switch to vertical output mode if the result is wider than the terminal width.",
+)
+@click.option(
+ "--warn/--no-warn", default=None, help="Warn before running a destructive query."
+)
+@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1)
+@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1)
+def cli(
+ dbname,
+ username_opt,
+ host,
+ port,
+ prompt_passwd,
+ never_prompt,
+ single_connection,
+ dbname_opt,
+ username,
+ version,
+ pgclirc,
+ dsn,
+ row_limit,
+ less_chatty,
+ prompt,
+ prompt_dsn,
+ list_databases,
+ auto_vertical_output,
+ list_dsn,
+ warn,
+):
+ if version:
+ print("Version:", __version__)
+ sys.exit(0)
+
+ config_dir = os.path.dirname(config_location())
+ if not os.path.exists(config_dir):
+ os.makedirs(config_dir)
+
+ # Migrate the config file from old location.
+ config_full_path = config_location() + "config"
+ if os.path.exists(os.path.expanduser("~/.pgclirc")):
+ if not os.path.exists(config_full_path):
+ shutil.move(os.path.expanduser("~/.pgclirc"), config_full_path)
+ print("Config file (~/.pgclirc) moved to new location", config_full_path)
+ else:
+ print("Config file is now located at", config_full_path)
+ print(
+ "Please move the existing config file ~/.pgclirc to",
+ config_full_path,
+ )
+ if list_dsn:
+ try:
+ cfg = load_config(pgclirc, config_full_path)
+ for alias in cfg["alias_dsn"]:
+ click.secho(alias + " : " + cfg["alias_dsn"][alias])
+ sys.exit(0)
+ except Exception as err:
+ click.secho(
+ "Invalid DSNs found in the config file. "
+ 'Please check the "[alias_dsn]" section in pgclirc.',
+ err=True,
+ fg="red",
+ )
+ exit(1)
+
+ pgcli = PGCli(
+ prompt_passwd,
+ never_prompt,
+ pgclirc_file=pgclirc,
+ row_limit=row_limit,
+ single_connection=single_connection,
+ less_chatty=less_chatty,
+ prompt=prompt,
+ prompt_dsn=prompt_dsn,
+ auto_vertical_output=auto_vertical_output,
+ warn=warn,
+ )
+
+ # Choose which ever one has a valid value.
+ if dbname_opt and dbname:
+ # work as psql: when database is given as option and argument use the argument as user
+ username = dbname
+ database = dbname_opt or dbname or ""
+ user = username_opt or username
+ service = None
+ if database.startswith("service="):
+ service = database[8:]
+ elif os.getenv("PGSERVICE") is not None:
+ service = os.getenv("PGSERVICE")
+ # because option --list or -l are not supposed to have a db name
+ if list_databases:
+ database = "postgres"
+
+ if dsn != "":
+ try:
+ cfg = load_config(pgclirc, config_full_path)
+ dsn_config = cfg["alias_dsn"][dsn]
+ except KeyError:
+ click.secho(
+ f"Could not find a DSN with alias {dsn}. "
+ 'Please check the "[alias_dsn]" section in pgclirc.',
+ err=True,
+ fg="red",
+ )
+ exit(1)
+ except Exception:
+ click.secho(
+ "Invalid DSNs found in the config file. "
+ 'Please check the "[alias_dsn]" section in pgclirc.',
+ err=True,
+ fg="red",
+ )
+ exit(1)
+ pgcli.connect_uri(dsn_config)
+ pgcli.dsn_alias = dsn
+ elif "://" in database:
+ pgcli.connect_uri(database)
+ elif "=" in database and service is None:
+ pgcli.connect_dsn(database, user=user)
+ elif service is not None:
+ pgcli.connect_service(service, user)
+ else:
+ pgcli.connect(database, host, user, port)
+
+ if list_databases:
+ cur, headers, status = pgcli.pgexecute.full_databases()
+
+ title = "List of databases"
+ settings = OutputSettings(table_format="ascii", missingval="<null>")
+ formatted = format_output(title, cur, headers, status, settings)
+ pgcli.echo_via_pager("\n".join(formatted))
+
+ sys.exit(0)
+
+ pgcli.logger.debug(
+ "Launch Params: \n" "\tdatabase: %r" "\tuser: %r" "\thost: %r" "\tport: %r",
+ database,
+ user,
+ host,
+ port,
+ )
+
+ if setproctitle:
+ obfuscate_process_password()
+
+ pgcli.run_cli()
+
+
+def obfuscate_process_password():
+ process_title = setproctitle.getproctitle()
+ if "://" in process_title:
+ process_title = re.sub(r":(.*):(.*)@", r":\1:xxxx@", process_title)
+ elif "=" in process_title:
+ process_title = re.sub(
+ r"password=(.+?)((\s[a-zA-Z]+=)|$)", r"password=xxxx\2", process_title
+ )
+
+ setproctitle.setproctitle(process_title)
+
+
+def has_meta_cmd(query):
+ """Determines if the completion needs a refresh by checking if the sql
+ statement is an alter, create, drop, commit or rollback."""
+ try:
+ first_token = query.split()[0]
+ if first_token.lower() in ("alter", "create", "drop", "commit", "rollback"):
+ return True
+ except Exception:
+ return False
+
+ return False
+
+
+def has_change_db_cmd(query):
+ """Determines if the statement is a database switch such as 'use' or '\\c'"""
+ try:
+ first_token = query.split()[0]
+ if first_token.lower() in ("use", "\\c", "\\connect"):
+ return True
+ except Exception:
+ return False
+
+ return False
+
+
+def has_change_path_cmd(sql):
+ """Determines if the search_path should be refreshed by checking if the
+ sql has 'set search_path'."""
+ return "set search_path" in sql.lower()
+
+
+def is_mutating(status):
+ """Determines if the statement is mutating based on the status."""
+ if not status:
+ return False
+
+ mutating = set(["insert", "update", "delete"])
+ return status.split(None, 1)[0].lower() in mutating
+
+
+def is_select(status):
+ """Returns true if the first word in status is 'select'."""
+ if not status:
+ return False
+ return status.split(None, 1)[0].lower() == "select"
+
+
+def exception_formatter(e):
+ return click.style(str(e), fg="red")
+
+
+def format_output(title, cur, headers, status, settings):
+ output = []
+ expanded = settings.expanded or settings.table_format == "vertical"
+ table_format = "vertical" if settings.expanded else settings.table_format
+ max_width = settings.max_width
+ case_function = settings.case_function
+ formatter = TabularOutputFormatter(format_name=table_format)
+
+ def format_array(val):
+ if val is None:
+ return settings.missingval
+ if not isinstance(val, list):
+ return val
+ return "{" + ",".join(str(format_array(e)) for e in val) + "}"
+
+ def format_arrays(data, headers, **_):
+ data = list(data)
+ for row in data:
+ row[:] = [
+ format_array(val) if isinstance(val, list) else val for val in row
+ ]
+
+ return data, headers
+
+ output_kwargs = {
+ "sep_title": "RECORD {n}",
+ "sep_character": "-",
+ "sep_length": (1, 25),
+ "missing_value": settings.missingval,
+ "integer_format": settings.dcmlfmt,
+ "float_format": settings.floatfmt,
+ "preprocessors": (format_numbers, format_arrays),
+ "disable_numparse": True,
+ "preserve_whitespace": True,
+ "style": settings.style_output,
+ }
+ if not settings.floatfmt:
+ output_kwargs["preprocessors"] = (align_decimals,)
+
+ if table_format == "csv":
+ # The default CSV dialect is "excel" which is not handling newline values correctly
+ # Nevertheless, we want to keep on using "excel" on Windows since it uses '\r\n'
+ # as the line terminator
+ # https://github.com/dbcli/pgcli/issues/1102
+ dialect = "excel" if platform.system() == "Windows" else "unix"
+ output_kwargs["dialect"] = dialect
+
+ if title: # Only print the title if it's not None.
+ output.append(title)
+
+ if cur:
+ headers = [case_function(x) for x in headers]
+ if max_width is not None:
+ cur = list(cur)
+ column_types = None
+ if hasattr(cur, "description"):
+ column_types = []
+ for d in cur.description:
+ if (
+ d[1] in psycopg2.extensions.DECIMAL.values
+ or d[1] in psycopg2.extensions.FLOAT.values
+ ):
+ column_types.append(float)
+ if (
+ d[1] == psycopg2.extensions.INTEGER.values
+ or d[1] in psycopg2.extensions.LONGINTEGER.values
+ ):
+ column_types.append(int)
+ else:
+ column_types.append(str)
+
+ formatted = formatter.format_output(cur, headers, **output_kwargs)
+ if isinstance(formatted, str):
+ formatted = iter(formatted.splitlines())
+ first_line = next(formatted)
+ formatted = itertools.chain([first_line], formatted)
+ if not expanded and max_width and len(first_line) > max_width and headers:
+ formatted = formatter.format_output(
+ cur, headers, format_name="vertical", column_types=None, **output_kwargs
+ )
+ if isinstance(formatted, str):
+ formatted = iter(formatted.splitlines())
+
+ output = itertools.chain(output, formatted)
+
+ # Only print the status if it's not None and we are not producing CSV
+ if status and table_format != "csv":
+ output = itertools.chain(output, [status])
+
+ return output
+
+
+def parse_service_info(service):
+ service = service or os.getenv("PGSERVICE")
+ service_file = os.getenv("PGSERVICEFILE")
+ if not service_file:
+ # try ~/.pg_service.conf (if that exists)
+ if platform.system() == "Windows":
+ service_file = os.getenv("PGSYSCONFDIR") + "\\pg_service.conf"
+ elif os.getenv("PGSYSCONFDIR"):
+ service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf")
+ else:
+ service_file = expanduser("~/.pg_service.conf")
+ if not service:
+ # nothing to do
+ return None, service_file
+ service_file_config = ConfigObj(service_file)
+ if service not in service_file_config:
+ return None, service_file
+ service_conf = service_file_config.get(service)
+ return service_conf, service_file
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/pgcli/packages/__init__.py b/pgcli/packages/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/pgcli/packages/__init__.py
diff --git a/pgcli/packages/parseutils/__init__.py b/pgcli/packages/parseutils/__init__.py
new file mode 100644
index 0000000..a11e7bf
--- /dev/null
+++ b/pgcli/packages/parseutils/__init__.py
@@ -0,0 +1,22 @@
+import sqlparse
+
+
+def query_starts_with(query, prefixes):
+ """Check if the query starts with any item from *prefixes*."""
+ prefixes = [prefix.lower() for prefix in prefixes]
+ formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
+ return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
+
+
+def queries_start_with(queries, prefixes):
+ """Check if any queries start with any item from *prefixes*."""
+ for query in sqlparse.split(queries):
+ if query and query_starts_with(query, prefixes) is True:
+ return True
+ return False
+
+
+def is_destructive(queries):
+ """Returns if any of the queries in *queries* is destructive."""
+ keywords = ("drop", "shutdown", "delete", "truncate", "alter")
+ return queries_start_with(queries, keywords)
diff --git a/pgcli/packages/parseutils/ctes.py b/pgcli/packages/parseutils/ctes.py
new file mode 100644
index 0000000..e1f9088
--- /dev/null
+++ b/pgcli/packages/parseutils/ctes.py
@@ -0,0 +1,141 @@
+from sqlparse import parse
+from sqlparse.tokens import Keyword, CTE, DML
+from sqlparse.sql import Identifier, IdentifierList, Parenthesis
+from collections import namedtuple
+from .meta import TableMetadata, ColumnMetadata
+
+
+# TableExpression is a namedtuple representing a CTE, used internally
+# name: cte alias assigned in the query
+# columns: list of column names
+# start: index into the original string of the left parens starting the CTE
+# stop: index into the original string of the right parens ending the CTE
+TableExpression = namedtuple("TableExpression", "name columns start stop")
+
+
+def isolate_query_ctes(full_text, text_before_cursor):
+ """Simplify a query by converting CTEs into table metadata objects"""
+
+ if not full_text or not full_text.strip():
+ return full_text, text_before_cursor, tuple()
+
+ ctes, remainder = extract_ctes(full_text)
+ if not ctes:
+ return full_text, text_before_cursor, ()
+
+ current_position = len(text_before_cursor)
+ meta = []
+
+ for cte in ctes:
+ if cte.start < current_position < cte.stop:
+ # Currently editing a cte - treat its body as the current full_text
+ text_before_cursor = full_text[cte.start : current_position]
+ full_text = full_text[cte.start : cte.stop]
+ return full_text, text_before_cursor, meta
+
+ # Append this cte to the list of available table metadata
+ cols = (ColumnMetadata(name, None, ()) for name in cte.columns)
+ meta.append(TableMetadata(cte.name, cols))
+
+ # Editing past the last cte (ie the main body of the query)
+ full_text = full_text[ctes[-1].stop :]
+ text_before_cursor = text_before_cursor[ctes[-1].stop : current_position]
+
+ return full_text, text_before_cursor, tuple(meta)
+
+
+def extract_ctes(sql):
+ """Extract constant table expresseions from a query
+
+ Returns tuple (ctes, remainder_sql)
+
+ ctes is a list of TableExpression namedtuples
+ remainder_sql is the text from the original query after the CTEs have
+ been stripped.
+ """
+
+ p = parse(sql)[0]
+
+ # Make sure the first meaningful token is "WITH" which is necessary to
+ # define CTEs
+ idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True)
+ if not (tok and tok.ttype == CTE):
+ return [], sql
+
+ # Get the next (meaningful) token, which should be the first CTE
+ idx, tok = p.token_next(idx)
+ if not tok:
+ return ([], "")
+ start_pos = token_start_pos(p.tokens, idx)
+ ctes = []
+
+ if isinstance(tok, IdentifierList):
+ # Multiple ctes
+ for t in tok.get_identifiers():
+ cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t))
+ cte = get_cte_from_token(t, start_pos + cte_start_offset)
+ if not cte:
+ continue
+ ctes.append(cte)
+ elif isinstance(tok, Identifier):
+ # A single CTE
+ cte = get_cte_from_token(tok, start_pos)
+ if cte:
+ ctes.append(cte)
+
+ idx = p.token_index(tok) + 1
+
+ # Collapse everything after the ctes into a remainder query
+ remainder = "".join(str(tok) for tok in p.tokens[idx:])
+
+ return ctes, remainder
+
+
+def get_cte_from_token(tok, pos0):
+ cte_name = tok.get_real_name()
+ if not cte_name:
+ return None
+
+ # Find the start position of the opening parens enclosing the cte body
+ idx, parens = tok.token_next_by(Parenthesis)
+ if not parens:
+ return None
+
+ start_pos = pos0 + token_start_pos(tok.tokens, idx)
+ cte_len = len(str(parens)) # includes parens
+ stop_pos = start_pos + cte_len
+
+ column_names = extract_column_names(parens)
+
+ return TableExpression(cte_name, column_names, start_pos, stop_pos)
+
+
+def extract_column_names(parsed):
+ # Find the first DML token to check if it's a SELECT or INSERT/UPDATE/DELETE
+ idx, tok = parsed.token_next_by(t=DML)
+ tok_val = tok and tok.value.lower()
+
+ if tok_val in ("insert", "update", "delete"):
+ # Jump ahead to the RETURNING clause where the list of column names is
+ idx, tok = parsed.token_next_by(idx, (Keyword, "returning"))
+ elif not tok_val == "select":
+ # Must be invalid CTE
+ return ()
+
+ # The next token should be either a column name, or a list of column names
+ idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True)
+ return tuple(t.get_name() for t in _identifiers(tok))
+
+
+def token_start_pos(tokens, idx):
+ return sum(len(str(t)) for t in tokens[:idx])
+
+
+def _identifiers(tok):
+ if isinstance(tok, IdentifierList):
+ for t in tok.get_identifiers():
+ # NB: IdentifierList.get_identifiers() can return non-identifiers!
+ if isinstance(t, Identifier):
+ yield t
+ elif isinstance(tok, Identifier):
+ yield tok
diff --git a/pgcli/packages/parseutils/meta.py b/pgcli/packages/parseutils/meta.py
new file mode 100644
index 0000000..108c01a
--- /dev/null
+++ b/pgcli/packages/parseutils/meta.py
@@ -0,0 +1,170 @@
+from collections import namedtuple
+
+_ColumnMetadata = namedtuple(
+ "ColumnMetadata", ["name", "datatype", "foreignkeys", "default", "has_default"]
+)
+
+
+def ColumnMetadata(name, datatype, foreignkeys=None, default=None, has_default=False):
+ return _ColumnMetadata(name, datatype, foreignkeys or [], default, has_default)
+
+
+ForeignKey = namedtuple(
+ "ForeignKey",
+ [
+ "parentschema",
+ "parenttable",
+ "parentcolumn",
+ "childschema",
+ "childtable",
+ "childcolumn",
+ ],
+)
+TableMetadata = namedtuple("TableMetadata", "name columns")
+
+
+def parse_defaults(defaults_string):
+ """Yields default values for a function, given the string provided by
+ pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)"""
+ if not defaults_string:
+ return
+ current = ""
+ in_quote = None
+ for char in defaults_string:
+ if current == "" and char == " ":
+ # Skip space after comma separating default expressions
+ continue
+ if char == '"' or char == "'":
+ if in_quote and char == in_quote:
+ # End quote
+ in_quote = None
+ elif not in_quote:
+ # Begin quote
+ in_quote = char
+ elif char == "," and not in_quote:
+ # End of expression
+ yield current
+ current = ""
+ continue
+ current += char
+ yield current
+
+
+class FunctionMetadata(object):
+ def __init__(
+ self,
+ schema_name,
+ func_name,
+ arg_names,
+ arg_types,
+ arg_modes,
+ return_type,
+ is_aggregate,
+ is_window,
+ is_set_returning,
+ is_extension,
+ arg_defaults,
+ ):
+ """Class for describing a postgresql function"""
+
+ self.schema_name = schema_name
+ self.func_name = func_name
+
+ self.arg_modes = tuple(arg_modes) if arg_modes else None
+ self.arg_names = tuple(arg_names) if arg_names else None
+
+ # Be flexible in not requiring arg_types -- use None as a placeholder
+ # for each arg. (Used for compatibility with old versions of postgresql
+ # where such info is hard to get.
+ if arg_types:
+ self.arg_types = tuple(arg_types)
+ elif arg_modes:
+ self.arg_types = tuple([None] * len(arg_modes))
+ elif arg_names:
+ self.arg_types = tuple([None] * len(arg_names))
+ else:
+ self.arg_types = None
+
+ self.arg_defaults = tuple(parse_defaults(arg_defaults))
+
+ self.return_type = return_type.strip()
+ self.is_aggregate = is_aggregate
+ self.is_window = is_window
+ self.is_set_returning = is_set_returning
+ self.is_extension = bool(is_extension)
+ self.is_public = self.schema_name and self.schema_name == "public"
+
+ def __eq__(self, other):
+ return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def _signature(self):
+ return (
+ self.schema_name,
+ self.func_name,
+ self.arg_names,
+ self.arg_types,
+ self.arg_modes,
+ self.return_type,
+ self.is_aggregate,
+ self.is_window,
+ self.is_set_returning,
+ self.is_extension,
+ self.arg_defaults,
+ )
+
+ def __hash__(self):
+ return hash(self._signature())
+
+ def __repr__(self):
+ return (
+ "%s(schema_name=%r, func_name=%r, arg_names=%r, "
+ "arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, "
+ "is_window=%r, is_set_returning=%r, is_extension=%r, arg_defaults=%r)"
+ ) % ((self.__class__.__name__,) + self._signature())
+
+ def has_variadic(self):
+ return self.arg_modes and any(arg_mode == "v" for arg_mode in self.arg_modes)
+
+ def args(self):
+ """Returns a list of input-parameter ColumnMetadata namedtuples."""
+ if not self.arg_names:
+ return []
+ modes = self.arg_modes or ["i"] * len(self.arg_names)
+ args = [
+ (name, typ)
+ for name, typ, mode in zip(self.arg_names, self.arg_types, modes)
+ if mode in ("i", "b", "v") # IN, INOUT, VARIADIC
+ ]
+
+ def arg(name, typ, num):
+ num_args = len(args)
+ num_defaults = len(self.arg_defaults)
+ has_default = num + num_defaults >= num_args
+ default = (
+ self.arg_defaults[num - num_args + num_defaults]
+ if has_default
+ else None
+ )
+ return ColumnMetadata(name, typ, [], default, has_default)
+
+ return [arg(name, typ, num) for num, (name, typ) in enumerate(args)]
+
+ def fields(self):
+ """Returns a list of output-field ColumnMetadata namedtuples"""
+
+ if self.return_type.lower() == "void":
+ return []
+ elif not self.arg_modes:
+ # For functions without output parameters, the function name
+ # is used as the name of the output column.
+ # E.g. 'SELECT unnest FROM unnest(...);'
+ return [ColumnMetadata(self.func_name, self.return_type, [])]
+
+ return [
+ ColumnMetadata(name, typ, [])
+ for name, typ, mode in zip(self.arg_names, self.arg_types, self.arg_modes)
+ if mode in ("o", "b", "t")
+ ] # OUT, INOUT, TABLE
diff --git a/pgcli/packages/parseutils/tables.py b/pgcli/packages/parseutils/tables.py
new file mode 100644
index 0000000..0ec3e69
--- /dev/null
+++ b/pgcli/packages/parseutils/tables.py
@@ -0,0 +1,170 @@
+import sqlparse
+from collections import namedtuple
+from sqlparse.sql import IdentifierList, Identifier, Function
+from sqlparse.tokens import Keyword, DML, Punctuation
+
+TableReference = namedtuple(
+ "TableReference", ["schema", "name", "alias", "is_function"]
+)
+TableReference.ref = property(
+ lambda self: self.alias
+ or (
+ self.name
+ if self.name.islower() or self.name[0] == '"'
+ else '"' + self.name + '"'
+ )
+)
+
+
+# This code is borrowed from sqlparse example script.
+# <url>
+def is_subselect(parsed):
+ if not parsed.is_group:
+ return False
+ for item in parsed.tokens:
+ if item.ttype is DML and item.value.upper() in (
+ "SELECT",
+ "INSERT",
+ "UPDATE",
+ "CREATE",
+ "DELETE",
+ ):
+ return True
+ return False
+
+
+def _identifier_is_function(identifier):
+ return any(isinstance(t, Function) for t in identifier.tokens)
+
+
+def extract_from_part(parsed, stop_at_punctuation=True):
+ tbl_prefix_seen = False
+ for item in parsed.tokens:
+ if tbl_prefix_seen:
+ if is_subselect(item):
+ for x in extract_from_part(item, stop_at_punctuation):
+ yield x
+ elif stop_at_punctuation and item.ttype is Punctuation:
+ return
+ # An incomplete nested select won't be recognized correctly as a
+ # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
+ # the second FROM to trigger this elif condition resulting in a
+ # `return`. So we need to ignore the keyword if the keyword
+ # FROM.
+ # Also 'SELECT * FROM abc JOIN def' will trigger this elif
+ # condition. So we need to ignore the keyword JOIN and its variants
+ # INNER JOIN, FULL OUTER JOIN, etc.
+ elif (
+ item.ttype is Keyword
+ and (not item.value.upper() == "FROM")
+ and (not item.value.upper().endswith("JOIN"))
+ ):
+ tbl_prefix_seen = False
+ else:
+ yield item
+ elif item.ttype is Keyword or item.ttype is Keyword.DML:
+ item_val = item.value.upper()
+ if (
+ item_val
+ in (
+ "COPY",
+ "FROM",
+ "INTO",
+ "UPDATE",
+ "TABLE",
+ )
+ or item_val.endswith("JOIN")
+ ):
+ tbl_prefix_seen = True
+ # 'SELECT a, FROM abc' will detect FROM as part of the column list.
+ # So this check here is necessary.
+ elif isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ if identifier.ttype is Keyword and identifier.value.upper() == "FROM":
+ tbl_prefix_seen = True
+ break
+
+
+def extract_table_identifiers(token_stream, allow_functions=True):
+ """yields tuples of TableReference namedtuples"""
+
+ # We need to do some massaging of the names because postgres is case-
+ # insensitive and '"Foo"' is not the same table as 'Foo' (while 'foo' is)
+ def parse_identifier(item):
+ name = item.get_real_name()
+ schema_name = item.get_parent_name()
+ alias = item.get_alias()
+ if not name:
+ schema_name = None
+ name = item.get_name()
+ alias = alias or name
+ schema_quoted = schema_name and item.value[0] == '"'
+ if schema_name and not schema_quoted:
+ schema_name = schema_name.lower()
+ quote_count = item.value.count('"')
+ name_quoted = quote_count > 2 or (quote_count and not schema_quoted)
+ alias_quoted = alias and item.value[-1] == '"'
+ if alias_quoted or name_quoted and not alias and name.islower():
+ alias = '"' + (alias or name) + '"'
+ if name and not name_quoted and not name.islower():
+ if not alias:
+ alias = name
+ name = name.lower()
+ return schema_name, name, alias
+
+ try:
+ for item in token_stream:
+ if isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ # Sometimes Keywords (such as FROM ) are classified as
+ # identifiers which don't have the get_real_name() method.
+ try:
+ schema_name = identifier.get_parent_name()
+ real_name = identifier.get_real_name()
+ is_function = allow_functions and _identifier_is_function(
+ identifier
+ )
+ except AttributeError:
+ continue
+ if real_name:
+ yield TableReference(
+ schema_name, real_name, identifier.get_alias(), is_function
+ )
+ elif isinstance(item, Identifier):
+ schema_name, real_name, alias = parse_identifier(item)
+ is_function = allow_functions and _identifier_is_function(item)
+
+ yield TableReference(schema_name, real_name, alias, is_function)
+ elif isinstance(item, Function):
+ schema_name, real_name, alias = parse_identifier(item)
+ yield TableReference(None, real_name, alias, allow_functions)
+ except StopIteration:
+ return
+
+
+# extract_tables is inspired from examples in the sqlparse lib.
+def extract_tables(sql):
+ """Extract the table names from an SQL statment.
+
+ Returns a list of TableReference namedtuples
+
+ """
+ parsed = sqlparse.parse(sql)
+ if not parsed:
+ return ()
+
+ # INSERT statements must stop looking for tables at the sign of first
+ # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
+ # abc is the table name, but if we don't stop at the first lparen, then
+ # we'll identify abc, col1 and col2 as table names.
+ insert_stmt = parsed[0].token_first().value.lower() == "insert"
+ stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
+
+ # Kludge: sqlparse mistakenly identifies insert statements as
+ # function calls due to the parenthesized column list, e.g. interprets
+ # "insert into foo (bar, baz)" as a function call to foo with arguments
+ # (bar, baz). So don't allow any identifiers in insert statements
+ # to have is_function=True
+ identifiers = extract_table_identifiers(stream, allow_functions=not insert_stmt)
+ # In the case 'sche.<cursor>', we get an empty TableReference; remove that
+ return tuple(i for i in identifiers if i.name)
diff --git a/pgcli/packages/parseutils/utils.py b/pgcli/packages/parseutils/utils.py
new file mode 100644
index 0000000..034c96e
--- /dev/null
+++ b/pgcli/packages/parseutils/utils.py
@@ -0,0 +1,140 @@
+import re
+import sqlparse
+from sqlparse.sql import Identifier
+from sqlparse.tokens import Token, Error
+
+cleanup_regex = {
+ # This matches only alphanumerics and underscores.
+ "alphanum_underscore": re.compile(r"(\w+)$"),
+ # This matches everything except spaces, parens, colon, and comma
+ "many_punctuations": re.compile(r"([^():,\s]+)$"),
+ # This matches everything except spaces, parens, colon, comma, and period
+ "most_punctuations": re.compile(r"([^\.():,\s]+)$"),
+ # This matches everything except a space.
+ "all_punctuations": re.compile(r"([^\s]+)$"),
+}
+
+
+def last_word(text, include="alphanum_underscore"):
+ r"""
+ Find the last word in a sentence.
+
+ >>> last_word('abc')
+ 'abc'
+ >>> last_word(' abc')
+ 'abc'
+ >>> last_word('')
+ ''
+ >>> last_word(' ')
+ ''
+ >>> last_word('abc ')
+ ''
+ >>> last_word('abc def')
+ 'def'
+ >>> last_word('abc def ')
+ ''
+ >>> last_word('abc def;')
+ ''
+ >>> last_word('bac $def')
+ 'def'
+ >>> last_word('bac $def', include='most_punctuations')
+ '$def'
+ >>> last_word('bac \def', include='most_punctuations')
+ '\\\\def'
+ >>> last_word('bac \def;', include='most_punctuations')
+ '\\\\def;'
+ >>> last_word('bac::def', include='most_punctuations')
+ 'def'
+ >>> last_word('"foo*bar', include='most_punctuations')
+ '"foo*bar'
+ """
+
+ if not text: # Empty string
+ return ""
+
+ if text[-1].isspace():
+ return ""
+ else:
+ regex = cleanup_regex[include]
+ matches = regex.search(text)
+ if matches:
+ return matches.group(0)
+ else:
+ return ""
+
+
+def find_prev_keyword(sql, n_skip=0):
+ """Find the last sql keyword in an SQL statement
+
+ Returns the value of the last keyword, and the text of the query with
+ everything after the last keyword stripped
+ """
+ if not sql.strip():
+ return None, ""
+
+ parsed = sqlparse.parse(sql)[0]
+ flattened = list(parsed.flatten())
+ flattened = flattened[: len(flattened) - n_skip]
+
+ logical_operators = ("AND", "OR", "NOT", "BETWEEN")
+
+ for t in reversed(flattened):
+ if t.value == "(" or (
+ t.is_keyword and (t.value.upper() not in logical_operators)
+ ):
+ # Find the location of token t in the original parsed statement
+ # We can't use parsed.token_index(t) because t may be a child token
+ # inside a TokenList, in which case token_index throws an error
+ # Minimal example:
+ # p = sqlparse.parse('select * from foo where bar')
+ # t = list(p.flatten())[-3] # The "Where" token
+ # p.token_index(t) # Throws ValueError: not in list
+ idx = flattened.index(t)
+
+ # Combine the string values of all tokens in the original list
+ # up to and including the target keyword token t, to produce a
+ # query string with everything after the keyword token removed
+ text = "".join(tok.value for tok in flattened[: idx + 1])
+ return t, text
+
+ return None, ""
+
+
+# Postgresql dollar quote signs look like `$$` or `$tag$`
+dollar_quote_regex = re.compile(r"^\$[^$]*\$$")
+
+
+def is_open_quote(sql):
+ """Returns true if the query contains an unclosed quote"""
+
+ # parsed can contain one or more semi-colon separated commands
+ parsed = sqlparse.parse(sql)
+ return any(_parsed_is_open_quote(p) for p in parsed)
+
+
+def _parsed_is_open_quote(parsed):
+ # Look for unmatched single quotes, or unmatched dollar sign quotes
+ return any(tok.match(Token.Error, ("'", "$")) for tok in parsed.flatten())
+
+
+def parse_partial_identifier(word):
+ """Attempt to parse a (partially typed) word as an identifier
+
+ word may include a schema qualification, like `schema_name.partial_name`
+ or `schema_name.` There may also be unclosed quotation marks, like
+ `"schema`, or `schema."partial_name`
+
+ :param word: string representing a (partially complete) identifier
+ :return: sqlparse.sql.Identifier, or None
+ """
+
+ p = sqlparse.parse(word)[0]
+ n_tok = len(p.tokens)
+ if n_tok == 1 and isinstance(p.tokens[0], Identifier):
+ return p.tokens[0]
+ elif p.token_next_by(m=(Error, '"'))[1]:
+ # An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar'
+ # Close the double quote, then reparse
+ return parse_partial_identifier(word + '"')
+ else:
+ return None
diff --git a/pgcli/packages/pgliterals/__init__.py b/pgcli/packages/pgliterals/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/pgcli/packages/pgliterals/__init__.py
diff --git a/pgcli/packages/pgliterals/main.py b/pgcli/packages/pgliterals/main.py
new file mode 100644
index 0000000..5c39296
--- /dev/null
+++ b/pgcli/packages/pgliterals/main.py
@@ -0,0 +1,15 @@
+import os
+import json
+
+root = os.path.dirname(__file__)
+literal_file = os.path.join(root, "pgliterals.json")
+
+with open(literal_file) as f:
+ literals = json.load(f)
+
+
+def get_literals(literal_type, type_=tuple):
+ # Where `literal_type` is one of 'keywords', 'functions', 'datatypes',
+ # returns a tuple of literal values of that type.
+
+ return type_(literals[literal_type])
diff --git a/pgcli/packages/pgliterals/pgliterals.json b/pgcli/packages/pgliterals/pgliterals.json
new file mode 100644
index 0000000..c7b74b5
--- /dev/null
+++ b/pgcli/packages/pgliterals/pgliterals.json
@@ -0,0 +1,629 @@
+{
+ "keywords": {
+ "ACCESS": [],
+ "ADD": [],
+ "ALL": [],
+ "ALTER": [
+ "AGGREGATE",
+ "COLLATION",
+ "COLUMN",
+ "CONVERSION",
+ "DATABASE",
+ "DEFAULT",
+ "DOMAIN",
+ "EVENT TRIGGER",
+ "EXTENSION",
+ "FOREIGN",
+ "FUNCTION",
+ "GROUP",
+ "INDEX",
+ "LANGUAGE",
+ "LARGE OBJECT",
+ "MATERIALIZED VIEW",
+ "OPERATOR",
+ "POLICY",
+ "ROLE",
+ "RULE",
+ "SCHEMA",
+ "SEQUENCE",
+ "SERVER",
+ "SYSTEM",
+ "TABLE",
+ "TABLESPACE",
+ "TEXT SEARCH",
+ "TRIGGER",
+ "TYPE",
+ "USER",
+ "VIEW"
+ ],
+ "AND": [],
+ "ANY": [],
+ "AS": [],
+ "ASC": [],
+ "AUDIT": [],
+ "BEGIN": [],
+ "BETWEEN": [],
+ "BY": [],
+ "CASE": [],
+ "CHAR": [],
+ "CHECK": [],
+ "CLUSTER": [],
+ "COLUMN": [],
+ "COMMENT": [],
+ "COMMIT": [],
+ "COMPRESS": [],
+ "CONCURRENTLY": [],
+ "CONNECT": [],
+ "COPY": [],
+ "CREATE": [
+ "ACCESS METHOD",
+ "AGGREGATE",
+ "CAST",
+ "COLLATION",
+ "CONVERSION",
+ "DATABASE",
+ "DOMAIN",
+ "EVENT TRIGGER",
+ "EXTENSION",
+ "FOREIGN DATA WRAPPER",
+ "FOREIGN EXTENSION",
+ "FUNCTION",
+ "GLOBAL",
+ "GROUP",
+ "IF NOT EXISTS",
+ "INDEX",
+ "LANGUAGE",
+ "LOCAL",
+ "MATERIALIZED VIEW",
+ "OPERATOR",
+ "OR REPLACE",
+ "POLICY",
+ "ROLE",
+ "RULE",
+ "SCHEMA",
+ "SEQUENCE",
+ "SERVER",
+ "TABLE",
+ "TABLESPACE",
+ "TEMPORARY",
+ "TEXT SEARCH",
+ "TRIGGER",
+ "TYPE",
+ "UNIQUE",
+ "UNLOGGED",
+ "USER",
+ "USER MAPPING",
+ "VIEW"
+ ],
+ "CURRENT": [],
+ "DATABASE": [],
+ "DATE": [],
+ "DECIMAL": [],
+ "DEFAULT": [],
+ "DELETE FROM": [],
+ "DELIMITER": [],
+ "DESC": [],
+ "DESCRIBE": [],
+ "DISTINCT": [],
+ "DROP": [
+ "ACCESS METHOD",
+ "AGGREGATE",
+ "CAST",
+ "COLLATION",
+ "COLUMN",
+ "CONVERSION",
+ "DATABASE",
+ "DOMAIN",
+ "EVENT TRIGGER",
+ "EXTENSION",
+ "FOREIGN DATA WRAPPER",
+ "FOREIGN TABLE",
+ "FUNCTION",
+ "GROUP",
+ "INDEX",
+ "LANGUAGE",
+ "MATERIALIZED VIEW",
+ "OPERATOR",
+ "OWNED",
+ "POLICY",
+ "ROLE",
+ "RULE",
+ "SCHEMA",
+ "SEQUENCE",
+ "SERVER",
+ "TABLE",
+ "TABLESPACE",
+ "TEXT SEARCH",
+ "TRANSFORM",
+ "TRIGGER",
+ "TYPE",
+ "USER",
+ "USER MAPPING",
+ "VIEW"
+ ],
+ "EXPLAIN": [],
+ "ELSE": [],
+ "ENCODING": [],
+ "ESCAPE": [],
+ "EXCLUSIVE": [],
+ "EXISTS": [],
+ "EXTENSION": [],
+ "FILE": [],
+ "FLOAT": [],
+ "FOR": [],
+ "FORMAT": [],
+ "FORCE_QUOTE": [],
+ "FORCE_NOT_NULL": [],
+ "FREEZE": [],
+ "FROM": [],
+ "FULL": [],
+ "FUNCTION": [],
+ "GRANT": [],
+ "GROUP BY": [],
+ "HAVING": [],
+ "HEADER": [],
+ "IDENTIFIED": [],
+ "IMMEDIATE": [],
+ "IN": [],
+ "INCREMENT": [],
+ "INDEX": [],
+ "INITIAL": [],
+ "INSERT INTO": [],
+ "INTEGER": [],
+ "INTERSECT": [],
+ "INTERVAL": [],
+ "INTO": [],
+ "IS": [],
+ "JOIN": [],
+ "LANGUAGE": [],
+ "LEFT": [],
+ "LEVEL": [],
+ "LIKE": [],
+ "LIMIT": [],
+ "LOCK": [],
+ "LONG": [],
+ "MATERIALIZED VIEW": [],
+ "MAXEXTENTS": [],
+ "MINUS": [],
+ "MLSLABEL": [],
+ "MODE": [],
+ "MODIFY": [],
+ "NOT": [],
+ "NOAUDIT": [],
+ "NOTICE": [],
+ "NOCOMPRESS": [],
+ "NOWAIT": [],
+ "NULL": [],
+ "NUMBER": [],
+ "OIDS": [],
+ "OF": [],
+ "OFFLINE": [],
+ "ON": [],
+ "ONLINE": [],
+ "OPTION": [],
+ "OR": [],
+ "ORDER BY": [],
+ "OUTER": [],
+ "OWNER": [],
+ "PCTFREE": [],
+ "PRIMARY": [],
+ "PRIOR": [],
+ "PRIVILEGES": [],
+ "QUOTE": [],
+ "RAISE": [],
+ "RENAME": [],
+ "REPLACE": [],
+ "RESET": ["ALL"],
+ "RAW": [],
+ "REFRESH MATERIALIZED VIEW": [],
+ "RESOURCE": [],
+ "RETURNS": [],
+ "REVOKE": [],
+ "RIGHT": [],
+ "ROLLBACK": [],
+ "ROW": [],
+ "ROWID": [],
+ "ROWNUM": [],
+ "ROWS": [],
+ "SELECT": [],
+ "SESSION": [],
+ "SET": [],
+ "SHARE": [],
+ "SHOW": [],
+ "SIZE": [],
+ "SMALLINT": [],
+ "START": [],
+ "SUCCESSFUL": [],
+ "SYNONYM": [],
+ "SYSDATE": [],
+ "TABLE": [],
+ "TEMPLATE": [],
+ "THEN": [],
+ "TO": [],
+ "TRIGGER": [],
+ "TRUNCATE": [],
+ "UID": [],
+ "UNION": [],
+ "UNIQUE": [],
+ "UPDATE": [],
+ "USE": [],
+ "USER": [],
+ "USING": [],
+ "VALIDATE": [],
+ "VALUES": [],
+ "VARCHAR": [],
+ "VARCHAR2": [],
+ "VIEW": [],
+ "WHEN": [],
+ "WHENEVER": [],
+ "WHERE": [],
+ "WITH": []
+ },
+ "functions": [
+ "ABBREV",
+ "ABS",
+ "AGE",
+ "AREA",
+ "ARRAY_AGG",
+ "ARRAY_APPEND",
+ "ARRAY_CAT",
+ "ARRAY_DIMS",
+ "ARRAY_FILL",
+ "ARRAY_LENGTH",
+ "ARRAY_LOWER",
+ "ARRAY_NDIMS",
+ "ARRAY_POSITION",
+ "ARRAY_POSITIONS",
+ "ARRAY_PREPEND",
+ "ARRAY_REMOVE",
+ "ARRAY_REPLACE",
+ "ARRAY_TO_STRING",
+ "ARRAY_UPPER",
+ "ASCII",
+ "AVG",
+ "BIT_AND",
+ "BIT_LENGTH",
+ "BIT_OR",
+ "BOOL_AND",
+ "BOOL_OR",
+ "BOUND_BOX",
+ "BOX",
+ "BROADCAST",
+ "BTRIM",
+ "CARDINALITY",
+ "CBRT",
+ "CEIL",
+ "CEILING",
+ "CENTER",
+ "CHAR_LENGTH",
+ "CHR",
+ "CIRCLE",
+ "CLOCK_TIMESTAMP",
+ "CONCAT",
+ "CONCAT_WS",
+ "CONVERT",
+ "CONVERT_FROM",
+ "CONVERT_TO",
+ "COUNT",
+ "CUME_DIST",
+ "CURRENT_DATE",
+ "CURRENT_TIME",
+ "CURRENT_TIMESTAMP",
+ "DATE_PART",
+ "DATE_TRUNC",
+ "DECODE",
+ "DEGREES",
+ "DENSE_RANK",
+ "DIAMETER",
+ "DIV",
+ "ENCODE",
+ "ENUM_FIRST",
+ "ENUM_LAST",
+ "ENUM_RANGE",
+ "EVERY",
+ "EXP",
+ "EXTRACT",
+ "FAMILY",
+ "FIRST_VALUE",
+ "FLOOR",
+ "FORMAT",
+ "GET_BIT",
+ "GET_BYTE",
+ "HEIGHT",
+ "HOST",
+ "HOSTMASK",
+ "INET_MERGE",
+ "INET_SAME_FAMILY",
+ "INITCAP",
+ "ISCLOSED",
+ "ISFINITE",
+ "ISOPEN",
+ "JUSTIFY_DAYS",
+ "JUSTIFY_HOURS",
+ "JUSTIFY_INTERVAL",
+ "LAG",
+ "LAST_VALUE",
+ "LEAD",
+ "LEFT",
+ "LENGTH",
+ "LINE",
+ "LN",
+ "LOCALTIME",
+ "LOCALTIMESTAMP",
+ "LOG",
+ "LOG10",
+ "LOWER",
+ "LPAD",
+ "LSEG",
+ "LTRIM",
+ "MAKE_DATE",
+ "MAKE_INTERVAL",
+ "MAKE_TIME",
+ "MAKE_TIMESTAMP",
+ "MAKE_TIMESTAMPTZ",
+ "MASKLEN",
+ "MAX",
+ "MD5",
+ "MIN",
+ "MOD",
+ "NETMASK",
+ "NETWORK",
+ "NOW",
+ "NPOINTS",
+ "NTH_VALUE",
+ "NTILE",
+ "NUM_NONNULLS",
+ "NUM_NULLS",
+ "OCTET_LENGTH",
+ "OVERLAY",
+ "PARSE_IDENT",
+ "PATH",
+ "PCLOSE",
+ "PERCENT_RANK",
+ "PG_CLIENT_ENCODING",
+ "PI",
+ "POINT",
+ "POLYGON",
+ "POPEN",
+ "POSITION",
+ "POWER",
+ "QUOTE_IDENT",
+ "QUOTE_LITERAL",
+ "QUOTE_NULLABLE",
+ "RADIANS",
+ "RADIUS",
+ "RANK",
+ "REGEXP_MATCH",
+ "REGEXP_MATCHES",
+ "REGEXP_REPLACE",
+ "REGEXP_SPLIT_TO_ARRAY",
+ "REGEXP_SPLIT_TO_TABLE",
+ "REPEAT",
+ "REPLACE",
+ "REVERSE",
+ "RIGHT",
+ "ROUND",
+ "ROW_NUMBER",
+ "RPAD",
+ "RTRIM",
+ "SCALE",
+ "SET_BIT",
+ "SET_BYTE",
+ "SET_MASKLEN",
+ "SHA224",
+ "SHA256",
+ "SHA384",
+ "SHA512",
+ "SIGN",
+ "SPLIT_PART",
+ "SQRT",
+ "STARTS_WITH",
+ "STATEMENT_TIMESTAMP",
+ "STRING_TO_ARRAY",
+ "STRPOS",
+ "SUBSTR",
+ "SUBSTRING",
+ "SUM",
+ "TEXT",
+ "TIMEOFDAY",
+ "TO_ASCII",
+ "TO_CHAR",
+ "TO_DATE",
+ "TO_HEX",
+ "TO_NUMBER",
+ "TO_TIMESTAMP",
+ "TRANSACTION_TIMESTAMP",
+ "TRANSLATE",
+ "TRIM",
+ "TRUNC",
+ "UNNEST",
+ "UPPER",
+ "WIDTH",
+ "WIDTH_BUCKET",
+ "XMLAGG"
+ ],
+ "datatypes": [
+ "ANY",
+ "ANYARRAY",
+ "ANYELEMENT",
+ "ANYENUM",
+ "ANYNONARRAY",
+ "ANYRANGE",
+ "BIGINT",
+ "BIGSERIAL",
+ "BIT",
+ "BIT VARYING",
+ "BOOL",
+ "BOOLEAN",
+ "BOX",
+ "BYTEA",
+ "CHAR",
+ "CHARACTER",
+ "CHARACTER VARYING",
+ "CIDR",
+ "CIRCLE",
+ "CSTRING",
+ "DATE",
+ "DECIMAL",
+ "DOUBLE PRECISION",
+ "EVENT_TRIGGER",
+ "FDW_HANDLER",
+ "FLOAT4",
+ "FLOAT8",
+ "INET",
+ "INT",
+ "INT2",
+ "INT4",
+ "INT8",
+ "INTEGER",
+ "INTERNAL",
+ "INTERVAL",
+ "JSON",
+ "JSONB",
+ "LANGUAGE_HANDLER",
+ "LINE",
+ "LSEG",
+ "MACADDR",
+ "MACADDR8",
+ "MONEY",
+ "NUMERIC",
+ "OID",
+ "OPAQUE",
+ "PATH",
+ "PG_LSN",
+ "POINT",
+ "POLYGON",
+ "REAL",
+ "RECORD",
+ "REGCLASS",
+ "REGCONFIG",
+ "REGDICTIONARY",
+ "REGNAMESPACE",
+ "REGOPER",
+ "REGOPERATOR",
+ "REGPROC",
+ "REGPROCEDURE",
+ "REGROLE",
+ "REGTYPE",
+ "SERIAL",
+ "SERIAL2",
+ "SERIAL4",
+ "SERIAL8",
+ "SMALLINT",
+ "SMALLSERIAL",
+ "TEXT",
+ "TIME",
+ "TIMESTAMP",
+ "TRIGGER",
+ "TSQUERY",
+ "TSVECTOR",
+ "TXID_SNAPSHOT",
+ "UUID",
+ "VARBIT",
+ "VARCHAR",
+ "VOID",
+ "XML"
+ ],
+ "reserved": [
+ "ALL",
+ "ANALYSE",
+ "ANALYZE",
+ "AND",
+ "ANY",
+ "ARRAY",
+ "AS",
+ "ASC",
+ "ASYMMETRIC",
+ "BOTH",
+ "CASE",
+ "CAST",
+ "CHECK",
+ "COLLATE",
+ "COLUMN",
+ "CONSTRAINT",
+ "CREATE",
+ "CURRENT_CATALOG",
+ "CURRENT_DATE",
+ "CURRENT_ROLE",
+ "CURRENT_TIME",
+ "CURRENT_TIMESTAMP",
+ "CURRENT_USER",
+ "DEFAULT",
+ "DEFERRABLE",
+ "DESC",
+ "DISTINCT",
+ "DO",
+ "ELSE",
+ "END",
+ "EXCEPT",
+ "FALSE",
+ "FETCH",
+ "FOR",
+ "FOREIGN",
+ "FROM",
+ "GRANT",
+ "GROUP",
+ "HAVING",
+ "IN",
+ "INITIALLY",
+ "INTERSECT",
+ "INTO",
+ "LATERAL",
+ "LEADING",
+ "LIMIT",
+ "LOCALTIME",
+ "LOCALTIMESTAMP",
+ "NOT",
+ "NULL",
+ "OFFSET",
+ "ON",
+ "ONLY",
+ "OR",
+ "ORDER",
+ "PLACING",
+ "PRIMARY",
+ "REFERENCES",
+ "RETURNING",
+ "SELECT",
+ "SESSION_USER",
+ "SOME",
+ "SYMMETRIC",
+ "TABLE",
+ "THEN",
+ "TO",
+ "TRAILING",
+ "TRUE",
+ "UNION",
+ "UNIQUE",
+ "USER",
+ "USING",
+ "VARIADIC",
+ "WHEN",
+ "WHERE",
+ "WINDOW",
+ "WITH",
+ "AUTHORIZATION",
+ "BINARY",
+ "COLLATION",
+ "CONCURRENTLY",
+ "CROSS",
+ "CURRENT_SCHEMA",
+ "FREEZE",
+ "FULL",
+ "ILIKE",
+ "INNER",
+ "IS",
+ "ISNULL",
+ "JOIN",
+ "LEFT",
+ "LIKE",
+ "NATURAL",
+ "NOTNULL",
+ "OUTER",
+ "OVERLAPS",
+ "RIGHT",
+ "SIMILAR",
+ "TABLESAMPLE",
+ "VERBOSE"
+ ]
+}
diff --git a/pgcli/packages/prioritization.py b/pgcli/packages/prioritization.py
new file mode 100644
index 0000000..e92dcbb
--- /dev/null
+++ b/pgcli/packages/prioritization.py
@@ -0,0 +1,51 @@
+import re
+import sqlparse
+from sqlparse.tokens import Name
+from collections import defaultdict
+from .pgliterals.main import get_literals
+
+
+white_space_regex = re.compile("\\s+", re.MULTILINE)
+
+
+def _compile_regex(keyword):
+ # Surround the keyword with word boundaries and replace interior whitespace
+ # with whitespace wildcards
+ pattern = "\\b" + white_space_regex.sub(r"\\s+", keyword) + "\\b"
+ return re.compile(pattern, re.MULTILINE | re.IGNORECASE)
+
+
+keywords = get_literals("keywords")
+keyword_regexs = dict((kw, _compile_regex(kw)) for kw in keywords)
+
+
+class PrevalenceCounter(object):
+ def __init__(self):
+ self.keyword_counts = defaultdict(int)
+ self.name_counts = defaultdict(int)
+
+ def update(self, text):
+ self.update_keywords(text)
+ self.update_names(text)
+
+ def update_names(self, text):
+ for parsed in sqlparse.parse(text):
+ for token in parsed.flatten():
+ if token.ttype in Name:
+ self.name_counts[token.value] += 1
+
+ def clear_names(self):
+ self.name_counts = defaultdict(int)
+
+ def update_keywords(self, text):
+ # Count keywords. Can't rely for sqlparse for this, because it's
+ # database agnostic
+ for keyword, regex in keyword_regexs.items():
+ for _ in regex.finditer(text):
+ self.keyword_counts[keyword] += 1
+
+ def keyword_count(self, keyword):
+ return self.keyword_counts[keyword]
+
+ def name_count(self, name):
+ return self.name_counts[name]
diff --git a/pgcli/packages/prompt_utils.py b/pgcli/packages/prompt_utils.py
new file mode 100644
index 0000000..3c58490
--- /dev/null
+++ b/pgcli/packages/prompt_utils.py
@@ -0,0 +1,35 @@
+import sys
+import click
+from .parseutils import is_destructive
+
+
+def confirm_destructive_query(queries):
+ """Check if the query is destructive and prompts the user to confirm.
+
+ Returns:
+ * None if the query is non-destructive or we can't prompt the user.
+ * True if the query is destructive and the user wants to proceed.
+ * False if the query is destructive and the user doesn't want to proceed.
+
+ """
+ prompt_text = (
+ "You're about to run a destructive command.\n" "Do you want to proceed? (y/n)"
+ )
+ if is_destructive(queries) and sys.stdin.isatty():
+ return prompt(prompt_text, type=bool)
+
+
+def confirm(*args, **kwargs):
+ """Prompt for confirmation (yes/no) and handle any abort exceptions."""
+ try:
+ return click.confirm(*args, **kwargs)
+ except click.Abort:
+ return False
+
+
+def prompt(*args, **kwargs):
+ """Prompt the user for input and handle any abort exceptions."""
+ try:
+ return click.prompt(*args, **kwargs)
+ except click.Abort:
+ return False
diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py
new file mode 100644
index 0000000..6ef8859
--- /dev/null
+++ b/pgcli/packages/sqlcompletion.py
@@ -0,0 +1,608 @@
+import sys
+import re
+import sqlparse
+from collections import namedtuple
+from sqlparse.sql import Comparison, Identifier, Where
+from .parseutils.utils import last_word, find_prev_keyword, parse_partial_identifier
+from .parseutils.tables import extract_tables
+from .parseutils.ctes import isolate_query_ctes
+from pgspecial.main import parse_special_command
+
+
+Special = namedtuple("Special", [])
+Database = namedtuple("Database", [])
+Schema = namedtuple("Schema", ["quoted"])
+Schema.__new__.__defaults__ = (False,)
+# FromClauseItem is a table/view/function used in the FROM clause
+# `table_refs` contains the list of tables/... already in the statement,
+# used to ensure that the alias we suggest is unique
+FromClauseItem = namedtuple("FromClauseItem", "schema table_refs local_tables")
+Table = namedtuple("Table", ["schema", "table_refs", "local_tables"])
+TableFormat = namedtuple("TableFormat", [])
+View = namedtuple("View", ["schema", "table_refs"])
+# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid'
+JoinCondition = namedtuple("JoinCondition", ["table_refs", "parent"])
+# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid'
+Join = namedtuple("Join", ["table_refs", "schema"])
+
+Function = namedtuple("Function", ["schema", "table_refs", "usage"])
+# For convenience, don't require the `usage` argument in Function constructor
+Function.__new__.__defaults__ = (None, tuple(), None)
+Table.__new__.__defaults__ = (None, tuple(), tuple())
+View.__new__.__defaults__ = (None, tuple())
+FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple())
+
+Column = namedtuple(
+ "Column",
+ ["table_refs", "require_last_table", "local_tables", "qualifiable", "context"],
+)
+Column.__new__.__defaults__ = (None, None, tuple(), False, None)
+
+Keyword = namedtuple("Keyword", ["last_token"])
+Keyword.__new__.__defaults__ = (None,)
+NamedQuery = namedtuple("NamedQuery", [])
+Datatype = namedtuple("Datatype", ["schema"])
+Alias = namedtuple("Alias", ["aliases"])
+
+Path = namedtuple("Path", [])
+
+
+class SqlStatement(object):
+ def __init__(self, full_text, text_before_cursor):
+ self.identifier = None
+ self.word_before_cursor = word_before_cursor = last_word(
+ text_before_cursor, include="many_punctuations"
+ )
+ full_text = _strip_named_query(full_text)
+ text_before_cursor = _strip_named_query(text_before_cursor)
+
+ full_text, text_before_cursor, self.local_tables = isolate_query_ctes(
+ full_text, text_before_cursor
+ )
+
+ self.text_before_cursor_including_last_word = text_before_cursor
+
+ # If we've partially typed a word then word_before_cursor won't be an
+ # empty string. In that case we want to remove the partially typed
+ # string before sending it to the sqlparser. Otherwise the last token
+ # will always be the partially typed string which renders the smart
+ # completion useless because it will always return the list of
+ # keywords as completion.
+ if self.word_before_cursor:
+ if word_before_cursor[-1] == "(" or word_before_cursor[0] == "\\":
+ parsed = sqlparse.parse(text_before_cursor)
+ else:
+ text_before_cursor = text_before_cursor[: -len(word_before_cursor)]
+ parsed = sqlparse.parse(text_before_cursor)
+ self.identifier = parse_partial_identifier(word_before_cursor)
+ else:
+ parsed = sqlparse.parse(text_before_cursor)
+
+ full_text, text_before_cursor, parsed = _split_multiple_statements(
+ full_text, text_before_cursor, parsed
+ )
+
+ self.full_text = full_text
+ self.text_before_cursor = text_before_cursor
+ self.parsed = parsed
+
+ self.last_token = parsed and parsed.token_prev(len(parsed.tokens))[1] or ""
+
+ def is_insert(self):
+ return self.parsed.token_first().value.lower() == "insert"
+
+ def get_tables(self, scope="full"):
+ """Gets the tables available in the statement.
+ param `scope:` possible values: 'full', 'insert', 'before'
+ If 'insert', only the first table is returned.
+ If 'before', only tables before the cursor are returned.
+ If not 'insert' and the stmt is an insert, the first table is skipped.
+ """
+ tables = extract_tables(
+ self.full_text if scope == "full" else self.text_before_cursor
+ )
+ if scope == "insert":
+ tables = tables[:1]
+ elif self.is_insert():
+ tables = tables[1:]
+ return tables
+
+ def get_previous_token(self, token):
+ return self.parsed.token_prev(self.parsed.token_index(token))[1]
+
+ def get_identifier_schema(self):
+ schema = (self.identifier and self.identifier.get_parent_name()) or None
+ # If schema name is unquoted, lower-case it
+ if schema and self.identifier.value[0] != '"':
+ schema = schema.lower()
+
+ return schema
+
+ def reduce_to_prev_keyword(self, n_skip=0):
+ prev_keyword, self.text_before_cursor = find_prev_keyword(
+ self.text_before_cursor, n_skip=n_skip
+ )
+ return prev_keyword
+
+
+def suggest_type(full_text, text_before_cursor):
+ """Takes the full_text that is typed so far and also the text before the
+ cursor to suggest completion type and scope.
+
+ Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
+ A scope for a column category will be a list of tables.
+ """
+
+ if full_text.startswith("\\i "):
+ return (Path(),)
+
+ # This is a temporary hack; the exception handling
+ # here should be removed once sqlparse has been fixed
+ try:
+ stmt = SqlStatement(full_text, text_before_cursor)
+ except (TypeError, AttributeError):
+ return []
+
+ # Check for special commands and handle those separately
+ if stmt.parsed:
+ # Be careful here because trivial whitespace is parsed as a
+ # statement, but the statement won't have a first token
+ tok1 = stmt.parsed.token_first()
+ if tok1 and tok1.value.startswith("\\"):
+ text = stmt.text_before_cursor + stmt.word_before_cursor
+ return suggest_special(text)
+
+ return suggest_based_on_last_token(stmt.last_token, stmt)
+
+
+named_query_regex = re.compile(r"^\s*\\ns\s+[A-z0-9\-_]+\s+")
+
+
+def _strip_named_query(txt):
+ """
+ This will strip "save named query" command in the beginning of the line:
+ '\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
+ ' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
+ """
+
+ if named_query_regex.match(txt):
+ txt = named_query_regex.sub("", txt)
+ return txt
+
+
+function_body_pattern = re.compile(r"(\$.*?\$)([\s\S]*?)\1", re.M)
+
+
+def _find_function_body(text):
+ split = function_body_pattern.search(text)
+ return (split.start(2), split.end(2)) if split else (None, None)
+
+
+def _statement_from_function(full_text, text_before_cursor, statement):
+ current_pos = len(text_before_cursor)
+ body_start, body_end = _find_function_body(full_text)
+ if body_start is None:
+ return full_text, text_before_cursor, statement
+ if not body_start <= current_pos < body_end:
+ return full_text, text_before_cursor, statement
+ full_text = full_text[body_start:body_end]
+ text_before_cursor = text_before_cursor[body_start:]
+ parsed = sqlparse.parse(text_before_cursor)
+ return _split_multiple_statements(full_text, text_before_cursor, parsed)
+
+
+def _split_multiple_statements(full_text, text_before_cursor, parsed):
+ if len(parsed) > 1:
+ # Multiple statements being edited -- isolate the current one by
+ # cumulatively summing statement lengths to find the one that bounds
+ # the current position
+ current_pos = len(text_before_cursor)
+ stmt_start, stmt_end = 0, 0
+
+ for statement in parsed:
+ stmt_len = len(str(statement))
+ stmt_start, stmt_end = stmt_end, stmt_end + stmt_len
+
+ if stmt_end >= current_pos:
+ text_before_cursor = full_text[stmt_start:current_pos]
+ full_text = full_text[stmt_start:]
+ break
+
+ elif parsed:
+ # A single statement
+ statement = parsed[0]
+ else:
+ # The empty string
+ return full_text, text_before_cursor, None
+
+ token2 = None
+ if statement.get_type() in ("CREATE", "CREATE OR REPLACE"):
+ token1 = statement.token_first()
+ if token1:
+ token1_idx = statement.token_index(token1)
+ token2 = statement.token_next(token1_idx)[1]
+ if token2 and token2.value.upper() == "FUNCTION":
+ full_text, text_before_cursor, statement = _statement_from_function(
+ full_text, text_before_cursor, statement
+ )
+ return full_text, text_before_cursor, statement
+
+
+SPECIALS_SUGGESTION = {
+ "dT": Datatype,
+ "df": Function,
+ "dt": Table,
+ "dv": View,
+ "sf": Function,
+}
+
+
+def suggest_special(text):
+ text = text.lstrip()
+ cmd, _, arg = parse_special_command(text)
+
+ if cmd == text:
+ # Trying to complete the special command itself
+ return (Special(),)
+
+ if cmd in ("\\c", "\\connect"):
+ return (Database(),)
+
+ if cmd == "\\T":
+ return (TableFormat(),)
+
+ if cmd == "\\dn":
+ return (Schema(),)
+
+ if arg:
+ # Try to distinguish "\d name" from "\d schema.name"
+ # Note that this will fail to obtain a schema name if wildcards are
+ # used, e.g. "\d schema???.name"
+ parsed = sqlparse.parse(arg)[0].tokens[0]
+ try:
+ schema = parsed.get_parent_name()
+ except AttributeError:
+ schema = None
+ else:
+ schema = None
+
+ if cmd[1:] == "d":
+ # \d can describe tables or views
+ if schema:
+ return (Table(schema=schema), View(schema=schema))
+ else:
+ return (Schema(), Table(schema=None), View(schema=None))
+ elif cmd[1:] in SPECIALS_SUGGESTION:
+ rel_type = SPECIALS_SUGGESTION[cmd[1:]]
+ if schema:
+ if rel_type == Function:
+ return (Function(schema=schema, usage="special"),)
+ return (rel_type(schema=schema),)
+ else:
+ if rel_type == Function:
+ return (Schema(), Function(schema=None, usage="special"))
+ return (Schema(), rel_type(schema=None))
+
+ if cmd in ["\\n", "\\ns", "\\nd"]:
+ return (NamedQuery(),)
+
+ return (Keyword(), Special())
+
+
+def suggest_based_on_last_token(token, stmt):
+
+ if isinstance(token, str):
+ token_v = token.lower()
+ elif isinstance(token, Comparison):
+ # If 'token' is a Comparison type such as
+ # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling
+ # token.value on the comparison type will only return the lhs of the
+ # comparison. In this case a.id. So we need to do token.tokens to get
+ # both sides of the comparison and pick the last token out of that
+ # list.
+ token_v = token.tokens[-1].value.lower()
+ elif isinstance(token, Where):
+ # sqlparse groups all tokens from the where clause into a single token
+ # list. This means that token.value may be something like
+ # 'where foo > 5 and '. We need to look "inside" token.tokens to handle
+ # suggestions in complicated where clauses correctly
+ prev_keyword = stmt.reduce_to_prev_keyword()
+ return suggest_based_on_last_token(prev_keyword, stmt)
+ elif isinstance(token, Identifier):
+ # If the previous token is an identifier, we can suggest datatypes if
+ # we're in a parenthesized column/field list, e.g.:
+ # CREATE TABLE foo (Identifier <CURSOR>
+ # CREATE FUNCTION foo (Identifier <CURSOR>
+ # If we're not in a parenthesized list, the most likely scenario is the
+ # user is about to specify an alias, e.g.:
+ # SELECT Identifier <CURSOR>
+ # SELECT foo FROM Identifier <CURSOR>
+ prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor)
+ if prev_keyword and prev_keyword.value == "(":
+ # Suggest datatypes
+ return suggest_based_on_last_token("type", stmt)
+ else:
+ return (Keyword(),)
+ else:
+ token_v = token.value.lower()
+
+ if not token:
+ return (Keyword(), Special())
+ elif token_v.endswith("("):
+ p = sqlparse.parse(stmt.text_before_cursor)[0]
+
+ if p.tokens and isinstance(p.tokens[-1], Where):
+ # Four possibilities:
+ # 1 - Parenthesized clause like "WHERE foo AND ("
+ # Suggest columns/functions
+ # 2 - Function call like "WHERE foo("
+ # Suggest columns/functions
+ # 3 - Subquery expression like "WHERE EXISTS ("
+ # Suggest keywords, in order to do a subquery
+ # 4 - Subquery OR array comparison like "WHERE foo = ANY("
+ # Suggest columns/functions AND keywords. (If we wanted to be
+ # really fancy, we could suggest only array-typed columns)
+
+ column_suggestions = suggest_based_on_last_token("where", stmt)
+
+ # Check for a subquery expression (cases 3 & 4)
+ where = p.tokens[-1]
+ prev_tok = where.token_prev(len(where.tokens) - 1)[1]
+
+ if isinstance(prev_tok, Comparison):
+ # e.g. "SELECT foo FROM bar WHERE foo = ANY("
+ prev_tok = prev_tok.tokens[-1]
+
+ prev_tok = prev_tok.value.lower()
+ if prev_tok == "exists":
+ return (Keyword(),)
+ else:
+ return column_suggestions
+
+ # Get the token before the parens
+ prev_tok = p.token_prev(len(p.tokens) - 1)[1]
+
+ if (
+ prev_tok
+ and prev_tok.value
+ and prev_tok.value.lower().split(" ")[-1] == "using"
+ ):
+ # tbl1 INNER JOIN tbl2 USING (col1, col2)
+ tables = stmt.get_tables("before")
+
+ # suggest columns that are present in more than one table
+ return (
+ Column(
+ table_refs=tables,
+ require_last_table=True,
+ local_tables=stmt.local_tables,
+ ),
+ )
+
+ elif p.token_first().value.lower() == "select":
+ # If the lparen is preceeded by a space chances are we're about to
+ # do a sub-select.
+ if last_word(stmt.text_before_cursor, "all_punctuations").startswith("("):
+ return (Keyword(),)
+ prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1]
+ if prev_prev_tok and prev_prev_tok.normalized == "INTO":
+ return (Column(table_refs=stmt.get_tables("insert"), context="insert"),)
+ # We're probably in a function argument list
+ return _suggest_expression(token_v, stmt)
+ elif token_v == "set":
+ return (Column(table_refs=stmt.get_tables(), local_tables=stmt.local_tables),)
+ elif token_v in ("select", "where", "having", "order by", "distinct"):
+ return _suggest_expression(token_v, stmt)
+ elif token_v == "as":
+ # Don't suggest anything for aliases
+ return ()
+ elif (token_v.endswith("join") and token.is_keyword) or (
+ token_v in ("copy", "from", "update", "into", "describe", "truncate")
+ ):
+
+ schema = stmt.get_identifier_schema()
+ tables = extract_tables(stmt.text_before_cursor)
+ is_join = token_v.endswith("join") and token.is_keyword
+
+ # Suggest tables from either the currently-selected schema or the
+ # public schema if no schema has been specified
+ suggest = []
+
+ if not schema:
+ # Suggest schemas
+ suggest.insert(0, Schema())
+
+ if token_v == "from" or is_join:
+ suggest.append(
+ FromClauseItem(
+ schema=schema, table_refs=tables, local_tables=stmt.local_tables
+ )
+ )
+ elif token_v == "truncate":
+ suggest.append(Table(schema))
+ else:
+ suggest.extend((Table(schema), View(schema)))
+
+ if is_join and _allow_join(stmt.parsed):
+ tables = stmt.get_tables("before")
+ suggest.append(Join(table_refs=tables, schema=schema))
+
+ return tuple(suggest)
+
+ elif token_v == "function":
+ schema = stmt.get_identifier_schema()
+
+ # stmt.get_previous_token will fail for e.g. `SELECT 1 FROM functions WHERE function:`
+ try:
+ prev = stmt.get_previous_token(token).value.lower()
+ if prev in ("drop", "alter", "create", "create or replace"):
+
+ # Suggest functions from either the currently-selected schema or the
+ # public schema if no schema has been specified
+ suggest = []
+
+ if not schema:
+ # Suggest schemas
+ suggest.insert(0, Schema())
+
+ suggest.append(Function(schema=schema, usage="signature"))
+ return tuple(suggest)
+
+ except ValueError:
+ pass
+ return tuple()
+
+ elif token_v in ("table", "view"):
+ # E.g. 'ALTER TABLE <tablname>'
+ rel_type = {"table": Table, "view": View, "function": Function}[token_v]
+ schema = stmt.get_identifier_schema()
+ if schema:
+ return (rel_type(schema=schema),)
+ else:
+ return (Schema(), rel_type(schema=schema))
+
+ elif token_v == "column":
+ # E.g. 'ALTER TABLE foo ALTER COLUMN bar
+ return (Column(table_refs=stmt.get_tables()),)
+
+ elif token_v == "on":
+ tables = stmt.get_tables("before")
+ parent = (stmt.identifier and stmt.identifier.get_parent_name()) or None
+ if parent:
+ # "ON parent.<suggestion>"
+ # parent can be either a schema name or table alias
+ filteredtables = tuple(t for t in tables if identifies(parent, t))
+ sugs = [
+ Column(table_refs=filteredtables, local_tables=stmt.local_tables),
+ Table(schema=parent),
+ View(schema=parent),
+ Function(schema=parent),
+ ]
+ if filteredtables and _allow_join_condition(stmt.parsed):
+ sugs.append(JoinCondition(table_refs=tables, parent=filteredtables[-1]))
+ return tuple(sugs)
+ else:
+ # ON <suggestion>
+ # Use table alias if there is one, otherwise the table name
+ aliases = tuple(t.ref for t in tables)
+ if _allow_join_condition(stmt.parsed):
+ return (
+ Alias(aliases=aliases),
+ JoinCondition(table_refs=tables, parent=None),
+ )
+ else:
+ return (Alias(aliases=aliases),)
+
+ elif token_v in ("c", "use", "database", "template"):
+ # "\c <db", "use <db>", "DROP DATABASE <db>",
+ # "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
+ return (Database(),)
+ elif token_v == "schema":
+ # DROP SCHEMA schema_name, SET SCHEMA schema name
+ prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2)
+ quoted = prev_keyword and prev_keyword.value.lower() == "set"
+ return (Schema(quoted),)
+ elif token_v.endswith(",") or token_v in ("=", "and", "or"):
+ prev_keyword = stmt.reduce_to_prev_keyword()
+ if prev_keyword:
+ return suggest_based_on_last_token(prev_keyword, stmt)
+ else:
+ return ()
+ elif token_v in ("type", "::"):
+ # ALTER TABLE foo SET DATA TYPE bar
+ # SELECT foo::bar
+ # Note that tables are a form of composite type in postgresql, so
+ # they're suggested here as well
+ schema = stmt.get_identifier_schema()
+ suggestions = [Datatype(schema=schema), Table(schema=schema)]
+ if not schema:
+ suggestions.append(Schema())
+ return tuple(suggestions)
+ elif token_v in {"alter", "create", "drop"}:
+ return (Keyword(token_v.upper()),)
+ elif token.is_keyword:
+ # token is a keyword we haven't implemented any special handling for
+ # go backwards in the query until we find one we do recognize
+ prev_keyword = stmt.reduce_to_prev_keyword(n_skip=1)
+ if prev_keyword:
+ return suggest_based_on_last_token(prev_keyword, stmt)
+ else:
+ return (Keyword(token_v.upper()),)
+ else:
+ return (Keyword(),)
+
+
+def _suggest_expression(token_v, stmt):
+ """
+ Return suggestions for an expression, taking account of any partially-typed
+ identifier's parent, which may be a table alias or schema name.
+ """
+ parent = stmt.identifier.get_parent_name() if stmt.identifier else []
+ tables = stmt.get_tables()
+
+ if parent:
+ tables = tuple(t for t in tables if identifies(parent, t))
+ return (
+ Column(table_refs=tables, local_tables=stmt.local_tables),
+ Table(schema=parent),
+ View(schema=parent),
+ Function(schema=parent),
+ )
+
+ return (
+ Column(table_refs=tables, local_tables=stmt.local_tables, qualifiable=True),
+ Function(schema=None),
+ Keyword(token_v.upper()),
+ )
+
+
+def identifies(id, ref):
+ """Returns true if string `id` matches TableReference `ref`"""
+
+ return (
+ id == ref.alias
+ or id == ref.name
+ or (ref.schema and (id == ref.schema + "." + ref.name))
+ )
+
+
+def _allow_join_condition(statement):
+ """
+ Tests if a join condition should be suggested
+
+ We need this to avoid bad suggestions when entering e.g.
+ select * from tbl1 a join tbl2 b on a.id = <cursor>
+ So check that the preceding token is a ON, AND, or OR keyword, instead of
+ e.g. an equals sign.
+
+ :param statement: an sqlparse.sql.Statement
+ :return: boolean
+ """
+
+ if not statement or not statement.tokens:
+ return False
+
+ last_tok = statement.token_prev(len(statement.tokens))[1]
+ return last_tok.value.lower() in ("on", "and", "or")
+
+
+def _allow_join(statement):
+ """
+ Tests if a join should be suggested
+
+ We need this to avoid bad suggestions when entering e.g.
+ select * from tbl1 a join tbl2 b <cursor>
+ So check that the preceding token is a JOIN keyword
+
+ :param statement: an sqlparse.sql.Statement
+ :return: boolean
+ """
+
+ if not statement or not statement.tokens:
+ return False
+
+ last_tok = statement.token_prev(len(statement.tokens))[1]
+ return last_tok.value.lower().endswith("join") and last_tok.value.lower() not in (
+ "cross join",
+ "natural join",
+ )
diff --git a/pgcli/pgbuffer.py b/pgcli/pgbuffer.py
new file mode 100644
index 0000000..706ed25
--- /dev/null
+++ b/pgcli/pgbuffer.py
@@ -0,0 +1,50 @@
+import logging
+
+from prompt_toolkit.enums import DEFAULT_BUFFER
+from prompt_toolkit.filters import Condition
+from prompt_toolkit.application import get_app
+from .packages.parseutils.utils import is_open_quote
+
+_logger = logging.getLogger(__name__)
+
+
+def _is_complete(sql):
+ # A complete command is an sql statement that ends with a semicolon, unless
+ # there's an open quote surrounding it, as is common when writing a
+ # CREATE FUNCTION command
+ return sql.endswith(";") and not is_open_quote(sql)
+
+
+"""
+Returns True if the buffer contents should be handled (i.e. the query/command
+executed) immediately. This is necessary as we use prompt_toolkit in multiline
+mode, which by default will insert new lines on Enter.
+"""
+
+
+def buffer_should_be_handled(pgcli):
+ @Condition
+ def cond():
+ if not pgcli.multi_line:
+ _logger.debug("Not in multi-line mode. Handle the buffer.")
+ return True
+
+ if pgcli.multiline_mode == "safe":
+ _logger.debug("Multi-line mode is set to 'safe'. Do NOT handle the buffer.")
+ return False
+
+ doc = get_app().layout.get_buffer_by_name(DEFAULT_BUFFER).document
+ text = doc.text.strip()
+
+ return (
+ text.startswith("\\") # Special Command
+ or text.endswith(r"\e") # Special Command
+ or text.endswith(r"\G") # Ended with \e which should launch the editor
+ or _is_complete(text) # A complete SQL command
+ or (text == "exit") # Exit doesn't need semi-colon
+ or (text == "quit") # Quit doesn't need semi-colon
+ or (text == ":q") # To all the vim fans out there
+ or (text == "") # Just a plain enter without any text
+ )
+
+ return cond
diff --git a/pgcli/pgclirc b/pgcli/pgclirc
new file mode 100644
index 0000000..e97afda
--- /dev/null
+++ b/pgcli/pgclirc
@@ -0,0 +1,195 @@
+# vi: ft=dosini
+[main]
+
+# Enables context sensitive auto-completion. If this is disabled, all
+# possible completions will be listed.
+smart_completion = True
+
+# Display the completions in several columns. (More completions will be
+# visible.)
+wider_completion_menu = False
+
+# Multi-line mode allows breaking up the sql statements into multiple lines. If
+# this is set to True, then the end of the statements must have a semi-colon.
+# If this is set to False then sql statements can't be split into multiple
+# lines. End of line (return) is considered as the end of the statement.
+multi_line = False
+
+# If multi_line_mode is set to "psql", in multi-line mode, [Enter] will execute
+# the current input if the input ends in a semicolon.
+# If multi_line_mode is set to "safe", in multi-line mode, [Enter] will always
+# insert a newline, and [Esc] [Enter] or [Alt]-[Enter] must be used to execute
+# a command.
+multi_line_mode = psql
+
+# Destructive warning mode will alert you before executing a sql statement
+# that may cause harm to the database such as "drop table", "drop database"
+# or "shutdown".
+destructive_warning = True
+
+# Enables expand mode, which is similar to `\x` in psql.
+expand = False
+
+# Enables auto expand mode, which is similar to `\x auto` in psql.
+auto_expand = False
+
+# If set to True, table suggestions will include a table alias
+generate_aliases = False
+
+# log_file location.
+# In Unix/Linux: ~/.config/pgcli/log
+# In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\log
+# %USERPROFILE% is typically C:\Users\{username}
+log_file = default
+
+# keyword casing preference. Possible values: "lower", "upper", "auto"
+keyword_casing = auto
+
+# casing_file location.
+# In Unix/Linux: ~/.config/pgcli/casing
+# In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\casing
+# %USERPROFILE% is typically C:\Users\{username}
+casing_file = default
+
+# If generate_casing_file is set to True and there is no file in the above
+# location, one will be generated based on usage in SQL/PLPGSQL functions.
+generate_casing_file = False
+
+# Casing of column headers based on the casing_file described above
+case_column_headers = True
+
+# history_file location.
+# In Unix/Linux: ~/.config/pgcli/history
+# In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\history
+# %USERPROFILE% is typically C:\Users\{username}
+history_file = default
+
+# Default log level. Possible values: "CRITICAL", "ERROR", "WARNING", "INFO"
+# and "DEBUG". "NONE" disables logging.
+log_level = INFO
+
+# Order of columns when expanding * to column list
+# Possible values: "table_order" and "alphabetic"
+asterisk_column_order = table_order
+
+# Whether to qualify with table alias/name when suggesting columns
+# Possible values: "always", "never" and "if_more_than_one_table"
+qualify_columns = if_more_than_one_table
+
+# When no schema is entered, only suggest objects in search_path
+search_path_filter = False
+
+# Default pager.
+# By default 'PAGER' environment variable is used
+# pager = less -SRXF
+
+# Timing of sql statements and table rendering.
+timing = True
+
+# Show/hide the informational toolbar with function keymap at the footer.
+show_bottom_toolbar = True
+
+# Table format. Possible values: psql, plain, simple, grid, fancy_grid, pipe,
+# ascii, double, github, orgtbl, rst, mediawiki, html, latex, latex_booktabs,
+# textile, moinmoin, jira, vertical, tsv, csv.
+# Recommended: psql, fancy_grid and grid.
+table_format = psql
+
+# Syntax Style. Possible values: manni, igor, xcode, vim, autumn, vs, rrt,
+# native, perldoc, borland, tango, emacs, friendly, monokai, paraiso-dark,
+# colorful, murphy, bw, pastie, paraiso-light, trac, default, fruity
+syntax_style = default
+
+# Keybindings:
+# When Vi mode is enabled you can use modal editing features offered by Vi in the REPL.
+# When Vi mode is disabled emacs keybindings such as Ctrl-A for home and Ctrl-E
+# for end are available in the REPL.
+vi = False
+
+# Error handling
+# When one of multiple SQL statements causes an error, choose to either
+# continue executing the remaining statements, or stopping
+# Possible values "STOP" or "RESUME"
+on_error = STOP
+
+# Set threshold for row limit. Use 0 to disable limiting.
+row_limit = 1000
+
+# Skip intro on startup and goodbye on exit
+less_chatty = False
+
+# Postgres prompt
+# \t - Current date and time
+# \u - Username
+# \h - Short hostname of the server (up to first '.')
+# \H - Hostname of the server
+# \d - Database name
+# \p - Database port
+# \i - Postgres PID
+# \# - "@" sign if logged in as superuser, '>' in other case
+# \n - Newline
+# \dsn_alias - name of dsn alias if -D option is used (empty otherwise)
+# \x1b[...m - insert ANSI escape sequence
+# eg: prompt = '\x1b[35m\u@\x1b[32m\h:\x1b[36m\d>'
+prompt = '\u@\h:\d> '
+
+# Number of lines to reserve for the suggestion menu
+min_num_menu_lines = 4
+
+# Character used to left pad multi-line queries to match the prompt size.
+multiline_continuation_char = ''
+
+# The string used in place of a null value.
+null_string = '<null>'
+
+# manage pager on startup
+enable_pager = True
+
+# Use keyring to automatically save and load password in a secure manner
+keyring = True
+
+# Custom colors for the completion menu, toolbar, etc.
+[colors]
+completion-menu.completion.current = 'bg:#ffffff #000000'
+completion-menu.completion = 'bg:#008888 #ffffff'
+completion-menu.meta.completion.current = 'bg:#44aaaa #000000'
+completion-menu.meta.completion = 'bg:#448888 #ffffff'
+completion-menu.multi-column-meta = 'bg:#aaffff #000000'
+scrollbar.arrow = 'bg:#003333'
+scrollbar = 'bg:#00aaaa'
+selected = '#ffffff bg:#6666aa'
+search = '#ffffff bg:#4444aa'
+search.current = '#ffffff bg:#44aa44'
+bottom-toolbar = 'bg:#222222 #aaaaaa'
+bottom-toolbar.off = 'bg:#222222 #888888'
+bottom-toolbar.on = 'bg:#222222 #ffffff'
+search-toolbar = 'noinherit bold'
+search-toolbar.text = 'nobold'
+system-toolbar = 'noinherit bold'
+arg-toolbar = 'noinherit bold'
+arg-toolbar.text = 'nobold'
+bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold'
+bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold'
+literal.string = '#ba2121'
+literal.number = '#666666'
+keyword = 'bold #008000'
+
+# style classes for colored table output
+output.header = "#00ff5f bold"
+output.odd-row = ""
+output.even-row = ""
+output.null = "#808080"
+
+# Named queries are queries you can execute by name.
+[named queries]
+
+# DSN to call by -D option
+[alias_dsn]
+# example_dsn = postgresql://[user[:password]@][netloc][:port][/dbname]
+
+# Format for number representation
+# for decimal "d" - 12345678, ",d" - 12,345,678
+# for float "g" - 123456.78, ",g" - 123,456.78
+[data_formats]
+decimal = ""
+float = ""
diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py
new file mode 100644
index 0000000..9c95a01
--- /dev/null
+++ b/pgcli/pgcompleter.py
@@ -0,0 +1,1046 @@
+import logging
+import re
+from itertools import count, repeat, chain
+import operator
+from collections import namedtuple, defaultdict, OrderedDict
+from cli_helpers.tabular_output import TabularOutputFormatter
+from pgspecial.namedqueries import NamedQueries
+from prompt_toolkit.completion import Completer, Completion, PathCompleter
+from prompt_toolkit.document import Document
+from .packages.sqlcompletion import (
+ FromClauseItem,
+ suggest_type,
+ Special,
+ Database,
+ Schema,
+ Table,
+ TableFormat,
+ Function,
+ Column,
+ View,
+ Keyword,
+ NamedQuery,
+ Datatype,
+ Alias,
+ Path,
+ JoinCondition,
+ Join,
+)
+from .packages.parseutils.meta import ColumnMetadata, ForeignKey
+from .packages.parseutils.utils import last_word
+from .packages.parseutils.tables import TableReference
+from .packages.pgliterals.main import get_literals
+from .packages.prioritization import PrevalenceCounter
+from .config import load_config, config_location
+
+_logger = logging.getLogger(__name__)
+
+Match = namedtuple("Match", ["completion", "priority"])
+
+_SchemaObject = namedtuple("SchemaObject", "name schema meta")
+
+
+def SchemaObject(name, schema=None, meta=None):
+ return _SchemaObject(name, schema, meta)
+
+
+_Candidate = namedtuple("Candidate", "completion prio meta synonyms prio2 display")
+
+
+def Candidate(
+ completion, prio=None, meta=None, synonyms=None, prio2=None, display=None
+):
+ return _Candidate(
+ completion, prio, meta, synonyms or [completion], prio2, display or completion
+ )
+
+
+# Used to strip trailing '::some_type' from default-value expressions
+arg_default_type_strip_regex = re.compile(r"::[\w\.]+(\[\])?$")
+
+normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"'
+
+
+def generate_alias(tbl):
+ """Generate a table alias, consisting of all upper-case letters in
+ the table name, or, if there are no upper-case letters, the first letter +
+ all letters preceded by _
+ param tbl - unescaped name of the table to alias
+ """
+ return "".join(
+ [l for l in tbl if l.isupper()]
+ or [l for l, prev in zip(tbl, "_" + tbl) if prev == "_" and l != "_"]
+ )
+
+
+class PGCompleter(Completer):
+ # keywords_tree: A dict mapping keywords to well known following keywords.
+ # e.g. 'CREATE': ['TABLE', 'USER', ...],
+ keywords_tree = get_literals("keywords", type_=dict)
+ keywords = tuple(set(chain(keywords_tree.keys(), *keywords_tree.values())))
+ functions = get_literals("functions")
+ datatypes = get_literals("datatypes")
+ reserved_words = set(get_literals("reserved"))
+
+ def __init__(self, smart_completion=True, pgspecial=None, settings=None):
+ super(PGCompleter, self).__init__()
+ self.smart_completion = smart_completion
+ self.pgspecial = pgspecial
+ self.prioritizer = PrevalenceCounter()
+ settings = settings or {}
+ self.signature_arg_style = settings.get(
+ "signature_arg_style", "{arg_name} {arg_type}"
+ )
+ self.call_arg_style = settings.get(
+ "call_arg_style", "{arg_name: <{max_arg_len}} := {arg_default}"
+ )
+ self.call_arg_display_style = settings.get(
+ "call_arg_display_style", "{arg_name}"
+ )
+ self.call_arg_oneliner_max = settings.get("call_arg_oneliner_max", 2)
+ self.search_path_filter = settings.get("search_path_filter")
+ self.generate_aliases = settings.get("generate_aliases")
+ self.casing_file = settings.get("casing_file")
+ self.insert_col_skip_patterns = [
+ re.compile(pattern)
+ for pattern in settings.get(
+ "insert_col_skip_patterns", [r"^now\(\)$", r"^nextval\("]
+ )
+ ]
+ self.generate_casing_file = settings.get("generate_casing_file")
+ self.qualify_columns = settings.get("qualify_columns", "if_more_than_one_table")
+ self.asterisk_column_order = settings.get(
+ "asterisk_column_order", "table_order"
+ )
+
+ keyword_casing = settings.get("keyword_casing", "upper").lower()
+ if keyword_casing not in ("upper", "lower", "auto"):
+ keyword_casing = "upper"
+ self.keyword_casing = keyword_casing
+ self.name_pattern = re.compile(r"^[_a-z][_a-z0-9\$]*$")
+
+ self.databases = []
+ self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}}
+ self.search_path = []
+ self.casing = {}
+
+ self.all_completions = set(self.keywords + self.functions)
+
+ def escape_name(self, name):
+ if name and (
+ (not self.name_pattern.match(name))
+ or (name.upper() in self.reserved_words)
+ or (name.upper() in self.functions)
+ ):
+ name = '"%s"' % name
+
+ return name
+
+ def escape_schema(self, name):
+ return "'{}'".format(self.unescape_name(name))
+
+ def unescape_name(self, name):
+ """ Unquote a string."""
+ if name and name[0] == '"' and name[-1] == '"':
+ name = name[1:-1]
+
+ return name
+
+ def escaped_names(self, names):
+ return [self.escape_name(name) for name in names]
+
+ def extend_database_names(self, databases):
+ self.databases.extend(databases)
+
+ def extend_keywords(self, additional_keywords):
+ self.keywords.extend(additional_keywords)
+ self.all_completions.update(additional_keywords)
+
+ def extend_schemata(self, schemata):
+
+ # schemata is a list of schema names
+ schemata = self.escaped_names(schemata)
+ metadata = self.dbmetadata["tables"]
+ for schema in schemata:
+ metadata[schema] = {}
+
+ # dbmetadata.values() are the 'tables' and 'functions' dicts
+ for metadata in self.dbmetadata.values():
+ for schema in schemata:
+ metadata[schema] = {}
+
+ self.all_completions.update(schemata)
+
+ def extend_casing(self, words):
+ """extend casing data
+
+ :return:
+ """
+ # casing should be a dict {lowercasename:PreferredCasingName}
+ self.casing = dict((word.lower(), word) for word in words)
+
+ def extend_relations(self, data, kind):
+ """extend metadata for tables or views.
+
+ :param data: list of (schema_name, rel_name) tuples
+ :param kind: either 'tables' or 'views'
+
+ :return:
+
+ """
+
+ data = [self.escaped_names(d) for d in data]
+
+ # dbmetadata['tables']['schema_name']['table_name'] should be an
+ # OrderedDict {column_name:ColumnMetaData}.
+ metadata = self.dbmetadata[kind]
+ for schema, relname in data:
+ try:
+ metadata[schema][relname] = OrderedDict()
+ except KeyError:
+ _logger.error(
+ "%r %r listed in unrecognized schema %r", kind, relname, schema
+ )
+ self.all_completions.add(relname)
+
+ def extend_columns(self, column_data, kind):
+ """extend column metadata.
+
+ :param column_data: list of (schema_name, rel_name, column_name,
+ column_type, has_default, default) tuples
+ :param kind: either 'tables' or 'views'
+
+ :return:
+
+ """
+ metadata = self.dbmetadata[kind]
+ for schema, relname, colname, datatype, has_default, default in column_data:
+ (schema, relname, colname) = self.escaped_names([schema, relname, colname])
+ column = ColumnMetadata(
+ name=colname,
+ datatype=datatype,
+ has_default=has_default,
+ default=default,
+ )
+ metadata[schema][relname][colname] = column
+ self.all_completions.add(colname)
+
+ def extend_functions(self, func_data):
+
+ # func_data is a list of function metadata namedtuples
+
+ # dbmetadata['schema_name']['functions']['function_name'] should return
+ # the function metadata namedtuple for the corresponding function
+ metadata = self.dbmetadata["functions"]
+
+ for f in func_data:
+ schema, func = self.escaped_names([f.schema_name, f.func_name])
+
+ if func in metadata[schema]:
+ metadata[schema][func].append(f)
+ else:
+ metadata[schema][func] = [f]
+
+ self.all_completions.add(func)
+
+ self._refresh_arg_list_cache()
+
+ def _refresh_arg_list_cache(self):
+ # We keep a cache of {function_usage:{function_metadata: function_arg_list_string}}
+ # This is used when suggesting functions, to avoid the latency that would result
+ # if we'd recalculate the arg lists each time we suggest functions (in large DBs)
+ self._arg_list_cache = {
+ usage: {
+ meta: self._arg_list(meta, usage)
+ for sch, funcs in self.dbmetadata["functions"].items()
+ for func, metas in funcs.items()
+ for meta in metas
+ }
+ for usage in ("call", "call_display", "signature")
+ }
+
+ def extend_foreignkeys(self, fk_data):
+
+ # fk_data is a list of ForeignKey namedtuples, with fields
+ # parentschema, childschema, parenttable, childtable,
+ # parentcolumns, childcolumns
+
+ # These are added as a list of ForeignKey namedtuples to the
+ # ColumnMetadata namedtuple for both the child and parent
+ meta = self.dbmetadata["tables"]
+
+ for fk in fk_data:
+ e = self.escaped_names
+ parentschema, childschema = e([fk.parentschema, fk.childschema])
+ parenttable, childtable = e([fk.parenttable, fk.childtable])
+ childcol, parcol = e([fk.childcolumn, fk.parentcolumn])
+ childcolmeta = meta[childschema][childtable][childcol]
+ parcolmeta = meta[parentschema][parenttable][parcol]
+ fk = ForeignKey(
+ parentschema, parenttable, parcol, childschema, childtable, childcol
+ )
+ childcolmeta.foreignkeys.append((fk))
+ parcolmeta.foreignkeys.append((fk))
+
+ def extend_datatypes(self, type_data):
+
+ # dbmetadata['datatypes'][schema_name][type_name] should store type
+ # metadata, such as composite type field names. Currently, we're not
+ # storing any metadata beyond typename, so just store None
+ meta = self.dbmetadata["datatypes"]
+
+ for t in type_data:
+ schema, type_name = self.escaped_names(t)
+ meta[schema][type_name] = None
+ self.all_completions.add(type_name)
+
+ def extend_query_history(self, text, is_init=False):
+ if is_init:
+ # During completer initialization, only load keyword preferences,
+ # not names
+ self.prioritizer.update_keywords(text)
+ else:
+ self.prioritizer.update(text)
+
+ def set_search_path(self, search_path):
+ self.search_path = self.escaped_names(search_path)
+
+ def reset_completions(self):
+ self.databases = []
+ self.special_commands = []
+ self.search_path = []
+ self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}}
+ self.all_completions = set(self.keywords + self.functions)
+
+ def find_matches(self, text, collection, mode="fuzzy", meta=None):
+ """Find completion matches for the given text.
+
+ Given the user's input text and a collection of available
+ completions, find completions matching the last word of the
+ text.
+
+ `collection` can be either a list of strings or a list of Candidate
+ namedtuples.
+ `mode` can be either 'fuzzy', or 'strict'
+ 'fuzzy': fuzzy matching, ties broken by name prevalance
+ `keyword`: start only matching, ties broken by keyword prevalance
+
+ yields prompt_toolkit Completion instances for any matches found
+ in the collection of available completions.
+
+ """
+ if not collection:
+ return []
+ prio_order = [
+ "keyword",
+ "function",
+ "view",
+ "table",
+ "datatype",
+ "database",
+ "schema",
+ "column",
+ "table alias",
+ "join",
+ "name join",
+ "fk join",
+ "table format",
+ ]
+ type_priority = prio_order.index(meta) if meta in prio_order else -1
+ text = last_word(text, include="most_punctuations").lower()
+ text_len = len(text)
+
+ if text and text[0] == '"':
+ # text starts with double quote; user is manually escaping a name
+ # Match on everything that follows the double-quote. Note that
+ # text_len is calculated before removing the quote, so the
+ # Completion.position value is correct
+ text = text[1:]
+
+ if mode == "fuzzy":
+ fuzzy = True
+ priority_func = self.prioritizer.name_count
+ else:
+ fuzzy = False
+ priority_func = self.prioritizer.keyword_count
+
+ # Construct a `_match` function for either fuzzy or non-fuzzy matching
+ # The match function returns a 2-tuple used for sorting the matches,
+ # or None if the item doesn't match
+ # Note: higher priority values mean more important, so use negative
+ # signs to flip the direction of the tuple
+ if fuzzy:
+ regex = ".*?".join(map(re.escape, text))
+ pat = re.compile("(%s)" % regex)
+
+ def _match(item):
+ if item.lower()[: len(text) + 1] in (text, text + " "):
+ # Exact match of first word in suggestion
+ # This is to get exact alias matches to the top
+ # E.g. for input `e`, 'Entries E' should be on top
+ # (before e.g. `EndUsers EU`)
+ return float("Infinity"), -1
+ r = pat.search(self.unescape_name(item.lower()))
+ if r:
+ return -len(r.group()), -r.start()
+
+ else:
+ match_end_limit = len(text)
+
+ def _match(item):
+ match_point = item.lower().find(text, 0, match_end_limit)
+ if match_point >= 0:
+ # Use negative infinity to force keywords to sort after all
+ # fuzzy matches
+ return -float("Infinity"), -match_point
+
+ matches = []
+ for cand in collection:
+ if isinstance(cand, _Candidate):
+ item, prio, display_meta, synonyms, prio2, display = cand
+ if display_meta is None:
+ display_meta = meta
+ syn_matches = (_match(x) for x in synonyms)
+ # Nones need to be removed to avoid max() crashing in Python 3
+ syn_matches = [m for m in syn_matches if m]
+ sort_key = max(syn_matches) if syn_matches else None
+ else:
+ item, display_meta, prio, prio2, display = cand, meta, 0, 0, cand
+ sort_key = _match(cand)
+
+ if sort_key:
+ if display_meta and len(display_meta) > 50:
+ # Truncate meta-text to 50 characters, if necessary
+ display_meta = display_meta[:47] + "..."
+
+ # Lexical order of items in the collection, used for
+ # tiebreaking items with the same match group length and start
+ # position. Since we use *higher* priority to mean "more
+ # important," we use -ord(c) to prioritize "aa" > "ab" and end
+ # with 1 to prioritize shorter strings (ie "user" > "users").
+ # We first do a case-insensitive sort and then a
+ # case-sensitive one as a tie breaker.
+ # We also use the unescape_name to make sure quoted names have
+ # the same priority as unquoted names.
+ lexical_priority = (
+ tuple(
+ 0 if c in (" _") else -ord(c)
+ for c in self.unescape_name(item.lower())
+ )
+ + (1,)
+ + tuple(c for c in item)
+ )
+
+ item = self.case(item)
+ display = self.case(display)
+ priority = (
+ sort_key,
+ type_priority,
+ prio,
+ priority_func(item),
+ prio2,
+ lexical_priority,
+ )
+ matches.append(
+ Match(
+ completion=Completion(
+ text=item,
+ start_position=-text_len,
+ display_meta=display_meta,
+ display=display,
+ ),
+ priority=priority,
+ )
+ )
+ return matches
+
+ def case(self, word):
+ return self.casing.get(word, word)
+
+ def get_completions(self, document, complete_event, smart_completion=None):
+ word_before_cursor = document.get_word_before_cursor(WORD=True)
+
+ if smart_completion is None:
+ smart_completion = self.smart_completion
+
+ # If smart_completion is off then match any word that starts with
+ # 'word_before_cursor'.
+ if not smart_completion:
+ matches = self.find_matches(
+ word_before_cursor, self.all_completions, mode="strict"
+ )
+ completions = [m.completion for m in matches]
+ return sorted(completions, key=operator.attrgetter("text"))
+
+ matches = []
+ suggestions = suggest_type(document.text, document.text_before_cursor)
+
+ for suggestion in suggestions:
+ suggestion_type = type(suggestion)
+ _logger.debug("Suggestion type: %r", suggestion_type)
+
+ # Map suggestion type to method
+ # e.g. 'table' -> self.get_table_matches
+ matcher = self.suggestion_matchers[suggestion_type]
+ matches.extend(matcher(self, suggestion, word_before_cursor))
+
+ # Sort matches so highest priorities are first
+ matches = sorted(matches, key=operator.attrgetter("priority"), reverse=True)
+
+ return [m.completion for m in matches]
+
+ def get_column_matches(self, suggestion, word_before_cursor):
+ tables = suggestion.table_refs
+ do_qualify = suggestion.qualifiable and {
+ "always": True,
+ "never": False,
+ "if_more_than_one_table": len(tables) > 1,
+ }[self.qualify_columns]
+ qualify = lambda col, tbl: (
+ (tbl + "." + self.case(col)) if do_qualify else self.case(col)
+ )
+ _logger.debug("Completion column scope: %r", tables)
+ scoped_cols = self.populate_scoped_cols(tables, suggestion.local_tables)
+
+ def make_cand(name, ref):
+ synonyms = (name, generate_alias(self.case(name)))
+ return Candidate(qualify(name, ref), 0, "column", synonyms)
+
+ def flat_cols():
+ return [
+ make_cand(c.name, t.ref)
+ for t, cols in scoped_cols.items()
+ for c in cols
+ ]
+
+ if suggestion.require_last_table:
+ # require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should
+ # suggest only columns that appear in the last table and one more
+ ltbl = tables[-1].ref
+ other_tbl_cols = set(
+ c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs
+ )
+ scoped_cols = {
+ t: [col for col in cols if col.name in other_tbl_cols]
+ for t, cols in scoped_cols.items()
+ if t.ref == ltbl
+ }
+ lastword = last_word(word_before_cursor, include="most_punctuations")
+ if lastword == "*":
+ if suggestion.context == "insert":
+
+ def filter(col):
+ if not col.has_default:
+ return True
+ return not any(
+ p.match(col.default) for p in self.insert_col_skip_patterns
+ )
+
+ scoped_cols = {
+ t: [col for col in cols if filter(col)]
+ for t, cols in scoped_cols.items()
+ }
+ if self.asterisk_column_order == "alphabetic":
+ for cols in scoped_cols.values():
+ cols.sort(key=operator.attrgetter("name"))
+ if (
+ lastword != word_before_cursor
+ and len(tables) == 1
+ and word_before_cursor[-len(lastword) - 1] == "."
+ ):
+ # User typed x.*; replicate "x." for all columns except the
+ # first, which gets the original (as we only replace the "*"")
+ sep = ", " + word_before_cursor[:-1]
+ collist = sep.join(self.case(c.completion) for c in flat_cols())
+ else:
+ collist = ", ".join(
+ qualify(c.name, t.ref) for t, cs in scoped_cols.items() for c in cs
+ )
+
+ return [
+ Match(
+ completion=Completion(
+ collist, -1, display_meta="columns", display="*"
+ ),
+ priority=(1, 1, 1),
+ )
+ ]
+
+ return self.find_matches(word_before_cursor, flat_cols(), meta="column")
+
+ def alias(self, tbl, tbls):
+ """Generate a unique table alias
+ tbl - name of the table to alias, quoted if it needs to be
+ tbls - TableReference iterable of tables already in query
+ """
+ tbl = self.case(tbl)
+ tbls = set(normalize_ref(t.ref) for t in tbls)
+ if self.generate_aliases:
+ tbl = generate_alias(self.unescape_name(tbl))
+ if normalize_ref(tbl) not in tbls:
+ return tbl
+ elif tbl[0] == '"':
+ aliases = ('"' + tbl[1:-1] + str(i) + '"' for i in count(2))
+ else:
+ aliases = (tbl + str(i) for i in count(2))
+ return next(a for a in aliases if normalize_ref(a) not in tbls)
+
+ def get_join_matches(self, suggestion, word_before_cursor):
+ tbls = suggestion.table_refs
+ cols = self.populate_scoped_cols(tbls)
+ # Set up some data structures for efficient access
+ qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls)
+ ref_prio = dict((normalize_ref(t.ref), n) for n, t in enumerate(tbls))
+ refs = set(normalize_ref(t.ref) for t in tbls)
+ other_tbls = set((t.schema, t.name) for t in list(cols)[:-1])
+ joins = []
+ # Iterate over FKs in existing tables to find potential joins
+ fks = (
+ (fk, rtbl, rcol)
+ for rtbl, rcols in cols.items()
+ for rcol in rcols
+ for fk in rcol.foreignkeys
+ )
+ col = namedtuple("col", "schema tbl col")
+ for fk, rtbl, rcol in fks:
+ right = col(rtbl.schema, rtbl.name, rcol.name)
+ child = col(fk.childschema, fk.childtable, fk.childcolumn)
+ parent = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
+ left = child if parent == right else parent
+ if suggestion.schema and left.schema != suggestion.schema:
+ continue
+ c = self.case
+ if self.generate_aliases or normalize_ref(left.tbl) in refs:
+ lref = self.alias(left.tbl, suggestion.table_refs)
+ join = "{0} {4} ON {4}.{1} = {2}.{3}".format(
+ c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref
+ )
+ else:
+ join = "{0} ON {0}.{1} = {2}.{3}".format(
+ c(left.tbl), c(left.col), rtbl.ref, c(right.col)
+ )
+ alias = generate_alias(self.case(left.tbl))
+ synonyms = [
+ join,
+ "{0} ON {0}.{1} = {2}.{3}".format(
+ alias, c(left.col), rtbl.ref, c(right.col)
+ ),
+ ]
+ # Schema-qualify if (1) new table in same schema as old, and old
+ # is schema-qualified, or (2) new in other schema, except public
+ if not suggestion.schema and (
+ qualified[normalize_ref(rtbl.ref)]
+ and left.schema == right.schema
+ or left.schema not in (right.schema, "public")
+ ):
+ join = left.schema + "." + join
+ prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + (
+ 0 if (left.schema, left.tbl) in other_tbls else 1
+ )
+ joins.append(Candidate(join, prio, "join", synonyms=synonyms))
+
+ return self.find_matches(word_before_cursor, joins, meta="join")
+
+ def get_join_condition_matches(self, suggestion, word_before_cursor):
+ col = namedtuple("col", "schema tbl col")
+ tbls = self.populate_scoped_cols(suggestion.table_refs).items
+ cols = [(t, c) for t, cs in tbls() for c in cs]
+ try:
+ lref = (suggestion.parent or suggestion.table_refs[-1]).ref
+ ltbl, lcols = [(t, cs) for (t, cs) in tbls() if t.ref == lref][-1]
+ except IndexError: # The user typed an incorrect table qualifier
+ return []
+ conds, found_conds = [], set()
+
+ def add_cond(lcol, rcol, rref, prio, meta):
+ prefix = "" if suggestion.parent else ltbl.ref + "."
+ case = self.case
+ cond = prefix + case(lcol) + " = " + rref + "." + case(rcol)
+ if cond not in found_conds:
+ found_conds.add(cond)
+ conds.append(Candidate(cond, prio + ref_prio[rref], meta))
+
+ def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
+ d = defaultdict(list)
+ for pair in pairs:
+ d[pair[0]].append(pair[1])
+ return d
+
+ # Tables that are closer to the cursor get higher prio
+ ref_prio = dict((tbl.ref, num) for num, tbl in enumerate(suggestion.table_refs))
+ # Map (schema, table, col) to tables
+ coldict = list_dict(
+ ((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref
+ )
+ # For each fk from the left table, generate a join condition if
+ # the other table is also in the scope
+ fks = ((fk, lcol.name) for lcol in lcols for fk in lcol.foreignkeys)
+ for fk, lcol in fks:
+ left = col(ltbl.schema, ltbl.name, lcol)
+ child = col(fk.childschema, fk.childtable, fk.childcolumn)
+ par = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
+ left, right = (child, par) if left == child else (par, child)
+ for rtbl in coldict[right]:
+ add_cond(left.col, right.col, rtbl.ref, 2000, "fk join")
+ # For name matching, use a {(colname, coltype): TableReference} dict
+ coltyp = namedtuple("coltyp", "name datatype")
+ col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols)
+ # Find all name-match join conditions
+ for c in (coltyp(c.name, c.datatype) for c in lcols):
+ for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref):
+ prio = 1000 if c.datatype in ("integer", "bigint", "smallint") else 0
+ add_cond(c.name, c.name, rtbl.ref, prio, "name join")
+
+ return self.find_matches(word_before_cursor, conds, meta="join")
+
+ def get_function_matches(self, suggestion, word_before_cursor, alias=False):
+
+ if suggestion.usage == "from":
+ # Only suggest functions allowed in FROM clause
+
+ def filt(f):
+ return (
+ not f.is_aggregate
+ and not f.is_window
+ and not f.is_extension
+ and (f.is_public or f.schema_name == suggestion.schema)
+ )
+
+ else:
+ alias = False
+
+ def filt(f):
+ return not f.is_extension and (
+ f.is_public or f.schema_name == suggestion.schema
+ )
+
+ arg_mode = {"signature": "signature", "special": None}.get(
+ suggestion.usage, "call"
+ )
+
+ # Function overloading means we way have multiple functions of the same
+ # name at this point, so keep unique names only
+ all_functions = self.populate_functions(suggestion.schema, filt)
+ funcs = set(
+ self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions
+ )
+
+ matches = self.find_matches(word_before_cursor, funcs, meta="function")
+
+ if not suggestion.schema and not suggestion.usage:
+ # also suggest hardcoded functions using startswith matching
+ predefined_funcs = self.find_matches(
+ word_before_cursor, self.functions, mode="strict", meta="function"
+ )
+ matches.extend(predefined_funcs)
+
+ return matches
+
+ def get_schema_matches(self, suggestion, word_before_cursor):
+ schema_names = self.dbmetadata["tables"].keys()
+
+ # Unless we're sure the user really wants them, hide schema names
+ # starting with pg_, which are mostly temporary schemas
+ if not word_before_cursor.startswith("pg_"):
+ schema_names = [s for s in schema_names if not s.startswith("pg_")]
+
+ if suggestion.quoted:
+ schema_names = [self.escape_schema(s) for s in schema_names]
+
+ return self.find_matches(word_before_cursor, schema_names, meta="schema")
+
+ def get_from_clause_item_matches(self, suggestion, word_before_cursor):
+ alias = self.generate_aliases
+ s = suggestion
+ t_sug = Table(s.schema, s.table_refs, s.local_tables)
+ v_sug = View(s.schema, s.table_refs)
+ f_sug = Function(s.schema, s.table_refs, usage="from")
+ return (
+ self.get_table_matches(t_sug, word_before_cursor, alias)
+ + self.get_view_matches(v_sug, word_before_cursor, alias)
+ + self.get_function_matches(f_sug, word_before_cursor, alias)
+ )
+
+ def _arg_list(self, func, usage):
+ """Returns a an arg list string, e.g. `(_foo:=23)` for a func.
+
+ :param func is a FunctionMetadata object
+ :param usage is 'call', 'call_display' or 'signature'
+
+ """
+ template = {
+ "call": self.call_arg_style,
+ "call_display": self.call_arg_display_style,
+ "signature": self.signature_arg_style,
+ }[usage]
+ args = func.args()
+ if not template:
+ return "()"
+ elif usage == "call" and len(args) < 2:
+ return "()"
+ elif usage == "call" and func.has_variadic():
+ return "()"
+ multiline = usage == "call" and len(args) > self.call_arg_oneliner_max
+ max_arg_len = max(len(a.name) for a in args) if multiline else 0
+ args = (
+ self._format_arg(template, arg, arg_num + 1, max_arg_len)
+ for arg_num, arg in enumerate(args)
+ )
+ if multiline:
+ return "(" + ",".join("\n " + a for a in args if a) + "\n)"
+ else:
+ return "(" + ", ".join(a for a in args if a) + ")"
+
+ def _format_arg(self, template, arg, arg_num, max_arg_len):
+ if not template:
+ return None
+ if arg.has_default:
+ arg_default = "NULL" if arg.default is None else arg.default
+ # Remove trailing ::(schema.)type
+ arg_default = arg_default_type_strip_regex.sub("", arg_default)
+ else:
+ arg_default = ""
+ return template.format(
+ max_arg_len=max_arg_len,
+ arg_name=self.case(arg.name),
+ arg_num=arg_num,
+ arg_type=arg.datatype,
+ arg_default=arg_default,
+ )
+
+ def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None):
+ """Returns a Candidate namedtuple.
+
+ :param tbl is a SchemaObject
+ :param arg_mode determines what type of arg list to suffix for functions.
+ Possible values: call, signature
+
+ """
+ cased_tbl = self.case(tbl.name)
+ if do_alias:
+ alias = self.alias(cased_tbl, suggestion.table_refs)
+ synonyms = (cased_tbl, generate_alias(cased_tbl))
+ maybe_alias = (" " + alias) if do_alias else ""
+ maybe_schema = (self.case(tbl.schema) + ".") if tbl.schema else ""
+ suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else ""
+ if arg_mode == "call":
+ display_suffix = self._arg_list_cache["call_display"][tbl.meta]
+ elif arg_mode == "signature":
+ display_suffix = self._arg_list_cache["signature"][tbl.meta]
+ else:
+ display_suffix = ""
+ item = maybe_schema + cased_tbl + suffix + maybe_alias
+ display = maybe_schema + cased_tbl + display_suffix + maybe_alias
+ prio2 = 0 if tbl.schema else 1
+ return Candidate(item, synonyms=synonyms, prio2=prio2, display=display)
+
+ def get_table_matches(self, suggestion, word_before_cursor, alias=False):
+ tables = self.populate_schema_objects(suggestion.schema, "tables")
+ tables.extend(SchemaObject(tbl.name) for tbl in suggestion.local_tables)
+
+ # Unless we're sure the user really wants them, don't suggest the
+ # pg_catalog tables that are implicitly on the search path
+ if not suggestion.schema and (not word_before_cursor.startswith("pg_")):
+ tables = [t for t in tables if not t.name.startswith("pg_")]
+ tables = [self._make_cand(t, alias, suggestion) for t in tables]
+ return self.find_matches(word_before_cursor, tables, meta="table")
+
+ def get_table_formats(self, _, word_before_cursor):
+ formats = TabularOutputFormatter().supported_formats
+ return self.find_matches(word_before_cursor, formats, meta="table format")
+
+ def get_view_matches(self, suggestion, word_before_cursor, alias=False):
+ views = self.populate_schema_objects(suggestion.schema, "views")
+
+ if not suggestion.schema and (not word_before_cursor.startswith("pg_")):
+ views = [v for v in views if not v.name.startswith("pg_")]
+ views = [self._make_cand(v, alias, suggestion) for v in views]
+ return self.find_matches(word_before_cursor, views, meta="view")
+
+ def get_alias_matches(self, suggestion, word_before_cursor):
+ aliases = suggestion.aliases
+ return self.find_matches(word_before_cursor, aliases, meta="table alias")
+
+ def get_database_matches(self, _, word_before_cursor):
+ return self.find_matches(word_before_cursor, self.databases, meta="database")
+
+ def get_keyword_matches(self, suggestion, word_before_cursor):
+ keywords = self.keywords_tree.keys()
+ # Get well known following keywords for the last token. If any, narrow
+ # candidates to this list.
+ next_keywords = self.keywords_tree.get(suggestion.last_token, [])
+ if next_keywords:
+ keywords = next_keywords
+
+ casing = self.keyword_casing
+ if casing == "auto":
+ if word_before_cursor and word_before_cursor[-1].islower():
+ casing = "lower"
+ else:
+ casing = "upper"
+
+ if casing == "upper":
+ keywords = [k.upper() for k in keywords]
+ else:
+ keywords = [k.lower() for k in keywords]
+
+ return self.find_matches(
+ word_before_cursor, keywords, mode="strict", meta="keyword"
+ )
+
+ def get_path_matches(self, _, word_before_cursor):
+ completer = PathCompleter(expanduser=True)
+ document = Document(
+ text=word_before_cursor, cursor_position=len(word_before_cursor)
+ )
+ for c in completer.get_completions(document, None):
+ yield Match(completion=c, priority=(0,))
+
+ def get_special_matches(self, _, word_before_cursor):
+ if not self.pgspecial:
+ return []
+
+ commands = self.pgspecial.commands
+ cmds = commands.keys()
+ cmds = [Candidate(cmd, 0, commands[cmd].description) for cmd in cmds]
+ return self.find_matches(word_before_cursor, cmds, mode="strict")
+
+ def get_datatype_matches(self, suggestion, word_before_cursor):
+ # suggest custom datatypes
+ types = self.populate_schema_objects(suggestion.schema, "datatypes")
+ types = [self._make_cand(t, False, suggestion) for t in types]
+ matches = self.find_matches(word_before_cursor, types, meta="datatype")
+
+ if not suggestion.schema:
+ # Also suggest hardcoded types
+ matches.extend(
+ self.find_matches(
+ word_before_cursor, self.datatypes, mode="strict", meta="datatype"
+ )
+ )
+
+ return matches
+
+ def get_namedquery_matches(self, _, word_before_cursor):
+ return self.find_matches(
+ word_before_cursor, NamedQueries.instance.list(), meta="named query"
+ )
+
+ suggestion_matchers = {
+ FromClauseItem: get_from_clause_item_matches,
+ JoinCondition: get_join_condition_matches,
+ Join: get_join_matches,
+ Column: get_column_matches,
+ Function: get_function_matches,
+ Schema: get_schema_matches,
+ Table: get_table_matches,
+ TableFormat: get_table_formats,
+ View: get_view_matches,
+ Alias: get_alias_matches,
+ Database: get_database_matches,
+ Keyword: get_keyword_matches,
+ Special: get_special_matches,
+ Datatype: get_datatype_matches,
+ NamedQuery: get_namedquery_matches,
+ Path: get_path_matches,
+ }
+
+ def populate_scoped_cols(self, scoped_tbls, local_tbls=()):
+ """Find all columns in a set of scoped_tables.
+
+ :param scoped_tbls: list of TableReference namedtuples
+ :param local_tbls: tuple(TableMetadata)
+ :return: {TableReference:{colname:ColumnMetaData}}
+
+ """
+ ctes = dict((normalize_ref(t.name), t.columns) for t in local_tbls)
+ columns = OrderedDict()
+ meta = self.dbmetadata
+
+ def addcols(schema, rel, alias, reltype, cols):
+ tbl = TableReference(schema, rel, alias, reltype == "functions")
+ if tbl not in columns:
+ columns[tbl] = []
+ columns[tbl].extend(cols)
+
+ for tbl in scoped_tbls:
+ # Local tables should shadow database tables
+ if tbl.schema is None and normalize_ref(tbl.name) in ctes:
+ cols = ctes[normalize_ref(tbl.name)]
+ addcols(None, tbl.name, "CTE", tbl.alias, cols)
+ continue
+ schemas = [tbl.schema] if tbl.schema else self.search_path
+ for schema in schemas:
+ relname = self.escape_name(tbl.name)
+ schema = self.escape_name(schema)
+ if tbl.is_function:
+ # Return column names from a set-returning function
+ # Get an array of FunctionMetadata objects
+ functions = meta["functions"].get(schema, {}).get(relname)
+ for func in functions or []:
+ # func is a FunctionMetadata object
+ cols = func.fields()
+ addcols(schema, relname, tbl.alias, "functions", cols)
+ else:
+ for reltype in ("tables", "views"):
+ cols = meta[reltype].get(schema, {}).get(relname)
+ if cols:
+ cols = cols.values()
+ addcols(schema, relname, tbl.alias, reltype, cols)
+ break
+
+ return columns
+
+ def _get_schemas(self, obj_typ, schema):
+ """Returns a list of schemas from which to suggest objects.
+
+ :param schema is the schema qualification input by the user (if any)
+
+ """
+ metadata = self.dbmetadata[obj_typ]
+ if schema:
+ schema = self.escape_name(schema)
+ return [schema] if schema in metadata else []
+ return self.search_path if self.search_path_filter else metadata.keys()
+
+ def _maybe_schema(self, schema, parent):
+ return None if parent or schema in self.search_path else schema
+
+ def populate_schema_objects(self, schema, obj_type):
+ """Returns a list of SchemaObjects representing tables or views.
+
+ :param schema is the schema qualification input by the user (if any)
+
+ """
+
+ return [
+ SchemaObject(
+ name=obj, schema=(self._maybe_schema(schema=sch, parent=schema))
+ )
+ for sch in self._get_schemas(obj_type, schema)
+ for obj in self.dbmetadata[obj_type][sch].keys()
+ ]
+
+ def populate_functions(self, schema, filter_func):
+ """Returns a list of function SchemaObjects.
+
+ :param filter_func is a function that accepts a FunctionMetadata
+ namedtuple and returns a boolean indicating whether that
+ function should be kept or discarded
+
+ """
+
+ # Because of multiple dispatch, we can have multiple functions
+ # with the same name, which is why `for meta in metas` is necessary
+ # in the comprehensions below
+ return [
+ SchemaObject(
+ name=func,
+ schema=(self._maybe_schema(schema=sch, parent=schema)),
+ meta=meta,
+ )
+ for sch in self._get_schemas("functions", schema)
+ for (func, metas) in self.dbmetadata["functions"][sch].items()
+ for meta in metas
+ if filter_func(meta)
+ ]
diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py
new file mode 100644
index 0000000..d34bf26
--- /dev/null
+++ b/pgcli/pgexecute.py
@@ -0,0 +1,857 @@
+import traceback
+import logging
+import psycopg2
+import psycopg2.extras
+import psycopg2.errorcodes
+import psycopg2.extensions as ext
+import sqlparse
+import pgspecial as special
+import select
+from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn
+from .packages.parseutils.meta import FunctionMetadata, ForeignKey
+
+_logger = logging.getLogger(__name__)
+
+# Cast all database input to unicode automatically.
+# See http://initd.org/psycopg/docs/usage.html#unicode-handling for more info.
+ext.register_type(ext.UNICODE)
+ext.register_type(ext.UNICODEARRAY)
+ext.register_type(ext.new_type((705,), "UNKNOWN", ext.UNICODE))
+# See https://github.com/dbcli/pgcli/issues/426 for more details.
+# This registers a unicode type caster for datatype 'RECORD'.
+ext.register_type(ext.new_type((2249,), "RECORD", ext.UNICODE))
+
+# Cast bytea fields to text. By default, this will render as hex strings with
+# Postgres 9+ and as escaped binary in earlier versions.
+ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING))
+
+# TODO: Get default timeout from pgclirc?
+_WAIT_SELECT_TIMEOUT = 1
+
+
+def _wait_select(conn):
+ """
+ copy-pasted from psycopg2.extras.wait_select
+ the default implementation doesn't define a timeout in the select calls
+ """
+ while 1:
+ try:
+ state = conn.poll()
+ if state == POLL_OK:
+ break
+ elif state == POLL_READ:
+ select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT)
+ elif state == POLL_WRITE:
+ select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT)
+ else:
+ raise conn.OperationalError("bad state from poll: %s" % state)
+ except KeyboardInterrupt:
+ conn.cancel()
+ # the loop will be broken by a server error
+ continue
+ except select.error as e:
+ errno = e.args[0]
+ if errno != 4:
+ raise
+
+
+# When running a query, make pressing CTRL+C raise a KeyboardInterrupt
+# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
+# See also https://github.com/psycopg/psycopg2/issues/468
+ext.set_wait_callback(_wait_select)
+
+
+def register_date_typecasters(connection):
+ """
+ Casts date and timestamp values to string, resolves issues with out of
+ range dates (e.g. BC) which psycopg2 can't handle
+ """
+
+ def cast_date(value, cursor):
+ return value
+
+ cursor = connection.cursor()
+ cursor.execute("SELECT NULL::date")
+ date_oid = cursor.description[0][1]
+ cursor.execute("SELECT NULL::timestamp")
+ timestamp_oid = cursor.description[0][1]
+ cursor.execute("SELECT NULL::timestamp with time zone")
+ timestamptz_oid = cursor.description[0][1]
+ oids = (date_oid, timestamp_oid, timestamptz_oid)
+ new_type = psycopg2.extensions.new_type(oids, "DATE", cast_date)
+ psycopg2.extensions.register_type(new_type)
+
+
+def register_json_typecasters(conn, loads_fn):
+ """Set the function for converting JSON data for a connection.
+
+ Use the supplied function to decode JSON data returned from the database
+ via the given connection. The function should accept a single argument of
+ the data as a string encoded in the database's character encoding.
+ psycopg2's default handler for JSON data is json.loads.
+ http://initd.org/psycopg/docs/extras.html#json-adaptation
+
+ This function attempts to register the typecaster for both JSON and JSONB
+ types.
+
+ Returns a set that is a subset of {'json', 'jsonb'} indicating which types
+ (if any) were successfully registered.
+ """
+ available = set()
+
+ for name in ["json", "jsonb"]:
+ try:
+ psycopg2.extras.register_json(conn, loads=loads_fn, name=name)
+ available.add(name)
+ except psycopg2.ProgrammingError:
+ pass
+
+ return available
+
+
+def register_hstore_typecaster(conn):
+ """
+ Instead of using register_hstore() which converts hstore into a python
+ dict, we query the 'oid' of hstore which will be different for each
+ database and register a type caster that converts it to unicode.
+ http://initd.org/psycopg/docs/extras.html#psycopg2.extras.register_hstore
+ """
+ with conn.cursor() as cur:
+ try:
+ cur.execute(
+ "select t.oid FROM pg_type t WHERE t.typname = 'hstore' and t.typisdefined"
+ )
+ oid = cur.fetchone()[0]
+ ext.register_type(ext.new_type((oid,), "HSTORE", ext.UNICODE))
+ except Exception:
+ pass
+
+
+class PGExecute(object):
+
+ # The boolean argument to the current_schemas function indicates whether
+ # implicit schemas, e.g. pg_catalog
+ search_path_query = """
+ SELECT * FROM unnest(current_schemas(true))"""
+
+ schemata_query = """
+ SELECT nspname
+ FROM pg_catalog.pg_namespace
+ ORDER BY 1 """
+
+ tables_query = """
+ SELECT n.nspname schema_name,
+ c.relname table_name
+ FROM pg_catalog.pg_class c
+ LEFT JOIN pg_catalog.pg_namespace n
+ ON n.oid = c.relnamespace
+ WHERE c.relkind = ANY(%s)
+ ORDER BY 1,2;"""
+
+ databases_query = """
+ SELECT d.datname
+ FROM pg_catalog.pg_database d
+ ORDER BY 1"""
+
+ full_databases_query = """
+ SELECT d.datname as "Name",
+ pg_catalog.pg_get_userbyid(d.datdba) as "Owner",
+ pg_catalog.pg_encoding_to_char(d.encoding) as "Encoding",
+ d.datcollate as "Collate",
+ d.datctype as "Ctype",
+ pg_catalog.array_to_string(d.datacl, E'\n') AS "Access privileges"
+ FROM pg_catalog.pg_database d
+ ORDER BY 1"""
+
+ socket_directory_query = """
+ SELECT setting
+ FROM pg_settings
+ WHERE name = 'unix_socket_directories'
+ """
+
+ view_definition_query = """
+ WITH v AS (SELECT %s::pg_catalog.regclass::pg_catalog.oid AS v_oid)
+ SELECT nspname, relname, relkind,
+ pg_catalog.pg_get_viewdef(c.oid, true),
+ array_remove(array_remove(c.reloptions,'check_option=local'),
+ 'check_option=cascaded') AS reloptions,
+ CASE
+ WHEN 'check_option=local' = ANY (c.reloptions) THEN 'LOCAL'::text
+ WHEN 'check_option=cascaded' = ANY (c.reloptions) THEN 'CASCADED'::text
+ ELSE NULL
+ END AS checkoption
+ FROM pg_catalog.pg_class c
+ LEFT JOIN pg_catalog.pg_namespace n ON (c.relnamespace = n.oid)
+ JOIN v ON (c.oid = v.v_oid)"""
+
+ function_definition_query = """
+ WITH f AS
+ (SELECT %s::pg_catalog.regproc::pg_catalog.oid AS f_oid)
+ SELECT pg_catalog.pg_get_functiondef(f.f_oid)
+ FROM f"""
+
+ version_query = "SELECT version();"
+
+ def __init__(
+ self,
+ database=None,
+ user=None,
+ password=None,
+ host=None,
+ port=None,
+ dsn=None,
+ **kwargs,
+ ):
+ self._conn_params = {}
+ self.conn = None
+ self.dbname = None
+ self.user = None
+ self.password = None
+ self.host = None
+ self.port = None
+ self.server_version = None
+ self.extra_args = None
+ self.connect(database, user, password, host, port, dsn, **kwargs)
+ self.reset_expanded = None
+
+ def copy(self):
+ """Returns a clone of the current executor."""
+ return self.__class__(**self._conn_params)
+
+ def connect(
+ self,
+ database=None,
+ user=None,
+ password=None,
+ host=None,
+ port=None,
+ dsn=None,
+ **kwargs,
+ ):
+
+ conn_params = self._conn_params.copy()
+
+ new_params = {
+ "database": database,
+ "user": user,
+ "password": password,
+ "host": host,
+ "port": port,
+ "dsn": dsn,
+ }
+ new_params.update(kwargs)
+
+ if new_params["dsn"]:
+ new_params = {"dsn": new_params["dsn"], "password": new_params["password"]}
+
+ if new_params["password"]:
+ new_params["dsn"] = make_dsn(
+ new_params["dsn"], password=new_params.pop("password")
+ )
+
+ conn_params.update({k: v for k, v in new_params.items() if v})
+
+ conn = psycopg2.connect(**conn_params)
+ cursor = conn.cursor()
+ conn.set_client_encoding("utf8")
+
+ self._conn_params = conn_params
+ if self.conn:
+ self.conn.close()
+ self.conn = conn
+ self.conn.autocommit = True
+
+ # When we connect using a DSN, we don't really know what db,
+ # user, etc. we connected to. Let's read it.
+ # Note: moved this after setting autocommit because of #664.
+ libpq_version = psycopg2.__libpq_version__
+ dsn_parameters = {}
+ if libpq_version >= 93000:
+ # use actual connection info from psycopg2.extensions.Connection.info
+ # as libpq_version > 9.3 is available and required dependency
+ dsn_parameters = conn.info.dsn_parameters
+ else:
+ try:
+ dsn_parameters = conn.get_dsn_parameters()
+ except Exception as x:
+ # https://github.com/dbcli/pgcli/issues/1110
+ # PQconninfo not available in libpq < 9.3
+ _logger.info("Exception in get_dsn_parameters: %r", x)
+
+ if dsn_parameters:
+ self.dbname = dsn_parameters.get("dbname")
+ self.user = dsn_parameters.get("user")
+ self.host = dsn_parameters.get("host")
+ self.port = dsn_parameters.get("port")
+ else:
+ self.dbname = conn_params.get("database")
+ self.user = conn_params.get("user")
+ self.host = conn_params.get("host")
+ self.port = conn_params.get("port")
+
+ self.password = password
+ self.extra_args = kwargs
+
+ if not self.host:
+ self.host = self.get_socket_directory()
+
+ pid = self._select_one(cursor, "select pg_backend_pid()")[0]
+ self.pid = pid
+ self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1")
+ self.server_version = conn.get_parameter_status("server_version")
+
+ register_date_typecasters(conn)
+ register_json_typecasters(self.conn, self._json_typecaster)
+ register_hstore_typecaster(self.conn)
+
+ @property
+ def short_host(self):
+ if "," in self.host:
+ host, _, _ = self.host.partition(",")
+ else:
+ host = self.host
+ short_host, _, _ = host.partition(".")
+ return short_host
+
+ def _select_one(self, cur, sql):
+ """
+ Helper method to run a select and retrieve a single field value
+ :param cur: cursor
+ :param sql: string
+ :return: string
+ """
+ cur.execute(sql)
+ return cur.fetchone()
+
+ def _json_typecaster(self, json_data):
+ """Interpret incoming JSON data as a string.
+
+ The raw data is decoded using the connection's encoding, which defaults
+ to the database's encoding.
+
+ See http://initd.org/psycopg/docs/connection.html#connection.encoding
+ """
+
+ return json_data
+
+ def failed_transaction(self):
+ status = self.conn.get_transaction_status()
+ return status == ext.TRANSACTION_STATUS_INERROR
+
+ def valid_transaction(self):
+ status = self.conn.get_transaction_status()
+ return (
+ status == ext.TRANSACTION_STATUS_ACTIVE
+ or status == ext.TRANSACTION_STATUS_INTRANS
+ )
+
+ def run(
+ self, statement, pgspecial=None, exception_formatter=None, on_error_resume=False
+ ):
+ """Execute the sql in the database and return the results.
+
+ :param statement: A string containing one or more sql statements
+ :param pgspecial: PGSpecial object
+ :param exception_formatter: A callable that accepts an Exception and
+ returns a formatted (title, rows, headers, status) tuple that can
+ act as a query result. If an exception_formatter is not supplied,
+ psycopg2 exceptions are always raised.
+ :param on_error_resume: Bool. If true, queries following an exception
+ (assuming exception_formatter has been supplied) continue to
+ execute.
+
+ :return: Generator yielding tuples containing
+ (title, rows, headers, status, query, success, is_special)
+ """
+
+ # Remove spaces and EOL
+ statement = statement.strip()
+ if not statement: # Empty string
+ yield (None, None, None, None, statement, False, False)
+
+ # Split the sql into separate queries and run each one.
+ for sql in sqlparse.split(statement):
+ # Remove spaces, eol and semi-colons.
+ sql = sql.rstrip(";")
+ sql = sqlparse.format(sql, strip_comments=True).strip()
+ if not sql:
+ continue
+ try:
+ if pgspecial:
+ # \G is treated specially since we have to set the expanded output.
+ if sql.endswith("\\G"):
+ if not pgspecial.expanded_output:
+ pgspecial.expanded_output = True
+ self.reset_expanded = True
+ sql = sql[:-2].strip()
+
+ # First try to run each query as special
+ _logger.debug("Trying a pgspecial command. sql: %r", sql)
+ try:
+ cur = self.conn.cursor()
+ except psycopg2.InterfaceError:
+ # edge case when connection is already closed, but we
+ # don't need cursor for special_cmd.arg_type == NO_QUERY.
+ # See https://github.com/dbcli/pgcli/issues/1014.
+ cur = None
+ try:
+ for result in pgspecial.execute(cur, sql):
+ # e.g. execute_from_file already appends these
+ if len(result) < 7:
+ yield result + (sql, True, True)
+ else:
+ yield result
+ continue
+ except special.CommandNotFound:
+ pass
+
+ # Not a special command, so execute as normal sql
+ yield self.execute_normal_sql(sql) + (sql, True, False)
+ except psycopg2.DatabaseError as e:
+ _logger.error("sql: %r, error: %r", sql, e)
+ _logger.error("traceback: %r", traceback.format_exc())
+
+ if self._must_raise(e) or not exception_formatter:
+ raise
+
+ yield None, None, None, exception_formatter(e), sql, False, False
+
+ if not on_error_resume:
+ break
+ finally:
+ if self.reset_expanded:
+ pgspecial.expanded_output = False
+ self.reset_expanded = None
+
+ def _must_raise(self, e):
+ """Return true if e is an error that should not be caught in ``run``.
+
+ An uncaught error will prompt the user to reconnect; as long as we
+ detect that the connection is stil open, we catch the error, as
+ reconnecting won't solve that problem.
+
+ :param e: DatabaseError. An exception raised while executing a query.
+
+ :return: Bool. True if ``run`` must raise this exception.
+
+ """
+ return self.conn.closed != 0
+
+ def execute_normal_sql(self, split_sql):
+ """Returns tuple (title, rows, headers, status)"""
+ _logger.debug("Regular sql statement. sql: %r", split_sql)
+ cur = self.conn.cursor()
+ cur.execute(split_sql)
+
+ # conn.notices persist between queies, we use pop to clear out the list
+ title = ""
+ while len(self.conn.notices) > 0:
+ title = self.conn.notices.pop() + title
+
+ # cur.description will be None for operations that do not return
+ # rows.
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ return title, cur, headers, cur.statusmessage
+ else:
+ _logger.debug("No rows in result.")
+ return title, None, None, cur.statusmessage
+
+ def search_path(self):
+ """Returns the current search path as a list of schema names"""
+
+ try:
+ with self.conn.cursor() as cur:
+ _logger.debug("Search path query. sql: %r", self.search_path_query)
+ cur.execute(self.search_path_query)
+ return [x[0] for x in cur.fetchall()]
+ except psycopg2.ProgrammingError:
+ fallback = "SELECT * FROM current_schemas(true)"
+ with self.conn.cursor() as cur:
+ _logger.debug("Search path query. sql: %r", fallback)
+ cur.execute(fallback)
+ return cur.fetchone()[0]
+
+ def view_definition(self, spec):
+ """Returns the SQL defining views described by `spec`"""
+
+ template = "CREATE OR REPLACE {6} VIEW {0}.{1} AS \n{3}"
+ # 2: relkind, v or m (materialized)
+ # 4: reloptions, null
+ # 5: checkoption: local or cascaded
+ with self.conn.cursor() as cur:
+ sql = self.view_definition_query
+ _logger.debug("View Definition Query. sql: %r\nspec: %r", sql, spec)
+ try:
+ cur.execute(sql, (spec,))
+ except psycopg2.ProgrammingError:
+ raise RuntimeError("View {} does not exist.".format(spec))
+ result = cur.fetchone()
+ view_type = "MATERIALIZED" if result[2] == "m" else ""
+ return template.format(*result + (view_type,))
+
+ def function_definition(self, spec):
+ """Returns the SQL defining functions described by `spec`"""
+
+ with self.conn.cursor() as cur:
+ sql = self.function_definition_query
+ _logger.debug("Function Definition Query. sql: %r\nspec: %r", sql, spec)
+ try:
+ cur.execute(sql, (spec,))
+ result = cur.fetchone()
+ return result[0]
+ except psycopg2.ProgrammingError:
+ raise RuntimeError("Function {} does not exist.".format(spec))
+
+ def schemata(self):
+ """Returns a list of schema names in the database"""
+
+ with self.conn.cursor() as cur:
+ _logger.debug("Schemata Query. sql: %r", self.schemata_query)
+ cur.execute(self.schemata_query)
+ return [x[0] for x in cur.fetchall()]
+
+ def _relations(self, kinds=("r", "p", "f", "v", "m")):
+ """Get table or view name metadata
+
+ :param kinds: list of postgres relkind filters:
+ 'r' - table
+ 'p' - partitioned table
+ 'f' - foreign table
+ 'v' - view
+ 'm' - materialized view
+ :return: (schema_name, rel_name) tuples
+ """
+
+ with self.conn.cursor() as cur:
+ sql = cur.mogrify(self.tables_query, [kinds])
+ _logger.debug("Tables Query. sql: %r", sql)
+ cur.execute(sql)
+ for row in cur:
+ yield row
+
+ def tables(self):
+ """Yields (schema_name, table_name) tuples"""
+ for row in self._relations(kinds=["r", "p", "f"]):
+ yield row
+
+ def views(self):
+ """Yields (schema_name, view_name) tuples.
+
+ Includes both views and and materialized views
+ """
+ for row in self._relations(kinds=["v", "m"]):
+ yield row
+
+ def _columns(self, kinds=("r", "p", "f", "v", "m")):
+ """Get column metadata for tables and views
+
+ :param kinds: kinds: list of postgres relkind filters:
+ 'r' - table
+ 'p' - partitioned table
+ 'f' - foreign table
+ 'v' - view
+ 'm' - materialized view
+ :return: list of (schema_name, relation_name, column_name, column_type) tuples
+ """
+
+ if self.conn.server_version >= 80400:
+ columns_query = """
+ SELECT nsp.nspname schema_name,
+ cls.relname table_name,
+ att.attname column_name,
+ att.atttypid::regtype::text type_name,
+ att.atthasdef AS has_default,
+ pg_catalog.pg_get_expr(def.adbin, def.adrelid, true) as default
+ FROM pg_catalog.pg_attribute att
+ INNER JOIN pg_catalog.pg_class cls
+ ON att.attrelid = cls.oid
+ INNER JOIN pg_catalog.pg_namespace nsp
+ ON cls.relnamespace = nsp.oid
+ LEFT OUTER JOIN pg_attrdef def
+ ON def.adrelid = att.attrelid
+ AND def.adnum = att.attnum
+ WHERE cls.relkind = ANY(%s)
+ AND NOT att.attisdropped
+ AND att.attnum > 0
+ ORDER BY 1, 2, att.attnum"""
+ else:
+ columns_query = """
+ SELECT nsp.nspname schema_name,
+ cls.relname table_name,
+ att.attname column_name,
+ typ.typname type_name,
+ NULL AS has_default,
+ NULL AS default
+ FROM pg_catalog.pg_attribute att
+ INNER JOIN pg_catalog.pg_class cls
+ ON att.attrelid = cls.oid
+ INNER JOIN pg_catalog.pg_namespace nsp
+ ON cls.relnamespace = nsp.oid
+ INNER JOIN pg_catalog.pg_type typ
+ ON typ.oid = att.atttypid
+ WHERE cls.relkind = ANY(%s)
+ AND NOT att.attisdropped
+ AND att.attnum > 0
+ ORDER BY 1, 2, att.attnum"""
+
+ with self.conn.cursor() as cur:
+ sql = cur.mogrify(columns_query, [kinds])
+ _logger.debug("Columns Query. sql: %r", sql)
+ cur.execute(sql)
+ for row in cur:
+ yield row
+
+ def table_columns(self):
+ for row in self._columns(kinds=["r", "p", "f"]):
+ yield row
+
+ def view_columns(self):
+ for row in self._columns(kinds=["v", "m"]):
+ yield row
+
+ def databases(self):
+ with self.conn.cursor() as cur:
+ _logger.debug("Databases Query. sql: %r", self.databases_query)
+ cur.execute(self.databases_query)
+ return [x[0] for x in cur.fetchall()]
+
+ def full_databases(self):
+ with self.conn.cursor() as cur:
+ _logger.debug("Databases Query. sql: %r", self.full_databases_query)
+ cur.execute(self.full_databases_query)
+ headers = [x[0] for x in cur.description]
+ return cur.fetchall(), headers, cur.statusmessage
+
+ def get_socket_directory(self):
+ with self.conn.cursor() as cur:
+ _logger.debug(
+ "Socket directory Query. sql: %r", self.socket_directory_query
+ )
+ cur.execute(self.socket_directory_query)
+ result = cur.fetchone()
+ return result[0] if result else ""
+
+ def foreignkeys(self):
+ """Yields ForeignKey named tuples"""
+
+ if self.conn.server_version < 90000:
+ return
+
+ with self.conn.cursor() as cur:
+ query = """
+ SELECT s_p.nspname AS parentschema,
+ t_p.relname AS parenttable,
+ unnest((
+ select
+ array_agg(attname ORDER BY i)
+ from
+ (select unnest(confkey) as attnum, generate_subscripts(confkey, 1) as i) x
+ JOIN pg_catalog.pg_attribute c USING(attnum)
+ WHERE c.attrelid = fk.confrelid
+ )) AS parentcolumn,
+ s_c.nspname AS childschema,
+ t_c.relname AS childtable,
+ unnest((
+ select
+ array_agg(attname ORDER BY i)
+ from
+ (select unnest(conkey) as attnum, generate_subscripts(conkey, 1) as i) x
+ JOIN pg_catalog.pg_attribute c USING(attnum)
+ WHERE c.attrelid = fk.conrelid
+ )) AS childcolumn
+ FROM pg_catalog.pg_constraint fk
+ JOIN pg_catalog.pg_class t_p ON t_p.oid = fk.confrelid
+ JOIN pg_catalog.pg_namespace s_p ON s_p.oid = t_p.relnamespace
+ JOIN pg_catalog.pg_class t_c ON t_c.oid = fk.conrelid
+ JOIN pg_catalog.pg_namespace s_c ON s_c.oid = t_c.relnamespace
+ WHERE fk.contype = 'f';
+ """
+ _logger.debug("Functions Query. sql: %r", query)
+ cur.execute(query)
+ for row in cur:
+ yield ForeignKey(*row)
+
+ def functions(self):
+ """Yields FunctionMetadata named tuples"""
+
+ if self.conn.server_version >= 110000:
+ query = """
+ SELECT n.nspname schema_name,
+ p.proname func_name,
+ p.proargnames,
+ COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[],
+ p.proargmodes,
+ prorettype::regtype::text return_type,
+ p.prokind = 'a' is_aggregate,
+ p.prokind = 'w' is_window,
+ p.proretset is_set_returning,
+ d.deptype = 'e' is_extension,
+ pg_get_expr(proargdefaults, 0) AS arg_defaults
+ FROM pg_catalog.pg_proc p
+ INNER JOIN pg_catalog.pg_namespace n
+ ON n.oid = p.pronamespace
+ LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
+ WHERE p.prorettype::regtype != 'trigger'::regtype
+ ORDER BY 1, 2
+ """
+ elif self.conn.server_version > 90000:
+ query = """
+ SELECT n.nspname schema_name,
+ p.proname func_name,
+ p.proargnames,
+ COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[],
+ p.proargmodes,
+ prorettype::regtype::text return_type,
+ p.proisagg is_aggregate,
+ p.proiswindow is_window,
+ p.proretset is_set_returning,
+ d.deptype = 'e' is_extension,
+ pg_get_expr(proargdefaults, 0) AS arg_defaults
+ FROM pg_catalog.pg_proc p
+ INNER JOIN pg_catalog.pg_namespace n
+ ON n.oid = p.pronamespace
+ LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
+ WHERE p.prorettype::regtype != 'trigger'::regtype
+ ORDER BY 1, 2
+ """
+ elif self.conn.server_version >= 80400:
+ query = """
+ SELECT n.nspname schema_name,
+ p.proname func_name,
+ p.proargnames,
+ COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[],
+ p.proargmodes,
+ prorettype::regtype::text,
+ p.proisagg is_aggregate,
+ false is_window,
+ p.proretset is_set_returning,
+ d.deptype = 'e' is_extension,
+ NULL AS arg_defaults
+ FROM pg_catalog.pg_proc p
+ INNER JOIN pg_catalog.pg_namespace n
+ ON n.oid = p.pronamespace
+ LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
+ WHERE p.prorettype::regtype != 'trigger'::regtype
+ ORDER BY 1, 2
+ """
+ else:
+ query = """
+ SELECT n.nspname schema_name,
+ p.proname func_name,
+ p.proargnames,
+ NULL arg_types,
+ NULL arg_modes,
+ '' ret_type,
+ p.proisagg is_aggregate,
+ false is_window,
+ p.proretset is_set_returning,
+ d.deptype = 'e' is_extension,
+ NULL AS arg_defaults
+ FROM pg_catalog.pg_proc p
+ INNER JOIN pg_catalog.pg_namespace n
+ ON n.oid = p.pronamespace
+ LEFT JOIN pg_depend d ON d.objid = p.oid and d.deptype = 'e'
+ WHERE p.prorettype::regtype != 'trigger'::regtype
+ ORDER BY 1, 2
+ """
+
+ with self.conn.cursor() as cur:
+ _logger.debug("Functions Query. sql: %r", query)
+ cur.execute(query)
+ for row in cur:
+ yield FunctionMetadata(*row)
+
+ def datatypes(self):
+ """Yields tuples of (schema_name, type_name)"""
+
+ with self.conn.cursor() as cur:
+ if self.conn.server_version > 90000:
+ query = """
+ SELECT n.nspname schema_name,
+ t.typname type_name
+ FROM pg_catalog.pg_type t
+ INNER JOIN pg_catalog.pg_namespace n
+ ON n.oid = t.typnamespace
+ WHERE ( t.typrelid = 0 -- non-composite types
+ OR ( -- composite type, but not a table
+ SELECT c.relkind = 'c'
+ FROM pg_catalog.pg_class c
+ WHERE c.oid = t.typrelid
+ )
+ )
+ AND NOT EXISTS( -- ignore array types
+ SELECT 1
+ FROM pg_catalog.pg_type el
+ WHERE el.oid = t.typelem AND el.typarray = t.oid
+ )
+ AND n.nspname <> 'pg_catalog'
+ AND n.nspname <> 'information_schema'
+ ORDER BY 1, 2;
+ """
+ else:
+ query = """
+ SELECT n.nspname schema_name,
+ pg_catalog.format_type(t.oid, NULL) type_name
+ FROM pg_catalog.pg_type t
+ LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
+ WHERE (t.typrelid = 0 OR (SELECT c.relkind = 'c' FROM pg_catalog.pg_class c WHERE c.oid = t.typrelid))
+ AND t.typname !~ '^_'
+ AND n.nspname <> 'pg_catalog'
+ AND n.nspname <> 'information_schema'
+ AND pg_catalog.pg_type_is_visible(t.oid)
+ ORDER BY 1, 2;
+ """
+ _logger.debug("Datatypes Query. sql: %r", query)
+ cur.execute(query)
+ for row in cur:
+ yield row
+
+ def casing(self):
+ """Yields the most common casing for names used in db functions"""
+ with self.conn.cursor() as cur:
+ query = r"""
+ WITH Words AS (
+ SELECT regexp_split_to_table(prosrc, '\W+') AS Word, COUNT(1)
+ FROM pg_catalog.pg_proc P
+ JOIN pg_catalog.pg_namespace N ON N.oid = P.pronamespace
+ JOIN pg_catalog.pg_language L ON L.oid = P.prolang
+ WHERE L.lanname IN ('sql', 'plpgsql')
+ AND N.nspname NOT IN ('pg_catalog', 'information_schema')
+ GROUP BY Word
+ ),
+ OrderWords AS (
+ SELECT Word,
+ ROW_NUMBER() OVER(PARTITION BY LOWER(Word) ORDER BY Count DESC)
+ FROM Words
+ WHERE Word ~* '.*[a-z].*'
+ ),
+ Names AS (
+ --Column names
+ SELECT attname AS Name
+ FROM pg_catalog.pg_attribute
+ UNION -- Table/view names
+ SELECT relname
+ FROM pg_catalog.pg_class
+ UNION -- Function names
+ SELECT proname
+ FROM pg_catalog.pg_proc
+ UNION -- Type names
+ SELECT typname
+ FROM pg_catalog.pg_type
+ UNION -- Schema names
+ SELECT nspname
+ FROM pg_catalog.pg_namespace
+ UNION -- Parameter names
+ SELECT unnest(proargnames)
+ FROM pg_proc
+ )
+ SELECT Word
+ FROM OrderWords
+ WHERE LOWER(Word) IN (SELECT Name FROM Names)
+ AND Row_Number = 1;
+ """
+ _logger.debug("Casing Query. sql: %r", query)
+ cur.execute(query)
+ for row in cur:
+ yield row[0]
diff --git a/pgcli/pgstyle.py b/pgcli/pgstyle.py
new file mode 100644
index 0000000..8229037
--- /dev/null
+++ b/pgcli/pgstyle.py
@@ -0,0 +1,116 @@
+import logging
+
+import pygments.styles
+from pygments.token import string_to_tokentype, Token
+from pygments.style import Style as PygmentsStyle
+from pygments.util import ClassNotFound
+from prompt_toolkit.styles.pygments import style_from_pygments_cls
+from prompt_toolkit.styles import merge_styles, Style
+
+logger = logging.getLogger(__name__)
+
+# map Pygments tokens (ptk 1.0) to class names (ptk 2.0).
+TOKEN_TO_PROMPT_STYLE = {
+ Token.Menu.Completions.Completion.Current: "completion-menu.completion.current",
+ Token.Menu.Completions.Completion: "completion-menu.completion",
+ Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current",
+ Token.Menu.Completions.Meta: "completion-menu.meta.completion",
+ Token.Menu.Completions.MultiColumnMeta: "completion-menu.multi-column-meta",
+ Token.Menu.Completions.ProgressButton: "scrollbar.arrow", # best guess
+ Token.Menu.Completions.ProgressBar: "scrollbar", # best guess
+ Token.SelectedText: "selected",
+ Token.SearchMatch: "search",
+ Token.SearchMatch.Current: "search.current",
+ Token.Toolbar: "bottom-toolbar",
+ Token.Toolbar.Off: "bottom-toolbar.off",
+ Token.Toolbar.On: "bottom-toolbar.on",
+ Token.Toolbar.Search: "search-toolbar",
+ Token.Toolbar.Search.Text: "search-toolbar.text",
+ Token.Toolbar.System: "system-toolbar",
+ Token.Toolbar.Arg: "arg-toolbar",
+ Token.Toolbar.Arg.Text: "arg-toolbar.text",
+ Token.Toolbar.Transaction.Valid: "bottom-toolbar.transaction.valid",
+ Token.Toolbar.Transaction.Failed: "bottom-toolbar.transaction.failed",
+ Token.Output.Header: "output.header",
+ Token.Output.OddRow: "output.odd-row",
+ Token.Output.EvenRow: "output.even-row",
+ Token.Output.Null: "output.null",
+ Token.Literal.String: "literal.string",
+ Token.Literal.Number: "literal.number",
+ Token.Keyword: "keyword",
+ Token.Prompt: "prompt",
+ Token.Continuation: "continuation",
+}
+
+# reverse dict for cli_helpers, because they still expect Pygments tokens.
+PROMPT_STYLE_TO_TOKEN = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()}
+
+
+def parse_pygments_style(token_name, style_object, style_dict):
+ """Parse token type and style string.
+
+ :param token_name: str name of Pygments token. Example: "Token.String"
+ :param style_object: pygments.style.Style instance to use as base
+ :param style_dict: dict of token names and their styles, customized to this cli
+
+ """
+ token_type = string_to_tokentype(token_name)
+ try:
+ other_token_type = string_to_tokentype(style_dict[token_name])
+ return token_type, style_object.styles[other_token_type]
+ except AttributeError:
+ return token_type, style_dict[token_name]
+
+
+def style_factory(name, cli_style):
+ try:
+ style = pygments.styles.get_style_by_name(name)
+ except ClassNotFound:
+ style = pygments.styles.get_style_by_name("native")
+
+ prompt_styles = []
+ # prompt-toolkit used pygments tokens for styling before, switched to style
+ # names in 2.0. Convert old token types to new style names, for backwards compatibility.
+ for token in cli_style:
+ if token.startswith("Token."):
+ # treat as pygments token (1.0)
+ token_type, style_value = parse_pygments_style(token, style, cli_style)
+ if token_type in TOKEN_TO_PROMPT_STYLE:
+ prompt_style = TOKEN_TO_PROMPT_STYLE[token_type]
+ prompt_styles.append((prompt_style, style_value))
+ else:
+ # we don't want to support tokens anymore
+ logger.error("Unhandled style / class name: %s", token)
+ else:
+ # treat as prompt style name (2.0). See default style names here:
+ # https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py
+ prompt_styles.append((token, cli_style[token]))
+
+ override_style = Style([("bottom-toolbar", "noreverse")])
+ return merge_styles(
+ [style_from_pygments_cls(style), override_style, Style(prompt_styles)]
+ )
+
+
+def style_factory_output(name, cli_style):
+ try:
+ style = pygments.styles.get_style_by_name(name).styles
+ except ClassNotFound:
+ style = pygments.styles.get_style_by_name("native").styles
+
+ for token in cli_style:
+ if token.startswith("Token."):
+ token_type, style_value = parse_pygments_style(token, style, cli_style)
+ style.update({token_type: style_value})
+ elif token in PROMPT_STYLE_TO_TOKEN:
+ token_type = PROMPT_STYLE_TO_TOKEN[token]
+ style.update({token_type: cli_style[token]})
+ else:
+ # TODO: cli helpers will have to switch to ptk.Style
+ logger.error("Unhandled style / class name: %s", token)
+
+ class OutputStyle(PygmentsStyle):
+ default_style = ""
+ styles = style
+
+ return OutputStyle
diff --git a/pgcli/pgtoolbar.py b/pgcli/pgtoolbar.py
new file mode 100644
index 0000000..f4289a1
--- /dev/null
+++ b/pgcli/pgtoolbar.py
@@ -0,0 +1,62 @@
+from prompt_toolkit.key_binding.vi_state import InputMode
+from prompt_toolkit.application import get_app
+
+
+def _get_vi_mode():
+ return {
+ InputMode.INSERT: "I",
+ InputMode.NAVIGATION: "N",
+ InputMode.REPLACE: "R",
+ InputMode.REPLACE_SINGLE: "R",
+ InputMode.INSERT_MULTIPLE: "M",
+ }[get_app().vi_state.input_mode]
+
+
+def create_toolbar_tokens_func(pgcli):
+ """Return a function that generates the toolbar tokens."""
+
+ def get_toolbar_tokens():
+ result = []
+ result.append(("class:bottom-toolbar", " "))
+
+ if pgcli.completer.smart_completion:
+ result.append(("class:bottom-toolbar.on", "[F2] Smart Completion: ON "))
+ else:
+ result.append(("class:bottom-toolbar.off", "[F2] Smart Completion: OFF "))
+
+ if pgcli.multi_line:
+ result.append(("class:bottom-toolbar.on", "[F3] Multiline: ON "))
+ else:
+ result.append(("class:bottom-toolbar.off", "[F3] Multiline: OFF "))
+
+ if pgcli.multi_line:
+ if pgcli.multiline_mode == "safe":
+ result.append(("class:bottom-toolbar", " ([Esc] [Enter] to execute]) "))
+ else:
+ result.append(
+ ("class:bottom-toolbar", " (Semi-colon [;] will end the line) ")
+ )
+
+ if pgcli.vi_mode:
+ result.append(
+ ("class:bottom-toolbar", "[F4] Vi-mode (" + _get_vi_mode() + ")")
+ )
+ else:
+ result.append(("class:bottom-toolbar", "[F4] Emacs-mode"))
+
+ if pgcli.pgexecute.failed_transaction():
+ result.append(
+ ("class:bottom-toolbar.transaction.failed", " Failed transaction")
+ )
+
+ if pgcli.pgexecute.valid_transaction():
+ result.append(
+ ("class:bottom-toolbar.transaction.valid", " Transaction")
+ )
+
+ if pgcli.completion_refresher.is_refreshing():
+ result.append(("class:bottom-toolbar", " Refreshing completions..."))
+
+ return result
+
+ return get_toolbar_tokens