summaryrefslogtreecommitdiffstats
path: root/mycli/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'mycli/main.py')
-rwxr-xr-xmycli/main.py140
1 files changed, 95 insertions, 45 deletions
diff --git a/mycli/main.py b/mycli/main.py
index f2b2fd8..3f08e9c 100755
--- a/mycli/main.py
+++ b/mycli/main.py
@@ -1,9 +1,12 @@
+from collections import defaultdict
+from io import open
import os
import sys
import traceback
import logging
import threading
import re
+import stat
import fileinput
from collections import namedtuple
try:
@@ -13,7 +16,6 @@ except ImportError:
from time import time
from datetime import datetime
from random import choice
-from io import open
from pymysql import OperationalError
from cli_helpers.tabular_output import TabularOutputFormatter
@@ -43,7 +45,7 @@ from .packages.special.favoritequeries import FavoriteQueries
from .sqlcompleter import SQLCompleter
from .clitoolbar import create_toolbar_tokens_func
from .clistyle import style_factory, style_factory_output
-from .sqlexecute import FIELD_TYPES, SQLExecute
+from .sqlexecute import FIELD_TYPES, SQLExecute, ERROR_CODE_ACCESS_DENIED
from .clibuffer import cli_is_multiline
from .completion_refresher import CompletionRefresher
from .config import (write_default_config, get_mylogin_cnf_path,
@@ -51,7 +53,7 @@ from .config import (write_default_config, get_mylogin_cnf_path,
strip_matching_quotes)
from .key_bindings import mycli_bindings
from .lexer import MyCliLexer
-from .__init__ import __version__
+from . import __version__
from .compat import WIN
from .packages.filepaths import dir_path_exists, guess_socket_location
@@ -66,6 +68,11 @@ except ImportError:
from urllib.parse import urlparse
from urllib.parse import unquote
+try:
+ import importlib.resources as resources
+except ImportError:
+ # Python < 3.7
+ import importlib_resources as resources
try:
import paramiko
@@ -75,7 +82,10 @@ except ImportError:
# Query tuples are used for maintaining history
Query = namedtuple('Query', ['query', 'successful', 'mutating'])
-PACKAGE_ROOT = os.path.abspath(os.path.dirname(__file__))
+SUPPORT_INFO = (
+ 'Home: http://mycli.net\n'
+ 'Bug tracker: https://github.com/dbcli/mycli/issues'
+)
class MyCli(object):
@@ -89,7 +99,7 @@ class MyCli(object):
'/etc/my.cnf',
'/etc/mysql/my.cnf',
'/usr/local/etc/my.cnf',
- '~/.my.cnf'
+ os.path.expanduser('~/.my.cnf'),
]
# check XDG_CONFIG_HOME exists and not an empty string
@@ -102,7 +112,6 @@ class MyCli(object):
os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc")
]
- default_config_file = os.path.join(PACKAGE_ROOT, 'myclirc')
pwd_config_file = os.path.join(os.getcwd(), ".myclirc")
def __init__(self, sqlexecute=None, prompt=None,
@@ -122,7 +131,7 @@ class MyCli(object):
self.cnf_files = [defaults_file]
# Load config.
- config_files = ([self.default_config_file] + self.system_config_files +
+ config_files = (self.system_config_files +
[myclirc] + [self.pwd_config_file])
c = self.config = read_config_files(config_files)
self.multi_line = c['main'].as_bool('multi_line')
@@ -154,7 +163,7 @@ class MyCli(object):
# Write user config if system config wasn't the last config loaded.
if c.filename not in self.system_config_files and not os.path.exists(myclirc):
- write_default_config(self.default_config_file, myclirc)
+ write_default_config(myclirc)
# audit log
if self.logfile is None and 'audit_log' in c['main']:
@@ -326,20 +335,33 @@ class MyCli(object):
cnf = read_config_files(files, list_values=False)
sections = ['client', 'mysqld']
+ key_transformations = {
+ 'mysqld': {
+ 'socket': 'default_socket',
+ 'port': 'default_port',
+ },
+ }
+
if self.login_path and self.login_path != 'client':
sections.append(self.login_path)
if self.defaults_suffix:
sections.extend([sect + self.defaults_suffix for sect in sections])
- def get(key):
- result = None
- for sect in cnf:
- if sect in sections and key in cnf[sect]:
- result = strip_matching_quotes(cnf[sect][key])
- return result
+ configuration = defaultdict(lambda: None)
+ for key in keys:
+ for section in cnf:
+ if (
+ section not in sections or
+ key not in cnf[section]
+ ):
+ continue
+ new_key = key_transformations.get(section, {}).get(key) or key
+ configuration[new_key] = strip_matching_quotes(
+ cnf[section][key])
+
+ return configuration
- return {x: get(x) for x in keys}
def merge_ssl_with_cnf(self, ssl, cnf):
"""Merge SSL configuration dict with cnf dict"""
@@ -367,7 +389,7 @@ class MyCli(object):
def connect(self, database='', user='', passwd='', host='', port='',
socket='', charset='', local_infile='', ssl='',
ssh_user='', ssh_host='', ssh_port='',
- ssh_password='', ssh_key_filename='', init_command=''):
+ ssh_password='', ssh_key_filename='', init_command='', password_file=''):
cnf = {'database': None,
'user': None,
@@ -375,6 +397,7 @@ class MyCli(object):
'host': None,
'port': None,
'socket': None,
+ 'default_socket': None,
'default-character-set': None,
'local-infile': None,
'loose-local-infile': None,
@@ -388,18 +411,23 @@ class MyCli(object):
cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys())
# Fall back to config values only if user did not specify a value.
-
database = database or cnf['database']
- # Socket interface not supported for SSH connections
- if port or (host and host != 'localhost') or (ssh_host and ssh_port):
- socket = ''
- else:
- socket = socket or cnf['socket'] or guess_socket_location()
user = user or cnf['user'] or os.getenv('USER')
host = host or cnf['host']
- port = int(port or cnf['port'] or 3306)
+ port = port or cnf['port']
ssl = ssl or {}
+ port = port and int(port)
+ if not port:
+ port = 3306
+ if not host or host == 'localhost':
+ socket = (
+ cnf['socket'] or
+ cnf['default_socket'] or
+ guess_socket_location()
+ )
+
+
passwd = passwd if isinstance(passwd, str) else cnf['password']
charset = charset or cnf['default-character-set'] or 'utf8'
@@ -417,6 +445,10 @@ class MyCli(object):
if not any(v for v in ssl.values()):
ssl = None
+ # if the passwd is not specfied try to set it using the password_file option
+ password_from_file = self.get_password_from_file(password_file)
+ passwd = passwd or password_from_file
+
# Connect to the database.
def _connect():
@@ -427,9 +459,12 @@ class MyCli(object):
ssh_password, ssh_key_filename, init_command
)
except OperationalError as e:
- if ('Access denied for user' in e.args[1]):
- new_passwd = click.prompt('Password', hide_input=True,
- show_default=False, type=str, err=True)
+ if e.args[0] == ERROR_CODE_ACCESS_DENIED:
+ if password_from_file:
+ new_passwd = password_from_file
+ else:
+ new_passwd = click.prompt('Password', hide_input=True,
+ show_default=False, type=str, err=True)
self.sqlexecute = SQLExecute(
database, user, new_passwd, host, port, socket,
charset, local_infile, ssl, ssh_user, ssh_host,
@@ -484,6 +519,17 @@ class MyCli(object):
self.echo(str(e), err=True, fg='red')
exit(1)
+ def get_password_from_file(self, password_file):
+ password_from_file = None
+ if password_file:
+ if (os.path.isfile(password_file) or stat.S_ISFIFO(os.stat(password_file).st_mode)) \
+ and os.access(password_file, os.R_OK):
+ with open(password_file) as fp:
+ password_from_file = fp.readline()
+ password_from_file = password_from_file.rstrip().lstrip()
+
+ return password_from_file
+
def handle_editor_command(self, text):
r"""Editor command is any query that is prefixed or suffixed by a '\e'.
The reason for a while loop is because a user might edit a query
@@ -542,9 +588,6 @@ class MyCli(object):
if self.smart_completion:
self.refresh_completions()
- author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS')
- sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS')
-
history_file = os.path.expanduser(
os.environ.get('MYCLI_HISTFILE', '~/.mycli-history'))
if dir_path_exists(history_file):
@@ -559,12 +602,10 @@ class MyCli(object):
key_bindings = mycli_bindings(self)
if not self.less_chatty:
- print(' '.join(sqlexecute.server_type()))
+ print(sqlexecute.server_info)
print('mycli', __version__)
- print('Chat: https://gitter.im/dbcli/mycli')
- print('Mail: https://groups.google.com/forum/#!forum/mycli-users')
- print('Home: http://mycli.net')
- print('Thanks to the contributor -', thanks_picker([author_file, sponsor_file]))
+ print(SUPPORT_INFO)
+ print('Thanks to the contributor -', thanks_picker())
def get_message():
prompt = self.get_prompt(self.prompt_format)
@@ -862,8 +903,8 @@ class MyCli(object):
if not output_via_pager:
# doesn't fit, flush buffer
- for line in buf:
- click.secho(line)
+ for buf_line in buf:
+ click.secho(buf_line)
buf = []
else:
click.secho(line)
@@ -933,7 +974,7 @@ class MyCli(object):
string = string.replace('\\u', sqlexecute.user or '(none)')
string = string.replace('\\h', host or '(none)')
string = string.replace('\\d', sqlexecute.dbname or '(none)')
- string = string.replace('\\t', sqlexecute.server_type()[0] or 'mycli')
+ string = string.replace('\\t', sqlexecute.server_info.species.name)
string = string.replace('\\n', "\n")
string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y'))
string = string.replace('\\m', now.strftime('%M'))
@@ -1083,7 +1124,7 @@ class MyCli(object):
help='Warn before running a destructive query.')
@click.option('--local-infile', type=bool,
help='Enable/disable LOAD DATA LOCAL INFILE.')
-@click.option('--login-path', type=str,
+@click.option('-g', '--login-path', type=str,
help='Read this path from the login file.')
@click.option('-e', '--execute', type=str,
help='Execute command and quit.')
@@ -1091,6 +1132,8 @@ class MyCli(object):
help='SQL statement to execute after connecting.')
@click.option('--charset', type=str,
help='Character set for MySQL session.')
+@click.option('--password-file', type=click.Path(),
+ help='File or FIFO path containing the password to connect to the db if not specified otherwise.')
@click.argument('database', default='', nargs=1)
def cli(database, user, host, port, socket, password, dbname,
version, verbose, prompt, logfile, defaults_group_suffix,
@@ -1099,7 +1142,7 @@ def cli(database, user, host, port, socket, password, dbname,
ssl_verify_server_cert, table, csv, warn, execute, myclirc, dsn,
list_dsn, ssh_user, ssh_host, ssh_port, ssh_password,
ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host,
- init_command, charset):
+ init_command, charset, password_file):
"""A MySQL terminal client with auto-completion and syntax highlighting.
\b
@@ -1225,7 +1268,8 @@ def cli(database, user, host, port, socket, password, dbname,
ssh_password=ssh_password,
ssh_key_filename=ssh_key_filename,
init_command=init_command,
- charset=charset
+ charset=charset,
+ password_file=password_file
)
mycli.logger.debug('Launch Params: \n'
@@ -1328,9 +1372,15 @@ def is_select(status):
return status.split(None, 1)[0].lower() == 'select'
-def thanks_picker(files=()):
+def thanks_picker():
+ import mycli
+ lines = (
+ resources.read_text(mycli, 'AUTHORS') +
+ resources.read_text(mycli, 'SPONSORS')
+ ).split('\n')
+
contents = []
- for line in fileinput.input(files=files):
+ for line in lines:
m = re.match(r'^ *\* (.*)', line)
if m:
contents.append(m.group(1))
@@ -1350,6 +1400,9 @@ def read_ssh_config(ssh_config_path):
try:
with open(ssh_config_path) as f:
ssh_config.parse(f)
+ except FileNotFoundError as e:
+ click.secho(str(e), err=True, fg='red')
+ sys.exit(1)
# Paramiko prior to version 2.7 raises Exception on parse errors.
# In 2.7 it has become paramiko.ssh_exception.SSHException,
# but let's catch everything for compatibility
@@ -1359,9 +1412,6 @@ def read_ssh_config(ssh_config_path):
err=True, fg='red'
)
sys.exit(1)
- except FileNotFoundError as e:
- click.secho(str(e), err=True, fg='red')
- sys.exit(1)
else:
return ssh_config