summaryrefslogtreecommitdiffstats
path: root/tests/metadata.py
blob: 4ebcccd07dd4fca727e04b21dd3e040337ef8e2d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
from functools import partial
from itertools import product
from pgcli.packages.parseutils.meta import FunctionMetadata, ForeignKey
from prompt_toolkit.completion import Completion
from prompt_toolkit.document import Document
from unittest.mock import Mock
import pytest

parametrize = pytest.mark.parametrize

qual = ["if_more_than_one_table", "always"]
no_qual = ["if_more_than_one_table", "never"]


def escape(name):
    if not name.islower() or name in ("select", "localtimestamp"):
        return '"' + name + '"'
    return name


def completion(display_meta, text, pos=0):
    return Completion(text, start_position=pos, display_meta=display_meta)


def function(text, pos=0, display=None):
    return Completion(
        text, display=display or text, start_position=pos, display_meta="function"
    )


def get_result(completer, text, position=None):
    position = len(text) if position is None else position
    return completer.get_completions(
        Document(text=text, cursor_position=position), Mock()
    )


def result_set(completer, text, position=None):
    return set(get_result(completer, text, position))


# The code below is quivalent to
# def schema(text, pos=0):
#   return completion('schema', text, pos)
# and so on
schema = partial(completion, "schema")
table = partial(completion, "table")
view = partial(completion, "view")
column = partial(completion, "column")
keyword = partial(completion, "keyword")
datatype = partial(completion, "datatype")
alias = partial(completion, "table alias")
name_join = partial(completion, "name join")
fk_join = partial(completion, "fk join")
join = partial(completion, "join")


def wildcard_expansion(cols, pos=-1):
    return Completion(cols, start_position=pos, display_meta="columns", display="*")


class MetaData:
    def __init__(self, metadata):
        self.metadata = metadata

    def builtin_functions(self, pos=0):
        return [function(f, pos) for f in self.completer.functions]

    def builtin_datatypes(self, pos=0):
        return [datatype(dt, pos) for dt in self.completer.datatypes]

    def keywords(self, pos=0):
        return [keyword(kw, pos) for kw in self.completer.keywords_tree.keys()]

    def specials(self, pos=0):
        return [
            Completion(text=k, start_position=pos, display_meta=v.description)
            for k, v in self.completer.pgspecial.commands.items()
        ]

    def columns(self, tbl, parent="public", typ="tables", pos=0):
        if typ == "functions":
            fun = [x for x in self.metadata[typ][parent] if x[0] == tbl][0]
            cols = fun[1]
        else:
            cols = self.metadata[typ][parent][tbl]
        return [column(escape(col), pos) for col in cols]

    def datatypes(self, parent="public", pos=0):
        return [
            datatype(escape(x), pos)
            for x in self.metadata.get("datatypes", {}).get(parent, [])
        ]

    def tables(self, parent="public", pos=0):
        return [
            table(escape(x), pos)
            for x in self.metadata.get("tables", {}).get(parent, [])
        ]

    def views(self, parent="public", pos=0):
        return [
            view(escape(x), pos) for x in self.metadata.get("views", {}).get(parent, [])
        ]

    def functions(self, parent="public", pos=0):
        return [
            function(
                escape(x[0])
                + "("
                + ", ".join(
                    arg_name + " := "
                    for (arg_name, arg_mode) in zip(x[1], x[3])
                    if arg_mode in ("b", "i")
                )
                + ")",
                pos,
                escape(x[0])
                + "("
                + ", ".join(
                    arg_name
                    for (arg_name, arg_mode) in zip(x[1], x[3])
                    if arg_mode in ("b", "i")
                )
                + ")",
            )
            for x in self.metadata.get("functions", {}).get(parent, [])
        ]

    def schemas(self, pos=0):
        schemas = {sch for schs in self.metadata.values() for sch in schs}
        return [schema(escape(s), pos=pos) for s in schemas]

    def functions_and_keywords(self, parent="public", pos=0):
        return (
            self.functions(parent, pos)
            + self.builtin_functions(pos)
            + self.keywords(pos)
        )

    # Note that the filtering parameters here only apply to the columns
    def columns_functions_and_keywords(self, tbl, parent="public", typ="tables", pos=0):
        return self.functions_and_keywords(pos=pos) + self.columns(
            tbl, parent, typ, pos
        )

    def from_clause_items(self, parent="public", pos=0):
        return (
            self.functions(parent, pos)
            + self.views(parent, pos)
            + self.tables(parent, pos)
        )

    def schemas_and_from_clause_items(self, parent="public", pos=0):
        return self.from_clause_items(parent, pos) + self.schemas(pos)

    def types(self, parent="public", pos=0):
        return self.datatypes(parent, pos) + self.tables(parent, pos)

    @property
    def completer(self):
        return self.get_completer()

    def get_completers(self, casing):
        """
        Returns a function taking three bools `casing`, `filtr`, `aliasing` and
        the list `qualify`, all defaulting to None.
        Returns a list of completers.
        These parameters specify the allowed values for the corresponding
        completer parameters, `None` meaning any, i.e. (None, None, None, None)
        results in all 24 possible completers, whereas e.g.
        (True, False, True, ['never']) results in the one completer with
        casing, without `search_path` filtering of objects, with table
        aliasing, and without column qualification.
        """

        def _cfg(_casing, filtr, aliasing, qualify):
            cfg = {"settings": {}}
            if _casing:
                cfg["casing"] = casing
            cfg["settings"]["search_path_filter"] = filtr
            cfg["settings"]["generate_aliases"] = aliasing
            cfg["settings"]["qualify_columns"] = qualify
            return cfg

        def _cfgs(casing, filtr, aliasing, qualify):
            casings = [True, False] if casing is None else [casing]
            filtrs = [True, False] if filtr is None else [filtr]
            aliases = [True, False] if aliasing is None else [aliasing]
            qualifys = qualify or ["always", "if_more_than_one_table", "never"]
            return [_cfg(*p) for p in product(casings, filtrs, aliases, qualifys)]

        def completers(casing=None, filtr=None, aliasing=None, qualify=None):
            get_comp = self.get_completer
            return [get_comp(**c) for c in _cfgs(casing, filtr, aliasing, qualify)]

        return completers

    def _make_col(self, sch, tbl, col):
        defaults = self.metadata.get("defaults", {}).get(sch, {})
        return (sch, tbl, col, "text", (tbl, col) in defaults, defaults.get((tbl, col)))

    def get_completer(self, settings=None, casing=None):
        metadata = self.metadata
        from pgcli.pgcompleter import PGCompleter
        from pgspecial import PGSpecial

        comp = PGCompleter(
            smart_completion=True, settings=settings, pgspecial=PGSpecial()
        )

        schemata, tables, tbl_cols, views, view_cols = [], [], [], [], []

        for sch, tbls in metadata["tables"].items():
            schemata.append(sch)

            for tbl, cols in tbls.items():
                tables.append((sch, tbl))
                # Let all columns be text columns
                tbl_cols.extend([self._make_col(sch, tbl, col) for col in cols])

        for sch, tbls in metadata.get("views", {}).items():
            for tbl, cols in tbls.items():
                views.append((sch, tbl))
                # Let all columns be text columns
                view_cols.extend([self._make_col(sch, tbl, col) for col in cols])

        functions = [
            FunctionMetadata(sch, *func_meta, arg_defaults=None)
            for sch, funcs in metadata["functions"].items()
            for func_meta in funcs
        ]

        datatypes = [
            (sch, typ)
            for sch, datatypes in metadata["datatypes"].items()
            for typ in datatypes
        ]

        foreignkeys = [
            ForeignKey(*fk) for fks in metadata["foreignkeys"].values() for fk in fks
        ]

        comp.extend_schemata(schemata)
        comp.extend_relations(tables, kind="tables")
        comp.extend_relations(views, kind="views")
        comp.extend_columns(tbl_cols, kind="tables")
        comp.extend_columns(view_cols, kind="views")
        comp.extend_functions(functions)
        comp.extend_datatypes(datatypes)
        comp.extend_foreignkeys(foreignkeys)
        comp.set_search_path(["public"])
        comp.extend_casing(casing or [])

        return comp