diff options
Diffstat (limited to '')
-rw-r--r-- | test/features/connection.feature | 35 | ||||
-rw-r--r-- | test/features/environment.py | 48 | ||||
-rw-r--r-- | test/features/steps/auto_vertical.py | 3 | ||||
-rw-r--r-- | test/features/steps/connection.py | 71 | ||||
-rw-r--r-- | test/features/steps/utils.py | 12 | ||||
-rw-r--r-- | test/features/steps/wrappers.py | 55 | ||||
-rw-r--r-- | test/test_main.py | 9 | ||||
-rw-r--r-- | test/test_sqlexecute.py | 22 |
8 files changed, 227 insertions, 28 deletions
diff --git a/test/features/connection.feature b/test/features/connection.feature new file mode 100644 index 0000000..b06935e --- /dev/null +++ b/test/features/connection.feature @@ -0,0 +1,35 @@ +Feature: connect to a database: + + @requires_local_db + Scenario: run mycli on localhost without port + When we run mycli with arguments "host=localhost" without arguments "port" + When we query "status" + Then status contains "via UNIX socket" + + Scenario: run mycli on TCP host without port + When we run mycli without arguments "port" + When we query "status" + Then status contains "via TCP/IP" + + Scenario: run mycli with port but without host + When we run mycli without arguments "host" + When we query "status" + Then status contains "via TCP/IP" + + @requires_local_db + Scenario: run mycli without host and port + When we run mycli without arguments "host port" + When we query "status" + Then status contains "via UNIX socket" + + Scenario: run mycli with my.cnf configuration + When we create my.cnf file + When we run mycli without arguments "host port user pass defaults_file" + Then we are logged in + + Scenario: run mycli with mylogin.cnf configuration + When we create mylogin.cnf file + When we run mycli with arguments "login_path=test_login_path" without arguments "host port user pass defaults_file" + Then we are logged in + + diff --git a/test/features/environment.py b/test/features/environment.py index 98c2004..1ea0f08 100644 --- a/test/features/environment.py +++ b/test/features/environment.py @@ -1,4 +1,5 @@ import os +import shutil import sys from tempfile import mkstemp @@ -11,6 +12,24 @@ from steps.wrappers import run_cli, wait_prompt test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log') +SELF_CONNECTING_FEATURES = ( + 'test/features/connection.feature', +) + + +MY_CNF_PATH = os.path.expanduser('~/.my.cnf') +MY_CNF_BACKUP_PATH = f'{MY_CNF_PATH}.backup' +MYLOGIN_CNF_PATH = os.path.expanduser('~/.mylogin.cnf') +MYLOGIN_CNF_BACKUP_PATH = f'{MYLOGIN_CNF_PATH}.backup' + + +def get_db_name_from_context(context): + return context.config.userdata.get( + 'my_test_db', None + ) or "mycli_behave_tests" + + + def before_all(context): """Set env parameters.""" os.environ['LINES'] = "100" @@ -22,7 +41,7 @@ def before_all(context): test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) login_path_file = os.path.join(test_dir, 'mylogin.cnf') - os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file +# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file context.package_root = os.path.abspath( os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) @@ -33,8 +52,7 @@ def before_all(context): context.exit_sent = False vi = '_'.join([str(x) for x in sys.version_info[:3]]) - db_name = context.config.userdata.get( - 'my_test_db', None) or "mycli_behave_tests" + db_name = get_db_name_from_context(context) db_name_full = '{0}_{1}'.format(db_name, vi) # Store get params from config/environment variables @@ -104,11 +122,18 @@ def before_step(context, _): context.atprompt = False -def before_scenario(context, _): +def before_scenario(context, arg): with open(test_log_file, 'w') as f: f.write('') - run_cli(context) - wait_prompt(context) + if arg.location.filename not in SELF_CONNECTING_FEATURES: + run_cli(context) + wait_prompt(context) + + if os.path.exists(MY_CNF_PATH): + shutil.move(MY_CNF_PATH, MY_CNF_BACKUP_PATH) + + if os.path.exists(MYLOGIN_CNF_PATH): + shutil.move(MYLOGIN_CNF_PATH, MYLOGIN_CNF_BACKUP_PATH) def after_scenario(context, _): @@ -134,6 +159,17 @@ def after_scenario(context, _): context.cli.sendcontrol('d') context.cli.expect_exact(pexpect.EOF, timeout=5) + if os.path.exists(MY_CNF_BACKUP_PATH): + shutil.move(MY_CNF_BACKUP_PATH, MY_CNF_PATH) + + if os.path.exists(MYLOGIN_CNF_BACKUP_PATH): + shutil.move(MYLOGIN_CNF_BACKUP_PATH, MYLOGIN_CNF_PATH) + elif os.path.exists(MYLOGIN_CNF_PATH): + # This file was moved in `before_scenario`. + # If it exists now, it has been created during a test + os.remove(MYLOGIN_CNF_PATH) + + # TODO: uncomment to debug a failure # def after_step(context, step): # if step.status == "failed": diff --git a/test/features/steps/auto_vertical.py b/test/features/steps/auto_vertical.py index 974740d..e1cb26f 100644 --- a/test/features/steps/auto_vertical.py +++ b/test/features/steps/auto_vertical.py @@ -3,11 +3,12 @@ from textwrap import dedent from behave import then, when import wrappers +from utils import parse_cli_args_to_dict @when('we run dbcli with {arg}') def step_run_cli_with_arg(context, arg): - wrappers.run_cli(context, run_args=arg.split('=')) + wrappers.run_cli(context, run_args=parse_cli_args_to_dict(arg)) @when('we execute a small query') diff --git a/test/features/steps/connection.py b/test/features/steps/connection.py new file mode 100644 index 0000000..e16dd86 --- /dev/null +++ b/test/features/steps/connection.py @@ -0,0 +1,71 @@ +import io +import os +import shlex + +from behave import when, then +import pexpect + +import wrappers +from test.features.steps.utils import parse_cli_args_to_dict +from test.features.environment import MY_CNF_PATH, MYLOGIN_CNF_PATH, get_db_name_from_context +from test.utils import HOST, PORT, USER, PASSWORD +from mycli.config import encrypt_mylogin_cnf + + +TEST_LOGIN_PATH = 'test_login_path' + + +@when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"') +@when('we run mycli without arguments "{excluded_args}"') +def step_run_cli_without_args(context, excluded_args, exact_args=''): + wrappers.run_cli( + context, + run_args=parse_cli_args_to_dict(exact_args), + exclude_args=parse_cli_args_to_dict(excluded_args).keys() + ) + + +@then('status contains "{expression}"') +def status_contains(context, expression): + wrappers.expect_exact(context, f'{expression}', timeout=5) + + # Normally, the shutdown after scenario waits for the prompt. + # But we may have changed the prompt, depending on parameters, + # so let's wait for its last character + context.cli.expect_exact('>') + context.atprompt = True + + +@when('we create my.cnf file') +def step_create_my_cnf_file(context): + my_cnf = ( + '[client]\n' + f'host = {HOST}\n' + f'port = {PORT}\n' + f'user = {USER}\n' + f'password = {PASSWORD}\n' + ) + with open(MY_CNF_PATH, 'w') as f: + f.write(my_cnf) + + +@when('we create mylogin.cnf file') +def step_create_mylogin_cnf_file(context): + os.environ.pop('MYSQL_TEST_LOGIN_FILE', None) + mylogin_cnf = ( + f'[{TEST_LOGIN_PATH}]\n' + f'host = {HOST}\n' + f'port = {PORT}\n' + f'user = {USER}\n' + f'password = {PASSWORD}\n' + ) + with open(MYLOGIN_CNF_PATH, 'wb') as f: + input_file = io.StringIO(mylogin_cnf) + f.write(encrypt_mylogin_cnf(input_file).read()) + + +@then('we are logged in') +def we_are_logged_in(context): + db_name = get_db_name_from_context(context) + context.cli.expect_exact(f'{db_name}>', timeout=5) + context.atprompt = True diff --git a/test/features/steps/utils.py b/test/features/steps/utils.py new file mode 100644 index 0000000..1ae63d2 --- /dev/null +++ b/test/features/steps/utils.py @@ -0,0 +1,12 @@ +import shlex + + +def parse_cli_args_to_dict(cli_args: str): + args_dict = {} + for arg in shlex.split(cli_args): + if '=' in arg: + key, value = arg.split('=') + args_dict[key] = value + else: + args_dict[arg] = None + return args_dict diff --git a/test/features/steps/wrappers.py b/test/features/steps/wrappers.py index de833dd..6408f23 100644 --- a/test/features/steps/wrappers.py +++ b/test/features/steps/wrappers.py @@ -3,6 +3,7 @@ import pexpect import sys import textwrap + try: from StringIO import StringIO except ImportError: @@ -13,7 +14,7 @@ def expect_exact(context, expected, timeout): timedout = False try: context.cli.expect_exact(expected, timeout=timeout) - except pexpect.exceptions.TIMEOUT: + except pexpect.TIMEOUT: timedout = True if timedout: # Strip color codes out of the output. @@ -46,21 +47,43 @@ def expect_pager(context, expected, timeout): context.conf['pager_boundary'], expected), timeout=timeout) -def run_cli(context, run_args=None): +def run_cli(context, run_args=None, exclude_args=None): """Run the process using pexpect.""" - run_args = run_args or [] - if context.conf.get('host', None): - run_args.extend(('-h', context.conf['host'])) - if context.conf.get('user', None): - run_args.extend(('-u', context.conf['user'])) - if context.conf.get('pass', None): - run_args.extend(('-p', context.conf['pass'])) - if context.conf.get('dbname', None): - run_args.extend(('-D', context.conf['dbname'])) - if context.conf.get('defaults-file', None): - run_args.extend(('--defaults-file', context.conf['defaults-file'])) - if context.conf.get('myclirc', None): - run_args.extend(('--myclirc', context.conf['myclirc'])) + run_args = run_args or {} + rendered_args = [] + exclude_args = set(exclude_args) if exclude_args else set() + + conf = dict(**context.conf) + conf.update(run_args) + + def add_arg(name, key, value): + if name not in exclude_args: + if value is not None: + rendered_args.extend((key, value)) + else: + rendered_args.append(key) + + if conf.get('host', None): + add_arg('host', '-h', conf['host']) + if conf.get('user', None): + add_arg('user', '-u', conf['user']) + if conf.get('pass', None): + add_arg('pass', '-p', conf['pass']) + if conf.get('port', None): + add_arg('port', '-P', str(conf['port'])) + if conf.get('dbname', None): + add_arg('dbname', '-D', conf['dbname']) + if conf.get('defaults-file', None): + add_arg('defaults_file', '--defaults-file', conf['defaults-file']) + if conf.get('myclirc', None): + add_arg('myclirc', '--myclirc', conf['myclirc']) + if conf.get('login_path'): + add_arg('login_path', '--login-path', conf['login_path']) + + for arg_name, arg_value in conf.items(): + if arg_name.startswith('-'): + add_arg(arg_name, arg_name, arg_value) + try: cli_cmd = context.conf['cli_command'] except KeyError: @@ -73,7 +96,7 @@ def run_cli(context, run_args=None): '"' ).format(sys.executable) - cmd_parts = [cli_cmd] + run_args + cmd_parts = [cli_cmd] + rendered_args cmd = ' '.join(cmd_parts) context.cli = pexpect.spawnu(cmd, cwd=context.package_root) context.logfile = StringIO() diff --git a/test/test_main.py b/test/test_main.py index 707c359..00fdc1b 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -3,8 +3,9 @@ import os import click from click.testing import CliRunner -from mycli.main import MyCli, cli, thanks_picker, PACKAGE_ROOT +from mycli.main import MyCli, cli, thanks_picker from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS +from mycli.sqlexecute import ServerInfo from .utils import USER, HOST, PORT, PASSWORD, dbtest, run from textwrap import dedent @@ -140,10 +141,7 @@ def test_batch_mode_csv(executor): def test_thanks_picker_utf8(): - author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS') - sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS') - - name = thanks_picker((author_file, sponsor_file)) + name = thanks_picker() assert name and isinstance(name, str) @@ -177,6 +175,7 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager): host = 'test' user = 'test' dbname = 'test' + server_info = ServerInfo.from_version_string('unknown') port = 0 def server_type(self): diff --git a/test/test_sqlexecute.py b/test/test_sqlexecute.py index 5168bf6..0f38a97 100644 --- a/test/test_sqlexecute.py +++ b/test/test_sqlexecute.py @@ -3,6 +3,7 @@ import os import pytest import pymysql +from mycli.sqlexecute import ServerInfo, ServerSpecies from .utils import run, dbtest, set_expanded_output, is_expanded_output @@ -270,3 +271,24 @@ def test_multiple_results(executor): 'status': '1 row in set'} ] assert results == expected + + +@pytest.mark.parametrize( + 'version_string, species, parsed_version_string, version', + ( + ('5.7.32-35', 'Percona', '5.7.32', 50732), + ('5.7.32-0ubuntu0.18.04.1', 'MySQL', '5.7.32', 50732), + ('10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508), + ('5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508), + ('5.0.16-pro-nt-log', 'MySQL', '5.0.16', 50016), + ('5.1.5a-alpha', 'MySQL', '5.1.5', 50105), + ('unexpected version string', None, '', 0), + ('', None, '', 0), + (None, None, '', 0), + ) +) +def test_version_parsing(version_string, species, parsed_version_string, version): + server_info = ServerInfo.from_version_string(version_string) + assert (server_info.species and server_info.species.name) == species or ServerSpecies.Unknown + assert server_info.version_str == parsed_version_string + assert server_info.version == version |