summaryrefslogtreecommitdiffstats
path: root/mycli
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2021-02-08 11:28:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2021-02-08 11:28:14 +0000
commitb678a621c57a6d3fdfac14bdbbef0ed743ab1742 (patch)
tree5481c14ce75dfda9c55721de033992b45ab0e1dc /mycli
parentInitial commit. (diff)
downloadmycli-b678a621c57a6d3fdfac14bdbbef0ed743ab1742.tar.xz
mycli-b678a621c57a6d3fdfac14bdbbef0ed743ab1742.zip
Adding upstream version 1.22.2.upstream/1.22.2
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'mycli')
-rw-r--r--mycli/AUTHORS79
-rw-r--r--mycli/SPONSORS31
-rw-r--r--mycli/__init__.py1
-rw-r--r--mycli/clibuffer.py54
-rw-r--r--mycli/clistyle.py118
-rw-r--r--mycli/clitoolbar.py52
-rw-r--r--mycli/compat.py6
-rw-r--r--mycli/completion_refresher.py123
-rw-r--r--mycli/config.py286
-rw-r--r--mycli/key_bindings.py85
-rw-r--r--mycli/lexer.py12
-rw-r--r--mycli/magic.py54
-rwxr-xr-xmycli/main.py1326
-rw-r--r--mycli/myclirc121
-rw-r--r--mycli/packages/__init__.py0
-rw-r--r--mycli/packages/completion_engine.py295
-rw-r--r--mycli/packages/filepaths.py106
-rw-r--r--mycli/packages/paramiko_stub/__init__.py28
-rw-r--r--mycli/packages/parseutils.py267
-rw-r--r--mycli/packages/prompt_utils.py54
-rw-r--r--mycli/packages/special/__init__.py10
-rw-r--r--mycli/packages/special/dbcommands.py157
-rw-r--r--mycli/packages/special/delimitercommand.py80
-rw-r--r--mycli/packages/special/favoritequeries.py63
-rw-r--r--mycli/packages/special/iocommands.py453
-rw-r--r--mycli/packages/special/main.py118
-rw-r--r--mycli/packages/special/utils.py46
-rw-r--r--mycli/packages/tabular_output/__init__.py0
-rw-r--r--mycli/packages/tabular_output/sql_format.py63
-rw-r--r--mycli/sqlcompleter.py435
-rw-r--r--mycli/sqlexecute.py313
31 files changed, 4836 insertions, 0 deletions
diff --git a/mycli/AUTHORS b/mycli/AUTHORS
new file mode 100644
index 0000000..b3636d9
--- /dev/null
+++ b/mycli/AUTHORS
@@ -0,0 +1,79 @@
+Project Lead:
+-------------
+ * Thomas Roten
+
+
+Core Developers:
+----------------
+
+ * Irina Truong
+ * Matheus Rosa
+ * Darik Gamble
+ * Dick Marinus
+ * Amjith Ramanujam
+
+Contributors:
+-------------
+
+ * Steve Robbins
+ * Shoma Suzuki
+ * Daniel West
+ * Scrappy Soft
+ * Daniel Black
+ * Jonathan Bruno
+ * Casper Langemeijer
+ * Jonathan Slenders
+ * Artem Bezsmertnyi
+ * Mikhail Borisov
+ * Heath Naylor
+ * Phil Cohen
+ * spacewander
+ * Adam Chainz
+ * Johannes Hoff
+ * Kacper Kwapisz
+ * Lennart Weller
+ * Martijn Engler
+ * Terseus
+ * Tyler Kuipers
+ * William GARCIA
+ * Yasuhiro Matsumoto
+ * bjarnagin
+ * jbruno
+ * mrdeathless
+ * Abirami P
+ * John Sterling
+ * Jialong Liu
+ * Zhidong
+ * Daniël van Eeden
+ * zer09
+ * cxbig
+ * chainkite
+ * Michał Górny
+ * Terje Røsten
+ * Ryan Smith
+ * Klaus Wünschel
+ * François Pietka
+ * Colin Caine
+ * Frederic Aoustin
+ * caitinggui
+ * ushuz
+ * Zhaolong Zhu
+ * Zhongyang Guan
+ * Huachao Mao
+ * QiaoHou Peng
+ * Yang Zou
+ * Angelo Lupo
+ * Aljosha Papsch
+ * Zane C. Bowers-Hadley
+ * Mike Palandra
+ * Georgy Frolov
+ * Jonathan Lloyd
+ * Nathan Huang
+ * Jakub Boukal
+ * Takeshi D. Itoh
+ * laixintao
+
+Creator:
+--------
+
+Amjith Ramanujam
diff --git a/mycli/SPONSORS b/mycli/SPONSORS
new file mode 100644
index 0000000..81b0904
--- /dev/null
+++ b/mycli/SPONSORS
@@ -0,0 +1,31 @@
+Many thanks to the following Kickstarter backers.
+
+* Tech Blue Software
+* jweiland.net
+
+# Silver Sponsors
+
+* Whitane Tech
+* Open Query Pty Ltd
+* Prathap Ramamurthy
+* Lincoln Loop
+
+# Sponsors
+
+* Nathan Taggart
+* Iryna Cherniavska
+* Sudaraka Wijesinghe
+* www.mysqlfanboy.com
+* Steve Robbins
+* Norbert Spichtig
+* orpharion bestheneme
+* Daniel Black
+* Anonymous
+* Magnus udd
+* Anonymous
+* Lewis Peckover
+* Cyrille Tabary
+* Heath Naylor
+* Ted Pennings
+* Chris Anderton
+* Jonathan Slenders
diff --git a/mycli/__init__.py b/mycli/__init__.py
new file mode 100644
index 0000000..53bfe2e
--- /dev/null
+++ b/mycli/__init__.py
@@ -0,0 +1 @@
+__version__ = '1.22.2'
diff --git a/mycli/clibuffer.py b/mycli/clibuffer.py
new file mode 100644
index 0000000..c9d29d1
--- /dev/null
+++ b/mycli/clibuffer.py
@@ -0,0 +1,54 @@
+from prompt_toolkit.enums import DEFAULT_BUFFER
+from prompt_toolkit.filters import Condition
+from prompt_toolkit.application import get_app
+from .packages.parseutils import is_open_quote
+from .packages import special
+
+
+def cli_is_multiline(mycli):
+ @Condition
+ def cond():
+ doc = get_app().layout.get_buffer_by_name(DEFAULT_BUFFER).document
+
+ if not mycli.multi_line:
+ return False
+ else:
+ return not _multiline_exception(doc.text)
+ return cond
+
+
+def _multiline_exception(text):
+ orig = text
+ text = text.strip()
+
+ # Multi-statement favorite query is a special case. Because there will
+ # be a semicolon separating statements, we can't consider semicolon an
+ # EOL. Let's consider an empty line an EOL instead.
+ if text.startswith('\\fs'):
+ return orig.endswith('\n')
+
+ return (
+ # Special Command
+ text.startswith('\\') or
+
+ # Delimiter declaration
+ text.lower().startswith('delimiter') or
+
+ # Ended with the current delimiter (usually a semi-column)
+ text.endswith(special.get_current_delimiter()) or
+
+ text.endswith('\\g') or
+ text.endswith('\\G') or
+
+ # Exit doesn't need semi-column`
+ (text == 'exit') or
+
+ # Quit doesn't need semi-column
+ (text == 'quit') or
+
+ # To all teh vim fans out there
+ (text == ':q') or
+
+ # just a plain enter without any text
+ (text == '')
+ )
diff --git a/mycli/clistyle.py b/mycli/clistyle.py
new file mode 100644
index 0000000..c94f793
--- /dev/null
+++ b/mycli/clistyle.py
@@ -0,0 +1,118 @@
+import logging
+
+import pygments.styles
+from pygments.token import string_to_tokentype, Token
+from pygments.style import Style as PygmentsStyle
+from pygments.util import ClassNotFound
+from prompt_toolkit.styles.pygments import style_from_pygments_cls
+from prompt_toolkit.styles import merge_styles, Style
+
+logger = logging.getLogger(__name__)
+
+# map Pygments tokens (ptk 1.0) to class names (ptk 2.0).
+TOKEN_TO_PROMPT_STYLE = {
+ Token.Menu.Completions.Completion.Current: 'completion-menu.completion.current',
+ Token.Menu.Completions.Completion: 'completion-menu.completion',
+ Token.Menu.Completions.Meta.Current: 'completion-menu.meta.completion.current',
+ Token.Menu.Completions.Meta: 'completion-menu.meta.completion',
+ Token.Menu.Completions.MultiColumnMeta: 'completion-menu.multi-column-meta',
+ Token.Menu.Completions.ProgressButton: 'scrollbar.arrow', # best guess
+ Token.Menu.Completions.ProgressBar: 'scrollbar', # best guess
+ Token.SelectedText: 'selected',
+ Token.SearchMatch: 'search',
+ Token.SearchMatch.Current: 'search.current',
+ Token.Toolbar: 'bottom-toolbar',
+ Token.Toolbar.Off: 'bottom-toolbar.off',
+ Token.Toolbar.On: 'bottom-toolbar.on',
+ Token.Toolbar.Search: 'search-toolbar',
+ Token.Toolbar.Search.Text: 'search-toolbar.text',
+ Token.Toolbar.System: 'system-toolbar',
+ Token.Toolbar.Arg: 'arg-toolbar',
+ Token.Toolbar.Arg.Text: 'arg-toolbar.text',
+ Token.Toolbar.Transaction.Valid: 'bottom-toolbar.transaction.valid',
+ Token.Toolbar.Transaction.Failed: 'bottom-toolbar.transaction.failed',
+ Token.Output.Header: 'output.header',
+ Token.Output.OddRow: 'output.odd-row',
+ Token.Output.EvenRow: 'output.even-row',
+ Token.Prompt: 'prompt',
+ Token.Continuation: 'continuation',
+}
+
+# reverse dict for cli_helpers, because they still expect Pygments tokens.
+PROMPT_STYLE_TO_TOKEN = {
+ v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()
+}
+
+
+def parse_pygments_style(token_name, style_object, style_dict):
+ """Parse token type and style string.
+
+ :param token_name: str name of Pygments token. Example: "Token.String"
+ :param style_object: pygments.style.Style instance to use as base
+ :param style_dict: dict of token names and their styles, customized to this cli
+
+ """
+ token_type = string_to_tokentype(token_name)
+ try:
+ other_token_type = string_to_tokentype(style_dict[token_name])
+ return token_type, style_object.styles[other_token_type]
+ except AttributeError as err:
+ return token_type, style_dict[token_name]
+
+
+def style_factory(name, cli_style):
+ try:
+ style = pygments.styles.get_style_by_name(name)
+ except ClassNotFound:
+ style = pygments.styles.get_style_by_name('native')
+
+ prompt_styles = []
+ # prompt-toolkit used pygments tokens for styling before, switched to style
+ # names in 2.0. Convert old token types to new style names, for backwards compatibility.
+ for token in cli_style:
+ if token.startswith('Token.'):
+ # treat as pygments token (1.0)
+ token_type, style_value = parse_pygments_style(
+ token, style, cli_style)
+ if token_type in TOKEN_TO_PROMPT_STYLE:
+ prompt_style = TOKEN_TO_PROMPT_STYLE[token_type]
+ prompt_styles.append((prompt_style, style_value))
+ else:
+ # we don't want to support tokens anymore
+ logger.error('Unhandled style / class name: %s', token)
+ else:
+ # treat as prompt style name (2.0). See default style names here:
+ # https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py
+ prompt_styles.append((token, cli_style[token]))
+
+ override_style = Style([('bottom-toolbar', 'noreverse')])
+ return merge_styles([
+ style_from_pygments_cls(style),
+ override_style,
+ Style(prompt_styles)
+ ])
+
+
+def style_factory_output(name, cli_style):
+ try:
+ style = pygments.styles.get_style_by_name(name).styles
+ except ClassNotFound:
+ style = pygments.styles.get_style_by_name('native').styles
+
+ for token in cli_style:
+ if token.startswith('Token.'):
+ token_type, style_value = parse_pygments_style(
+ token, style, cli_style)
+ style.update({token_type: style_value})
+ elif token in PROMPT_STYLE_TO_TOKEN:
+ token_type = PROMPT_STYLE_TO_TOKEN[token]
+ style.update({token_type: cli_style[token]})
+ else:
+ # TODO: cli helpers will have to switch to ptk.Style
+ logger.error('Unhandled style / class name: %s', token)
+
+ class OutputStyle(PygmentsStyle):
+ default_style = ""
+ styles = style
+
+ return OutputStyle
diff --git a/mycli/clitoolbar.py b/mycli/clitoolbar.py
new file mode 100644
index 0000000..e03e182
--- /dev/null
+++ b/mycli/clitoolbar.py
@@ -0,0 +1,52 @@
+from prompt_toolkit.key_binding.vi_state import InputMode
+from prompt_toolkit.application import get_app
+from prompt_toolkit.enums import EditingMode
+from .packages import special
+
+
+def create_toolbar_tokens_func(mycli, show_fish_help):
+ """Return a function that generates the toolbar tokens."""
+ def get_toolbar_tokens():
+ result = []
+ result.append(('class:bottom-toolbar', ' '))
+
+ if mycli.multi_line:
+ delimiter = special.get_current_delimiter()
+ result.append(
+ (
+ 'class:bottom-toolbar',
+ ' ({} [{}] will end the line) '.format(
+ 'Semi-colon' if delimiter == ';' else 'Delimiter', delimiter)
+ ))
+
+ if mycli.multi_line:
+ result.append(('class:bottom-toolbar.on', '[F3] Multiline: ON '))
+ else:
+ result.append(('class:bottom-toolbar.off',
+ '[F3] Multiline: OFF '))
+ if mycli.prompt_app.editing_mode == EditingMode.VI:
+ result.append((
+ 'class:botton-toolbar.on',
+ 'Vi-mode ({})'.format(_get_vi_mode())
+ ))
+
+ if show_fish_help():
+ result.append(
+ ('class:bottom-toolbar', ' Right-arrow to complete suggestion'))
+
+ if mycli.completion_refresher.is_refreshing():
+ result.append(
+ ('class:bottom-toolbar', ' Refreshing completions...'))
+
+ return result
+ return get_toolbar_tokens
+
+
+def _get_vi_mode():
+ """Get the current vi mode for display."""
+ return {
+ InputMode.INSERT: 'I',
+ InputMode.NAVIGATION: 'N',
+ InputMode.REPLACE: 'R',
+ InputMode.INSERT_MULTIPLE: 'M',
+ }[get_app().vi_state.input_mode]
diff --git a/mycli/compat.py b/mycli/compat.py
new file mode 100644
index 0000000..2ebfe07
--- /dev/null
+++ b/mycli/compat.py
@@ -0,0 +1,6 @@
+"""Platform and Python version compatibility support."""
+
+import sys
+
+
+WIN = sys.platform in ('win32', 'cygwin')
diff --git a/mycli/completion_refresher.py b/mycli/completion_refresher.py
new file mode 100644
index 0000000..e6c8dd0
--- /dev/null
+++ b/mycli/completion_refresher.py
@@ -0,0 +1,123 @@
+import threading
+from .packages.special.main import COMMANDS
+from collections import OrderedDict
+
+from .sqlcompleter import SQLCompleter
+from .sqlexecute import SQLExecute
+
+class CompletionRefresher(object):
+
+ refreshers = OrderedDict()
+
+ def __init__(self):
+ self._completer_thread = None
+ self._restart_refresh = threading.Event()
+
+ def refresh(self, executor, callbacks, completer_options=None):
+ """Creates a SQLCompleter object and populates it with the relevant
+ completion suggestions in a background thread.
+
+ executor - SQLExecute object, used to extract the credentials to connect
+ to the database.
+ callbacks - A function or a list of functions to call after the thread
+ has completed the refresh. The newly created completion
+ object will be passed in as an argument to each callback.
+ completer_options - dict of options to pass to SQLCompleter.
+
+ """
+ if completer_options is None:
+ completer_options = {}
+
+ if self.is_refreshing():
+ self._restart_refresh.set()
+ return [(None, None, None, 'Auto-completion refresh restarted.')]
+ else:
+ self._completer_thread = threading.Thread(
+ target=self._bg_refresh,
+ args=(executor, callbacks, completer_options),
+ name='completion_refresh')
+ self._completer_thread.setDaemon(True)
+ self._completer_thread.start()
+ return [(None, None, None,
+ 'Auto-completion refresh started in the background.')]
+
+ def is_refreshing(self):
+ return self._completer_thread and self._completer_thread.is_alive()
+
+ def _bg_refresh(self, sqlexecute, callbacks, completer_options):
+ completer = SQLCompleter(**completer_options)
+
+ # Create a new pgexecute method to popoulate the completions.
+ e = sqlexecute
+ executor = SQLExecute(e.dbname, e.user, e.password, e.host, e.port,
+ e.socket, e.charset, e.local_infile, e.ssl,
+ e.ssh_user, e.ssh_host, e.ssh_port,
+ e.ssh_password, e.ssh_key_filename)
+
+ # If callbacks is a single function then push it into a list.
+ if callable(callbacks):
+ callbacks = [callbacks]
+
+ while 1:
+ for refresher in self.refreshers.values():
+ refresher(completer, executor)
+ if self._restart_refresh.is_set():
+ self._restart_refresh.clear()
+ break
+ else:
+ # Break out of while loop if the for loop finishes natually
+ # without hitting the break statement.
+ break
+
+ # Start over the refresh from the beginning if the for loop hit the
+ # break statement.
+ continue
+
+ for callback in callbacks:
+ callback(completer)
+
+def refresher(name, refreshers=CompletionRefresher.refreshers):
+ """Decorator to add the decorated function to the dictionary of
+ refreshers. Any function decorated with a @refresher will be executed as
+ part of the completion refresh routine."""
+ def wrapper(wrapped):
+ refreshers[name] = wrapped
+ return wrapped
+ return wrapper
+
+@refresher('databases')
+def refresh_databases(completer, executor):
+ completer.extend_database_names(executor.databases())
+
+@refresher('schemata')
+def refresh_schemata(completer, executor):
+ # schemata - In MySQL Schema is the same as database. But for mycli
+ # schemata will be the name of the current database.
+ completer.extend_schemata(executor.dbname)
+ completer.set_dbname(executor.dbname)
+
+@refresher('tables')
+def refresh_tables(completer, executor):
+ completer.extend_relations(executor.tables(), kind='tables')
+ completer.extend_columns(executor.table_columns(), kind='tables')
+
+@refresher('users')
+def refresh_users(completer, executor):
+ completer.extend_users(executor.users())
+
+# @refresher('views')
+# def refresh_views(completer, executor):
+# completer.extend_relations(executor.views(), kind='views')
+# completer.extend_columns(executor.view_columns(), kind='views')
+
+@refresher('functions')
+def refresh_functions(completer, executor):
+ completer.extend_functions(executor.functions())
+
+@refresher('special_commands')
+def refresh_special(completer, executor):
+ completer.extend_special_commands(COMMANDS.keys())
+
+@refresher('show_commands')
+def refresh_show_commands(completer, executor):
+ completer.extend_show_items(executor.show_candidates())
diff --git a/mycli/config.py b/mycli/config.py
new file mode 100644
index 0000000..e0f2d1f
--- /dev/null
+++ b/mycli/config.py
@@ -0,0 +1,286 @@
+import io
+import shutil
+from copy import copy
+from io import BytesIO, TextIOWrapper
+import logging
+import os
+from os.path import exists
+import struct
+import sys
+from typing import Union
+
+from configobj import ConfigObj, ConfigObjError
+from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
+from cryptography.hazmat.backends import default_backend
+
+try:
+ basestring
+except NameError:
+ basestring = str
+
+
+logger = logging.getLogger(__name__)
+
+
+def log(logger, level, message):
+ """Logs message to stderr if logging isn't initialized."""
+
+ if logger.parent.name != 'root':
+ logger.log(level, message)
+ else:
+ print(message, file=sys.stderr)
+
+
+def read_config_file(f, list_values=True):
+ """Read a config file.
+
+ *list_values* set to `True` is the default behavior of ConfigObj.
+ Disabling it causes values to not be parsed for lists,
+ (e.g. 'a,b,c' -> ['a', 'b', 'c']. Additionally, the config values are
+ not unquoted. We are disabling list_values when reading MySQL config files
+ so we can correctly interpret commas in passwords.
+
+ """
+
+ if isinstance(f, basestring):
+ f = os.path.expanduser(f)
+
+ try:
+ config = ConfigObj(f, interpolation=False, encoding='utf8',
+ list_values=list_values)
+ except ConfigObjError as e:
+ log(logger, logging.ERROR, "Unable to parse line {0} of config file "
+ "'{1}'.".format(e.line_number, f))
+ log(logger, logging.ERROR, "Using successfully parsed config values.")
+ return e.config
+ except (IOError, OSError) as e:
+ log(logger, logging.WARNING, "You don't have permission to read "
+ "config file '{0}'.".format(e.filename))
+ return None
+
+ return config
+
+
+def get_included_configs(config_file: Union[str, io.TextIOWrapper]) -> list:
+ """Get a list of configuration files that are included into config_path
+ with !includedir directive.
+
+ "Normal" configs should be passed as file paths. The only exception
+ is .mylogin which is decoded into a stream. However, it never
+ contains include directives and so will be ignored by this
+ function.
+
+ """
+ if not isinstance(config_file, str) or not os.path.isfile(config_file):
+ return []
+ included_configs = []
+
+ try:
+ with open(config_file) as f:
+ include_directives = filter(
+ lambda s: s.startswith('!includedir'),
+ f
+ )
+ dirs = map(lambda s: s.strip().split()[-1], include_directives)
+ dirs = filter(os.path.isdir, dirs)
+ for dir in dirs:
+ for filename in os.listdir(dir):
+ if filename.endswith('.cnf'):
+ included_configs.append(os.path.join(dir, filename))
+ except (PermissionError, UnicodeDecodeError):
+ pass
+ return included_configs
+
+
+def read_config_files(files, list_values=True):
+ """Read and merge a list of config files."""
+
+ config = ConfigObj(list_values=list_values)
+ _files = copy(files)
+ while _files:
+ _file = _files.pop(0)
+ _config = read_config_file(_file, list_values=list_values)
+
+ # expand includes only if we were able to parse config
+ # (otherwise we'll just encounter the same errors again)
+ if config is not None:
+ _files = get_included_configs(_file) + _files
+ if bool(_config) is True:
+ config.merge(_config)
+ config.filename = _config.filename
+
+ return config
+
+
+def write_default_config(source, destination, overwrite=False):
+ destination = os.path.expanduser(destination)
+ if not overwrite and exists(destination):
+ return
+
+ shutil.copyfile(source, destination)
+
+
+def get_mylogin_cnf_path():
+ """Return the path to the login path file or None if it doesn't exist."""
+ mylogin_cnf_path = os.getenv('MYSQL_TEST_LOGIN_FILE')
+
+ if mylogin_cnf_path is None:
+ app_data = os.getenv('APPDATA')
+ default_dir = os.path.join(app_data, 'MySQL') if app_data else '~'
+ mylogin_cnf_path = os.path.join(default_dir, '.mylogin.cnf')
+
+ mylogin_cnf_path = os.path.expanduser(mylogin_cnf_path)
+
+ if exists(mylogin_cnf_path):
+ logger.debug("Found login path file at '{0}'".format(mylogin_cnf_path))
+ return mylogin_cnf_path
+ return None
+
+
+def open_mylogin_cnf(name):
+ """Open a readable version of .mylogin.cnf.
+
+ Returns the file contents as a TextIOWrapper object.
+
+ :param str name: The pathname of the file to be opened.
+ :return: the login path file or None
+ """
+
+ try:
+ with open(name, 'rb') as f:
+ plaintext = read_and_decrypt_mylogin_cnf(f)
+ except (OSError, IOError, ValueError):
+ logger.error('Unable to open login path file.')
+ return None
+
+ if not isinstance(plaintext, BytesIO):
+ logger.error('Unable to read login path file.')
+ return None
+
+ return TextIOWrapper(plaintext)
+
+
+def read_and_decrypt_mylogin_cnf(f):
+ """Read and decrypt the contents of .mylogin.cnf.
+
+ This decryption algorithm mimics the code in MySQL's
+ mysql_config_editor.cc.
+
+ The login key is 20-bytes of random non-printable ASCII.
+ It is written to the actual login path file. It is used
+ to generate the real key used in the AES cipher.
+
+ :param f: an I/O object opened in binary mode
+ :return: the decrypted login path file
+ :rtype: io.BytesIO or None
+ """
+
+ # Number of bytes used to store the length of ciphertext.
+ MAX_CIPHER_STORE_LEN = 4
+
+ LOGIN_KEY_LEN = 20
+
+ # Move past the unused buffer.
+ buf = f.read(4)
+
+ if not buf or len(buf) != 4:
+ logger.error('Login path file is blank or incomplete.')
+ return None
+
+ # Read the login key.
+ key = f.read(LOGIN_KEY_LEN)
+
+ # Generate the real key.
+ rkey = [0] * 16
+ for i in range(LOGIN_KEY_LEN):
+ try:
+ rkey[i % 16] ^= ord(key[i:i+1])
+ except TypeError:
+ # ord() was unable to get the value of the byte.
+ logger.error('Unable to generate login path AES key.')
+ return None
+ rkey = struct.pack('16B', *rkey)
+
+ # Create a decryptor object using the key.
+ decryptor = _get_decryptor(rkey)
+
+ # Create a bytes buffer to hold the plaintext.
+ plaintext = BytesIO()
+
+ while True:
+ # Read the length of the ciphertext.
+ len_buf = f.read(MAX_CIPHER_STORE_LEN)
+ if len(len_buf) < MAX_CIPHER_STORE_LEN:
+ break
+ cipher_len, = struct.unpack("<i", len_buf)
+
+ # Read cipher_len bytes from the file and decrypt.
+ cipher = f.read(cipher_len)
+ plain = _remove_pad(decryptor.update(cipher))
+ if plain is False:
+ continue
+ plaintext.write(plain)
+
+ if plaintext.tell() == 0:
+ logger.error('No data successfully decrypted from login path file.')
+ return None
+
+ plaintext.seek(0)
+ return plaintext
+
+
+def str_to_bool(s):
+ """Convert a string value to its corresponding boolean value."""
+ if isinstance(s, bool):
+ return s
+ elif not isinstance(s, basestring):
+ raise TypeError('argument must be a string')
+
+ true_values = ('true', 'on', '1')
+ false_values = ('false', 'off', '0')
+
+ if s.lower() in true_values:
+ return True
+ elif s.lower() in false_values:
+ return False
+ else:
+ raise ValueError('not a recognized boolean value: %s'.format(s))
+
+
+def strip_matching_quotes(s):
+ """Remove matching, surrounding quotes from a string.
+
+ This is the same logic that ConfigObj uses when parsing config
+ values.
+
+ """
+ if (isinstance(s, basestring) and len(s) >= 2 and
+ s[0] == s[-1] and s[0] in ('"', "'")):
+ s = s[1:-1]
+ return s
+
+
+def _get_decryptor(key):
+ """Get the AES decryptor."""
+ c = Cipher(algorithms.AES(key), modes.ECB(), backend=default_backend())
+ return c.decryptor()
+
+
+def _remove_pad(line):
+ """Remove the pad from the *line*."""
+ pad_length = ord(line[-1:])
+ try:
+ # Determine pad length.
+ pad_length = ord(line[-1:])
+ except TypeError:
+ # ord() was unable to get the value of the byte.
+ logger.warning('Unable to remove pad.')
+ return False
+
+ if pad_length > len(line) or len(set(line[-pad_length:])) != 1:
+ # Pad length should be less than or equal to the length of the
+ # plaintext. The pad should have a single unique byte.
+ logger.warning('Invalid pad found in login path file.')
+ return False
+
+ return line[:-pad_length]
diff --git a/mycli/key_bindings.py b/mycli/key_bindings.py
new file mode 100644
index 0000000..57b917b
--- /dev/null
+++ b/mycli/key_bindings.py
@@ -0,0 +1,85 @@
+import logging
+from prompt_toolkit.enums import EditingMode
+from prompt_toolkit.filters import completion_is_selected
+from prompt_toolkit.key_binding import KeyBindings
+
+_logger = logging.getLogger(__name__)
+
+
+def mycli_bindings(mycli):
+ """Custom key bindings for mycli."""
+ kb = KeyBindings()
+
+ @kb.add('f2')
+ def _(event):
+ """Enable/Disable SmartCompletion Mode."""
+ _logger.debug('Detected F2 key.')
+ mycli.completer.smart_completion = not mycli.completer.smart_completion
+
+ @kb.add('f3')
+ def _(event):
+ """Enable/Disable Multiline Mode."""
+ _logger.debug('Detected F3 key.')
+ mycli.multi_line = not mycli.multi_line
+
+ @kb.add('f4')
+ def _(event):
+ """Toggle between Vi and Emacs mode."""
+ _logger.debug('Detected F4 key.')
+ if mycli.key_bindings == "vi":
+ event.app.editing_mode = EditingMode.EMACS
+ mycli.key_bindings = "emacs"
+ else:
+ event.app.editing_mode = EditingMode.VI
+ mycli.key_bindings = "vi"
+
+ @kb.add('tab')
+ def _(event):
+ """Force autocompletion at cursor."""
+ _logger.debug('Detected <Tab> key.')
+ b = event.app.current_buffer
+ if b.complete_state:
+ b.complete_next()
+ else:
+ b.start_completion(select_first=True)
+
+ @kb.add('c-space')
+ def _(event):
+ """
+ Initialize autocompletion at cursor.
+
+ If the autocompletion menu is not showing, display it with the
+ appropriate completions for the context.
+
+ If the menu is showing, select the next completion.
+ """
+ _logger.debug('Detected <C-Space> key.')
+
+ b = event.app.current_buffer
+ if b.complete_state:
+ b.complete_next()
+ else:
+ b.start_completion(select_first=False)
+
+ @kb.add('enter', filter=completion_is_selected)
+ def _(event):
+ """Makes the enter key work as the tab key only when showing the menu.
+
+ In other words, don't execute query when enter is pressed in
+ the completion dropdown menu, instead close the dropdown menu
+ (accept current selection).
+
+ """
+ _logger.debug('Detected enter key.')
+
+ event.current_buffer.complete_state = None
+ b = event.app.current_buffer
+ b.complete_state = None
+
+ @kb.add('escape', 'enter')
+ def _(event):
+ """Introduces a line break regardless of multi-line mode or not."""
+ _logger.debug('Detected alt-enter key.')
+ event.app.current_buffer.insert_text('\n')
+
+ return kb
diff --git a/mycli/lexer.py b/mycli/lexer.py
new file mode 100644
index 0000000..4b14d72
--- /dev/null
+++ b/mycli/lexer.py
@@ -0,0 +1,12 @@
+from pygments.lexer import inherit
+from pygments.lexers.sql import MySqlLexer
+from pygments.token import Keyword
+
+
+class MyCliLexer(MySqlLexer):
+ """Extends MySQL lexer to add keywords."""
+
+ tokens = {
+ 'root': [(r'\brepair\b', Keyword),
+ (r'\boffset\b', Keyword), inherit],
+ }
diff --git a/mycli/magic.py b/mycli/magic.py
new file mode 100644
index 0000000..5527f72
--- /dev/null
+++ b/mycli/magic.py
@@ -0,0 +1,54 @@
+from .main import MyCli
+import sql.parse
+import sql.connection
+import logging
+
+_logger = logging.getLogger(__name__)
+
+def load_ipython_extension(ipython):
+
+ # This is called via the ipython command '%load_ext mycli.magic'.
+
+ # First, load the sql magic if it isn't already loaded.
+ if not ipython.find_line_magic('sql'):
+ ipython.run_line_magic('load_ext', 'sql')
+
+ # Register our own magic.
+ ipython.register_magic_function(mycli_line_magic, 'line', 'mycli')
+
+def mycli_line_magic(line):
+ _logger.debug('mycli magic called: %r', line)
+ parsed = sql.parse.parse(line, {})
+ conn = sql.connection.Connection.get(parsed['connection'])
+
+ try:
+ # A corresponding mycli object already exists
+ mycli = conn._mycli
+ _logger.debug('Reusing existing mycli')
+ except AttributeError:
+ mycli = MyCli()
+ u = conn.session.engine.url
+ _logger.debug('New mycli: %r', str(u))
+
+ mycli.connect(u.database, u.host, u.username, u.port, u.password)
+ conn._mycli = mycli
+
+ # For convenience, print the connection alias
+ print('Connected: {}'.format(conn.name))
+
+ try:
+ mycli.run_cli()
+ except SystemExit:
+ pass
+
+ if not mycli.query_history:
+ return
+
+ q = mycli.query_history[-1]
+ if q.mutating:
+ _logger.debug('Mutating query detected -- ignoring')
+ return
+
+ if q.successful:
+ ipython = get_ipython()
+ return ipython.run_cell_magic('sql', line, q.query)
diff --git a/mycli/main.py b/mycli/main.py
new file mode 100755
index 0000000..03797a0
--- /dev/null
+++ b/mycli/main.py
@@ -0,0 +1,1326 @@
+import os
+import sys
+import traceback
+import logging
+import threading
+import re
+import fileinput
+from collections import namedtuple
+try:
+ from pwd import getpwuid
+except ImportError:
+ pass
+from time import time
+from datetime import datetime
+from random import choice
+from io import open
+
+from pymysql import OperationalError
+from cli_helpers.tabular_output import TabularOutputFormatter
+from cli_helpers.tabular_output import preprocessors
+from cli_helpers.utils import strip_ansi
+import click
+import sqlparse
+from mycli.packages.parseutils import is_dropping_database
+from prompt_toolkit.completion import DynamicCompleter
+from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
+from prompt_toolkit.key_binding.bindings.named_commands import register as prompt_register
+from prompt_toolkit.shortcuts import PromptSession, CompleteStyle
+from prompt_toolkit.document import Document
+from prompt_toolkit.filters import HasFocus, IsDone
+from prompt_toolkit.layout.processors import (HighlightMatchingBracketProcessor,
+ ConditionalProcessor)
+from prompt_toolkit.lexers import PygmentsLexer
+from prompt_toolkit.history import FileHistory
+from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
+
+from .packages.special.main import NO_QUERY
+from .packages.prompt_utils import confirm, confirm_destructive_query
+from .packages.tabular_output import sql_format
+from .packages import special
+from .packages.special.favoritequeries import FavoriteQueries
+from .sqlcompleter import SQLCompleter
+from .clitoolbar import create_toolbar_tokens_func
+from .clistyle import style_factory, style_factory_output
+from .sqlexecute import FIELD_TYPES, SQLExecute
+from .clibuffer import cli_is_multiline
+from .completion_refresher import CompletionRefresher
+from .config import (write_default_config, get_mylogin_cnf_path,
+ open_mylogin_cnf, read_config_files, str_to_bool,
+ strip_matching_quotes)
+from .key_bindings import mycli_bindings
+from .lexer import MyCliLexer
+from .__init__ import __version__
+from .compat import WIN
+from .packages.filepaths import dir_path_exists, guess_socket_location
+
+import itertools
+
+click.disable_unicode_literals_warning = True
+
+try:
+ from urlparse import urlparse
+ from urlparse import unquote
+except ImportError:
+ from urllib.parse import urlparse
+ from urllib.parse import unquote
+
+
+try:
+ import paramiko
+except ImportError:
+ from mycli.packages.paramiko_stub import paramiko
+
+# Query tuples are used for maintaining history
+Query = namedtuple('Query', ['query', 'successful', 'mutating'])
+
+PACKAGE_ROOT = os.path.abspath(os.path.dirname(__file__))
+
+
+class MyCli(object):
+
+ default_prompt = '\\t \\u@\\h:\\d> '
+ max_len_prompt = 45
+ defaults_suffix = None
+
+ # In order of being loaded. Files lower in list override earlier ones.
+ cnf_files = [
+ '/etc/my.cnf',
+ '/etc/mysql/my.cnf',
+ '/usr/local/etc/my.cnf',
+ '~/.my.cnf'
+ ]
+
+ # check XDG_CONFIG_HOME exists and not an empty string
+ if os.environ.get("XDG_CONFIG_HOME"):
+ xdg_config_home = os.environ.get("XDG_CONFIG_HOME")
+ else:
+ xdg_config_home = "~/.config"
+ system_config_files = [
+ '/etc/myclirc',
+ os.path.join(xdg_config_home, "mycli", "myclirc")
+ ]
+
+ default_config_file = os.path.join(PACKAGE_ROOT, 'myclirc')
+ pwd_config_file = os.path.join(os.getcwd(), ".myclirc")
+
+ def __init__(self, sqlexecute=None, prompt=None,
+ logfile=None, defaults_suffix=None, defaults_file=None,
+ login_path=None, auto_vertical_output=False, warn=None,
+ myclirc="~/.myclirc"):
+ self.sqlexecute = sqlexecute
+ self.logfile = logfile
+ self.defaults_suffix = defaults_suffix
+ self.login_path = login_path
+
+ # self.cnf_files is a class variable that stores the list of mysql
+ # config files to read in at launch.
+ # If defaults_file is specified then override the class variable with
+ # defaults_file.
+ if defaults_file:
+ self.cnf_files = [defaults_file]
+
+ # Load config.
+ config_files = ([self.default_config_file] + self.system_config_files +
+ [myclirc] + [self.pwd_config_file])
+ c = self.config = read_config_files(config_files)
+ self.multi_line = c['main'].as_bool('multi_line')
+ self.key_bindings = c['main']['key_bindings']
+ special.set_timing_enabled(c['main'].as_bool('timing'))
+
+ FavoriteQueries.instance = FavoriteQueries.from_config(self.config)
+
+ self.dsn_alias = None
+ self.formatter = TabularOutputFormatter(
+ format_name=c['main']['table_format'])
+ sql_format.register_new_formatter(self.formatter)
+ self.formatter.mycli = self
+ self.syntax_style = c['main']['syntax_style']
+ self.less_chatty = c['main'].as_bool('less_chatty')
+ self.cli_style = c['colors']
+ self.output_style = style_factory_output(
+ self.syntax_style,
+ self.cli_style
+ )
+ self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')
+ c_dest_warning = c['main'].as_bool('destructive_warning')
+ self.destructive_warning = c_dest_warning if warn is None else warn
+ self.login_path_as_host = c['main'].as_bool('login_path_as_host')
+
+ # read from cli argument or user config file
+ self.auto_vertical_output = auto_vertical_output or \
+ c['main'].as_bool('auto_vertical_output')
+
+ # Write user config if system config wasn't the last config loaded.
+ if c.filename not in self.system_config_files:
+ write_default_config(self.default_config_file, myclirc)
+
+ # audit log
+ if self.logfile is None and 'audit_log' in c['main']:
+ try:
+ self.logfile = open(os.path.expanduser(c['main']['audit_log']), 'a')
+ except (IOError, OSError) as e:
+ self.echo('Error: Unable to open the audit log file. Your queries will not be logged.',
+ err=True, fg='red')
+ self.logfile = False
+
+ self.completion_refresher = CompletionRefresher()
+
+ self.logger = logging.getLogger(__name__)
+ self.initialize_logging()
+
+ prompt_cnf = self.read_my_cnf_files(self.cnf_files, ['prompt'])['prompt']
+ self.prompt_format = prompt or prompt_cnf or c['main']['prompt'] or \
+ self.default_prompt
+ self.multiline_continuation_char = c['main']['prompt_continuation']
+ keyword_casing = c['main'].get('keyword_casing', 'auto')
+
+ self.query_history = []
+
+ # Initialize completer.
+ self.smart_completion = c['main'].as_bool('smart_completion')
+ self.completer = SQLCompleter(
+ self.smart_completion,
+ supported_formats=self.formatter.supported_formats,
+ keyword_casing=keyword_casing)
+ self._completer_lock = threading.Lock()
+
+ # Register custom special commands.
+ self.register_special_commands()
+
+ # Load .mylogin.cnf if it exists.
+ mylogin_cnf_path = get_mylogin_cnf_path()
+ if mylogin_cnf_path:
+ mylogin_cnf = open_mylogin_cnf(mylogin_cnf_path)
+ if mylogin_cnf_path and mylogin_cnf:
+ # .mylogin.cnf gets read last, even if defaults_file is specified.
+ self.cnf_files.append(mylogin_cnf)
+ elif mylogin_cnf_path and not mylogin_cnf:
+ # There was an error reading the login path file.
+ print('Error: Unable to read login path file.')
+
+ self.prompt_app = None
+
+ def register_special_commands(self):
+ special.register_special_command(self.change_db, 'use',
+ '\\u', 'Change to a new database.', aliases=('\\u',))
+ special.register_special_command(self.change_db, 'connect',
+ '\\r', 'Reconnect to the database. Optional database argument.',
+ aliases=('\\r', ), case_sensitive=True)
+ special.register_special_command(self.refresh_completions, 'rehash',
+ '\\#', 'Refresh auto-completions.', arg_type=NO_QUERY, aliases=('\\#',))
+ special.register_special_command(
+ self.change_table_format, 'tableformat', '\\T',
+ 'Change the table format used to output results.',
+ aliases=('\\T',), case_sensitive=True)
+ special.register_special_command(self.execute_from_file, 'source', '\\. filename',
+ 'Execute commands from file.', aliases=('\\.',))
+ special.register_special_command(self.change_prompt_format, 'prompt',
+ '\\R', 'Change prompt format.', aliases=('\\R',), case_sensitive=True)
+
+ def change_table_format(self, arg, **_):
+ try:
+ self.formatter.format_name = arg
+ yield (None, None, None,
+ 'Changed table format to {}'.format(arg))
+ except ValueError:
+ msg = 'Table format {} not recognized. Allowed formats:'.format(
+ arg)
+ for table_type in self.formatter.supported_formats:
+ msg += "\n\t{}".format(table_type)
+ yield (None, None, None, msg)
+
+ def change_db(self, arg, **_):
+ if not arg:
+ click.secho(
+ "No database selected",
+ err=True, fg="red"
+ )
+ return
+
+ self.sqlexecute.change_db(arg)
+
+ yield (None, None, None, 'You are now connected to database "%s" as '
+ 'user "%s"' % (self.sqlexecute.dbname, self.sqlexecute.user))
+
+ def execute_from_file(self, arg, **_):
+ if not arg:
+ message = 'Missing required argument, filename.'
+ return [(None, None, None, message)]
+ try:
+ with open(os.path.expanduser(arg)) as f:
+ query = f.read()
+ except IOError as e:
+ return [(None, None, None, str(e))]
+
+ if (self.destructive_warning and
+ confirm_destructive_query(query) is False):
+ message = 'Wise choice. Command execution stopped.'
+ return [(None, None, None, message)]
+
+ return self.sqlexecute.run(query)
+
+ def change_prompt_format(self, arg, **_):
+ """
+ Change the prompt format.
+ """
+ if not arg:
+ message = 'Missing required argument, format.'
+ return [(None, None, None, message)]
+
+ self.prompt_format = self.get_prompt(arg)
+ return [(None, None, None, "Changed prompt format to %s" % arg)]
+
+ def initialize_logging(self):
+
+ log_file = os.path.expanduser(self.config['main']['log_file'])
+ log_level = self.config['main']['log_level']
+
+ level_map = {'CRITICAL': logging.CRITICAL,
+ 'ERROR': logging.ERROR,
+ 'WARNING': logging.WARNING,
+ 'INFO': logging.INFO,
+ 'DEBUG': logging.DEBUG
+ }
+
+ # Disable logging if value is NONE by switching to a no-op handler
+ # Set log level to a high value so it doesn't even waste cycles getting called.
+ if log_level.upper() == "NONE":
+ handler = logging.NullHandler()
+ log_level = "CRITICAL"
+ elif dir_path_exists(log_file):
+ handler = logging.FileHandler(log_file)
+ else:
+ self.echo(
+ 'Error: Unable to open the log file "{}".'.format(log_file),
+ err=True, fg='red')
+ return
+
+ formatter = logging.Formatter(
+ '%(asctime)s (%(process)d/%(threadName)s) '
+ '%(name)s %(levelname)s - %(message)s')
+
+ handler.setFormatter(formatter)
+
+ root_logger = logging.getLogger('mycli')
+ root_logger.addHandler(handler)
+ root_logger.setLevel(level_map[log_level.upper()])
+
+ logging.captureWarnings(True)
+
+ root_logger.debug('Initializing mycli logging.')
+ root_logger.debug('Log file %r.', log_file)
+
+
+ def read_my_cnf_files(self, files, keys):
+ """
+ Reads a list of config files and merges them. The last one will win.
+ :param files: list of files to read
+ :param keys: list of keys to retrieve
+ :returns: tuple, with None for missing keys.
+ """
+ cnf = read_config_files(files, list_values=False)
+
+ sections = ['client', 'mysqld']
+ if self.login_path and self.login_path != 'client':
+ sections.append(self.login_path)
+
+ if self.defaults_suffix:
+ sections.extend([sect + self.defaults_suffix for sect in sections])
+
+ def get(key):
+ result = None
+ for sect in cnf:
+ if sect in sections and key in cnf[sect]:
+ result = strip_matching_quotes(cnf[sect][key])
+ return result
+
+ return {x: get(x) for x in keys}
+
+ def merge_ssl_with_cnf(self, ssl, cnf):
+ """Merge SSL configuration dict with cnf dict"""
+
+ merged = {}
+ merged.update(ssl)
+ prefix = 'ssl-'
+ for k, v in cnf.items():
+ # skip unrelated options
+ if not k.startswith(prefix):
+ continue
+ if v is None:
+ continue
+ # special case because PyMySQL argument is significantly different
+ # from commandline
+ if k == 'ssl-verify-server-cert':
+ merged['check_hostname'] = v
+ else:
+ # use argument name just strip "ssl-" prefix
+ arg = k[len(prefix):]
+ merged[arg] = v
+
+ return merged
+
+ def connect(self, database='', user='', passwd='', host='', port='',
+ socket='', charset='', local_infile='', ssl='',
+ ssh_user='', ssh_host='', ssh_port='',
+ ssh_password='', ssh_key_filename=''):
+
+ cnf = {'database': None,
+ 'user': None,
+ 'password': None,
+ 'host': None,
+ 'port': None,
+ 'socket': None,
+ 'default-character-set': None,
+ 'local-infile': None,
+ 'loose-local-infile': None,
+ 'ssl-ca': None,
+ 'ssl-cert': None,
+ 'ssl-key': None,
+ 'ssl-cipher': None,
+ 'ssl-verify-serer-cert': None,
+ }
+
+ cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys())
+
+ # Fall back to config values only if user did not specify a value.
+
+ database = database or cnf['database']
+ # Socket interface not supported for SSH connections
+ if port or host or ssh_host or ssh_port:
+ socket = ''
+ else:
+ socket = socket or cnf['socket'] or guess_socket_location()
+ user = user or cnf['user'] or os.getenv('USER')
+ host = host or cnf['host']
+ port = port or cnf['port']
+ ssl = ssl or {}
+
+ passwd = passwd or cnf['password']
+ charset = charset or cnf['default-character-set'] or 'utf8'
+
+ # Favor whichever local_infile option is set.
+ for local_infile_option in (local_infile, cnf['local-infile'],
+ cnf['loose-local-infile'], False):
+ try:
+ local_infile = str_to_bool(local_infile_option)
+ break
+ except (TypeError, ValueError):
+ pass
+
+ ssl = self.merge_ssl_with_cnf(ssl, cnf)
+ # prune lone check_hostname=False
+ if not any(v for v in ssl.values()):
+ ssl = None
+
+ # Connect to the database.
+
+ def _connect():
+ try:
+ self.sqlexecute = SQLExecute(
+ database, user, passwd, host, port, socket, charset,
+ local_infile, ssl, ssh_user, ssh_host, ssh_port,
+ ssh_password, ssh_key_filename
+ )
+ except OperationalError as e:
+ if ('Access denied for user' in e.args[1]):
+ new_passwd = click.prompt('Password', hide_input=True,
+ show_default=False, type=str, err=True)
+ self.sqlexecute = SQLExecute(
+ database, user, new_passwd, host, port, socket,
+ charset, local_infile, ssl, ssh_user, ssh_host,
+ ssh_port, ssh_password, ssh_key_filename
+ )
+ else:
+ raise e
+
+ try:
+ if not WIN and socket:
+ socket_owner = getpwuid(os.stat(socket).st_uid).pw_name
+ self.echo(
+ f"Connecting to socket {socket}, owned by user {socket_owner}")
+ try:
+ _connect()
+ except OperationalError as e:
+ # These are "Can't open socket" and 2x "Can't connect"
+ if [code for code in (2001, 2002, 2003) if code == e.args[0]]:
+ self.logger.debug('Database connection failed: %r.', e)
+ self.logger.error(
+ "traceback: %r", traceback.format_exc())
+ self.logger.debug('Retrying over TCP/IP')
+ self.echo(
+ "Failed to connect to local MySQL server through socket '{}':".format(socket))
+ self.echo(str(e), err=True)
+ self.echo(
+ 'Retrying over TCP/IP', err=True)
+
+ # Else fall back to TCP/IP localhost
+ socket = ""
+ host = 'localhost'
+ port = 3306
+ _connect()
+ else:
+ raise e
+ else:
+ host = host or 'localhost'
+ port = port or 3306
+
+ # Bad ports give particularly daft error messages
+ try:
+ port = int(port)
+ except ValueError as e:
+ self.echo("Error: Invalid port number: '{0}'.".format(port),
+ err=True, fg='red')
+ exit(1)
+
+ _connect()
+ except Exception as e: # Connecting to a database could fail.
+ self.logger.debug('Database connection failed: %r.', e)
+ self.logger.error("traceback: %r", traceback.format_exc())
+ self.echo(str(e), err=True, fg='red')
+ exit(1)
+
+ def handle_editor_command(self, text):
+ """Editor command is any query that is prefixed or suffixed by a '\e'.
+ The reason for a while loop is because a user might edit a query
+ multiple times. For eg:
+
+ "select * from \e"<enter> to edit it in vim, then come
+ back to the prompt with the edited query "select * from
+ blah where q = 'abc'\e" to edit it again.
+ :param text: Document
+ :return: Document
+
+ """
+
+ while special.editor_command(text):
+ filename = special.get_filename(text)
+ query = (special.get_editor_query(text) or
+ self.get_last_query())
+ sql, message = special.open_external_editor(filename, sql=query)
+ if message:
+ # Something went wrong. Raise an exception and bail.
+ raise RuntimeError(message)
+ while True:
+ try:
+ text = self.prompt_app.prompt(default=sql)
+ break
+ except KeyboardInterrupt:
+ sql = ""
+
+ continue
+ return text
+
+ def run_cli(self):
+ iterations = 0
+ sqlexecute = self.sqlexecute
+ logger = self.logger
+ self.configure_pager()
+
+ if self.smart_completion:
+ self.refresh_completions()
+
+ author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS')
+ sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS')
+
+ history_file = os.path.expanduser(
+ os.environ.get('MYCLI_HISTFILE', '~/.mycli-history'))
+ if dir_path_exists(history_file):
+ history = FileHistory(history_file)
+ else:
+ history = None
+ self.echo(
+ 'Error: Unable to open the history file "{}". '
+ 'Your query history will not be saved.'.format(history_file),
+ err=True, fg='red')
+
+ key_bindings = mycli_bindings(self)
+
+ if not self.less_chatty:
+ print(' '.join(sqlexecute.server_type()))
+ print('mycli', __version__)
+ print('Chat: https://gitter.im/dbcli/mycli')
+ print('Mail: https://groups.google.com/forum/#!forum/mycli-users')
+ print('Home: http://mycli.net')
+ print('Thanks to the contributor -', thanks_picker([author_file, sponsor_file]))
+
+ def get_message():
+ prompt = self.get_prompt(self.prompt_format)
+ if self.prompt_format == self.default_prompt and len(prompt) > self.max_len_prompt:
+ prompt = self.get_prompt('\\d> ')
+ return [('class:prompt', prompt)]
+
+ def get_continuation(width, *_):
+ if self.multiline_continuation_char:
+ left_padding = width - len(self.multiline_continuation_char)
+ continuation = " " * \
+ max((left_padding - 1), 0) + \
+ self.multiline_continuation_char + " "
+ else:
+ continuation = " "
+ return [('class:continuation', continuation)]
+
+ def show_suggestion_tip():
+ return iterations < 2
+
+ def one_iteration(text=None):
+ if text is None:
+ try:
+ text = self.prompt_app.prompt()
+ except KeyboardInterrupt:
+ return
+
+ special.set_expanded_output(False)
+
+ try:
+ text = self.handle_editor_command(text)
+ except RuntimeError as e:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ self.echo(str(e), err=True, fg='red')
+ return
+
+ if not text.strip():
+ return
+
+ if self.destructive_warning:
+ destroy = confirm_destructive_query(text)
+ if destroy is None:
+ pass # Query was not destructive. Nothing to do here.
+ elif destroy is True:
+ self.echo('Your call!')
+ else:
+ self.echo('Wise choice!')
+ return
+ else:
+ destroy = True
+
+ # Keep track of whether or not the query is mutating. In case
+ # of a multi-statement query, the overall query is considered
+ # mutating if any one of the component statements is mutating
+ mutating = False
+
+ try:
+ logger.debug('sql: %r', text)
+
+ special.write_tee(self.get_prompt(self.prompt_format) + text)
+ if self.logfile:
+ self.logfile.write('\n# %s\n' % datetime.now())
+ self.logfile.write(text)
+ self.logfile.write('\n')
+
+ successful = False
+ start = time()
+ res = sqlexecute.run(text)
+ self.formatter.query = text
+ successful = True
+ result_count = 0
+ for title, cur, headers, status in res:
+ logger.debug("headers: %r", headers)
+ logger.debug("rows: %r", cur)
+ logger.debug("status: %r", status)
+ threshold = 1000
+ if (is_select(status) and
+ cur and cur.rowcount > threshold):
+ self.echo('The result set has more than {} rows.'.format(
+ threshold), fg='red')
+ if not confirm('Do you want to continue?'):
+ self.echo("Aborted!", err=True, fg='red')
+ break
+
+ if self.auto_vertical_output:
+ max_width = self.prompt_app.output.get_size().columns
+ else:
+ max_width = None
+
+ formatted = self.format_output(
+ title, cur, headers, special.is_expanded_output(),
+ max_width)
+
+ t = time() - start
+ try:
+ if result_count > 0:
+ self.echo('')
+ try:
+ self.output(formatted, status)
+ except KeyboardInterrupt:
+ pass
+ if special.is_timing_enabled():
+ self.echo('Time: %0.03fs' % t)
+ except KeyboardInterrupt:
+ pass
+
+ start = time()
+ result_count += 1
+ mutating = mutating or destroy or is_mutating(status)
+ special.unset_once_if_written()
+ except EOFError as e:
+ raise e
+ except KeyboardInterrupt:
+ # get last connection id
+ connection_id_to_kill = sqlexecute.connection_id
+ logger.debug("connection id to kill: %r", connection_id_to_kill)
+ # Restart connection to the database
+ sqlexecute.connect()
+ try:
+ for title, cur, headers, status in sqlexecute.run('kill %s' % connection_id_to_kill):
+ status_str = str(status).lower()
+ if status_str.find('ok') > -1:
+ logger.debug("cancelled query, connection id: %r, sql: %r",
+ connection_id_to_kill, text)
+ self.echo("cancelled query", err=True, fg='red')
+ except Exception as e:
+ self.echo('Encountered error while cancelling query: {}'.format(e),
+ err=True, fg='red')
+ except NotImplementedError:
+ self.echo('Not Yet Implemented.', fg="yellow")
+ except OperationalError as e:
+ logger.debug("Exception: %r", e)
+ if (e.args[0] in (2003, 2006, 2013)):
+ logger.debug('Attempting to reconnect.')
+ self.echo('Reconnecting...', fg='yellow')
+ try:
+ sqlexecute.connect()
+ logger.debug('Reconnected successfully.')
+ one_iteration(text)
+ return # OK to just return, cuz the recursion call runs to the end.
+ except OperationalError as e:
+ logger.debug('Reconnect failed. e: %r', e)
+ self.echo(str(e), err=True, fg='red')
+ # If reconnection failed, don't proceed further.
+ return
+ else:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ self.echo(str(e), err=True, fg='red')
+ except Exception as e:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ self.echo(str(e), err=True, fg='red')
+ else:
+ if is_dropping_database(text, self.sqlexecute.dbname):
+ self.sqlexecute.dbname = None
+ self.sqlexecute.connect()
+
+ # Refresh the table names and column names if necessary.
+ if need_completion_refresh(text):
+ self.refresh_completions(
+ reset=need_completion_reset(text))
+ finally:
+ if self.logfile is False:
+ self.echo("Warning: This query was not logged.",
+ err=True, fg='red')
+ query = Query(text, successful, mutating)
+ self.query_history.append(query)
+
+ get_toolbar_tokens = create_toolbar_tokens_func(
+ self, show_suggestion_tip)
+ if self.wider_completion_menu:
+ complete_style = CompleteStyle.MULTI_COLUMN
+ else:
+ complete_style = CompleteStyle.COLUMN
+
+ with self._completer_lock:
+
+ if self.key_bindings == 'vi':
+ editing_mode = EditingMode.VI
+ else:
+ editing_mode = EditingMode.EMACS
+
+ self.prompt_app = PromptSession(
+ lexer=PygmentsLexer(MyCliLexer),
+ reserve_space_for_menu=self.get_reserved_space(),
+ message=get_message,
+ prompt_continuation=get_continuation,
+ bottom_toolbar=get_toolbar_tokens,
+ complete_style=complete_style,
+ input_processors=[ConditionalProcessor(
+ processor=HighlightMatchingBracketProcessor(
+ chars='[](){}'),
+ filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()
+ )],
+ tempfile_suffix='.sql',
+ completer=DynamicCompleter(lambda: self.completer),
+ history=history,
+ auto_suggest=AutoSuggestFromHistory(),
+ complete_while_typing=True,
+ multiline=cli_is_multiline(self),
+ style=style_factory(self.syntax_style, self.cli_style),
+ include_default_pygments_style=False,
+ key_bindings=key_bindings,
+ enable_open_in_editor=True,
+ enable_system_prompt=True,
+ enable_suspend=True,
+ editing_mode=editing_mode,
+ search_ignore_case=True
+ )
+
+ try:
+ while True:
+ one_iteration()
+ iterations += 1
+ except EOFError:
+ special.close_tee()
+ if not self.less_chatty:
+ self.echo('Goodbye!')
+
+ def log_output(self, output):
+ """Log the output in the audit log, if it's enabled."""
+ if self.logfile:
+ click.echo(output, file=self.logfile)
+
+ def echo(self, s, **kwargs):
+ """Print a message to stdout.
+
+ The message will be logged in the audit log, if enabled.
+
+ All keyword arguments are passed to click.echo().
+
+ """
+ self.log_output(s)
+ click.secho(s, **kwargs)
+
+ def get_output_margin(self, status=None):
+ """Get the output margin (number of rows for the prompt, footer and
+ timing message."""
+ margin = self.get_reserved_space() + self.get_prompt(self.prompt_format).count('\n') + 1
+ if special.is_timing_enabled():
+ margin += 1
+ if status:
+ margin += 1 + status.count('\n')
+
+ return margin
+
+
+ def output(self, output, status=None):
+ """Output text to stdout or a pager command.
+
+ The status text is not outputted to pager or files.
+
+ The message will be logged in the audit log, if enabled. The
+ message will be written to the tee file, if enabled. The
+ message will be written to the output file, if enabled.
+
+ """
+ if output:
+ size = self.prompt_app.output.get_size()
+
+ margin = self.get_output_margin(status)
+
+ fits = True
+ buf = []
+ output_via_pager = self.explicit_pager and special.is_pager_enabled()
+ for i, line in enumerate(output, 1):
+ self.log_output(line)
+ special.write_tee(line)
+ special.write_once(line)
+
+ if fits or output_via_pager:
+ # buffering
+ buf.append(line)
+ if len(line) > size.columns or i > (size.rows - margin):
+ fits = False
+ if not self.explicit_pager and special.is_pager_enabled():
+ # doesn't fit, use pager
+ output_via_pager = True
+
+ if not output_via_pager:
+ # doesn't fit, flush buffer
+ for line in buf:
+ click.secho(line)
+ buf = []
+ else:
+ click.secho(line)
+
+ if buf:
+ if output_via_pager:
+ def newlinewrapper(text):
+ for line in text:
+ yield line + "\n"
+ click.echo_via_pager(newlinewrapper(buf))
+ else:
+ for line in buf:
+ click.secho(line)
+
+ if status:
+ self.log_output(status)
+ click.secho(status)
+
+ def configure_pager(self):
+ # Provide sane defaults for less if they are empty.
+ if not os.environ.get('LESS'):
+ os.environ['LESS'] = '-RXF'
+
+ cnf = self.read_my_cnf_files(self.cnf_files, ['pager', 'skip-pager'])
+ if cnf['pager']:
+ special.set_pager(cnf['pager'])
+ self.explicit_pager = True
+ else:
+ self.explicit_pager = False
+
+ if cnf['skip-pager'] or not self.config['main'].as_bool('enable_pager'):
+ special.disable_pager()
+
+ def refresh_completions(self, reset=False):
+ if reset:
+ with self._completer_lock:
+ self.completer.reset_completions()
+ self.completion_refresher.refresh(
+ self.sqlexecute, self._on_completions_refreshed,
+ {'smart_completion': self.smart_completion,
+ 'supported_formats': self.formatter.supported_formats,
+ 'keyword_casing': self.completer.keyword_casing})
+
+ return [(None, None, None,
+ 'Auto-completion refresh started in the background.')]
+
+ def _on_completions_refreshed(self, new_completer):
+ """Swap the completer object in cli with the newly created completer.
+ """
+ with self._completer_lock:
+ self.completer = new_completer
+
+ if self.prompt_app:
+ # After refreshing, redraw the CLI to clear the statusbar
+ # "Refreshing completions..." indicator
+ self.prompt_app.app.invalidate()
+
+ def get_completions(self, text, cursor_positition):
+ with self._completer_lock:
+ return self.completer.get_completions(
+ Document(text=text, cursor_position=cursor_positition), None)
+
+ def get_prompt(self, string):
+ sqlexecute = self.sqlexecute
+ host = self.login_path if self.login_path and self.login_path_as_host else sqlexecute.host
+ now = datetime.now()
+ string = string.replace('\\u', sqlexecute.user or '(none)')
+ string = string.replace('\\h', host or '(none)')
+ string = string.replace('\\d', sqlexecute.dbname or '(none)')
+ string = string.replace('\\t', sqlexecute.server_type()[0] or 'mycli')
+ string = string.replace('\\n', "\n")
+ string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y'))
+ string = string.replace('\\m', now.strftime('%M'))
+ string = string.replace('\\P', now.strftime('%p'))
+ string = string.replace('\\R', now.strftime('%H'))
+ string = string.replace('\\r', now.strftime('%I'))
+ string = string.replace('\\s', now.strftime('%S'))
+ string = string.replace('\\p', str(sqlexecute.port))
+ string = string.replace('\\A', self.dsn_alias or '(none)')
+ string = string.replace('\\_', ' ')
+ return string
+
+ def run_query(self, query, new_line=True):
+ """Runs *query*."""
+ results = self.sqlexecute.run(query)
+ for result in results:
+ title, cur, headers, status = result
+ self.formatter.query = query
+ output = self.format_output(title, cur, headers)
+ for line in output:
+ click.echo(line, nl=new_line)
+
+ def format_output(self, title, cur, headers, expanded=False,
+ max_width=None):
+ expanded = expanded or self.formatter.format_name == 'vertical'
+ output = []
+
+ output_kwargs = {
+ 'dialect': 'unix',
+ 'disable_numparse': True,
+ 'preserve_whitespace': True,
+ 'style': self.output_style
+ }
+
+ if not self.formatter.format_name in sql_format.supported_formats:
+ output_kwargs["preprocessors"] = (preprocessors.align_decimals, )
+
+ if title: # Only print the title if it's not None.
+ output = itertools.chain(output, [title])
+
+ if cur:
+ column_types = None
+ if hasattr(cur, 'description'):
+ def get_col_type(col):
+ col_type = FIELD_TYPES.get(col[1], str)
+ return col_type if type(col_type) is type else str
+ column_types = [get_col_type(col) for col in cur.description]
+
+ if max_width is not None:
+ cur = list(cur)
+
+ formatted = self.formatter.format_output(
+ cur, headers, format_name='vertical' if expanded else None,
+ column_types=column_types,
+ **output_kwargs)
+
+ if isinstance(formatted, str):
+ formatted = formatted.splitlines()
+ formatted = iter(formatted)
+
+ if (not expanded and max_width and headers and cur):
+ first_line = next(formatted)
+ if len(strip_ansi(first_line)) > max_width:
+ formatted = self.formatter.format_output(
+ cur, headers, format_name='vertical', column_types=column_types, **output_kwargs)
+ if isinstance(formatted, str):
+ formatted = iter(formatted.splitlines())
+ else:
+ formatted = itertools.chain([first_line], formatted)
+
+ output = itertools.chain(output, formatted)
+
+
+ return output
+
+ def get_reserved_space(self):
+ """Get the number of lines to reserve for the completion menu."""
+ reserved_space_ratio = .45
+ max_reserved_space = 8
+ _, height = click.get_terminal_size()
+ return min(int(round(height * reserved_space_ratio)), max_reserved_space)
+
+ def get_last_query(self):
+ """Get the last query executed or None."""
+ return self.query_history[-1][0] if self.query_history else None
+
+
+@click.command()
+@click.option('-h', '--host', envvar='MYSQL_HOST', help='Host address of the database.')
+@click.option('-P', '--port', envvar='MYSQL_TCP_PORT', type=int, help='Port number to use for connection. Honors '
+ '$MYSQL_TCP_PORT.')
+@click.option('-u', '--user', help='User name to connect to the database.')
+@click.option('-S', '--socket', envvar='MYSQL_UNIX_PORT', help='The socket file to use for connection.')
+@click.option('-p', '--password', 'password', envvar='MYSQL_PWD', type=str,
+ help='Password to connect to the database.')
+@click.option('--pass', 'password', envvar='MYSQL_PWD', type=str,
+ help='Password to connect to the database.')
+@click.option('--ssh-user', help='User name to connect to ssh server.')
+@click.option('--ssh-host', help='Host name to connect to ssh server.')
+@click.option('--ssh-port', default=22, help='Port to connect to ssh server.')
+@click.option('--ssh-password', help='Password to connect to ssh server.')
+@click.option('--ssh-key-filename', help='Private key filename (identify file) for the ssh connection.')
+@click.option('--ssh-config-path', help='Path to ssh configuration.',
+ default=os.path.expanduser('~') + '/.ssh/config')
+@click.option('--ssh-config-host', help='Host to connect to ssh server reading from ssh configuration.')
+@click.option('--ssl-ca', help='CA file in PEM format.',
+ type=click.Path(exists=True))
+@click.option('--ssl-capath', help='CA directory.')
+@click.option('--ssl-cert', help='X509 cert in PEM format.',
+ type=click.Path(exists=True))
+@click.option('--ssl-key', help='X509 key in PEM format.',
+ type=click.Path(exists=True))
+@click.option('--ssl-cipher', help='SSL cipher to use.')
+@click.option('--ssl-verify-server-cert', is_flag=True,
+ help=('Verify server\'s "Common Name" in its cert against '
+ 'hostname used when connecting. This option is disabled '
+ 'by default.'))
+# as of 2016-02-15 revocation list is not supported by underling PyMySQL
+# library (--ssl-crl and --ssl-crlpath options in vanilla mysql client)
+@click.option('-V', '--version', is_flag=True, help='Output mycli\'s version.')
+@click.option('-v', '--verbose', is_flag=True, help='Verbose output.')
+@click.option('-D', '--database', 'dbname', help='Database to use.')
+@click.option('-d', '--dsn', default='', envvar='DSN',
+ help='Use DSN configured into the [alias_dsn] section of myclirc file.')
+@click.option('--list-dsn', 'list_dsn', is_flag=True,
+ help='list of DSN configured into the [alias_dsn] section of myclirc file.')
+@click.option('--list-ssh-config', 'list_ssh_config', is_flag=True,
+ help='list ssh configurations in the ssh config (requires paramiko).')
+@click.option('-R', '--prompt', 'prompt',
+ help='Prompt format (Default: "{0}").'.format(
+ MyCli.default_prompt))
+@click.option('-l', '--logfile', type=click.File(mode='a', encoding='utf-8'),
+ help='Log every query and its results to a file.')
+@click.option('--defaults-group-suffix', type=str,
+ help='Read MySQL config groups with the specified suffix.')
+@click.option('--defaults-file', type=click.Path(),
+ help='Only read MySQL options from the given file.')
+@click.option('--myclirc', type=click.Path(), default="~/.myclirc",
+ help='Location of myclirc file.')
+@click.option('--auto-vertical-output', is_flag=True,
+ help='Automatically switch to vertical output mode if the result is wider than the terminal width.')
+@click.option('-t', '--table', is_flag=True,
+ help='Display batch output in table format.')
+@click.option('--csv', is_flag=True,
+ help='Display batch output in CSV format.')
+@click.option('--warn/--no-warn', default=None,
+ help='Warn before running a destructive query.')
+@click.option('--local-infile', type=bool,
+ help='Enable/disable LOAD DATA LOCAL INFILE.')
+@click.option('--login-path', type=str,
+ help='Read this path from the login file.')
+@click.option('-e', '--execute', type=str,
+ help='Execute command and quit.')
+@click.argument('database', default='', nargs=1)
+def cli(database, user, host, port, socket, password, dbname,
+ version, verbose, prompt, logfile, defaults_group_suffix,
+ defaults_file, login_path, auto_vertical_output, local_infile,
+ ssl_ca, ssl_capath, ssl_cert, ssl_key, ssl_cipher,
+ ssl_verify_server_cert, table, csv, warn, execute, myclirc, dsn,
+ list_dsn, ssh_user, ssh_host, ssh_port, ssh_password,
+ ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host):
+ """A MySQL terminal client with auto-completion and syntax highlighting.
+
+ \b
+ Examples:
+ - mycli my_database
+ - mycli -u my_user -h my_host.com my_database
+ - mycli mysql://my_user@my_host.com:3306/my_database
+
+ """
+
+ if version:
+ print('Version:', __version__)
+ sys.exit(0)
+
+ mycli = MyCli(prompt=prompt, logfile=logfile,
+ defaults_suffix=defaults_group_suffix,
+ defaults_file=defaults_file, login_path=login_path,
+ auto_vertical_output=auto_vertical_output, warn=warn,
+ myclirc=myclirc)
+ if list_dsn:
+ try:
+ alias_dsn = mycli.config['alias_dsn']
+ except KeyError as err:
+ click.secho('Invalid DSNs found in the config file. '\
+ 'Please check the "[alias_dsn]" section in myclirc.',
+ err=True, fg='red')
+ exit(1)
+ except Exception as e:
+ click.secho(str(e), err=True, fg='red')
+ exit(1)
+ for alias, value in alias_dsn.items():
+ if verbose:
+ click.secho("{} : {}".format(alias, value))
+ else:
+ click.secho(alias)
+ sys.exit(0)
+ if list_ssh_config:
+ ssh_config = read_ssh_config(ssh_config_path)
+ for host in ssh_config.get_hostnames():
+ if verbose:
+ host_config = ssh_config.lookup(host)
+ click.secho("{} : {}".format(
+ host, host_config.get('hostname')))
+ else:
+ click.secho(host)
+ sys.exit(0)
+ # Choose which ever one has a valid value.
+ database = dbname or database
+
+ ssl = {
+ 'ca': ssl_ca and os.path.expanduser(ssl_ca),
+ 'cert': ssl_cert and os.path.expanduser(ssl_cert),
+ 'key': ssl_key and os.path.expanduser(ssl_key),
+ 'capath': ssl_capath,
+ 'cipher': ssl_cipher,
+ 'check_hostname': ssl_verify_server_cert,
+ }
+
+ # remove empty ssl options
+ ssl = {k: v for k, v in ssl.items() if v is not None}
+
+ dsn_uri = None
+
+ # Treat the database argument as a DSN alias if we're missing
+ # other connection information.
+ if (mycli.config['alias_dsn'] and database and '://' not in database
+ and not any([user, password, host, port, login_path])):
+ dsn, database = database, ''
+
+ if database and '://' in database:
+ dsn_uri, database = database, ''
+
+ if dsn:
+ try:
+ dsn_uri = mycli.config['alias_dsn'][dsn]
+ except KeyError:
+ click.secho('Could not find the specified DSN in the config file. '
+ 'Please check the "[alias_dsn]" section in your '
+ 'myclirc.', err=True, fg='red')
+ exit(1)
+ else:
+ mycli.dsn_alias = dsn
+
+ if dsn_uri:
+ uri = urlparse(dsn_uri)
+ if not database:
+ database = uri.path[1:] # ignore the leading fwd slash
+ if not user:
+ user = unquote(uri.username)
+ if not password and uri.password is not None:
+ password = unquote(uri.password)
+ if not host:
+ host = uri.hostname
+ if not port:
+ port = uri.port
+
+ if ssh_config_host:
+ ssh_config = read_ssh_config(
+ ssh_config_path
+ ).lookup(ssh_config_host)
+ ssh_host = ssh_host if ssh_host else ssh_config.get('hostname')
+ ssh_user = ssh_user if ssh_user else ssh_config.get('user')
+ if ssh_config.get('port') and ssh_port == 22:
+ # port has a default value, overwrite it if it's in the config
+ ssh_port = int(ssh_config.get('port'))
+ ssh_key_filename = ssh_key_filename if ssh_key_filename else ssh_config.get(
+ 'identityfile', [None])[0]
+
+ ssh_key_filename = ssh_key_filename and os.path.expanduser(ssh_key_filename)
+
+ mycli.connect(
+ database=database,
+ user=user,
+ passwd=password,
+ host=host,
+ port=port,
+ socket=socket,
+ local_infile=local_infile,
+ ssl=ssl,
+ ssh_user=ssh_user,
+ ssh_host=ssh_host,
+ ssh_port=ssh_port,
+ ssh_password=ssh_password,
+ ssh_key_filename=ssh_key_filename
+ )
+
+ mycli.logger.debug('Launch Params: \n'
+ '\tdatabase: %r'
+ '\tuser: %r'
+ '\thost: %r'
+ '\tport: %r', database, user, host, port)
+
+ # --execute argument
+ if execute:
+ try:
+ if csv:
+ mycli.formatter.format_name = 'csv'
+ elif not table:
+ mycli.formatter.format_name = 'tsv'
+
+ mycli.run_query(execute)
+ exit(0)
+ except Exception as e:
+ click.secho(str(e), err=True, fg='red')
+ exit(1)
+
+ if sys.stdin.isatty():
+ mycli.run_cli()
+ else:
+ stdin = click.get_text_stream('stdin')
+ try:
+ stdin_text = stdin.read()
+ except MemoryError:
+ click.secho('Failed! Ran out of memory.', err=True, fg='red')
+ click.secho('You might want to try the official mysql client.', err=True, fg='red')
+ click.secho('Sorry... :(', err=True, fg='red')
+ exit(1)
+
+ try:
+ sys.stdin = open('/dev/tty')
+ except (IOError, OSError):
+ mycli.logger.warning('Unable to open TTY as stdin.')
+
+ if (mycli.destructive_warning and
+ confirm_destructive_query(stdin_text) is False):
+ exit(0)
+ try:
+ new_line = True
+
+ if csv:
+ mycli.formatter.format_name = 'csv'
+ elif not table:
+ mycli.formatter.format_name = 'tsv'
+
+ mycli.run_query(stdin_text, new_line=new_line)
+ exit(0)
+ except Exception as e:
+ click.secho(str(e), err=True, fg='red')
+ exit(1)
+
+
+def need_completion_refresh(queries):
+ """Determines if the completion needs a refresh by checking if the sql
+ statement is an alter, create, drop or change db."""
+ for query in sqlparse.split(queries):
+ try:
+ first_token = query.split()[0]
+ if first_token.lower() in ('alter', 'create', 'use', '\\r',
+ '\\u', 'connect', 'drop', 'rename'):
+ return True
+ except Exception:
+ return False
+
+
+def need_completion_reset(queries):
+ """Determines if the statement is a database switch such as 'use' or '\\u'.
+ When a database is changed the existing completions must be reset before we
+ start the completion refresh for the new database.
+ """
+ for query in sqlparse.split(queries):
+ try:
+ first_token = query.split()[0]
+ if first_token.lower() in ('use', '\\u'):
+ return True
+ except Exception:
+ return False
+
+
+def is_mutating(status):
+ """Determines if the statement is mutating based on the status."""
+ if not status:
+ return False
+
+ mutating = set(['insert', 'update', 'delete', 'alter', 'create', 'drop',
+ 'replace', 'truncate', 'load', 'rename'])
+ return status.split(None, 1)[0].lower() in mutating
+
+
+def is_select(status):
+ """Returns true if the first word in status is 'select'."""
+ if not status:
+ return False
+ return status.split(None, 1)[0].lower() == 'select'
+
+
+def thanks_picker(files=()):
+ contents = []
+ for line in fileinput.input(files=files):
+ m = re.match('^ *\* (.*)', line)
+ if m:
+ contents.append(m.group(1))
+ return choice(contents)
+
+
+@prompt_register('edit-and-execute-command')
+def edit_and_execute(event):
+ """Different from the prompt-toolkit default, we want to have a choice not
+ to execute a query after editing, hence validate_and_handle=False."""
+ buff = event.current_buffer
+ buff.open_in_editor(validate_and_handle=False)
+
+
+def read_ssh_config(ssh_config_path):
+ ssh_config = paramiko.config.SSHConfig()
+ try:
+ with open(ssh_config_path) as f:
+ ssh_config.parse(f)
+ # Paramiko prior to version 2.7 raises Exception on parse errors.
+ # In 2.7 it has become paramiko.ssh_exception.SSHException,
+ # but let's catch everything for compatibility
+ except Exception as err:
+ click.secho(
+ f'Could not parse SSH configuration file {ssh_config_path}:\n{err} ',
+ err=True, fg='red'
+ )
+ sys.exit(1)
+ except FileNotFoundError as e:
+ click.secho(str(e), err=True, fg='red')
+ sys.exit(1)
+ else:
+ return ssh_config
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/mycli/myclirc b/mycli/myclirc
new file mode 100644
index 0000000..534b201
--- /dev/null
+++ b/mycli/myclirc
@@ -0,0 +1,121 @@
+# vi: ft=dosini
+[main]
+
+# Enables context sensitive auto-completion. If this is disabled the all
+# possible completions will be listed.
+smart_completion = True
+
+# Multi-line mode allows breaking up the sql statements into multiple lines. If
+# this is set to True, then the end of the statements must have a semi-colon.
+# If this is set to False then sql statements can't be split into multiple
+# lines. End of line (return) is considered as the end of the statement.
+multi_line = False
+
+# Destructive warning mode will alert you before executing a sql statement
+# that may cause harm to the database such as "drop table", "drop database"
+# or "shutdown".
+destructive_warning = True
+
+# log_file location.
+log_file = ~/.mycli.log
+
+# Default log level. Possible values: "CRITICAL", "ERROR", "WARNING", "INFO"
+# and "DEBUG". "NONE" disables logging.
+log_level = INFO
+
+# Log every query and its results to a file. Enable this by uncommenting the
+# line below.
+# audit_log = ~/.mycli-audit.log
+
+# Timing of sql statments and table rendering.
+timing = True
+
+# Table format. Possible values: ascii, double, github,
+# psql, plain, simple, grid, fancy_grid, pipe, orgtbl, rst, mediawiki, html,
+# latex, latex_booktabs, textile, moinmoin, jira, vertical, tsv, csv.
+# Recommended: ascii
+table_format = ascii
+
+# Syntax coloring style. Possible values (many support the "-dark" suffix):
+# manni, igor, xcode, vim, autumn, vs, rrt, native, perldoc, borland, tango, emacs,
+# friendly, monokai, paraiso, colorful, murphy, bw, pastie, paraiso, trac, default,
+# fruity.
+# Screenshots at http://mycli.net/syntax
+syntax_style = default
+
+# Keybindings: Possible values: emacs, vi.
+# Emacs mode: Ctrl-A is home, Ctrl-E is end. All emacs keybindings are available in the REPL.
+# When Vi mode is enabled you can use modal editing features offered by Vi in the REPL.
+key_bindings = emacs
+
+# Enabling this option will show the suggestions in a wider menu. Thus more items are suggested.
+wider_completion_menu = False
+
+# MySQL prompt
+# \D - The full current date
+# \d - Database name
+# \h - Hostname of the server
+# \m - Minutes of the current time
+# \n - Newline
+# \P - AM/PM
+# \p - Port
+# \R - The current time, in 24-hour military time (0–23)
+# \r - The current time, standard 12-hour time (1–12)
+# \s - Seconds of the current time
+# \t - Product type (Percona, MySQL, MariaDB)
+# \A - DSN alias name (from the [alias_dsn] section)
+# \u - Username
+prompt = '\t \u@\h:\d> '
+prompt_continuation = '->'
+
+# Skip intro info on startup and outro info on exit
+less_chatty = False
+
+# Use alias from --login-path instead of host name in prompt
+login_path_as_host = False
+
+# Cause result sets to be displayed vertically if they are too wide for the current window,
+# and using normal tabular format otherwise. (This applies to statements terminated by ; or \G.)
+auto_vertical_output = False
+
+# keyword casing preference. Possible values "lower", "upper", "auto"
+keyword_casing = auto
+
+# disabled pager on startup
+enable_pager = True
+
+# Custom colors for the completion menu, toolbar, etc.
+[colors]
+completion-menu.completion.current = 'bg:#ffffff #000000'
+completion-menu.completion = 'bg:#008888 #ffffff'
+completion-menu.meta.completion.current = 'bg:#44aaaa #000000'
+completion-menu.meta.completion = 'bg:#448888 #ffffff'
+completion-menu.multi-column-meta = 'bg:#aaffff #000000'
+scrollbar.arrow = 'bg:#003333'
+scrollbar = 'bg:#00aaaa'
+selected = '#ffffff bg:#6666aa'
+search = '#ffffff bg:#4444aa'
+search.current = '#ffffff bg:#44aa44'
+bottom-toolbar = 'bg:#222222 #aaaaaa'
+bottom-toolbar.off = 'bg:#222222 #888888'
+bottom-toolbar.on = 'bg:#222222 #ffffff'
+search-toolbar = 'noinherit bold'
+search-toolbar.text = 'nobold'
+system-toolbar = 'noinherit bold'
+arg-toolbar = 'noinherit bold'
+arg-toolbar.text = 'nobold'
+bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold'
+bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold'
+
+# style classes for colored table output
+output.header = "#00ff5f bold"
+output.odd-row = ""
+output.even-row = ""
+
+# Favorite queries.
+[favorite_queries]
+
+# Use the -d option to reference a DSN.
+# Special characters in passwords and other strings can be escaped with URL encoding.
+[alias_dsn]
+# example_dsn = mysql://[user[:password]@][host][:port][/dbname]
diff --git a/mycli/packages/__init__.py b/mycli/packages/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/mycli/packages/__init__.py
diff --git a/mycli/packages/completion_engine.py b/mycli/packages/completion_engine.py
new file mode 100644
index 0000000..2b19c32
--- /dev/null
+++ b/mycli/packages/completion_engine.py
@@ -0,0 +1,295 @@
+import os
+import sys
+import sqlparse
+from sqlparse.sql import Comparison, Identifier, Where
+from sqlparse.compat import text_type
+from .parseutils import last_word, extract_tables, find_prev_keyword
+from .special import parse_special_command
+
+
+def suggest_type(full_text, text_before_cursor):
+ """Takes the full_text that is typed so far and also the text before the
+ cursor to suggest completion type and scope.
+
+ Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
+ A scope for a column category will be a list of tables.
+ """
+
+ word_before_cursor = last_word(text_before_cursor,
+ include='many_punctuations')
+
+ identifier = None
+
+ # here should be removed once sqlparse has been fixed
+ try:
+ # If we've partially typed a word then word_before_cursor won't be an empty
+ # string. In that case we want to remove the partially typed string before
+ # sending it to the sqlparser. Otherwise the last token will always be the
+ # partially typed string which renders the smart completion useless because
+ # it will always return the list of keywords as completion.
+ if word_before_cursor:
+ if word_before_cursor.endswith(
+ '(') or word_before_cursor.startswith('\\'):
+ parsed = sqlparse.parse(text_before_cursor)
+ else:
+ parsed = sqlparse.parse(
+ text_before_cursor[:-len(word_before_cursor)])
+
+ # word_before_cursor may include a schema qualification, like
+ # "schema_name.partial_name" or "schema_name.", so parse it
+ # separately
+ p = sqlparse.parse(word_before_cursor)[0]
+
+ if p.tokens and isinstance(p.tokens[0], Identifier):
+ identifier = p.tokens[0]
+ else:
+ parsed = sqlparse.parse(text_before_cursor)
+ except (TypeError, AttributeError):
+ return [{'type': 'keyword'}]
+
+ if len(parsed) > 1:
+ # Multiple statements being edited -- isolate the current one by
+ # cumulatively summing statement lengths to find the one that bounds the
+ # current position
+ current_pos = len(text_before_cursor)
+ stmt_start, stmt_end = 0, 0
+
+ for statement in parsed:
+ stmt_len = len(text_type(statement))
+ stmt_start, stmt_end = stmt_end, stmt_end + stmt_len
+
+ if stmt_end >= current_pos:
+ text_before_cursor = full_text[stmt_start:current_pos]
+ full_text = full_text[stmt_start:]
+ break
+
+ elif parsed:
+ # A single statement
+ statement = parsed[0]
+ else:
+ # The empty string
+ statement = None
+
+ # Check for special commands and handle those separately
+ if statement:
+ # Be careful here because trivial whitespace is parsed as a statement,
+ # but the statement won't have a first token
+ tok1 = statement.token_first()
+ if tok1 and (tok1.value == 'source' or tok1.value.startswith('\\')):
+ return suggest_special(text_before_cursor)
+
+ last_token = statement and statement.token_prev(len(statement.tokens))[1] or ''
+
+ return suggest_based_on_last_token(last_token, text_before_cursor,
+ full_text, identifier)
+
+
+def suggest_special(text):
+ text = text.lstrip()
+ cmd, _, arg = parse_special_command(text)
+
+ if cmd == text:
+ # Trying to complete the special command itself
+ return [{'type': 'special'}]
+
+ if cmd in ('\\u', '\\r'):
+ return [{'type': 'database'}]
+
+ if cmd in ('\\T'):
+ return [{'type': 'table_format'}]
+
+ if cmd in ['\\f', '\\fs', '\\fd']:
+ return [{'type': 'favoritequery'}]
+
+ if cmd in ['\\dt', '\\dt+']:
+ return [
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'},
+ ]
+ elif cmd in ['\\.', 'source']:
+ return[{'type': 'file_name'}]
+
+ return [{'type': 'keyword'}, {'type': 'special'}]
+
+
+def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier):
+ if isinstance(token, str):
+ token_v = token.lower()
+ elif isinstance(token, Comparison):
+ # If 'token' is a Comparison type such as
+ # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling
+ # token.value on the comparison type will only return the lhs of the
+ # comparison. In this case a.id. So we need to do token.tokens to get
+ # both sides of the comparison and pick the last token out of that
+ # list.
+ token_v = token.tokens[-1].value.lower()
+ elif isinstance(token, Where):
+ # sqlparse groups all tokens from the where clause into a single token
+ # list. This means that token.value may be something like
+ # 'where foo > 5 and '. We need to look "inside" token.tokens to handle
+ # suggestions in complicated where clauses correctly
+ prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
+ return suggest_based_on_last_token(prev_keyword, text_before_cursor,
+ full_text, identifier)
+ else:
+ token_v = token.value.lower()
+
+ is_operand = lambda x: x and any([x.endswith(op) for op in ['+', '-', '*', '/']])
+
+ if not token:
+ return [{'type': 'keyword'}, {'type': 'special'}]
+ elif token_v.endswith('('):
+ p = sqlparse.parse(text_before_cursor)[0]
+
+ if p.tokens and isinstance(p.tokens[-1], Where):
+ # Four possibilities:
+ # 1 - Parenthesized clause like "WHERE foo AND ("
+ # Suggest columns/functions
+ # 2 - Function call like "WHERE foo("
+ # Suggest columns/functions
+ # 3 - Subquery expression like "WHERE EXISTS ("
+ # Suggest keywords, in order to do a subquery
+ # 4 - Subquery OR array comparison like "WHERE foo = ANY("
+ # Suggest columns/functions AND keywords. (If we wanted to be
+ # really fancy, we could suggest only array-typed columns)
+
+ column_suggestions = suggest_based_on_last_token('where',
+ text_before_cursor, full_text, identifier)
+
+ # Check for a subquery expression (cases 3 & 4)
+ where = p.tokens[-1]
+ idx, prev_tok = where.token_prev(len(where.tokens) - 1)
+
+ if isinstance(prev_tok, Comparison):
+ # e.g. "SELECT foo FROM bar WHERE foo = ANY("
+ prev_tok = prev_tok.tokens[-1]
+
+ prev_tok = prev_tok.value.lower()
+ if prev_tok == 'exists':
+ return [{'type': 'keyword'}]
+ else:
+ return column_suggestions
+
+ # Get the token before the parens
+ idx, prev_tok = p.token_prev(len(p.tokens) - 1)
+ if prev_tok and prev_tok.value and prev_tok.value.lower() == 'using':
+ # tbl1 INNER JOIN tbl2 USING (col1, col2)
+ tables = extract_tables(full_text)
+
+ # suggest columns that are present in more than one table
+ return [{'type': 'column', 'tables': tables, 'drop_unique': True}]
+ elif p.token_first().value.lower() == 'select':
+ # If the lparen is preceeded by a space chances are we're about to
+ # do a sub-select.
+ if last_word(text_before_cursor,
+ 'all_punctuations').startswith('('):
+ return [{'type': 'keyword'}]
+ elif p.token_first().value.lower() == 'show':
+ return [{'type': 'show'}]
+
+ # We're probably in a function argument list
+ return [{'type': 'column', 'tables': extract_tables(full_text)}]
+ elif token_v in ('set', 'order by', 'distinct'):
+ return [{'type': 'column', 'tables': extract_tables(full_text)}]
+ elif token_v == 'as':
+ # Don't suggest anything for an alias
+ return []
+ elif token_v in ('show'):
+ return [{'type': 'show'}]
+ elif token_v in ('to',):
+ p = sqlparse.parse(text_before_cursor)[0]
+ if p.token_first().value.lower() == 'change':
+ return [{'type': 'change'}]
+ else:
+ return [{'type': 'user'}]
+ elif token_v in ('user', 'for'):
+ return [{'type': 'user'}]
+ elif token_v in ('select', 'where', 'having'):
+ # Check for a table alias or schema qualification
+ parent = (identifier and identifier.get_parent_name()) or []
+
+ tables = extract_tables(full_text)
+ if parent:
+ tables = [t for t in tables if identifies(parent, *t)]
+ return [{'type': 'column', 'tables': tables},
+ {'type': 'table', 'schema': parent},
+ {'type': 'view', 'schema': parent},
+ {'type': 'function', 'schema': parent}]
+ else:
+ aliases = [alias or table for (schema, table, alias) in tables]
+ return [{'type': 'column', 'tables': tables},
+ {'type': 'function', 'schema': []},
+ {'type': 'alias', 'aliases': aliases},
+ {'type': 'keyword'}]
+ elif (token_v.endswith('join') and token.is_keyword) or (token_v in
+ ('copy', 'from', 'update', 'into', 'describe', 'truncate',
+ 'desc', 'explain')):
+ schema = (identifier and identifier.get_parent_name()) or []
+
+ # Suggest tables from either the currently-selected schema or the
+ # public schema if no schema has been specified
+ suggest = [{'type': 'table', 'schema': schema}]
+
+ if not schema:
+ # Suggest schemas
+ suggest.insert(0, {'type': 'schema'})
+
+ # Only tables can be TRUNCATED, otherwise suggest views
+ if token_v != 'truncate':
+ suggest.append({'type': 'view', 'schema': schema})
+
+ return suggest
+
+ elif token_v in ('table', 'view', 'function'):
+ # E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>'
+ rel_type = token_v
+ schema = (identifier and identifier.get_parent_name()) or []
+ if schema:
+ return [{'type': rel_type, 'schema': schema}]
+ else:
+ return [{'type': 'schema'}, {'type': rel_type, 'schema': []}]
+ elif token_v == 'on':
+ tables = extract_tables(full_text) # [(schema, table, alias), ...]
+ parent = (identifier and identifier.get_parent_name()) or []
+ if parent:
+ # "ON parent.<suggestion>"
+ # parent can be either a schema name or table alias
+ tables = [t for t in tables if identifies(parent, *t)]
+ return [{'type': 'column', 'tables': tables},
+ {'type': 'table', 'schema': parent},
+ {'type': 'view', 'schema': parent},
+ {'type': 'function', 'schema': parent}]
+ else:
+ # ON <suggestion>
+ # Use table alias if there is one, otherwise the table name
+ aliases = [alias or table for (schema, table, alias) in tables]
+ suggest = [{'type': 'alias', 'aliases': aliases}]
+
+ # The lists of 'aliases' could be empty if we're trying to complete
+ # a GRANT query. eg: GRANT SELECT, INSERT ON <tab>
+ # In that case we just suggest all tables.
+ if not aliases:
+ suggest.append({'type': 'table', 'schema': parent})
+ return suggest
+
+ elif token_v in ('use', 'database', 'template', 'connect'):
+ # "\c <db", "use <db>", "DROP DATABASE <db>",
+ # "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
+ return [{'type': 'database'}]
+ elif token_v == 'tableformat':
+ return [{'type': 'table_format'}]
+ elif token_v.endswith(',') or is_operand(token_v) or token_v in ['=', 'and', 'or']:
+ prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
+ if prev_keyword:
+ return suggest_based_on_last_token(
+ prev_keyword, text_before_cursor, full_text, identifier)
+ else:
+ return []
+ else:
+ return [{'type': 'keyword'}]
+
+
+def identifies(id, schema, table, alias):
+ return id == alias or id == table or (
+ schema and (id == schema + '.' + table))
diff --git a/mycli/packages/filepaths.py b/mycli/packages/filepaths.py
new file mode 100644
index 0000000..79fe26d
--- /dev/null
+++ b/mycli/packages/filepaths.py
@@ -0,0 +1,106 @@
+import os
+import platform
+
+
+if os.name == "posix":
+ if platform.system() == "Darwin":
+ DEFAULT_SOCKET_DIRS = ("/tmp",)
+ else:
+ DEFAULT_SOCKET_DIRS = ("/var/run", "/var/lib")
+else:
+ DEFAULT_SOCKET_DIRS = ()
+
+
+def list_path(root_dir):
+ """List directory if exists.
+
+ :param root_dir: str
+ :return: list
+
+ """
+ res = []
+ if os.path.isdir(root_dir):
+ for name in os.listdir(root_dir):
+ res.append(name)
+ return res
+
+
+def complete_path(curr_dir, last_dir):
+ """Return the path to complete that matches the last entered component.
+
+ If the last entered component is ~, expanded path would not
+ match, so return all of the available paths.
+
+ :param curr_dir: str
+ :param last_dir: str
+ :return: str
+
+ """
+ if not last_dir or curr_dir.startswith(last_dir):
+ return curr_dir
+ elif last_dir == '~':
+ return os.path.join(last_dir, curr_dir)
+
+
+def parse_path(root_dir):
+ """Split path into head and last component for the completer.
+
+ Also return position where last component starts.
+
+ :param root_dir: str path
+ :return: tuple of (string, string, int)
+
+ """
+ base_dir, last_dir, position = '', '', 0
+ if root_dir:
+ base_dir, last_dir = os.path.split(root_dir)
+ position = -len(last_dir) if last_dir else 0
+ return base_dir, last_dir, position
+
+
+def suggest_path(root_dir):
+ """List all files and subdirectories in a directory.
+
+ If the directory is not specified, suggest root directory,
+ user directory, current and parent directory.
+
+ :param root_dir: string: directory to list
+ :return: list
+
+ """
+ if not root_dir:
+ return [os.path.abspath(os.sep), '~', os.curdir, os.pardir]
+
+ if '~' in root_dir:
+ root_dir = os.path.expanduser(root_dir)
+
+ if not os.path.exists(root_dir):
+ root_dir, _ = os.path.split(root_dir)
+
+ return list_path(root_dir)
+
+
+def dir_path_exists(path):
+ """Check if the directory path exists for a given file.
+
+ For example, for a file /home/user/.cache/mycli/log, check if
+ /home/user/.cache/mycli exists.
+
+ :param str path: The file path.
+ :return: Whether or not the directory path exists.
+
+ """
+ return os.path.exists(os.path.dirname(path))
+
+
+def guess_socket_location():
+ """Try to guess the location of the default mysql socket file."""
+ socket_dirs = filter(os.path.exists, DEFAULT_SOCKET_DIRS)
+ for directory in socket_dirs:
+ for r, dirs, files in os.walk(directory, topdown=True):
+ for filename in files:
+ name, ext = os.path.splitext(filename)
+ if name.startswith("mysql") and ext in ('.socket', '.sock'):
+ return os.path.join(r, filename)
+ dirs[:] = [d for d in dirs if d.startswith("mysql")]
+ return None
diff --git a/mycli/packages/paramiko_stub/__init__.py b/mycli/packages/paramiko_stub/__init__.py
new file mode 100644
index 0000000..045b00e
--- /dev/null
+++ b/mycli/packages/paramiko_stub/__init__.py
@@ -0,0 +1,28 @@
+"""A module to import instead of paramiko when it is not available (to avoid
+checking for paramiko all over the place).
+
+When paramiko is first envoked, it simply shuts down mycli, telling
+user they either have to install paramiko or should not use SSH
+features.
+
+"""
+
+
+class Paramiko:
+ def __getattr__(self, name):
+ import sys
+ from textwrap import dedent
+ print(dedent("""
+ To enable certain SSH features you need to install paramiko:
+
+ pip install paramiko
+
+ It is required for the following configuration options:
+ --list-ssh-config
+ --ssh-config-host
+ --ssh-host
+ """))
+ sys.exit(1)
+
+
+paramiko = Paramiko()
diff --git a/mycli/packages/parseutils.py b/mycli/packages/parseutils.py
new file mode 100644
index 0000000..e3b383e
--- /dev/null
+++ b/mycli/packages/parseutils.py
@@ -0,0 +1,267 @@
+import re
+import sqlparse
+from sqlparse.sql import IdentifierList, Identifier, Function
+from sqlparse.tokens import Keyword, DML, Punctuation
+
+cleanup_regex = {
+ # This matches only alphanumerics and underscores.
+ 'alphanum_underscore': re.compile(r'(\w+)$'),
+ # This matches everything except spaces, parens, colon, and comma
+ 'many_punctuations': re.compile(r'([^():,\s]+)$'),
+ # This matches everything except spaces, parens, colon, comma, and period
+ 'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
+ # This matches everything except a space.
+ 'all_punctuations': re.compile('([^\s]+)$'),
+ }
+
+def last_word(text, include='alphanum_underscore'):
+ """
+ Find the last word in a sentence.
+
+ >>> last_word('abc')
+ 'abc'
+ >>> last_word(' abc')
+ 'abc'
+ >>> last_word('')
+ ''
+ >>> last_word(' ')
+ ''
+ >>> last_word('abc ')
+ ''
+ >>> last_word('abc def')
+ 'def'
+ >>> last_word('abc def ')
+ ''
+ >>> last_word('abc def;')
+ ''
+ >>> last_word('bac $def')
+ 'def'
+ >>> last_word('bac $def', include='most_punctuations')
+ '$def'
+ >>> last_word('bac \def', include='most_punctuations')
+ '\\\\def'
+ >>> last_word('bac \def;', include='most_punctuations')
+ '\\\\def;'
+ >>> last_word('bac::def', include='most_punctuations')
+ 'def'
+ """
+
+ if not text: # Empty string
+ return ''
+
+ if text[-1].isspace():
+ return ''
+ else:
+ regex = cleanup_regex[include]
+ matches = regex.search(text)
+ if matches:
+ return matches.group(0)
+ else:
+ return ''
+
+
+# This code is borrowed from sqlparse example script.
+# <url>
+def is_subselect(parsed):
+ if not parsed.is_group:
+ return False
+ for item in parsed.tokens:
+ if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT',
+ 'UPDATE', 'CREATE', 'DELETE'):
+ return True
+ return False
+
+def extract_from_part(parsed, stop_at_punctuation=True):
+ tbl_prefix_seen = False
+ for item in parsed.tokens:
+ if tbl_prefix_seen:
+ if is_subselect(item):
+ for x in extract_from_part(item, stop_at_punctuation):
+ yield x
+ elif stop_at_punctuation and item.ttype is Punctuation:
+ return
+ # An incomplete nested select won't be recognized correctly as a
+ # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
+ # the second FROM to trigger this elif condition resulting in a
+ # StopIteration. So we need to ignore the keyword if the keyword
+ # FROM.
+ # Also 'SELECT * FROM abc JOIN def' will trigger this elif
+ # condition. So we need to ignore the keyword JOIN and its variants
+ # INNER JOIN, FULL OUTER JOIN, etc.
+ elif item.ttype is Keyword and (
+ not item.value.upper() == 'FROM') and (
+ not item.value.upper().endswith('JOIN')):
+ return
+ else:
+ yield item
+ elif ((item.ttype is Keyword or item.ttype is Keyword.DML) and
+ item.value.upper() in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE', 'JOIN',)):
+ tbl_prefix_seen = True
+ # 'SELECT a, FROM abc' will detect FROM as part of the column list.
+ # So this check here is necessary.
+ elif isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ if (identifier.ttype is Keyword and
+ identifier.value.upper() == 'FROM'):
+ tbl_prefix_seen = True
+ break
+
+def extract_table_identifiers(token_stream):
+ """yields tuples of (schema_name, table_name, table_alias)"""
+
+ for item in token_stream:
+ if isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ # Sometimes Keywords (such as FROM ) are classified as
+ # identifiers which don't have the get_real_name() method.
+ try:
+ schema_name = identifier.get_parent_name()
+ real_name = identifier.get_real_name()
+ except AttributeError:
+ continue
+ if real_name:
+ yield (schema_name, real_name, identifier.get_alias())
+ elif isinstance(item, Identifier):
+ real_name = item.get_real_name()
+ schema_name = item.get_parent_name()
+
+ if real_name:
+ yield (schema_name, real_name, item.get_alias())
+ else:
+ name = item.get_name()
+ yield (None, name, item.get_alias() or name)
+ elif isinstance(item, Function):
+ yield (None, item.get_name(), item.get_name())
+
+# extract_tables is inspired from examples in the sqlparse lib.
+def extract_tables(sql):
+ """Extract the table names from an SQL statment.
+
+ Returns a list of (schema, table, alias) tuples
+
+ """
+ parsed = sqlparse.parse(sql)
+ if not parsed:
+ return []
+
+ # INSERT statements must stop looking for tables at the sign of first
+ # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
+ # abc is the table name, but if we don't stop at the first lparen, then
+ # we'll identify abc, col1 and col2 as table names.
+ insert_stmt = parsed[0].token_first().value.lower() == 'insert'
+ stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
+ return list(extract_table_identifiers(stream))
+
+def find_prev_keyword(sql):
+ """ Find the last sql keyword in an SQL statement
+
+ Returns the value of the last keyword, and the text of the query with
+ everything after the last keyword stripped
+ """
+ if not sql.strip():
+ return None, ''
+
+ parsed = sqlparse.parse(sql)[0]
+ flattened = list(parsed.flatten())
+
+ logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
+
+ for t in reversed(flattened):
+ if t.value == '(' or (t.is_keyword and (
+ t.value.upper() not in logical_operators)):
+ # Find the location of token t in the original parsed statement
+ # We can't use parsed.token_index(t) because t may be a child token
+ # inside a TokenList, in which case token_index thows an error
+ # Minimal example:
+ # p = sqlparse.parse('select * from foo where bar')
+ # t = list(p.flatten())[-3] # The "Where" token
+ # p.token_index(t) # Throws ValueError: not in list
+ idx = flattened.index(t)
+
+ # Combine the string values of all tokens in the original list
+ # up to and including the target keyword token t, to produce a
+ # query string with everything after the keyword token removed
+ text = ''.join(tok.value for tok in flattened[:idx+1])
+ return t, text
+
+ return None, ''
+
+
+def query_starts_with(query, prefixes):
+ """Check if the query starts with any item from *prefixes*."""
+ prefixes = [prefix.lower() for prefix in prefixes]
+ formatted_sql = sqlparse.format(query.lower(), strip_comments=True)
+ return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
+
+
+def queries_start_with(queries, prefixes):
+ """Check if any queries start with any item from *prefixes*."""
+ for query in sqlparse.split(queries):
+ if query and query_starts_with(query, prefixes) is True:
+ return True
+ return False
+
+
+def query_has_where_clause(query):
+ """Check if the query contains a where-clause."""
+ return any(
+ isinstance(token, sqlparse.sql.Where)
+ for token_list in sqlparse.parse(query)
+ for token in token_list
+ )
+
+
+def is_destructive(queries):
+ """Returns if any of the queries in *queries* is destructive."""
+ keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter')
+ for query in sqlparse.split(queries):
+ if query:
+ if query_starts_with(query, keywords) is True:
+ return True
+ elif query_starts_with(
+ query, ['update']
+ ) is True and not query_has_where_clause(query):
+ return True
+
+ return False
+
+
+def is_open_quote(sql):
+ """Returns true if the query contains an unclosed quote."""
+
+ # parsed can contain one or more semi-colon separated commands
+ parsed = sqlparse.parse(sql)
+ return any(_parsed_is_open_quote(p) for p in parsed)
+
+
+if __name__ == '__main__':
+ sql = 'select * from (select t. from tabl t'
+ print (extract_tables(sql))
+
+
+def is_dropping_database(queries, dbname):
+ """Determine if the query is dropping a specific database."""
+ result = False
+ if dbname is None:
+ return False
+
+ def normalize_db_name(db):
+ return db.lower().strip('`"')
+
+ dbname = normalize_db_name(dbname)
+
+ for query in sqlparse.parse(queries):
+ keywords = [t for t in query.tokens if t.is_keyword]
+ if len(keywords) < 2:
+ continue
+ if keywords[0].normalized in ("DROP", "CREATE") and keywords[1].value.lower() in (
+ "database",
+ "schema",
+ ):
+ database_token = next(
+ (t for t in query.tokens if isinstance(t, Identifier)), None
+ )
+ if database_token is not None and normalize_db_name(database_token.get_name()) == dbname:
+ result = keywords[0].normalized == "DROP"
+ else:
+ return result
diff --git a/mycli/packages/prompt_utils.py b/mycli/packages/prompt_utils.py
new file mode 100644
index 0000000..fb1e431
--- /dev/null
+++ b/mycli/packages/prompt_utils.py
@@ -0,0 +1,54 @@
+import sys
+import click
+from .parseutils import is_destructive
+
+
+class ConfirmBoolParamType(click.ParamType):
+ name = 'confirmation'
+
+ def convert(self, value, param, ctx):
+ if isinstance(value, bool):
+ return bool(value)
+ value = value.lower()
+ if value in ('yes', 'y'):
+ return True
+ elif value in ('no', 'n'):
+ return False
+ self.fail('%s is not a valid boolean' % value, param, ctx)
+
+ def __repr__(self):
+ return 'BOOL'
+
+
+BOOLEAN_TYPE = ConfirmBoolParamType()
+
+
+def confirm_destructive_query(queries):
+ """Check if the query is destructive and prompts the user to confirm.
+
+ Returns:
+ * None if the query is non-destructive or we can't prompt the user.
+ * True if the query is destructive and the user wants to proceed.
+ * False if the query is destructive and the user doesn't want to proceed.
+
+ """
+ prompt_text = ("You're about to run a destructive command.\n"
+ "Do you want to proceed? (y/n)")
+ if is_destructive(queries) and sys.stdin.isatty():
+ return prompt(prompt_text, type=BOOLEAN_TYPE)
+
+
+def confirm(*args, **kwargs):
+ """Prompt for confirmation (yes/no) and handle any abort exceptions."""
+ try:
+ return click.confirm(*args, **kwargs)
+ except click.Abort:
+ return False
+
+
+def prompt(*args, **kwargs):
+ """Prompt the user for input and handle any abort exceptions."""
+ try:
+ return click.prompt(*args, **kwargs)
+ except click.Abort:
+ return False
diff --git a/mycli/packages/special/__init__.py b/mycli/packages/special/__init__.py
new file mode 100644
index 0000000..92bcca6
--- /dev/null
+++ b/mycli/packages/special/__init__.py
@@ -0,0 +1,10 @@
+__all__ = []
+
+def export(defn):
+ """Decorator to explicitly mark functions that are exposed in a lib."""
+ globals()[defn.__name__] = defn
+ __all__.append(defn.__name__)
+ return defn
+
+from . import dbcommands
+from . import iocommands
diff --git a/mycli/packages/special/dbcommands.py b/mycli/packages/special/dbcommands.py
new file mode 100644
index 0000000..ed90e4c
--- /dev/null
+++ b/mycli/packages/special/dbcommands.py
@@ -0,0 +1,157 @@
+import logging
+import os
+import platform
+from mycli import __version__
+from mycli.packages.special import iocommands
+from mycli.packages.special.utils import format_uptime
+from .main import special_command, RAW_QUERY, PARSED_QUERY
+from pymysql import ProgrammingError
+
+log = logging.getLogger(__name__)
+
+
+@special_command('\\dt', '\\dt[+] [table]', 'List or describe tables.',
+ arg_type=PARSED_QUERY, case_sensitive=True)
+def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False):
+ if arg:
+ query = 'SHOW FIELDS FROM {0}'.format(arg)
+ else:
+ query = 'SHOW TABLES'
+ log.debug(query)
+ cur.execute(query)
+ tables = cur.fetchall()
+ status = ''
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ else:
+ return [(None, None, None, '')]
+
+ if verbose and arg:
+ query = 'SHOW CREATE TABLE {0}'.format(arg)
+ log.debug(query)
+ cur.execute(query)
+ status = cur.fetchone()[1]
+
+ return [(None, tables, headers, status)]
+
+@special_command('\\l', '\\l', 'List databases.', arg_type=RAW_QUERY, case_sensitive=True)
+def list_databases(cur, **_):
+ query = 'SHOW DATABASES'
+ log.debug(query)
+ cur.execute(query)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ return [(None, cur, headers, '')]
+ else:
+ return [(None, None, None, '')]
+
+@special_command('status', '\\s', 'Get status information from the server.',
+ arg_type=RAW_QUERY, aliases=('\\s', ), case_sensitive=True)
+def status(cur, **_):
+ query = 'SHOW GLOBAL STATUS;'
+ log.debug(query)
+ try:
+ cur.execute(query)
+ except ProgrammingError:
+ # Fallback in case query fail, as it does with Mysql 4
+ query = 'SHOW STATUS;'
+ log.debug(query)
+ cur.execute(query)
+ status = dict(cur.fetchall())
+
+ query = 'SHOW GLOBAL VARIABLES;'
+ log.debug(query)
+ cur.execute(query)
+ variables = dict(cur.fetchall())
+
+ # prepare in case keys are bytes, as with Python 3 and Mysql 4
+ if (isinstance(list(variables)[0], bytes) and
+ isinstance(list(status)[0], bytes)):
+ variables = {k.decode('utf-8'): v.decode('utf-8') for k, v
+ in variables.items()}
+ status = {k.decode('utf-8'): v.decode('utf-8') for k, v
+ in status.items()}
+
+ # Create output buffers.
+ title = []
+ output = []
+ footer = []
+
+ title.append('--------------')
+
+ # Output the mycli client information.
+ implementation = platform.python_implementation()
+ version = platform.python_version()
+ client_info = []
+ client_info.append('mycli {0},'.format(__version__))
+ client_info.append('running on {0} {1}'.format(implementation, version))
+ title.append(' '.join(client_info) + '\n')
+
+ # Build the output that will be displayed as a table.
+ output.append(('Connection id:', cur.connection.thread_id()))
+
+ query = 'SELECT DATABASE(), USER();'
+ log.debug(query)
+ cur.execute(query)
+ db, user = cur.fetchone()
+ if db is None:
+ db = ''
+
+ output.append(('Current database:', db))
+ output.append(('Current user:', user))
+
+ if iocommands.is_pager_enabled():
+ if 'PAGER' in os.environ:
+ pager = os.environ['PAGER']
+ else:
+ pager = 'System default'
+ else:
+ pager = 'stdout'
+ output.append(('Current pager:', pager))
+
+ output.append(('Server version:', '{0} {1}'.format(
+ variables['version'], variables['version_comment'])))
+ output.append(('Protocol version:', variables['protocol_version']))
+
+ if 'unix' in cur.connection.host_info.lower():
+ host_info = cur.connection.host_info
+ else:
+ host_info = '{0} via TCP/IP'.format(cur.connection.host)
+
+ output.append(('Connection:', host_info))
+
+ query = ('SELECT @@character_set_server, @@character_set_database, '
+ '@@character_set_client, @@character_set_connection LIMIT 1;')
+ log.debug(query)
+ cur.execute(query)
+ charset = cur.fetchone()
+ output.append(('Server characterset:', charset[0]))
+ output.append(('Db characterset:', charset[1]))
+ output.append(('Client characterset:', charset[2]))
+ output.append(('Conn. characterset:', charset[3]))
+
+ if 'TCP/IP' in host_info:
+ output.append(('TCP port:', cur.connection.port))
+ else:
+ output.append(('UNIX socket:', variables['socket']))
+
+ output.append(('Uptime:', format_uptime(status['Uptime'])))
+
+ # Print the current server statistics.
+ stats = []
+ stats.append('Connections: {0}'.format(status['Threads_connected']))
+ if 'Queries' in status:
+ stats.append('Queries: {0}'.format(status['Queries']))
+ stats.append('Slow queries: {0}'.format(status['Slow_queries']))
+ stats.append('Opens: {0}'.format(status['Opened_tables']))
+ stats.append('Flush tables: {0}'.format(status['Flush_commands']))
+ stats.append('Open tables: {0}'.format(status['Open_tables']))
+ if 'Queries' in status:
+ queries_per_second = int(status['Queries']) / int(status['Uptime'])
+ stats.append('Queries per second avg: {:.3f}'.format(
+ queries_per_second))
+ stats = ' '.join(stats)
+ footer.append('\n' + stats)
+
+ footer.append('--------------')
+ return [('\n'.join(title), output, '', '\n'.join(footer))]
diff --git a/mycli/packages/special/delimitercommand.py b/mycli/packages/special/delimitercommand.py
new file mode 100644
index 0000000..994b134
--- /dev/null
+++ b/mycli/packages/special/delimitercommand.py
@@ -0,0 +1,80 @@
+import re
+import sqlparse
+
+
+class DelimiterCommand(object):
+ def __init__(self):
+ self._delimiter = ';'
+
+ def _split(self, sql):
+ """Temporary workaround until sqlparse.split() learns about custom
+ delimiters."""
+
+ placeholder = "\ufffc" # unicode object replacement character
+
+ if self._delimiter == ';':
+ return sqlparse.split(sql)
+
+ # We must find a string that original sql does not contain.
+ # Most likely, our placeholder is enough, but if not, keep looking
+ while placeholder in sql:
+ placeholder += placeholder[0]
+ sql = sql.replace(';', placeholder)
+ sql = sql.replace(self._delimiter, ';')
+
+ split = sqlparse.split(sql)
+
+ return [
+ stmt.replace(';', self._delimiter).replace(placeholder, ';')
+ for stmt in split
+ ]
+
+ def queries_iter(self, input):
+ """Iterate over queries in the input string."""
+
+ queries = self._split(input)
+ while queries:
+ for sql in queries:
+ delimiter = self._delimiter
+ sql = queries.pop(0)
+ if sql.endswith(delimiter):
+ trailing_delimiter = True
+ sql = sql.strip(delimiter)
+ else:
+ trailing_delimiter = False
+
+ yield sql
+
+ # if the delimiter was changed by the last command,
+ # re-split everything, and if we previously stripped
+ # the delimiter, append it to the end
+ if self._delimiter != delimiter:
+ combined_statement = ' '.join([sql] + queries)
+ if trailing_delimiter:
+ combined_statement += delimiter
+ queries = self._split(combined_statement)[1:]
+
+ def set(self, arg, **_):
+ """Change delimiter.
+
+ Since `arg` is everything that follows the DELIMITER token
+ after sqlparse (it may include other statements separated by
+ the new delimiter), we want to set the delimiter to the first
+ word of it.
+
+ """
+ match = arg and re.search(r'[^\s]+', arg)
+ if not match:
+ message = 'Missing required argument, delimiter'
+ return [(None, None, None, message)]
+
+ delimiter = match.group()
+ if delimiter.lower() == 'delimiter':
+ return [(None, None, None, 'Invalid delimiter "delimiter"')]
+
+ self._delimiter = delimiter
+ return [(None, None, None, "Changed delimiter to {}".format(delimiter))]
+
+ @property
+ def current(self):
+ return self._delimiter
diff --git a/mycli/packages/special/favoritequeries.py b/mycli/packages/special/favoritequeries.py
new file mode 100644
index 0000000..0b91400
--- /dev/null
+++ b/mycli/packages/special/favoritequeries.py
@@ -0,0 +1,63 @@
+class FavoriteQueries(object):
+
+ section_name = 'favorite_queries'
+
+ usage = '''
+Favorite Queries are a way to save frequently used queries
+with a short name.
+Examples:
+
+ # Save a new favorite query.
+ > \\fs simple select * from abc where a is not Null;
+
+ # List all favorite queries.
+ > \\f
+ ╒════════╤═══════════════════════════════════════╕
+ │ Name │ Query │
+ ╞════════╪═══════════════════════════════════════╡
+ │ simple │ SELECT * FROM abc where a is not NULL │
+ ╘════════╧═══════════════════════════════════════╛
+
+ # Run a favorite query.
+ > \\f simple
+ ╒════════╤════════╕
+ │ a │ b │
+ ╞════════╪════════╡
+ │ 日本語 │ 日本語 │
+ ╘════════╧════════╛
+
+ # Delete a favorite query.
+ > \\fd simple
+ simple: Deleted
+'''
+
+ # Class-level variable, for convenience to use as a singleton.
+ instance = None
+
+ def __init__(self, config):
+ self.config = config
+
+ @classmethod
+ def from_config(cls, config):
+ return FavoriteQueries(config)
+
+ def list(self):
+ return self.config.get(self.section_name, [])
+
+ def get(self, name):
+ return self.config.get(self.section_name, {}).get(name, None)
+
+ def save(self, name, query):
+ self.config.encoding = 'utf-8'
+ if self.section_name not in self.config:
+ self.config[self.section_name] = {}
+ self.config[self.section_name][name] = query
+ self.config.write()
+
+ def delete(self, name):
+ try:
+ del self.config[self.section_name][name]
+ except KeyError:
+ return '%s: Not Found.' % name
+ self.config.write()
+ return '%s: Deleted' % name
diff --git a/mycli/packages/special/iocommands.py b/mycli/packages/special/iocommands.py
new file mode 100644
index 0000000..11dca8d
--- /dev/null
+++ b/mycli/packages/special/iocommands.py
@@ -0,0 +1,453 @@
+import os
+import re
+import locale
+import logging
+import subprocess
+import shlex
+from io import open
+from time import sleep
+
+import click
+import sqlparse
+
+from . import export
+from .main import special_command, NO_QUERY, PARSED_QUERY
+from .favoritequeries import FavoriteQueries
+from .delimitercommand import DelimiterCommand
+from .utils import handle_cd_command
+from mycli.packages.prompt_utils import confirm_destructive_query
+
+TIMING_ENABLED = False
+use_expanded_output = False
+PAGER_ENABLED = True
+tee_file = None
+once_file = None
+written_to_once_file = False
+delimiter_command = DelimiterCommand()
+
+
+@export
+def set_timing_enabled(val):
+ global TIMING_ENABLED
+ TIMING_ENABLED = val
+
+@export
+def set_pager_enabled(val):
+ global PAGER_ENABLED
+ PAGER_ENABLED = val
+
+
+@export
+def is_pager_enabled():
+ return PAGER_ENABLED
+
+@export
+@special_command('pager', '\\P [command]',
+ 'Set PAGER. Print the query results via PAGER.',
+ arg_type=PARSED_QUERY, aliases=('\\P', ), case_sensitive=True)
+def set_pager(arg, **_):
+ if arg:
+ os.environ['PAGER'] = arg
+ msg = 'PAGER set to %s.' % arg
+ set_pager_enabled(True)
+ else:
+ if 'PAGER' in os.environ:
+ msg = 'PAGER set to %s.' % os.environ['PAGER']
+ else:
+ # This uses click's default per echo_via_pager.
+ msg = 'Pager enabled.'
+ set_pager_enabled(True)
+
+ return [(None, None, None, msg)]
+
+@export
+@special_command('nopager', '\\n', 'Disable pager, print to stdout.',
+ arg_type=NO_QUERY, aliases=('\\n', ), case_sensitive=True)
+def disable_pager():
+ set_pager_enabled(False)
+ return [(None, None, None, 'Pager disabled.')]
+
+@special_command('\\timing', '\\t', 'Toggle timing of commands.', arg_type=NO_QUERY, aliases=('\\t', ), case_sensitive=True)
+def toggle_timing():
+ global TIMING_ENABLED
+ TIMING_ENABLED = not TIMING_ENABLED
+ message = "Timing is "
+ message += "on." if TIMING_ENABLED else "off."
+ return [(None, None, None, message)]
+
+@export
+def is_timing_enabled():
+ return TIMING_ENABLED
+
+@export
+def set_expanded_output(val):
+ global use_expanded_output
+ use_expanded_output = val
+
+@export
+def is_expanded_output():
+ return use_expanded_output
+
+_logger = logging.getLogger(__name__)
+
+@export
+def editor_command(command):
+ """
+ Is this an external editor command?
+ :param command: string
+ """
+ # It is possible to have `\e filename` or `SELECT * FROM \e`. So we check
+ # for both conditions.
+ return command.strip().endswith('\\e') or command.strip().startswith('\\e')
+
+@export
+def get_filename(sql):
+ if sql.strip().startswith('\\e'):
+ command, _, filename = sql.partition(' ')
+ return filename.strip() or None
+
+
+@export
+def get_editor_query(sql):
+ """Get the query part of an editor command."""
+ sql = sql.strip()
+
+ # The reason we can't simply do .strip('\e') is that it strips characters,
+ # not a substring. So it'll strip "e" in the end of the sql also!
+ # Ex: "select * from style\e" -> "select * from styl".
+ pattern = re.compile('(^\\\e|\\\e$)')
+ while pattern.search(sql):
+ sql = pattern.sub('', sql)
+
+ return sql
+
+
+@export
+def open_external_editor(filename=None, sql=None):
+ """Open external editor, wait for the user to type in their query, return
+ the query.
+
+ :return: list with one tuple, query as first element.
+
+ """
+
+ message = None
+ filename = filename.strip().split(' ', 1)[0] if filename else None
+
+ sql = sql or ''
+ MARKER = '# Type your query above this line.\n'
+
+ # Populate the editor buffer with the partial sql (if available) and a
+ # placeholder comment.
+ query = click.edit(u'{sql}\n\n{marker}'.format(sql=sql, marker=MARKER),
+ filename=filename, extension='.sql')
+
+ if filename:
+ try:
+ with open(filename) as f:
+ query = f.read()
+ except IOError:
+ message = 'Error reading file: %s.' % filename
+
+ if query is not None:
+ query = query.split(MARKER, 1)[0].rstrip('\n')
+ else:
+ # Don't return None for the caller to deal with.
+ # Empty string is ok.
+ query = sql
+
+ return (query, message)
+
+
+@special_command('\\f', '\\f [name [args..]]', 'List or execute favorite queries.', arg_type=PARSED_QUERY, case_sensitive=True)
+def execute_favorite_query(cur, arg, **_):
+ """Returns (title, rows, headers, status)"""
+ if arg == '':
+ for result in list_favorite_queries():
+ yield result
+
+ """Parse out favorite name and optional substitution parameters"""
+ name, _, arg_str = arg.partition(' ')
+ args = shlex.split(arg_str)
+
+ query = FavoriteQueries.instance.get(name)
+ if query is None:
+ message = "No favorite query: %s" % (name)
+ yield (None, None, None, message)
+ else:
+ query, arg_error = subst_favorite_query_args(query, args)
+ if arg_error:
+ yield (None, None, None, arg_error)
+ else:
+ for sql in sqlparse.split(query):
+ sql = sql.rstrip(';')
+ title = '> %s' % (sql)
+ cur.execute(sql)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ yield (title, cur, headers, None)
+ else:
+ yield (title, None, None, None)
+
+def list_favorite_queries():
+ """List of all favorite queries.
+ Returns (title, rows, headers, status)"""
+
+ headers = ["Name", "Query"]
+ rows = [(r, FavoriteQueries.instance.get(r))
+ for r in FavoriteQueries.instance.list()]
+
+ if not rows:
+ status = '\nNo favorite queries found.' + FavoriteQueries.instance.usage
+ else:
+ status = ''
+ return [('', rows, headers, status)]
+
+
+def subst_favorite_query_args(query, args):
+ """replace positional parameters ($1...$N) in query."""
+ for idx, val in enumerate(args):
+ subst_var = '$' + str(idx + 1)
+ if subst_var not in query:
+ return [None, 'query does not have substitution parameter ' + subst_var + ':\n ' + query]
+
+ query = query.replace(subst_var, val)
+
+ match = re.search('\\$\d+', query)
+ if match:
+ return[None, 'missing substitution for ' + match.group(0) + ' in query:\n ' + query]
+
+ return [query, None]
+
+@special_command('\\fs', '\\fs name query', 'Save a favorite query.')
+def save_favorite_query(arg, **_):
+ """Save a new favorite query.
+ Returns (title, rows, headers, status)"""
+
+ usage = 'Syntax: \\fs name query.\n\n' + FavoriteQueries.instance.usage
+ if not arg:
+ return [(None, None, None, usage)]
+
+ name, _, query = arg.partition(' ')
+
+ # If either name or query is missing then print the usage and complain.
+ if (not name) or (not query):
+ return [(None, None, None,
+ usage + 'Err: Both name and query are required.')]
+
+ FavoriteQueries.instance.save(name, query)
+ return [(None, None, None, "Saved.")]
+
+
+@special_command('\\fd', '\\fd [name]', 'Delete a favorite query.')
+def delete_favorite_query(arg, **_):
+ """Delete an existing favorite query."""
+ usage = 'Syntax: \\fd name.\n\n' + FavoriteQueries.instance.usage
+ if not arg:
+ return [(None, None, None, usage)]
+
+ status = FavoriteQueries.instance.delete(arg)
+
+ return [(None, None, None, status)]
+
+
+@special_command('system', 'system [command]',
+ 'Execute a system shell commmand.')
+def execute_system_command(arg, **_):
+ """Execute a system shell command."""
+ usage = "Syntax: system [command].\n"
+
+ if not arg:
+ return [(None, None, None, usage)]
+
+ try:
+ command = arg.strip()
+ if command.startswith('cd'):
+ ok, error_message = handle_cd_command(arg)
+ if not ok:
+ return [(None, None, None, error_message)]
+ return [(None, None, None, '')]
+
+ args = arg.split(' ')
+ process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ output, error = process.communicate()
+ response = output if not error else error
+
+ # Python 3 returns bytes. This needs to be decoded to a string.
+ if isinstance(response, bytes):
+ encoding = locale.getpreferredencoding(False)
+ response = response.decode(encoding)
+
+ return [(None, None, None, response)]
+ except OSError as e:
+ return [(None, None, None, 'OSError: %s' % e.strerror)]
+
+
+def parseargfile(arg):
+ if arg.startswith('-o '):
+ mode = "w"
+ filename = arg[3:]
+ else:
+ mode = 'a'
+ filename = arg
+
+ if not filename:
+ raise TypeError('You must provide a filename.')
+
+ return {'file': os.path.expanduser(filename), 'mode': mode}
+
+
+@special_command('tee', 'tee [-o] filename',
+ 'Append all results to an output file (overwrite using -o).')
+def set_tee(arg, **_):
+ global tee_file
+
+ try:
+ tee_file = open(**parseargfile(arg))
+ except (IOError, OSError) as e:
+ raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror))
+
+ return [(None, None, None, "")]
+
+@export
+def close_tee():
+ global tee_file
+ if tee_file:
+ tee_file.close()
+ tee_file = None
+
+
+@special_command('notee', 'notee', 'Stop writing results to an output file.')
+def no_tee(arg, **_):
+ close_tee()
+ return [(None, None, None, "")]
+
+@export
+def write_tee(output):
+ global tee_file
+ if tee_file:
+ click.echo(output, file=tee_file, nl=False)
+ click.echo(u'\n', file=tee_file, nl=False)
+ tee_file.flush()
+
+
+@special_command('\\once', '\\o [-o] filename',
+ 'Append next result to an output file (overwrite using -o).',
+ aliases=('\\o', ))
+def set_once(arg, **_):
+ global once_file, written_to_once_file
+
+ once_file = parseargfile(arg)
+ written_to_once_file = False
+
+ return [(None, None, None, "")]
+
+
+@export
+def write_once(output):
+ global once_file, written_to_once_file
+ if output and once_file:
+ try:
+ f = open(**once_file)
+ except (IOError, OSError) as e:
+ once_file = None
+ raise OSError("Cannot write to file '{}': {}".format(
+ e.filename, e.strerror))
+ with f:
+ click.echo(output, file=f, nl=False)
+ click.echo(u"\n", file=f, nl=False)
+ written_to_once_file = True
+
+
+@export
+def unset_once_if_written():
+ """Unset the once file, if it has been written to."""
+ global once_file
+ if written_to_once_file:
+ once_file = None
+
+
+@special_command(
+ 'watch',
+ 'watch [seconds] [-c] query',
+ 'Executes the query every [seconds] seconds (by default 5).'
+)
+def watch_query(arg, **kwargs):
+ usage = """Syntax: watch [seconds] [-c] query.
+ * seconds: The interval at the query will be repeated, in seconds.
+ By default 5.
+ * -c: Clears the screen between every iteration.
+"""
+ if not arg:
+ yield (None, None, None, usage)
+ return
+ seconds = 5
+ clear_screen = False
+ statement = None
+ while statement is None:
+ arg = arg.strip()
+ if not arg:
+ # Oops, we parsed all the arguments without finding a statement
+ yield (None, None, None, usage)
+ return
+ (current_arg, _, arg) = arg.partition(' ')
+ try:
+ seconds = float(current_arg)
+ continue
+ except ValueError:
+ pass
+ if current_arg == '-c':
+ clear_screen = True
+ continue
+ statement = '{0!s} {1!s}'.format(current_arg, arg)
+ destructive_prompt = confirm_destructive_query(statement)
+ if destructive_prompt is False:
+ click.secho("Wise choice!")
+ return
+ elif destructive_prompt is True:
+ click.secho("Your call!")
+ cur = kwargs['cur']
+ sql_list = [
+ (sql.rstrip(';'), "> {0!s}".format(sql))
+ for sql in sqlparse.split(statement)
+ ]
+ old_pager_enabled = is_pager_enabled()
+ while True:
+ if clear_screen:
+ click.clear()
+ try:
+ # Somewhere in the code the pager its activated after every yield,
+ # so we disable it in every iteration
+ set_pager_enabled(False)
+ for (sql, title) in sql_list:
+ cur.execute(sql)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ yield (title, cur, headers, None)
+ else:
+ yield (title, None, None, None)
+ sleep(seconds)
+ except KeyboardInterrupt:
+ # This prints the Ctrl-C character in its own line, which prevents
+ # to print a line with the cursor positioned behind the prompt
+ click.secho("", nl=True)
+ return
+ finally:
+ set_pager_enabled(old_pager_enabled)
+
+
+@export
+@special_command('delimiter', None, 'Change SQL delimiter.')
+def set_delimiter(arg, **_):
+ return delimiter_command.set(arg)
+
+
+@export
+def get_current_delimiter():
+ return delimiter_command.current
+
+
+@export
+def split_queries(input):
+ for query in delimiter_command.queries_iter(input):
+ yield query
diff --git a/mycli/packages/special/main.py b/mycli/packages/special/main.py
new file mode 100644
index 0000000..dddba66
--- /dev/null
+++ b/mycli/packages/special/main.py
@@ -0,0 +1,118 @@
+import logging
+from collections import namedtuple
+
+from . import export
+
+log = logging.getLogger(__name__)
+
+NO_QUERY = 0
+PARSED_QUERY = 1
+RAW_QUERY = 2
+
+SpecialCommand = namedtuple('SpecialCommand',
+ ['handler', 'command', 'shortcut', 'description', 'arg_type', 'hidden',
+ 'case_sensitive'])
+
+COMMANDS = {}
+
+@export
+class CommandNotFound(Exception):
+ pass
+
+@export
+def parse_special_command(sql):
+ command, _, arg = sql.partition(' ')
+ verbose = '+' in command
+ command = command.strip().replace('+', '')
+ return (command, verbose, arg.strip())
+
+@export
+def special_command(command, shortcut, description, arg_type=PARSED_QUERY,
+ hidden=False, case_sensitive=False, aliases=()):
+ def wrapper(wrapped):
+ register_special_command(wrapped, command, shortcut, description,
+ arg_type, hidden, case_sensitive, aliases)
+ return wrapped
+ return wrapper
+
+@export
+def register_special_command(handler, command, shortcut, description,
+ arg_type=PARSED_QUERY, hidden=False, case_sensitive=False, aliases=()):
+ cmd = command.lower() if not case_sensitive else command
+ COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description,
+ arg_type, hidden, case_sensitive)
+ for alias in aliases:
+ cmd = alias.lower() if not case_sensitive else alias
+ COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description,
+ arg_type, case_sensitive=case_sensitive,
+ hidden=True)
+
+@export
+def execute(cur, sql):
+ """Execute a special command and return the results. If the special command
+ is not supported a KeyError will be raised.
+ """
+ command, verbose, arg = parse_special_command(sql)
+
+ if (command not in COMMANDS) and (command.lower() not in COMMANDS):
+ raise CommandNotFound
+
+ try:
+ special_cmd = COMMANDS[command]
+ except KeyError:
+ special_cmd = COMMANDS[command.lower()]
+ if special_cmd.case_sensitive:
+ raise CommandNotFound('Command not found: %s' % command)
+
+ # "help <SQL KEYWORD> is a special case. We want built-in help, not
+ # mycli help here.
+ if command == 'help' and arg:
+ return show_keyword_help(cur=cur, arg=arg)
+
+ if special_cmd.arg_type == NO_QUERY:
+ return special_cmd.handler()
+ elif special_cmd.arg_type == PARSED_QUERY:
+ return special_cmd.handler(cur=cur, arg=arg, verbose=verbose)
+ elif special_cmd.arg_type == RAW_QUERY:
+ return special_cmd.handler(cur=cur, query=sql)
+
+@special_command('help', '\\?', 'Show this help.', arg_type=NO_QUERY, aliases=('\\?', '?'))
+def show_help(): # All the parameters are ignored.
+ headers = ['Command', 'Shortcut', 'Description']
+ result = []
+
+ for _, value in sorted(COMMANDS.items()):
+ if not value.hidden:
+ result.append((value.command, value.shortcut, value.description))
+ return [(None, result, headers, None)]
+
+def show_keyword_help(cur, arg):
+ """
+ Call the built-in "show <command>", to display help for an SQL keyword.
+ :param cur: cursor
+ :param arg: string
+ :return: list
+ """
+ keyword = arg.strip('"').strip("'")
+ query = "help '{0}'".format(keyword)
+ log.debug(query)
+ cur.execute(query)
+ if cur.description and cur.rowcount > 0:
+ headers = [x[0] for x in cur.description]
+ return [(None, cur, headers, '')]
+ else:
+ return [(None, None, None, 'No help found for {0}.'.format(keyword))]
+
+
+@special_command('exit', '\\q', 'Exit.', arg_type=NO_QUERY, aliases=('\\q', ))
+@special_command('quit', '\\q', 'Quit.', arg_type=NO_QUERY)
+def quit(*_args):
+ raise EOFError
+
+
+@special_command('\\e', '\\e', 'Edit command with editor (uses $EDITOR).',
+ arg_type=NO_QUERY, case_sensitive=True)
+@special_command('\\G', '\\G', 'Display current query results vertically.',
+ arg_type=NO_QUERY, case_sensitive=True)
+def stub():
+ raise NotImplementedError
diff --git a/mycli/packages/special/utils.py b/mycli/packages/special/utils.py
new file mode 100644
index 0000000..ef96093
--- /dev/null
+++ b/mycli/packages/special/utils.py
@@ -0,0 +1,46 @@
+import os
+import subprocess
+
+def handle_cd_command(arg):
+ """Handles a `cd` shell command by calling python's os.chdir."""
+ CD_CMD = 'cd'
+ tokens = arg.split(CD_CMD + ' ')
+ directory = tokens[-1] if len(tokens) > 1 else None
+ if not directory:
+ return False, "No folder name was provided."
+ try:
+ os.chdir(directory)
+ subprocess.call(['pwd'])
+ return True, None
+ except OSError as e:
+ return False, e.strerror
+
+def format_uptime(uptime_in_seconds):
+ """Format number of seconds into human-readable string.
+
+ :param uptime_in_seconds: The server uptime in seconds.
+ :returns: A human-readable string representing the uptime.
+
+ >>> uptime = format_uptime('56892')
+ >>> print(uptime)
+ 15 hours 48 min 12 sec
+ """
+
+ m, s = divmod(int(uptime_in_seconds), 60)
+ h, m = divmod(m, 60)
+ d, h = divmod(h, 24)
+
+ uptime_values = []
+
+ for value, unit in ((d, 'days'), (h, 'hours'), (m, 'min'), (s, 'sec')):
+ if value == 0 and not uptime_values:
+ # Don't include a value/unit if the unit isn't applicable to
+ # the uptime. E.g. don't do 0 days 0 hours 1 min 30 sec.
+ continue
+ elif value == 1 and unit.endswith('s'):
+ # Remove the "s" if the unit is singular.
+ unit = unit[:-1]
+ uptime_values.append('{0} {1}'.format(value, unit))
+
+ uptime = ' '.join(uptime_values)
+ return uptime
diff --git a/mycli/packages/tabular_output/__init__.py b/mycli/packages/tabular_output/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/mycli/packages/tabular_output/__init__.py
diff --git a/mycli/packages/tabular_output/sql_format.py b/mycli/packages/tabular_output/sql_format.py
new file mode 100644
index 0000000..730e633
--- /dev/null
+++ b/mycli/packages/tabular_output/sql_format.py
@@ -0,0 +1,63 @@
+"""Format adapter for sql."""
+
+from cli_helpers.utils import filter_dict_by_key
+from mycli.packages.parseutils import extract_tables
+
+supported_formats = ('sql-insert', 'sql-update', 'sql-update-1',
+ 'sql-update-2', )
+
+preprocessors = ()
+
+
+def escape_for_sql_statement(value):
+ if isinstance(value, bytes):
+ return f"X'{value.hex()}'"
+ else:
+ return formatter.mycli.sqlexecute.conn.escape(value)
+
+
+def adapter(data, headers, table_format=None, **kwargs):
+ tables = extract_tables(formatter.query)
+ if len(tables) > 0:
+ table = tables[0]
+ if table[0]:
+ table_name = "{}.{}".format(*table[:2])
+ else:
+ table_name = table[1]
+ else:
+ table_name = "`DUAL`"
+ if table_format == 'sql-insert':
+ h = "`, `".join(headers)
+ yield "INSERT INTO {} (`{}`) VALUES".format(table_name, h)
+ prefix = " "
+ for d in data:
+ values = ", ".join(escape_for_sql_statement(v)
+ for i, v in enumerate(d))
+ yield "{}({})".format(prefix, values)
+ if prefix == " ":
+ prefix = ", "
+ yield ";"
+ if table_format.startswith('sql-update'):
+ s = table_format.split('-')
+ keys = 1
+ if len(s) > 2:
+ keys = int(s[-1])
+ for d in data:
+ yield "UPDATE {} SET".format(table_name)
+ prefix = " "
+ for i, v in enumerate(d[keys:], keys):
+ yield "{}`{}` = {}".format(prefix, headers[i], escape_for_sql_statement(v))
+ if prefix == " ":
+ prefix = ", "
+ f = "`{}` = {}"
+ where = (f.format(headers[i], escape_for_sql_statement(
+ d[i])) for i in range(keys))
+ yield "WHERE {};".format(" AND ".join(where))
+
+
+def register_new_formatter(TabularOutputFormatter):
+ global formatter
+ formatter = TabularOutputFormatter
+ for sql_format in supported_formats:
+ TabularOutputFormatter.register_new_formatter(
+ sql_format, adapter, preprocessors, {'table_format': sql_format})
diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py
new file mode 100644
index 0000000..20611be
--- /dev/null
+++ b/mycli/sqlcompleter.py
@@ -0,0 +1,435 @@
+import logging
+from re import compile, escape
+from collections import Counter
+
+from prompt_toolkit.completion import Completer, Completion
+
+from .packages.completion_engine import suggest_type
+from .packages.parseutils import last_word
+from .packages.filepaths import parse_path, complete_path, suggest_path
+from .packages.special.favoritequeries import FavoriteQueries
+
+_logger = logging.getLogger(__name__)
+
+
+class SQLCompleter(Completer):
+ keywords = ['ACCESS', 'ADD', 'ALL', 'ALTER TABLE', 'AND', 'ANY', 'AS',
+ 'ASC', 'AUTO_INCREMENT', 'BEFORE', 'BEGIN', 'BETWEEN',
+ 'BIGINT', 'BINARY', 'BY', 'CASE', 'CHANGE MASTER TO', 'CHAR',
+ 'CHARACTER SET', 'CHECK', 'COLLATE', 'COLUMN', 'COMMENT',
+ 'COMMIT', 'CONSTRAINT', 'CREATE', 'CURRENT',
+ 'CURRENT_TIMESTAMP', 'DATABASE', 'DATE', 'DECIMAL', 'DEFAULT',
+ 'DELETE FROM', 'DESC', 'DESCRIBE', 'DROP',
+ 'ELSE', 'END', 'ENGINE', 'ESCAPE', 'EXISTS', 'FILE', 'FLOAT',
+ 'FOR', 'FOREIGN KEY', 'FORMAT', 'FROM', 'FULL', 'FUNCTION',
+ 'GRANT', 'GROUP BY', 'HAVING', 'HOST', 'IDENTIFIED', 'IN',
+ 'INCREMENT', 'INDEX', 'INSERT INTO', 'INT', 'INTEGER',
+ 'INTERVAL', 'INTO', 'IS', 'JOIN', 'KEY', 'LEFT', 'LEVEL',
+ 'LIKE', 'LIMIT', 'LOCK', 'LOGS', 'LONG', 'MASTER',
+ 'MEDIUMINT', 'MODE', 'MODIFY', 'NOT', 'NULL', 'NUMBER',
+ 'OFFSET', 'ON', 'OPTION', 'OR', 'ORDER BY', 'OUTER', 'OWNER',
+ 'PASSWORD', 'PORT', 'PRIMARY', 'PRIVILEGES', 'PROCESSLIST',
+ 'PURGE', 'REFERENCES', 'REGEXP', 'RENAME', 'REPAIR', 'RESET',
+ 'REVOKE', 'RIGHT', 'ROLLBACK', 'ROW', 'ROWS', 'ROW_FORMAT',
+ 'SAVEPOINT', 'SELECT', 'SESSION', 'SET', 'SHARE', 'SHOW',
+ 'SLAVE', 'SMALLINT', 'SMALLINT', 'START', 'STOP', 'TABLE',
+ 'THEN', 'TINYINT', 'TO', 'TRANSACTION', 'TRIGGER', 'TRUNCATE',
+ 'UNION', 'UNIQUE', 'UNSIGNED', 'UPDATE', 'USE', 'USER',
+ 'USING', 'VALUES', 'VARCHAR', 'VIEW', 'WHEN', 'WHERE', 'WITH']
+
+ functions = ['AVG', 'CONCAT', 'COUNT', 'DISTINCT', 'FIRST', 'FORMAT',
+ 'FROM_UNIXTIME', 'LAST', 'LCASE', 'LEN', 'MAX', 'MID',
+ 'MIN', 'NOW', 'ROUND', 'SUM', 'TOP', 'UCASE', 'UNIX_TIMESTAMP']
+
+ show_items = []
+
+ change_items = ['MASTER_BIND', 'MASTER_HOST', 'MASTER_USER',
+ 'MASTER_PASSWORD', 'MASTER_PORT', 'MASTER_CONNECT_RETRY',
+ 'MASTER_HEARTBEAT_PERIOD', 'MASTER_LOG_FILE',
+ 'MASTER_LOG_POS', 'RELAY_LOG_FILE', 'RELAY_LOG_POS',
+ 'MASTER_SSL', 'MASTER_SSL_CA', 'MASTER_SSL_CAPATH',
+ 'MASTER_SSL_CERT', 'MASTER_SSL_KEY', 'MASTER_SSL_CIPHER',
+ 'MASTER_SSL_VERIFY_SERVER_CERT', 'IGNORE_SERVER_IDS']
+
+ users = []
+
+ def __init__(self, smart_completion=True, supported_formats=(), keyword_casing='auto'):
+ super(self.__class__, self).__init__()
+ self.smart_completion = smart_completion
+ self.reserved_words = set()
+ for x in self.keywords:
+ self.reserved_words.update(x.split())
+ self.name_pattern = compile("^[_a-z][_a-z0-9\$]*$")
+
+ self.special_commands = []
+ self.table_formats = supported_formats
+ if keyword_casing not in ('upper', 'lower', 'auto'):
+ keyword_casing = 'auto'
+ self.keyword_casing = keyword_casing
+ self.reset_completions()
+
+ def escape_name(self, name):
+ if name and ((not self.name_pattern.match(name))
+ or (name.upper() in self.reserved_words)
+ or (name.upper() in self.functions)):
+ name = '`%s`' % name
+
+ return name
+
+ def unescape_name(self, name):
+ """Unquote a string."""
+ if name and name[0] == '"' and name[-1] == '"':
+ name = name[1:-1]
+
+ return name
+
+ def escaped_names(self, names):
+ return [self.escape_name(name) for name in names]
+
+ def extend_special_commands(self, special_commands):
+ # Special commands are not part of all_completions since they can only
+ # be at the beginning of a line.
+ self.special_commands.extend(special_commands)
+
+ def extend_database_names(self, databases):
+ self.databases.extend(databases)
+
+ def extend_keywords(self, additional_keywords):
+ self.keywords.extend(additional_keywords)
+ self.all_completions.update(additional_keywords)
+
+ def extend_show_items(self, show_items):
+ for show_item in show_items:
+ self.show_items.extend(show_item)
+ self.all_completions.update(show_item)
+
+ def extend_change_items(self, change_items):
+ for change_item in change_items:
+ self.change_items.extend(change_item)
+ self.all_completions.update(change_item)
+
+ def extend_users(self, users):
+ for user in users:
+ self.users.extend(user)
+ self.all_completions.update(user)
+
+ def extend_schemata(self, schema):
+ if schema is None:
+ return
+ metadata = self.dbmetadata['tables']
+ metadata[schema] = {}
+
+ # dbmetadata.values() are the 'tables' and 'functions' dicts
+ for metadata in self.dbmetadata.values():
+ metadata[schema] = {}
+ self.all_completions.update(schema)
+
+ def extend_relations(self, data, kind):
+ """Extend metadata for tables or views
+
+ :param data: list of (rel_name, ) tuples
+ :param kind: either 'tables' or 'views'
+ :return:
+ """
+ # 'data' is a generator object. It can throw an exception while being
+ # consumed. This could happen if the user has launched the app without
+ # specifying a database name. This exception must be handled to prevent
+ # crashing.
+ try:
+ data = [self.escaped_names(d) for d in data]
+ except Exception:
+ data = []
+
+ # dbmetadata['tables'][$schema_name][$table_name] should be a list of
+ # column names. Default to an asterisk
+ metadata = self.dbmetadata[kind]
+ for relname in data:
+ try:
+ metadata[self.dbname][relname[0]] = ['*']
+ except KeyError:
+ _logger.error('%r %r listed in unrecognized schema %r',
+ kind, relname[0], self.dbname)
+ self.all_completions.add(relname[0])
+
+ def extend_columns(self, column_data, kind):
+ """Extend column metadata
+
+ :param column_data: list of (rel_name, column_name) tuples
+ :param kind: either 'tables' or 'views'
+ :return:
+ """
+ # 'column_data' is a generator object. It can throw an exception while
+ # being consumed. This could happen if the user has launched the app
+ # without specifying a database name. This exception must be handled to
+ # prevent crashing.
+ try:
+ column_data = [self.escaped_names(d) for d in column_data]
+ except Exception:
+ column_data = []
+
+ metadata = self.dbmetadata[kind]
+ for relname, column in column_data:
+ metadata[self.dbname][relname].append(column)
+ self.all_completions.add(column)
+
+ def extend_functions(self, func_data):
+ # 'func_data' is a generator object. It can throw an exception while
+ # being consumed. This could happen if the user has launched the app
+ # without specifying a database name. This exception must be handled to
+ # prevent crashing.
+ try:
+ func_data = [self.escaped_names(d) for d in func_data]
+ except Exception:
+ func_data = []
+
+ # dbmetadata['functions'][$schema_name][$function_name] should return
+ # function metadata.
+ metadata = self.dbmetadata['functions']
+
+ for func in func_data:
+ metadata[self.dbname][func[0]] = None
+ self.all_completions.add(func[0])
+
+ def set_dbname(self, dbname):
+ self.dbname = dbname
+
+ def reset_completions(self):
+ self.databases = []
+ self.users = []
+ self.show_items = []
+ self.dbname = ''
+ self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {}}
+ self.all_completions = set(self.keywords + self.functions)
+
+ @staticmethod
+ def find_matches(text, collection, start_only=False, fuzzy=True, casing=None):
+ """Find completion matches for the given text.
+
+ Given the user's input text and a collection of available
+ completions, find completions matching the last word of the
+ text.
+
+ If `start_only` is True, the text will match an available
+ completion only at the beginning. Otherwise, a completion is
+ considered a match if the text appears anywhere within it.
+
+ yields prompt_toolkit Completion instances for any matches found
+ in the collection of available completions.
+ """
+ last = last_word(text, include='most_punctuations')
+ text = last.lower()
+
+ completions = []
+
+ if fuzzy:
+ regex = '.*?'.join(map(escape, text))
+ pat = compile('(%s)' % regex)
+ for item in sorted(collection):
+ r = pat.search(item.lower())
+ if r:
+ completions.append((len(r.group()), r.start(), item))
+ else:
+ match_end_limit = len(text) if start_only else None
+ for item in sorted(collection):
+ match_point = item.lower().find(text, 0, match_end_limit)
+ if match_point >= 0:
+ completions.append((len(text), match_point, item))
+
+ if casing == 'auto':
+ casing = 'lower' if last and last[-1].islower() else 'upper'
+
+ def apply_case(kw):
+ if casing == 'upper':
+ return kw.upper()
+ return kw.lower()
+
+ return (Completion(z if casing is None else apply_case(z), -len(text))
+ for x, y, z in sorted(completions))
+
+ def get_completions(self, document, complete_event, smart_completion=None):
+ word_before_cursor = document.get_word_before_cursor(WORD=True)
+ if smart_completion is None:
+ smart_completion = self.smart_completion
+
+ # If smart_completion is off then match any word that starts with
+ # 'word_before_cursor'.
+ if not smart_completion:
+ return self.find_matches(word_before_cursor, self.all_completions,
+ start_only=True, fuzzy=False)
+
+ completions = []
+ suggestions = suggest_type(document.text, document.text_before_cursor)
+
+ for suggestion in suggestions:
+
+ _logger.debug('Suggestion type: %r', suggestion['type'])
+
+ if suggestion['type'] == 'column':
+ tables = suggestion['tables']
+ _logger.debug("Completion column scope: %r", tables)
+ scoped_cols = self.populate_scoped_cols(tables)
+ if suggestion.get('drop_unique'):
+ # drop_unique is used for 'tb11 JOIN tbl2 USING (...'
+ # which should suggest only columns that appear in more than
+ # one table
+ scoped_cols = [
+ col for (col, count) in Counter(scoped_cols).items()
+ if count > 1 and col != '*'
+ ]
+
+ cols = self.find_matches(word_before_cursor, scoped_cols)
+ completions.extend(cols)
+
+ elif suggestion['type'] == 'function':
+ # suggest user-defined functions using substring matching
+ funcs = self.populate_schema_objects(suggestion['schema'],
+ 'functions')
+ user_funcs = self.find_matches(word_before_cursor, funcs)
+ completions.extend(user_funcs)
+
+ # suggest hardcoded functions using startswith matching only if
+ # there is no schema qualifier. If a schema qualifier is
+ # present it probably denotes a table.
+ # eg: SELECT * FROM users u WHERE u.
+ if not suggestion['schema']:
+ predefined_funcs = self.find_matches(word_before_cursor,
+ self.functions,
+ start_only=True,
+ fuzzy=False,
+ casing=self.keyword_casing)
+ completions.extend(predefined_funcs)
+
+ elif suggestion['type'] == 'table':
+ tables = self.populate_schema_objects(suggestion['schema'],
+ 'tables')
+ tables = self.find_matches(word_before_cursor, tables)
+ completions.extend(tables)
+
+ elif suggestion['type'] == 'view':
+ views = self.populate_schema_objects(suggestion['schema'],
+ 'views')
+ views = self.find_matches(word_before_cursor, views)
+ completions.extend(views)
+
+ elif suggestion['type'] == 'alias':
+ aliases = suggestion['aliases']
+ aliases = self.find_matches(word_before_cursor, aliases)
+ completions.extend(aliases)
+
+ elif suggestion['type'] == 'database':
+ dbs = self.find_matches(word_before_cursor, self.databases)
+ completions.extend(dbs)
+
+ elif suggestion['type'] == 'keyword':
+ keywords = self.find_matches(word_before_cursor, self.keywords,
+ start_only=True,
+ fuzzy=False,
+ casing=self.keyword_casing)
+ completions.extend(keywords)
+
+ elif suggestion['type'] == 'show':
+ show_items = self.find_matches(word_before_cursor,
+ self.show_items,
+ start_only=False,
+ fuzzy=True,
+ casing=self.keyword_casing)
+ completions.extend(show_items)
+
+ elif suggestion['type'] == 'change':
+ change_items = self.find_matches(word_before_cursor,
+ self.change_items,
+ start_only=False,
+ fuzzy=True)
+ completions.extend(change_items)
+ elif suggestion['type'] == 'user':
+ users = self.find_matches(word_before_cursor, self.users,
+ start_only=False,
+ fuzzy=True)
+ completions.extend(users)
+
+ elif suggestion['type'] == 'special':
+ special = self.find_matches(word_before_cursor,
+ self.special_commands,
+ start_only=True,
+ fuzzy=False)
+ completions.extend(special)
+ elif suggestion['type'] == 'favoritequery':
+ queries = self.find_matches(word_before_cursor,
+ FavoriteQueries.instance.list(),
+ start_only=False, fuzzy=True)
+ completions.extend(queries)
+ elif suggestion['type'] == 'table_format':
+ formats = self.find_matches(word_before_cursor,
+ self.table_formats,
+ start_only=True, fuzzy=False)
+ completions.extend(formats)
+ elif suggestion['type'] == 'file_name':
+ file_names = self.find_files(word_before_cursor)
+ completions.extend(file_names)
+
+ return completions
+
+ def find_files(self, word):
+ """Yield matching directory or file names.
+
+ :param word:
+ :return: iterable
+
+ """
+ base_path, last_path, position = parse_path(word)
+ paths = suggest_path(word)
+ for name in sorted(paths):
+ suggestion = complete_path(name, last_path)
+ if suggestion:
+ yield Completion(suggestion, position)
+
+ def populate_scoped_cols(self, scoped_tbls):
+ """Find all columns in a set of scoped_tables
+ :param scoped_tbls: list of (schema, table, alias) tuples
+ :return: list of column names
+ """
+ columns = []
+ meta = self.dbmetadata
+
+ for tbl in scoped_tbls:
+ # A fully qualified schema.relname reference or default_schema
+ # DO NOT escape schema names.
+ schema = tbl[0] or self.dbname
+ relname = tbl[1]
+ escaped_relname = self.escape_name(tbl[1])
+
+ # We don't know if schema.relname is a table or view. Since
+ # tables and views cannot share the same name, we can check one
+ # at a time
+ try:
+ columns.extend(meta['tables'][schema][relname])
+
+ # Table exists, so don't bother checking for a view
+ continue
+ except KeyError:
+ try:
+ columns.extend(meta['tables'][schema][escaped_relname])
+ # Table exists, so don't bother checking for a view
+ continue
+ except KeyError:
+ pass
+
+ try:
+ columns.extend(meta['views'][schema][relname])
+ except KeyError:
+ pass
+
+ return columns
+
+ def populate_schema_objects(self, schema, obj_type):
+ """Returns list of tables or functions for a (optional) schema"""
+ metadata = self.dbmetadata[obj_type]
+ schema = schema or self.dbname
+
+ try:
+ objects = metadata[schema].keys()
+ except KeyError:
+ # schema doesn't exist
+ objects = []
+
+ return objects
diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py
new file mode 100644
index 0000000..c68af0f
--- /dev/null
+++ b/mycli/sqlexecute.py
@@ -0,0 +1,313 @@
+import logging
+import pymysql
+import sqlparse
+from .packages import special
+from pymysql.constants import FIELD_TYPE
+from pymysql.converters import (convert_datetime,
+ convert_timedelta, convert_date, conversions,
+ decoders)
+try:
+ import paramiko
+except ImportError:
+ from mycli.packages.paramiko_stub import paramiko
+
+_logger = logging.getLogger(__name__)
+
+FIELD_TYPES = decoders.copy()
+FIELD_TYPES.update({
+ FIELD_TYPE.NULL: type(None)
+})
+
+class SQLExecute(object):
+
+ databases_query = '''SHOW DATABASES'''
+
+ tables_query = '''SHOW TABLES'''
+
+ version_query = '''SELECT @@VERSION'''
+
+ version_comment_query = '''SELECT @@VERSION_COMMENT'''
+ version_comment_query_mysql4 = '''SHOW VARIABLES LIKE "version_comment"'''
+
+ show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"'''
+
+ users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user'''
+
+ functions_query = '''SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES
+ WHERE ROUTINE_TYPE="FUNCTION" AND ROUTINE_SCHEMA = "%s"'''
+
+ table_columns_query = '''select TABLE_NAME, COLUMN_NAME from information_schema.columns
+ where table_schema = '%s'
+ order by table_name,ordinal_position'''
+
+ def __init__(self, database, user, password, host, port, socket, charset,
+ local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password,
+ ssh_key_filename):
+ self.dbname = database
+ self.user = user
+ self.password = password
+ self.host = host
+ self.port = port
+ self.socket = socket
+ self.charset = charset
+ self.local_infile = local_infile
+ self.ssl = ssl
+ self._server_type = None
+ self.connection_id = None
+ self.ssh_user = ssh_user
+ self.ssh_host = ssh_host
+ self.ssh_port = ssh_port
+ self.ssh_password = ssh_password
+ self.ssh_key_filename = ssh_key_filename
+ self.connect()
+
+ def connect(self, database=None, user=None, password=None, host=None,
+ port=None, socket=None, charset=None, local_infile=None,
+ ssl=None, ssh_host=None, ssh_port=None, ssh_user=None,
+ ssh_password=None, ssh_key_filename=None):
+ db = (database or self.dbname)
+ user = (user or self.user)
+ password = (password or self.password)
+ host = (host or self.host)
+ port = (port or self.port)
+ socket = (socket or self.socket)
+ charset = (charset or self.charset)
+ local_infile = (local_infile or self.local_infile)
+ ssl = (ssl or self.ssl)
+ ssh_user = (ssh_user or self.ssh_user)
+ ssh_host = (ssh_host or self.ssh_host)
+ ssh_port = (ssh_port or self.ssh_port)
+ ssh_password = (ssh_password or self.ssh_password)
+ ssh_key_filename = (ssh_key_filename or self.ssh_key_filename)
+ _logger.debug(
+ 'Connection DB Params: \n'
+ '\tdatabase: %r'
+ '\tuser: %r'
+ '\thost: %r'
+ '\tport: %r'
+ '\tsocket: %r'
+ '\tcharset: %r'
+ '\tlocal_infile: %r'
+ '\tssl: %r'
+ '\tssh_user: %r'
+ '\tssh_host: %r'
+ '\tssh_port: %r'
+ '\tssh_password: %r'
+ '\tssh_key_filename: %r',
+ db, user, host, port, socket, charset, local_infile, ssl,
+ ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename
+ )
+ conv = conversions.copy()
+ conv.update({
+ FIELD_TYPE.TIMESTAMP: lambda obj: (convert_datetime(obj) or obj),
+ FIELD_TYPE.DATETIME: lambda obj: (convert_datetime(obj) or obj),
+ FIELD_TYPE.TIME: lambda obj: (convert_timedelta(obj) or obj),
+ FIELD_TYPE.DATE: lambda obj: (convert_date(obj) or obj),
+ })
+
+ defer_connect = False
+
+ if ssh_host:
+ defer_connect = True
+
+ conn = pymysql.connect(
+ database=db, user=user, password=password, host=host, port=port,
+ unix_socket=socket, use_unicode=True, charset=charset,
+ autocommit=True, client_flag=pymysql.constants.CLIENT.INTERACTIVE,
+ local_infile=local_infile, conv=conv, ssl=ssl, program_name="mycli",
+ defer_connect=defer_connect
+ )
+
+ if ssh_host:
+ client = paramiko.SSHClient()
+ client.load_system_host_keys()
+ client.set_missing_host_key_policy(paramiko.WarningPolicy())
+ client.connect(
+ ssh_host, ssh_port, ssh_user, ssh_password,
+ key_filename=ssh_key_filename
+ )
+ chan = client.get_transport().open_channel(
+ 'direct-tcpip',
+ (host, port),
+ ('0.0.0.0', 0),
+ )
+ conn.connect(chan)
+
+ if hasattr(self, 'conn'):
+ self.conn.close()
+ self.conn = conn
+ # Update them after the connection is made to ensure that it was a
+ # successful connection.
+ self.dbname = db
+ self.user = user
+ self.password = password
+ self.host = host
+ self.port = port
+ self.socket = socket
+ self.charset = charset
+ self.ssl = ssl
+ # retrieve connection id
+ self.reset_connection_id()
+
+ def run(self, statement):
+ """Execute the sql in the database and return the results. The results
+ are a list of tuples. Each tuple has 4 values
+ (title, rows, headers, status).
+ """
+
+ # Remove spaces and EOL
+ statement = statement.strip()
+ if not statement: # Empty string
+ yield (None, None, None, None)
+
+ # Split the sql into separate queries and run each one.
+ # Unless it's saving a favorite query, in which case we
+ # want to save them all together.
+ if statement.startswith('\\fs'):
+ components = [statement]
+ else:
+ components = special.split_queries(statement)
+
+ for sql in components:
+ # \G is treated specially since we have to set the expanded output.
+ if sql.endswith('\\G'):
+ special.set_expanded_output(True)
+ sql = sql[:-2].strip()
+
+ cur = self.conn.cursor()
+ try: # Special command
+ _logger.debug('Trying a dbspecial command. sql: %r', sql)
+ for result in special.execute(cur, sql):
+ yield result
+ except special.CommandNotFound: # Regular SQL
+ _logger.debug('Regular sql statement. sql: %r', sql)
+ cur.execute(sql)
+ while True:
+ yield self.get_result(cur)
+
+ # PyMySQL returns an extra, empty result set with stored
+ # procedures. We skip it (rowcount is zero and no
+ # description).
+ if not cur.nextset() or (not cur.rowcount and cur.description is None):
+ break
+
+ def get_result(self, cursor):
+ """Get the current result's data from the cursor."""
+ title = headers = None
+
+ # cursor.description is not None for queries that return result sets,
+ # e.g. SELECT or SHOW.
+ if cursor.description is not None:
+ headers = [x[0] for x in cursor.description]
+ status = '{0} row{1} in set'
+ else:
+ _logger.debug('No rows in result.')
+ status = 'Query OK, {0} row{1} affected'
+ status = status.format(cursor.rowcount,
+ '' if cursor.rowcount == 1 else 's')
+
+ return (title, cursor if cursor.description else None, headers, status)
+
+ def tables(self):
+ """Yields table names"""
+
+ with self.conn.cursor() as cur:
+ _logger.debug('Tables Query. sql: %r', self.tables_query)
+ cur.execute(self.tables_query)
+ for row in cur:
+ yield row
+
+ def table_columns(self):
+ """Yields (table name, column name) pairs"""
+ with self.conn.cursor() as cur:
+ _logger.debug('Columns Query. sql: %r', self.table_columns_query)
+ cur.execute(self.table_columns_query % self.dbname)
+ for row in cur:
+ yield row
+
+ def databases(self):
+ with self.conn.cursor() as cur:
+ _logger.debug('Databases Query. sql: %r', self.databases_query)
+ cur.execute(self.databases_query)
+ return [x[0] for x in cur.fetchall()]
+
+ def functions(self):
+ """Yields tuples of (schema_name, function_name)"""
+
+ with self.conn.cursor() as cur:
+ _logger.debug('Functions Query. sql: %r', self.functions_query)
+ cur.execute(self.functions_query % self.dbname)
+ for row in cur:
+ yield row
+
+ def show_candidates(self):
+ with self.conn.cursor() as cur:
+ _logger.debug('Show Query. sql: %r', self.show_candidates_query)
+ try:
+ cur.execute(self.show_candidates_query)
+ except pymysql.DatabaseError as e:
+ _logger.error('No show completions due to %r', e)
+ yield ''
+ else:
+ for row in cur:
+ yield (row[0].split(None, 1)[-1], )
+
+ def users(self):
+ with self.conn.cursor() as cur:
+ _logger.debug('Users Query. sql: %r', self.users_query)
+ try:
+ cur.execute(self.users_query)
+ except pymysql.DatabaseError as e:
+ _logger.error('No user completions due to %r', e)
+ yield ''
+ else:
+ for row in cur:
+ yield row
+
+ def server_type(self):
+ if self._server_type:
+ return self._server_type
+ with self.conn.cursor() as cur:
+ _logger.debug('Version Query. sql: %r', self.version_query)
+ cur.execute(self.version_query)
+ version = cur.fetchone()[0]
+ if version[0] == '4':
+ _logger.debug('Version Comment. sql: %r',
+ self.version_comment_query_mysql4)
+ cur.execute(self.version_comment_query_mysql4)
+ version_comment = cur.fetchone()[1].lower()
+ if isinstance(version_comment, bytes):
+ # with python3 this query returns bytes
+ version_comment = version_comment.decode('utf-8')
+ else:
+ _logger.debug('Version Comment. sql: %r',
+ self.version_comment_query)
+ cur.execute(self.version_comment_query)
+ version_comment = cur.fetchone()[0].lower()
+
+ if 'mariadb' in version_comment:
+ product_type = 'mariadb'
+ elif 'percona' in version_comment:
+ product_type = 'percona'
+ else:
+ product_type = 'mysql'
+
+ self._server_type = (product_type, version)
+ return self._server_type
+
+ def get_connection_id(self):
+ if not self.connection_id:
+ self.reset_connection_id()
+ return self.connection_id
+
+ def reset_connection_id(self):
+ # Remember current connection id
+ _logger.debug('Get current connection id')
+ res = self.run('select connection_id()')
+ for title, cur, headers, status in res:
+ self.connection_id = cur.fetchone()[0]
+ _logger.debug('Current connection id: %s', self.connection_id)
+
+ def change_db(self, db):
+ self.conn.select_db(db)
+ self.dbname = db