summaryrefslogtreecommitdiffstats
path: root/pgcli/completion_refresher.py
blob: 1039d51599f40b02bdf84cb3282330a09aa71225 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import threading
import os
from collections import OrderedDict

from .pgcompleter import PGCompleter


class CompletionRefresher:

    refreshers = OrderedDict()

    def __init__(self):
        self._completer_thread = None
        self._restart_refresh = threading.Event()

    def refresh(self, executor, special, callbacks, history=None, settings=None):
        """
        Creates a PGCompleter object and populates it with the relevant
        completion suggestions in a background thread.

        executor - PGExecute object, used to extract the credentials to connect
                   to the database.
        special - PGSpecial object used for creating a new completion object.
        settings - dict of settings for completer object
        callbacks - A function or a list of functions to call after the thread
                    has completed the refresh. The newly created completion
                    object will be passed in as an argument to each callback.
        """
        if executor.is_virtual_database():
            # do nothing
            return [(None, None, None, "Auto-completion refresh can't be started.")]

        if self.is_refreshing():
            self._restart_refresh.set()
            return [(None, None, None, "Auto-completion refresh restarted.")]
        else:
            self._completer_thread = threading.Thread(
                target=self._bg_refresh,
                args=(executor, special, callbacks, history, settings),
                name="completion_refresh",
            )
            self._completer_thread.setDaemon(True)
            self._completer_thread.start()
            return [
                (None, None, None, "Auto-completion refresh started in the background.")
            ]

    def is_refreshing(self):
        return self._completer_thread and self._completer_thread.is_alive()

    def _bg_refresh(self, pgexecute, special, callbacks, history=None, settings=None):
        settings = settings or {}
        completer = PGCompleter(
            smart_completion=True, pgspecial=special, settings=settings
        )

        if settings.get("single_connection"):
            executor = pgexecute
        else:
            # Create a new pgexecute method to populate the completions.
            executor = pgexecute.copy()
        # If callbacks is a single function then push it into a list.
        if callable(callbacks):
            callbacks = [callbacks]

        while 1:
            for refresher in self.refreshers.values():
                refresher(completer, executor)
                if self._restart_refresh.is_set():
                    self._restart_refresh.clear()
                    break
            else:
                # Break out of while loop if the for loop finishes natually
                # without hitting the break statement.
                break

            # Start over the refresh from the beginning if the for loop hit the
            # break statement.
            continue

        # Load history into pgcompleter so it can learn user preferences
        n_recent = 100
        if history:
            for recent in history.get_strings()[-n_recent:]:
                completer.extend_query_history(recent, is_init=True)

        for callback in callbacks:
            callback(completer)

        if not settings.get("single_connection") and executor.conn:
            # close connection established with pgexecute.copy()
            executor.conn.close()


def refresher(name, refreshers=CompletionRefresher.refreshers):
    """Decorator to populate the dictionary of refreshers with the current
    function.
    """

    def wrapper(wrapped):
        refreshers[name] = wrapped
        return wrapped

    return wrapper


@refresher("schemata")
def refresh_schemata(completer, executor):
    completer.set_search_path(executor.search_path())
    completer.extend_schemata(executor.schemata())


@refresher("tables")
def refresh_tables(completer, executor):
    completer.extend_relations(executor.tables(), kind="tables")
    completer.extend_columns(executor.table_columns(), kind="tables")
    completer.extend_foreignkeys(executor.foreignkeys())


@refresher("views")
def refresh_views(completer, executor):
    completer.extend_relations(executor.views(), kind="views")
    completer.extend_columns(executor.view_columns(), kind="views")


@refresher("types")
def refresh_types(completer, executor):
    completer.extend_datatypes(executor.datatypes())


@refresher("databases")
def refresh_databases(completer, executor):
    completer.extend_database_names(executor.databases())


@refresher("casing")
def refresh_casing(completer, executor):
    casing_file = completer.casing_file
    if not casing_file:
        return
    generate_casing_file = completer.generate_casing_file
    if generate_casing_file and not os.path.isfile(casing_file):
        casing_prefs = "\n".join(executor.casing())
        with open(casing_file, "w") as f:
            f.write(casing_prefs)
    if os.path.isfile(casing_file):
        with open(casing_file) as f:
            completer.extend_casing([line.strip() for line in f])


@refresher("functions")
def refresh_functions(completer, executor):
    completer.extend_functions(executor.functions())