summaryrefslogtreecommitdiffstats
path: root/mycli/packages/special/iocommands.py
blob: 11dca8d0bdbcf6af61bd8ff49781a6201742bc69 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
import os
import re
import locale
import logging
import subprocess
import shlex
from io import open
from time import sleep

import click
import sqlparse

from . import export
from .main import special_command, NO_QUERY, PARSED_QUERY
from .favoritequeries import FavoriteQueries
from .delimitercommand import DelimiterCommand
from .utils import handle_cd_command
from mycli.packages.prompt_utils import confirm_destructive_query

TIMING_ENABLED = False
use_expanded_output = False
PAGER_ENABLED = True
tee_file = None
once_file = None
written_to_once_file = False
delimiter_command = DelimiterCommand()


@export
def set_timing_enabled(val):
    global TIMING_ENABLED
    TIMING_ENABLED = val

@export
def set_pager_enabled(val):
    global PAGER_ENABLED
    PAGER_ENABLED = val


@export
def is_pager_enabled():
    return PAGER_ENABLED

@export
@special_command('pager', '\\P [command]',
                 'Set PAGER. Print the query results via PAGER.',
                 arg_type=PARSED_QUERY, aliases=('\\P', ), case_sensitive=True)
def set_pager(arg, **_):
    if arg:
        os.environ['PAGER'] = arg
        msg = 'PAGER set to %s.' % arg
        set_pager_enabled(True)
    else:
        if 'PAGER' in os.environ:
            msg = 'PAGER set to %s.' % os.environ['PAGER']
        else:
            # This uses click's default per echo_via_pager.
            msg = 'Pager enabled.'
        set_pager_enabled(True)

    return [(None, None, None, msg)]

@export
@special_command('nopager', '\\n', 'Disable pager, print to stdout.',
                 arg_type=NO_QUERY, aliases=('\\n', ), case_sensitive=True)
def disable_pager():
    set_pager_enabled(False)
    return [(None, None, None, 'Pager disabled.')]

@special_command('\\timing', '\\t', 'Toggle timing of commands.', arg_type=NO_QUERY, aliases=('\\t', ), case_sensitive=True)
def toggle_timing():
    global TIMING_ENABLED
    TIMING_ENABLED = not TIMING_ENABLED
    message = "Timing is "
    message += "on." if TIMING_ENABLED else "off."
    return [(None, None, None, message)]

@export
def is_timing_enabled():
    return TIMING_ENABLED

@export
def set_expanded_output(val):
    global use_expanded_output
    use_expanded_output = val

@export
def is_expanded_output():
    return use_expanded_output

_logger = logging.getLogger(__name__)

@export
def editor_command(command):
    """
    Is this an external editor command?
    :param command: string
    """
    # It is possible to have `\e filename` or `SELECT * FROM \e`. So we check
    # for both conditions.
    return command.strip().endswith('\\e') or command.strip().startswith('\\e')

@export
def get_filename(sql):
    if sql.strip().startswith('\\e'):
        command, _, filename = sql.partition(' ')
        return filename.strip() or None


@export
def get_editor_query(sql):
    """Get the query part of an editor command."""
    sql = sql.strip()

    # The reason we can't simply do .strip('\e') is that it strips characters,
    # not a substring. So it'll strip "e" in the end of the sql also!
    # Ex: "select * from style\e" -> "select * from styl".
    pattern = re.compile('(^\\\e|\\\e$)')
    while pattern.search(sql):
        sql = pattern.sub('', sql)

    return sql


@export
def open_external_editor(filename=None, sql=None):
    """Open external editor, wait for the user to type in their query, return
    the query.

    :return: list with one tuple, query as first element.

    """

    message = None
    filename = filename.strip().split(' ', 1)[0] if filename else None

    sql = sql or ''
    MARKER = '# Type your query above this line.\n'

    # Populate the editor buffer with the partial sql (if available) and a
    # placeholder comment.
    query = click.edit(u'{sql}\n\n{marker}'.format(sql=sql, marker=MARKER),
                       filename=filename, extension='.sql')

    if filename:
        try:
            with open(filename) as f:
                query = f.read()
        except IOError:
            message = 'Error reading file: %s.' % filename

    if query is not None:
        query = query.split(MARKER, 1)[0].rstrip('\n')
    else:
        # Don't return None for the caller to deal with.
        # Empty string is ok.
        query = sql

    return (query, message)


@special_command('\\f', '\\f [name [args..]]', 'List or execute favorite queries.', arg_type=PARSED_QUERY, case_sensitive=True)
def execute_favorite_query(cur, arg, **_):
    """Returns (title, rows, headers, status)"""
    if arg == '':
        for result in list_favorite_queries():
            yield result

    """Parse out favorite name and optional substitution parameters"""
    name, _, arg_str = arg.partition(' ')
    args = shlex.split(arg_str)

    query = FavoriteQueries.instance.get(name)
    if query is None:
        message = "No favorite query: %s" % (name)
        yield (None, None, None, message)
    else:
        query, arg_error = subst_favorite_query_args(query, args)
        if arg_error:
            yield (None, None, None, arg_error)
        else:
            for sql in sqlparse.split(query):
                sql = sql.rstrip(';')
                title = '> %s' % (sql)
                cur.execute(sql)
                if cur.description:
                    headers = [x[0] for x in cur.description]
                    yield (title, cur, headers, None)
                else:
                    yield (title, None, None, None)

def list_favorite_queries():
    """List of all favorite queries.
    Returns (title, rows, headers, status)"""

    headers = ["Name", "Query"]
    rows = [(r, FavoriteQueries.instance.get(r))
            for r in FavoriteQueries.instance.list()]

    if not rows:
        status = '\nNo favorite queries found.' + FavoriteQueries.instance.usage
    else:
        status = ''
    return [('', rows, headers, status)]


def subst_favorite_query_args(query, args):
    """replace positional parameters ($1...$N) in query."""
    for idx, val in enumerate(args):
        subst_var = '$' + str(idx + 1)
        if subst_var not in query:
            return [None, 'query does not have substitution parameter ' + subst_var + ':\n  ' + query]

        query = query.replace(subst_var, val)

    match = re.search('\\$\d+', query)
    if match:
        return[None, 'missing substitution for ' + match.group(0) + ' in query:\n  ' + query]

    return [query, None]

@special_command('\\fs', '\\fs name query', 'Save a favorite query.')
def save_favorite_query(arg, **_):
    """Save a new favorite query.
    Returns (title, rows, headers, status)"""

    usage = 'Syntax: \\fs name query.\n\n' + FavoriteQueries.instance.usage
    if not arg:
        return [(None, None, None, usage)]

    name, _, query = arg.partition(' ')

    # If either name or query is missing then print the usage and complain.
    if (not name) or (not query):
        return [(None, None, None,
            usage + 'Err: Both name and query are required.')]

    FavoriteQueries.instance.save(name, query)
    return [(None, None, None, "Saved.")]


@special_command('\\fd', '\\fd [name]', 'Delete a favorite query.')
def delete_favorite_query(arg, **_):
    """Delete an existing favorite query."""
    usage = 'Syntax: \\fd name.\n\n' + FavoriteQueries.instance.usage
    if not arg:
        return [(None, None, None, usage)]

    status = FavoriteQueries.instance.delete(arg)

    return [(None, None, None, status)]


@special_command('system', 'system [command]',
                 'Execute a system shell commmand.')
def execute_system_command(arg, **_):
    """Execute a system shell command."""
    usage = "Syntax: system [command].\n"

    if not arg:
      return [(None, None, None, usage)]

    try:
        command = arg.strip()
        if command.startswith('cd'):
            ok, error_message = handle_cd_command(arg)
            if not ok:
                return [(None, None, None, error_message)]
            return [(None, None, None, '')]

        args = arg.split(' ')
        process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        output, error = process.communicate()
        response = output if not error else error

        # Python 3 returns bytes. This needs to be decoded to a string.
        if isinstance(response, bytes):
            encoding = locale.getpreferredencoding(False)
            response = response.decode(encoding)

        return [(None, None, None, response)]
    except OSError as e:
        return [(None, None, None, 'OSError: %s' % e.strerror)]


def parseargfile(arg):
    if arg.startswith('-o '):
        mode = "w"
        filename = arg[3:]
    else:
        mode = 'a'
        filename = arg

    if not filename:
        raise TypeError('You must provide a filename.')

    return {'file': os.path.expanduser(filename), 'mode': mode}


@special_command('tee', 'tee [-o] filename',
                 'Append all results to an output file (overwrite using -o).')
def set_tee(arg, **_):
    global tee_file

    try:
        tee_file = open(**parseargfile(arg))
    except (IOError, OSError) as e:
        raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror))

    return [(None, None, None, "")]

@export
def close_tee():
    global tee_file
    if tee_file:
        tee_file.close()
        tee_file = None


@special_command('notee', 'notee', 'Stop writing results to an output file.')
def no_tee(arg, **_):
    close_tee()
    return [(None, None, None, "")]

@export
def write_tee(output):
    global tee_file
    if tee_file:
        click.echo(output, file=tee_file, nl=False)
        click.echo(u'\n', file=tee_file, nl=False)
        tee_file.flush()


@special_command('\\once', '\\o [-o] filename',
                 'Append next result to an output file (overwrite using -o).',
                 aliases=('\\o', ))
def set_once(arg, **_):
    global once_file, written_to_once_file

    once_file = parseargfile(arg)
    written_to_once_file = False

    return [(None, None, None, "")]


@export
def write_once(output):
    global once_file, written_to_once_file
    if output and once_file:
        try:
            f = open(**once_file)
        except (IOError, OSError) as e:
            once_file = None
            raise OSError("Cannot write to file '{}': {}".format(
                e.filename, e.strerror))
        with f:
            click.echo(output, file=f, nl=False)
            click.echo(u"\n", file=f, nl=False)
        written_to_once_file = True


@export
def unset_once_if_written():
    """Unset the once file, if it has been written to."""
    global once_file
    if written_to_once_file:
        once_file = None


@special_command(
    'watch',
    'watch [seconds] [-c] query',
    'Executes the query every [seconds] seconds (by default 5).'
)
def watch_query(arg, **kwargs):
    usage = """Syntax: watch [seconds] [-c] query.
    * seconds: The interval at the query will be repeated, in seconds.
               By default 5.
    * -c: Clears the screen between every iteration.
"""
    if not arg:
        yield (None, None, None, usage)
        return
    seconds = 5
    clear_screen = False
    statement = None
    while statement is None:
        arg = arg.strip()
        if not arg:
            # Oops, we parsed all the arguments without finding a statement
            yield (None, None, None, usage)
            return
        (current_arg, _, arg) = arg.partition(' ')
        try:
            seconds = float(current_arg)
            continue
        except ValueError:
            pass
        if current_arg == '-c':
            clear_screen = True
            continue
        statement = '{0!s} {1!s}'.format(current_arg, arg)
    destructive_prompt = confirm_destructive_query(statement)
    if destructive_prompt is False:
        click.secho("Wise choice!")
        return
    elif destructive_prompt is True:
        click.secho("Your call!")
    cur = kwargs['cur']
    sql_list = [
        (sql.rstrip(';'), "> {0!s}".format(sql))
        for sql in sqlparse.split(statement)
    ]
    old_pager_enabled = is_pager_enabled()
    while True:
        if clear_screen:
            click.clear()
        try:
            # Somewhere in the code the pager its activated after every yield,
            # so we disable it in every iteration
            set_pager_enabled(False)
            for (sql, title) in sql_list:
                cur.execute(sql)
                if cur.description:
                    headers = [x[0] for x in cur.description]
                    yield (title, cur, headers, None)
                else:
                    yield (title, None, None, None)
            sleep(seconds)
        except KeyboardInterrupt:
            # This prints the Ctrl-C character in its own line, which prevents
            # to print a line with the cursor positioned behind the prompt
            click.secho("", nl=True)
            return
        finally:
            set_pager_enabled(old_pager_enabled)


@export
@special_command('delimiter', None, 'Change SQL delimiter.')
def set_delimiter(arg, **_):
    return delimiter_command.set(arg)


@export
def get_current_delimiter():
    return delimiter_command.current


@export
def split_queries(input):
    for query in delimiter_command.queries_iter(input):
        yield query