From b678a621c57a6d3fdfac14bdbbef0ed743ab1742 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 8 Feb 2021 12:28:14 +0100 Subject: Adding upstream version 1.22.2. Signed-off-by: Daniel Baumann --- mycli/packages/__init__.py | 0 mycli/packages/completion_engine.py | 295 ++++++++++++++++++ mycli/packages/filepaths.py | 106 +++++++ mycli/packages/paramiko_stub/__init__.py | 28 ++ mycli/packages/parseutils.py | 267 ++++++++++++++++ mycli/packages/prompt_utils.py | 54 ++++ mycli/packages/special/__init__.py | 10 + mycli/packages/special/dbcommands.py | 157 ++++++++++ mycli/packages/special/delimitercommand.py | 80 +++++ mycli/packages/special/favoritequeries.py | 63 ++++ mycli/packages/special/iocommands.py | 453 ++++++++++++++++++++++++++++ mycli/packages/special/main.py | 118 ++++++++ mycli/packages/special/utils.py | 46 +++ mycli/packages/tabular_output/__init__.py | 0 mycli/packages/tabular_output/sql_format.py | 63 ++++ 15 files changed, 1740 insertions(+) create mode 100644 mycli/packages/__init__.py create mode 100644 mycli/packages/completion_engine.py create mode 100644 mycli/packages/filepaths.py create mode 100644 mycli/packages/paramiko_stub/__init__.py create mode 100644 mycli/packages/parseutils.py create mode 100644 mycli/packages/prompt_utils.py create mode 100644 mycli/packages/special/__init__.py create mode 100644 mycli/packages/special/dbcommands.py create mode 100644 mycli/packages/special/delimitercommand.py create mode 100644 mycli/packages/special/favoritequeries.py create mode 100644 mycli/packages/special/iocommands.py create mode 100644 mycli/packages/special/main.py create mode 100644 mycli/packages/special/utils.py create mode 100644 mycli/packages/tabular_output/__init__.py create mode 100644 mycli/packages/tabular_output/sql_format.py (limited to 'mycli/packages') diff --git a/mycli/packages/__init__.py b/mycli/packages/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mycli/packages/completion_engine.py b/mycli/packages/completion_engine.py new file mode 100644 index 0000000..2b19c32 --- /dev/null +++ b/mycli/packages/completion_engine.py @@ -0,0 +1,295 @@ +import os +import sys +import sqlparse +from sqlparse.sql import Comparison, Identifier, Where +from sqlparse.compat import text_type +from .parseutils import last_word, extract_tables, find_prev_keyword +from .special import parse_special_command + + +def suggest_type(full_text, text_before_cursor): + """Takes the full_text that is typed so far and also the text before the + cursor to suggest completion type and scope. + + Returns a tuple with a type of entity ('table', 'column' etc) and a scope. + A scope for a column category will be a list of tables. + """ + + word_before_cursor = last_word(text_before_cursor, + include='many_punctuations') + + identifier = None + + # here should be removed once sqlparse has been fixed + try: + # If we've partially typed a word then word_before_cursor won't be an empty + # string. In that case we want to remove the partially typed string before + # sending it to the sqlparser. Otherwise the last token will always be the + # partially typed string which renders the smart completion useless because + # it will always return the list of keywords as completion. + if word_before_cursor: + if word_before_cursor.endswith( + '(') or word_before_cursor.startswith('\\'): + parsed = sqlparse.parse(text_before_cursor) + else: + parsed = sqlparse.parse( + text_before_cursor[:-len(word_before_cursor)]) + + # word_before_cursor may include a schema qualification, like + # "schema_name.partial_name" or "schema_name.", so parse it + # separately + p = sqlparse.parse(word_before_cursor)[0] + + if p.tokens and isinstance(p.tokens[0], Identifier): + identifier = p.tokens[0] + else: + parsed = sqlparse.parse(text_before_cursor) + except (TypeError, AttributeError): + return [{'type': 'keyword'}] + + if len(parsed) > 1: + # Multiple statements being edited -- isolate the current one by + # cumulatively summing statement lengths to find the one that bounds the + # current position + current_pos = len(text_before_cursor) + stmt_start, stmt_end = 0, 0 + + for statement in parsed: + stmt_len = len(text_type(statement)) + stmt_start, stmt_end = stmt_end, stmt_end + stmt_len + + if stmt_end >= current_pos: + text_before_cursor = full_text[stmt_start:current_pos] + full_text = full_text[stmt_start:] + break + + elif parsed: + # A single statement + statement = parsed[0] + else: + # The empty string + statement = None + + # Check for special commands and handle those separately + if statement: + # Be careful here because trivial whitespace is parsed as a statement, + # but the statement won't have a first token + tok1 = statement.token_first() + if tok1 and (tok1.value == 'source' or tok1.value.startswith('\\')): + return suggest_special(text_before_cursor) + + last_token = statement and statement.token_prev(len(statement.tokens))[1] or '' + + return suggest_based_on_last_token(last_token, text_before_cursor, + full_text, identifier) + + +def suggest_special(text): + text = text.lstrip() + cmd, _, arg = parse_special_command(text) + + if cmd == text: + # Trying to complete the special command itself + return [{'type': 'special'}] + + if cmd in ('\\u', '\\r'): + return [{'type': 'database'}] + + if cmd in ('\\T'): + return [{'type': 'table_format'}] + + if cmd in ['\\f', '\\fs', '\\fd']: + return [{'type': 'favoritequery'}] + + if cmd in ['\\dt', '\\dt+']: + return [ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}, + ] + elif cmd in ['\\.', 'source']: + return[{'type': 'file_name'}] + + return [{'type': 'keyword'}, {'type': 'special'}] + + +def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier): + if isinstance(token, str): + token_v = token.lower() + elif isinstance(token, Comparison): + # If 'token' is a Comparison type such as + # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling + # token.value on the comparison type will only return the lhs of the + # comparison. In this case a.id. So we need to do token.tokens to get + # both sides of the comparison and pick the last token out of that + # list. + token_v = token.tokens[-1].value.lower() + elif isinstance(token, Where): + # sqlparse groups all tokens from the where clause into a single token + # list. This means that token.value may be something like + # 'where foo > 5 and '. We need to look "inside" token.tokens to handle + # suggestions in complicated where clauses correctly + prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor) + return suggest_based_on_last_token(prev_keyword, text_before_cursor, + full_text, identifier) + else: + token_v = token.value.lower() + + is_operand = lambda x: x and any([x.endswith(op) for op in ['+', '-', '*', '/']]) + + if not token: + return [{'type': 'keyword'}, {'type': 'special'}] + elif token_v.endswith('('): + p = sqlparse.parse(text_before_cursor)[0] + + if p.tokens and isinstance(p.tokens[-1], Where): + # Four possibilities: + # 1 - Parenthesized clause like "WHERE foo AND (" + # Suggest columns/functions + # 2 - Function call like "WHERE foo(" + # Suggest columns/functions + # 3 - Subquery expression like "WHERE EXISTS (" + # Suggest keywords, in order to do a subquery + # 4 - Subquery OR array comparison like "WHERE foo = ANY(" + # Suggest columns/functions AND keywords. (If we wanted to be + # really fancy, we could suggest only array-typed columns) + + column_suggestions = suggest_based_on_last_token('where', + text_before_cursor, full_text, identifier) + + # Check for a subquery expression (cases 3 & 4) + where = p.tokens[-1] + idx, prev_tok = where.token_prev(len(where.tokens) - 1) + + if isinstance(prev_tok, Comparison): + # e.g. "SELECT foo FROM bar WHERE foo = ANY(" + prev_tok = prev_tok.tokens[-1] + + prev_tok = prev_tok.value.lower() + if prev_tok == 'exists': + return [{'type': 'keyword'}] + else: + return column_suggestions + + # Get the token before the parens + idx, prev_tok = p.token_prev(len(p.tokens) - 1) + if prev_tok and prev_tok.value and prev_tok.value.lower() == 'using': + # tbl1 INNER JOIN tbl2 USING (col1, col2) + tables = extract_tables(full_text) + + # suggest columns that are present in more than one table + return [{'type': 'column', 'tables': tables, 'drop_unique': True}] + elif p.token_first().value.lower() == 'select': + # If the lparen is preceeded by a space chances are we're about to + # do a sub-select. + if last_word(text_before_cursor, + 'all_punctuations').startswith('('): + return [{'type': 'keyword'}] + elif p.token_first().value.lower() == 'show': + return [{'type': 'show'}] + + # We're probably in a function argument list + return [{'type': 'column', 'tables': extract_tables(full_text)}] + elif token_v in ('set', 'order by', 'distinct'): + return [{'type': 'column', 'tables': extract_tables(full_text)}] + elif token_v == 'as': + # Don't suggest anything for an alias + return [] + elif token_v in ('show'): + return [{'type': 'show'}] + elif token_v in ('to',): + p = sqlparse.parse(text_before_cursor)[0] + if p.token_first().value.lower() == 'change': + return [{'type': 'change'}] + else: + return [{'type': 'user'}] + elif token_v in ('user', 'for'): + return [{'type': 'user'}] + elif token_v in ('select', 'where', 'having'): + # Check for a table alias or schema qualification + parent = (identifier and identifier.get_parent_name()) or [] + + tables = extract_tables(full_text) + if parent: + tables = [t for t in tables if identifies(parent, *t)] + return [{'type': 'column', 'tables': tables}, + {'type': 'table', 'schema': parent}, + {'type': 'view', 'schema': parent}, + {'type': 'function', 'schema': parent}] + else: + aliases = [alias or table for (schema, table, alias) in tables] + return [{'type': 'column', 'tables': tables}, + {'type': 'function', 'schema': []}, + {'type': 'alias', 'aliases': aliases}, + {'type': 'keyword'}] + elif (token_v.endswith('join') and token.is_keyword) or (token_v in + ('copy', 'from', 'update', 'into', 'describe', 'truncate', + 'desc', 'explain')): + schema = (identifier and identifier.get_parent_name()) or [] + + # Suggest tables from either the currently-selected schema or the + # public schema if no schema has been specified + suggest = [{'type': 'table', 'schema': schema}] + + if not schema: + # Suggest schemas + suggest.insert(0, {'type': 'schema'}) + + # Only tables can be TRUNCATED, otherwise suggest views + if token_v != 'truncate': + suggest.append({'type': 'view', 'schema': schema}) + + return suggest + + elif token_v in ('table', 'view', 'function'): + # E.g. 'DROP FUNCTION ', 'ALTER TABLE ' + rel_type = token_v + schema = (identifier and identifier.get_parent_name()) or [] + if schema: + return [{'type': rel_type, 'schema': schema}] + else: + return [{'type': 'schema'}, {'type': rel_type, 'schema': []}] + elif token_v == 'on': + tables = extract_tables(full_text) # [(schema, table, alias), ...] + parent = (identifier and identifier.get_parent_name()) or [] + if parent: + # "ON parent." + # parent can be either a schema name or table alias + tables = [t for t in tables if identifies(parent, *t)] + return [{'type': 'column', 'tables': tables}, + {'type': 'table', 'schema': parent}, + {'type': 'view', 'schema': parent}, + {'type': 'function', 'schema': parent}] + else: + # ON + # Use table alias if there is one, otherwise the table name + aliases = [alias or table for (schema, table, alias) in tables] + suggest = [{'type': 'alias', 'aliases': aliases}] + + # The lists of 'aliases' could be empty if we're trying to complete + # a GRANT query. eg: GRANT SELECT, INSERT ON + # In that case we just suggest all tables. + if not aliases: + suggest.append({'type': 'table', 'schema': parent}) + return suggest + + elif token_v in ('use', 'database', 'template', 'connect'): + # "\c ", "DROP DATABASE ", + # "CREATE DATABASE WITH TEMPLATE " + return [{'type': 'database'}] + elif token_v == 'tableformat': + return [{'type': 'table_format'}] + elif token_v.endswith(',') or is_operand(token_v) or token_v in ['=', 'and', 'or']: + prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor) + if prev_keyword: + return suggest_based_on_last_token( + prev_keyword, text_before_cursor, full_text, identifier) + else: + return [] + else: + return [{'type': 'keyword'}] + + +def identifies(id, schema, table, alias): + return id == alias or id == table or ( + schema and (id == schema + '.' + table)) diff --git a/mycli/packages/filepaths.py b/mycli/packages/filepaths.py new file mode 100644 index 0000000..79fe26d --- /dev/null +++ b/mycli/packages/filepaths.py @@ -0,0 +1,106 @@ +import os +import platform + + +if os.name == "posix": + if platform.system() == "Darwin": + DEFAULT_SOCKET_DIRS = ("/tmp",) + else: + DEFAULT_SOCKET_DIRS = ("/var/run", "/var/lib") +else: + DEFAULT_SOCKET_DIRS = () + + +def list_path(root_dir): + """List directory if exists. + + :param root_dir: str + :return: list + + """ + res = [] + if os.path.isdir(root_dir): + for name in os.listdir(root_dir): + res.append(name) + return res + + +def complete_path(curr_dir, last_dir): + """Return the path to complete that matches the last entered component. + + If the last entered component is ~, expanded path would not + match, so return all of the available paths. + + :param curr_dir: str + :param last_dir: str + :return: str + + """ + if not last_dir or curr_dir.startswith(last_dir): + return curr_dir + elif last_dir == '~': + return os.path.join(last_dir, curr_dir) + + +def parse_path(root_dir): + """Split path into head and last component for the completer. + + Also return position where last component starts. + + :param root_dir: str path + :return: tuple of (string, string, int) + + """ + base_dir, last_dir, position = '', '', 0 + if root_dir: + base_dir, last_dir = os.path.split(root_dir) + position = -len(last_dir) if last_dir else 0 + return base_dir, last_dir, position + + +def suggest_path(root_dir): + """List all files and subdirectories in a directory. + + If the directory is not specified, suggest root directory, + user directory, current and parent directory. + + :param root_dir: string: directory to list + :return: list + + """ + if not root_dir: + return [os.path.abspath(os.sep), '~', os.curdir, os.pardir] + + if '~' in root_dir: + root_dir = os.path.expanduser(root_dir) + + if not os.path.exists(root_dir): + root_dir, _ = os.path.split(root_dir) + + return list_path(root_dir) + + +def dir_path_exists(path): + """Check if the directory path exists for a given file. + + For example, for a file /home/user/.cache/mycli/log, check if + /home/user/.cache/mycli exists. + + :param str path: The file path. + :return: Whether or not the directory path exists. + + """ + return os.path.exists(os.path.dirname(path)) + + +def guess_socket_location(): + """Try to guess the location of the default mysql socket file.""" + socket_dirs = filter(os.path.exists, DEFAULT_SOCKET_DIRS) + for directory in socket_dirs: + for r, dirs, files in os.walk(directory, topdown=True): + for filename in files: + name, ext = os.path.splitext(filename) + if name.startswith("mysql") and ext in ('.socket', '.sock'): + return os.path.join(r, filename) + dirs[:] = [d for d in dirs if d.startswith("mysql")] + return None diff --git a/mycli/packages/paramiko_stub/__init__.py b/mycli/packages/paramiko_stub/__init__.py new file mode 100644 index 0000000..045b00e --- /dev/null +++ b/mycli/packages/paramiko_stub/__init__.py @@ -0,0 +1,28 @@ +"""A module to import instead of paramiko when it is not available (to avoid +checking for paramiko all over the place). + +When paramiko is first envoked, it simply shuts down mycli, telling +user they either have to install paramiko or should not use SSH +features. + +""" + + +class Paramiko: + def __getattr__(self, name): + import sys + from textwrap import dedent + print(dedent(""" + To enable certain SSH features you need to install paramiko: + + pip install paramiko + + It is required for the following configuration options: + --list-ssh-config + --ssh-config-host + --ssh-host + """)) + sys.exit(1) + + +paramiko = Paramiko() diff --git a/mycli/packages/parseutils.py b/mycli/packages/parseutils.py new file mode 100644 index 0000000..e3b383e --- /dev/null +++ b/mycli/packages/parseutils.py @@ -0,0 +1,267 @@ +import re +import sqlparse +from sqlparse.sql import IdentifierList, Identifier, Function +from sqlparse.tokens import Keyword, DML, Punctuation + +cleanup_regex = { + # This matches only alphanumerics and underscores. + 'alphanum_underscore': re.compile(r'(\w+)$'), + # This matches everything except spaces, parens, colon, and comma + 'many_punctuations': re.compile(r'([^():,\s]+)$'), + # This matches everything except spaces, parens, colon, comma, and period + 'most_punctuations': re.compile(r'([^\.():,\s]+)$'), + # This matches everything except a space. + 'all_punctuations': re.compile('([^\s]+)$'), + } + +def last_word(text, include='alphanum_underscore'): + """ + Find the last word in a sentence. + + >>> last_word('abc') + 'abc' + >>> last_word(' abc') + 'abc' + >>> last_word('') + '' + >>> last_word(' ') + '' + >>> last_word('abc ') + '' + >>> last_word('abc def') + 'def' + >>> last_word('abc def ') + '' + >>> last_word('abc def;') + '' + >>> last_word('bac $def') + 'def' + >>> last_word('bac $def', include='most_punctuations') + '$def' + >>> last_word('bac \def', include='most_punctuations') + '\\\\def' + >>> last_word('bac \def;', include='most_punctuations') + '\\\\def;' + >>> last_word('bac::def', include='most_punctuations') + 'def' + """ + + if not text: # Empty string + return '' + + if text[-1].isspace(): + return '' + else: + regex = cleanup_regex[include] + matches = regex.search(text) + if matches: + return matches.group(0) + else: + return '' + + +# This code is borrowed from sqlparse example script. +# +def is_subselect(parsed): + if not parsed.is_group: + return False + for item in parsed.tokens: + if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT', + 'UPDATE', 'CREATE', 'DELETE'): + return True + return False + +def extract_from_part(parsed, stop_at_punctuation=True): + tbl_prefix_seen = False + for item in parsed.tokens: + if tbl_prefix_seen: + if is_subselect(item): + for x in extract_from_part(item, stop_at_punctuation): + yield x + elif stop_at_punctuation and item.ttype is Punctuation: + return + # An incomplete nested select won't be recognized correctly as a + # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes + # the second FROM to trigger this elif condition resulting in a + # StopIteration. So we need to ignore the keyword if the keyword + # FROM. + # Also 'SELECT * FROM abc JOIN def' will trigger this elif + # condition. So we need to ignore the keyword JOIN and its variants + # INNER JOIN, FULL OUTER JOIN, etc. + elif item.ttype is Keyword and ( + not item.value.upper() == 'FROM') and ( + not item.value.upper().endswith('JOIN')): + return + else: + yield item + elif ((item.ttype is Keyword or item.ttype is Keyword.DML) and + item.value.upper() in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE', 'JOIN',)): + tbl_prefix_seen = True + # 'SELECT a, FROM abc' will detect FROM as part of the column list. + # So this check here is necessary. + elif isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + if (identifier.ttype is Keyword and + identifier.value.upper() == 'FROM'): + tbl_prefix_seen = True + break + +def extract_table_identifiers(token_stream): + """yields tuples of (schema_name, table_name, table_alias)""" + + for item in token_stream: + if isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + # Sometimes Keywords (such as FROM ) are classified as + # identifiers which don't have the get_real_name() method. + try: + schema_name = identifier.get_parent_name() + real_name = identifier.get_real_name() + except AttributeError: + continue + if real_name: + yield (schema_name, real_name, identifier.get_alias()) + elif isinstance(item, Identifier): + real_name = item.get_real_name() + schema_name = item.get_parent_name() + + if real_name: + yield (schema_name, real_name, item.get_alias()) + else: + name = item.get_name() + yield (None, name, item.get_alias() or name) + elif isinstance(item, Function): + yield (None, item.get_name(), item.get_name()) + +# extract_tables is inspired from examples in the sqlparse lib. +def extract_tables(sql): + """Extract the table names from an SQL statment. + + Returns a list of (schema, table, alias) tuples + + """ + parsed = sqlparse.parse(sql) + if not parsed: + return [] + + # INSERT statements must stop looking for tables at the sign of first + # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2) + # abc is the table name, but if we don't stop at the first lparen, then + # we'll identify abc, col1 and col2 as table names. + insert_stmt = parsed[0].token_first().value.lower() == 'insert' + stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt) + return list(extract_table_identifiers(stream)) + +def find_prev_keyword(sql): + """ Find the last sql keyword in an SQL statement + + Returns the value of the last keyword, and the text of the query with + everything after the last keyword stripped + """ + if not sql.strip(): + return None, '' + + parsed = sqlparse.parse(sql)[0] + flattened = list(parsed.flatten()) + + logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN') + + for t in reversed(flattened): + if t.value == '(' or (t.is_keyword and ( + t.value.upper() not in logical_operators)): + # Find the location of token t in the original parsed statement + # We can't use parsed.token_index(t) because t may be a child token + # inside a TokenList, in which case token_index thows an error + # Minimal example: + # p = sqlparse.parse('select * from foo where bar') + # t = list(p.flatten())[-3] # The "Where" token + # p.token_index(t) # Throws ValueError: not in list + idx = flattened.index(t) + + # Combine the string values of all tokens in the original list + # up to and including the target keyword token t, to produce a + # query string with everything after the keyword token removed + text = ''.join(tok.value for tok in flattened[:idx+1]) + return t, text + + return None, '' + + +def query_starts_with(query, prefixes): + """Check if the query starts with any item from *prefixes*.""" + prefixes = [prefix.lower() for prefix in prefixes] + formatted_sql = sqlparse.format(query.lower(), strip_comments=True) + return bool(formatted_sql) and formatted_sql.split()[0] in prefixes + + +def queries_start_with(queries, prefixes): + """Check if any queries start with any item from *prefixes*.""" + for query in sqlparse.split(queries): + if query and query_starts_with(query, prefixes) is True: + return True + return False + + +def query_has_where_clause(query): + """Check if the query contains a where-clause.""" + return any( + isinstance(token, sqlparse.sql.Where) + for token_list in sqlparse.parse(query) + for token in token_list + ) + + +def is_destructive(queries): + """Returns if any of the queries in *queries* is destructive.""" + keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter') + for query in sqlparse.split(queries): + if query: + if query_starts_with(query, keywords) is True: + return True + elif query_starts_with( + query, ['update'] + ) is True and not query_has_where_clause(query): + return True + + return False + + +def is_open_quote(sql): + """Returns true if the query contains an unclosed quote.""" + + # parsed can contain one or more semi-colon separated commands + parsed = sqlparse.parse(sql) + return any(_parsed_is_open_quote(p) for p in parsed) + + +if __name__ == '__main__': + sql = 'select * from (select t. from tabl t' + print (extract_tables(sql)) + + +def is_dropping_database(queries, dbname): + """Determine if the query is dropping a specific database.""" + result = False + if dbname is None: + return False + + def normalize_db_name(db): + return db.lower().strip('`"') + + dbname = normalize_db_name(dbname) + + for query in sqlparse.parse(queries): + keywords = [t for t in query.tokens if t.is_keyword] + if len(keywords) < 2: + continue + if keywords[0].normalized in ("DROP", "CREATE") and keywords[1].value.lower() in ( + "database", + "schema", + ): + database_token = next( + (t for t in query.tokens if isinstance(t, Identifier)), None + ) + if database_token is not None and normalize_db_name(database_token.get_name()) == dbname: + result = keywords[0].normalized == "DROP" + else: + return result diff --git a/mycli/packages/prompt_utils.py b/mycli/packages/prompt_utils.py new file mode 100644 index 0000000..fb1e431 --- /dev/null +++ b/mycli/packages/prompt_utils.py @@ -0,0 +1,54 @@ +import sys +import click +from .parseutils import is_destructive + + +class ConfirmBoolParamType(click.ParamType): + name = 'confirmation' + + def convert(self, value, param, ctx): + if isinstance(value, bool): + return bool(value) + value = value.lower() + if value in ('yes', 'y'): + return True + elif value in ('no', 'n'): + return False + self.fail('%s is not a valid boolean' % value, param, ctx) + + def __repr__(self): + return 'BOOL' + + +BOOLEAN_TYPE = ConfirmBoolParamType() + + +def confirm_destructive_query(queries): + """Check if the query is destructive and prompts the user to confirm. + + Returns: + * None if the query is non-destructive or we can't prompt the user. + * True if the query is destructive and the user wants to proceed. + * False if the query is destructive and the user doesn't want to proceed. + + """ + prompt_text = ("You're about to run a destructive command.\n" + "Do you want to proceed? (y/n)") + if is_destructive(queries) and sys.stdin.isatty(): + return prompt(prompt_text, type=BOOLEAN_TYPE) + + +def confirm(*args, **kwargs): + """Prompt for confirmation (yes/no) and handle any abort exceptions.""" + try: + return click.confirm(*args, **kwargs) + except click.Abort: + return False + + +def prompt(*args, **kwargs): + """Prompt the user for input and handle any abort exceptions.""" + try: + return click.prompt(*args, **kwargs) + except click.Abort: + return False diff --git a/mycli/packages/special/__init__.py b/mycli/packages/special/__init__.py new file mode 100644 index 0000000..92bcca6 --- /dev/null +++ b/mycli/packages/special/__init__.py @@ -0,0 +1,10 @@ +__all__ = [] + +def export(defn): + """Decorator to explicitly mark functions that are exposed in a lib.""" + globals()[defn.__name__] = defn + __all__.append(defn.__name__) + return defn + +from . import dbcommands +from . import iocommands diff --git a/mycli/packages/special/dbcommands.py b/mycli/packages/special/dbcommands.py new file mode 100644 index 0000000..ed90e4c --- /dev/null +++ b/mycli/packages/special/dbcommands.py @@ -0,0 +1,157 @@ +import logging +import os +import platform +from mycli import __version__ +from mycli.packages.special import iocommands +from mycli.packages.special.utils import format_uptime +from .main import special_command, RAW_QUERY, PARSED_QUERY +from pymysql import ProgrammingError + +log = logging.getLogger(__name__) + + +@special_command('\\dt', '\\dt[+] [table]', 'List or describe tables.', + arg_type=PARSED_QUERY, case_sensitive=True) +def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False): + if arg: + query = 'SHOW FIELDS FROM {0}'.format(arg) + else: + query = 'SHOW TABLES' + log.debug(query) + cur.execute(query) + tables = cur.fetchall() + status = '' + if cur.description: + headers = [x[0] for x in cur.description] + else: + return [(None, None, None, '')] + + if verbose and arg: + query = 'SHOW CREATE TABLE {0}'.format(arg) + log.debug(query) + cur.execute(query) + status = cur.fetchone()[1] + + return [(None, tables, headers, status)] + +@special_command('\\l', '\\l', 'List databases.', arg_type=RAW_QUERY, case_sensitive=True) +def list_databases(cur, **_): + query = 'SHOW DATABASES' + log.debug(query) + cur.execute(query) + if cur.description: + headers = [x[0] for x in cur.description] + return [(None, cur, headers, '')] + else: + return [(None, None, None, '')] + +@special_command('status', '\\s', 'Get status information from the server.', + arg_type=RAW_QUERY, aliases=('\\s', ), case_sensitive=True) +def status(cur, **_): + query = 'SHOW GLOBAL STATUS;' + log.debug(query) + try: + cur.execute(query) + except ProgrammingError: + # Fallback in case query fail, as it does with Mysql 4 + query = 'SHOW STATUS;' + log.debug(query) + cur.execute(query) + status = dict(cur.fetchall()) + + query = 'SHOW GLOBAL VARIABLES;' + log.debug(query) + cur.execute(query) + variables = dict(cur.fetchall()) + + # prepare in case keys are bytes, as with Python 3 and Mysql 4 + if (isinstance(list(variables)[0], bytes) and + isinstance(list(status)[0], bytes)): + variables = {k.decode('utf-8'): v.decode('utf-8') for k, v + in variables.items()} + status = {k.decode('utf-8'): v.decode('utf-8') for k, v + in status.items()} + + # Create output buffers. + title = [] + output = [] + footer = [] + + title.append('--------------') + + # Output the mycli client information. + implementation = platform.python_implementation() + version = platform.python_version() + client_info = [] + client_info.append('mycli {0},'.format(__version__)) + client_info.append('running on {0} {1}'.format(implementation, version)) + title.append(' '.join(client_info) + '\n') + + # Build the output that will be displayed as a table. + output.append(('Connection id:', cur.connection.thread_id())) + + query = 'SELECT DATABASE(), USER();' + log.debug(query) + cur.execute(query) + db, user = cur.fetchone() + if db is None: + db = '' + + output.append(('Current database:', db)) + output.append(('Current user:', user)) + + if iocommands.is_pager_enabled(): + if 'PAGER' in os.environ: + pager = os.environ['PAGER'] + else: + pager = 'System default' + else: + pager = 'stdout' + output.append(('Current pager:', pager)) + + output.append(('Server version:', '{0} {1}'.format( + variables['version'], variables['version_comment']))) + output.append(('Protocol version:', variables['protocol_version'])) + + if 'unix' in cur.connection.host_info.lower(): + host_info = cur.connection.host_info + else: + host_info = '{0} via TCP/IP'.format(cur.connection.host) + + output.append(('Connection:', host_info)) + + query = ('SELECT @@character_set_server, @@character_set_database, ' + '@@character_set_client, @@character_set_connection LIMIT 1;') + log.debug(query) + cur.execute(query) + charset = cur.fetchone() + output.append(('Server characterset:', charset[0])) + output.append(('Db characterset:', charset[1])) + output.append(('Client characterset:', charset[2])) + output.append(('Conn. characterset:', charset[3])) + + if 'TCP/IP' in host_info: + output.append(('TCP port:', cur.connection.port)) + else: + output.append(('UNIX socket:', variables['socket'])) + + output.append(('Uptime:', format_uptime(status['Uptime']))) + + # Print the current server statistics. + stats = [] + stats.append('Connections: {0}'.format(status['Threads_connected'])) + if 'Queries' in status: + stats.append('Queries: {0}'.format(status['Queries'])) + stats.append('Slow queries: {0}'.format(status['Slow_queries'])) + stats.append('Opens: {0}'.format(status['Opened_tables'])) + stats.append('Flush tables: {0}'.format(status['Flush_commands'])) + stats.append('Open tables: {0}'.format(status['Open_tables'])) + if 'Queries' in status: + queries_per_second = int(status['Queries']) / int(status['Uptime']) + stats.append('Queries per second avg: {:.3f}'.format( + queries_per_second)) + stats = ' '.join(stats) + footer.append('\n' + stats) + + footer.append('--------------') + return [('\n'.join(title), output, '', '\n'.join(footer))] diff --git a/mycli/packages/special/delimitercommand.py b/mycli/packages/special/delimitercommand.py new file mode 100644 index 0000000..994b134 --- /dev/null +++ b/mycli/packages/special/delimitercommand.py @@ -0,0 +1,80 @@ +import re +import sqlparse + + +class DelimiterCommand(object): + def __init__(self): + self._delimiter = ';' + + def _split(self, sql): + """Temporary workaround until sqlparse.split() learns about custom + delimiters.""" + + placeholder = "\ufffc" # unicode object replacement character + + if self._delimiter == ';': + return sqlparse.split(sql) + + # We must find a string that original sql does not contain. + # Most likely, our placeholder is enough, but if not, keep looking + while placeholder in sql: + placeholder += placeholder[0] + sql = sql.replace(';', placeholder) + sql = sql.replace(self._delimiter, ';') + + split = sqlparse.split(sql) + + return [ + stmt.replace(';', self._delimiter).replace(placeholder, ';') + for stmt in split + ] + + def queries_iter(self, input): + """Iterate over queries in the input string.""" + + queries = self._split(input) + while queries: + for sql in queries: + delimiter = self._delimiter + sql = queries.pop(0) + if sql.endswith(delimiter): + trailing_delimiter = True + sql = sql.strip(delimiter) + else: + trailing_delimiter = False + + yield sql + + # if the delimiter was changed by the last command, + # re-split everything, and if we previously stripped + # the delimiter, append it to the end + if self._delimiter != delimiter: + combined_statement = ' '.join([sql] + queries) + if trailing_delimiter: + combined_statement += delimiter + queries = self._split(combined_statement)[1:] + + def set(self, arg, **_): + """Change delimiter. + + Since `arg` is everything that follows the DELIMITER token + after sqlparse (it may include other statements separated by + the new delimiter), we want to set the delimiter to the first + word of it. + + """ + match = arg and re.search(r'[^\s]+', arg) + if not match: + message = 'Missing required argument, delimiter' + return [(None, None, None, message)] + + delimiter = match.group() + if delimiter.lower() == 'delimiter': + return [(None, None, None, 'Invalid delimiter "delimiter"')] + + self._delimiter = delimiter + return [(None, None, None, "Changed delimiter to {}".format(delimiter))] + + @property + def current(self): + return self._delimiter diff --git a/mycli/packages/special/favoritequeries.py b/mycli/packages/special/favoritequeries.py new file mode 100644 index 0000000..0b91400 --- /dev/null +++ b/mycli/packages/special/favoritequeries.py @@ -0,0 +1,63 @@ +class FavoriteQueries(object): + + section_name = 'favorite_queries' + + usage = ''' +Favorite Queries are a way to save frequently used queries +with a short name. +Examples: + + # Save a new favorite query. + > \\fs simple select * from abc where a is not Null; + + # List all favorite queries. + > \\f + ╒════════╤═══════════════════════════════════════╕ + │ Name │ Query │ + ╞════════╪═══════════════════════════════════════╡ + │ simple │ SELECT * FROM abc where a is not NULL │ + ╘════════╧═══════════════════════════════════════╛ + + # Run a favorite query. + > \\f simple + ╒════════╤════════╕ + │ a │ b │ + ╞════════╪════════╡ + │ 日本語 │ 日本語 │ + ╘════════╧════════╛ + + # Delete a favorite query. + > \\fd simple + simple: Deleted +''' + + # Class-level variable, for convenience to use as a singleton. + instance = None + + def __init__(self, config): + self.config = config + + @classmethod + def from_config(cls, config): + return FavoriteQueries(config) + + def list(self): + return self.config.get(self.section_name, []) + + def get(self, name): + return self.config.get(self.section_name, {}).get(name, None) + + def save(self, name, query): + self.config.encoding = 'utf-8' + if self.section_name not in self.config: + self.config[self.section_name] = {} + self.config[self.section_name][name] = query + self.config.write() + + def delete(self, name): + try: + del self.config[self.section_name][name] + except KeyError: + return '%s: Not Found.' % name + self.config.write() + return '%s: Deleted' % name diff --git a/mycli/packages/special/iocommands.py b/mycli/packages/special/iocommands.py new file mode 100644 index 0000000..11dca8d --- /dev/null +++ b/mycli/packages/special/iocommands.py @@ -0,0 +1,453 @@ +import os +import re +import locale +import logging +import subprocess +import shlex +from io import open +from time import sleep + +import click +import sqlparse + +from . import export +from .main import special_command, NO_QUERY, PARSED_QUERY +from .favoritequeries import FavoriteQueries +from .delimitercommand import DelimiterCommand +from .utils import handle_cd_command +from mycli.packages.prompt_utils import confirm_destructive_query + +TIMING_ENABLED = False +use_expanded_output = False +PAGER_ENABLED = True +tee_file = None +once_file = None +written_to_once_file = False +delimiter_command = DelimiterCommand() + + +@export +def set_timing_enabled(val): + global TIMING_ENABLED + TIMING_ENABLED = val + +@export +def set_pager_enabled(val): + global PAGER_ENABLED + PAGER_ENABLED = val + + +@export +def is_pager_enabled(): + return PAGER_ENABLED + +@export +@special_command('pager', '\\P [command]', + 'Set PAGER. Print the query results via PAGER.', + arg_type=PARSED_QUERY, aliases=('\\P', ), case_sensitive=True) +def set_pager(arg, **_): + if arg: + os.environ['PAGER'] = arg + msg = 'PAGER set to %s.' % arg + set_pager_enabled(True) + else: + if 'PAGER' in os.environ: + msg = 'PAGER set to %s.' % os.environ['PAGER'] + else: + # This uses click's default per echo_via_pager. + msg = 'Pager enabled.' + set_pager_enabled(True) + + return [(None, None, None, msg)] + +@export +@special_command('nopager', '\\n', 'Disable pager, print to stdout.', + arg_type=NO_QUERY, aliases=('\\n', ), case_sensitive=True) +def disable_pager(): + set_pager_enabled(False) + return [(None, None, None, 'Pager disabled.')] + +@special_command('\\timing', '\\t', 'Toggle timing of commands.', arg_type=NO_QUERY, aliases=('\\t', ), case_sensitive=True) +def toggle_timing(): + global TIMING_ENABLED + TIMING_ENABLED = not TIMING_ENABLED + message = "Timing is " + message += "on." if TIMING_ENABLED else "off." + return [(None, None, None, message)] + +@export +def is_timing_enabled(): + return TIMING_ENABLED + +@export +def set_expanded_output(val): + global use_expanded_output + use_expanded_output = val + +@export +def is_expanded_output(): + return use_expanded_output + +_logger = logging.getLogger(__name__) + +@export +def editor_command(command): + """ + Is this an external editor command? + :param command: string + """ + # It is possible to have `\e filename` or `SELECT * FROM \e`. So we check + # for both conditions. + return command.strip().endswith('\\e') or command.strip().startswith('\\e') + +@export +def get_filename(sql): + if sql.strip().startswith('\\e'): + command, _, filename = sql.partition(' ') + return filename.strip() or None + + +@export +def get_editor_query(sql): + """Get the query part of an editor command.""" + sql = sql.strip() + + # The reason we can't simply do .strip('\e') is that it strips characters, + # not a substring. So it'll strip "e" in the end of the sql also! + # Ex: "select * from style\e" -> "select * from styl". + pattern = re.compile('(^\\\e|\\\e$)') + while pattern.search(sql): + sql = pattern.sub('', sql) + + return sql + + +@export +def open_external_editor(filename=None, sql=None): + """Open external editor, wait for the user to type in their query, return + the query. + + :return: list with one tuple, query as first element. + + """ + + message = None + filename = filename.strip().split(' ', 1)[0] if filename else None + + sql = sql or '' + MARKER = '# Type your query above this line.\n' + + # Populate the editor buffer with the partial sql (if available) and a + # placeholder comment. + query = click.edit(u'{sql}\n\n{marker}'.format(sql=sql, marker=MARKER), + filename=filename, extension='.sql') + + if filename: + try: + with open(filename) as f: + query = f.read() + except IOError: + message = 'Error reading file: %s.' % filename + + if query is not None: + query = query.split(MARKER, 1)[0].rstrip('\n') + else: + # Don't return None for the caller to deal with. + # Empty string is ok. + query = sql + + return (query, message) + + +@special_command('\\f', '\\f [name [args..]]', 'List or execute favorite queries.', arg_type=PARSED_QUERY, case_sensitive=True) +def execute_favorite_query(cur, arg, **_): + """Returns (title, rows, headers, status)""" + if arg == '': + for result in list_favorite_queries(): + yield result + + """Parse out favorite name and optional substitution parameters""" + name, _, arg_str = arg.partition(' ') + args = shlex.split(arg_str) + + query = FavoriteQueries.instance.get(name) + if query is None: + message = "No favorite query: %s" % (name) + yield (None, None, None, message) + else: + query, arg_error = subst_favorite_query_args(query, args) + if arg_error: + yield (None, None, None, arg_error) + else: + for sql in sqlparse.split(query): + sql = sql.rstrip(';') + title = '> %s' % (sql) + cur.execute(sql) + if cur.description: + headers = [x[0] for x in cur.description] + yield (title, cur, headers, None) + else: + yield (title, None, None, None) + +def list_favorite_queries(): + """List of all favorite queries. + Returns (title, rows, headers, status)""" + + headers = ["Name", "Query"] + rows = [(r, FavoriteQueries.instance.get(r)) + for r in FavoriteQueries.instance.list()] + + if not rows: + status = '\nNo favorite queries found.' + FavoriteQueries.instance.usage + else: + status = '' + return [('', rows, headers, status)] + + +def subst_favorite_query_args(query, args): + """replace positional parameters ($1...$N) in query.""" + for idx, val in enumerate(args): + subst_var = '$' + str(idx + 1) + if subst_var not in query: + return [None, 'query does not have substitution parameter ' + subst_var + ':\n ' + query] + + query = query.replace(subst_var, val) + + match = re.search('\\$\d+', query) + if match: + return[None, 'missing substitution for ' + match.group(0) + ' in query:\n ' + query] + + return [query, None] + +@special_command('\\fs', '\\fs name query', 'Save a favorite query.') +def save_favorite_query(arg, **_): + """Save a new favorite query. + Returns (title, rows, headers, status)""" + + usage = 'Syntax: \\fs name query.\n\n' + FavoriteQueries.instance.usage + if not arg: + return [(None, None, None, usage)] + + name, _, query = arg.partition(' ') + + # If either name or query is missing then print the usage and complain. + if (not name) or (not query): + return [(None, None, None, + usage + 'Err: Both name and query are required.')] + + FavoriteQueries.instance.save(name, query) + return [(None, None, None, "Saved.")] + + +@special_command('\\fd', '\\fd [name]', 'Delete a favorite query.') +def delete_favorite_query(arg, **_): + """Delete an existing favorite query.""" + usage = 'Syntax: \\fd name.\n\n' + FavoriteQueries.instance.usage + if not arg: + return [(None, None, None, usage)] + + status = FavoriteQueries.instance.delete(arg) + + return [(None, None, None, status)] + + +@special_command('system', 'system [command]', + 'Execute a system shell commmand.') +def execute_system_command(arg, **_): + """Execute a system shell command.""" + usage = "Syntax: system [command].\n" + + if not arg: + return [(None, None, None, usage)] + + try: + command = arg.strip() + if command.startswith('cd'): + ok, error_message = handle_cd_command(arg) + if not ok: + return [(None, None, None, error_message)] + return [(None, None, None, '')] + + args = arg.split(' ') + process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + output, error = process.communicate() + response = output if not error else error + + # Python 3 returns bytes. This needs to be decoded to a string. + if isinstance(response, bytes): + encoding = locale.getpreferredencoding(False) + response = response.decode(encoding) + + return [(None, None, None, response)] + except OSError as e: + return [(None, None, None, 'OSError: %s' % e.strerror)] + + +def parseargfile(arg): + if arg.startswith('-o '): + mode = "w" + filename = arg[3:] + else: + mode = 'a' + filename = arg + + if not filename: + raise TypeError('You must provide a filename.') + + return {'file': os.path.expanduser(filename), 'mode': mode} + + +@special_command('tee', 'tee [-o] filename', + 'Append all results to an output file (overwrite using -o).') +def set_tee(arg, **_): + global tee_file + + try: + tee_file = open(**parseargfile(arg)) + except (IOError, OSError) as e: + raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror)) + + return [(None, None, None, "")] + +@export +def close_tee(): + global tee_file + if tee_file: + tee_file.close() + tee_file = None + + +@special_command('notee', 'notee', 'Stop writing results to an output file.') +def no_tee(arg, **_): + close_tee() + return [(None, None, None, "")] + +@export +def write_tee(output): + global tee_file + if tee_file: + click.echo(output, file=tee_file, nl=False) + click.echo(u'\n', file=tee_file, nl=False) + tee_file.flush() + + +@special_command('\\once', '\\o [-o] filename', + 'Append next result to an output file (overwrite using -o).', + aliases=('\\o', )) +def set_once(arg, **_): + global once_file, written_to_once_file + + once_file = parseargfile(arg) + written_to_once_file = False + + return [(None, None, None, "")] + + +@export +def write_once(output): + global once_file, written_to_once_file + if output and once_file: + try: + f = open(**once_file) + except (IOError, OSError) as e: + once_file = None + raise OSError("Cannot write to file '{}': {}".format( + e.filename, e.strerror)) + with f: + click.echo(output, file=f, nl=False) + click.echo(u"\n", file=f, nl=False) + written_to_once_file = True + + +@export +def unset_once_if_written(): + """Unset the once file, if it has been written to.""" + global once_file + if written_to_once_file: + once_file = None + + +@special_command( + 'watch', + 'watch [seconds] [-c] query', + 'Executes the query every [seconds] seconds (by default 5).' +) +def watch_query(arg, **kwargs): + usage = """Syntax: watch [seconds] [-c] query. + * seconds: The interval at the query will be repeated, in seconds. + By default 5. + * -c: Clears the screen between every iteration. +""" + if not arg: + yield (None, None, None, usage) + return + seconds = 5 + clear_screen = False + statement = None + while statement is None: + arg = arg.strip() + if not arg: + # Oops, we parsed all the arguments without finding a statement + yield (None, None, None, usage) + return + (current_arg, _, arg) = arg.partition(' ') + try: + seconds = float(current_arg) + continue + except ValueError: + pass + if current_arg == '-c': + clear_screen = True + continue + statement = '{0!s} {1!s}'.format(current_arg, arg) + destructive_prompt = confirm_destructive_query(statement) + if destructive_prompt is False: + click.secho("Wise choice!") + return + elif destructive_prompt is True: + click.secho("Your call!") + cur = kwargs['cur'] + sql_list = [ + (sql.rstrip(';'), "> {0!s}".format(sql)) + for sql in sqlparse.split(statement) + ] + old_pager_enabled = is_pager_enabled() + while True: + if clear_screen: + click.clear() + try: + # Somewhere in the code the pager its activated after every yield, + # so we disable it in every iteration + set_pager_enabled(False) + for (sql, title) in sql_list: + cur.execute(sql) + if cur.description: + headers = [x[0] for x in cur.description] + yield (title, cur, headers, None) + else: + yield (title, None, None, None) + sleep(seconds) + except KeyboardInterrupt: + # This prints the Ctrl-C character in its own line, which prevents + # to print a line with the cursor positioned behind the prompt + click.secho("", nl=True) + return + finally: + set_pager_enabled(old_pager_enabled) + + +@export +@special_command('delimiter', None, 'Change SQL delimiter.') +def set_delimiter(arg, **_): + return delimiter_command.set(arg) + + +@export +def get_current_delimiter(): + return delimiter_command.current + + +@export +def split_queries(input): + for query in delimiter_command.queries_iter(input): + yield query diff --git a/mycli/packages/special/main.py b/mycli/packages/special/main.py new file mode 100644 index 0000000..dddba66 --- /dev/null +++ b/mycli/packages/special/main.py @@ -0,0 +1,118 @@ +import logging +from collections import namedtuple + +from . import export + +log = logging.getLogger(__name__) + +NO_QUERY = 0 +PARSED_QUERY = 1 +RAW_QUERY = 2 + +SpecialCommand = namedtuple('SpecialCommand', + ['handler', 'command', 'shortcut', 'description', 'arg_type', 'hidden', + 'case_sensitive']) + +COMMANDS = {} + +@export +class CommandNotFound(Exception): + pass + +@export +def parse_special_command(sql): + command, _, arg = sql.partition(' ') + verbose = '+' in command + command = command.strip().replace('+', '') + return (command, verbose, arg.strip()) + +@export +def special_command(command, shortcut, description, arg_type=PARSED_QUERY, + hidden=False, case_sensitive=False, aliases=()): + def wrapper(wrapped): + register_special_command(wrapped, command, shortcut, description, + arg_type, hidden, case_sensitive, aliases) + return wrapped + return wrapper + +@export +def register_special_command(handler, command, shortcut, description, + arg_type=PARSED_QUERY, hidden=False, case_sensitive=False, aliases=()): + cmd = command.lower() if not case_sensitive else command + COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description, + arg_type, hidden, case_sensitive) + for alias in aliases: + cmd = alias.lower() if not case_sensitive else alias + COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description, + arg_type, case_sensitive=case_sensitive, + hidden=True) + +@export +def execute(cur, sql): + """Execute a special command and return the results. If the special command + is not supported a KeyError will be raised. + """ + command, verbose, arg = parse_special_command(sql) + + if (command not in COMMANDS) and (command.lower() not in COMMANDS): + raise CommandNotFound + + try: + special_cmd = COMMANDS[command] + except KeyError: + special_cmd = COMMANDS[command.lower()] + if special_cmd.case_sensitive: + raise CommandNotFound('Command not found: %s' % command) + + # "help is a special case. We want built-in help, not + # mycli help here. + if command == 'help' and arg: + return show_keyword_help(cur=cur, arg=arg) + + if special_cmd.arg_type == NO_QUERY: + return special_cmd.handler() + elif special_cmd.arg_type == PARSED_QUERY: + return special_cmd.handler(cur=cur, arg=arg, verbose=verbose) + elif special_cmd.arg_type == RAW_QUERY: + return special_cmd.handler(cur=cur, query=sql) + +@special_command('help', '\\?', 'Show this help.', arg_type=NO_QUERY, aliases=('\\?', '?')) +def show_help(): # All the parameters are ignored. + headers = ['Command', 'Shortcut', 'Description'] + result = [] + + for _, value in sorted(COMMANDS.items()): + if not value.hidden: + result.append((value.command, value.shortcut, value.description)) + return [(None, result, headers, None)] + +def show_keyword_help(cur, arg): + """ + Call the built-in "show ", to display help for an SQL keyword. + :param cur: cursor + :param arg: string + :return: list + """ + keyword = arg.strip('"').strip("'") + query = "help '{0}'".format(keyword) + log.debug(query) + cur.execute(query) + if cur.description and cur.rowcount > 0: + headers = [x[0] for x in cur.description] + return [(None, cur, headers, '')] + else: + return [(None, None, None, 'No help found for {0}.'.format(keyword))] + + +@special_command('exit', '\\q', 'Exit.', arg_type=NO_QUERY, aliases=('\\q', )) +@special_command('quit', '\\q', 'Quit.', arg_type=NO_QUERY) +def quit(*_args): + raise EOFError + + +@special_command('\\e', '\\e', 'Edit command with editor (uses $EDITOR).', + arg_type=NO_QUERY, case_sensitive=True) +@special_command('\\G', '\\G', 'Display current query results vertically.', + arg_type=NO_QUERY, case_sensitive=True) +def stub(): + raise NotImplementedError diff --git a/mycli/packages/special/utils.py b/mycli/packages/special/utils.py new file mode 100644 index 0000000..ef96093 --- /dev/null +++ b/mycli/packages/special/utils.py @@ -0,0 +1,46 @@ +import os +import subprocess + +def handle_cd_command(arg): + """Handles a `cd` shell command by calling python's os.chdir.""" + CD_CMD = 'cd' + tokens = arg.split(CD_CMD + ' ') + directory = tokens[-1] if len(tokens) > 1 else None + if not directory: + return False, "No folder name was provided." + try: + os.chdir(directory) + subprocess.call(['pwd']) + return True, None + except OSError as e: + return False, e.strerror + +def format_uptime(uptime_in_seconds): + """Format number of seconds into human-readable string. + + :param uptime_in_seconds: The server uptime in seconds. + :returns: A human-readable string representing the uptime. + + >>> uptime = format_uptime('56892') + >>> print(uptime) + 15 hours 48 min 12 sec + """ + + m, s = divmod(int(uptime_in_seconds), 60) + h, m = divmod(m, 60) + d, h = divmod(h, 24) + + uptime_values = [] + + for value, unit in ((d, 'days'), (h, 'hours'), (m, 'min'), (s, 'sec')): + if value == 0 and not uptime_values: + # Don't include a value/unit if the unit isn't applicable to + # the uptime. E.g. don't do 0 days 0 hours 1 min 30 sec. + continue + elif value == 1 and unit.endswith('s'): + # Remove the "s" if the unit is singular. + unit = unit[:-1] + uptime_values.append('{0} {1}'.format(value, unit)) + + uptime = ' '.join(uptime_values) + return uptime diff --git a/mycli/packages/tabular_output/__init__.py b/mycli/packages/tabular_output/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mycli/packages/tabular_output/sql_format.py b/mycli/packages/tabular_output/sql_format.py new file mode 100644 index 0000000..730e633 --- /dev/null +++ b/mycli/packages/tabular_output/sql_format.py @@ -0,0 +1,63 @@ +"""Format adapter for sql.""" + +from cli_helpers.utils import filter_dict_by_key +from mycli.packages.parseutils import extract_tables + +supported_formats = ('sql-insert', 'sql-update', 'sql-update-1', + 'sql-update-2', ) + +preprocessors = () + + +def escape_for_sql_statement(value): + if isinstance(value, bytes): + return f"X'{value.hex()}'" + else: + return formatter.mycli.sqlexecute.conn.escape(value) + + +def adapter(data, headers, table_format=None, **kwargs): + tables = extract_tables(formatter.query) + if len(tables) > 0: + table = tables[0] + if table[0]: + table_name = "{}.{}".format(*table[:2]) + else: + table_name = table[1] + else: + table_name = "`DUAL`" + if table_format == 'sql-insert': + h = "`, `".join(headers) + yield "INSERT INTO {} (`{}`) VALUES".format(table_name, h) + prefix = " " + for d in data: + values = ", ".join(escape_for_sql_statement(v) + for i, v in enumerate(d)) + yield "{}({})".format(prefix, values) + if prefix == " ": + prefix = ", " + yield ";" + if table_format.startswith('sql-update'): + s = table_format.split('-') + keys = 1 + if len(s) > 2: + keys = int(s[-1]) + for d in data: + yield "UPDATE {} SET".format(table_name) + prefix = " " + for i, v in enumerate(d[keys:], keys): + yield "{}`{}` = {}".format(prefix, headers[i], escape_for_sql_statement(v)) + if prefix == " ": + prefix = ", " + f = "`{}` = {}" + where = (f.format(headers[i], escape_for_sql_statement( + d[i])) for i in range(keys)) + yield "WHERE {};".format(" AND ".join(where)) + + +def register_new_formatter(TabularOutputFormatter): + global formatter + formatter = TabularOutputFormatter + for sql_format in supported_formats: + TabularOutputFormatter.register_new_formatter( + sql_format, adapter, preprocessors, {'table_format': sql_format}) -- cgit v1.2.3