summaryrefslogtreecommitdiffstats
path: root/mycli/packages
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2021-02-08 11:28:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2021-02-08 11:28:14 +0000
commitb678a621c57a6d3fdfac14bdbbef0ed743ab1742 (patch)
tree5481c14ce75dfda9c55721de033992b45ab0e1dc /mycli/packages
parentInitial commit. (diff)
downloadmycli-b678a621c57a6d3fdfac14bdbbef0ed743ab1742.tar.xz
mycli-b678a621c57a6d3fdfac14bdbbef0ed743ab1742.zip
Adding upstream version 1.22.2.upstream/1.22.2
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'mycli/packages')
-rw-r--r--mycli/packages/__init__.py0
-rw-r--r--mycli/packages/completion_engine.py295
-rw-r--r--mycli/packages/filepaths.py106
-rw-r--r--mycli/packages/paramiko_stub/__init__.py28
-rw-r--r--mycli/packages/parseutils.py267
-rw-r--r--mycli/packages/prompt_utils.py54
-rw-r--r--mycli/packages/special/__init__.py10
-rw-r--r--mycli/packages/special/dbcommands.py157
-rw-r--r--mycli/packages/special/delimitercommand.py80
-rw-r--r--mycli/packages/special/favoritequeries.py63
-rw-r--r--mycli/packages/special/iocommands.py453
-rw-r--r--mycli/packages/special/main.py118
-rw-r--r--mycli/packages/special/utils.py46
-rw-r--r--mycli/packages/tabular_output/__init__.py0
-rw-r--r--mycli/packages/tabular_output/sql_format.py63
15 files changed, 1740 insertions, 0 deletions
diff --git a/mycli/packages/__init__.py b/mycli/packages/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/mycli/packages/__init__.py
diff --git a/mycli/packages/completion_engine.py b/mycli/packages/completion_engine.py
new file mode 100644
index 0000000..2b19c32
--- /dev/null
+++ b/mycli/packages/completion_engine.py
@@ -0,0 +1,295 @@
+import os
+import sys
+import sqlparse
+from sqlparse.sql import Comparison, Identifier, Where
+from sqlparse.compat import text_type
+from .parseutils import last_word, extract_tables, find_prev_keyword
+from .special import parse_special_command
+
+
+def suggest_type(full_text, text_before_cursor):
+ """Takes the full_text that is typed so far and also the text before the
+ cursor to suggest completion type and scope.
+
+ Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
+ A scope for a column category will be a list of tables.
+ """
+
+ word_before_cursor = last_word(text_before_cursor,
+ include='many_punctuations')
+
+ identifier = None
+
+ # here should be removed once sqlparse has been fixed
+ try:
+ # If we've partially typed a word then word_before_cursor won't be an empty
+ # string. In that case we want to remove the partially typed string before
+ # sending it to the sqlparser. Otherwise the last token will always be the
+ # partially typed string which renders the smart completion useless because
+ # it will always return the list of keywords as completion.
+ if word_before_cursor:
+ if word_before_cursor.endswith(
+ '(') or word_before_cursor.startswith('\\'):
+ parsed = sqlparse.parse(text_before_cursor)
+ else:
+ parsed = sqlparse.parse(
+ text_before_cursor[:-len(word_before_cursor)])
+
+ # word_before_cursor may include a schema qualification, like
+ # "schema_name.partial_name" or "schema_name.", so parse it
+ # separately
+ p = sqlparse.parse(word_before_cursor)[0]
+
+ if p.tokens and isinstance(p.tokens[0], Identifier):
+ identifier = p.tokens[0]
+ else:
+ parsed = sqlparse.parse(text_before_cursor)
+ except (TypeError, AttributeError):
+ return [{'type': 'keyword'}]
+
+ if len(parsed) > 1:
+ # Multiple statements being edited -- isolate the current one by
+ # cumulatively summing statement lengths to find the one that bounds the
+ # current position
+ current_pos = len(text_before_cursor)
+ stmt_start, stmt_end = 0, 0
+
+ for statement in parsed:
+ stmt_len = len(text_type(statement))
+ stmt_start, stmt_end = stmt_end, stmt_end + stmt_len
+
+ if stmt_end >= current_pos:
+ text_before_cursor = full_text[stmt_start:current_pos]
+ full_text = full_text[stmt_start:]
+ break
+
+ elif parsed:
+ # A single statement
+ statement = parsed[0]
+ else:
+ # The empty string
+ statement = None
+
+ # Check for special commands and handle those separately
+ if statement:
+ # Be careful here because trivial whitespace is parsed as a statement,
+ # but the statement won't have a first token
+ tok1 = statement.token_first()
+ if tok1 and (tok1.value == 'source' or tok1.value.startswith('\\')):
+ return suggest_special(text_before_cursor)
+
+ last_token = statement and statement.token_prev(len(statement.tokens))[1] or ''
+
+ return suggest_based_on_last_token(last_token, text_before_cursor,
+ full_text, identifier)
+
+
+def suggest_special(text):
+ text = text.lstrip()
+ cmd, _, arg = parse_special_command(text)
+
+ if cmd == text:
+ # Trying to complete the special command itself
+ return [{'type': 'special'}]
+
+ if cmd in ('\\u', '\\r'):
+ return [{'type': 'database'}]
+
+ if cmd in ('\\T'):
+ return [{'type': 'table_format'}]
+
+ if cmd in ['\\f', '\\fs', '\\fd']:
+ return [{'type': 'favoritequery'}]
+
+ if cmd in ['\\dt', '\\dt+']:
+ return [
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'},
+ ]
+ elif cmd in ['\\.', 'source']:
+ return[{'type': 'file_name'}]
+
+ return [{'type': 'keyword'}, {'type': 'special'}]
+
+
+def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier):
+ if isinstance(token, str):
+ token_v = token.lower()
+ elif isinstance(token, Comparison):
+ # If 'token' is a Comparison type such as
+ # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling
+ # token.value on the comparison type will only return the lhs of the
+ # comparison. In this case a.id. So we need to do token.tokens to get
+ # both sides of the comparison and pick the last token out of that
+ # list.
+ token_v = token.tokens[-1].value.lower()
+ elif isinstance(token, Where):
+ # sqlparse groups all tokens from the where clause into a single token
+ # list. This means that token.value may be something like
+ # 'where foo > 5 and '. We need to look "inside" token.tokens to handle
+ # suggestions in complicated where clauses correctly
+ prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
+ return suggest_based_on_last_token(prev_keyword, text_before_cursor,
+ full_text, identifier)
+ else:
+ token_v = token.value.lower()
+
+ is_operand = lambda x: x and any([x.endswith(op) for op in ['+', '-', '*', '/']])
+
+ if not token:
+ return [{'type': 'keyword'}, {'type': 'special'}]
+ elif token_v.endswith('('):
+ p = sqlparse.parse(text_before_cursor)[0]
+
+ if p.tokens and isinstance(p.tokens[-1], Where):
+ # Four possibilities:
+ # 1 - Parenthesized clause like "WHERE foo AND ("
+ # Suggest columns/functions
+ # 2 - Function call like "WHERE foo("
+ # Suggest columns/functions
+ # 3 - Subquery expression like "WHERE EXISTS ("
+ # Suggest keywords, in order to do a subquery
+ # 4 - Subquery OR array comparison like "WHERE foo = ANY("
+ # Suggest columns/functions AND keywords. (If we wanted to be
+ # really fancy, we could suggest only array-typed columns)
+
+ column_suggestions = suggest_based_on_last_token('where',
+ text_before_cursor, full_text, identifier)
+
+ # Check for a subquery expression (cases 3 & 4)
+ where = p.tokens[-1]
+ idx, prev_tok = where.token_prev(len(where.tokens) - 1)
+
+ if isinstance(prev_tok, Comparison):
+ # e.g. "SELECT foo FROM bar WHERE foo = ANY("
+ prev_tok = prev_tok.tokens[-1]
+
+ prev_tok = prev_tok.value.lower()
+ if prev_tok == 'exists':
+ return [{'type': 'keyword'}]
+ else:
+ return column_suggestions
+
+ # Get the token before the parens
+ idx, prev_tok = p.token_prev(len(p.tokens) - 1)
+ if prev_tok and prev_tok.value and prev_tok.value.lower() == 'using':
+ # tbl1 INNER JOIN tbl2 USING (col1, col2)
+ tables = extract_tables(full_text)
+
+ # suggest columns that are present in more than one table
+ return [{'type': 'column', 'tables': tables, 'drop_unique': True}]
+ elif p.token_first().value.lower() == 'select':
+ # If the lparen is preceeded by a space chances are we're about to
+ # do a sub-select.
+ if last_word(text_before_cursor,
+ 'all_punctuations').startswith('('):
+ return [{'type': 'keyword'}]
+ elif p.token_first().value.lower() == 'show':
+ return [{'type': 'show'}]
+
+ # We're probably in a function argument list
+ return [{'type': 'column', 'tables': extract_tables(full_text)}]
+ elif token_v in ('set', 'order by', 'distinct'):
+ return [{'type': 'column', 'tables': extract_tables(full_text)}]
+ elif token_v == 'as':
+ # Don't suggest anything for an alias
+ return []
+ elif token_v in ('show'):
+ return [{'type': 'show'}]
+ elif token_v in ('to',):
+ p = sqlparse.parse(text_before_cursor)[0]
+ if p.token_first().value.lower() == 'change':
+ return [{'type': 'change'}]
+ else:
+ return [{'type': 'user'}]
+ elif token_v in ('user', 'for'):
+ return [{'type': 'user'}]
+ elif token_v in ('select', 'where', 'having'):
+ # Check for a table alias or schema qualification
+ parent = (identifier and identifier.get_parent_name()) or []
+
+ tables = extract_tables(full_text)
+ if parent:
+ tables = [t for t in tables if identifies(parent, *t)]
+ return [{'type': 'column', 'tables': tables},
+ {'type': 'table', 'schema': parent},
+ {'type': 'view', 'schema': parent},
+ {'type': 'function', 'schema': parent}]
+ else:
+ aliases = [alias or table for (schema, table, alias) in tables]
+ return [{'type': 'column', 'tables': tables},
+ {'type': 'function', 'schema': []},
+ {'type': 'alias', 'aliases': aliases},
+ {'type': 'keyword'}]
+ elif (token_v.endswith('join') and token.is_keyword) or (token_v in
+ ('copy', 'from', 'update', 'into', 'describe', 'truncate',
+ 'desc', 'explain')):
+ schema = (identifier and identifier.get_parent_name()) or []
+
+ # Suggest tables from either the currently-selected schema or the
+ # public schema if no schema has been specified
+ suggest = [{'type': 'table', 'schema': schema}]
+
+ if not schema:
+ # Suggest schemas
+ suggest.insert(0, {'type': 'schema'})
+
+ # Only tables can be TRUNCATED, otherwise suggest views
+ if token_v != 'truncate':
+ suggest.append({'type': 'view', 'schema': schema})
+
+ return suggest
+
+ elif token_v in ('table', 'view', 'function'):
+ # E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>'
+ rel_type = token_v
+ schema = (identifier and identifier.get_parent_name()) or []
+ if schema:
+ return [{'type': rel_type, 'schema': schema}]
+ else:
+ return [{'type': 'schema'}, {'type': rel_type, 'schema': []}]
+ elif token_v == 'on':
+ tables = extract_tables(full_text) # [(schema, table, alias), ...]
+ parent = (identifier and identifier.get_parent_name()) or []
+ if parent:
+ # "ON parent.<suggestion>"
+ # parent can be either a schema name or table alias
+ tables = [t for t in tables if identifies(parent, *t)]
+ return [{'type': 'column', 'tables': tables},
+ {'type': 'table', 'schema': parent},
+ {'type': 'view', 'schema': parent},
+ {'type': 'function', 'schema': parent}]
+ else:
+ # ON <suggestion>
+ # Use table alias if there is one, otherwise the table name
+ aliases = [alias or table for (schema, table, alias) in tables]
+ suggest = [{'type': 'alias', 'aliases': aliases}]
+
+ # The lists of 'aliases' could be empty if we're trying to complete
+ # a GRANT query. eg: GRANT SELECT, INSERT ON <tab>
+ # In that case we just suggest all tables.
+ if not aliases:
+ suggest.append({'type': 'table', 'schema': parent})
+ return suggest
+
+ elif token_v in ('use', 'database', 'template', 'connect'):
+ # "\c <db", "use <db>", "DROP DATABASE <db>",
+ # "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
+ return [{'type': 'database'}]
+ elif token_v == 'tableformat':
+ return [{'type': 'table_format'}]
+ elif token_v.endswith(',') or is_operand(token_v) or token_v in ['=', 'and', 'or']:
+ prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
+ if prev_keyword:
+ return suggest_based_on_last_token(
+ prev_keyword, text_before_cursor, full_text, identifier)
+ else:
+ return []
+ else:
+ return [{'type': 'keyword'}]
+
+
+def identifies(id, schema, table, alias):
+ return id == alias or id == table or (
+ schema and (id == schema + '.' + table))
diff --git a/mycli/packages/filepaths.py b/mycli/packages/filepaths.py
new file mode 100644
index 0000000..79fe26d
--- /dev/null
+++ b/mycli/packages/filepaths.py
@@ -0,0 +1,106 @@
+import os
+import platform
+
+
+if os.name == "posix":
+ if platform.system() == "Darwin":
+ DEFAULT_SOCKET_DIRS = ("/tmp",)
+ else:
+ DEFAULT_SOCKET_DIRS = ("/var/run", "/var/lib")
+else:
+ DEFAULT_SOCKET_DIRS = ()
+
+
+def list_path(root_dir):
+ """List directory if exists.
+
+ :param root_dir: str
+ :return: list
+
+ """
+ res = []
+ if os.path.isdir(root_dir):
+ for name in os.listdir(root_dir):
+ res.append(name)
+ return res
+
+
+def complete_path(curr_dir, last_dir):
+ """Return the path to complete that matches the last entered component.
+
+ If the last entered component is ~, expanded path would not
+ match, so return all of the available paths.
+
+ :param curr_dir: str
+ :param last_dir: str
+ :return: str
+
+ """
+ if not last_dir or curr_dir.startswith(last_dir):
+ return curr_dir
+ elif last_dir == '~':
+ return os.path.join(last_dir, curr_dir)
+
+
+def parse_path(root_dir):
+ """Split path into head and last component for the completer.
+
+ Also return position where last component starts.
+
+ :param root_dir: str path
+ :return: tuple of (string, string, int)
+
+ """
+ base_dir, last_dir, position = '', '', 0
+ if root_dir:
+ base_dir, last_dir = os.path.split(root_dir)
+ position = -len(last_dir) if last_dir else 0
+ return base_dir, last_dir, position
+
+
+def suggest_path(root_dir):
+ """List all files and subdirectories in a directory.
+
+ If the directory is not specified, suggest root directory,
+ user directory, current and parent directory.
+
+ :param root_dir: string: directory to list
+ :return: list
+
+ """
+ if not root_dir:
+ return [os.path.abspath(os.sep), '~', os.curdir, os.pardir]
+
+ if '~' in root_dir:
+ root_dir = os.path.expanduser(root_dir)
+
+ if not os.path.exists(root_dir):
+ root_dir, _ = os.path.split(root_dir)
+
+ return list_path(root_dir)
+
+
+def dir_path_exists(path):
+ """Check if the directory path exists for a given file.
+
+ For example, for a file /home/user/.cache/mycli/log, check if
+ /home/user/.cache/mycli exists.
+
+ :param str path: The file path.
+ :return: Whether or not the directory path exists.
+
+ """
+ return os.path.exists(os.path.dirname(path))
+
+
+def guess_socket_location():
+ """Try to guess the location of the default mysql socket file."""
+ socket_dirs = filter(os.path.exists, DEFAULT_SOCKET_DIRS)
+ for directory in socket_dirs:
+ for r, dirs, files in os.walk(directory, topdown=True):
+ for filename in files:
+ name, ext = os.path.splitext(filename)
+ if name.startswith("mysql") and ext in ('.socket', '.sock'):
+ return os.path.join(r, filename)
+ dirs[:] = [d for d in dirs if d.startswith("mysql")]
+ return None
diff --git a/mycli/packages/paramiko_stub/__init__.py b/mycli/packages/paramiko_stub/__init__.py
new file mode 100644
index 0000000..045b00e
--- /dev/null
+++ b/mycli/packages/paramiko_stub/__init__.py
@@ -0,0 +1,28 @@
+"""A module to import instead of paramiko when it is not available (to avoid
+checking for paramiko all over the place).
+
+When paramiko is first envoked, it simply shuts down mycli, telling
+user they either have to install paramiko or should not use SSH
+features.
+
+"""
+
+
+class Paramiko:
+ def __getattr__(self, name):
+ import sys
+ from textwrap import dedent
+ print(dedent("""
+ To enable certain SSH features you need to install paramiko:
+
+ pip install paramiko
+
+ It is required for the following configuration options:
+ --list-ssh-config
+ --ssh-config-host
+ --ssh-host
+ """))
+ sys.exit(1)
+
+
+paramiko = Paramiko()
diff --git a/mycli/packages/parseutils.py b/mycli/packages/parseutils.py
new file mode 100644
index 0000000..e3b383e
--- /dev/null
+++ b/mycli/packages/parseutils.py
@@ -0,0 +1,267 @@
+import re
+import sqlparse
+from sqlparse.sql import IdentifierList, Identifier, Function
+from sqlparse.tokens import Keyword, DML, Punctuation
+
+cleanup_regex = {
+ # This matches only alphanumerics and underscores.
+ 'alphanum_underscore': re.compile(r'(\w+)$'),
+ # This matches everything except spaces, parens, colon, and comma
+ 'many_punctuations': re.compile(r'([^():,\s]+)$'),
+ # This matches everything except spaces, parens, colon, comma, and period
+ 'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
+ # This matches everything except a space.
+ 'all_punctuations': re.compile('([^\s]+)$'),
+ }
+
+def last_word(text, include='alphanum_underscore'):
+ """
+ Find the last word in a sentence.
+
+ >>> last_word('abc')
+ 'abc'
+ >>> last_word(' abc')
+ 'abc'
+ >>> last_word('')
+ ''
+ >>> last_word(' ')
+ ''
+ >>> last_word('abc ')
+ ''
+ >>> last_word('abc def')
+ 'def'
+ >>> last_word('abc def ')
+ ''
+ >>> last_word('abc def;')
+ ''
+ >>> last_word('bac $def')
+ 'def'
+ >>> last_word('bac $def', include='most_punctuations')
+ '$def'
+ >>> last_word('bac \def', include='most_punctuations')
+ '\\\\def'
+ >>> last_word('bac \def;', include='most_punctuations')
+ '\\\\def;'
+ >>> last_word('bac::def', include='most_punctuations')
+ 'def'
+ """
+
+ if not text: # Empty string
+ return ''
+
+ if text[-1].isspace():
+ return ''
+ else:
+ regex = cleanup_regex[include]
+ matches = regex.search(text)
+ if matches:
+ return matches.group(0)
+ else:
+ return ''
+
+
+# This code is borrowed from sqlparse example script.
+# <url>
+def is_subselect(parsed):
+ if not parsed.is_group:
+ return False
+ for item in parsed.tokens:
+ if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT',
+ 'UPDATE', 'CREATE', 'DELETE'):
+ return True
+ return False
+
+def extract_from_part(parsed, stop_at_punctuation=True):
+ tbl_prefix_seen = False
+ for item in parsed.tokens:
+ if tbl_prefix_seen:
+ if is_subselect(item):
+ for x in extract_from_part(item, stop_at_punctuation):
+ yield x
+ elif stop_at_punctuation and item.ttype is Punctuation:
+ return
+ # An incomplete nested select won't be recognized correctly as a
+ # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
+ # the second FROM to trigger this elif condition resulting in a
+ # StopIteration. So we need to ignore the keyword if the keyword
+ # FROM.
+ # Also 'SELECT * FROM abc JOIN def' will trigger this elif
+ # condition. So we need to ignore the keyword JOIN and its variants
+ # INNER JOIN, FULL OUTER JOIN, etc.
+ elif item.ttype is Keyword and (
+ not item.value.upper() == 'FROM') and (
+ not item.value.upper().endswith('JOIN')):
+ return
+ else:
+ yield item
+ elif ((item.ttype is Keyword or item.ttype is Keyword.DML) and
+ item.value.upper() in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE', 'JOIN',)):
+ tbl_prefix_seen = True
+ # 'SELECT a, FROM abc' will detect FROM as part of the column list.
+ # So this check here is necessary.
+ elif isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ if (identifier.ttype is Keyword and
+ identifier.value.upper() == 'FROM'):
+ tbl_prefix_seen = True
+ break
+
+def extract_table_identifiers(token_stream):
+ """yields tuples of (schema_name, table_name, table_alias)"""
+
+ for item in token_stream:
+ if isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ # Sometimes Keywords (such as FROM ) are classified as
+ # identifiers which don't have the get_real_name() method.
+ try:
+ schema_name = identifier.get_parent_name()
+ real_name = identifier.get_real_name()
+ except AttributeError:
+ continue
+ if real_name:
+ yield (schema_name, real_name, identifier.get_alias())
+ elif isinstance(item, Identifier):
+ real_name = item.get_real_name()
+ schema_name = item.get_parent_name()
+
+ if real_name:
+ yield (schema_name, real_name, item.get_alias())
+ else:
+ name = item.get_name()
+ yield (None, name, item.get_alias() or name)
+ elif isinstance(item, Function):
+ yield (None, item.get_name(), item.get_name())
+
+# extract_tables is inspired from examples in the sqlparse lib.
+def extract_tables(sql):
+ """Extract the table names from an SQL statment.
+
+ Returns a list of (schema, table, alias) tuples
+
+ """
+ parsed = sqlparse.parse(sql)
+ if not parsed:
+ return []
+
+ # INSERT statements must stop looking for tables at the sign of first
+ # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
+ # abc is the table name, but if we don't stop at the first lparen, then
+ # we'll identify abc, col1 and col2 as table names.
+ insert_stmt = parsed[0].token_first().value.lower() == 'insert'
+ stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
+ return list(extract_table_identifiers(stream))
+
+def find_prev_keyword(sql):
+ """ Find the last sql keyword in an SQL statement
+
+ Returns the value of the last keyword, and the text of the query with
+ everything after the last keyword stripped
+ """
+ if not sql.strip():
+ return None, ''
+
+ parsed = sqlparse.parse(sql)[0]
+ flattened = list(parsed.flatten())
+
+ logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
+
+ for t in reversed(flattened):
+ if t.value == '(' or (t.is_keyword and (
+ t.value.upper() not in logical_operators)):
+ # Find the location of token t in the original parsed statement
+ # We can't use parsed.token_index(t) because t may be a child token
+ # inside a TokenList, in which case token_index thows an error
+ # Minimal example:
+ # p = sqlparse.parse('select * from foo where bar')
+ # t = list(p.flatten())[-3] # The "Where" token
+ # p.token_index(t) # Throws ValueError: not in list
+ idx = flattened.index(t)
+
+ # Combine the string values of all tokens in the original list
+ # up to and including the target keyword token t, to produce a
+ # query string with everything after the keyword token removed
+ text = ''.join(tok.value for tok in flattened[:idx+1])
+ return t, text
+
+ return None, ''
+
+
+def query_starts_with(query, prefixes):
+ """Check if the query starts with any item from *prefixes*."""
+ prefixes = [prefix.lower() for prefix in prefixes]
+ formatted_sql = sqlparse.format(query.lower(), strip_comments=True)
+ return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
+
+
+def queries_start_with(queries, prefixes):
+ """Check if any queries start with any item from *prefixes*."""
+ for query in sqlparse.split(queries):
+ if query and query_starts_with(query, prefixes) is True:
+ return True
+ return False
+
+
+def query_has_where_clause(query):
+ """Check if the query contains a where-clause."""
+ return any(
+ isinstance(token, sqlparse.sql.Where)
+ for token_list in sqlparse.parse(query)
+ for token in token_list
+ )
+
+
+def is_destructive(queries):
+ """Returns if any of the queries in *queries* is destructive."""
+ keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter')
+ for query in sqlparse.split(queries):
+ if query:
+ if query_starts_with(query, keywords) is True:
+ return True
+ elif query_starts_with(
+ query, ['update']
+ ) is True and not query_has_where_clause(query):
+ return True
+
+ return False
+
+
+def is_open_quote(sql):
+ """Returns true if the query contains an unclosed quote."""
+
+ # parsed can contain one or more semi-colon separated commands
+ parsed = sqlparse.parse(sql)
+ return any(_parsed_is_open_quote(p) for p in parsed)
+
+
+if __name__ == '__main__':
+ sql = 'select * from (select t. from tabl t'
+ print (extract_tables(sql))
+
+
+def is_dropping_database(queries, dbname):
+ """Determine if the query is dropping a specific database."""
+ result = False
+ if dbname is None:
+ return False
+
+ def normalize_db_name(db):
+ return db.lower().strip('`"')
+
+ dbname = normalize_db_name(dbname)
+
+ for query in sqlparse.parse(queries):
+ keywords = [t for t in query.tokens if t.is_keyword]
+ if len(keywords) < 2:
+ continue
+ if keywords[0].normalized in ("DROP", "CREATE") and keywords[1].value.lower() in (
+ "database",
+ "schema",
+ ):
+ database_token = next(
+ (t for t in query.tokens if isinstance(t, Identifier)), None
+ )
+ if database_token is not None and normalize_db_name(database_token.get_name()) == dbname:
+ result = keywords[0].normalized == "DROP"
+ else:
+ return result
diff --git a/mycli/packages/prompt_utils.py b/mycli/packages/prompt_utils.py
new file mode 100644
index 0000000..fb1e431
--- /dev/null
+++ b/mycli/packages/prompt_utils.py
@@ -0,0 +1,54 @@
+import sys
+import click
+from .parseutils import is_destructive
+
+
+class ConfirmBoolParamType(click.ParamType):
+ name = 'confirmation'
+
+ def convert(self, value, param, ctx):
+ if isinstance(value, bool):
+ return bool(value)
+ value = value.lower()
+ if value in ('yes', 'y'):
+ return True
+ elif value in ('no', 'n'):
+ return False
+ self.fail('%s is not a valid boolean' % value, param, ctx)
+
+ def __repr__(self):
+ return 'BOOL'
+
+
+BOOLEAN_TYPE = ConfirmBoolParamType()
+
+
+def confirm_destructive_query(queries):
+ """Check if the query is destructive and prompts the user to confirm.
+
+ Returns:
+ * None if the query is non-destructive or we can't prompt the user.
+ * True if the query is destructive and the user wants to proceed.
+ * False if the query is destructive and the user doesn't want to proceed.
+
+ """
+ prompt_text = ("You're about to run a destructive command.\n"
+ "Do you want to proceed? (y/n)")
+ if is_destructive(queries) and sys.stdin.isatty():
+ return prompt(prompt_text, type=BOOLEAN_TYPE)
+
+
+def confirm(*args, **kwargs):
+ """Prompt for confirmation (yes/no) and handle any abort exceptions."""
+ try:
+ return click.confirm(*args, **kwargs)
+ except click.Abort:
+ return False
+
+
+def prompt(*args, **kwargs):
+ """Prompt the user for input and handle any abort exceptions."""
+ try:
+ return click.prompt(*args, **kwargs)
+ except click.Abort:
+ return False
diff --git a/mycli/packages/special/__init__.py b/mycli/packages/special/__init__.py
new file mode 100644
index 0000000..92bcca6
--- /dev/null
+++ b/mycli/packages/special/__init__.py
@@ -0,0 +1,10 @@
+__all__ = []
+
+def export(defn):
+ """Decorator to explicitly mark functions that are exposed in a lib."""
+ globals()[defn.__name__] = defn
+ __all__.append(defn.__name__)
+ return defn
+
+from . import dbcommands
+from . import iocommands
diff --git a/mycli/packages/special/dbcommands.py b/mycli/packages/special/dbcommands.py
new file mode 100644
index 0000000..ed90e4c
--- /dev/null
+++ b/mycli/packages/special/dbcommands.py
@@ -0,0 +1,157 @@
+import logging
+import os
+import platform
+from mycli import __version__
+from mycli.packages.special import iocommands
+from mycli.packages.special.utils import format_uptime
+from .main import special_command, RAW_QUERY, PARSED_QUERY
+from pymysql import ProgrammingError
+
+log = logging.getLogger(__name__)
+
+
+@special_command('\\dt', '\\dt[+] [table]', 'List or describe tables.',
+ arg_type=PARSED_QUERY, case_sensitive=True)
+def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False):
+ if arg:
+ query = 'SHOW FIELDS FROM {0}'.format(arg)
+ else:
+ query = 'SHOW TABLES'
+ log.debug(query)
+ cur.execute(query)
+ tables = cur.fetchall()
+ status = ''
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ else:
+ return [(None, None, None, '')]
+
+ if verbose and arg:
+ query = 'SHOW CREATE TABLE {0}'.format(arg)
+ log.debug(query)
+ cur.execute(query)
+ status = cur.fetchone()[1]
+
+ return [(None, tables, headers, status)]
+
+@special_command('\\l', '\\l', 'List databases.', arg_type=RAW_QUERY, case_sensitive=True)
+def list_databases(cur, **_):
+ query = 'SHOW DATABASES'
+ log.debug(query)
+ cur.execute(query)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ return [(None, cur, headers, '')]
+ else:
+ return [(None, None, None, '')]
+
+@special_command('status', '\\s', 'Get status information from the server.',
+ arg_type=RAW_QUERY, aliases=('\\s', ), case_sensitive=True)
+def status(cur, **_):
+ query = 'SHOW GLOBAL STATUS;'
+ log.debug(query)
+ try:
+ cur.execute(query)
+ except ProgrammingError:
+ # Fallback in case query fail, as it does with Mysql 4
+ query = 'SHOW STATUS;'
+ log.debug(query)
+ cur.execute(query)
+ status = dict(cur.fetchall())
+
+ query = 'SHOW GLOBAL VARIABLES;'
+ log.debug(query)
+ cur.execute(query)
+ variables = dict(cur.fetchall())
+
+ # prepare in case keys are bytes, as with Python 3 and Mysql 4
+ if (isinstance(list(variables)[0], bytes) and
+ isinstance(list(status)[0], bytes)):
+ variables = {k.decode('utf-8'): v.decode('utf-8') for k, v
+ in variables.items()}
+ status = {k.decode('utf-8'): v.decode('utf-8') for k, v
+ in status.items()}
+
+ # Create output buffers.
+ title = []
+ output = []
+ footer = []
+
+ title.append('--------------')
+
+ # Output the mycli client information.
+ implementation = platform.python_implementation()
+ version = platform.python_version()
+ client_info = []
+ client_info.append('mycli {0},'.format(__version__))
+ client_info.append('running on {0} {1}'.format(implementation, version))
+ title.append(' '.join(client_info) + '\n')
+
+ # Build the output that will be displayed as a table.
+ output.append(('Connection id:', cur.connection.thread_id()))
+
+ query = 'SELECT DATABASE(), USER();'
+ log.debug(query)
+ cur.execute(query)
+ db, user = cur.fetchone()
+ if db is None:
+ db = ''
+
+ output.append(('Current database:', db))
+ output.append(('Current user:', user))
+
+ if iocommands.is_pager_enabled():
+ if 'PAGER' in os.environ:
+ pager = os.environ['PAGER']
+ else:
+ pager = 'System default'
+ else:
+ pager = 'stdout'
+ output.append(('Current pager:', pager))
+
+ output.append(('Server version:', '{0} {1}'.format(
+ variables['version'], variables['version_comment'])))
+ output.append(('Protocol version:', variables['protocol_version']))
+
+ if 'unix' in cur.connection.host_info.lower():
+ host_info = cur.connection.host_info
+ else:
+ host_info = '{0} via TCP/IP'.format(cur.connection.host)
+
+ output.append(('Connection:', host_info))
+
+ query = ('SELECT @@character_set_server, @@character_set_database, '
+ '@@character_set_client, @@character_set_connection LIMIT 1;')
+ log.debug(query)
+ cur.execute(query)
+ charset = cur.fetchone()
+ output.append(('Server characterset:', charset[0]))
+ output.append(('Db characterset:', charset[1]))
+ output.append(('Client characterset:', charset[2]))
+ output.append(('Conn. characterset:', charset[3]))
+
+ if 'TCP/IP' in host_info:
+ output.append(('TCP port:', cur.connection.port))
+ else:
+ output.append(('UNIX socket:', variables['socket']))
+
+ output.append(('Uptime:', format_uptime(status['Uptime'])))
+
+ # Print the current server statistics.
+ stats = []
+ stats.append('Connections: {0}'.format(status['Threads_connected']))
+ if 'Queries' in status:
+ stats.append('Queries: {0}'.format(status['Queries']))
+ stats.append('Slow queries: {0}'.format(status['Slow_queries']))
+ stats.append('Opens: {0}'.format(status['Opened_tables']))
+ stats.append('Flush tables: {0}'.format(status['Flush_commands']))
+ stats.append('Open tables: {0}'.format(status['Open_tables']))
+ if 'Queries' in status:
+ queries_per_second = int(status['Queries']) / int(status['Uptime'])
+ stats.append('Queries per second avg: {:.3f}'.format(
+ queries_per_second))
+ stats = ' '.join(stats)
+ footer.append('\n' + stats)
+
+ footer.append('--------------')
+ return [('\n'.join(title), output, '', '\n'.join(footer))]
diff --git a/mycli/packages/special/delimitercommand.py b/mycli/packages/special/delimitercommand.py
new file mode 100644
index 0000000..994b134
--- /dev/null
+++ b/mycli/packages/special/delimitercommand.py
@@ -0,0 +1,80 @@
+import re
+import sqlparse
+
+
+class DelimiterCommand(object):
+ def __init__(self):
+ self._delimiter = ';'
+
+ def _split(self, sql):
+ """Temporary workaround until sqlparse.split() learns about custom
+ delimiters."""
+
+ placeholder = "\ufffc" # unicode object replacement character
+
+ if self._delimiter == ';':
+ return sqlparse.split(sql)
+
+ # We must find a string that original sql does not contain.
+ # Most likely, our placeholder is enough, but if not, keep looking
+ while placeholder in sql:
+ placeholder += placeholder[0]
+ sql = sql.replace(';', placeholder)
+ sql = sql.replace(self._delimiter, ';')
+
+ split = sqlparse.split(sql)
+
+ return [
+ stmt.replace(';', self._delimiter).replace(placeholder, ';')
+ for stmt in split
+ ]
+
+ def queries_iter(self, input):
+ """Iterate over queries in the input string."""
+
+ queries = self._split(input)
+ while queries:
+ for sql in queries:
+ delimiter = self._delimiter
+ sql = queries.pop(0)
+ if sql.endswith(delimiter):
+ trailing_delimiter = True
+ sql = sql.strip(delimiter)
+ else:
+ trailing_delimiter = False
+
+ yield sql
+
+ # if the delimiter was changed by the last command,
+ # re-split everything, and if we previously stripped
+ # the delimiter, append it to the end
+ if self._delimiter != delimiter:
+ combined_statement = ' '.join([sql] + queries)
+ if trailing_delimiter:
+ combined_statement += delimiter
+ queries = self._split(combined_statement)[1:]
+
+ def set(self, arg, **_):
+ """Change delimiter.
+
+ Since `arg` is everything that follows the DELIMITER token
+ after sqlparse (it may include other statements separated by
+ the new delimiter), we want to set the delimiter to the first
+ word of it.
+
+ """
+ match = arg and re.search(r'[^\s]+', arg)
+ if not match:
+ message = 'Missing required argument, delimiter'
+ return [(None, None, None, message)]
+
+ delimiter = match.group()
+ if delimiter.lower() == 'delimiter':
+ return [(None, None, None, 'Invalid delimiter "delimiter"')]
+
+ self._delimiter = delimiter
+ return [(None, None, None, "Changed delimiter to {}".format(delimiter))]
+
+ @property
+ def current(self):
+ return self._delimiter
diff --git a/mycli/packages/special/favoritequeries.py b/mycli/packages/special/favoritequeries.py
new file mode 100644
index 0000000..0b91400
--- /dev/null
+++ b/mycli/packages/special/favoritequeries.py
@@ -0,0 +1,63 @@
+class FavoriteQueries(object):
+
+ section_name = 'favorite_queries'
+
+ usage = '''
+Favorite Queries are a way to save frequently used queries
+with a short name.
+Examples:
+
+ # Save a new favorite query.
+ > \\fs simple select * from abc where a is not Null;
+
+ # List all favorite queries.
+ > \\f
+ ╒════════╤═══════════════════════════════════════╕
+ │ Name │ Query │
+ ╞════════╪═══════════════════════════════════════╡
+ │ simple │ SELECT * FROM abc where a is not NULL │
+ ╘════════╧═══════════════════════════════════════╛
+
+ # Run a favorite query.
+ > \\f simple
+ ╒════════╤════════╕
+ │ a │ b │
+ ╞════════╪════════╡
+ │ 日本語 │ 日本語 │
+ ╘════════╧════════╛
+
+ # Delete a favorite query.
+ > \\fd simple
+ simple: Deleted
+'''
+
+ # Class-level variable, for convenience to use as a singleton.
+ instance = None
+
+ def __init__(self, config):
+ self.config = config
+
+ @classmethod
+ def from_config(cls, config):
+ return FavoriteQueries(config)
+
+ def list(self):
+ return self.config.get(self.section_name, [])
+
+ def get(self, name):
+ return self.config.get(self.section_name, {}).get(name, None)
+
+ def save(self, name, query):
+ self.config.encoding = 'utf-8'
+ if self.section_name not in self.config:
+ self.config[self.section_name] = {}
+ self.config[self.section_name][name] = query
+ self.config.write()
+
+ def delete(self, name):
+ try:
+ del self.config[self.section_name][name]
+ except KeyError:
+ return '%s: Not Found.' % name
+ self.config.write()
+ return '%s: Deleted' % name
diff --git a/mycli/packages/special/iocommands.py b/mycli/packages/special/iocommands.py
new file mode 100644
index 0000000..11dca8d
--- /dev/null
+++ b/mycli/packages/special/iocommands.py
@@ -0,0 +1,453 @@
+import os
+import re
+import locale
+import logging
+import subprocess
+import shlex
+from io import open
+from time import sleep
+
+import click
+import sqlparse
+
+from . import export
+from .main import special_command, NO_QUERY, PARSED_QUERY
+from .favoritequeries import FavoriteQueries
+from .delimitercommand import DelimiterCommand
+from .utils import handle_cd_command
+from mycli.packages.prompt_utils import confirm_destructive_query
+
+TIMING_ENABLED = False
+use_expanded_output = False
+PAGER_ENABLED = True
+tee_file = None
+once_file = None
+written_to_once_file = False
+delimiter_command = DelimiterCommand()
+
+
+@export
+def set_timing_enabled(val):
+ global TIMING_ENABLED
+ TIMING_ENABLED = val
+
+@export
+def set_pager_enabled(val):
+ global PAGER_ENABLED
+ PAGER_ENABLED = val
+
+
+@export
+def is_pager_enabled():
+ return PAGER_ENABLED
+
+@export
+@special_command('pager', '\\P [command]',
+ 'Set PAGER. Print the query results via PAGER.',
+ arg_type=PARSED_QUERY, aliases=('\\P', ), case_sensitive=True)
+def set_pager(arg, **_):
+ if arg:
+ os.environ['PAGER'] = arg
+ msg = 'PAGER set to %s.' % arg
+ set_pager_enabled(True)
+ else:
+ if 'PAGER' in os.environ:
+ msg = 'PAGER set to %s.' % os.environ['PAGER']
+ else:
+ # This uses click's default per echo_via_pager.
+ msg = 'Pager enabled.'
+ set_pager_enabled(True)
+
+ return [(None, None, None, msg)]
+
+@export
+@special_command('nopager', '\\n', 'Disable pager, print to stdout.',
+ arg_type=NO_QUERY, aliases=('\\n', ), case_sensitive=True)
+def disable_pager():
+ set_pager_enabled(False)
+ return [(None, None, None, 'Pager disabled.')]
+
+@special_command('\\timing', '\\t', 'Toggle timing of commands.', arg_type=NO_QUERY, aliases=('\\t', ), case_sensitive=True)
+def toggle_timing():
+ global TIMING_ENABLED
+ TIMING_ENABLED = not TIMING_ENABLED
+ message = "Timing is "
+ message += "on." if TIMING_ENABLED else "off."
+ return [(None, None, None, message)]
+
+@export
+def is_timing_enabled():
+ return TIMING_ENABLED
+
+@export
+def set_expanded_output(val):
+ global use_expanded_output
+ use_expanded_output = val
+
+@export
+def is_expanded_output():
+ return use_expanded_output
+
+_logger = logging.getLogger(__name__)
+
+@export
+def editor_command(command):
+ """
+ Is this an external editor command?
+ :param command: string
+ """
+ # It is possible to have `\e filename` or `SELECT * FROM \e`. So we check
+ # for both conditions.
+ return command.strip().endswith('\\e') or command.strip().startswith('\\e')
+
+@export
+def get_filename(sql):
+ if sql.strip().startswith('\\e'):
+ command, _, filename = sql.partition(' ')
+ return filename.strip() or None
+
+
+@export
+def get_editor_query(sql):
+ """Get the query part of an editor command."""
+ sql = sql.strip()
+
+ # The reason we can't simply do .strip('\e') is that it strips characters,
+ # not a substring. So it'll strip "e" in the end of the sql also!
+ # Ex: "select * from style\e" -> "select * from styl".
+ pattern = re.compile('(^\\\e|\\\e$)')
+ while pattern.search(sql):
+ sql = pattern.sub('', sql)
+
+ return sql
+
+
+@export
+def open_external_editor(filename=None, sql=None):
+ """Open external editor, wait for the user to type in their query, return
+ the query.
+
+ :return: list with one tuple, query as first element.
+
+ """
+
+ message = None
+ filename = filename.strip().split(' ', 1)[0] if filename else None
+
+ sql = sql or ''
+ MARKER = '# Type your query above this line.\n'
+
+ # Populate the editor buffer with the partial sql (if available) and a
+ # placeholder comment.
+ query = click.edit(u'{sql}\n\n{marker}'.format(sql=sql, marker=MARKER),
+ filename=filename, extension='.sql')
+
+ if filename:
+ try:
+ with open(filename) as f:
+ query = f.read()
+ except IOError:
+ message = 'Error reading file: %s.' % filename
+
+ if query is not None:
+ query = query.split(MARKER, 1)[0].rstrip('\n')
+ else:
+ # Don't return None for the caller to deal with.
+ # Empty string is ok.
+ query = sql
+
+ return (query, message)
+
+
+@special_command('\\f', '\\f [name [args..]]', 'List or execute favorite queries.', arg_type=PARSED_QUERY, case_sensitive=True)
+def execute_favorite_query(cur, arg, **_):
+ """Returns (title, rows, headers, status)"""
+ if arg == '':
+ for result in list_favorite_queries():
+ yield result
+
+ """Parse out favorite name and optional substitution parameters"""
+ name, _, arg_str = arg.partition(' ')
+ args = shlex.split(arg_str)
+
+ query = FavoriteQueries.instance.get(name)
+ if query is None:
+ message = "No favorite query: %s" % (name)
+ yield (None, None, None, message)
+ else:
+ query, arg_error = subst_favorite_query_args(query, args)
+ if arg_error:
+ yield (None, None, None, arg_error)
+ else:
+ for sql in sqlparse.split(query):
+ sql = sql.rstrip(';')
+ title = '> %s' % (sql)
+ cur.execute(sql)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ yield (title, cur, headers, None)
+ else:
+ yield (title, None, None, None)
+
+def list_favorite_queries():
+ """List of all favorite queries.
+ Returns (title, rows, headers, status)"""
+
+ headers = ["Name", "Query"]
+ rows = [(r, FavoriteQueries.instance.get(r))
+ for r in FavoriteQueries.instance.list()]
+
+ if not rows:
+ status = '\nNo favorite queries found.' + FavoriteQueries.instance.usage
+ else:
+ status = ''
+ return [('', rows, headers, status)]
+
+
+def subst_favorite_query_args(query, args):
+ """replace positional parameters ($1...$N) in query."""
+ for idx, val in enumerate(args):
+ subst_var = '$' + str(idx + 1)
+ if subst_var not in query:
+ return [None, 'query does not have substitution parameter ' + subst_var + ':\n ' + query]
+
+ query = query.replace(subst_var, val)
+
+ match = re.search('\\$\d+', query)
+ if match:
+ return[None, 'missing substitution for ' + match.group(0) + ' in query:\n ' + query]
+
+ return [query, None]
+
+@special_command('\\fs', '\\fs name query', 'Save a favorite query.')
+def save_favorite_query(arg, **_):
+ """Save a new favorite query.
+ Returns (title, rows, headers, status)"""
+
+ usage = 'Syntax: \\fs name query.\n\n' + FavoriteQueries.instance.usage
+ if not arg:
+ return [(None, None, None, usage)]
+
+ name, _, query = arg.partition(' ')
+
+ # If either name or query is missing then print the usage and complain.
+ if (not name) or (not query):
+ return [(None, None, None,
+ usage + 'Err: Both name and query are required.')]
+
+ FavoriteQueries.instance.save(name, query)
+ return [(None, None, None, "Saved.")]
+
+
+@special_command('\\fd', '\\fd [name]', 'Delete a favorite query.')
+def delete_favorite_query(arg, **_):
+ """Delete an existing favorite query."""
+ usage = 'Syntax: \\fd name.\n\n' + FavoriteQueries.instance.usage
+ if not arg:
+ return [(None, None, None, usage)]
+
+ status = FavoriteQueries.instance.delete(arg)
+
+ return [(None, None, None, status)]
+
+
+@special_command('system', 'system [command]',
+ 'Execute a system shell commmand.')
+def execute_system_command(arg, **_):
+ """Execute a system shell command."""
+ usage = "Syntax: system [command].\n"
+
+ if not arg:
+ return [(None, None, None, usage)]
+
+ try:
+ command = arg.strip()
+ if command.startswith('cd'):
+ ok, error_message = handle_cd_command(arg)
+ if not ok:
+ return [(None, None, None, error_message)]
+ return [(None, None, None, '')]
+
+ args = arg.split(' ')
+ process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ output, error = process.communicate()
+ response = output if not error else error
+
+ # Python 3 returns bytes. This needs to be decoded to a string.
+ if isinstance(response, bytes):
+ encoding = locale.getpreferredencoding(False)
+ response = response.decode(encoding)
+
+ return [(None, None, None, response)]
+ except OSError as e:
+ return [(None, None, None, 'OSError: %s' % e.strerror)]
+
+
+def parseargfile(arg):
+ if arg.startswith('-o '):
+ mode = "w"
+ filename = arg[3:]
+ else:
+ mode = 'a'
+ filename = arg
+
+ if not filename:
+ raise TypeError('You must provide a filename.')
+
+ return {'file': os.path.expanduser(filename), 'mode': mode}
+
+
+@special_command('tee', 'tee [-o] filename',
+ 'Append all results to an output file (overwrite using -o).')
+def set_tee(arg, **_):
+ global tee_file
+
+ try:
+ tee_file = open(**parseargfile(arg))
+ except (IOError, OSError) as e:
+ raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror))
+
+ return [(None, None, None, "")]
+
+@export
+def close_tee():
+ global tee_file
+ if tee_file:
+ tee_file.close()
+ tee_file = None
+
+
+@special_command('notee', 'notee', 'Stop writing results to an output file.')
+def no_tee(arg, **_):
+ close_tee()
+ return [(None, None, None, "")]
+
+@export
+def write_tee(output):
+ global tee_file
+ if tee_file:
+ click.echo(output, file=tee_file, nl=False)
+ click.echo(u'\n', file=tee_file, nl=False)
+ tee_file.flush()
+
+
+@special_command('\\once', '\\o [-o] filename',
+ 'Append next result to an output file (overwrite using -o).',
+ aliases=('\\o', ))
+def set_once(arg, **_):
+ global once_file, written_to_once_file
+
+ once_file = parseargfile(arg)
+ written_to_once_file = False
+
+ return [(None, None, None, "")]
+
+
+@export
+def write_once(output):
+ global once_file, written_to_once_file
+ if output and once_file:
+ try:
+ f = open(**once_file)
+ except (IOError, OSError) as e:
+ once_file = None
+ raise OSError("Cannot write to file '{}': {}".format(
+ e.filename, e.strerror))
+ with f:
+ click.echo(output, file=f, nl=False)
+ click.echo(u"\n", file=f, nl=False)
+ written_to_once_file = True
+
+
+@export
+def unset_once_if_written():
+ """Unset the once file, if it has been written to."""
+ global once_file
+ if written_to_once_file:
+ once_file = None
+
+
+@special_command(
+ 'watch',
+ 'watch [seconds] [-c] query',
+ 'Executes the query every [seconds] seconds (by default 5).'
+)
+def watch_query(arg, **kwargs):
+ usage = """Syntax: watch [seconds] [-c] query.
+ * seconds: The interval at the query will be repeated, in seconds.
+ By default 5.
+ * -c: Clears the screen between every iteration.
+"""
+ if not arg:
+ yield (None, None, None, usage)
+ return
+ seconds = 5
+ clear_screen = False
+ statement = None
+ while statement is None:
+ arg = arg.strip()
+ if not arg:
+ # Oops, we parsed all the arguments without finding a statement
+ yield (None, None, None, usage)
+ return
+ (current_arg, _, arg) = arg.partition(' ')
+ try:
+ seconds = float(current_arg)
+ continue
+ except ValueError:
+ pass
+ if current_arg == '-c':
+ clear_screen = True
+ continue
+ statement = '{0!s} {1!s}'.format(current_arg, arg)
+ destructive_prompt = confirm_destructive_query(statement)
+ if destructive_prompt is False:
+ click.secho("Wise choice!")
+ return
+ elif destructive_prompt is True:
+ click.secho("Your call!")
+ cur = kwargs['cur']
+ sql_list = [
+ (sql.rstrip(';'), "> {0!s}".format(sql))
+ for sql in sqlparse.split(statement)
+ ]
+ old_pager_enabled = is_pager_enabled()
+ while True:
+ if clear_screen:
+ click.clear()
+ try:
+ # Somewhere in the code the pager its activated after every yield,
+ # so we disable it in every iteration
+ set_pager_enabled(False)
+ for (sql, title) in sql_list:
+ cur.execute(sql)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ yield (title, cur, headers, None)
+ else:
+ yield (title, None, None, None)
+ sleep(seconds)
+ except KeyboardInterrupt:
+ # This prints the Ctrl-C character in its own line, which prevents
+ # to print a line with the cursor positioned behind the prompt
+ click.secho("", nl=True)
+ return
+ finally:
+ set_pager_enabled(old_pager_enabled)
+
+
+@export
+@special_command('delimiter', None, 'Change SQL delimiter.')
+def set_delimiter(arg, **_):
+ return delimiter_command.set(arg)
+
+
+@export
+def get_current_delimiter():
+ return delimiter_command.current
+
+
+@export
+def split_queries(input):
+ for query in delimiter_command.queries_iter(input):
+ yield query
diff --git a/mycli/packages/special/main.py b/mycli/packages/special/main.py
new file mode 100644
index 0000000..dddba66
--- /dev/null
+++ b/mycli/packages/special/main.py
@@ -0,0 +1,118 @@
+import logging
+from collections import namedtuple
+
+from . import export
+
+log = logging.getLogger(__name__)
+
+NO_QUERY = 0
+PARSED_QUERY = 1
+RAW_QUERY = 2
+
+SpecialCommand = namedtuple('SpecialCommand',
+ ['handler', 'command', 'shortcut', 'description', 'arg_type', 'hidden',
+ 'case_sensitive'])
+
+COMMANDS = {}
+
+@export
+class CommandNotFound(Exception):
+ pass
+
+@export
+def parse_special_command(sql):
+ command, _, arg = sql.partition(' ')
+ verbose = '+' in command
+ command = command.strip().replace('+', '')
+ return (command, verbose, arg.strip())
+
+@export
+def special_command(command, shortcut, description, arg_type=PARSED_QUERY,
+ hidden=False, case_sensitive=False, aliases=()):
+ def wrapper(wrapped):
+ register_special_command(wrapped, command, shortcut, description,
+ arg_type, hidden, case_sensitive, aliases)
+ return wrapped
+ return wrapper
+
+@export
+def register_special_command(handler, command, shortcut, description,
+ arg_type=PARSED_QUERY, hidden=False, case_sensitive=False, aliases=()):
+ cmd = command.lower() if not case_sensitive else command
+ COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description,
+ arg_type, hidden, case_sensitive)
+ for alias in aliases:
+ cmd = alias.lower() if not case_sensitive else alias
+ COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description,
+ arg_type, case_sensitive=case_sensitive,
+ hidden=True)
+
+@export
+def execute(cur, sql):
+ """Execute a special command and return the results. If the special command
+ is not supported a KeyError will be raised.
+ """
+ command, verbose, arg = parse_special_command(sql)
+
+ if (command not in COMMANDS) and (command.lower() not in COMMANDS):
+ raise CommandNotFound
+
+ try:
+ special_cmd = COMMANDS[command]
+ except KeyError:
+ special_cmd = COMMANDS[command.lower()]
+ if special_cmd.case_sensitive:
+ raise CommandNotFound('Command not found: %s' % command)
+
+ # "help <SQL KEYWORD> is a special case. We want built-in help, not
+ # mycli help here.
+ if command == 'help' and arg:
+ return show_keyword_help(cur=cur, arg=arg)
+
+ if special_cmd.arg_type == NO_QUERY:
+ return special_cmd.handler()
+ elif special_cmd.arg_type == PARSED_QUERY:
+ return special_cmd.handler(cur=cur, arg=arg, verbose=verbose)
+ elif special_cmd.arg_type == RAW_QUERY:
+ return special_cmd.handler(cur=cur, query=sql)
+
+@special_command('help', '\\?', 'Show this help.', arg_type=NO_QUERY, aliases=('\\?', '?'))
+def show_help(): # All the parameters are ignored.
+ headers = ['Command', 'Shortcut', 'Description']
+ result = []
+
+ for _, value in sorted(COMMANDS.items()):
+ if not value.hidden:
+ result.append((value.command, value.shortcut, value.description))
+ return [(None, result, headers, None)]
+
+def show_keyword_help(cur, arg):
+ """
+ Call the built-in "show <command>", to display help for an SQL keyword.
+ :param cur: cursor
+ :param arg: string
+ :return: list
+ """
+ keyword = arg.strip('"').strip("'")
+ query = "help '{0}'".format(keyword)
+ log.debug(query)
+ cur.execute(query)
+ if cur.description and cur.rowcount > 0:
+ headers = [x[0] for x in cur.description]
+ return [(None, cur, headers, '')]
+ else:
+ return [(None, None, None, 'No help found for {0}.'.format(keyword))]
+
+
+@special_command('exit', '\\q', 'Exit.', arg_type=NO_QUERY, aliases=('\\q', ))
+@special_command('quit', '\\q', 'Quit.', arg_type=NO_QUERY)
+def quit(*_args):
+ raise EOFError
+
+
+@special_command('\\e', '\\e', 'Edit command with editor (uses $EDITOR).',
+ arg_type=NO_QUERY, case_sensitive=True)
+@special_command('\\G', '\\G', 'Display current query results vertically.',
+ arg_type=NO_QUERY, case_sensitive=True)
+def stub():
+ raise NotImplementedError
diff --git a/mycli/packages/special/utils.py b/mycli/packages/special/utils.py
new file mode 100644
index 0000000..ef96093
--- /dev/null
+++ b/mycli/packages/special/utils.py
@@ -0,0 +1,46 @@
+import os
+import subprocess
+
+def handle_cd_command(arg):
+ """Handles a `cd` shell command by calling python's os.chdir."""
+ CD_CMD = 'cd'
+ tokens = arg.split(CD_CMD + ' ')
+ directory = tokens[-1] if len(tokens) > 1 else None
+ if not directory:
+ return False, "No folder name was provided."
+ try:
+ os.chdir(directory)
+ subprocess.call(['pwd'])
+ return True, None
+ except OSError as e:
+ return False, e.strerror
+
+def format_uptime(uptime_in_seconds):
+ """Format number of seconds into human-readable string.
+
+ :param uptime_in_seconds: The server uptime in seconds.
+ :returns: A human-readable string representing the uptime.
+
+ >>> uptime = format_uptime('56892')
+ >>> print(uptime)
+ 15 hours 48 min 12 sec
+ """
+
+ m, s = divmod(int(uptime_in_seconds), 60)
+ h, m = divmod(m, 60)
+ d, h = divmod(h, 24)
+
+ uptime_values = []
+
+ for value, unit in ((d, 'days'), (h, 'hours'), (m, 'min'), (s, 'sec')):
+ if value == 0 and not uptime_values:
+ # Don't include a value/unit if the unit isn't applicable to
+ # the uptime. E.g. don't do 0 days 0 hours 1 min 30 sec.
+ continue
+ elif value == 1 and unit.endswith('s'):
+ # Remove the "s" if the unit is singular.
+ unit = unit[:-1]
+ uptime_values.append('{0} {1}'.format(value, unit))
+
+ uptime = ' '.join(uptime_values)
+ return uptime
diff --git a/mycli/packages/tabular_output/__init__.py b/mycli/packages/tabular_output/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/mycli/packages/tabular_output/__init__.py
diff --git a/mycli/packages/tabular_output/sql_format.py b/mycli/packages/tabular_output/sql_format.py
new file mode 100644
index 0000000..730e633
--- /dev/null
+++ b/mycli/packages/tabular_output/sql_format.py
@@ -0,0 +1,63 @@
+"""Format adapter for sql."""
+
+from cli_helpers.utils import filter_dict_by_key
+from mycli.packages.parseutils import extract_tables
+
+supported_formats = ('sql-insert', 'sql-update', 'sql-update-1',
+ 'sql-update-2', )
+
+preprocessors = ()
+
+
+def escape_for_sql_statement(value):
+ if isinstance(value, bytes):
+ return f"X'{value.hex()}'"
+ else:
+ return formatter.mycli.sqlexecute.conn.escape(value)
+
+
+def adapter(data, headers, table_format=None, **kwargs):
+ tables = extract_tables(formatter.query)
+ if len(tables) > 0:
+ table = tables[0]
+ if table[0]:
+ table_name = "{}.{}".format(*table[:2])
+ else:
+ table_name = table[1]
+ else:
+ table_name = "`DUAL`"
+ if table_format == 'sql-insert':
+ h = "`, `".join(headers)
+ yield "INSERT INTO {} (`{}`) VALUES".format(table_name, h)
+ prefix = " "
+ for d in data:
+ values = ", ".join(escape_for_sql_statement(v)
+ for i, v in enumerate(d))
+ yield "{}({})".format(prefix, values)
+ if prefix == " ":
+ prefix = ", "
+ yield ";"
+ if table_format.startswith('sql-update'):
+ s = table_format.split('-')
+ keys = 1
+ if len(s) > 2:
+ keys = int(s[-1])
+ for d in data:
+ yield "UPDATE {} SET".format(table_name)
+ prefix = " "
+ for i, v in enumerate(d[keys:], keys):
+ yield "{}`{}` = {}".format(prefix, headers[i], escape_for_sql_statement(v))
+ if prefix == " ":
+ prefix = ", "
+ f = "`{}` = {}"
+ where = (f.format(headers[i], escape_for_sql_statement(
+ d[i])) for i in range(keys))
+ yield "WHERE {};".format(" AND ".join(where))
+
+
+def register_new_formatter(TabularOutputFormatter):
+ global formatter
+ formatter = TabularOutputFormatter
+ for sql_format in supported_formats:
+ TabularOutputFormatter.register_new_formatter(
+ sql_format, adapter, preprocessors, {'table_format': sql_format})