summaryrefslogtreecommitdiffstats
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--test/features/connection.feature35
-rw-r--r--test/features/environment.py48
-rw-r--r--test/features/steps/auto_vertical.py3
-rw-r--r--test/features/steps/connection.py71
-rw-r--r--test/features/steps/utils.py12
-rw-r--r--test/features/steps/wrappers.py55
-rw-r--r--test/test_main.py9
-rw-r--r--test/test_sqlexecute.py22
8 files changed, 227 insertions, 28 deletions
diff --git a/test/features/connection.feature b/test/features/connection.feature
new file mode 100644
index 0000000..b06935e
--- /dev/null
+++ b/test/features/connection.feature
@@ -0,0 +1,35 @@
+Feature: connect to a database:
+
+ @requires_local_db
+ Scenario: run mycli on localhost without port
+ When we run mycli with arguments "host=localhost" without arguments "port"
+ When we query "status"
+ Then status contains "via UNIX socket"
+
+ Scenario: run mycli on TCP host without port
+ When we run mycli without arguments "port"
+ When we query "status"
+ Then status contains "via TCP/IP"
+
+ Scenario: run mycli with port but without host
+ When we run mycli without arguments "host"
+ When we query "status"
+ Then status contains "via TCP/IP"
+
+ @requires_local_db
+ Scenario: run mycli without host and port
+ When we run mycli without arguments "host port"
+ When we query "status"
+ Then status contains "via UNIX socket"
+
+ Scenario: run mycli with my.cnf configuration
+ When we create my.cnf file
+ When we run mycli without arguments "host port user pass defaults_file"
+ Then we are logged in
+
+ Scenario: run mycli with mylogin.cnf configuration
+ When we create mylogin.cnf file
+ When we run mycli with arguments "login_path=test_login_path" without arguments "host port user pass defaults_file"
+ Then we are logged in
+
+
diff --git a/test/features/environment.py b/test/features/environment.py
index 98c2004..1ea0f08 100644
--- a/test/features/environment.py
+++ b/test/features/environment.py
@@ -1,4 +1,5 @@
import os
+import shutil
import sys
from tempfile import mkstemp
@@ -11,6 +12,24 @@ from steps.wrappers import run_cli, wait_prompt
test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log')
+SELF_CONNECTING_FEATURES = (
+ 'test/features/connection.feature',
+)
+
+
+MY_CNF_PATH = os.path.expanduser('~/.my.cnf')
+MY_CNF_BACKUP_PATH = f'{MY_CNF_PATH}.backup'
+MYLOGIN_CNF_PATH = os.path.expanduser('~/.mylogin.cnf')
+MYLOGIN_CNF_BACKUP_PATH = f'{MYLOGIN_CNF_PATH}.backup'
+
+
+def get_db_name_from_context(context):
+ return context.config.userdata.get(
+ 'my_test_db', None
+ ) or "mycli_behave_tests"
+
+
+
def before_all(context):
"""Set env parameters."""
os.environ['LINES'] = "100"
@@ -22,7 +41,7 @@ def before_all(context):
test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
login_path_file = os.path.join(test_dir, 'mylogin.cnf')
- os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
+# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
context.package_root = os.path.abspath(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
@@ -33,8 +52,7 @@ def before_all(context):
context.exit_sent = False
vi = '_'.join([str(x) for x in sys.version_info[:3]])
- db_name = context.config.userdata.get(
- 'my_test_db', None) or "mycli_behave_tests"
+ db_name = get_db_name_from_context(context)
db_name_full = '{0}_{1}'.format(db_name, vi)
# Store get params from config/environment variables
@@ -104,11 +122,18 @@ def before_step(context, _):
context.atprompt = False
-def before_scenario(context, _):
+def before_scenario(context, arg):
with open(test_log_file, 'w') as f:
f.write('')
- run_cli(context)
- wait_prompt(context)
+ if arg.location.filename not in SELF_CONNECTING_FEATURES:
+ run_cli(context)
+ wait_prompt(context)
+
+ if os.path.exists(MY_CNF_PATH):
+ shutil.move(MY_CNF_PATH, MY_CNF_BACKUP_PATH)
+
+ if os.path.exists(MYLOGIN_CNF_PATH):
+ shutil.move(MYLOGIN_CNF_PATH, MYLOGIN_CNF_BACKUP_PATH)
def after_scenario(context, _):
@@ -134,6 +159,17 @@ def after_scenario(context, _):
context.cli.sendcontrol('d')
context.cli.expect_exact(pexpect.EOF, timeout=5)
+ if os.path.exists(MY_CNF_BACKUP_PATH):
+ shutil.move(MY_CNF_BACKUP_PATH, MY_CNF_PATH)
+
+ if os.path.exists(MYLOGIN_CNF_BACKUP_PATH):
+ shutil.move(MYLOGIN_CNF_BACKUP_PATH, MYLOGIN_CNF_PATH)
+ elif os.path.exists(MYLOGIN_CNF_PATH):
+ # This file was moved in `before_scenario`.
+ # If it exists now, it has been created during a test
+ os.remove(MYLOGIN_CNF_PATH)
+
+
# TODO: uncomment to debug a failure
# def after_step(context, step):
# if step.status == "failed":
diff --git a/test/features/steps/auto_vertical.py b/test/features/steps/auto_vertical.py
index 974740d..e1cb26f 100644
--- a/test/features/steps/auto_vertical.py
+++ b/test/features/steps/auto_vertical.py
@@ -3,11 +3,12 @@ from textwrap import dedent
from behave import then, when
import wrappers
+from utils import parse_cli_args_to_dict
@when('we run dbcli with {arg}')
def step_run_cli_with_arg(context, arg):
- wrappers.run_cli(context, run_args=arg.split('='))
+ wrappers.run_cli(context, run_args=parse_cli_args_to_dict(arg))
@when('we execute a small query')
diff --git a/test/features/steps/connection.py b/test/features/steps/connection.py
new file mode 100644
index 0000000..e16dd86
--- /dev/null
+++ b/test/features/steps/connection.py
@@ -0,0 +1,71 @@
+import io
+import os
+import shlex
+
+from behave import when, then
+import pexpect
+
+import wrappers
+from test.features.steps.utils import parse_cli_args_to_dict
+from test.features.environment import MY_CNF_PATH, MYLOGIN_CNF_PATH, get_db_name_from_context
+from test.utils import HOST, PORT, USER, PASSWORD
+from mycli.config import encrypt_mylogin_cnf
+
+
+TEST_LOGIN_PATH = 'test_login_path'
+
+
+@when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"')
+@when('we run mycli without arguments "{excluded_args}"')
+def step_run_cli_without_args(context, excluded_args, exact_args=''):
+ wrappers.run_cli(
+ context,
+ run_args=parse_cli_args_to_dict(exact_args),
+ exclude_args=parse_cli_args_to_dict(excluded_args).keys()
+ )
+
+
+@then('status contains "{expression}"')
+def status_contains(context, expression):
+ wrappers.expect_exact(context, f'{expression}', timeout=5)
+
+ # Normally, the shutdown after scenario waits for the prompt.
+ # But we may have changed the prompt, depending on parameters,
+ # so let's wait for its last character
+ context.cli.expect_exact('>')
+ context.atprompt = True
+
+
+@when('we create my.cnf file')
+def step_create_my_cnf_file(context):
+ my_cnf = (
+ '[client]\n'
+ f'host = {HOST}\n'
+ f'port = {PORT}\n'
+ f'user = {USER}\n'
+ f'password = {PASSWORD}\n'
+ )
+ with open(MY_CNF_PATH, 'w') as f:
+ f.write(my_cnf)
+
+
+@when('we create mylogin.cnf file')
+def step_create_mylogin_cnf_file(context):
+ os.environ.pop('MYSQL_TEST_LOGIN_FILE', None)
+ mylogin_cnf = (
+ f'[{TEST_LOGIN_PATH}]\n'
+ f'host = {HOST}\n'
+ f'port = {PORT}\n'
+ f'user = {USER}\n'
+ f'password = {PASSWORD}\n'
+ )
+ with open(MYLOGIN_CNF_PATH, 'wb') as f:
+ input_file = io.StringIO(mylogin_cnf)
+ f.write(encrypt_mylogin_cnf(input_file).read())
+
+
+@then('we are logged in')
+def we_are_logged_in(context):
+ db_name = get_db_name_from_context(context)
+ context.cli.expect_exact(f'{db_name}>', timeout=5)
+ context.atprompt = True
diff --git a/test/features/steps/utils.py b/test/features/steps/utils.py
new file mode 100644
index 0000000..1ae63d2
--- /dev/null
+++ b/test/features/steps/utils.py
@@ -0,0 +1,12 @@
+import shlex
+
+
+def parse_cli_args_to_dict(cli_args: str):
+ args_dict = {}
+ for arg in shlex.split(cli_args):
+ if '=' in arg:
+ key, value = arg.split('=')
+ args_dict[key] = value
+ else:
+ args_dict[arg] = None
+ return args_dict
diff --git a/test/features/steps/wrappers.py b/test/features/steps/wrappers.py
index de833dd..6408f23 100644
--- a/test/features/steps/wrappers.py
+++ b/test/features/steps/wrappers.py
@@ -3,6 +3,7 @@ import pexpect
import sys
import textwrap
+
try:
from StringIO import StringIO
except ImportError:
@@ -13,7 +14,7 @@ def expect_exact(context, expected, timeout):
timedout = False
try:
context.cli.expect_exact(expected, timeout=timeout)
- except pexpect.exceptions.TIMEOUT:
+ except pexpect.TIMEOUT:
timedout = True
if timedout:
# Strip color codes out of the output.
@@ -46,21 +47,43 @@ def expect_pager(context, expected, timeout):
context.conf['pager_boundary'], expected), timeout=timeout)
-def run_cli(context, run_args=None):
+def run_cli(context, run_args=None, exclude_args=None):
"""Run the process using pexpect."""
- run_args = run_args or []
- if context.conf.get('host', None):
- run_args.extend(('-h', context.conf['host']))
- if context.conf.get('user', None):
- run_args.extend(('-u', context.conf['user']))
- if context.conf.get('pass', None):
- run_args.extend(('-p', context.conf['pass']))
- if context.conf.get('dbname', None):
- run_args.extend(('-D', context.conf['dbname']))
- if context.conf.get('defaults-file', None):
- run_args.extend(('--defaults-file', context.conf['defaults-file']))
- if context.conf.get('myclirc', None):
- run_args.extend(('--myclirc', context.conf['myclirc']))
+ run_args = run_args or {}
+ rendered_args = []
+ exclude_args = set(exclude_args) if exclude_args else set()
+
+ conf = dict(**context.conf)
+ conf.update(run_args)
+
+ def add_arg(name, key, value):
+ if name not in exclude_args:
+ if value is not None:
+ rendered_args.extend((key, value))
+ else:
+ rendered_args.append(key)
+
+ if conf.get('host', None):
+ add_arg('host', '-h', conf['host'])
+ if conf.get('user', None):
+ add_arg('user', '-u', conf['user'])
+ if conf.get('pass', None):
+ add_arg('pass', '-p', conf['pass'])
+ if conf.get('port', None):
+ add_arg('port', '-P', str(conf['port']))
+ if conf.get('dbname', None):
+ add_arg('dbname', '-D', conf['dbname'])
+ if conf.get('defaults-file', None):
+ add_arg('defaults_file', '--defaults-file', conf['defaults-file'])
+ if conf.get('myclirc', None):
+ add_arg('myclirc', '--myclirc', conf['myclirc'])
+ if conf.get('login_path'):
+ add_arg('login_path', '--login-path', conf['login_path'])
+
+ for arg_name, arg_value in conf.items():
+ if arg_name.startswith('-'):
+ add_arg(arg_name, arg_name, arg_value)
+
try:
cli_cmd = context.conf['cli_command']
except KeyError:
@@ -73,7 +96,7 @@ def run_cli(context, run_args=None):
'"'
).format(sys.executable)
- cmd_parts = [cli_cmd] + run_args
+ cmd_parts = [cli_cmd] + rendered_args
cmd = ' '.join(cmd_parts)
context.cli = pexpect.spawnu(cmd, cwd=context.package_root)
context.logfile = StringIO()
diff --git a/test/test_main.py b/test/test_main.py
index 707c359..00fdc1b 100644
--- a/test/test_main.py
+++ b/test/test_main.py
@@ -3,8 +3,9 @@ import os
import click
from click.testing import CliRunner
-from mycli.main import MyCli, cli, thanks_picker, PACKAGE_ROOT
+from mycli.main import MyCli, cli, thanks_picker
from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
+from mycli.sqlexecute import ServerInfo
from .utils import USER, HOST, PORT, PASSWORD, dbtest, run
from textwrap import dedent
@@ -140,10 +141,7 @@ def test_batch_mode_csv(executor):
def test_thanks_picker_utf8():
- author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS')
- sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS')
-
- name = thanks_picker((author_file, sponsor_file))
+ name = thanks_picker()
assert name and isinstance(name, str)
@@ -177,6 +175,7 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
host = 'test'
user = 'test'
dbname = 'test'
+ server_info = ServerInfo.from_version_string('unknown')
port = 0
def server_type(self):
diff --git a/test/test_sqlexecute.py b/test/test_sqlexecute.py
index 5168bf6..0f38a97 100644
--- a/test/test_sqlexecute.py
+++ b/test/test_sqlexecute.py
@@ -3,6 +3,7 @@ import os
import pytest
import pymysql
+from mycli.sqlexecute import ServerInfo, ServerSpecies
from .utils import run, dbtest, set_expanded_output, is_expanded_output
@@ -270,3 +271,24 @@ def test_multiple_results(executor):
'status': '1 row in set'}
]
assert results == expected
+
+
+@pytest.mark.parametrize(
+ 'version_string, species, parsed_version_string, version',
+ (
+ ('5.7.32-35', 'Percona', '5.7.32', 50732),
+ ('5.7.32-0ubuntu0.18.04.1', 'MySQL', '5.7.32', 50732),
+ ('10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
+ ('5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
+ ('5.0.16-pro-nt-log', 'MySQL', '5.0.16', 50016),
+ ('5.1.5a-alpha', 'MySQL', '5.1.5', 50105),
+ ('unexpected version string', None, '', 0),
+ ('', None, '', 0),
+ (None, None, '', 0),
+ )
+)
+def test_version_parsing(version_string, species, parsed_version_string, version):
+ server_info = ServerInfo.from_version_string(version_string)
+ assert (server_info.species and server_info.species.name) == species or ServerSpecies.Unknown
+ assert server_info.version_str == parsed_version_string
+ assert server_info.version == version