import json import logging import re from itertools import count, repeat, chain import operator from collections import namedtuple, defaultdict, OrderedDict from cli_helpers.tabular_output import TabularOutputFormatter from pgspecial.namedqueries import NamedQueries from prompt_toolkit.completion import Completer, Completion, PathCompleter from prompt_toolkit.document import Document from .packages.sqlcompletion import ( FromClauseItem, suggest_type, Special, Database, Schema, Table, TableFormat, Function, Column, View, Keyword, NamedQuery, Datatype, Alias, Path, JoinCondition, Join, ) from .packages.parseutils.meta import ColumnMetadata, ForeignKey from .packages.parseutils.utils import last_word from .packages.parseutils.tables import TableReference from .packages.pgliterals.main import get_literals from .packages.prioritization import PrevalenceCounter from .config import load_config, config_location _logger = logging.getLogger(__name__) Match = namedtuple("Match", ["completion", "priority"]) _SchemaObject = namedtuple("SchemaObject", "name schema meta") def SchemaObject(name, schema=None, meta=None): return _SchemaObject(name, schema, meta) _Candidate = namedtuple("Candidate", "completion prio meta synonyms prio2 display") def Candidate( completion, prio=None, meta=None, synonyms=None, prio2=None, display=None ): return _Candidate( completion, prio, meta, synonyms or [completion], prio2, display or completion ) # Used to strip trailing '::some_type' from default-value expressions arg_default_type_strip_regex = re.compile(r"::[\w\.]+(\[\])?$") normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"' def generate_alias(tbl, alias_map=None): """Generate a table alias, consisting of all upper-case letters in the table name, or, if there are no upper-case letters, the first letter + all letters preceded by _ param tbl - unescaped name of the table to alias """ if alias_map and tbl in alias_map: return alias_map[tbl] return "".join( [l for l in tbl if l.isupper()] or [l for l, prev in zip(tbl, "_" + tbl) if prev == "_" and l != "_"] ) class InvalidMapFile(ValueError): pass def load_alias_map_file(path): try: with open(path) as fo: alias_map = json.load(fo) except FileNotFoundError as err: raise InvalidMapFile( f"Cannot read alias_map_file - {err.filename} does not exist" ) except json.JSONDecodeError: raise InvalidMapFile(f"Cannot read alias_map_file - {path} is not valid json") else: return alias_map class PGCompleter(Completer): # keywords_tree: A dict mapping keywords to well known following keywords. # e.g. 'CREATE': ['TABLE', 'USER', ...], keywords_tree = get_literals("keywords", type_=dict) keywords = tuple(set(chain(keywords_tree.keys(), *keywords_tree.values()))) functions = get_literals("functions") datatypes = get_literals("datatypes") reserved_words = set(get_literals("reserved")) def __init__(self, smart_completion=True, pgspecial=None, settings=None): super().__init__() self.smart_completion = smart_completion self.pgspecial = pgspecial self.prioritizer = PrevalenceCounter() settings = settings or {} self.signature_arg_style = settings.get( "signature_arg_style", "{arg_name} {arg_type}" ) self.call_arg_style = settings.get( "call_arg_style", "{arg_name: <{max_arg_len}} := {arg_default}" ) self.call_arg_display_style = settings.get( "call_arg_display_style", "{arg_name}" ) self.call_arg_oneliner_max = settings.get("call_arg_oneliner_max", 2) self.search_path_filter = settings.get("search_path_filter") self.generate_aliases = settings.get("generate_aliases") alias_map_file = settings.get("alias_map_file") if alias_map_file is not None: self.alias_map = load_alias_map_file(alias_map_file) else: self.alias_map = None self.casing_file = settings.get("casing_file") self.insert_col_skip_patterns = [ re.compile(pattern) for pattern in settings.get( "insert_col_skip_patterns", [r"^now\(\)$", r"^nextval\("] ) ] self.generate_casing_file = settings.get("generate_casing_file") self.qualify_columns = settings.get("qualify_columns", "if_more_than_one_table") self.asterisk_column_order = settings.get( "asterisk_column_order", "table_order" ) keyword_casing = settings.get("keyword_casing", "upper").lower() if keyword_casing not in ("upper", "lower", "auto"): keyword_casing = "upper" self.keyword_casing = keyword_casing self.name_pattern = re.compile(r"^[_a-z][_a-z0-9\$]*$") self.databases = [] self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}} self.search_path = [] self.casing = {} self.all_completions = set(self.keywords + self.functions) def escape_name(self, name): if name and ( (not self.name_pattern.match(name)) or (name.upper() in self.reserved_words) or (name.upper() in self.functions) ): name = '"%s"' % name return name def escape_schema(self, name): return "'{}'".format(self.unescape_name(name)) def unescape_name(self, name): """Unquote a string.""" if name and name[0] == '"' and name[-1] == '"': name = name[1:-1] return name def escaped_names(self, names): return [self.escape_name(name) for name in names] def extend_database_names(self, databases): self.databases.extend(databases) def extend_keywords(self, additional_keywords): self.keywords.extend(additional_keywords) self.all_completions.update(additional_keywords) def extend_schemata(self, schemata): # schemata is a list of schema names schemata = self.escaped_names(schemata) metadata = self.dbmetadata["tables"] for schema in schemata: metadata[schema] = {} # dbmetadata.values() are the 'tables' and 'functions' dicts for metadata in self.dbmetadata.values(): for schema in schemata: metadata[schema] = {} self.all_completions.update(schemata) def extend_casing(self, words): """extend casing data :return: """ # casing should be a dict {lowercasename:PreferredCasingName} self.casing = {word.lower(): word for word in words} def extend_relations(self, data, kind): """extend metadata for tables or views. :param data: list of (schema_name, rel_name) tuples :param kind: either 'tables' or 'views' :return: """ data = [self.escaped_names(d) for d in data] # dbmetadata['tables']['schema_name']['table_name'] should be an # OrderedDict {column_name:ColumnMetaData}. metadata = self.dbmetadata[kind] for schema, relname in data: try: metadata[schema][relname] = OrderedDict() except KeyError: _logger.error( "%r %r listed in unrecognized schema %r", kind, relname, schema ) self.all_completions.add(relname) def extend_columns(self, column_data, kind): """extend column metadata. :param column_data: list of (schema_name, rel_name, column_name, column_type, has_default, default) tuples :param kind: either 'tables' or 'views' :return: """ metadata = self.dbmetadata[kind] for schema, relname, colname, datatype, has_default, default in column_data: (schema, relname, colname) = self.escaped_names([schema, relname, colname]) column = ColumnMetadata( name=colname, datatype=datatype, has_default=has_default, default=default, ) metadata[schema][relname][colname] = column self.all_completions.add(colname) def extend_functions(self, func_data): # func_data is a list of function metadata namedtuples # dbmetadata['schema_name']['functions']['function_name'] should return # the function metadata namedtuple for the corresponding function metadata = self.dbmetadata["functions"] for f in func_data: schema, func = self.escaped_names([f.schema_name, f.func_name]) if func in metadata[schema]: metadata[schema][func].append(f) else: metadata[schema][func] = [f] self.all_completions.add(func) self._refresh_arg_list_cache() def _refresh_arg_list_cache(self): # We keep a cache of {function_usage:{function_metadata: function_arg_list_string}} # This is used when suggesting functions, to avoid the latency that would result # if we'd recalculate the arg lists each time we suggest functions (in large DBs) self._arg_list_cache = { usage: { meta: self._arg_list(meta, usage) for sch, funcs in self.dbmetadata["functions"].items() for func, metas in funcs.items() for meta in metas } for usage in ("call", "call_display", "signature") } def extend_foreignkeys(self, fk_data): # fk_data is a list of ForeignKey namedtuples, with fields # parentschema, childschema, parenttable, childtable, # parentcolumns, childcolumns # These are added as a list of ForeignKey namedtuples to the # ColumnMetadata namedtuple for both the child and parent meta = self.dbmetadata["tables"] for fk in fk_data: e = self.escaped_names parentschema, childschema = e([fk.parentschema, fk.childschema]) parenttable, childtable = e([fk.parenttable, fk.childtable]) childcol, parcol = e([fk.childcolumn, fk.parentcolumn]) childcolmeta = meta[childschema][childtable][childcol] parcolmeta = meta[parentschema][parenttable][parcol] fk = ForeignKey( parentschema, parenttable, parcol, childschema, childtable, childcol ) childcolmeta.foreignkeys.append(fk) parcolmeta.foreignkeys.append(fk) def extend_datatypes(self, type_data): # dbmetadata['datatypes'][schema_name][type_name] should store type # metadata, such as composite type field names. Currently, we're not # storing any metadata beyond typename, so just store None meta = self.dbmetadata["datatypes"] for t in type_data: schema, type_name = self.escaped_names(t) meta[schema][type_name] = None self.all_completions.add(type_name) def extend_query_history(self, text, is_init=False): if is_init: # During completer initialization, only load keyword preferences, # not names self.prioritizer.update_keywords(text) else: self.prioritizer.update(text) def set_search_path(self, search_path): self.search_path = self.escaped_names(search_path) def reset_completions(self): self.databases = [] self.special_commands = [] self.search_path = [] self.dbmetadata = {"tables": {}, "views": {}, "functions": {}, "datatypes": {}} self.all_completions = set(self.keywords + self.functions) def find_matches(self, text, collection, mode="fuzzy", meta=None): """Find completion matches for the given text. Given the user's input text and a collection of available completions, find completions matching the last word of the text. `collection` can be either a list of strings or a list of Candidate namedtuples. `mode` can be either 'fuzzy', or 'strict' 'fuzzy': fuzzy matching, ties broken by name prevalance `keyword`: start only matching, ties broken by keyword prevalance yields prompt_toolkit Completion instances for any matches found in the collection of available completions. """ if not collection: return [] prio_order = [ "keyword", "function", "view", "table", "datatype", "database", "schema", "column", "table alias", "join", "name join", "fk join", "table format", ] type_priority = prio_order.index(meta) if meta in prio_order else -1 text = last_word(text, include="most_punctuations").lower() text_len = len(text) if text and text[0] == '"': # text starts with double quote; user is manually escaping a name # Match on everything that follows the double-quote. Note that # text_len is calculated before removing the quote, so the # Completion.position value is correct text = text[1:] if mode == "fuzzy": fuzzy = True priority_func = self.prioritizer.name_count else: fuzzy = False priority_func = self.prioritizer.keyword_count # Construct a `_match` function for either fuzzy or non-fuzzy matching # The match function returns a 2-tuple used for sorting the matches, # or None if the item doesn't match # Note: higher priority values mean more important, so use negative # signs to flip the direction of the tuple if fuzzy: regex = ".*?".join(map(re.escape, text)) pat = re.compile("(%s)" % regex) def _match(item): if item.lower()[: len(text) + 1] in (text, text + " "): # Exact match of first word in suggestion # This is to get exact alias matches to the top # E.g. for input `e`, 'Entries E' should be on top # (before e.g. `EndUsers EU`) return float("Infinity"), -1 r = pat.search(self.unescape_name(item.lower())) if r: return -len(r.group()), -r.start() else: match_end_limit = len(text) def _match(item): match_point = item.lower().find(text, 0, match_end_limit) if match_point >= 0: # Use negative infinity to force keywords to sort after all # fuzzy matches return -float("Infinity"), -match_point matches = [] for cand in collection: if isinstance(cand, _Candidate): item, prio, display_meta, synonyms, prio2, display = cand if display_meta is None: display_meta = meta syn_matches = (_match(x) for x in synonyms) # Nones need to be removed to avoid max() crashing in Python 3 syn_matches = [m for m in syn_matches if m] sort_key = max(syn_matches) if syn_matches else None else: item, display_meta, prio, prio2, display = cand, meta, 0, 0, cand sort_key = _match(cand) if sort_key: if display_meta and len(display_meta) > 50: # Truncate meta-text to 50 characters, if necessary display_meta = display_meta[:47] + "..." # Lexical order of items in the collection, used for # tiebreaking items with the same match group length and start # position. Since we use *higher* priority to mean "more # important," we use -ord(c) to prioritize "aa" > "ab" and end # with 1 to prioritize shorter strings (ie "user" > "users"). # We first do a case-insensitive sort and then a # case-sensitive one as a tie breaker. # We also use the unescape_name to make sure quoted names have # the same priority as unquoted names. lexical_priority = ( tuple( 0 if c in " _" else -ord(c) for c in self.unescape_name(item.lower()) ) + (1,) + tuple(c for c in item) ) item = self.case(item) display = self.case(display) priority = ( sort_key, type_priority, prio, priority_func(item), prio2, lexical_priority, ) matches.append( Match( completion=Completion( text=item, start_position=-text_len, display_meta=display_meta, display=display, ), priority=priority, ) ) return matches def case(self, word): return self.casing.get(word, word) def get_completions(self, document, complete_event, smart_completion=None): word_before_cursor = document.get_word_before_cursor(WORD=True) if smart_completion is None: smart_completion = self.smart_completion # If smart_completion is off then match any word that starts with # 'word_before_cursor'. if not smart_completion: matches = self.find_matches( word_before_cursor, self.all_completions, mode="strict" ) completions = [m.completion for m in matches] return sorted(completions, key=operator.attrgetter("text")) matches = [] suggestions = suggest_type(document.text, document.text_before_cursor) for suggestion in suggestions: suggestion_type = type(suggestion) _logger.debug("Suggestion type: %r", suggestion_type) # Map suggestion type to method # e.g. 'table' -> self.get_table_matches matcher = self.suggestion_matchers[suggestion_type] matches.extend(matcher(self, suggestion, word_before_cursor)) # Sort matches so highest priorities are first matches = sorted(matches, key=operator.attrgetter("priority"), reverse=True) return [m.completion for m in matches] def get_column_matches(self, suggestion, word_before_cursor): tables = suggestion.table_refs do_qualify = ( suggestion.qualifiable and { "always": True, "never": False, "if_more_than_one_table": len(tables) > 1, }[self.qualify_columns] ) qualify = lambda col, tbl: ( (tbl + "." + self.case(col)) if do_qualify else self.case(col) ) _logger.debug("Completion column scope: %r", tables) scoped_cols = self.populate_scoped_cols(tables, suggestion.local_tables) def make_cand(name, ref): synonyms = (name, generate_alias(self.case(name))) return Candidate(qualify(name, ref), 0, "column", synonyms) def flat_cols(): return [ make_cand(c.name, t.ref) for t, cols in scoped_cols.items() for c in cols ] if suggestion.require_last_table: # require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should # suggest only columns that appear in the last table and one more ltbl = tables[-1].ref other_tbl_cols = { c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs } scoped_cols = { t: [col for col in cols if col.name in other_tbl_cols] for t, cols in scoped_cols.items() if t.ref == ltbl } lastword = last_word(word_before_cursor, include="most_punctuations") if lastword == "*": if suggestion.context == "insert": def filter(col): if not col.has_default: return True return not any( p.match(col.default) for p in self.insert_col_skip_patterns ) scoped_cols = { t: [col for col in cols if filter(col)] for t, cols in scoped_cols.items() } if self.asterisk_column_order == "alphabetic": for cols in scoped_cols.values(): cols.sort(key=operator.attrgetter("name")) if ( lastword != word_before_cursor and len(tables) == 1 and word_before_cursor[-len(lastword) - 1] == "." ): # User typed x.*; replicate "x." for all columns except the # first, which gets the original (as we only replace the "*"") sep = ", " + word_before_cursor[:-1] collist = sep.join(self.case(c.completion) for c in flat_cols()) else: collist = ", ".join( qualify(c.name, t.ref) for t, cs in scoped_cols.items() for c in cs ) return [ Match( completion=Completion( collist, -1, display_meta="columns", display="*" ), priority=(1, 1, 1), ) ] return self.find_matches(word_before_cursor, flat_cols(), meta="column") def alias(self, tbl, tbls): """Generate a unique table alias tbl - name of the table to alias, quoted if it needs to be tbls - TableReference iterable of tables already in query """ tbl = self.case(tbl) tbls = {normalize_ref(t.ref) for t in tbls} if self.generate_aliases: tbl = generate_alias(self.unescape_name(tbl)) if normalize_ref(tbl) not in tbls: return tbl elif tbl[0] == '"': aliases = ('"' + tbl[1:-1] + str(i) + '"' for i in count(2)) else: aliases = (tbl + str(i) for i in count(2)) return next(a for a in aliases if normalize_ref(a) not in tbls) def get_join_matches(self, suggestion, word_before_cursor): tbls = suggestion.table_refs cols = self.populate_scoped_cols(tbls) # Set up some data structures for efficient access qualified = {normalize_ref(t.ref): t.schema for t in tbls} ref_prio = {normalize_ref(t.ref): n for n, t in enumerate(tbls)} refs = {normalize_ref(t.ref) for t in tbls} other_tbls = {(t.schema, t.name) for t in list(cols)[:-1]} joins = [] # Iterate over FKs in existing tables to find potential joins fks = ( (fk, rtbl, rcol) for rtbl, rcols in cols.items() for rcol in rcols for fk in rcol.foreignkeys ) col = namedtuple("col", "schema tbl col") for fk, rtbl, rcol in fks: right = col(rtbl.schema, rtbl.name, rcol.name) child = col(fk.childschema, fk.childtable, fk.childcolumn) parent = col(fk.parentschema, fk.parenttable, fk.parentcolumn) left = child if parent == right else parent if suggestion.schema and left.schema != suggestion.schema: continue c = self.case if self.generate_aliases or normalize_ref(left.tbl) in refs: lref = self.alias(left.tbl, suggestion.table_refs) join = "{0} {4} ON {4}.{1} = {2}.{3}".format( c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref ) else: join = "{0} ON {0}.{1} = {2}.{3}".format( c(left.tbl), c(left.col), rtbl.ref, c(right.col) ) alias = generate_alias(self.case(left.tbl)) synonyms = [ join, "{0} ON {0}.{1} = {2}.{3}".format( alias, c(left.col), rtbl.ref, c(right.col) ), ] # Schema-qualify if (1) new table in same schema as old, and old # is schema-qualified, or (2) new in other schema, except public if not suggestion.schema and ( qualified[normalize_ref(rtbl.ref)] and left.schema == right.schema or left.schema not in (right.schema, "public") ): join = left.schema + "." + join prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + ( 0 if (left.schema, left.tbl) in other_tbls else 1 ) joins.append(Candidate(join, prio, "join", synonyms=synonyms)) return self.find_matches(word_before_cursor, joins, meta="join") def get_join_condition_matches(self, suggestion, word_before_cursor): col = namedtuple("col", "schema tbl col") tbls = self.populate_scoped_cols(suggestion.table_refs).items cols = [(t, c) for t, cs in tbls() for c in cs] try: lref = (suggestion.parent or suggestion.table_refs[-1]).ref ltbl, lcols = [(t, cs) for (t, cs) in tbls() if t.ref == lref][-1] except IndexError: # The user typed an incorrect table qualifier return [] conds, found_conds = [], set() def add_cond(lcol, rcol, rref, prio, meta): prefix = "" if suggestion.parent else ltbl.ref + "." case = self.case cond = prefix + case(lcol) + " = " + rref + "." + case(rcol) if cond not in found_conds: found_conds.add(cond) conds.append(Candidate(cond, prio + ref_prio[rref], meta)) def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]} d = defaultdict(list) for pair in pairs: d[pair[0]].append(pair[1]) return d # Tables that are closer to the cursor get higher prio ref_prio = {tbl.ref: num for num, tbl in enumerate(suggestion.table_refs)} # Map (schema, table, col) to tables coldict = list_dict( ((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref ) # For each fk from the left table, generate a join condition if # the other table is also in the scope fks = ((fk, lcol.name) for lcol in lcols for fk in lcol.foreignkeys) for fk, lcol in fks: left = col(ltbl.schema, ltbl.name, lcol) child = col(fk.childschema, fk.childtable, fk.childcolumn) par = col(fk.parentschema, fk.parenttable, fk.parentcolumn) left, right = (child, par) if left == child else (par, child) for rtbl in coldict[right]: add_cond(left.col, right.col, rtbl.ref, 2000, "fk join") # For name matching, use a {(colname, coltype): TableReference} dict coltyp = namedtuple("coltyp", "name datatype") col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols) # Find all name-match join conditions for c in (coltyp(c.name, c.datatype) for c in lcols): for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref): prio = 1000 if c.datatype in ("integer", "bigint", "smallint") else 0 add_cond(c.name, c.name, rtbl.ref, prio, "name join") return self.find_matches(word_before_cursor, conds, meta="join") def get_function_matches(self, suggestion, word_before_cursor, alias=False): if suggestion.usage == "from": # Only suggest functions allowed in FROM clause def filt(f): return ( not f.is_aggregate and not f.is_window and not f.is_extension and ( f.is_public or f.schema_name in self.search_path or f.schema_name == suggestion.schema ) ) else: alias = False def filt(f): return not f.is_extension and ( f.is_public or f.schema_name == suggestion.schema ) arg_mode = {"signature": "signature", "special": None}.get( suggestion.usage, "call" ) # Function overloading means we way have multiple functions of the same # name at this point, so keep unique names only all_functions = self.populate_functions(suggestion.schema, filt) funcs = {self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions} matches = self.find_matches(word_before_cursor, funcs, meta="function") if not suggestion.schema and not suggestion.usage: # also suggest hardcoded functions using startswith matching predefined_funcs = self.find_matches( word_before_cursor, self.functions, mode="strict", meta="function" ) matches.extend(predefined_funcs) return matches def get_schema_matches(self, suggestion, word_before_cursor): schema_names = self.dbmetadata["tables"].keys() # Unless we're sure the user really wants them, hide schema names # starting with pg_, which are mostly temporary schemas if not word_before_cursor.startswith("pg_"): schema_names = [s for s in schema_names if not s.startswith("pg_")] if suggestion.quoted: schema_names = [self.escape_schema(s) for s in schema_names] return self.find_matches(word_before_cursor, schema_names, meta="schema") def get_from_clause_item_matches(self, suggestion, word_before_cursor): alias = self.generate_aliases s = suggestion t_sug = Table(s.schema, s.table_refs, s.local_tables) v_sug = View(s.schema, s.table_refs) f_sug = Function(s.schema, s.table_refs, usage="from") return ( self.get_table_matches(t_sug, word_before_cursor, alias) + self.get_view_matches(v_sug, word_before_cursor, alias) + self.get_function_matches(f_sug, word_before_cursor, alias) ) def _arg_list(self, func, usage): """Returns a an arg list string, e.g. `(_foo:=23)` for a func. :param func is a FunctionMetadata object :param usage is 'call', 'call_display' or 'signature' """ template = { "call": self.call_arg_style, "call_display": self.call_arg_display_style, "signature": self.signature_arg_style, }[usage] args = func.args() if not template: return "()" elif usage == "call" and len(args) < 2: return "()" elif usage == "call" and func.has_variadic(): return "()" multiline = usage == "call" and len(args) > self.call_arg_oneliner_max max_arg_len = max(len(a.name) for a in args) if multiline else 0 args = ( self._format_arg(template, arg, arg_num + 1, max_arg_len) for arg_num, arg in enumerate(args) ) if multiline: return "(" + ",".join("\n " + a for a in args if a) + "\n)" else: return "(" + ", ".join(a for a in args if a) + ")" def _format_arg(self, template, arg, arg_num, max_arg_len): if not template: return None if arg.has_default: arg_default = "NULL" if arg.default is None else arg.default # Remove trailing ::(schema.)type arg_default = arg_default_type_strip_regex.sub("", arg_default) else: arg_default = "" return template.format( max_arg_len=max_arg_len, arg_name=self.case(arg.name), arg_num=arg_num, arg_type=arg.datatype, arg_default=arg_default, ) def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None): """Returns a Candidate namedtuple. :param tbl is a SchemaObject :param arg_mode determines what type of arg list to suffix for functions. Possible values: call, signature """ cased_tbl = self.case(tbl.name) if do_alias: alias = self.alias(cased_tbl, suggestion.table_refs) synonyms = (cased_tbl, generate_alias(cased_tbl)) maybe_alias = (" " + alias) if do_alias else "" maybe_schema = (self.case(tbl.schema) + ".") if tbl.schema else "" suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else "" if arg_mode == "call": display_suffix = self._arg_list_cache["call_display"][tbl.meta] elif arg_mode == "signature": display_suffix = self._arg_list_cache["signature"][tbl.meta] else: display_suffix = "" item = maybe_schema + cased_tbl + suffix + maybe_alias display = maybe_schema + cased_tbl + display_suffix + maybe_alias prio2 = 0 if tbl.schema else 1 return Candidate(item, synonyms=synonyms, prio2=prio2, display=display) def get_table_matches(self, suggestion, word_before_cursor, alias=False): tables = self.populate_schema_objects(suggestion.schema, "tables") tables.extend(SchemaObject(tbl.name) for tbl in suggestion.local_tables) # Unless we're sure the user really wants them, don't suggest the # pg_catalog tables that are implicitly on the search path if not suggestion.schema and (not word_before_cursor.startswith("pg_")): tables = [t for t in tables if not t.name.startswith("pg_")] tables = [self._make_cand(t, alias, suggestion) for t in tables] return self.find_matches(word_before_cursor, tables, meta="table") def get_table_formats(self, _, word_before_cursor): formats = TabularOutputFormatter().supported_formats return self.find_matches(word_before_cursor, formats, meta="table format") def get_view_matches(self, suggestion, word_before_cursor, alias=False): views = self.populate_schema_objects(suggestion.schema, "views") if not suggestion.schema and (not word_before_cursor.startswith("pg_")): views = [v for v in views if not v.name.startswith("pg_")] views = [self._make_cand(v, alias, suggestion) for v in views] return self.find_matches(word_before_cursor, views, meta="view") def get_alias_matches(self, suggestion, word_before_cursor): aliases = suggestion.aliases return self.find_matches(word_before_cursor, aliases, meta="table alias") def get_database_matches(self, _, word_before_cursor): return self.find_matches(word_before_cursor, self.databases, meta="database") def get_keyword_matches(self, suggestion, word_before_cursor): keywords = self.keywords_tree.keys() # Get well known following keywords for the last token. If any, narrow # candidates to this list. next_keywords = self.keywords_tree.get(suggestion.last_token, []) if next_keywords: keywords = next_keywords casing = self.keyword_casing if casing == "auto": if word_before_cursor and word_before_cursor[-1].islower(): casing = "lower" else: casing = "upper" if casing == "upper": keywords = [k.upper() for k in keywords] else: keywords = [k.lower() for k in keywords] return self.find_matches( word_before_cursor, keywords, mode="strict", meta="keyword" ) def get_path_matches(self, _, word_before_cursor): completer = PathCompleter(expanduser=True) document = Document( text=word_before_cursor, cursor_position=len(word_before_cursor) ) for c in completer.get_completions(document, None): yield Match(completion=c, priority=(0,)) def get_special_matches(self, _, word_before_cursor): if not self.pgspecial: return [] commands = self.pgspecial.commands cmds = commands.keys() cmds = [Candidate(cmd, 0, commands[cmd].description) for cmd in cmds] return self.find_matches(word_before_cursor, cmds, mode="strict") def get_datatype_matches(self, suggestion, word_before_cursor): # suggest custom datatypes types = self.populate_schema_objects(suggestion.schema, "datatypes") types = [self._make_cand(t, False, suggestion) for t in types] matches = self.find_matches(word_before_cursor, types, meta="datatype") if not suggestion.schema: # Also suggest hardcoded types matches.extend( self.find_matches( word_before_cursor, self.datatypes, mode="strict", meta="datatype" ) ) return matches def get_namedquery_matches(self, _, word_before_cursor): return self.find_matches( word_before_cursor, NamedQueries.instance.list(), meta="named query" ) suggestion_matchers = { FromClauseItem: get_from_clause_item_matches, JoinCondition: get_join_condition_matches, Join: get_join_matches, Column: get_column_matches, Function: get_function_matches, Schema: get_schema_matches, Table: get_table_matches, TableFormat: get_table_formats, View: get_view_matches, Alias: get_alias_matches, Database: get_database_matches, Keyword: get_keyword_matches, Special: get_special_matches, Datatype: get_datatype_matches, NamedQuery: get_namedquery_matches, Path: get_path_matches, } def populate_scoped_cols(self, scoped_tbls, local_tbls=()): """Find all columns in a set of scoped_tables. :param scoped_tbls: list of TableReference namedtuples :param local_tbls: tuple(TableMetadata) :return: {TableReference:{colname:ColumnMetaData}} """ ctes = {normalize_ref(t.name): t.columns for t in local_tbls} columns = OrderedDict() meta = self.dbmetadata def addcols(schema, rel, alias, reltype, cols): tbl = TableReference(schema, rel, alias, reltype == "functions") if tbl not in columns: columns[tbl] = [] columns[tbl].extend(cols) for tbl in scoped_tbls: # Local tables should shadow database tables if tbl.schema is None and normalize_ref(tbl.name) in ctes: cols = ctes[normalize_ref(tbl.name)] addcols(None, tbl.name, "CTE", tbl.alias, cols) continue schemas = [tbl.schema] if tbl.schema else self.search_path for schema in schemas: relname = self.escape_name(tbl.name) schema = self.escape_name(schema) if tbl.is_function: # Return column names from a set-returning function # Get an array of FunctionMetadata objects functions = meta["functions"].get(schema, {}).get(relname) for func in functions or []: # func is a FunctionMetadata object cols = func.fields() addcols(schema, relname, tbl.alias, "functions", cols) else: for reltype in ("tables", "views"): cols = meta[reltype].get(schema, {}).get(relname) if cols: cols = cols.values() addcols(schema, relname, tbl.alias, reltype, cols) break return columns def _get_schemas(self, obj_typ, schema): """Returns a list of schemas from which to suggest objects. :param schema is the schema qualification input by the user (if any) """ metadata = self.dbmetadata[obj_typ] if schema: schema = self.escape_name(schema) return [schema] if schema in metadata else [] return self.search_path if self.search_path_filter else metadata.keys() def _maybe_schema(self, schema, parent): return None if parent or schema in self.search_path else schema def populate_schema_objects(self, schema, obj_type): """Returns a list of SchemaObjects representing tables or views. :param schema is the schema qualification input by the user (if any) """ return [ SchemaObject( name=obj, schema=(self._maybe_schema(schema=sch, parent=schema)) ) for sch in self._get_schemas(obj_type, schema) for obj in self.dbmetadata[obj_type][sch].keys() ] def populate_functions(self, schema, filter_func): """Returns a list of function SchemaObjects. :param filter_func is a function that accepts a FunctionMetadata namedtuple and returns a boolean indicating whether that function should be kept or discarded """ # Because of multiple dispatch, we can have multiple functions # with the same name, which is why `for meta in metas` is necessary # in the comprehensions below return [ SchemaObject( name=func, schema=(self._maybe_schema(schema=sch, parent=schema)), meta=meta, ) for sch in self._get_schemas("functions", schema) for (func, metas) in self.dbmetadata["functions"][sch].items() for meta in metas if filter_func(meta) ]