summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/ci.yml14
-rw-r--r--README.md26
-rw-r--r--changelog.md47
-rw-r--r--mycli/AUTHORS3
-rw-r--r--mycli/__init__.py2
-rw-r--r--mycli/clibuffer.py1
-rw-r--r--mycli/config.py104
-rw-r--r--mycli/key_bindings.py8
-rwxr-xr-xmycli/main.py140
-rw-r--r--mycli/packages/completion_engine.py2
-rw-r--r--mycli/packages/parseutils.py14
-rw-r--r--mycli/packages/special/iocommands.py2
-rw-r--r--mycli/packages/tabular_output/sql_format.py1
-rw-r--r--mycli/sqlcompleter.py2
-rw-r--r--mycli/sqlexecute.py102
-rw-r--r--pytest.ini2
-rwxr-xr-xrelease.py1
-rwxr-xr-xsetup.py8
-rw-r--r--test/features/connection.feature35
-rw-r--r--test/features/environment.py48
-rw-r--r--test/features/steps/auto_vertical.py3
-rw-r--r--test/features/steps/connection.py71
-rw-r--r--test/features/steps/utils.py12
-rw-r--r--test/features/steps/wrappers.py55
-rw-r--r--test/test_main.py9
-rw-r--r--test/test_sqlexecute.py22
26 files changed, 565 insertions, 169 deletions
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 413b749..0a14472 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -7,13 +7,20 @@ on:
jobs:
linux:
-
- runs-on: ubuntu-latest
-
strategy:
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
+ include:
+ - python-version: 3.6
+ os: ubuntu-16.04 # MySQL 5.7.32
+ - python-version: 3.7
+ os: ubuntu-18.04 # MySQL 5.7.32
+ - python-version: 3.8
+ os: ubuntu-18.04 # MySQL 5.7.32
+ - python-version: 3.9
+ os: ubuntu-20.04 # MySQL 8.0.22
+ runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
@@ -42,6 +49,7 @@ jobs:
- name: Pytest / behave
env:
PYTEST_PASSWORD: root
+ PYTEST_HOST: 127.0.0.1
run: |
./setup.py test --pytest-args="--cov-report= --cov=mycli"
diff --git a/README.md b/README.md
index c709eb8..cc04a91 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,8 @@
# mycli
-[![Build Status](https://travis-ci.org/dbcli/mycli.svg?branch=master)](https://travis-ci.org/dbcli/mycli)
-[![PyPI](https://img.shields.io/pypi/v/mycli.svg?style=plastic)](https://pypi.python.org/pypi/mycli)
-[![Join the chat at https://gitter.im/dbcli/mycli](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/dbcli/mycli?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
+[![Build Status](https://github.com/dbcli/mycli/workflows/mycli/badge.svg)](https://github.com/dbcli/mycli/actions?query=workflow%3Amycli)
+[![PyPI](https://img.shields.io/pypi/v/mycli.svg)](https://pypi.python.org/pypi/mycli)
+[![LGTM](https://img.shields.io/lgtm/grade/python/github/dbcli/mycli.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/dbcli/mycli/context:python)
A command line client for MySQL that can do auto-completion and syntax highlighting.
@@ -53,6 +53,7 @@ $ sudo apt-get install mycli # Only on debian or ubuntu
-h, --host TEXT Host address of the database.
-P, --port INTEGER Port number to use for connection. Honors
$MYSQL_TCP_PORT.
+
-u, --user TEXT User name to connect to the database.
-S, --socket TEXT The socket file to use for connection.
-p, --password TEXT Password to connect to the database.
@@ -63,8 +64,11 @@ $ sudo apt-get install mycli # Only on debian or ubuntu
--ssh-password TEXT Password to connect to ssh server.
--ssh-key-filename TEXT Private key filename (identify file) for the
ssh connection.
+
--ssh-config-path TEXT Path to ssh configuration.
- --ssh-config-host TEXT Host for ssh server in ssh configurations (requires paramiko).
+ --ssh-config-host TEXT Host to connect to ssh server reading from ssh
+ configuration.
+
--ssl-ca PATH CA file in PEM format.
--ssl-capath TEXT CA directory.
--ssl-cert PATH X509 cert in PEM format.
@@ -73,33 +77,43 @@ $ sudo apt-get install mycli # Only on debian or ubuntu
--ssl-verify-server-cert Verify server's "Common Name" in its cert
against hostname used when connecting. This
option is disabled by default.
+
-V, --version Output mycli's version.
-v, --verbose Verbose output.
-D, --database TEXT Database to use.
-d, --dsn TEXT Use DSN configured into the [alias_dsn]
section of myclirc file.
+
--list-dsn list of DSN configured into the [alias_dsn]
section of myclirc file.
- --list-ssh-config list ssh configurations in the ssh config (requires paramiko).
+
+ --list-ssh-config list ssh configurations in the ssh config
+ (requires paramiko).
+
-R, --prompt TEXT Prompt format (Default: "\t \u@\h:\d> ").
-l, --logfile FILENAME Log every query and its results to a file.
--defaults-group-suffix TEXT Read MySQL config groups with the specified
suffix.
+
--defaults-file PATH Only read MySQL options from the given file.
--myclirc PATH Location of myclirc file.
--auto-vertical-output Automatically switch to vertical output mode
if the result is wider than the terminal
width.
+
-t, --table Display batch output in table format.
--csv Display batch output in CSV format.
--warn / --no-warn Warn before running a destructive query.
--local-infile BOOLEAN Enable/disable LOAD DATA LOCAL INFILE.
- --login-path TEXT Read this path from the login file.
+ -g, --login-path TEXT Read this path from the login file.
-e, --execute TEXT Execute command and quit.
--init-command TEXT SQL statement to execute after connecting.
--charset TEXT Character set for MySQL session.
+ --password-file PATH File or FIFO path containing the password
+ to connect to the db if not specified otherwise
--help Show this message and exit.
+
Features
--------
diff --git a/changelog.md b/changelog.md
index fe6e268..95e594f 100644
--- a/changelog.md
+++ b/changelog.md
@@ -1,19 +1,60 @@
+TBD:
+====
+
+*
+
+1.24.1:
+=======
+
+Bug Fixes:
+---------
+* Restore dependency on cryptography for the interactive password prompt
+
+
+1.24.0
+======
+
+Bug Fixes:
+----------
+* Allow `FileNotFound` exception for SSH config files.
+* Fix startup error on MySQL < 5.0.22
+* Check error code rather than message for Access Denied error
+* Fix login with ~/.my.cnf files
+
+Features:
+---------
+* Add `-g` shortcut to option `--login-path`.
+* Alt-Enter dispatches the command in multi-line mode.
+* Allow to pass a file or FIFO path with --password-file when password is not specified or is failing (as suggested in this best-practice https://www.netmeister.org/blog/passing-passwords.html)
+
+Internal:
+---------
+* Remove unused function is_open_quote()
+* Use importlib, instead of file links, to locate resources
+* Test various host-port combinations in command line arguments
+* Switched from Cryptography to pyaes for decrypting mylogin.cnf
+
+
1.23.2
-===
+======
Bug Fixes:
----------
* Ensure `--port` is always an int.
1.23.1
-===
+======
Bug Fixes:
----------
* Allow `--host` without `--port` to make a TCP connection.
1.23.0
-===
+======
+
+Bug Fixes:
+----------
+* Fix config file include logic
Features:
---------
diff --git a/mycli/AUTHORS b/mycli/AUTHORS
index 221ce8b..8cdea91 100644
--- a/mycli/AUTHORS
+++ b/mycli/AUTHORS
@@ -75,6 +75,8 @@ Contributors:
* Zach DeCook
* kevinhwang91
* KITAGAWA Yasutaka
+ * Nicolas Palumbo
+ * Andy Teijelo PĂ©rez
* bitkeen
* Morgan Mitchell
* Massimiliano Torromeo
@@ -82,6 +84,7 @@ Contributors:
* xeron
* 0xflotus
* Seamile
+ * Jerome Provensal
Creator:
--------
diff --git a/mycli/__init__.py b/mycli/__init__.py
index 375471f..785c3b8 100644
--- a/mycli/__init__.py
+++ b/mycli/__init__.py
@@ -1 +1 @@
-__version__ = '1.23.2'
+__version__ = '1.24.1'
diff --git a/mycli/clibuffer.py b/mycli/clibuffer.py
index c0cb5c1..81353b6 100644
--- a/mycli/clibuffer.py
+++ b/mycli/clibuffer.py
@@ -1,7 +1,6 @@
from prompt_toolkit.enums import DEFAULT_BUFFER
from prompt_toolkit.filters import Condition
from prompt_toolkit.application import get_app
-from .packages.parseutils import is_open_quote
from .packages import special
diff --git a/mycli/config.py b/mycli/config.py
index e0f2d1f..5d71109 100644
--- a/mycli/config.py
+++ b/mycli/config.py
@@ -1,5 +1,3 @@
-import io
-import shutil
from copy import copy
from io import BytesIO, TextIOWrapper
import logging
@@ -7,11 +5,16 @@ import os
from os.path import exists
import struct
import sys
-from typing import Union
+from typing import Union, IO
from configobj import ConfigObj, ConfigObjError
-from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
-from cryptography.hazmat.backends import default_backend
+import pyaes
+
+try:
+ import importlib.resources as resources
+except ImportError:
+ # Python < 3.7
+ import importlib_resources as resources
try:
basestring
@@ -49,9 +52,9 @@ def read_config_file(f, list_values=True):
config = ConfigObj(f, interpolation=False, encoding='utf8',
list_values=list_values)
except ConfigObjError as e:
- log(logger, logging.ERROR, "Unable to parse line {0} of config file "
+ log(logger, logging.WARNING, "Unable to parse line {0} of config file "
"'{1}'.".format(e.line_number, f))
- log(logger, logging.ERROR, "Using successfully parsed config values.")
+ log(logger, logging.WARNING, "Using successfully parsed config values.")
return e.config
except (IOError, OSError) as e:
log(logger, logging.WARNING, "You don't have permission to read "
@@ -61,7 +64,7 @@ def read_config_file(f, list_values=True):
return config
-def get_included_configs(config_file: Union[str, io.TextIOWrapper]) -> list:
+def get_included_configs(config_file: Union[str, TextIOWrapper]) -> list:
"""Get a list of configuration files that are included into config_path
with !includedir directive.
@@ -95,7 +98,7 @@ def get_included_configs(config_file: Union[str, io.TextIOWrapper]) -> list:
def read_config_files(files, list_values=True):
"""Read and merge a list of config files."""
- config = ConfigObj(list_values=list_values)
+ config = create_default_config(list_values=list_values)
_files = copy(files)
while _files:
_file = _files.pop(0)
@@ -112,12 +115,21 @@ def read_config_files(files, list_values=True):
return config
-def write_default_config(source, destination, overwrite=False):
+def create_default_config(list_values=True):
+ import mycli
+ default_config_file = resources.open_text(mycli, 'myclirc')
+ return read_config_file(default_config_file, list_values=list_values)
+
+
+def write_default_config(destination, overwrite=False):
+ import mycli
+ default_config = resources.read_text(mycli, 'myclirc')
destination = os.path.expanduser(destination)
if not overwrite and exists(destination):
return
- shutil.copyfile(source, destination)
+ with open(destination, 'w') as f:
+ f.write(default_config)
def get_mylogin_cnf_path():
@@ -160,6 +172,58 @@ def open_mylogin_cnf(name):
return TextIOWrapper(plaintext)
+# TODO reuse code between encryption an decryption
+def encrypt_mylogin_cnf(plaintext: IO[str]):
+ """Encryption of .mylogin.cnf file, analogous to calling
+ mysql_config_editor.
+
+ Code is based on the python implementation by Kristian Koehntopp
+ https://github.com/isotopp/mysql-config-coder
+
+ """
+ def realkey(key):
+ """Create the AES key from the login key."""
+ rkey = bytearray(16)
+ for i in range(len(key)):
+ rkey[i % 16] ^= key[i]
+ return bytes(rkey)
+
+ def encode_line(plaintext, real_key, buf_len):
+ aes = pyaes.AESModeOfOperationECB(real_key)
+ text_len = len(plaintext)
+ pad_len = buf_len - text_len
+ pad_chr = bytes(chr(pad_len), "utf8")
+ plaintext = plaintext.encode() + pad_chr * pad_len
+ encrypted_text = b''.join(
+ [aes.encrypt(plaintext[i: i + 16])
+ for i in range(0, len(plaintext), 16)]
+ )
+ return encrypted_text
+
+ LOGIN_KEY_LENGTH = 20
+ key = os.urandom(LOGIN_KEY_LENGTH)
+ real_key = realkey(key)
+
+ outfile = BytesIO()
+
+ outfile.write(struct.pack("i", 0))
+ outfile.write(key)
+
+ while True:
+ line = plaintext.readline()
+ if not line:
+ break
+ real_len = len(line)
+ pad_len = (int(real_len / 16) + 1) * 16
+
+ outfile.write(struct.pack("i", pad_len))
+ x = encode_line(line, real_key, pad_len)
+ outfile.write(x)
+
+ outfile.seek(0)
+ return outfile
+
+
def read_and_decrypt_mylogin_cnf(f):
"""Read and decrypt the contents of .mylogin.cnf.
@@ -201,11 +265,9 @@ def read_and_decrypt_mylogin_cnf(f):
return None
rkey = struct.pack('16B', *rkey)
- # Create a decryptor object using the key.
- decryptor = _get_decryptor(rkey)
-
# Create a bytes buffer to hold the plaintext.
plaintext = BytesIO()
+ aes = pyaes.AESModeOfOperationECB(rkey)
while True:
# Read the length of the ciphertext.
@@ -216,7 +278,10 @@ def read_and_decrypt_mylogin_cnf(f):
# Read cipher_len bytes from the file and decrypt.
cipher = f.read(cipher_len)
- plain = _remove_pad(decryptor.update(cipher))
+ plain = _remove_pad(
+ b''.join([aes.decrypt(cipher[i: i + 16])
+ for i in range(0, cipher_len, 16)])
+ )
if plain is False:
continue
plaintext.write(plain)
@@ -244,7 +309,7 @@ def str_to_bool(s):
elif s.lower() in false_values:
return False
else:
- raise ValueError('not a recognized boolean value: %s'.format(s))
+ raise ValueError('not a recognized boolean value: {0}'.format(s))
def strip_matching_quotes(s):
@@ -260,15 +325,8 @@ def strip_matching_quotes(s):
return s
-def _get_decryptor(key):
- """Get the AES decryptor."""
- c = Cipher(algorithms.AES(key), modes.ECB(), backend=default_backend())
- return c.decryptor()
-
-
def _remove_pad(line):
"""Remove the pad from the *line*."""
- pad_length = ord(line[-1:])
try:
# Determine pad length.
pad_length = ord(line[-1:])
diff --git a/mycli/key_bindings.py b/mycli/key_bindings.py
index 57b917b..4a24c82 100644
--- a/mycli/key_bindings.py
+++ b/mycli/key_bindings.py
@@ -78,8 +78,12 @@ def mycli_bindings(mycli):
@kb.add('escape', 'enter')
def _(event):
- """Introduces a line break regardless of multi-line mode or not."""
+ """Introduces a line break in multi-line mode, or dispatches the
+ command in single-line mode."""
_logger.debug('Detected alt-enter key.')
- event.app.current_buffer.insert_text('\n')
+ if mycli.multi_line:
+ event.app.current_buffer.validate_and_handle()
+ else:
+ event.app.current_buffer.insert_text('\n')
return kb
diff --git a/mycli/main.py b/mycli/main.py
index f2b2fd8..3f08e9c 100755
--- a/mycli/main.py
+++ b/mycli/main.py
@@ -1,9 +1,12 @@
+from collections import defaultdict
+from io import open
import os
import sys
import traceback
import logging
import threading
import re
+import stat
import fileinput
from collections import namedtuple
try:
@@ -13,7 +16,6 @@ except ImportError:
from time import time
from datetime import datetime
from random import choice
-from io import open
from pymysql import OperationalError
from cli_helpers.tabular_output import TabularOutputFormatter
@@ -43,7 +45,7 @@ from .packages.special.favoritequeries import FavoriteQueries
from .sqlcompleter import SQLCompleter
from .clitoolbar import create_toolbar_tokens_func
from .clistyle import style_factory, style_factory_output
-from .sqlexecute import FIELD_TYPES, SQLExecute
+from .sqlexecute import FIELD_TYPES, SQLExecute, ERROR_CODE_ACCESS_DENIED
from .clibuffer import cli_is_multiline
from .completion_refresher import CompletionRefresher
from .config import (write_default_config, get_mylogin_cnf_path,
@@ -51,7 +53,7 @@ from .config import (write_default_config, get_mylogin_cnf_path,
strip_matching_quotes)
from .key_bindings import mycli_bindings
from .lexer import MyCliLexer
-from .__init__ import __version__
+from . import __version__
from .compat import WIN
from .packages.filepaths import dir_path_exists, guess_socket_location
@@ -66,6 +68,11 @@ except ImportError:
from urllib.parse import urlparse
from urllib.parse import unquote
+try:
+ import importlib.resources as resources
+except ImportError:
+ # Python < 3.7
+ import importlib_resources as resources
try:
import paramiko
@@ -75,7 +82,10 @@ except ImportError:
# Query tuples are used for maintaining history
Query = namedtuple('Query', ['query', 'successful', 'mutating'])
-PACKAGE_ROOT = os.path.abspath(os.path.dirname(__file__))
+SUPPORT_INFO = (
+ 'Home: http://mycli.net\n'
+ 'Bug tracker: https://github.com/dbcli/mycli/issues'
+)
class MyCli(object):
@@ -89,7 +99,7 @@ class MyCli(object):
'/etc/my.cnf',
'/etc/mysql/my.cnf',
'/usr/local/etc/my.cnf',
- '~/.my.cnf'
+ os.path.expanduser('~/.my.cnf'),
]
# check XDG_CONFIG_HOME exists and not an empty string
@@ -102,7 +112,6 @@ class MyCli(object):
os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc")
]
- default_config_file = os.path.join(PACKAGE_ROOT, 'myclirc')
pwd_config_file = os.path.join(os.getcwd(), ".myclirc")
def __init__(self, sqlexecute=None, prompt=None,
@@ -122,7 +131,7 @@ class MyCli(object):
self.cnf_files = [defaults_file]
# Load config.
- config_files = ([self.default_config_file] + self.system_config_files +
+ config_files = (self.system_config_files +
[myclirc] + [self.pwd_config_file])
c = self.config = read_config_files(config_files)
self.multi_line = c['main'].as_bool('multi_line')
@@ -154,7 +163,7 @@ class MyCli(object):
# Write user config if system config wasn't the last config loaded.
if c.filename not in self.system_config_files and not os.path.exists(myclirc):
- write_default_config(self.default_config_file, myclirc)
+ write_default_config(myclirc)
# audit log
if self.logfile is None and 'audit_log' in c['main']:
@@ -326,20 +335,33 @@ class MyCli(object):
cnf = read_config_files(files, list_values=False)
sections = ['client', 'mysqld']
+ key_transformations = {
+ 'mysqld': {
+ 'socket': 'default_socket',
+ 'port': 'default_port',
+ },
+ }
+
if self.login_path and self.login_path != 'client':
sections.append(self.login_path)
if self.defaults_suffix:
sections.extend([sect + self.defaults_suffix for sect in sections])
- def get(key):
- result = None
- for sect in cnf:
- if sect in sections and key in cnf[sect]:
- result = strip_matching_quotes(cnf[sect][key])
- return result
+ configuration = defaultdict(lambda: None)
+ for key in keys:
+ for section in cnf:
+ if (
+ section not in sections or
+ key not in cnf[section]
+ ):
+ continue
+ new_key = key_transformations.get(section, {}).get(key) or key
+ configuration[new_key] = strip_matching_quotes(
+ cnf[section][key])
+
+ return configuration
- return {x: get(x) for x in keys}
def merge_ssl_with_cnf(self, ssl, cnf):
"""Merge SSL configuration dict with cnf dict"""
@@ -367,7 +389,7 @@ class MyCli(object):
def connect(self, database='', user='', passwd='', host='', port='',
socket='', charset='', local_infile='', ssl='',
ssh_user='', ssh_host='', ssh_port='',
- ssh_password='', ssh_key_filename='', init_command=''):
+ ssh_password='', ssh_key_filename='', init_command='', password_file=''):
cnf = {'database': None,
'user': None,
@@ -375,6 +397,7 @@ class MyCli(object):
'host': None,
'port': None,
'socket': None,
+ 'default_socket': None,
'default-character-set': None,
'local-infile': None,
'loose-local-infile': None,
@@ -388,18 +411,23 @@ class MyCli(object):
cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys())
# Fall back to config values only if user did not specify a value.
-
database = database or cnf['database']
- # Socket interface not supported for SSH connections
- if port or (host and host != 'localhost') or (ssh_host and ssh_port):
- socket = ''
- else:
- socket = socket or cnf['socket'] or guess_socket_location()
user = user or cnf['user'] or os.getenv('USER')
host = host or cnf['host']
- port = int(port or cnf['port'] or 3306)
+ port = port or cnf['port']
ssl = ssl or {}
+ port = port and int(port)
+ if not port:
+ port = 3306
+ if not host or host == 'localhost':
+ socket = (
+ cnf['socket'] or
+ cnf['default_socket'] or
+ guess_socket_location()
+ )
+
+
passwd = passwd if isinstance(passwd, str) else cnf['password']
charset = charset or cnf['default-character-set'] or 'utf8'
@@ -417,6 +445,10 @@ class MyCli(object):
if not any(v for v in ssl.values()):
ssl = None
+ # if the passwd is not specfied try to set it using the password_file option
+ password_from_file = self.get_password_from_file(password_file)
+ passwd = passwd or password_from_file
+
# Connect to the database.
def _connect():
@@ -427,9 +459,12 @@ class MyCli(object):
ssh_password, ssh_key_filename, init_command
)
except OperationalError as e:
- if ('Access denied for user' in e.args[1]):
- new_passwd = click.prompt('Password', hide_input=True,
- show_default=False, type=str, err=True)
+ if e.args[0] == ERROR_CODE_ACCESS_DENIED:
+ if password_from_file:
+ new_passwd = password_from_file
+ else:
+ new_passwd = click.prompt('Password', hide_input=True,
+ show_default=False, type=str, err=True)
self.sqlexecute = SQLExecute(
database, user, new_passwd, host, port, socket,
charset, local_infile, ssl, ssh_user, ssh_host,
@@ -484,6 +519,17 @@ class MyCli(object):
self.echo(str(e), err=True, fg='red')
exit(1)
+ def get_password_from_file(self, password_file):
+ password_from_file = None
+ if password_file:
+ if (os.path.isfile(password_file) or stat.S_ISFIFO(os.stat(password_file).st_mode)) \
+ and os.access(password_file, os.R_OK):
+ with open(password_file) as fp:
+ password_from_file = fp.readline()
+ password_from_file = password_from_file.rstrip().lstrip()
+
+ return password_from_file
+
def handle_editor_command(self, text):
r"""Editor command is any query that is prefixed or suffixed by a '\e'.
The reason for a while loop is because a user might edit a query
@@ -542,9 +588,6 @@ class MyCli(object):
if self.smart_completion:
self.refresh_completions()
- author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS')
- sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS')
-
history_file = os.path.expanduser(
os.environ.get('MYCLI_HISTFILE', '~/.mycli-history'))
if dir_path_exists(history_file):
@@ -559,12 +602,10 @@ class MyCli(object):
key_bindings = mycli_bindings(self)
if not self.less_chatty:
- print(' '.join(sqlexecute.server_type()))
+ print(sqlexecute.server_info)
print('mycli', __version__)
- print('Chat: https://gitter.im/dbcli/mycli')
- print('Mail: https://groups.google.com/forum/#!forum/mycli-users')
- print('Home: http://mycli.net')
- print('Thanks to the contributor -', thanks_picker([author_file, sponsor_file]))
+ print(SUPPORT_INFO)
+ print('Thanks to the contributor -', thanks_picker())
def get_message():
prompt = self.get_prompt(self.prompt_format)
@@ -862,8 +903,8 @@ class MyCli(object):
if not output_via_pager:
# doesn't fit, flush buffer
- for line in buf:
- click.secho(line)
+ for buf_line in buf:
+ click.secho(buf_line)
buf = []
else:
click.secho(line)
@@ -933,7 +974,7 @@ class MyCli(object):
string = string.replace('\\u', sqlexecute.user or '(none)')
string = string.replace('\\h', host or '(none)')
string = string.replace('\\d', sqlexecute.dbname or '(none)')
- string = string.replace('\\t', sqlexecute.server_type()[0] or 'mycli')
+ string = string.replace('\\t', sqlexecute.server_info.species.name)
string = string.replace('\\n', "\n")
string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y'))
string = string.replace('\\m', now.strftime('%M'))
@@ -1083,7 +1124,7 @@ class MyCli(object):
help='Warn before running a destructive query.')
@click.option('--local-infile', type=bool,
help='Enable/disable LOAD DATA LOCAL INFILE.')
-@click.option('--login-path', type=str,
+@click.option('-g', '--login-path', type=str,
help='Read this path from the login file.')
@click.option('-e', '--execute', type=str,
help='Execute command and quit.')
@@ -1091,6 +1132,8 @@ class MyCli(object):
help='SQL statement to execute after connecting.')
@click.option('--charset', type=str,
help='Character set for MySQL session.')
+@click.option('--password-file', type=click.Path(),
+ help='File or FIFO path containing the password to connect to the db if not specified otherwise.')
@click.argument('database', default='', nargs=1)
def cli(database, user, host, port, socket, password, dbname,
version, verbose, prompt, logfile, defaults_group_suffix,
@@ -1099,7 +1142,7 @@ def cli(database, user, host, port, socket, password, dbname,
ssl_verify_server_cert, table, csv, warn, execute, myclirc, dsn,
list_dsn, ssh_user, ssh_host, ssh_port, ssh_password,
ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host,
- init_command, charset):
+ init_command, charset, password_file):
"""A MySQL terminal client with auto-completion and syntax highlighting.
\b
@@ -1225,7 +1268,8 @@ def cli(database, user, host, port, socket, password, dbname,
ssh_password=ssh_password,
ssh_key_filename=ssh_key_filename,
init_command=init_command,
- charset=charset
+ charset=charset,
+ password_file=password_file
)
mycli.logger.debug('Launch Params: \n'
@@ -1328,9 +1372,15 @@ def is_select(status):
return status.split(None, 1)[0].lower() == 'select'
-def thanks_picker(files=()):
+def thanks_picker():
+ import mycli
+ lines = (
+ resources.read_text(mycli, 'AUTHORS') +
+ resources.read_text(mycli, 'SPONSORS')
+ ).split('\n')
+
contents = []
- for line in fileinput.input(files=files):
+ for line in lines:
m = re.match(r'^ *\* (.*)', line)
if m:
contents.append(m.group(1))
@@ -1350,6 +1400,9 @@ def read_ssh_config(ssh_config_path):
try:
with open(ssh_config_path) as f:
ssh_config.parse(f)
+ except FileNotFoundError as e:
+ click.secho(str(e), err=True, fg='red')
+ sys.exit(1)
# Paramiko prior to version 2.7 raises Exception on parse errors.
# In 2.7 it has become paramiko.ssh_exception.SSHException,
# but let's catch everything for compatibility
@@ -1359,9 +1412,6 @@ def read_ssh_config(ssh_config_path):
err=True, fg='red'
)
sys.exit(1)
- except FileNotFoundError as e:
- click.secho(str(e), err=True, fg='red')
- sys.exit(1)
else:
return ssh_config
diff --git a/mycli/packages/completion_engine.py b/mycli/packages/completion_engine.py
index 3cff2cc..c7db06c 100644
--- a/mycli/packages/completion_engine.py
+++ b/mycli/packages/completion_engine.py
@@ -1,5 +1,3 @@
-import os
-import sys
import sqlparse
from sqlparse.sql import Comparison, Identifier, Where
from .parseutils import last_word, extract_tables, find_prev_keyword
diff --git a/mycli/packages/parseutils.py b/mycli/packages/parseutils.py
index 268e04e..fa5f2c9 100644
--- a/mycli/packages/parseutils.py
+++ b/mycli/packages/parseutils.py
@@ -12,7 +12,8 @@ cleanup_regex = {
'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
# This matches everything except a space.
'all_punctuations': re.compile(r'([^\s]+)$'),
- }
+}
+
def last_word(text, include='alphanum_underscore'):
r"""
@@ -226,14 +227,6 @@ def is_destructive(queries):
return False
-def is_open_quote(sql):
- """Returns true if the query contains an unclosed quote."""
-
- # parsed can contain one or more semi-colon separated commands
- parsed = sqlparse.parse(sql)
- return any(_parsed_is_open_quote(p) for p in parsed)
-
-
if __name__ == '__main__':
sql = 'select * from (select t. from tabl t'
print (extract_tables(sql))
@@ -263,5 +256,4 @@ def is_dropping_database(queries, dbname):
)
if database_token is not None and normalize_db_name(database_token.get_name()) == dbname:
result = keywords[0].normalized == "DROP"
- else:
- return result
+ return result
diff --git a/mycli/packages/special/iocommands.py b/mycli/packages/special/iocommands.py
index 58066b8..01f3c7b 100644
--- a/mycli/packages/special/iocommands.py
+++ b/mycli/packages/special/iocommands.py
@@ -302,7 +302,7 @@ def execute_system_command(arg, **_):
usage = "Syntax: system [command].\n"
if not arg:
- return [(None, None, None, usage)]
+ return [(None, None, None, usage)]
try:
command = arg.strip()
diff --git a/mycli/packages/tabular_output/sql_format.py b/mycli/packages/tabular_output/sql_format.py
index 730e633..e6587bd 100644
--- a/mycli/packages/tabular_output/sql_format.py
+++ b/mycli/packages/tabular_output/sql_format.py
@@ -1,6 +1,5 @@
"""Format adapter for sql."""
-from cli_helpers.utils import filter_dict_by_key
from mycli.packages.parseutils import extract_tables
supported_formats = ('sql-insert', 'sql-update', 'sql-update-1',
diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py
index 73b9b44..3656aa6 100644
--- a/mycli/sqlcompleter.py
+++ b/mycli/sqlcompleter.py
@@ -72,7 +72,7 @@ class SQLCompleter(Completer):
if name and ((not self.name_pattern.match(name))
or (name.upper() in self.reserved_words)
or (name.upper() in self.functions)):
- name = '`%s`' % name
+ name = '`%s`' % name
return name
diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py
index 7534982..9461438 100644
--- a/mycli/sqlexecute.py
+++ b/mycli/sqlexecute.py
@@ -1,6 +1,8 @@
+import enum
import logging
+import re
+
import pymysql
-import sqlparse
from .packages import special
from pymysql.constants import FIELD_TYPE
from pymysql.converters import (convert_datetime,
@@ -18,17 +20,71 @@ FIELD_TYPES.update({
FIELD_TYPE.NULL: type(None)
})
+
+ERROR_CODE_ACCESS_DENIED = 1045
+
+
+class ServerSpecies(enum.Enum):
+ MySQL = 'MySQL'
+ MariaDB = 'MariaDB'
+ Percona = 'Percona'
+ Unknown = 'MySQL'
+
+
+class ServerInfo:
+ def __init__(self, species, version_str):
+ self.species = species
+ self.version_str = version_str
+ self.version = self.calc_mysql_version_value(version_str)
+
+ @staticmethod
+ def calc_mysql_version_value(version_str) -> int:
+ if not version_str or not isinstance(version_str, str):
+ return 0
+ try:
+ major, minor, patch = version_str.split('.')
+ except ValueError:
+ return 0
+ else:
+ return int(major) * 10_000 + int(minor) * 100 + int(patch)
+
+ @classmethod
+ def from_version_string(cls, version_string):
+ if not version_string:
+ return cls(ServerSpecies.Unknown, '')
+
+ re_species = (
+ (r'(?P<version>[0-9\.]+)-MariaDB', ServerSpecies.MariaDB),
+ (r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[0-9]+$)',
+ ServerSpecies.Percona),
+ (r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[A-Za-z0-9_]+)',
+ ServerSpecies.MySQL),
+ )
+ for regexp, species in re_species:
+ match = re.search(regexp, version_string)
+ if match is not None:
+ parsed_version = match.group('version')
+ detected_species = species
+ break
+ else:
+ detected_species = ServerSpecies.Unknown
+ parsed_version = ''
+
+ return cls(detected_species, parsed_version)
+
+ def __str__(self):
+ if self.species:
+ return f'{self.species.value} {self.version_str}'
+ else:
+ return self.version_str
+
+
class SQLExecute(object):
databases_query = '''SHOW DATABASES'''
tables_query = '''SHOW TABLES'''
- version_query = '''SELECT @@VERSION'''
-
- version_comment_query = '''SELECT @@VERSION_COMMENT'''
- version_comment_query_mysql4 = '''SHOW VARIABLES LIKE "version_comment"'''
-
show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"'''
users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user'''
@@ -52,7 +108,7 @@ class SQLExecute(object):
self.charset = charset
self.local_infile = local_infile
self.ssl = ssl
- self._server_type = None
+ self.server_info = None
self.connection_id = None
self.ssh_user = ssh_user
self.ssh_host = ssh_host
@@ -157,6 +213,7 @@ class SQLExecute(object):
self.init_command = init_command
# retrieve connection id
self.reset_connection_id()
+ self.server_info = ServerInfo.from_version_string(conn.server_version)
def run(self, statement):
"""Execute the sql in the database and return the results. The results
@@ -273,37 +330,6 @@ class SQLExecute(object):
for row in cur:
yield row
- def server_type(self):
- if self._server_type:
- return self._server_type
- with self.conn.cursor() as cur:
- _logger.debug('Version Query. sql: %r', self.version_query)
- cur.execute(self.version_query)
- version = cur.fetchone()[0]
- if version[0] == '4':
- _logger.debug('Version Comment. sql: %r',
- self.version_comment_query_mysql4)
- cur.execute(self.version_comment_query_mysql4)
- version_comment = cur.fetchone()[1].lower()
- if isinstance(version_comment, bytes):
- # with python3 this query returns bytes
- version_comment = version_comment.decode('utf-8')
- else:
- _logger.debug('Version Comment. sql: %r',
- self.version_comment_query)
- cur.execute(self.version_comment_query)
- version_comment = cur.fetchone()[0].lower()
-
- if 'mariadb' in version_comment:
- product_type = 'mariadb'
- elif 'percona' in version_comment:
- product_type = 'percona'
- else:
- product_type = 'mysql'
-
- self._server_type = (product_type, version)
- return self._server_type
-
def get_connection_id(self):
if not self.connection_id:
self.reset_connection_id()
diff --git a/pytest.ini b/pytest.ini
new file mode 100644
index 0000000..5422131
--- /dev/null
+++ b/pytest.ini
@@ -0,0 +1,2 @@
+[pytest]
+addopts = --ignore=mycli/packages/paramiko_stub/__init__.py
diff --git a/release.py b/release.py
index 30c41b3..3f18f03 100755
--- a/release.py
+++ b/release.py
@@ -1,6 +1,5 @@
"""A script to publish a release of mycli to PyPI."""
-import io
from optparse import OptionParser
import re
import subprocess
diff --git a/setup.py b/setup.py
index 4aa7f91..1d619f3 100755
--- a/setup.py
+++ b/setup.py
@@ -18,16 +18,20 @@ description = 'CLI for MySQL Database. With auto-completion and syntax highlight
install_requirements = [
'click >= 7.0',
+ 'cryptography >= 1.0.0',
'Pygments >= 1.6',
'prompt_toolkit>=3.0.6,<4.0.0',
'PyMySQL >= 0.9.2',
'sqlparse>=0.3.0,<0.4.0',
'configobj >= 5.0.5',
- 'cryptography >= 1.0.0',
'cli_helpers[styles] >= 2.0.1',
- 'pyperclip >= 1.8.1'
+ 'pyperclip >= 1.8.1',
+ 'pyaes >= 1.6.1'
]
+if sys.version_info.minor < 9:
+ install_requirements.append('importlib_resources >= 5.0.0')
+
class lint(Command):
description = 'check code against PEP 8 (and fix violations)'
diff --git a/test/features/connection.feature b/test/features/connection.feature
new file mode 100644
index 0000000..b06935e
--- /dev/null
+++ b/test/features/connection.feature
@@ -0,0 +1,35 @@
+Feature: connect to a database:
+
+ @requires_local_db
+ Scenario: run mycli on localhost without port
+ When we run mycli with arguments "host=localhost" without arguments "port"
+ When we query "status"
+ Then status contains "via UNIX socket"
+
+ Scenario: run mycli on TCP host without port
+ When we run mycli without arguments "port"
+ When we query "status"
+ Then status contains "via TCP/IP"
+
+ Scenario: run mycli with port but without host
+ When we run mycli without arguments "host"
+ When we query "status"
+ Then status contains "via TCP/IP"
+
+ @requires_local_db
+ Scenario: run mycli without host and port
+ When we run mycli without arguments "host port"
+ When we query "status"
+ Then status contains "via UNIX socket"
+
+ Scenario: run mycli with my.cnf configuration
+ When we create my.cnf file
+ When we run mycli without arguments "host port user pass defaults_file"
+ Then we are logged in
+
+ Scenario: run mycli with mylogin.cnf configuration
+ When we create mylogin.cnf file
+ When we run mycli with arguments "login_path=test_login_path" without arguments "host port user pass defaults_file"
+ Then we are logged in
+
+
diff --git a/test/features/environment.py b/test/features/environment.py
index 98c2004..1ea0f08 100644
--- a/test/features/environment.py
+++ b/test/features/environment.py
@@ -1,4 +1,5 @@
import os
+import shutil
import sys
from tempfile import mkstemp
@@ -11,6 +12,24 @@ from steps.wrappers import run_cli, wait_prompt
test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log')
+SELF_CONNECTING_FEATURES = (
+ 'test/features/connection.feature',
+)
+
+
+MY_CNF_PATH = os.path.expanduser('~/.my.cnf')
+MY_CNF_BACKUP_PATH = f'{MY_CNF_PATH}.backup'
+MYLOGIN_CNF_PATH = os.path.expanduser('~/.mylogin.cnf')
+MYLOGIN_CNF_BACKUP_PATH = f'{MYLOGIN_CNF_PATH}.backup'
+
+
+def get_db_name_from_context(context):
+ return context.config.userdata.get(
+ 'my_test_db', None
+ ) or "mycli_behave_tests"
+
+
+
def before_all(context):
"""Set env parameters."""
os.environ['LINES'] = "100"
@@ -22,7 +41,7 @@ def before_all(context):
test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
login_path_file = os.path.join(test_dir, 'mylogin.cnf')
- os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
+# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
context.package_root = os.path.abspath(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
@@ -33,8 +52,7 @@ def before_all(context):
context.exit_sent = False
vi = '_'.join([str(x) for x in sys.version_info[:3]])
- db_name = context.config.userdata.get(
- 'my_test_db', None) or "mycli_behave_tests"
+ db_name = get_db_name_from_context(context)
db_name_full = '{0}_{1}'.format(db_name, vi)
# Store get params from config/environment variables
@@ -104,11 +122,18 @@ def before_step(context, _):
context.atprompt = False
-def before_scenario(context, _):
+def before_scenario(context, arg):
with open(test_log_file, 'w') as f:
f.write('')
- run_cli(context)
- wait_prompt(context)
+ if arg.location.filename not in SELF_CONNECTING_FEATURES:
+ run_cli(context)
+ wait_prompt(context)
+
+ if os.path.exists(MY_CNF_PATH):
+ shutil.move(MY_CNF_PATH, MY_CNF_BACKUP_PATH)
+
+ if os.path.exists(MYLOGIN_CNF_PATH):
+ shutil.move(MYLOGIN_CNF_PATH, MYLOGIN_CNF_BACKUP_PATH)
def after_scenario(context, _):
@@ -134,6 +159,17 @@ def after_scenario(context, _):
context.cli.sendcontrol('d')
context.cli.expect_exact(pexpect.EOF, timeout=5)
+ if os.path.exists(MY_CNF_BACKUP_PATH):
+ shutil.move(MY_CNF_BACKUP_PATH, MY_CNF_PATH)
+
+ if os.path.exists(MYLOGIN_CNF_BACKUP_PATH):
+ shutil.move(MYLOGIN_CNF_BACKUP_PATH, MYLOGIN_CNF_PATH)
+ elif os.path.exists(MYLOGIN_CNF_PATH):
+ # This file was moved in `before_scenario`.
+ # If it exists now, it has been created during a test
+ os.remove(MYLOGIN_CNF_PATH)
+
+
# TODO: uncomment to debug a failure
# def after_step(context, step):
# if step.status == "failed":
diff --git a/test/features/steps/auto_vertical.py b/test/features/steps/auto_vertical.py
index 974740d..e1cb26f 100644
--- a/test/features/steps/auto_vertical.py
+++ b/test/features/steps/auto_vertical.py
@@ -3,11 +3,12 @@ from textwrap import dedent
from behave import then, when
import wrappers
+from utils import parse_cli_args_to_dict
@when('we run dbcli with {arg}')
def step_run_cli_with_arg(context, arg):
- wrappers.run_cli(context, run_args=arg.split('='))
+ wrappers.run_cli(context, run_args=parse_cli_args_to_dict(arg))
@when('we execute a small query')
diff --git a/test/features/steps/connection.py b/test/features/steps/connection.py
new file mode 100644
index 0000000..e16dd86
--- /dev/null
+++ b/test/features/steps/connection.py
@@ -0,0 +1,71 @@
+import io
+import os
+import shlex
+
+from behave import when, then
+import pexpect
+
+import wrappers
+from test.features.steps.utils import parse_cli_args_to_dict
+from test.features.environment import MY_CNF_PATH, MYLOGIN_CNF_PATH, get_db_name_from_context
+from test.utils import HOST, PORT, USER, PASSWORD
+from mycli.config import encrypt_mylogin_cnf
+
+
+TEST_LOGIN_PATH = 'test_login_path'
+
+
+@when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"')
+@when('we run mycli without arguments "{excluded_args}"')
+def step_run_cli_without_args(context, excluded_args, exact_args=''):
+ wrappers.run_cli(
+ context,
+ run_args=parse_cli_args_to_dict(exact_args),
+ exclude_args=parse_cli_args_to_dict(excluded_args).keys()
+ )
+
+
+@then('status contains "{expression}"')
+def status_contains(context, expression):
+ wrappers.expect_exact(context, f'{expression}', timeout=5)
+
+ # Normally, the shutdown after scenario waits for the prompt.
+ # But we may have changed the prompt, depending on parameters,
+ # so let's wait for its last character
+ context.cli.expect_exact('>')
+ context.atprompt = True
+
+
+@when('we create my.cnf file')
+def step_create_my_cnf_file(context):
+ my_cnf = (
+ '[client]\n'
+ f'host = {HOST}\n'
+ f'port = {PORT}\n'
+ f'user = {USER}\n'
+ f'password = {PASSWORD}\n'
+ )
+ with open(MY_CNF_PATH, 'w') as f:
+ f.write(my_cnf)
+
+
+@when('we create mylogin.cnf file')
+def step_create_mylogin_cnf_file(context):
+ os.environ.pop('MYSQL_TEST_LOGIN_FILE', None)
+ mylogin_cnf = (
+ f'[{TEST_LOGIN_PATH}]\n'
+ f'host = {HOST}\n'
+ f'port = {PORT}\n'
+ f'user = {USER}\n'
+ f'password = {PASSWORD}\n'
+ )
+ with open(MYLOGIN_CNF_PATH, 'wb') as f:
+ input_file = io.StringIO(mylogin_cnf)
+ f.write(encrypt_mylogin_cnf(input_file).read())
+
+
+@then('we are logged in')
+def we_are_logged_in(context):
+ db_name = get_db_name_from_context(context)
+ context.cli.expect_exact(f'{db_name}>', timeout=5)
+ context.atprompt = True
diff --git a/test/features/steps/utils.py b/test/features/steps/utils.py
new file mode 100644
index 0000000..1ae63d2
--- /dev/null
+++ b/test/features/steps/utils.py
@@ -0,0 +1,12 @@
+import shlex
+
+
+def parse_cli_args_to_dict(cli_args: str):
+ args_dict = {}
+ for arg in shlex.split(cli_args):
+ if '=' in arg:
+ key, value = arg.split('=')
+ args_dict[key] = value
+ else:
+ args_dict[arg] = None
+ return args_dict
diff --git a/test/features/steps/wrappers.py b/test/features/steps/wrappers.py
index de833dd..6408f23 100644
--- a/test/features/steps/wrappers.py
+++ b/test/features/steps/wrappers.py
@@ -3,6 +3,7 @@ import pexpect
import sys
import textwrap
+
try:
from StringIO import StringIO
except ImportError:
@@ -13,7 +14,7 @@ def expect_exact(context, expected, timeout):
timedout = False
try:
context.cli.expect_exact(expected, timeout=timeout)
- except pexpect.exceptions.TIMEOUT:
+ except pexpect.TIMEOUT:
timedout = True
if timedout:
# Strip color codes out of the output.
@@ -46,21 +47,43 @@ def expect_pager(context, expected, timeout):
context.conf['pager_boundary'], expected), timeout=timeout)
-def run_cli(context, run_args=None):
+def run_cli(context, run_args=None, exclude_args=None):
"""Run the process using pexpect."""
- run_args = run_args or []
- if context.conf.get('host', None):
- run_args.extend(('-h', context.conf['host']))
- if context.conf.get('user', None):
- run_args.extend(('-u', context.conf['user']))
- if context.conf.get('pass', None):
- run_args.extend(('-p', context.conf['pass']))
- if context.conf.get('dbname', None):
- run_args.extend(('-D', context.conf['dbname']))
- if context.conf.get('defaults-file', None):
- run_args.extend(('--defaults-file', context.conf['defaults-file']))
- if context.conf.get('myclirc', None):
- run_args.extend(('--myclirc', context.conf['myclirc']))
+ run_args = run_args or {}
+ rendered_args = []
+ exclude_args = set(exclude_args) if exclude_args else set()
+
+ conf = dict(**context.conf)
+ conf.update(run_args)
+
+ def add_arg(name, key, value):
+ if name not in exclude_args:
+ if value is not None:
+ rendered_args.extend((key, value))
+ else:
+ rendered_args.append(key)
+
+ if conf.get('host', None):
+ add_arg('host', '-h', conf['host'])
+ if conf.get('user', None):
+ add_arg('user', '-u', conf['user'])
+ if conf.get('pass', None):
+ add_arg('pass', '-p', conf['pass'])
+ if conf.get('port', None):
+ add_arg('port', '-P', str(conf['port']))
+ if conf.get('dbname', None):
+ add_arg('dbname', '-D', conf['dbname'])
+ if conf.get('defaults-file', None):
+ add_arg('defaults_file', '--defaults-file', conf['defaults-file'])
+ if conf.get('myclirc', None):
+ add_arg('myclirc', '--myclirc', conf['myclirc'])
+ if conf.get('login_path'):
+ add_arg('login_path', '--login-path', conf['login_path'])
+
+ for arg_name, arg_value in conf.items():
+ if arg_name.startswith('-'):
+ add_arg(arg_name, arg_name, arg_value)
+
try:
cli_cmd = context.conf['cli_command']
except KeyError:
@@ -73,7 +96,7 @@ def run_cli(context, run_args=None):
'"'
).format(sys.executable)
- cmd_parts = [cli_cmd] + run_args
+ cmd_parts = [cli_cmd] + rendered_args
cmd = ' '.join(cmd_parts)
context.cli = pexpect.spawnu(cmd, cwd=context.package_root)
context.logfile = StringIO()
diff --git a/test/test_main.py b/test/test_main.py
index 707c359..00fdc1b 100644
--- a/test/test_main.py
+++ b/test/test_main.py
@@ -3,8 +3,9 @@ import os
import click
from click.testing import CliRunner
-from mycli.main import MyCli, cli, thanks_picker, PACKAGE_ROOT
+from mycli.main import MyCli, cli, thanks_picker
from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
+from mycli.sqlexecute import ServerInfo
from .utils import USER, HOST, PORT, PASSWORD, dbtest, run
from textwrap import dedent
@@ -140,10 +141,7 @@ def test_batch_mode_csv(executor):
def test_thanks_picker_utf8():
- author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS')
- sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS')
-
- name = thanks_picker((author_file, sponsor_file))
+ name = thanks_picker()
assert name and isinstance(name, str)
@@ -177,6 +175,7 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
host = 'test'
user = 'test'
dbname = 'test'
+ server_info = ServerInfo.from_version_string('unknown')
port = 0
def server_type(self):
diff --git a/test/test_sqlexecute.py b/test/test_sqlexecute.py
index 5168bf6..0f38a97 100644
--- a/test/test_sqlexecute.py
+++ b/test/test_sqlexecute.py
@@ -3,6 +3,7 @@ import os
import pytest
import pymysql
+from mycli.sqlexecute import ServerInfo, ServerSpecies
from .utils import run, dbtest, set_expanded_output, is_expanded_output
@@ -270,3 +271,24 @@ def test_multiple_results(executor):
'status': '1 row in set'}
]
assert results == expected
+
+
+@pytest.mark.parametrize(
+ 'version_string, species, parsed_version_string, version',
+ (
+ ('5.7.32-35', 'Percona', '5.7.32', 50732),
+ ('5.7.32-0ubuntu0.18.04.1', 'MySQL', '5.7.32', 50732),
+ ('10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
+ ('5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
+ ('5.0.16-pro-nt-log', 'MySQL', '5.0.16', 50016),
+ ('5.1.5a-alpha', 'MySQL', '5.1.5', 50105),
+ ('unexpected version string', None, '', 0),
+ ('', None, '', 0),
+ (None, None, '', 0),
+ )
+)
+def test_version_parsing(version_string, species, parsed_version_string, version):
+ server_info = ServerInfo.from_version_string(version_string)
+ assert (server_info.species and server_info.species.name) == species or ServerSpecies.Unknown
+ assert server_info.version_str == parsed_version_string
+ assert server_info.version == version