diff options
Diffstat (limited to 'pgcli/pgcompleter.py')
-rw-r--r-- | pgcli/pgcompleter.py | 1046 |
1 files changed, 1046 insertions, 0 deletions
diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py new file mode 100644 index 0000000..9c95a01 --- /dev/null +++ b/pgcli/pgcompleter.py @@ -0,0 +1,1046 @@ +import logging +import re +from itertools import count, repeat, chain +import operator +from collections import namedtuple, defaultdict, OrderedDict +from cli_helpers.tabular_output import TabularOutputFormatter +from pgspecial.namedqueries import NamedQueries +from prompt_toolkit.completion import Completer, Completion, PathCompleter +from prompt_toolkit.document import Document +from .packages.sqlcompletion import ( + FromClauseItem, + suggest_type, + Special, + Database, + Schema, + Table, + TableFormat, + Function, + Column, + View, + Keyword, + NamedQuery, + Datatype, + Alias, + Path, + JoinCondition, + Join, +) +from .packages.parseutils.meta import ColumnMetadata, ForeignKey +from .packages.parseutils.utils import last_word +from .packages.parseutils.tables import TableReference +from .packages.pgliterals.main import get_literals +from .packages.prioritization import PrevalenceCounter +from .config import load_config, config_location + +_logger = logging.getLogger(__name__) + +Match = namedtuple("Match", ["completion", "priority"]) + +_SchemaObject = namedtuple("SchemaObject", "name schema meta") + + +def SchemaObject(name, schema=None, meta=None): + return _SchemaObject(name, schema, meta) + + +_Candidate = namedtuple("Candidate", "completion prio meta synonyms prio2 display") + + +def Candidate( + completion, prio=None, meta=None, synonyms=None, prio2=None, display=None +): + return _Candidate( + completion, prio, meta, synonyms or [completion], prio2, display or completion + ) + + +# Used to strip trailing '::some_type' from default-value expressions +arg_default_type_strip_regex = re.compile(r"::[\w\.]+(\[\])?$") + +normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"' + + +def generate_alias(tbl): + """Generate a table alias, consisting of all upper-case letters in + the table name, or, if there are no upper-case letters, the first letter + + all letters preceded by _ + param tbl - unescaped name of the table to alias + """ + return "".join( + [l for l in tbl if l.isupper()] + or [l for l, prev in zip(tbl, "_" + tbl) if prev == "_" and l != "_"] + ) + + +class PGCompleter(Completer): + # keywords_tree: A dict mapping keywords to well known following keywords. + # e.g. 'CREATE': ['TABLE', 'USER', ...], + keywords_tree = get_literals("keywords", type_=dict) + keywords = tuple(set(chain(keywords_tree.keys(), *keywords_tree.values()))) + functions = get_literals("functions") + datatypes = get_literals("datatypes") + reserved_words = set(get_literals("reserved")) + + def __init__(self, smart_completion=True, pgspecial=None, settings=None): + super(PGCompleter, self).__init__() + self.smart_completion = smart_completion + self.pgspecial = pgspecial + self.prioritizer = PrevalenceCounter() + settings = settings or {} + self.signature_arg_style = settings.get( + "signature_arg_style", "{arg_name} {arg_type}" + ) + self.call_arg_style = settings.get( + "call_arg_style", "{arg_name: <{max_arg_len}} := {arg_default}" + ) + self.call_arg_display_style = settings.get( + "call_arg_display_style", "{arg_name}" + ) + self.call_arg_oneliner_max = settings.get("call_arg_oneliner_max", 2) + self.search_path_filter = settings.get("search_path_filter") + self.generate_aliases = settings.get("generate_aliases") + self.casing_file = settings.get("casing_file") + self.insert_col_skip_patterns = [ + re.compile(pattern) + for pattern in settings.get( + "insert_col_skip_patterns", [r"^now\(\)$", r"^nextval\("] + ) + ] + self.generate_casing_file = settings.get("generate_casing_file") + self.qualify_columns = settings.get("qualify_columns", "if_more_than_one_table") + self.asterisk_column_order = settings.get( + "asterisk_column_order", "table_order" + ) + + keyword_casing = settings.get("keyword_casing", "upper").lower() + if keyword_casing not in ("upper", "lower", "auto"): + keyword_casing = "upper" + self.keyword_casing = keyword_casing + self.name_pattern = re.compile(r"^[_a-z][_a-z0-9\$]*$") + + self.databases = [] + self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}} + self.search_path = [] + self.casing = {} + + self.all_completions = set(self.keywords + self.functions) + + def escape_name(self, name): + if name and ( + (not self.name_pattern.match(name)) + or (name.upper() in self.reserved_words) + or (name.upper() in self.functions) + ): + name = '"%s"' % name + + return name + + def escape_schema(self, name): + return "'{}'".format(self.unescape_name(name)) + + def unescape_name(self, name): + """ Unquote a string.""" + if name and name[0] == '"' and name[-1] == '"': + name = name[1:-1] + + return name + + def escaped_names(self, names): + return [self.escape_name(name) for name in names] + + def extend_database_names(self, databases): + self.databases.extend(databases) + + def extend_keywords(self, additional_keywords): + self.keywords.extend(additional_keywords) + self.all_completions.update(additional_keywords) + + def extend_schemata(self, schemata): + + # schemata is a list of schema names + schemata = self.escaped_names(schemata) + metadata = self.dbmetadata["tables"] + for schema in schemata: + metadata[schema] = {} + + # dbmetadata.values() are the 'tables' and 'functions' dicts + for metadata in self.dbmetadata.values(): + for schema in schemata: + metadata[schema] = {} + + self.all_completions.update(schemata) + + def extend_casing(self, words): + """extend casing data + + :return: + """ + # casing should be a dict {lowercasename:PreferredCasingName} + self.casing = dict((word.lower(), word) for word in words) + + def extend_relations(self, data, kind): + """extend metadata for tables or views. + + :param data: list of (schema_name, rel_name) tuples + :param kind: either 'tables' or 'views' + + :return: + + """ + + data = [self.escaped_names(d) for d in data] + + # dbmetadata['tables']['schema_name']['table_name'] should be an + # OrderedDict {column_name:ColumnMetaData}. + metadata = self.dbmetadata[kind] + for schema, relname in data: + try: + metadata[schema][relname] = OrderedDict() + except KeyError: + _logger.error( + "%r %r listed in unrecognized schema %r", kind, relname, schema + ) + self.all_completions.add(relname) + + def extend_columns(self, column_data, kind): + """extend column metadata. + + :param column_data: list of (schema_name, rel_name, column_name, + column_type, has_default, default) tuples + :param kind: either 'tables' or 'views' + + :return: + + """ + metadata = self.dbmetadata[kind] + for schema, relname, colname, datatype, has_default, default in column_data: + (schema, relname, colname) = self.escaped_names([schema, relname, colname]) + column = ColumnMetadata( + name=colname, + datatype=datatype, + has_default=has_default, + default=default, + ) + metadata[schema][relname][colname] = column + self.all_completions.add(colname) + + def extend_functions(self, func_data): + + # func_data is a list of function metadata namedtuples + + # dbmetadata['schema_name']['functions']['function_name'] should return + # the function metadata namedtuple for the corresponding function + metadata = self.dbmetadata["functions"] + + for f in func_data: + schema, func = self.escaped_names([f.schema_name, f.func_name]) + + if func in metadata[schema]: + metadata[schema][func].append(f) + else: + metadata[schema][func] = [f] + + self.all_completions.add(func) + + self._refresh_arg_list_cache() + + def _refresh_arg_list_cache(self): + # We keep a cache of {function_usage:{function_metadata: function_arg_list_string}} + # This is used when suggesting functions, to avoid the latency that would result + # if we'd recalculate the arg lists each time we suggest functions (in large DBs) + self._arg_list_cache = { + usage: { + meta: self._arg_list(meta, usage) + for sch, funcs in self.dbmetadata["functions"].items() + for func, metas in funcs.items() + for meta in metas + } + for usage in ("call", "call_display", "signature") + } + + def extend_foreignkeys(self, fk_data): + + # fk_data is a list of ForeignKey namedtuples, with fields + # parentschema, childschema, parenttable, childtable, + # parentcolumns, childcolumns + + # These are added as a list of ForeignKey namedtuples to the + # ColumnMetadata namedtuple for both the child and parent + meta = self.dbmetadata["tables"] + + for fk in fk_data: + e = self.escaped_names + parentschema, childschema = e([fk.parentschema, fk.childschema]) + parenttable, childtable = e([fk.parenttable, fk.childtable]) + childcol, parcol = e([fk.childcolumn, fk.parentcolumn]) + childcolmeta = meta[childschema][childtable][childcol] + parcolmeta = meta[parentschema][parenttable][parcol] + fk = ForeignKey( + parentschema, parenttable, parcol, childschema, childtable, childcol + ) + childcolmeta.foreignkeys.append((fk)) + parcolmeta.foreignkeys.append((fk)) + + def extend_datatypes(self, type_data): + + # dbmetadata['datatypes'][schema_name][type_name] should store type + # metadata, such as composite type field names. Currently, we're not + # storing any metadata beyond typename, so just store None + meta = self.dbmetadata["datatypes"] + + for t in type_data: + schema, type_name = self.escaped_names(t) + meta[schema][type_name] = None + self.all_completions.add(type_name) + + def extend_query_history(self, text, is_init=False): + if is_init: + # During completer initialization, only load keyword preferences, + # not names + self.prioritizer.update_keywords(text) + else: + self.prioritizer.update(text) + + def set_search_path(self, search_path): + self.search_path = self.escaped_names(search_path) + + def reset_completions(self): + self.databases = [] + self.special_commands = [] + self.search_path = [] + self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}} + self.all_completions = set(self.keywords + self.functions) + + def find_matches(self, text, collection, mode="fuzzy", meta=None): + """Find completion matches for the given text. + + Given the user's input text and a collection of available + completions, find completions matching the last word of the + text. + + `collection` can be either a list of strings or a list of Candidate + namedtuples. + `mode` can be either 'fuzzy', or 'strict' + 'fuzzy': fuzzy matching, ties broken by name prevalance + `keyword`: start only matching, ties broken by keyword prevalance + + yields prompt_toolkit Completion instances for any matches found + in the collection of available completions. + + """ + if not collection: + return [] + prio_order = [ + "keyword", + "function", + "view", + "table", + "datatype", + "database", + "schema", + "column", + "table alias", + "join", + "name join", + "fk join", + "table format", + ] + type_priority = prio_order.index(meta) if meta in prio_order else -1 + text = last_word(text, include="most_punctuations").lower() + text_len = len(text) + + if text and text[0] == '"': + # text starts with double quote; user is manually escaping a name + # Match on everything that follows the double-quote. Note that + # text_len is calculated before removing the quote, so the + # Completion.position value is correct + text = text[1:] + + if mode == "fuzzy": + fuzzy = True + priority_func = self.prioritizer.name_count + else: + fuzzy = False + priority_func = self.prioritizer.keyword_count + + # Construct a `_match` function for either fuzzy or non-fuzzy matching + # The match function returns a 2-tuple used for sorting the matches, + # or None if the item doesn't match + # Note: higher priority values mean more important, so use negative + # signs to flip the direction of the tuple + if fuzzy: + regex = ".*?".join(map(re.escape, text)) + pat = re.compile("(%s)" % regex) + + def _match(item): + if item.lower()[: len(text) + 1] in (text, text + " "): + # Exact match of first word in suggestion + # This is to get exact alias matches to the top + # E.g. for input `e`, 'Entries E' should be on top + # (before e.g. `EndUsers EU`) + return float("Infinity"), -1 + r = pat.search(self.unescape_name(item.lower())) + if r: + return -len(r.group()), -r.start() + + else: + match_end_limit = len(text) + + def _match(item): + match_point = item.lower().find(text, 0, match_end_limit) + if match_point >= 0: + # Use negative infinity to force keywords to sort after all + # fuzzy matches + return -float("Infinity"), -match_point + + matches = [] + for cand in collection: + if isinstance(cand, _Candidate): + item, prio, display_meta, synonyms, prio2, display = cand + if display_meta is None: + display_meta = meta + syn_matches = (_match(x) for x in synonyms) + # Nones need to be removed to avoid max() crashing in Python 3 + syn_matches = [m for m in syn_matches if m] + sort_key = max(syn_matches) if syn_matches else None + else: + item, display_meta, prio, prio2, display = cand, meta, 0, 0, cand + sort_key = _match(cand) + + if sort_key: + if display_meta and len(display_meta) > 50: + # Truncate meta-text to 50 characters, if necessary + display_meta = display_meta[:47] + "..." + + # Lexical order of items in the collection, used for + # tiebreaking items with the same match group length and start + # position. Since we use *higher* priority to mean "more + # important," we use -ord(c) to prioritize "aa" > "ab" and end + # with 1 to prioritize shorter strings (ie "user" > "users"). + # We first do a case-insensitive sort and then a + # case-sensitive one as a tie breaker. + # We also use the unescape_name to make sure quoted names have + # the same priority as unquoted names. + lexical_priority = ( + tuple( + 0 if c in (" _") else -ord(c) + for c in self.unescape_name(item.lower()) + ) + + (1,) + + tuple(c for c in item) + ) + + item = self.case(item) + display = self.case(display) + priority = ( + sort_key, + type_priority, + prio, + priority_func(item), + prio2, + lexical_priority, + ) + matches.append( + Match( + completion=Completion( + text=item, + start_position=-text_len, + display_meta=display_meta, + display=display, + ), + priority=priority, + ) + ) + return matches + + def case(self, word): + return self.casing.get(word, word) + + def get_completions(self, document, complete_event, smart_completion=None): + word_before_cursor = document.get_word_before_cursor(WORD=True) + + if smart_completion is None: + smart_completion = self.smart_completion + + # If smart_completion is off then match any word that starts with + # 'word_before_cursor'. + if not smart_completion: + matches = self.find_matches( + word_before_cursor, self.all_completions, mode="strict" + ) + completions = [m.completion for m in matches] + return sorted(completions, key=operator.attrgetter("text")) + + matches = [] + suggestions = suggest_type(document.text, document.text_before_cursor) + + for suggestion in suggestions: + suggestion_type = type(suggestion) + _logger.debug("Suggestion type: %r", suggestion_type) + + # Map suggestion type to method + # e.g. 'table' -> self.get_table_matches + matcher = self.suggestion_matchers[suggestion_type] + matches.extend(matcher(self, suggestion, word_before_cursor)) + + # Sort matches so highest priorities are first + matches = sorted(matches, key=operator.attrgetter("priority"), reverse=True) + + return [m.completion for m in matches] + + def get_column_matches(self, suggestion, word_before_cursor): + tables = suggestion.table_refs + do_qualify = suggestion.qualifiable and { + "always": True, + "never": False, + "if_more_than_one_table": len(tables) > 1, + }[self.qualify_columns] + qualify = lambda col, tbl: ( + (tbl + "." + self.case(col)) if do_qualify else self.case(col) + ) + _logger.debug("Completion column scope: %r", tables) + scoped_cols = self.populate_scoped_cols(tables, suggestion.local_tables) + + def make_cand(name, ref): + synonyms = (name, generate_alias(self.case(name))) + return Candidate(qualify(name, ref), 0, "column", synonyms) + + def flat_cols(): + return [ + make_cand(c.name, t.ref) + for t, cols in scoped_cols.items() + for c in cols + ] + + if suggestion.require_last_table: + # require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should + # suggest only columns that appear in the last table and one more + ltbl = tables[-1].ref + other_tbl_cols = set( + c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs + ) + scoped_cols = { + t: [col for col in cols if col.name in other_tbl_cols] + for t, cols in scoped_cols.items() + if t.ref == ltbl + } + lastword = last_word(word_before_cursor, include="most_punctuations") + if lastword == "*": + if suggestion.context == "insert": + + def filter(col): + if not col.has_default: + return True + return not any( + p.match(col.default) for p in self.insert_col_skip_patterns + ) + + scoped_cols = { + t: [col for col in cols if filter(col)] + for t, cols in scoped_cols.items() + } + if self.asterisk_column_order == "alphabetic": + for cols in scoped_cols.values(): + cols.sort(key=operator.attrgetter("name")) + if ( + lastword != word_before_cursor + and len(tables) == 1 + and word_before_cursor[-len(lastword) - 1] == "." + ): + # User typed x.*; replicate "x." for all columns except the + # first, which gets the original (as we only replace the "*"") + sep = ", " + word_before_cursor[:-1] + collist = sep.join(self.case(c.completion) for c in flat_cols()) + else: + collist = ", ".join( + qualify(c.name, t.ref) for t, cs in scoped_cols.items() for c in cs + ) + + return [ + Match( + completion=Completion( + collist, -1, display_meta="columns", display="*" + ), + priority=(1, 1, 1), + ) + ] + + return self.find_matches(word_before_cursor, flat_cols(), meta="column") + + def alias(self, tbl, tbls): + """Generate a unique table alias + tbl - name of the table to alias, quoted if it needs to be + tbls - TableReference iterable of tables already in query + """ + tbl = self.case(tbl) + tbls = set(normalize_ref(t.ref) for t in tbls) + if self.generate_aliases: + tbl = generate_alias(self.unescape_name(tbl)) + if normalize_ref(tbl) not in tbls: + return tbl + elif tbl[0] == '"': + aliases = ('"' + tbl[1:-1] + str(i) + '"' for i in count(2)) + else: + aliases = (tbl + str(i) for i in count(2)) + return next(a for a in aliases if normalize_ref(a) not in tbls) + + def get_join_matches(self, suggestion, word_before_cursor): + tbls = suggestion.table_refs + cols = self.populate_scoped_cols(tbls) + # Set up some data structures for efficient access + qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls) + ref_prio = dict((normalize_ref(t.ref), n) for n, t in enumerate(tbls)) + refs = set(normalize_ref(t.ref) for t in tbls) + other_tbls = set((t.schema, t.name) for t in list(cols)[:-1]) + joins = [] + # Iterate over FKs in existing tables to find potential joins + fks = ( + (fk, rtbl, rcol) + for rtbl, rcols in cols.items() + for rcol in rcols + for fk in rcol.foreignkeys + ) + col = namedtuple("col", "schema tbl col") + for fk, rtbl, rcol in fks: + right = col(rtbl.schema, rtbl.name, rcol.name) + child = col(fk.childschema, fk.childtable, fk.childcolumn) + parent = col(fk.parentschema, fk.parenttable, fk.parentcolumn) + left = child if parent == right else parent + if suggestion.schema and left.schema != suggestion.schema: + continue + c = self.case + if self.generate_aliases or normalize_ref(left.tbl) in refs: + lref = self.alias(left.tbl, suggestion.table_refs) + join = "{0} {4} ON {4}.{1} = {2}.{3}".format( + c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref + ) + else: + join = "{0} ON {0}.{1} = {2}.{3}".format( + c(left.tbl), c(left.col), rtbl.ref, c(right.col) + ) + alias = generate_alias(self.case(left.tbl)) + synonyms = [ + join, + "{0} ON {0}.{1} = {2}.{3}".format( + alias, c(left.col), rtbl.ref, c(right.col) + ), + ] + # Schema-qualify if (1) new table in same schema as old, and old + # is schema-qualified, or (2) new in other schema, except public + if not suggestion.schema and ( + qualified[normalize_ref(rtbl.ref)] + and left.schema == right.schema + or left.schema not in (right.schema, "public") + ): + join = left.schema + "." + join + prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + ( + 0 if (left.schema, left.tbl) in other_tbls else 1 + ) + joins.append(Candidate(join, prio, "join", synonyms=synonyms)) + + return self.find_matches(word_before_cursor, joins, meta="join") + + def get_join_condition_matches(self, suggestion, word_before_cursor): + col = namedtuple("col", "schema tbl col") + tbls = self.populate_scoped_cols(suggestion.table_refs).items + cols = [(t, c) for t, cs in tbls() for c in cs] + try: + lref = (suggestion.parent or suggestion.table_refs[-1]).ref + ltbl, lcols = [(t, cs) for (t, cs) in tbls() if t.ref == lref][-1] + except IndexError: # The user typed an incorrect table qualifier + return [] + conds, found_conds = [], set() + + def add_cond(lcol, rcol, rref, prio, meta): + prefix = "" if suggestion.parent else ltbl.ref + "." + case = self.case + cond = prefix + case(lcol) + " = " + rref + "." + case(rcol) + if cond not in found_conds: + found_conds.add(cond) + conds.append(Candidate(cond, prio + ref_prio[rref], meta)) + + def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]} + d = defaultdict(list) + for pair in pairs: + d[pair[0]].append(pair[1]) + return d + + # Tables that are closer to the cursor get higher prio + ref_prio = dict((tbl.ref, num) for num, tbl in enumerate(suggestion.table_refs)) + # Map (schema, table, col) to tables + coldict = list_dict( + ((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref + ) + # For each fk from the left table, generate a join condition if + # the other table is also in the scope + fks = ((fk, lcol.name) for lcol in lcols for fk in lcol.foreignkeys) + for fk, lcol in fks: + left = col(ltbl.schema, ltbl.name, lcol) + child = col(fk.childschema, fk.childtable, fk.childcolumn) + par = col(fk.parentschema, fk.parenttable, fk.parentcolumn) + left, right = (child, par) if left == child else (par, child) + for rtbl in coldict[right]: + add_cond(left.col, right.col, rtbl.ref, 2000, "fk join") + # For name matching, use a {(colname, coltype): TableReference} dict + coltyp = namedtuple("coltyp", "name datatype") + col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols) + # Find all name-match join conditions + for c in (coltyp(c.name, c.datatype) for c in lcols): + for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref): + prio = 1000 if c.datatype in ("integer", "bigint", "smallint") else 0 + add_cond(c.name, c.name, rtbl.ref, prio, "name join") + + return self.find_matches(word_before_cursor, conds, meta="join") + + def get_function_matches(self, suggestion, word_before_cursor, alias=False): + + if suggestion.usage == "from": + # Only suggest functions allowed in FROM clause + + def filt(f): + return ( + not f.is_aggregate + and not f.is_window + and not f.is_extension + and (f.is_public or f.schema_name == suggestion.schema) + ) + + else: + alias = False + + def filt(f): + return not f.is_extension and ( + f.is_public or f.schema_name == suggestion.schema + ) + + arg_mode = {"signature": "signature", "special": None}.get( + suggestion.usage, "call" + ) + + # Function overloading means we way have multiple functions of the same + # name at this point, so keep unique names only + all_functions = self.populate_functions(suggestion.schema, filt) + funcs = set( + self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions + ) + + matches = self.find_matches(word_before_cursor, funcs, meta="function") + + if not suggestion.schema and not suggestion.usage: + # also suggest hardcoded functions using startswith matching + predefined_funcs = self.find_matches( + word_before_cursor, self.functions, mode="strict", meta="function" + ) + matches.extend(predefined_funcs) + + return matches + + def get_schema_matches(self, suggestion, word_before_cursor): + schema_names = self.dbmetadata["tables"].keys() + + # Unless we're sure the user really wants them, hide schema names + # starting with pg_, which are mostly temporary schemas + if not word_before_cursor.startswith("pg_"): + schema_names = [s for s in schema_names if not s.startswith("pg_")] + + if suggestion.quoted: + schema_names = [self.escape_schema(s) for s in schema_names] + + return self.find_matches(word_before_cursor, schema_names, meta="schema") + + def get_from_clause_item_matches(self, suggestion, word_before_cursor): + alias = self.generate_aliases + s = suggestion + t_sug = Table(s.schema, s.table_refs, s.local_tables) + v_sug = View(s.schema, s.table_refs) + f_sug = Function(s.schema, s.table_refs, usage="from") + return ( + self.get_table_matches(t_sug, word_before_cursor, alias) + + self.get_view_matches(v_sug, word_before_cursor, alias) + + self.get_function_matches(f_sug, word_before_cursor, alias) + ) + + def _arg_list(self, func, usage): + """Returns a an arg list string, e.g. `(_foo:=23)` for a func. + + :param func is a FunctionMetadata object + :param usage is 'call', 'call_display' or 'signature' + + """ + template = { + "call": self.call_arg_style, + "call_display": self.call_arg_display_style, + "signature": self.signature_arg_style, + }[usage] + args = func.args() + if not template: + return "()" + elif usage == "call" and len(args) < 2: + return "()" + elif usage == "call" and func.has_variadic(): + return "()" + multiline = usage == "call" and len(args) > self.call_arg_oneliner_max + max_arg_len = max(len(a.name) for a in args) if multiline else 0 + args = ( + self._format_arg(template, arg, arg_num + 1, max_arg_len) + for arg_num, arg in enumerate(args) + ) + if multiline: + return "(" + ",".join("\n " + a for a in args if a) + "\n)" + else: + return "(" + ", ".join(a for a in args if a) + ")" + + def _format_arg(self, template, arg, arg_num, max_arg_len): + if not template: + return None + if arg.has_default: + arg_default = "NULL" if arg.default is None else arg.default + # Remove trailing ::(schema.)type + arg_default = arg_default_type_strip_regex.sub("", arg_default) + else: + arg_default = "" + return template.format( + max_arg_len=max_arg_len, + arg_name=self.case(arg.name), + arg_num=arg_num, + arg_type=arg.datatype, + arg_default=arg_default, + ) + + def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None): + """Returns a Candidate namedtuple. + + :param tbl is a SchemaObject + :param arg_mode determines what type of arg list to suffix for functions. + Possible values: call, signature + + """ + cased_tbl = self.case(tbl.name) + if do_alias: + alias = self.alias(cased_tbl, suggestion.table_refs) + synonyms = (cased_tbl, generate_alias(cased_tbl)) + maybe_alias = (" " + alias) if do_alias else "" + maybe_schema = (self.case(tbl.schema) + ".") if tbl.schema else "" + suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else "" + if arg_mode == "call": + display_suffix = self._arg_list_cache["call_display"][tbl.meta] + elif arg_mode == "signature": + display_suffix = self._arg_list_cache["signature"][tbl.meta] + else: + display_suffix = "" + item = maybe_schema + cased_tbl + suffix + maybe_alias + display = maybe_schema + cased_tbl + display_suffix + maybe_alias + prio2 = 0 if tbl.schema else 1 + return Candidate(item, synonyms=synonyms, prio2=prio2, display=display) + + def get_table_matches(self, suggestion, word_before_cursor, alias=False): + tables = self.populate_schema_objects(suggestion.schema, "tables") + tables.extend(SchemaObject(tbl.name) for tbl in suggestion.local_tables) + + # Unless we're sure the user really wants them, don't suggest the + # pg_catalog tables that are implicitly on the search path + if not suggestion.schema and (not word_before_cursor.startswith("pg_")): + tables = [t for t in tables if not t.name.startswith("pg_")] + tables = [self._make_cand(t, alias, suggestion) for t in tables] + return self.find_matches(word_before_cursor, tables, meta="table") + + def get_table_formats(self, _, word_before_cursor): + formats = TabularOutputFormatter().supported_formats + return self.find_matches(word_before_cursor, formats, meta="table format") + + def get_view_matches(self, suggestion, word_before_cursor, alias=False): + views = self.populate_schema_objects(suggestion.schema, "views") + + if not suggestion.schema and (not word_before_cursor.startswith("pg_")): + views = [v for v in views if not v.name.startswith("pg_")] + views = [self._make_cand(v, alias, suggestion) for v in views] + return self.find_matches(word_before_cursor, views, meta="view") + + def get_alias_matches(self, suggestion, word_before_cursor): + aliases = suggestion.aliases + return self.find_matches(word_before_cursor, aliases, meta="table alias") + + def get_database_matches(self, _, word_before_cursor): + return self.find_matches(word_before_cursor, self.databases, meta="database") + + def get_keyword_matches(self, suggestion, word_before_cursor): + keywords = self.keywords_tree.keys() + # Get well known following keywords for the last token. If any, narrow + # candidates to this list. + next_keywords = self.keywords_tree.get(suggestion.last_token, []) + if next_keywords: + keywords = next_keywords + + casing = self.keyword_casing + if casing == "auto": + if word_before_cursor and word_before_cursor[-1].islower(): + casing = "lower" + else: + casing = "upper" + + if casing == "upper": + keywords = [k.upper() for k in keywords] + else: + keywords = [k.lower() for k in keywords] + + return self.find_matches( + word_before_cursor, keywords, mode="strict", meta="keyword" + ) + + def get_path_matches(self, _, word_before_cursor): + completer = PathCompleter(expanduser=True) + document = Document( + text=word_before_cursor, cursor_position=len(word_before_cursor) + ) + for c in completer.get_completions(document, None): + yield Match(completion=c, priority=(0,)) + + def get_special_matches(self, _, word_before_cursor): + if not self.pgspecial: + return [] + + commands = self.pgspecial.commands + cmds = commands.keys() + cmds = [Candidate(cmd, 0, commands[cmd].description) for cmd in cmds] + return self.find_matches(word_before_cursor, cmds, mode="strict") + + def get_datatype_matches(self, suggestion, word_before_cursor): + # suggest custom datatypes + types = self.populate_schema_objects(suggestion.schema, "datatypes") + types = [self._make_cand(t, False, suggestion) for t in types] + matches = self.find_matches(word_before_cursor, types, meta="datatype") + + if not suggestion.schema: + # Also suggest hardcoded types + matches.extend( + self.find_matches( + word_before_cursor, self.datatypes, mode="strict", meta="datatype" + ) + ) + + return matches + + def get_namedquery_matches(self, _, word_before_cursor): + return self.find_matches( + word_before_cursor, NamedQueries.instance.list(), meta="named query" + ) + + suggestion_matchers = { + FromClauseItem: get_from_clause_item_matches, + JoinCondition: get_join_condition_matches, + Join: get_join_matches, + Column: get_column_matches, + Function: get_function_matches, + Schema: get_schema_matches, + Table: get_table_matches, + TableFormat: get_table_formats, + View: get_view_matches, + Alias: get_alias_matches, + Database: get_database_matches, + Keyword: get_keyword_matches, + Special: get_special_matches, + Datatype: get_datatype_matches, + NamedQuery: get_namedquery_matches, + Path: get_path_matches, + } + + def populate_scoped_cols(self, scoped_tbls, local_tbls=()): + """Find all columns in a set of scoped_tables. + + :param scoped_tbls: list of TableReference namedtuples + :param local_tbls: tuple(TableMetadata) + :return: {TableReference:{colname:ColumnMetaData}} + + """ + ctes = dict((normalize_ref(t.name), t.columns) for t in local_tbls) + columns = OrderedDict() + meta = self.dbmetadata + + def addcols(schema, rel, alias, reltype, cols): + tbl = TableReference(schema, rel, alias, reltype == "functions") + if tbl not in columns: + columns[tbl] = [] + columns[tbl].extend(cols) + + for tbl in scoped_tbls: + # Local tables should shadow database tables + if tbl.schema is None and normalize_ref(tbl.name) in ctes: + cols = ctes[normalize_ref(tbl.name)] + addcols(None, tbl.name, "CTE", tbl.alias, cols) + continue + schemas = [tbl.schema] if tbl.schema else self.search_path + for schema in schemas: + relname = self.escape_name(tbl.name) + schema = self.escape_name(schema) + if tbl.is_function: + # Return column names from a set-returning function + # Get an array of FunctionMetadata objects + functions = meta["functions"].get(schema, {}).get(relname) + for func in functions or []: + # func is a FunctionMetadata object + cols = func.fields() + addcols(schema, relname, tbl.alias, "functions", cols) + else: + for reltype in ("tables", "views"): + cols = meta[reltype].get(schema, {}).get(relname) + if cols: + cols = cols.values() + addcols(schema, relname, tbl.alias, reltype, cols) + break + + return columns + + def _get_schemas(self, obj_typ, schema): + """Returns a list of schemas from which to suggest objects. + + :param schema is the schema qualification input by the user (if any) + + """ + metadata = self.dbmetadata[obj_typ] + if schema: + schema = self.escape_name(schema) + return [schema] if schema in metadata else [] + return self.search_path if self.search_path_filter else metadata.keys() + + def _maybe_schema(self, schema, parent): + return None if parent or schema in self.search_path else schema + + def populate_schema_objects(self, schema, obj_type): + """Returns a list of SchemaObjects representing tables or views. + + :param schema is the schema qualification input by the user (if any) + + """ + + return [ + SchemaObject( + name=obj, schema=(self._maybe_schema(schema=sch, parent=schema)) + ) + for sch in self._get_schemas(obj_type, schema) + for obj in self.dbmetadata[obj_type][sch].keys() + ] + + def populate_functions(self, schema, filter_func): + """Returns a list of function SchemaObjects. + + :param filter_func is a function that accepts a FunctionMetadata + namedtuple and returns a boolean indicating whether that + function should be kept or discarded + + """ + + # Because of multiple dispatch, we can have multiple functions + # with the same name, which is why `for meta in metas` is necessary + # in the comprehensions below + return [ + SchemaObject( + name=func, + schema=(self._maybe_schema(schema=sch, parent=schema)), + meta=meta, + ) + for sch in self._get_schemas("functions", schema) + for (func, metas) in self.dbmetadata["functions"][sch].items() + for meta in metas + if filter_func(meta) + ] |