summaryrefslogtreecommitdiffstats
path: root/pgcli/pgcompleter.py
diff options
context:
space:
mode:
Diffstat (limited to 'pgcli/pgcompleter.py')
-rw-r--r--pgcli/pgcompleter.py1046
1 files changed, 1046 insertions, 0 deletions
diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py
new file mode 100644
index 0000000..9c95a01
--- /dev/null
+++ b/pgcli/pgcompleter.py
@@ -0,0 +1,1046 @@
+import logging
+import re
+from itertools import count, repeat, chain
+import operator
+from collections import namedtuple, defaultdict, OrderedDict
+from cli_helpers.tabular_output import TabularOutputFormatter
+from pgspecial.namedqueries import NamedQueries
+from prompt_toolkit.completion import Completer, Completion, PathCompleter
+from prompt_toolkit.document import Document
+from .packages.sqlcompletion import (
+ FromClauseItem,
+ suggest_type,
+ Special,
+ Database,
+ Schema,
+ Table,
+ TableFormat,
+ Function,
+ Column,
+ View,
+ Keyword,
+ NamedQuery,
+ Datatype,
+ Alias,
+ Path,
+ JoinCondition,
+ Join,
+)
+from .packages.parseutils.meta import ColumnMetadata, ForeignKey
+from .packages.parseutils.utils import last_word
+from .packages.parseutils.tables import TableReference
+from .packages.pgliterals.main import get_literals
+from .packages.prioritization import PrevalenceCounter
+from .config import load_config, config_location
+
+_logger = logging.getLogger(__name__)
+
+Match = namedtuple("Match", ["completion", "priority"])
+
+_SchemaObject = namedtuple("SchemaObject", "name schema meta")
+
+
+def SchemaObject(name, schema=None, meta=None):
+ return _SchemaObject(name, schema, meta)
+
+
+_Candidate = namedtuple("Candidate", "completion prio meta synonyms prio2 display")
+
+
+def Candidate(
+ completion, prio=None, meta=None, synonyms=None, prio2=None, display=None
+):
+ return _Candidate(
+ completion, prio, meta, synonyms or [completion], prio2, display or completion
+ )
+
+
+# Used to strip trailing '::some_type' from default-value expressions
+arg_default_type_strip_regex = re.compile(r"::[\w\.]+(\[\])?$")
+
+normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"'
+
+
+def generate_alias(tbl):
+ """Generate a table alias, consisting of all upper-case letters in
+ the table name, or, if there are no upper-case letters, the first letter +
+ all letters preceded by _
+ param tbl - unescaped name of the table to alias
+ """
+ return "".join(
+ [l for l in tbl if l.isupper()]
+ or [l for l, prev in zip(tbl, "_" + tbl) if prev == "_" and l != "_"]
+ )
+
+
+class PGCompleter(Completer):
+ # keywords_tree: A dict mapping keywords to well known following keywords.
+ # e.g. 'CREATE': ['TABLE', 'USER', ...],
+ keywords_tree = get_literals("keywords", type_=dict)
+ keywords = tuple(set(chain(keywords_tree.keys(), *keywords_tree.values())))
+ functions = get_literals("functions")
+ datatypes = get_literals("datatypes")
+ reserved_words = set(get_literals("reserved"))
+
+ def __init__(self, smart_completion=True, pgspecial=None, settings=None):
+ super(PGCompleter, self).__init__()
+ self.smart_completion = smart_completion
+ self.pgspecial = pgspecial
+ self.prioritizer = PrevalenceCounter()
+ settings = settings or {}
+ self.signature_arg_style = settings.get(
+ "signature_arg_style", "{arg_name} {arg_type}"
+ )
+ self.call_arg_style = settings.get(
+ "call_arg_style", "{arg_name: <{max_arg_len}} := {arg_default}"
+ )
+ self.call_arg_display_style = settings.get(
+ "call_arg_display_style", "{arg_name}"
+ )
+ self.call_arg_oneliner_max = settings.get("call_arg_oneliner_max", 2)
+ self.search_path_filter = settings.get("search_path_filter")
+ self.generate_aliases = settings.get("generate_aliases")
+ self.casing_file = settings.get("casing_file")
+ self.insert_col_skip_patterns = [
+ re.compile(pattern)
+ for pattern in settings.get(
+ "insert_col_skip_patterns", [r"^now\(\)$", r"^nextval\("]
+ )
+ ]
+ self.generate_casing_file = settings.get("generate_casing_file")
+ self.qualify_columns = settings.get("qualify_columns", "if_more_than_one_table")
+ self.asterisk_column_order = settings.get(
+ "asterisk_column_order", "table_order"
+ )
+
+ keyword_casing = settings.get("keyword_casing", "upper").lower()
+ if keyword_casing not in ("upper", "lower", "auto"):
+ keyword_casing = "upper"
+ self.keyword_casing = keyword_casing
+ self.name_pattern = re.compile(r"^[_a-z][_a-z0-9\$]*$")
+
+ self.databases = []
+ self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}}
+ self.search_path = []
+ self.casing = {}
+
+ self.all_completions = set(self.keywords + self.functions)
+
+ def escape_name(self, name):
+ if name and (
+ (not self.name_pattern.match(name))
+ or (name.upper() in self.reserved_words)
+ or (name.upper() in self.functions)
+ ):
+ name = '"%s"' % name
+
+ return name
+
+ def escape_schema(self, name):
+ return "'{}'".format(self.unescape_name(name))
+
+ def unescape_name(self, name):
+ """ Unquote a string."""
+ if name and name[0] == '"' and name[-1] == '"':
+ name = name[1:-1]
+
+ return name
+
+ def escaped_names(self, names):
+ return [self.escape_name(name) for name in names]
+
+ def extend_database_names(self, databases):
+ self.databases.extend(databases)
+
+ def extend_keywords(self, additional_keywords):
+ self.keywords.extend(additional_keywords)
+ self.all_completions.update(additional_keywords)
+
+ def extend_schemata(self, schemata):
+
+ # schemata is a list of schema names
+ schemata = self.escaped_names(schemata)
+ metadata = self.dbmetadata["tables"]
+ for schema in schemata:
+ metadata[schema] = {}
+
+ # dbmetadata.values() are the 'tables' and 'functions' dicts
+ for metadata in self.dbmetadata.values():
+ for schema in schemata:
+ metadata[schema] = {}
+
+ self.all_completions.update(schemata)
+
+ def extend_casing(self, words):
+ """extend casing data
+
+ :return:
+ """
+ # casing should be a dict {lowercasename:PreferredCasingName}
+ self.casing = dict((word.lower(), word) for word in words)
+
+ def extend_relations(self, data, kind):
+ """extend metadata for tables or views.
+
+ :param data: list of (schema_name, rel_name) tuples
+ :param kind: either 'tables' or 'views'
+
+ :return:
+
+ """
+
+ data = [self.escaped_names(d) for d in data]
+
+ # dbmetadata['tables']['schema_name']['table_name'] should be an
+ # OrderedDict {column_name:ColumnMetaData}.
+ metadata = self.dbmetadata[kind]
+ for schema, relname in data:
+ try:
+ metadata[schema][relname] = OrderedDict()
+ except KeyError:
+ _logger.error(
+ "%r %r listed in unrecognized schema %r", kind, relname, schema
+ )
+ self.all_completions.add(relname)
+
+ def extend_columns(self, column_data, kind):
+ """extend column metadata.
+
+ :param column_data: list of (schema_name, rel_name, column_name,
+ column_type, has_default, default) tuples
+ :param kind: either 'tables' or 'views'
+
+ :return:
+
+ """
+ metadata = self.dbmetadata[kind]
+ for schema, relname, colname, datatype, has_default, default in column_data:
+ (schema, relname, colname) = self.escaped_names([schema, relname, colname])
+ column = ColumnMetadata(
+ name=colname,
+ datatype=datatype,
+ has_default=has_default,
+ default=default,
+ )
+ metadata[schema][relname][colname] = column
+ self.all_completions.add(colname)
+
+ def extend_functions(self, func_data):
+
+ # func_data is a list of function metadata namedtuples
+
+ # dbmetadata['schema_name']['functions']['function_name'] should return
+ # the function metadata namedtuple for the corresponding function
+ metadata = self.dbmetadata["functions"]
+
+ for f in func_data:
+ schema, func = self.escaped_names([f.schema_name, f.func_name])
+
+ if func in metadata[schema]:
+ metadata[schema][func].append(f)
+ else:
+ metadata[schema][func] = [f]
+
+ self.all_completions.add(func)
+
+ self._refresh_arg_list_cache()
+
+ def _refresh_arg_list_cache(self):
+ # We keep a cache of {function_usage:{function_metadata: function_arg_list_string}}
+ # This is used when suggesting functions, to avoid the latency that would result
+ # if we'd recalculate the arg lists each time we suggest functions (in large DBs)
+ self._arg_list_cache = {
+ usage: {
+ meta: self._arg_list(meta, usage)
+ for sch, funcs in self.dbmetadata["functions"].items()
+ for func, metas in funcs.items()
+ for meta in metas
+ }
+ for usage in ("call", "call_display", "signature")
+ }
+
+ def extend_foreignkeys(self, fk_data):
+
+ # fk_data is a list of ForeignKey namedtuples, with fields
+ # parentschema, childschema, parenttable, childtable,
+ # parentcolumns, childcolumns
+
+ # These are added as a list of ForeignKey namedtuples to the
+ # ColumnMetadata namedtuple for both the child and parent
+ meta = self.dbmetadata["tables"]
+
+ for fk in fk_data:
+ e = self.escaped_names
+ parentschema, childschema = e([fk.parentschema, fk.childschema])
+ parenttable, childtable = e([fk.parenttable, fk.childtable])
+ childcol, parcol = e([fk.childcolumn, fk.parentcolumn])
+ childcolmeta = meta[childschema][childtable][childcol]
+ parcolmeta = meta[parentschema][parenttable][parcol]
+ fk = ForeignKey(
+ parentschema, parenttable, parcol, childschema, childtable, childcol
+ )
+ childcolmeta.foreignkeys.append((fk))
+ parcolmeta.foreignkeys.append((fk))
+
+ def extend_datatypes(self, type_data):
+
+ # dbmetadata['datatypes'][schema_name][type_name] should store type
+ # metadata, such as composite type field names. Currently, we're not
+ # storing any metadata beyond typename, so just store None
+ meta = self.dbmetadata["datatypes"]
+
+ for t in type_data:
+ schema, type_name = self.escaped_names(t)
+ meta[schema][type_name] = None
+ self.all_completions.add(type_name)
+
+ def extend_query_history(self, text, is_init=False):
+ if is_init:
+ # During completer initialization, only load keyword preferences,
+ # not names
+ self.prioritizer.update_keywords(text)
+ else:
+ self.prioritizer.update(text)
+
+ def set_search_path(self, search_path):
+ self.search_path = self.escaped_names(search_path)
+
+ def reset_completions(self):
+ self.databases = []
+ self.special_commands = []
+ self.search_path = []
+ self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}}
+ self.all_completions = set(self.keywords + self.functions)
+
+ def find_matches(self, text, collection, mode="fuzzy", meta=None):
+ """Find completion matches for the given text.
+
+ Given the user's input text and a collection of available
+ completions, find completions matching the last word of the
+ text.
+
+ `collection` can be either a list of strings or a list of Candidate
+ namedtuples.
+ `mode` can be either 'fuzzy', or 'strict'
+ 'fuzzy': fuzzy matching, ties broken by name prevalance
+ `keyword`: start only matching, ties broken by keyword prevalance
+
+ yields prompt_toolkit Completion instances for any matches found
+ in the collection of available completions.
+
+ """
+ if not collection:
+ return []
+ prio_order = [
+ "keyword",
+ "function",
+ "view",
+ "table",
+ "datatype",
+ "database",
+ "schema",
+ "column",
+ "table alias",
+ "join",
+ "name join",
+ "fk join",
+ "table format",
+ ]
+ type_priority = prio_order.index(meta) if meta in prio_order else -1
+ text = last_word(text, include="most_punctuations").lower()
+ text_len = len(text)
+
+ if text and text[0] == '"':
+ # text starts with double quote; user is manually escaping a name
+ # Match on everything that follows the double-quote. Note that
+ # text_len is calculated before removing the quote, so the
+ # Completion.position value is correct
+ text = text[1:]
+
+ if mode == "fuzzy":
+ fuzzy = True
+ priority_func = self.prioritizer.name_count
+ else:
+ fuzzy = False
+ priority_func = self.prioritizer.keyword_count
+
+ # Construct a `_match` function for either fuzzy or non-fuzzy matching
+ # The match function returns a 2-tuple used for sorting the matches,
+ # or None if the item doesn't match
+ # Note: higher priority values mean more important, so use negative
+ # signs to flip the direction of the tuple
+ if fuzzy:
+ regex = ".*?".join(map(re.escape, text))
+ pat = re.compile("(%s)" % regex)
+
+ def _match(item):
+ if item.lower()[: len(text) + 1] in (text, text + " "):
+ # Exact match of first word in suggestion
+ # This is to get exact alias matches to the top
+ # E.g. for input `e`, 'Entries E' should be on top
+ # (before e.g. `EndUsers EU`)
+ return float("Infinity"), -1
+ r = pat.search(self.unescape_name(item.lower()))
+ if r:
+ return -len(r.group()), -r.start()
+
+ else:
+ match_end_limit = len(text)
+
+ def _match(item):
+ match_point = item.lower().find(text, 0, match_end_limit)
+ if match_point >= 0:
+ # Use negative infinity to force keywords to sort after all
+ # fuzzy matches
+ return -float("Infinity"), -match_point
+
+ matches = []
+ for cand in collection:
+ if isinstance(cand, _Candidate):
+ item, prio, display_meta, synonyms, prio2, display = cand
+ if display_meta is None:
+ display_meta = meta
+ syn_matches = (_match(x) for x in synonyms)
+ # Nones need to be removed to avoid max() crashing in Python 3
+ syn_matches = [m for m in syn_matches if m]
+ sort_key = max(syn_matches) if syn_matches else None
+ else:
+ item, display_meta, prio, prio2, display = cand, meta, 0, 0, cand
+ sort_key = _match(cand)
+
+ if sort_key:
+ if display_meta and len(display_meta) > 50:
+ # Truncate meta-text to 50 characters, if necessary
+ display_meta = display_meta[:47] + "..."
+
+ # Lexical order of items in the collection, used for
+ # tiebreaking items with the same match group length and start
+ # position. Since we use *higher* priority to mean "more
+ # important," we use -ord(c) to prioritize "aa" > "ab" and end
+ # with 1 to prioritize shorter strings (ie "user" > "users").
+ # We first do a case-insensitive sort and then a
+ # case-sensitive one as a tie breaker.
+ # We also use the unescape_name to make sure quoted names have
+ # the same priority as unquoted names.
+ lexical_priority = (
+ tuple(
+ 0 if c in (" _") else -ord(c)
+ for c in self.unescape_name(item.lower())
+ )
+ + (1,)
+ + tuple(c for c in item)
+ )
+
+ item = self.case(item)
+ display = self.case(display)
+ priority = (
+ sort_key,
+ type_priority,
+ prio,
+ priority_func(item),
+ prio2,
+ lexical_priority,
+ )
+ matches.append(
+ Match(
+ completion=Completion(
+ text=item,
+ start_position=-text_len,
+ display_meta=display_meta,
+ display=display,
+ ),
+ priority=priority,
+ )
+ )
+ return matches
+
+ def case(self, word):
+ return self.casing.get(word, word)
+
+ def get_completions(self, document, complete_event, smart_completion=None):
+ word_before_cursor = document.get_word_before_cursor(WORD=True)
+
+ if smart_completion is None:
+ smart_completion = self.smart_completion
+
+ # If smart_completion is off then match any word that starts with
+ # 'word_before_cursor'.
+ if not smart_completion:
+ matches = self.find_matches(
+ word_before_cursor, self.all_completions, mode="strict"
+ )
+ completions = [m.completion for m in matches]
+ return sorted(completions, key=operator.attrgetter("text"))
+
+ matches = []
+ suggestions = suggest_type(document.text, document.text_before_cursor)
+
+ for suggestion in suggestions:
+ suggestion_type = type(suggestion)
+ _logger.debug("Suggestion type: %r", suggestion_type)
+
+ # Map suggestion type to method
+ # e.g. 'table' -> self.get_table_matches
+ matcher = self.suggestion_matchers[suggestion_type]
+ matches.extend(matcher(self, suggestion, word_before_cursor))
+
+ # Sort matches so highest priorities are first
+ matches = sorted(matches, key=operator.attrgetter("priority"), reverse=True)
+
+ return [m.completion for m in matches]
+
+ def get_column_matches(self, suggestion, word_before_cursor):
+ tables = suggestion.table_refs
+ do_qualify = suggestion.qualifiable and {
+ "always": True,
+ "never": False,
+ "if_more_than_one_table": len(tables) > 1,
+ }[self.qualify_columns]
+ qualify = lambda col, tbl: (
+ (tbl + "." + self.case(col)) if do_qualify else self.case(col)
+ )
+ _logger.debug("Completion column scope: %r", tables)
+ scoped_cols = self.populate_scoped_cols(tables, suggestion.local_tables)
+
+ def make_cand(name, ref):
+ synonyms = (name, generate_alias(self.case(name)))
+ return Candidate(qualify(name, ref), 0, "column", synonyms)
+
+ def flat_cols():
+ return [
+ make_cand(c.name, t.ref)
+ for t, cols in scoped_cols.items()
+ for c in cols
+ ]
+
+ if suggestion.require_last_table:
+ # require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should
+ # suggest only columns that appear in the last table and one more
+ ltbl = tables[-1].ref
+ other_tbl_cols = set(
+ c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs
+ )
+ scoped_cols = {
+ t: [col for col in cols if col.name in other_tbl_cols]
+ for t, cols in scoped_cols.items()
+ if t.ref == ltbl
+ }
+ lastword = last_word(word_before_cursor, include="most_punctuations")
+ if lastword == "*":
+ if suggestion.context == "insert":
+
+ def filter(col):
+ if not col.has_default:
+ return True
+ return not any(
+ p.match(col.default) for p in self.insert_col_skip_patterns
+ )
+
+ scoped_cols = {
+ t: [col for col in cols if filter(col)]
+ for t, cols in scoped_cols.items()
+ }
+ if self.asterisk_column_order == "alphabetic":
+ for cols in scoped_cols.values():
+ cols.sort(key=operator.attrgetter("name"))
+ if (
+ lastword != word_before_cursor
+ and len(tables) == 1
+ and word_before_cursor[-len(lastword) - 1] == "."
+ ):
+ # User typed x.*; replicate "x." for all columns except the
+ # first, which gets the original (as we only replace the "*"")
+ sep = ", " + word_before_cursor[:-1]
+ collist = sep.join(self.case(c.completion) for c in flat_cols())
+ else:
+ collist = ", ".join(
+ qualify(c.name, t.ref) for t, cs in scoped_cols.items() for c in cs
+ )
+
+ return [
+ Match(
+ completion=Completion(
+ collist, -1, display_meta="columns", display="*"
+ ),
+ priority=(1, 1, 1),
+ )
+ ]
+
+ return self.find_matches(word_before_cursor, flat_cols(), meta="column")
+
+ def alias(self, tbl, tbls):
+ """Generate a unique table alias
+ tbl - name of the table to alias, quoted if it needs to be
+ tbls - TableReference iterable of tables already in query
+ """
+ tbl = self.case(tbl)
+ tbls = set(normalize_ref(t.ref) for t in tbls)
+ if self.generate_aliases:
+ tbl = generate_alias(self.unescape_name(tbl))
+ if normalize_ref(tbl) not in tbls:
+ return tbl
+ elif tbl[0] == '"':
+ aliases = ('"' + tbl[1:-1] + str(i) + '"' for i in count(2))
+ else:
+ aliases = (tbl + str(i) for i in count(2))
+ return next(a for a in aliases if normalize_ref(a) not in tbls)
+
+ def get_join_matches(self, suggestion, word_before_cursor):
+ tbls = suggestion.table_refs
+ cols = self.populate_scoped_cols(tbls)
+ # Set up some data structures for efficient access
+ qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls)
+ ref_prio = dict((normalize_ref(t.ref), n) for n, t in enumerate(tbls))
+ refs = set(normalize_ref(t.ref) for t in tbls)
+ other_tbls = set((t.schema, t.name) for t in list(cols)[:-1])
+ joins = []
+ # Iterate over FKs in existing tables to find potential joins
+ fks = (
+ (fk, rtbl, rcol)
+ for rtbl, rcols in cols.items()
+ for rcol in rcols
+ for fk in rcol.foreignkeys
+ )
+ col = namedtuple("col", "schema tbl col")
+ for fk, rtbl, rcol in fks:
+ right = col(rtbl.schema, rtbl.name, rcol.name)
+ child = col(fk.childschema, fk.childtable, fk.childcolumn)
+ parent = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
+ left = child if parent == right else parent
+ if suggestion.schema and left.schema != suggestion.schema:
+ continue
+ c = self.case
+ if self.generate_aliases or normalize_ref(left.tbl) in refs:
+ lref = self.alias(left.tbl, suggestion.table_refs)
+ join = "{0} {4} ON {4}.{1} = {2}.{3}".format(
+ c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref
+ )
+ else:
+ join = "{0} ON {0}.{1} = {2}.{3}".format(
+ c(left.tbl), c(left.col), rtbl.ref, c(right.col)
+ )
+ alias = generate_alias(self.case(left.tbl))
+ synonyms = [
+ join,
+ "{0} ON {0}.{1} = {2}.{3}".format(
+ alias, c(left.col), rtbl.ref, c(right.col)
+ ),
+ ]
+ # Schema-qualify if (1) new table in same schema as old, and old
+ # is schema-qualified, or (2) new in other schema, except public
+ if not suggestion.schema and (
+ qualified[normalize_ref(rtbl.ref)]
+ and left.schema == right.schema
+ or left.schema not in (right.schema, "public")
+ ):
+ join = left.schema + "." + join
+ prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + (
+ 0 if (left.schema, left.tbl) in other_tbls else 1
+ )
+ joins.append(Candidate(join, prio, "join", synonyms=synonyms))
+
+ return self.find_matches(word_before_cursor, joins, meta="join")
+
+ def get_join_condition_matches(self, suggestion, word_before_cursor):
+ col = namedtuple("col", "schema tbl col")
+ tbls = self.populate_scoped_cols(suggestion.table_refs).items
+ cols = [(t, c) for t, cs in tbls() for c in cs]
+ try:
+ lref = (suggestion.parent or suggestion.table_refs[-1]).ref
+ ltbl, lcols = [(t, cs) for (t, cs) in tbls() if t.ref == lref][-1]
+ except IndexError: # The user typed an incorrect table qualifier
+ return []
+ conds, found_conds = [], set()
+
+ def add_cond(lcol, rcol, rref, prio, meta):
+ prefix = "" if suggestion.parent else ltbl.ref + "."
+ case = self.case
+ cond = prefix + case(lcol) + " = " + rref + "." + case(rcol)
+ if cond not in found_conds:
+ found_conds.add(cond)
+ conds.append(Candidate(cond, prio + ref_prio[rref], meta))
+
+ def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
+ d = defaultdict(list)
+ for pair in pairs:
+ d[pair[0]].append(pair[1])
+ return d
+
+ # Tables that are closer to the cursor get higher prio
+ ref_prio = dict((tbl.ref, num) for num, tbl in enumerate(suggestion.table_refs))
+ # Map (schema, table, col) to tables
+ coldict = list_dict(
+ ((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref
+ )
+ # For each fk from the left table, generate a join condition if
+ # the other table is also in the scope
+ fks = ((fk, lcol.name) for lcol in lcols for fk in lcol.foreignkeys)
+ for fk, lcol in fks:
+ left = col(ltbl.schema, ltbl.name, lcol)
+ child = col(fk.childschema, fk.childtable, fk.childcolumn)
+ par = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
+ left, right = (child, par) if left == child else (par, child)
+ for rtbl in coldict[right]:
+ add_cond(left.col, right.col, rtbl.ref, 2000, "fk join")
+ # For name matching, use a {(colname, coltype): TableReference} dict
+ coltyp = namedtuple("coltyp", "name datatype")
+ col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols)
+ # Find all name-match join conditions
+ for c in (coltyp(c.name, c.datatype) for c in lcols):
+ for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref):
+ prio = 1000 if c.datatype in ("integer", "bigint", "smallint") else 0
+ add_cond(c.name, c.name, rtbl.ref, prio, "name join")
+
+ return self.find_matches(word_before_cursor, conds, meta="join")
+
+ def get_function_matches(self, suggestion, word_before_cursor, alias=False):
+
+ if suggestion.usage == "from":
+ # Only suggest functions allowed in FROM clause
+
+ def filt(f):
+ return (
+ not f.is_aggregate
+ and not f.is_window
+ and not f.is_extension
+ and (f.is_public or f.schema_name == suggestion.schema)
+ )
+
+ else:
+ alias = False
+
+ def filt(f):
+ return not f.is_extension and (
+ f.is_public or f.schema_name == suggestion.schema
+ )
+
+ arg_mode = {"signature": "signature", "special": None}.get(
+ suggestion.usage, "call"
+ )
+
+ # Function overloading means we way have multiple functions of the same
+ # name at this point, so keep unique names only
+ all_functions = self.populate_functions(suggestion.schema, filt)
+ funcs = set(
+ self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions
+ )
+
+ matches = self.find_matches(word_before_cursor, funcs, meta="function")
+
+ if not suggestion.schema and not suggestion.usage:
+ # also suggest hardcoded functions using startswith matching
+ predefined_funcs = self.find_matches(
+ word_before_cursor, self.functions, mode="strict", meta="function"
+ )
+ matches.extend(predefined_funcs)
+
+ return matches
+
+ def get_schema_matches(self, suggestion, word_before_cursor):
+ schema_names = self.dbmetadata["tables"].keys()
+
+ # Unless we're sure the user really wants them, hide schema names
+ # starting with pg_, which are mostly temporary schemas
+ if not word_before_cursor.startswith("pg_"):
+ schema_names = [s for s in schema_names if not s.startswith("pg_")]
+
+ if suggestion.quoted:
+ schema_names = [self.escape_schema(s) for s in schema_names]
+
+ return self.find_matches(word_before_cursor, schema_names, meta="schema")
+
+ def get_from_clause_item_matches(self, suggestion, word_before_cursor):
+ alias = self.generate_aliases
+ s = suggestion
+ t_sug = Table(s.schema, s.table_refs, s.local_tables)
+ v_sug = View(s.schema, s.table_refs)
+ f_sug = Function(s.schema, s.table_refs, usage="from")
+ return (
+ self.get_table_matches(t_sug, word_before_cursor, alias)
+ + self.get_view_matches(v_sug, word_before_cursor, alias)
+ + self.get_function_matches(f_sug, word_before_cursor, alias)
+ )
+
+ def _arg_list(self, func, usage):
+ """Returns a an arg list string, e.g. `(_foo:=23)` for a func.
+
+ :param func is a FunctionMetadata object
+ :param usage is 'call', 'call_display' or 'signature'
+
+ """
+ template = {
+ "call": self.call_arg_style,
+ "call_display": self.call_arg_display_style,
+ "signature": self.signature_arg_style,
+ }[usage]
+ args = func.args()
+ if not template:
+ return "()"
+ elif usage == "call" and len(args) < 2:
+ return "()"
+ elif usage == "call" and func.has_variadic():
+ return "()"
+ multiline = usage == "call" and len(args) > self.call_arg_oneliner_max
+ max_arg_len = max(len(a.name) for a in args) if multiline else 0
+ args = (
+ self._format_arg(template, arg, arg_num + 1, max_arg_len)
+ for arg_num, arg in enumerate(args)
+ )
+ if multiline:
+ return "(" + ",".join("\n " + a for a in args if a) + "\n)"
+ else:
+ return "(" + ", ".join(a for a in args if a) + ")"
+
+ def _format_arg(self, template, arg, arg_num, max_arg_len):
+ if not template:
+ return None
+ if arg.has_default:
+ arg_default = "NULL" if arg.default is None else arg.default
+ # Remove trailing ::(schema.)type
+ arg_default = arg_default_type_strip_regex.sub("", arg_default)
+ else:
+ arg_default = ""
+ return template.format(
+ max_arg_len=max_arg_len,
+ arg_name=self.case(arg.name),
+ arg_num=arg_num,
+ arg_type=arg.datatype,
+ arg_default=arg_default,
+ )
+
+ def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None):
+ """Returns a Candidate namedtuple.
+
+ :param tbl is a SchemaObject
+ :param arg_mode determines what type of arg list to suffix for functions.
+ Possible values: call, signature
+
+ """
+ cased_tbl = self.case(tbl.name)
+ if do_alias:
+ alias = self.alias(cased_tbl, suggestion.table_refs)
+ synonyms = (cased_tbl, generate_alias(cased_tbl))
+ maybe_alias = (" " + alias) if do_alias else ""
+ maybe_schema = (self.case(tbl.schema) + ".") if tbl.schema else ""
+ suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else ""
+ if arg_mode == "call":
+ display_suffix = self._arg_list_cache["call_display"][tbl.meta]
+ elif arg_mode == "signature":
+ display_suffix = self._arg_list_cache["signature"][tbl.meta]
+ else:
+ display_suffix = ""
+ item = maybe_schema + cased_tbl + suffix + maybe_alias
+ display = maybe_schema + cased_tbl + display_suffix + maybe_alias
+ prio2 = 0 if tbl.schema else 1
+ return Candidate(item, synonyms=synonyms, prio2=prio2, display=display)
+
+ def get_table_matches(self, suggestion, word_before_cursor, alias=False):
+ tables = self.populate_schema_objects(suggestion.schema, "tables")
+ tables.extend(SchemaObject(tbl.name) for tbl in suggestion.local_tables)
+
+ # Unless we're sure the user really wants them, don't suggest the
+ # pg_catalog tables that are implicitly on the search path
+ if not suggestion.schema and (not word_before_cursor.startswith("pg_")):
+ tables = [t for t in tables if not t.name.startswith("pg_")]
+ tables = [self._make_cand(t, alias, suggestion) for t in tables]
+ return self.find_matches(word_before_cursor, tables, meta="table")
+
+ def get_table_formats(self, _, word_before_cursor):
+ formats = TabularOutputFormatter().supported_formats
+ return self.find_matches(word_before_cursor, formats, meta="table format")
+
+ def get_view_matches(self, suggestion, word_before_cursor, alias=False):
+ views = self.populate_schema_objects(suggestion.schema, "views")
+
+ if not suggestion.schema and (not word_before_cursor.startswith("pg_")):
+ views = [v for v in views if not v.name.startswith("pg_")]
+ views = [self._make_cand(v, alias, suggestion) for v in views]
+ return self.find_matches(word_before_cursor, views, meta="view")
+
+ def get_alias_matches(self, suggestion, word_before_cursor):
+ aliases = suggestion.aliases
+ return self.find_matches(word_before_cursor, aliases, meta="table alias")
+
+ def get_database_matches(self, _, word_before_cursor):
+ return self.find_matches(word_before_cursor, self.databases, meta="database")
+
+ def get_keyword_matches(self, suggestion, word_before_cursor):
+ keywords = self.keywords_tree.keys()
+ # Get well known following keywords for the last token. If any, narrow
+ # candidates to this list.
+ next_keywords = self.keywords_tree.get(suggestion.last_token, [])
+ if next_keywords:
+ keywords = next_keywords
+
+ casing = self.keyword_casing
+ if casing == "auto":
+ if word_before_cursor and word_before_cursor[-1].islower():
+ casing = "lower"
+ else:
+ casing = "upper"
+
+ if casing == "upper":
+ keywords = [k.upper() for k in keywords]
+ else:
+ keywords = [k.lower() for k in keywords]
+
+ return self.find_matches(
+ word_before_cursor, keywords, mode="strict", meta="keyword"
+ )
+
+ def get_path_matches(self, _, word_before_cursor):
+ completer = PathCompleter(expanduser=True)
+ document = Document(
+ text=word_before_cursor, cursor_position=len(word_before_cursor)
+ )
+ for c in completer.get_completions(document, None):
+ yield Match(completion=c, priority=(0,))
+
+ def get_special_matches(self, _, word_before_cursor):
+ if not self.pgspecial:
+ return []
+
+ commands = self.pgspecial.commands
+ cmds = commands.keys()
+ cmds = [Candidate(cmd, 0, commands[cmd].description) for cmd in cmds]
+ return self.find_matches(word_before_cursor, cmds, mode="strict")
+
+ def get_datatype_matches(self, suggestion, word_before_cursor):
+ # suggest custom datatypes
+ types = self.populate_schema_objects(suggestion.schema, "datatypes")
+ types = [self._make_cand(t, False, suggestion) for t in types]
+ matches = self.find_matches(word_before_cursor, types, meta="datatype")
+
+ if not suggestion.schema:
+ # Also suggest hardcoded types
+ matches.extend(
+ self.find_matches(
+ word_before_cursor, self.datatypes, mode="strict", meta="datatype"
+ )
+ )
+
+ return matches
+
+ def get_namedquery_matches(self, _, word_before_cursor):
+ return self.find_matches(
+ word_before_cursor, NamedQueries.instance.list(), meta="named query"
+ )
+
+ suggestion_matchers = {
+ FromClauseItem: get_from_clause_item_matches,
+ JoinCondition: get_join_condition_matches,
+ Join: get_join_matches,
+ Column: get_column_matches,
+ Function: get_function_matches,
+ Schema: get_schema_matches,
+ Table: get_table_matches,
+ TableFormat: get_table_formats,
+ View: get_view_matches,
+ Alias: get_alias_matches,
+ Database: get_database_matches,
+ Keyword: get_keyword_matches,
+ Special: get_special_matches,
+ Datatype: get_datatype_matches,
+ NamedQuery: get_namedquery_matches,
+ Path: get_path_matches,
+ }
+
+ def populate_scoped_cols(self, scoped_tbls, local_tbls=()):
+ """Find all columns in a set of scoped_tables.
+
+ :param scoped_tbls: list of TableReference namedtuples
+ :param local_tbls: tuple(TableMetadata)
+ :return: {TableReference:{colname:ColumnMetaData}}
+
+ """
+ ctes = dict((normalize_ref(t.name), t.columns) for t in local_tbls)
+ columns = OrderedDict()
+ meta = self.dbmetadata
+
+ def addcols(schema, rel, alias, reltype, cols):
+ tbl = TableReference(schema, rel, alias, reltype == "functions")
+ if tbl not in columns:
+ columns[tbl] = []
+ columns[tbl].extend(cols)
+
+ for tbl in scoped_tbls:
+ # Local tables should shadow database tables
+ if tbl.schema is None and normalize_ref(tbl.name) in ctes:
+ cols = ctes[normalize_ref(tbl.name)]
+ addcols(None, tbl.name, "CTE", tbl.alias, cols)
+ continue
+ schemas = [tbl.schema] if tbl.schema else self.search_path
+ for schema in schemas:
+ relname = self.escape_name(tbl.name)
+ schema = self.escape_name(schema)
+ if tbl.is_function:
+ # Return column names from a set-returning function
+ # Get an array of FunctionMetadata objects
+ functions = meta["functions"].get(schema, {}).get(relname)
+ for func in functions or []:
+ # func is a FunctionMetadata object
+ cols = func.fields()
+ addcols(schema, relname, tbl.alias, "functions", cols)
+ else:
+ for reltype in ("tables", "views"):
+ cols = meta[reltype].get(schema, {}).get(relname)
+ if cols:
+ cols = cols.values()
+ addcols(schema, relname, tbl.alias, reltype, cols)
+ break
+
+ return columns
+
+ def _get_schemas(self, obj_typ, schema):
+ """Returns a list of schemas from which to suggest objects.
+
+ :param schema is the schema qualification input by the user (if any)
+
+ """
+ metadata = self.dbmetadata[obj_typ]
+ if schema:
+ schema = self.escape_name(schema)
+ return [schema] if schema in metadata else []
+ return self.search_path if self.search_path_filter else metadata.keys()
+
+ def _maybe_schema(self, schema, parent):
+ return None if parent or schema in self.search_path else schema
+
+ def populate_schema_objects(self, schema, obj_type):
+ """Returns a list of SchemaObjects representing tables or views.
+
+ :param schema is the schema qualification input by the user (if any)
+
+ """
+
+ return [
+ SchemaObject(
+ name=obj, schema=(self._maybe_schema(schema=sch, parent=schema))
+ )
+ for sch in self._get_schemas(obj_type, schema)
+ for obj in self.dbmetadata[obj_type][sch].keys()
+ ]
+
+ def populate_functions(self, schema, filter_func):
+ """Returns a list of function SchemaObjects.
+
+ :param filter_func is a function that accepts a FunctionMetadata
+ namedtuple and returns a boolean indicating whether that
+ function should be kept or discarded
+
+ """
+
+ # Because of multiple dispatch, we can have multiple functions
+ # with the same name, which is why `for meta in metas` is necessary
+ # in the comprehensions below
+ return [
+ SchemaObject(
+ name=func,
+ schema=(self._maybe_schema(schema=sch, parent=schema)),
+ meta=meta,
+ )
+ for sch in self._get_schemas("functions", schema)
+ for (func, metas) in self.dbmetadata["functions"][sch].items()
+ for meta in metas
+ if filter_func(meta)
+ ]