diff options
Diffstat (limited to 'mycli/packages/special/iocommands.py')
-rw-r--r-- | mycli/packages/special/iocommands.py | 453 |
1 files changed, 453 insertions, 0 deletions
diff --git a/mycli/packages/special/iocommands.py b/mycli/packages/special/iocommands.py new file mode 100644 index 0000000..11dca8d --- /dev/null +++ b/mycli/packages/special/iocommands.py @@ -0,0 +1,453 @@ +import os +import re +import locale +import logging +import subprocess +import shlex +from io import open +from time import sleep + +import click +import sqlparse + +from . import export +from .main import special_command, NO_QUERY, PARSED_QUERY +from .favoritequeries import FavoriteQueries +from .delimitercommand import DelimiterCommand +from .utils import handle_cd_command +from mycli.packages.prompt_utils import confirm_destructive_query + +TIMING_ENABLED = False +use_expanded_output = False +PAGER_ENABLED = True +tee_file = None +once_file = None +written_to_once_file = False +delimiter_command = DelimiterCommand() + + +@export +def set_timing_enabled(val): + global TIMING_ENABLED + TIMING_ENABLED = val + +@export +def set_pager_enabled(val): + global PAGER_ENABLED + PAGER_ENABLED = val + + +@export +def is_pager_enabled(): + return PAGER_ENABLED + +@export +@special_command('pager', '\\P [command]', + 'Set PAGER. Print the query results via PAGER.', + arg_type=PARSED_QUERY, aliases=('\\P', ), case_sensitive=True) +def set_pager(arg, **_): + if arg: + os.environ['PAGER'] = arg + msg = 'PAGER set to %s.' % arg + set_pager_enabled(True) + else: + if 'PAGER' in os.environ: + msg = 'PAGER set to %s.' % os.environ['PAGER'] + else: + # This uses click's default per echo_via_pager. + msg = 'Pager enabled.' + set_pager_enabled(True) + + return [(None, None, None, msg)] + +@export +@special_command('nopager', '\\n', 'Disable pager, print to stdout.', + arg_type=NO_QUERY, aliases=('\\n', ), case_sensitive=True) +def disable_pager(): + set_pager_enabled(False) + return [(None, None, None, 'Pager disabled.')] + +@special_command('\\timing', '\\t', 'Toggle timing of commands.', arg_type=NO_QUERY, aliases=('\\t', ), case_sensitive=True) +def toggle_timing(): + global TIMING_ENABLED + TIMING_ENABLED = not TIMING_ENABLED + message = "Timing is " + message += "on." if TIMING_ENABLED else "off." + return [(None, None, None, message)] + +@export +def is_timing_enabled(): + return TIMING_ENABLED + +@export +def set_expanded_output(val): + global use_expanded_output + use_expanded_output = val + +@export +def is_expanded_output(): + return use_expanded_output + +_logger = logging.getLogger(__name__) + +@export +def editor_command(command): + """ + Is this an external editor command? + :param command: string + """ + # It is possible to have `\e filename` or `SELECT * FROM \e`. So we check + # for both conditions. + return command.strip().endswith('\\e') or command.strip().startswith('\\e') + +@export +def get_filename(sql): + if sql.strip().startswith('\\e'): + command, _, filename = sql.partition(' ') + return filename.strip() or None + + +@export +def get_editor_query(sql): + """Get the query part of an editor command.""" + sql = sql.strip() + + # The reason we can't simply do .strip('\e') is that it strips characters, + # not a substring. So it'll strip "e" in the end of the sql also! + # Ex: "select * from style\e" -> "select * from styl". + pattern = re.compile('(^\\\e|\\\e$)') + while pattern.search(sql): + sql = pattern.sub('', sql) + + return sql + + +@export +def open_external_editor(filename=None, sql=None): + """Open external editor, wait for the user to type in their query, return + the query. + + :return: list with one tuple, query as first element. + + """ + + message = None + filename = filename.strip().split(' ', 1)[0] if filename else None + + sql = sql or '' + MARKER = '# Type your query above this line.\n' + + # Populate the editor buffer with the partial sql (if available) and a + # placeholder comment. + query = click.edit(u'{sql}\n\n{marker}'.format(sql=sql, marker=MARKER), + filename=filename, extension='.sql') + + if filename: + try: + with open(filename) as f: + query = f.read() + except IOError: + message = 'Error reading file: %s.' % filename + + if query is not None: + query = query.split(MARKER, 1)[0].rstrip('\n') + else: + # Don't return None for the caller to deal with. + # Empty string is ok. + query = sql + + return (query, message) + + +@special_command('\\f', '\\f [name [args..]]', 'List or execute favorite queries.', arg_type=PARSED_QUERY, case_sensitive=True) +def execute_favorite_query(cur, arg, **_): + """Returns (title, rows, headers, status)""" + if arg == '': + for result in list_favorite_queries(): + yield result + + """Parse out favorite name and optional substitution parameters""" + name, _, arg_str = arg.partition(' ') + args = shlex.split(arg_str) + + query = FavoriteQueries.instance.get(name) + if query is None: + message = "No favorite query: %s" % (name) + yield (None, None, None, message) + else: + query, arg_error = subst_favorite_query_args(query, args) + if arg_error: + yield (None, None, None, arg_error) + else: + for sql in sqlparse.split(query): + sql = sql.rstrip(';') + title = '> %s' % (sql) + cur.execute(sql) + if cur.description: + headers = [x[0] for x in cur.description] + yield (title, cur, headers, None) + else: + yield (title, None, None, None) + +def list_favorite_queries(): + """List of all favorite queries. + Returns (title, rows, headers, status)""" + + headers = ["Name", "Query"] + rows = [(r, FavoriteQueries.instance.get(r)) + for r in FavoriteQueries.instance.list()] + + if not rows: + status = '\nNo favorite queries found.' + FavoriteQueries.instance.usage + else: + status = '' + return [('', rows, headers, status)] + + +def subst_favorite_query_args(query, args): + """replace positional parameters ($1...$N) in query.""" + for idx, val in enumerate(args): + subst_var = '$' + str(idx + 1) + if subst_var not in query: + return [None, 'query does not have substitution parameter ' + subst_var + ':\n ' + query] + + query = query.replace(subst_var, val) + + match = re.search('\\$\d+', query) + if match: + return[None, 'missing substitution for ' + match.group(0) + ' in query:\n ' + query] + + return [query, None] + +@special_command('\\fs', '\\fs name query', 'Save a favorite query.') +def save_favorite_query(arg, **_): + """Save a new favorite query. + Returns (title, rows, headers, status)""" + + usage = 'Syntax: \\fs name query.\n\n' + FavoriteQueries.instance.usage + if not arg: + return [(None, None, None, usage)] + + name, _, query = arg.partition(' ') + + # If either name or query is missing then print the usage and complain. + if (not name) or (not query): + return [(None, None, None, + usage + 'Err: Both name and query are required.')] + + FavoriteQueries.instance.save(name, query) + return [(None, None, None, "Saved.")] + + +@special_command('\\fd', '\\fd [name]', 'Delete a favorite query.') +def delete_favorite_query(arg, **_): + """Delete an existing favorite query.""" + usage = 'Syntax: \\fd name.\n\n' + FavoriteQueries.instance.usage + if not arg: + return [(None, None, None, usage)] + + status = FavoriteQueries.instance.delete(arg) + + return [(None, None, None, status)] + + +@special_command('system', 'system [command]', + 'Execute a system shell commmand.') +def execute_system_command(arg, **_): + """Execute a system shell command.""" + usage = "Syntax: system [command].\n" + + if not arg: + return [(None, None, None, usage)] + + try: + command = arg.strip() + if command.startswith('cd'): + ok, error_message = handle_cd_command(arg) + if not ok: + return [(None, None, None, error_message)] + return [(None, None, None, '')] + + args = arg.split(' ') + process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + output, error = process.communicate() + response = output if not error else error + + # Python 3 returns bytes. This needs to be decoded to a string. + if isinstance(response, bytes): + encoding = locale.getpreferredencoding(False) + response = response.decode(encoding) + + return [(None, None, None, response)] + except OSError as e: + return [(None, None, None, 'OSError: %s' % e.strerror)] + + +def parseargfile(arg): + if arg.startswith('-o '): + mode = "w" + filename = arg[3:] + else: + mode = 'a' + filename = arg + + if not filename: + raise TypeError('You must provide a filename.') + + return {'file': os.path.expanduser(filename), 'mode': mode} + + +@special_command('tee', 'tee [-o] filename', + 'Append all results to an output file (overwrite using -o).') +def set_tee(arg, **_): + global tee_file + + try: + tee_file = open(**parseargfile(arg)) + except (IOError, OSError) as e: + raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror)) + + return [(None, None, None, "")] + +@export +def close_tee(): + global tee_file + if tee_file: + tee_file.close() + tee_file = None + + +@special_command('notee', 'notee', 'Stop writing results to an output file.') +def no_tee(arg, **_): + close_tee() + return [(None, None, None, "")] + +@export +def write_tee(output): + global tee_file + if tee_file: + click.echo(output, file=tee_file, nl=False) + click.echo(u'\n', file=tee_file, nl=False) + tee_file.flush() + + +@special_command('\\once', '\\o [-o] filename', + 'Append next result to an output file (overwrite using -o).', + aliases=('\\o', )) +def set_once(arg, **_): + global once_file, written_to_once_file + + once_file = parseargfile(arg) + written_to_once_file = False + + return [(None, None, None, "")] + + +@export +def write_once(output): + global once_file, written_to_once_file + if output and once_file: + try: + f = open(**once_file) + except (IOError, OSError) as e: + once_file = None + raise OSError("Cannot write to file '{}': {}".format( + e.filename, e.strerror)) + with f: + click.echo(output, file=f, nl=False) + click.echo(u"\n", file=f, nl=False) + written_to_once_file = True + + +@export +def unset_once_if_written(): + """Unset the once file, if it has been written to.""" + global once_file + if written_to_once_file: + once_file = None + + +@special_command( + 'watch', + 'watch [seconds] [-c] query', + 'Executes the query every [seconds] seconds (by default 5).' +) +def watch_query(arg, **kwargs): + usage = """Syntax: watch [seconds] [-c] query. + * seconds: The interval at the query will be repeated, in seconds. + By default 5. + * -c: Clears the screen between every iteration. +""" + if not arg: + yield (None, None, None, usage) + return + seconds = 5 + clear_screen = False + statement = None + while statement is None: + arg = arg.strip() + if not arg: + # Oops, we parsed all the arguments without finding a statement + yield (None, None, None, usage) + return + (current_arg, _, arg) = arg.partition(' ') + try: + seconds = float(current_arg) + continue + except ValueError: + pass + if current_arg == '-c': + clear_screen = True + continue + statement = '{0!s} {1!s}'.format(current_arg, arg) + destructive_prompt = confirm_destructive_query(statement) + if destructive_prompt is False: + click.secho("Wise choice!") + return + elif destructive_prompt is True: + click.secho("Your call!") + cur = kwargs['cur'] + sql_list = [ + (sql.rstrip(';'), "> {0!s}".format(sql)) + for sql in sqlparse.split(statement) + ] + old_pager_enabled = is_pager_enabled() + while True: + if clear_screen: + click.clear() + try: + # Somewhere in the code the pager its activated after every yield, + # so we disable it in every iteration + set_pager_enabled(False) + for (sql, title) in sql_list: + cur.execute(sql) + if cur.description: + headers = [x[0] for x in cur.description] + yield (title, cur, headers, None) + else: + yield (title, None, None, None) + sleep(seconds) + except KeyboardInterrupt: + # This prints the Ctrl-C character in its own line, which prevents + # to print a line with the cursor positioned behind the prompt + click.secho("", nl=True) + return + finally: + set_pager_enabled(old_pager_enabled) + + +@export +@special_command('delimiter', None, 'Change SQL delimiter.') +def set_delimiter(arg, **_): + return delimiter_command.set(arg) + + +@export +def get_current_delimiter(): + return delimiter_command.current + + +@export +def split_queries(input): + for query in delimiter_command.queries_iter(input): + yield query |