summaryrefslogtreecommitdiffstats
path: root/mycli/packages/special/iocommands.py
diff options
context:
space:
mode:
Diffstat (limited to 'mycli/packages/special/iocommands.py')
-rw-r--r--mycli/packages/special/iocommands.py453
1 files changed, 453 insertions, 0 deletions
diff --git a/mycli/packages/special/iocommands.py b/mycli/packages/special/iocommands.py
new file mode 100644
index 0000000..11dca8d
--- /dev/null
+++ b/mycli/packages/special/iocommands.py
@@ -0,0 +1,453 @@
+import os
+import re
+import locale
+import logging
+import subprocess
+import shlex
+from io import open
+from time import sleep
+
+import click
+import sqlparse
+
+from . import export
+from .main import special_command, NO_QUERY, PARSED_QUERY
+from .favoritequeries import FavoriteQueries
+from .delimitercommand import DelimiterCommand
+from .utils import handle_cd_command
+from mycli.packages.prompt_utils import confirm_destructive_query
+
+TIMING_ENABLED = False
+use_expanded_output = False
+PAGER_ENABLED = True
+tee_file = None
+once_file = None
+written_to_once_file = False
+delimiter_command = DelimiterCommand()
+
+
+@export
+def set_timing_enabled(val):
+ global TIMING_ENABLED
+ TIMING_ENABLED = val
+
+@export
+def set_pager_enabled(val):
+ global PAGER_ENABLED
+ PAGER_ENABLED = val
+
+
+@export
+def is_pager_enabled():
+ return PAGER_ENABLED
+
+@export
+@special_command('pager', '\\P [command]',
+ 'Set PAGER. Print the query results via PAGER.',
+ arg_type=PARSED_QUERY, aliases=('\\P', ), case_sensitive=True)
+def set_pager(arg, **_):
+ if arg:
+ os.environ['PAGER'] = arg
+ msg = 'PAGER set to %s.' % arg
+ set_pager_enabled(True)
+ else:
+ if 'PAGER' in os.environ:
+ msg = 'PAGER set to %s.' % os.environ['PAGER']
+ else:
+ # This uses click's default per echo_via_pager.
+ msg = 'Pager enabled.'
+ set_pager_enabled(True)
+
+ return [(None, None, None, msg)]
+
+@export
+@special_command('nopager', '\\n', 'Disable pager, print to stdout.',
+ arg_type=NO_QUERY, aliases=('\\n', ), case_sensitive=True)
+def disable_pager():
+ set_pager_enabled(False)
+ return [(None, None, None, 'Pager disabled.')]
+
+@special_command('\\timing', '\\t', 'Toggle timing of commands.', arg_type=NO_QUERY, aliases=('\\t', ), case_sensitive=True)
+def toggle_timing():
+ global TIMING_ENABLED
+ TIMING_ENABLED = not TIMING_ENABLED
+ message = "Timing is "
+ message += "on." if TIMING_ENABLED else "off."
+ return [(None, None, None, message)]
+
+@export
+def is_timing_enabled():
+ return TIMING_ENABLED
+
+@export
+def set_expanded_output(val):
+ global use_expanded_output
+ use_expanded_output = val
+
+@export
+def is_expanded_output():
+ return use_expanded_output
+
+_logger = logging.getLogger(__name__)
+
+@export
+def editor_command(command):
+ """
+ Is this an external editor command?
+ :param command: string
+ """
+ # It is possible to have `\e filename` or `SELECT * FROM \e`. So we check
+ # for both conditions.
+ return command.strip().endswith('\\e') or command.strip().startswith('\\e')
+
+@export
+def get_filename(sql):
+ if sql.strip().startswith('\\e'):
+ command, _, filename = sql.partition(' ')
+ return filename.strip() or None
+
+
+@export
+def get_editor_query(sql):
+ """Get the query part of an editor command."""
+ sql = sql.strip()
+
+ # The reason we can't simply do .strip('\e') is that it strips characters,
+ # not a substring. So it'll strip "e" in the end of the sql also!
+ # Ex: "select * from style\e" -> "select * from styl".
+ pattern = re.compile('(^\\\e|\\\e$)')
+ while pattern.search(sql):
+ sql = pattern.sub('', sql)
+
+ return sql
+
+
+@export
+def open_external_editor(filename=None, sql=None):
+ """Open external editor, wait for the user to type in their query, return
+ the query.
+
+ :return: list with one tuple, query as first element.
+
+ """
+
+ message = None
+ filename = filename.strip().split(' ', 1)[0] if filename else None
+
+ sql = sql or ''
+ MARKER = '# Type your query above this line.\n'
+
+ # Populate the editor buffer with the partial sql (if available) and a
+ # placeholder comment.
+ query = click.edit(u'{sql}\n\n{marker}'.format(sql=sql, marker=MARKER),
+ filename=filename, extension='.sql')
+
+ if filename:
+ try:
+ with open(filename) as f:
+ query = f.read()
+ except IOError:
+ message = 'Error reading file: %s.' % filename
+
+ if query is not None:
+ query = query.split(MARKER, 1)[0].rstrip('\n')
+ else:
+ # Don't return None for the caller to deal with.
+ # Empty string is ok.
+ query = sql
+
+ return (query, message)
+
+
+@special_command('\\f', '\\f [name [args..]]', 'List or execute favorite queries.', arg_type=PARSED_QUERY, case_sensitive=True)
+def execute_favorite_query(cur, arg, **_):
+ """Returns (title, rows, headers, status)"""
+ if arg == '':
+ for result in list_favorite_queries():
+ yield result
+
+ """Parse out favorite name and optional substitution parameters"""
+ name, _, arg_str = arg.partition(' ')
+ args = shlex.split(arg_str)
+
+ query = FavoriteQueries.instance.get(name)
+ if query is None:
+ message = "No favorite query: %s" % (name)
+ yield (None, None, None, message)
+ else:
+ query, arg_error = subst_favorite_query_args(query, args)
+ if arg_error:
+ yield (None, None, None, arg_error)
+ else:
+ for sql in sqlparse.split(query):
+ sql = sql.rstrip(';')
+ title = '> %s' % (sql)
+ cur.execute(sql)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ yield (title, cur, headers, None)
+ else:
+ yield (title, None, None, None)
+
+def list_favorite_queries():
+ """List of all favorite queries.
+ Returns (title, rows, headers, status)"""
+
+ headers = ["Name", "Query"]
+ rows = [(r, FavoriteQueries.instance.get(r))
+ for r in FavoriteQueries.instance.list()]
+
+ if not rows:
+ status = '\nNo favorite queries found.' + FavoriteQueries.instance.usage
+ else:
+ status = ''
+ return [('', rows, headers, status)]
+
+
+def subst_favorite_query_args(query, args):
+ """replace positional parameters ($1...$N) in query."""
+ for idx, val in enumerate(args):
+ subst_var = '$' + str(idx + 1)
+ if subst_var not in query:
+ return [None, 'query does not have substitution parameter ' + subst_var + ':\n ' + query]
+
+ query = query.replace(subst_var, val)
+
+ match = re.search('\\$\d+', query)
+ if match:
+ return[None, 'missing substitution for ' + match.group(0) + ' in query:\n ' + query]
+
+ return [query, None]
+
+@special_command('\\fs', '\\fs name query', 'Save a favorite query.')
+def save_favorite_query(arg, **_):
+ """Save a new favorite query.
+ Returns (title, rows, headers, status)"""
+
+ usage = 'Syntax: \\fs name query.\n\n' + FavoriteQueries.instance.usage
+ if not arg:
+ return [(None, None, None, usage)]
+
+ name, _, query = arg.partition(' ')
+
+ # If either name or query is missing then print the usage and complain.
+ if (not name) or (not query):
+ return [(None, None, None,
+ usage + 'Err: Both name and query are required.')]
+
+ FavoriteQueries.instance.save(name, query)
+ return [(None, None, None, "Saved.")]
+
+
+@special_command('\\fd', '\\fd [name]', 'Delete a favorite query.')
+def delete_favorite_query(arg, **_):
+ """Delete an existing favorite query."""
+ usage = 'Syntax: \\fd name.\n\n' + FavoriteQueries.instance.usage
+ if not arg:
+ return [(None, None, None, usage)]
+
+ status = FavoriteQueries.instance.delete(arg)
+
+ return [(None, None, None, status)]
+
+
+@special_command('system', 'system [command]',
+ 'Execute a system shell commmand.')
+def execute_system_command(arg, **_):
+ """Execute a system shell command."""
+ usage = "Syntax: system [command].\n"
+
+ if not arg:
+ return [(None, None, None, usage)]
+
+ try:
+ command = arg.strip()
+ if command.startswith('cd'):
+ ok, error_message = handle_cd_command(arg)
+ if not ok:
+ return [(None, None, None, error_message)]
+ return [(None, None, None, '')]
+
+ args = arg.split(' ')
+ process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ output, error = process.communicate()
+ response = output if not error else error
+
+ # Python 3 returns bytes. This needs to be decoded to a string.
+ if isinstance(response, bytes):
+ encoding = locale.getpreferredencoding(False)
+ response = response.decode(encoding)
+
+ return [(None, None, None, response)]
+ except OSError as e:
+ return [(None, None, None, 'OSError: %s' % e.strerror)]
+
+
+def parseargfile(arg):
+ if arg.startswith('-o '):
+ mode = "w"
+ filename = arg[3:]
+ else:
+ mode = 'a'
+ filename = arg
+
+ if not filename:
+ raise TypeError('You must provide a filename.')
+
+ return {'file': os.path.expanduser(filename), 'mode': mode}
+
+
+@special_command('tee', 'tee [-o] filename',
+ 'Append all results to an output file (overwrite using -o).')
+def set_tee(arg, **_):
+ global tee_file
+
+ try:
+ tee_file = open(**parseargfile(arg))
+ except (IOError, OSError) as e:
+ raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror))
+
+ return [(None, None, None, "")]
+
+@export
+def close_tee():
+ global tee_file
+ if tee_file:
+ tee_file.close()
+ tee_file = None
+
+
+@special_command('notee', 'notee', 'Stop writing results to an output file.')
+def no_tee(arg, **_):
+ close_tee()
+ return [(None, None, None, "")]
+
+@export
+def write_tee(output):
+ global tee_file
+ if tee_file:
+ click.echo(output, file=tee_file, nl=False)
+ click.echo(u'\n', file=tee_file, nl=False)
+ tee_file.flush()
+
+
+@special_command('\\once', '\\o [-o] filename',
+ 'Append next result to an output file (overwrite using -o).',
+ aliases=('\\o', ))
+def set_once(arg, **_):
+ global once_file, written_to_once_file
+
+ once_file = parseargfile(arg)
+ written_to_once_file = False
+
+ return [(None, None, None, "")]
+
+
+@export
+def write_once(output):
+ global once_file, written_to_once_file
+ if output and once_file:
+ try:
+ f = open(**once_file)
+ except (IOError, OSError) as e:
+ once_file = None
+ raise OSError("Cannot write to file '{}': {}".format(
+ e.filename, e.strerror))
+ with f:
+ click.echo(output, file=f, nl=False)
+ click.echo(u"\n", file=f, nl=False)
+ written_to_once_file = True
+
+
+@export
+def unset_once_if_written():
+ """Unset the once file, if it has been written to."""
+ global once_file
+ if written_to_once_file:
+ once_file = None
+
+
+@special_command(
+ 'watch',
+ 'watch [seconds] [-c] query',
+ 'Executes the query every [seconds] seconds (by default 5).'
+)
+def watch_query(arg, **kwargs):
+ usage = """Syntax: watch [seconds] [-c] query.
+ * seconds: The interval at the query will be repeated, in seconds.
+ By default 5.
+ * -c: Clears the screen between every iteration.
+"""
+ if not arg:
+ yield (None, None, None, usage)
+ return
+ seconds = 5
+ clear_screen = False
+ statement = None
+ while statement is None:
+ arg = arg.strip()
+ if not arg:
+ # Oops, we parsed all the arguments without finding a statement
+ yield (None, None, None, usage)
+ return
+ (current_arg, _, arg) = arg.partition(' ')
+ try:
+ seconds = float(current_arg)
+ continue
+ except ValueError:
+ pass
+ if current_arg == '-c':
+ clear_screen = True
+ continue
+ statement = '{0!s} {1!s}'.format(current_arg, arg)
+ destructive_prompt = confirm_destructive_query(statement)
+ if destructive_prompt is False:
+ click.secho("Wise choice!")
+ return
+ elif destructive_prompt is True:
+ click.secho("Your call!")
+ cur = kwargs['cur']
+ sql_list = [
+ (sql.rstrip(';'), "> {0!s}".format(sql))
+ for sql in sqlparse.split(statement)
+ ]
+ old_pager_enabled = is_pager_enabled()
+ while True:
+ if clear_screen:
+ click.clear()
+ try:
+ # Somewhere in the code the pager its activated after every yield,
+ # so we disable it in every iteration
+ set_pager_enabled(False)
+ for (sql, title) in sql_list:
+ cur.execute(sql)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ yield (title, cur, headers, None)
+ else:
+ yield (title, None, None, None)
+ sleep(seconds)
+ except KeyboardInterrupt:
+ # This prints the Ctrl-C character in its own line, which prevents
+ # to print a line with the cursor positioned behind the prompt
+ click.secho("", nl=True)
+ return
+ finally:
+ set_pager_enabled(old_pager_enabled)
+
+
+@export
+@special_command('delimiter', None, 'Change SQL delimiter.')
+def set_delimiter(arg, **_):
+ return delimiter_command.set(arg)
+
+
+@export
+def get_current_delimiter():
+ return delimiter_command.current
+
+
+@export
+def split_queries(input):
+ for query in delimiter_command.queries_iter(input):
+ yield query