from collections import defaultdict from io import open import os import sys import shutil import traceback import logging import threading import re import stat import fileinput from collections import namedtuple try: from pwd import getpwuid except ImportError: pass from time import time from datetime import datetime from random import choice from pymysql import OperationalError from cli_helpers.tabular_output import TabularOutputFormatter from cli_helpers.tabular_output import preprocessors from cli_helpers.utils import strip_ansi import click import sqlparse import sqlglot from mycli.packages.parseutils import is_dropping_database, is_destructive from prompt_toolkit.completion import DynamicCompleter from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode from prompt_toolkit.key_binding.bindings.named_commands import register as prompt_register from prompt_toolkit.shortcuts import PromptSession, CompleteStyle from prompt_toolkit.document import Document from prompt_toolkit.filters import HasFocus, IsDone from prompt_toolkit.formatted_text import ANSI from prompt_toolkit.layout.processors import (HighlightMatchingBracketProcessor, ConditionalProcessor) from prompt_toolkit.lexers import PygmentsLexer from prompt_toolkit.history import FileHistory from prompt_toolkit.auto_suggest import AutoSuggestFromHistory from .packages.special.main import NO_QUERY from .packages.prompt_utils import confirm, confirm_destructive_query from .packages.tabular_output import sql_format from .packages import special from .packages.special.favoritequeries import FavoriteQueries from .sqlcompleter import SQLCompleter from .clitoolbar import create_toolbar_tokens_func from .clistyle import style_factory, style_factory_output from .sqlexecute import FIELD_TYPES, SQLExecute, ERROR_CODE_ACCESS_DENIED from .clibuffer import cli_is_multiline from .completion_refresher import CompletionRefresher from .config import (write_default_config, get_mylogin_cnf_path, open_mylogin_cnf, read_config_files, str_to_bool, strip_matching_quotes) from .key_bindings import mycli_bindings from .lexer import MyCliLexer from . import __version__ from .compat import WIN from .packages.filepaths import dir_path_exists, guess_socket_location import itertools click.disable_unicode_literals_warning = True try: from urlparse import urlparse from urlparse import unquote except ImportError: from urllib.parse import urlparse from urllib.parse import unquote try: import importlib.resources as resources except ImportError: # Python < 3.7 import importlib_resources as resources try: import paramiko except ImportError: from mycli.packages.paramiko_stub import paramiko # Query tuples are used for maintaining history Query = namedtuple('Query', ['query', 'successful', 'mutating']) SUPPORT_INFO = ( 'Home: http://mycli.net\n' 'Bug tracker: https://github.com/dbcli/mycli/issues' ) class MyCli(object): default_prompt = '\\t \\u@\\h:\\d> ' default_prompt_splitln = '\\u@\\h\\n(\\t):\\d>' max_len_prompt = 45 defaults_suffix = None # In order of being loaded. Files lower in list override earlier ones. cnf_files = [ '/etc/my.cnf', '/etc/mysql/my.cnf', '/usr/local/etc/my.cnf', os.path.expanduser('~/.my.cnf'), ] # check XDG_CONFIG_HOME exists and not an empty string if os.environ.get("XDG_CONFIG_HOME"): xdg_config_home = os.environ.get("XDG_CONFIG_HOME") else: xdg_config_home = "~/.config" system_config_files = [ '/etc/myclirc', os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc") ] pwd_config_file = os.path.join(os.getcwd(), ".myclirc") def __init__(self, sqlexecute=None, prompt=None, logfile=None, defaults_suffix=None, defaults_file=None, login_path=None, auto_vertical_output=False, warn=None, myclirc="~/.myclirc"): self.sqlexecute = sqlexecute self.logfile = logfile self.defaults_suffix = defaults_suffix self.login_path = login_path self.toolbar_error_message = None # self.cnf_files is a class variable that stores the list of mysql # config files to read in at launch. # If defaults_file is specified then override the class variable with # defaults_file. if defaults_file: self.cnf_files = [defaults_file] # Load config. config_files = (self.system_config_files + [myclirc] + [self.pwd_config_file]) c = self.config = read_config_files(config_files) self.multi_line = c['main'].as_bool('multi_line') self.key_bindings = c['main']['key_bindings'] special.set_timing_enabled(c['main'].as_bool('timing')) self.beep_after_seconds = float(c['main']['beep_after_seconds'] or 0) FavoriteQueries.instance = FavoriteQueries.from_config(self.config) self.dsn_alias = None self.formatter = TabularOutputFormatter( format_name=c['main']['table_format']) sql_format.register_new_formatter(self.formatter) self.formatter.mycli = self self.syntax_style = c['main']['syntax_style'] self.less_chatty = c['main'].as_bool('less_chatty') self.cli_style = c['colors'] self.output_style = style_factory_output( self.syntax_style, self.cli_style ) self.wider_completion_menu = c['main'].as_bool('wider_completion_menu') c_dest_warning = c['main'].as_bool('destructive_warning') self.destructive_warning = c_dest_warning if warn is None else warn self.login_path_as_host = c['main'].as_bool('login_path_as_host') # read from cli argument or user config file self.auto_vertical_output = auto_vertical_output or \ c['main'].as_bool('auto_vertical_output') # Write user config if system config wasn't the last config loaded. if c.filename not in self.system_config_files and not os.path.exists(myclirc): write_default_config(myclirc) # audit log if self.logfile is None and 'audit_log' in c['main']: try: self.logfile = open(os.path.expanduser(c['main']['audit_log']), 'a') except (IOError, OSError) as e: self.echo('Error: Unable to open the audit log file. Your queries will not be logged.', err=True, fg='red') self.logfile = False self.completion_refresher = CompletionRefresher() self.logger = logging.getLogger(__name__) self.initialize_logging() prompt_cnf = self.read_my_cnf_files(self.cnf_files, ['prompt'])['prompt'] self.prompt_format = prompt or prompt_cnf or c['main']['prompt'] or \ self.default_prompt self.multiline_continuation_char = c['main']['prompt_continuation'] keyword_casing = c['main'].get('keyword_casing', 'auto') self.query_history = [] # Initialize completer. self.smart_completion = c['main'].as_bool('smart_completion') self.completer = SQLCompleter( self.smart_completion, supported_formats=self.formatter.supported_formats, keyword_casing=keyword_casing) self._completer_lock = threading.Lock() # Register custom special commands. self.register_special_commands() # Load .mylogin.cnf if it exists. mylogin_cnf_path = get_mylogin_cnf_path() if mylogin_cnf_path: mylogin_cnf = open_mylogin_cnf(mylogin_cnf_path) if mylogin_cnf_path and mylogin_cnf: # .mylogin.cnf gets read last, even if defaults_file is specified. self.cnf_files.append(mylogin_cnf) elif mylogin_cnf_path and not mylogin_cnf: # There was an error reading the login path file. print('Error: Unable to read login path file.') self.prompt_app = None def register_special_commands(self): special.register_special_command(self.change_db, 'use', '\\u', 'Change to a new database.', aliases=('\\u',)) special.register_special_command(self.change_db, 'connect', '\\r', 'Reconnect to the database. Optional database argument.', aliases=('\\r', ), case_sensitive=True) special.register_special_command(self.refresh_completions, 'rehash', '\\#', 'Refresh auto-completions.', arg_type=NO_QUERY, aliases=('\\#',)) special.register_special_command( self.change_table_format, 'tableformat', '\\T', 'Change the table format used to output results.', aliases=('\\T',), case_sensitive=True) special.register_special_command(self.execute_from_file, 'source', '\\. filename', 'Execute commands from file.', aliases=('\\.',)) special.register_special_command(self.change_prompt_format, 'prompt', '\\R', 'Change prompt format.', aliases=('\\R',), case_sensitive=True) def change_table_format(self, arg, **_): try: self.formatter.format_name = arg yield (None, None, None, 'Changed table format to {}'.format(arg)) except ValueError: msg = 'Table format {} not recognized. Allowed formats:'.format( arg) for table_type in self.formatter.supported_formats: msg += "\n\t{}".format(table_type) yield (None, None, None, msg) def change_db(self, arg, **_): if not arg: click.secho( "No database selected", err=True, fg="red" ) return if arg.startswith('`') and arg.endswith('`'): arg = re.sub(r'^`(.*)`$', r'\1', arg) arg = re.sub(r'``', r'`', arg) self.sqlexecute.change_db(arg) yield (None, None, None, 'You are now connected to database "%s" as ' 'user "%s"' % (self.sqlexecute.dbname, self.sqlexecute.user)) def execute_from_file(self, arg, **_): if not arg: message = 'Missing required argument, filename.' return [(None, None, None, message)] try: with open(os.path.expanduser(arg)) as f: query = f.read() except IOError as e: return [(None, None, None, str(e))] if (self.destructive_warning and confirm_destructive_query(query) is False): message = 'Wise choice. Command execution stopped.' return [(None, None, None, message)] return self.sqlexecute.run(query) def change_prompt_format(self, arg, **_): """ Change the prompt format. """ if not arg: message = 'Missing required argument, format.' return [(None, None, None, message)] self.prompt_format = self.get_prompt(arg) return [(None, None, None, "Changed prompt format to %s" % arg)] def initialize_logging(self): log_file = os.path.expanduser(self.config['main']['log_file']) log_level = self.config['main']['log_level'] level_map = {'CRITICAL': logging.CRITICAL, 'ERROR': logging.ERROR, 'WARNING': logging.WARNING, 'INFO': logging.INFO, 'DEBUG': logging.DEBUG } # Disable logging if value is NONE by switching to a no-op handler # Set log level to a high value so it doesn't even waste cycles getting called. if log_level.upper() == "NONE": handler = logging.NullHandler() log_level = "CRITICAL" elif dir_path_exists(log_file): handler = logging.FileHandler(log_file) else: self.echo( 'Error: Unable to open the log file "{}".'.format(log_file), err=True, fg='red') return formatter = logging.Formatter( '%(asctime)s (%(process)d/%(threadName)s) ' '%(name)s %(levelname)s - %(message)s') handler.setFormatter(formatter) root_logger = logging.getLogger('mycli') root_logger.addHandler(handler) root_logger.setLevel(level_map[log_level.upper()]) logging.captureWarnings(True) root_logger.debug('Initializing mycli logging.') root_logger.debug('Log file %r.', log_file) def read_my_cnf_files(self, files, keys): """ Reads a list of config files and merges them. The last one will win. :param files: list of files to read :param keys: list of keys to retrieve :returns: tuple, with None for missing keys. """ cnf = read_config_files(files, list_values=False) sections = ['client', 'mysqld'] key_transformations = { 'mysqld': { 'socket': 'default_socket', 'port': 'default_port', 'user': 'default_user', }, } if self.login_path and self.login_path != 'client': sections.append(self.login_path) if self.defaults_suffix: sections.extend([sect + self.defaults_suffix for sect in sections]) configuration = defaultdict(lambda: None) for key in keys: for section in cnf: if ( section not in sections or key not in cnf[section] ): continue new_key = key_transformations.get(section, {}).get(key) or key configuration[new_key] = strip_matching_quotes( cnf[section][key]) return configuration def merge_ssl_with_cnf(self, ssl, cnf): """Merge SSL configuration dict with cnf dict""" merged = {} merged.update(ssl) prefix = 'ssl-' for k, v in cnf.items(): # skip unrelated options if not k.startswith(prefix): continue if v is None: continue # special case because PyMySQL argument is significantly different # from commandline if k == 'ssl-verify-server-cert': merged['check_hostname'] = v else: # use argument name just strip "ssl-" prefix arg = k[len(prefix):] merged[arg] = v return merged def connect(self, database='', user='', passwd='', host='', port='', socket='', charset='', local_infile='', ssl='', ssh_user='', ssh_host='', ssh_port='', ssh_password='', ssh_key_filename='', init_command='', password_file=''): cnf = {'database': None, 'user': None, 'password': None, 'host': None, 'port': None, 'socket': None, 'default_socket': None, 'default-character-set': None, 'local-infile': None, 'loose-local-infile': None, 'ssl-ca': None, 'ssl-cert': None, 'ssl-key': None, 'ssl-cipher': None, 'ssl-verify-serer-cert': None, } cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys()) # Fall back to config values only if user did not specify a value. database = database or cnf['database'] user = user or cnf['user'] or os.getenv('USER') host = host or cnf['host'] port = port or cnf['port'] ssl = ssl or {} port = port and int(port) if not port: port = 3306 if not host or host == 'localhost': socket = ( socket or cnf['socket'] or cnf['default_socket'] or guess_socket_location() ) passwd = passwd if isinstance(passwd, str) else cnf['password'] charset = charset or cnf['default-character-set'] or 'utf8' # Favor whichever local_infile option is set. for local_infile_option in (local_infile, cnf['local-infile'], cnf['loose-local-infile'], False): try: local_infile = str_to_bool(local_infile_option) break except (TypeError, ValueError): pass ssl = self.merge_ssl_with_cnf(ssl, cnf) # prune lone check_hostname=False if not any(v for v in ssl.values()): ssl = None # if the passwd is not specified try to set it using the password_file option password_from_file = self.get_password_from_file(password_file) passwd = passwd or password_from_file # Connect to the database. def _connect(): try: self.sqlexecute = SQLExecute( database, user, passwd, host, port, socket, charset, local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename, init_command ) except OperationalError as e: if e.args[0] == ERROR_CODE_ACCESS_DENIED: if password_from_file: new_passwd = password_from_file else: new_passwd = click.prompt('Password', hide_input=True, show_default=False, type=str, err=True) self.sqlexecute = SQLExecute( database, user, new_passwd, host, port, socket, charset, local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename, init_command ) else: raise e try: if not WIN and socket: socket_owner = getpwuid(os.stat(socket).st_uid).pw_name self.echo( f"Connecting to socket {socket}, owned by user {socket_owner}", err=True) try: _connect() except OperationalError as e: # These are "Can't open socket" and 2x "Can't connect" if [code for code in (2001, 2002, 2003) if code == e.args[0]]: self.logger.debug('Database connection failed: %r.', e) self.logger.error( "traceback: %r", traceback.format_exc()) self.logger.debug('Retrying over TCP/IP') self.echo( "Failed to connect to local MySQL server through socket '{}':".format(socket)) self.echo(str(e), err=True) self.echo( 'Retrying over TCP/IP', err=True) # Else fall back to TCP/IP localhost socket = "" host = 'localhost' port = 3306 _connect() else: raise e else: host = host or 'localhost' port = port or 3306 # Bad ports give particularly daft error messages try: port = int(port) except ValueError as e: self.echo("Error: Invalid port number: '{0}'.".format(port), err=True, fg='red') exit(1) _connect() except Exception as e: # Connecting to a database could fail. self.logger.debug('Database connection failed: %r.', e) self.logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') exit(1) def get_password_from_file(self, password_file): password_from_file = None if password_file: if (os.path.isfile(password_file) or stat.S_ISFIFO(os.stat(password_file).st_mode)) \ and os.access(password_file, os.R_OK): with open(password_file) as fp: password_from_file = fp.readline() password_from_file = password_from_file.rstrip().lstrip() return password_from_file def handle_editor_command(self, text): r"""Editor command is any query that is prefixed or suffixed by a '\e'. The reason for a while loop is because a user might edit a query multiple times. For eg: "select * from \e" to edit it in vim, then come back to the prompt with the edited query "select * from blah where q = 'abc'\e" to edit it again. :param text: Document :return: Document """ while special.editor_command(text): filename = special.get_filename(text) query = (special.get_editor_query(text) or self.get_last_query()) sql, message = special.open_external_editor(filename, sql=query) if message: # Something went wrong. Raise an exception and bail. raise RuntimeError(message) while True: try: text = self.prompt_app.prompt(default=sql) break except KeyboardInterrupt: sql = "" continue return text def handle_clip_command(self, text): r"""A clip command is any query that is prefixed or suffixed by a '\clip'. :param text: Document :return: Boolean """ if special.clip_command(text): query = (special.get_clip_query(text) or self.get_last_query()) message = special.copy_query_to_clipboard(sql=query) if message: raise RuntimeError(message) return True return False def handle_prettify_binding(self, text): try: statements = sqlglot.parse(text, read='mysql') except Exception as e: statements = [] if len(statements) == 1 and statements[0]: pretty_text = statements[0].sql(pretty=True, pad=4, dialect='mysql') else: pretty_text = '' self.toolbar_error_message = 'Prettify failed to parse statement' if len(pretty_text) > 0: pretty_text = pretty_text + ';' return pretty_text def handle_unprettify_binding(self, text): try: statements = sqlglot.parse(text, read='mysql') except Exception as e: statements = [] if len(statements) == 1 and statements[0]: unpretty_text = statements[0].sql(pretty=False, dialect='mysql') else: unpretty_text = '' self.toolbar_error_message = 'Unprettify failed to parse statement' if len(unpretty_text) > 0: unpretty_text = unpretty_text + ';' return unpretty_text def run_cli(self): iterations = 0 sqlexecute = self.sqlexecute logger = self.logger self.configure_pager() if self.smart_completion: self.refresh_completions() history_file = os.path.expanduser( os.environ.get('MYCLI_HISTFILE', '~/.mycli-history')) if dir_path_exists(history_file): history = FileHistory(history_file) else: history = None self.echo( 'Error: Unable to open the history file "{}". ' 'Your query history will not be saved.'.format(history_file), err=True, fg='red') key_bindings = mycli_bindings(self) if not self.less_chatty: print(sqlexecute.server_info) print('mycli', __version__) print(SUPPORT_INFO) print('Thanks to the contributor -', thanks_picker()) def get_message(): prompt = self.get_prompt(self.prompt_format) if self.prompt_format == self.default_prompt and len(prompt) > self.max_len_prompt: prompt = self.get_prompt(self.default_prompt_splitln) prompt = prompt.replace("\\x1b", "\x1b") return ANSI(prompt) def get_continuation(width, *_): if self.multiline_continuation_char == '': continuation = '' elif self.multiline_continuation_char: left_padding = width - len(self.multiline_continuation_char) continuation = " " * \ max((left_padding - 1), 0) + \ self.multiline_continuation_char + " " else: continuation = " " return [('class:continuation', continuation)] def show_suggestion_tip(): return iterations < 2 def one_iteration(text=None): if text is None: try: text = self.prompt_app.prompt() except KeyboardInterrupt: return special.set_expanded_output(False) try: text = self.handle_editor_command(text) except RuntimeError as e: logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') return try: if self.handle_clip_command(text): return except RuntimeError as e: logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') return if not text.strip(): return if self.destructive_warning: destroy = confirm_destructive_query(text) if destroy is None: pass # Query was not destructive. Nothing to do here. elif destroy is True: self.echo('Your call!') else: self.echo('Wise choice!') return else: destroy = True # Keep track of whether or not the query is mutating. In case # of a multi-statement query, the overall query is considered # mutating if any one of the component statements is mutating mutating = False try: logger.debug('sql: %r', text) special.write_tee(self.get_prompt(self.prompt_format) + text) if self.logfile: self.logfile.write('\n# %s\n' % datetime.now()) self.logfile.write(text) self.logfile.write('\n') successful = False start = time() res = sqlexecute.run(text) self.formatter.query = text successful = True result_count = 0 for title, cur, headers, status in res: logger.debug("headers: %r", headers) logger.debug("rows: %r", cur) logger.debug("status: %r", status) threshold = 1000 if (is_select(status) and cur and cur.rowcount > threshold): self.echo('The result set has more than {} rows.'.format( threshold), fg='red') if not confirm('Do you want to continue?'): self.echo("Aborted!", err=True, fg='red') break if self.auto_vertical_output: max_width = self.prompt_app.output.get_size().columns else: max_width = None formatted = self.format_output( title, cur, headers, special.is_expanded_output(), max_width) t = time() - start try: if result_count > 0: self.echo('') try: self.output(formatted, status) except KeyboardInterrupt: pass if self.beep_after_seconds > 0 and t >= self.beep_after_seconds: self.bell() if special.is_timing_enabled(): self.echo('Time: %0.03fs' % t) except KeyboardInterrupt: pass start = time() result_count += 1 mutating = mutating or destroy or is_mutating(status) special.unset_once_if_written() special.unset_pipe_once_if_written() except EOFError as e: raise e except KeyboardInterrupt: # get last connection id connection_id_to_kill = sqlexecute.connection_id # some mysql compatible databases may not implemente connection_id() if connection_id_to_kill > 0: logger.debug("connection id to kill: %r", connection_id_to_kill) # Restart connection to the database sqlexecute.connect() try: for title, cur, headers, status in sqlexecute.run('kill %s' % connection_id_to_kill): status_str = str(status).lower() if status_str.find('ok') > -1: logger.debug("cancelled query, connection id: %r, sql: %r", connection_id_to_kill, text) self.echo("cancelled query", err=True, fg='red') except Exception as e: self.echo('Encountered error while cancelling query: {}'.format(e), err=True, fg='red') else: logger.debug("Did not get a connection id, skip cancelling query") except NotImplementedError: self.echo('Not Yet Implemented.', fg="yellow") except OperationalError as e: logger.debug("Exception: %r", e) if (e.args[0] in (2003, 2006, 2013)): logger.debug('Attempting to reconnect.') self.echo('Reconnecting...', fg='yellow') try: sqlexecute.connect() logger.debug('Reconnected successfully.') one_iteration(text) return # OK to just return, cuz the recursion call runs to the end. except OperationalError as e: logger.debug('Reconnect failed. e: %r', e) self.echo(str(e), err=True, fg='red') # If reconnection failed, don't proceed further. return else: logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') except Exception as e: logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) self.echo(str(e), err=True, fg='red') else: if is_dropping_database(text, self.sqlexecute.dbname): self.sqlexecute.dbname = None self.sqlexecute.connect() # Refresh the table names and column names if necessary. if need_completion_refresh(text): self.refresh_completions( reset=need_completion_reset(text)) finally: if self.logfile is False: self.echo("Warning: This query was not logged.", err=True, fg='red') query = Query(text, successful, mutating) self.query_history.append(query) get_toolbar_tokens = create_toolbar_tokens_func( self, show_suggestion_tip) if self.wider_completion_menu: complete_style = CompleteStyle.MULTI_COLUMN else: complete_style = CompleteStyle.COLUMN with self._completer_lock: if self.key_bindings == 'vi': editing_mode = EditingMode.VI else: editing_mode = EditingMode.EMACS self.prompt_app = PromptSession( lexer=PygmentsLexer(MyCliLexer), reserve_space_for_menu=self.get_reserved_space(), message=get_message, prompt_continuation=get_continuation, bottom_toolbar=get_toolbar_tokens, complete_style=complete_style, input_processors=[ConditionalProcessor( processor=HighlightMatchingBracketProcessor( chars='[](){}'), filter=HasFocus(DEFAULT_BUFFER) & ~IsDone() )], tempfile_suffix='.sql', completer=DynamicCompleter(lambda: self.completer), history=history, auto_suggest=AutoSuggestFromHistory(), complete_while_typing=True, multiline=cli_is_multiline(self), style=style_factory(self.syntax_style, self.cli_style), include_default_pygments_style=False, key_bindings=key_bindings, enable_open_in_editor=True, enable_system_prompt=True, enable_suspend=True, editing_mode=editing_mode, search_ignore_case=True ) try: while True: one_iteration() iterations += 1 except EOFError: special.close_tee() if not self.less_chatty: self.echo('Goodbye!') def log_output(self, output): """Log the output in the audit log, if it's enabled.""" if self.logfile: click.echo(output, file=self.logfile) def echo(self, s, **kwargs): """Print a message to stdout. The message will be logged in the audit log, if enabled. All keyword arguments are passed to click.echo(). """ self.log_output(s) click.secho(s, **kwargs) def bell(self): """Print a bell on the stderr. """ click.secho('\a', err=True, nl=False) def get_output_margin(self, status=None): """Get the output margin (number of rows for the prompt, footer and timing message.""" margin = self.get_reserved_space() + self.get_prompt(self.prompt_format).count('\n') + 1 if special.is_timing_enabled(): margin += 1 if status: margin += 1 + status.count('\n') return margin def output(self, output, status=None): """Output text to stdout or a pager command. The status text is not outputted to pager or files. The message will be logged in the audit log, if enabled. The message will be written to the tee file, if enabled. The message will be written to the output file, if enabled. """ if output: size = self.prompt_app.output.get_size() margin = self.get_output_margin(status) fits = True buf = [] output_via_pager = self.explicit_pager and special.is_pager_enabled() for i, line in enumerate(output, 1): self.log_output(line) special.write_tee(line) special.write_once(line) special.write_pipe_once(line) if fits or output_via_pager: # buffering buf.append(line) if len(line) > size.columns or i > (size.rows - margin): fits = False if not self.explicit_pager and special.is_pager_enabled(): # doesn't fit, use pager output_via_pager = True if not output_via_pager: # doesn't fit, flush buffer for buf_line in buf: click.secho(buf_line) buf = [] else: click.secho(line) if buf: if output_via_pager: def newlinewrapper(text): for line in text: yield line + "\n" click.echo_via_pager(newlinewrapper(buf)) else: for line in buf: click.secho(line) if status: self.log_output(status) click.secho(status) def configure_pager(self): # Provide sane defaults for less if they are empty. if not os.environ.get('LESS'): os.environ['LESS'] = '-RXF' cnf = self.read_my_cnf_files(self.cnf_files, ['pager', 'skip-pager']) cnf_pager = cnf['pager'] or self.config['main']['pager'] if cnf_pager: special.set_pager(cnf_pager) self.explicit_pager = True else: self.explicit_pager = False if cnf['skip-pager'] or not self.config['main'].as_bool('enable_pager'): special.disable_pager() def refresh_completions(self, reset=False): if reset: with self._completer_lock: self.completer.reset_completions() self.completion_refresher.refresh( self.sqlexecute, self._on_completions_refreshed, {'smart_completion': self.smart_completion, 'supported_formats': self.formatter.supported_formats, 'keyword_casing': self.completer.keyword_casing}) return [(None, None, None, 'Auto-completion refresh started in the background.')] def _on_completions_refreshed(self, new_completer): """Swap the completer object in cli with the newly created completer. """ with self._completer_lock: self.completer = new_completer if self.prompt_app: # After refreshing, redraw the CLI to clear the statusbar # "Refreshing completions..." indicator self.prompt_app.app.invalidate() def get_completions(self, text, cursor_positition): with self._completer_lock: return self.completer.get_completions( Document(text=text, cursor_position=cursor_positition), None) def get_prompt(self, string): sqlexecute = self.sqlexecute host = self.login_path if self.login_path and self.login_path_as_host else sqlexecute.host now = datetime.now() string = string.replace('\\u', sqlexecute.user or '(none)') string = string.replace('\\h', host or '(none)') string = string.replace('\\d', sqlexecute.dbname or '(none)') string = string.replace('\\t', sqlexecute.server_info.species.name) string = string.replace('\\n', "\n") string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y')) string = string.replace('\\m', now.strftime('%M')) string = string.replace('\\P', now.strftime('%p')) string = string.replace('\\R', now.strftime('%H')) string = string.replace('\\r', now.strftime('%I')) string = string.replace('\\s', now.strftime('%S')) string = string.replace('\\p', str(sqlexecute.port)) string = string.replace('\\A', self.dsn_alias or '(none)') string = string.replace('\\_', ' ') return string def run_query(self, query, new_line=True): """Runs *query*.""" results = self.sqlexecute.run(query) for result in results: title, cur, headers, status = result self.formatter.query = query output = self.format_output(title, cur, headers, special.is_expanded_output()) for line in output: click.echo(line, nl=new_line) def format_output(self, title, cur, headers, expanded=False, max_width=None): expanded = expanded or self.formatter.format_name == 'vertical' output = [] output_kwargs = { 'dialect': 'unix', 'disable_numparse': True, 'preserve_whitespace': True, 'style': self.output_style } if not self.formatter.format_name in sql_format.supported_formats: output_kwargs["preprocessors"] = (preprocessors.align_decimals, ) if title: # Only print the title if it's not None. output = itertools.chain(output, [title]) if cur: column_types = None if hasattr(cur, 'description'): def get_col_type(col): col_type = FIELD_TYPES.get(col[1], str) return col_type if type(col_type) is type else str column_types = [get_col_type(col) for col in cur.description] if max_width is not None: cur = list(cur) formatted = self.formatter.format_output( cur, headers, format_name='vertical' if expanded else None, column_types=column_types, **output_kwargs) if isinstance(formatted, str): formatted = formatted.splitlines() formatted = iter(formatted) if (not expanded and max_width and headers and cur): first_line = next(formatted) if len(strip_ansi(first_line)) > max_width: formatted = self.formatter.format_output( cur, headers, format_name='vertical', column_types=column_types, **output_kwargs) if isinstance(formatted, str): formatted = iter(formatted.splitlines()) else: formatted = itertools.chain([first_line], formatted) output = itertools.chain(output, formatted) return output def get_reserved_space(self): """Get the number of lines to reserve for the completion menu.""" reserved_space_ratio = .45 max_reserved_space = 8 _, height = shutil.get_terminal_size() return min(int(round(height * reserved_space_ratio)), max_reserved_space) def get_last_query(self): """Get the last query executed or None.""" return self.query_history[-1][0] if self.query_history else None @click.command() @click.option('-h', '--host', envvar='MYSQL_HOST', help='Host address of the database.') @click.option('-P', '--port', envvar='MYSQL_TCP_PORT', type=int, help='Port number to use for connection. Honors ' '$MYSQL_TCP_PORT.') @click.option('-u', '--user', help='User name to connect to the database.') @click.option('-S', '--socket', envvar='MYSQL_UNIX_PORT', help='The socket file to use for connection.') @click.option('-p', '--password', 'password', envvar='MYSQL_PWD', type=str, help='Password to connect to the database.') @click.option('--pass', 'password', envvar='MYSQL_PWD', type=str, help='Password to connect to the database.') @click.option('--ssh-user', help='User name to connect to ssh server.') @click.option('--ssh-host', help='Host name to connect to ssh server.') @click.option('--ssh-port', default=22, help='Port to connect to ssh server.') @click.option('--ssh-password', help='Password to connect to ssh server.') @click.option('--ssh-key-filename', help='Private key filename (identify file) for the ssh connection.') @click.option('--ssh-config-path', help='Path to ssh configuration.', default=os.path.expanduser('~') + '/.ssh/config') @click.option('--ssh-config-host', help='Host to connect to ssh server reading from ssh configuration.') @click.option('--ssl', 'ssl_enable', is_flag=True, help='Enable SSL for connection (automatically enabled with other flags).') @click.option('--ssl-ca', help='CA file in PEM format.', type=click.Path(exists=True)) @click.option('--ssl-capath', help='CA directory.') @click.option('--ssl-cert', help='X509 cert in PEM format.', type=click.Path(exists=True)) @click.option('--ssl-key', help='X509 key in PEM format.', type=click.Path(exists=True)) @click.option('--ssl-cipher', help='SSL cipher to use.') @click.option('--tls-version', type=click.Choice(['TLSv1', 'TLSv1.1', 'TLSv1.2', 'TLSv1.3'], case_sensitive=False), help='TLS protocol version for secure connection.') @click.option('--ssl-verify-server-cert', is_flag=True, help=('Verify server\'s "Common Name" in its cert against ' 'hostname used when connecting. This option is disabled ' 'by default.')) # as of 2016-02-15 revocation list is not supported by underling PyMySQL # library (--ssl-crl and --ssl-crlpath options in vanilla mysql client) @click.option('-V', '--version', is_flag=True, help='Output mycli\'s version.') @click.option('-v', '--verbose', is_flag=True, help='Verbose output.') @click.option('-D', '--database', 'dbname', help='Database to use.') @click.option('-d', '--dsn', default='', envvar='DSN', help='Use DSN configured into the [alias_dsn] section of myclirc file.') @click.option('--list-dsn', 'list_dsn', is_flag=True, help='list of DSN configured into the [alias_dsn] section of myclirc file.') @click.option('--list-ssh-config', 'list_ssh_config', is_flag=True, help='list ssh configurations in the ssh config (requires paramiko).') @click.option('-R', '--prompt', 'prompt', help='Prompt format (Default: "{0}").'.format( MyCli.default_prompt)) @click.option('-l', '--logfile', type=click.File(mode='a', encoding='utf-8'), help='Log every query and its results to a file.') @click.option('--defaults-group-suffix', type=str, help='Read MySQL config groups with the specified suffix.') @click.option('--defaults-file', type=click.Path(), help='Only read MySQL options from the given file.') @click.option('--myclirc', type=click.Path(), default="~/.myclirc", help='Location of myclirc file.') @click.option('--auto-vertical-output', is_flag=True, help='Automatically switch to vertical output mode if the result is wider than the terminal width.') @click.option('-t', '--table', is_flag=True, help='Display batch output in table format.') @click.option('--csv', is_flag=True, help='Display batch output in CSV format.') @click.option('--warn/--no-warn', default=None, help='Warn before running a destructive query.') @click.option('--local-infile', type=bool, help='Enable/disable LOAD DATA LOCAL INFILE.') @click.option('-g', '--login-path', type=str, help='Read this path from the login file.') @click.option('-e', '--execute', type=str, help='Execute command and quit.') @click.option('--init-command', type=str, help='SQL statement to execute after connecting.') @click.option('--charset', type=str, help='Character set for MySQL session.') @click.option('--password-file', type=click.Path(), help='File or FIFO path containing the password to connect to the db if not specified otherwise.') @click.argument('database', default='', nargs=1) def cli(database, user, host, port, socket, password, dbname, version, verbose, prompt, logfile, defaults_group_suffix, defaults_file, login_path, auto_vertical_output, local_infile, ssl_enable, ssl_ca, ssl_capath, ssl_cert, ssl_key, ssl_cipher, tls_version, ssl_verify_server_cert, table, csv, warn, execute, myclirc, dsn, list_dsn, ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host, init_command, charset, password_file): """A MySQL terminal client with auto-completion and syntax highlighting. \b Examples: - mycli my_database - mycli -u my_user -h my_host.com my_database - mycli mysql://my_user@my_host.com:3306/my_database """ if version: print('Version:', __version__) sys.exit(0) mycli = MyCli(prompt=prompt, logfile=logfile, defaults_suffix=defaults_group_suffix, defaults_file=defaults_file, login_path=login_path, auto_vertical_output=auto_vertical_output, warn=warn, myclirc=myclirc) if list_dsn: try: alias_dsn = mycli.config['alias_dsn'] except KeyError as err: click.secho('Invalid DSNs found in the config file. '\ 'Please check the "[alias_dsn]" section in myclirc.', err=True, fg='red') exit(1) except Exception as e: click.secho(str(e), err=True, fg='red') exit(1) for alias, value in alias_dsn.items(): if verbose: click.secho("{} : {}".format(alias, value)) else: click.secho(alias) sys.exit(0) if list_ssh_config: ssh_config = read_ssh_config(ssh_config_path) for host in ssh_config.get_hostnames(): if verbose: host_config = ssh_config.lookup(host) click.secho("{} : {}".format( host, host_config.get('hostname'))) else: click.secho(host) sys.exit(0) # Choose which ever one has a valid value. database = dbname or database ssl = { 'enable': ssl_enable, 'ca': ssl_ca and os.path.expanduser(ssl_ca), 'cert': ssl_cert and os.path.expanduser(ssl_cert), 'key': ssl_key and os.path.expanduser(ssl_key), 'capath': ssl_capath, 'cipher': ssl_cipher, 'tls_version': tls_version, 'check_hostname': ssl_verify_server_cert, } # remove empty ssl options ssl = {k: v for k, v in ssl.items() if v is not None} dsn_uri = None # Treat the database argument as a DSN alias if we're missing # other connection information. if (mycli.config['alias_dsn'] and database and '://' not in database and not any([user, password, host, port, login_path])): dsn, database = database, '' if database and '://' in database: dsn_uri, database = database, '' if dsn: try: dsn_uri = mycli.config['alias_dsn'][dsn] except KeyError: click.secho('Could not find the specified DSN in the config file. ' 'Please check the "[alias_dsn]" section in your ' 'myclirc.', err=True, fg='red') exit(1) else: mycli.dsn_alias = dsn if dsn_uri: uri = urlparse(dsn_uri) if not database: database = uri.path[1:] # ignore the leading fwd slash if not user and uri.username is not None: user = unquote(uri.username) if not password and uri.password is not None: password = unquote(uri.password) if not host: host = uri.hostname if not port: port = uri.port if ssh_config_host: ssh_config = read_ssh_config( ssh_config_path ).lookup(ssh_config_host) ssh_host = ssh_host if ssh_host else ssh_config.get('hostname') ssh_user = ssh_user if ssh_user else ssh_config.get('user') if ssh_config.get('port') and ssh_port == 22: # port has a default value, overwrite it if it's in the config ssh_port = int(ssh_config.get('port')) ssh_key_filename = ssh_key_filename if ssh_key_filename else ssh_config.get( 'identityfile', [None])[0] ssh_key_filename = ssh_key_filename and os.path.expanduser(ssh_key_filename) mycli.connect( database=database, user=user, passwd=password, host=host, port=port, socket=socket, local_infile=local_infile, ssl=ssl, ssh_user=ssh_user, ssh_host=ssh_host, ssh_port=ssh_port, ssh_password=ssh_password, ssh_key_filename=ssh_key_filename, init_command=init_command, charset=charset, password_file=password_file ) mycli.logger.debug('Launch Params: \n' '\tdatabase: %r' '\tuser: %r' '\thost: %r' '\tport: %r', database, user, host, port) # --execute argument if execute: try: if csv: mycli.formatter.format_name = 'csv' if execute.endswith(r'\G'): execute = execute[:-2] elif table: if execute.endswith(r'\G'): execute = execute[:-2] else: mycli.formatter.format_name = 'tsv' mycli.run_query(execute) exit(0) except Exception as e: click.secho(str(e), err=True, fg='red') exit(1) if sys.stdin.isatty(): mycli.run_cli() else: stdin = click.get_text_stream('stdin') try: stdin_text = stdin.read() except MemoryError: click.secho('Failed! Ran out of memory.', err=True, fg='red') click.secho('You might want to try the official mysql client.', err=True, fg='red') click.secho('Sorry... :(', err=True, fg='red') exit(1) if mycli.destructive_warning and is_destructive(stdin_text): try: sys.stdin = open('/dev/tty') warn_confirmed = confirm_destructive_query(stdin_text) except (IOError, OSError): mycli.logger.warning('Unable to open TTY as stdin.') if not warn_confirmed: exit(0) try: new_line = True if csv: mycli.formatter.format_name = 'csv' elif not table: mycli.formatter.format_name = 'tsv' mycli.run_query(stdin_text, new_line=new_line) exit(0) except Exception as e: click.secho(str(e), err=True, fg='red') exit(1) def need_completion_refresh(queries): """Determines if the completion needs a refresh by checking if the sql statement is an alter, create, drop or change db.""" for query in sqlparse.split(queries): try: first_token = query.split()[0] if first_token.lower() in ('alter', 'create', 'use', '\\r', '\\u', 'connect', 'drop', 'rename'): return True except Exception: return False def need_completion_reset(queries): """Determines if the statement is a database switch such as 'use' or '\\u'. When a database is changed the existing completions must be reset before we start the completion refresh for the new database. """ for query in sqlparse.split(queries): try: first_token = query.split()[0] if first_token.lower() in ('use', '\\u'): return True except Exception: return False def is_mutating(status): """Determines if the statement is mutating based on the status.""" if not status: return False mutating = set(['insert', 'update', 'delete', 'alter', 'create', 'drop', 'replace', 'truncate', 'load', 'rename']) return status.split(None, 1)[0].lower() in mutating def is_select(status): """Returns true if the first word in status is 'select'.""" if not status: return False return status.split(None, 1)[0].lower() == 'select' def thanks_picker(): import mycli lines = ( resources.read_text(mycli, 'AUTHORS') + resources.read_text(mycli, 'SPONSORS') ).split('\n') contents = [] for line in lines: m = re.match(r'^ *\* (.*)', line) if m: contents.append(m.group(1)) return choice(contents) @prompt_register('edit-and-execute-command') def edit_and_execute(event): """Different from the prompt-toolkit default, we want to have a choice not to execute a query after editing, hence validate_and_handle=False.""" buff = event.current_buffer buff.open_in_editor(validate_and_handle=False) def read_ssh_config(ssh_config_path): ssh_config = paramiko.config.SSHConfig() try: with open(ssh_config_path) as f: ssh_config.parse(f) except FileNotFoundError as e: click.secho(str(e), err=True, fg='red') sys.exit(1) # Paramiko prior to version 2.7 raises Exception on parse errors. # In 2.7 it has become paramiko.ssh_exception.SSHException, # but let's catch everything for compatibility except Exception as err: click.secho( f'Could not parse SSH configuration file {ssh_config_path}:\n{err} ', err=True, fg='red' ) sys.exit(1) else: return ssh_config if __name__ == "__main__": cli()