diff options
Diffstat (limited to 'tests/metadata.py')
-rw-r--r-- | tests/metadata.py | 255 |
1 files changed, 255 insertions, 0 deletions
diff --git a/tests/metadata.py b/tests/metadata.py new file mode 100644 index 0000000..4ebcccd --- /dev/null +++ b/tests/metadata.py @@ -0,0 +1,255 @@ +from functools import partial +from itertools import product +from pgcli.packages.parseutils.meta import FunctionMetadata, ForeignKey +from prompt_toolkit.completion import Completion +from prompt_toolkit.document import Document +from unittest.mock import Mock +import pytest + +parametrize = pytest.mark.parametrize + +qual = ["if_more_than_one_table", "always"] +no_qual = ["if_more_than_one_table", "never"] + + +def escape(name): + if not name.islower() or name in ("select", "localtimestamp"): + return '"' + name + '"' + return name + + +def completion(display_meta, text, pos=0): + return Completion(text, start_position=pos, display_meta=display_meta) + + +def function(text, pos=0, display=None): + return Completion( + text, display=display or text, start_position=pos, display_meta="function" + ) + + +def get_result(completer, text, position=None): + position = len(text) if position is None else position + return completer.get_completions( + Document(text=text, cursor_position=position), Mock() + ) + + +def result_set(completer, text, position=None): + return set(get_result(completer, text, position)) + + +# The code below is quivalent to +# def schema(text, pos=0): +# return completion('schema', text, pos) +# and so on +schema = partial(completion, "schema") +table = partial(completion, "table") +view = partial(completion, "view") +column = partial(completion, "column") +keyword = partial(completion, "keyword") +datatype = partial(completion, "datatype") +alias = partial(completion, "table alias") +name_join = partial(completion, "name join") +fk_join = partial(completion, "fk join") +join = partial(completion, "join") + + +def wildcard_expansion(cols, pos=-1): + return Completion(cols, start_position=pos, display_meta="columns", display="*") + + +class MetaData: + def __init__(self, metadata): + self.metadata = metadata + + def builtin_functions(self, pos=0): + return [function(f, pos) for f in self.completer.functions] + + def builtin_datatypes(self, pos=0): + return [datatype(dt, pos) for dt in self.completer.datatypes] + + def keywords(self, pos=0): + return [keyword(kw, pos) for kw in self.completer.keywords_tree.keys()] + + def specials(self, pos=0): + return [ + Completion(text=k, start_position=pos, display_meta=v.description) + for k, v in self.completer.pgspecial.commands.items() + ] + + def columns(self, tbl, parent="public", typ="tables", pos=0): + if typ == "functions": + fun = [x for x in self.metadata[typ][parent] if x[0] == tbl][0] + cols = fun[1] + else: + cols = self.metadata[typ][parent][tbl] + return [column(escape(col), pos) for col in cols] + + def datatypes(self, parent="public", pos=0): + return [ + datatype(escape(x), pos) + for x in self.metadata.get("datatypes", {}).get(parent, []) + ] + + def tables(self, parent="public", pos=0): + return [ + table(escape(x), pos) + for x in self.metadata.get("tables", {}).get(parent, []) + ] + + def views(self, parent="public", pos=0): + return [ + view(escape(x), pos) for x in self.metadata.get("views", {}).get(parent, []) + ] + + def functions(self, parent="public", pos=0): + return [ + function( + escape(x[0]) + + "(" + + ", ".join( + arg_name + " := " + for (arg_name, arg_mode) in zip(x[1], x[3]) + if arg_mode in ("b", "i") + ) + + ")", + pos, + escape(x[0]) + + "(" + + ", ".join( + arg_name + for (arg_name, arg_mode) in zip(x[1], x[3]) + if arg_mode in ("b", "i") + ) + + ")", + ) + for x in self.metadata.get("functions", {}).get(parent, []) + ] + + def schemas(self, pos=0): + schemas = {sch for schs in self.metadata.values() for sch in schs} + return [schema(escape(s), pos=pos) for s in schemas] + + def functions_and_keywords(self, parent="public", pos=0): + return ( + self.functions(parent, pos) + + self.builtin_functions(pos) + + self.keywords(pos) + ) + + # Note that the filtering parameters here only apply to the columns + def columns_functions_and_keywords(self, tbl, parent="public", typ="tables", pos=0): + return self.functions_and_keywords(pos=pos) + self.columns( + tbl, parent, typ, pos + ) + + def from_clause_items(self, parent="public", pos=0): + return ( + self.functions(parent, pos) + + self.views(parent, pos) + + self.tables(parent, pos) + ) + + def schemas_and_from_clause_items(self, parent="public", pos=0): + return self.from_clause_items(parent, pos) + self.schemas(pos) + + def types(self, parent="public", pos=0): + return self.datatypes(parent, pos) + self.tables(parent, pos) + + @property + def completer(self): + return self.get_completer() + + def get_completers(self, casing): + """ + Returns a function taking three bools `casing`, `filtr`, `aliasing` and + the list `qualify`, all defaulting to None. + Returns a list of completers. + These parameters specify the allowed values for the corresponding + completer parameters, `None` meaning any, i.e. (None, None, None, None) + results in all 24 possible completers, whereas e.g. + (True, False, True, ['never']) results in the one completer with + casing, without `search_path` filtering of objects, with table + aliasing, and without column qualification. + """ + + def _cfg(_casing, filtr, aliasing, qualify): + cfg = {"settings": {}} + if _casing: + cfg["casing"] = casing + cfg["settings"]["search_path_filter"] = filtr + cfg["settings"]["generate_aliases"] = aliasing + cfg["settings"]["qualify_columns"] = qualify + return cfg + + def _cfgs(casing, filtr, aliasing, qualify): + casings = [True, False] if casing is None else [casing] + filtrs = [True, False] if filtr is None else [filtr] + aliases = [True, False] if aliasing is None else [aliasing] + qualifys = qualify or ["always", "if_more_than_one_table", "never"] + return [_cfg(*p) for p in product(casings, filtrs, aliases, qualifys)] + + def completers(casing=None, filtr=None, aliasing=None, qualify=None): + get_comp = self.get_completer + return [get_comp(**c) for c in _cfgs(casing, filtr, aliasing, qualify)] + + return completers + + def _make_col(self, sch, tbl, col): + defaults = self.metadata.get("defaults", {}).get(sch, {}) + return (sch, tbl, col, "text", (tbl, col) in defaults, defaults.get((tbl, col))) + + def get_completer(self, settings=None, casing=None): + metadata = self.metadata + from pgcli.pgcompleter import PGCompleter + from pgspecial import PGSpecial + + comp = PGCompleter( + smart_completion=True, settings=settings, pgspecial=PGSpecial() + ) + + schemata, tables, tbl_cols, views, view_cols = [], [], [], [], [] + + for sch, tbls in metadata["tables"].items(): + schemata.append(sch) + + for tbl, cols in tbls.items(): + tables.append((sch, tbl)) + # Let all columns be text columns + tbl_cols.extend([self._make_col(sch, tbl, col) for col in cols]) + + for sch, tbls in metadata.get("views", {}).items(): + for tbl, cols in tbls.items(): + views.append((sch, tbl)) + # Let all columns be text columns + view_cols.extend([self._make_col(sch, tbl, col) for col in cols]) + + functions = [ + FunctionMetadata(sch, *func_meta, arg_defaults=None) + for sch, funcs in metadata["functions"].items() + for func_meta in funcs + ] + + datatypes = [ + (sch, typ) + for sch, datatypes in metadata["datatypes"].items() + for typ in datatypes + ] + + foreignkeys = [ + ForeignKey(*fk) for fks in metadata["foreignkeys"].values() for fk in fks + ] + + comp.extend_schemata(schemata) + comp.extend_relations(tables, kind="tables") + comp.extend_relations(views, kind="views") + comp.extend_columns(tbl_cols, kind="tables") + comp.extend_columns(view_cols, kind="views") + comp.extend_functions(functions) + comp.extend_datatypes(datatypes) + comp.extend_foreignkeys(foreignkeys) + comp.set_search_path(["public"]) + comp.extend_casing(casing or []) + + return comp |