diff options
Diffstat (limited to 'test')
29 files changed, 1340 insertions, 1560 deletions
diff --git a/test/conftest.py b/test/conftest.py index 1325596..5575b40 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,13 +1,12 @@ import pytest -from .utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db, - db_connection, SSH_USER, SSH_HOST, SSH_PORT) +from .utils import HOST, USER, PASSWORD, PORT, CHARSET, create_db, db_connection, SSH_USER, SSH_HOST, SSH_PORT import mycli.sqlexecute @pytest.fixture(scope="function") def connection(): - create_db('mycli_test_db') - connection = db_connection('mycli_test_db') + create_db("mycli_test_db") + connection = db_connection("mycli_test_db") yield connection connection.close() @@ -22,8 +21,18 @@ def cursor(connection): @pytest.fixture def executor(connection): return mycli.sqlexecute.SQLExecute( - database='mycli_test_db', user=USER, - host=HOST, password=PASSWORD, port=PORT, socket=None, charset=CHARSET, - local_infile=False, ssl=None, ssh_user=SSH_USER, ssh_host=SSH_HOST, - ssh_port=SSH_PORT, ssh_password=None, ssh_key_filename=None + database="mycli_test_db", + user=USER, + host=HOST, + password=PASSWORD, + port=PORT, + socket=None, + charset=CHARSET, + local_infile=False, + ssl=None, + ssh_user=SSH_USER, + ssh_host=SSH_HOST, + ssh_port=SSH_PORT, + ssh_password=None, + ssh_key_filename=None, ) diff --git a/test/features/db_utils.py b/test/features/db_utils.py index be550e9..175cc1b 100644 --- a/test/features/db_utils.py +++ b/test/features/db_utils.py @@ -1,8 +1,7 @@ import pymysql -def create_db(hostname='localhost', port=3306, username=None, - password=None, dbname=None): +def create_db(hostname="localhost", port=3306, username=None, password=None, dbname=None): """Create test database. :param hostname: string @@ -14,17 +13,12 @@ def create_db(hostname='localhost', port=3306, username=None, """ cn = pymysql.connect( - host=hostname, - port=port, - user=username, - password=password, - charset='utf8mb4', - cursorclass=pymysql.cursors.DictCursor + host=hostname, port=port, user=username, password=password, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor ) with cn.cursor() as cr: - cr.execute('drop database if exists ' + dbname) - cr.execute('create database ' + dbname) + cr.execute("drop database if exists " + dbname) + cr.execute("create database " + dbname) cn.close() @@ -44,20 +38,13 @@ def create_cn(hostname, port, password, username, dbname): """ cn = pymysql.connect( - host=hostname, - port=port, - user=username, - password=password, - db=dbname, - charset='utf8mb4', - cursorclass=pymysql.cursors.DictCursor + host=hostname, port=port, user=username, password=password, db=dbname, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor ) return cn -def drop_db(hostname='localhost', port=3306, username=None, - password=None, dbname=None): +def drop_db(hostname="localhost", port=3306, username=None, password=None, dbname=None): """Drop database. :param hostname: string @@ -68,17 +55,11 @@ def drop_db(hostname='localhost', port=3306, username=None, """ cn = pymysql.connect( - host=hostname, - port=port, - user=username, - password=password, - db=dbname, - charset='utf8mb4', - cursorclass=pymysql.cursors.DictCursor + host=hostname, port=port, user=username, password=password, db=dbname, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor ) with cn.cursor() as cr: - cr.execute('drop database if exists ' + dbname) + cr.execute("drop database if exists " + dbname) close_cn(cn) diff --git a/test/features/environment.py b/test/features/environment.py index 1ea0f08..a3d3764 100644 --- a/test/features/environment.py +++ b/test/features/environment.py @@ -9,96 +9,72 @@ import pexpect from steps.wrappers import run_cli, wait_prompt -test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log') +test_log_file = os.path.join(os.environ["HOME"], ".mycli.test.log") -SELF_CONNECTING_FEATURES = ( - 'test/features/connection.feature', -) +SELF_CONNECTING_FEATURES = ("test/features/connection.feature",) -MY_CNF_PATH = os.path.expanduser('~/.my.cnf') -MY_CNF_BACKUP_PATH = f'{MY_CNF_PATH}.backup' -MYLOGIN_CNF_PATH = os.path.expanduser('~/.mylogin.cnf') -MYLOGIN_CNF_BACKUP_PATH = f'{MYLOGIN_CNF_PATH}.backup' +MY_CNF_PATH = os.path.expanduser("~/.my.cnf") +MY_CNF_BACKUP_PATH = f"{MY_CNF_PATH}.backup" +MYLOGIN_CNF_PATH = os.path.expanduser("~/.mylogin.cnf") +MYLOGIN_CNF_BACKUP_PATH = f"{MYLOGIN_CNF_PATH}.backup" def get_db_name_from_context(context): - return context.config.userdata.get( - 'my_test_db', None - ) or "mycli_behave_tests" - + return context.config.userdata.get("my_test_db", None) or "mycli_behave_tests" def before_all(context): """Set env parameters.""" - os.environ['LINES'] = "100" - os.environ['COLUMNS'] = "100" - os.environ['EDITOR'] = 'ex' - os.environ['LC_ALL'] = 'en_US.UTF-8' - os.environ['PROMPT_TOOLKIT_NO_CPR'] = '1' - os.environ['MYCLI_HISTFILE'] = os.devnull + os.environ["LINES"] = "100" + os.environ["COLUMNS"] = "100" + os.environ["EDITOR"] = "ex" + os.environ["LC_ALL"] = "en_US.UTF-8" + os.environ["PROMPT_TOOLKIT_NO_CPR"] = "1" + os.environ["MYCLI_HISTFILE"] = os.devnull - test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) - login_path_file = os.path.join(test_dir, 'mylogin.cnf') -# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file + # test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) + # login_path_file = os.path.join(test_dir, "mylogin.cnf") + # os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file - context.package_root = os.path.abspath( - os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + context.package_root = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) - os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root, - '.coveragerc') + os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root, ".coveragerc") context.exit_sent = False - vi = '_'.join([str(x) for x in sys.version_info[:3]]) + vi = "_".join([str(x) for x in sys.version_info[:3]]) db_name = get_db_name_from_context(context) - db_name_full = '{0}_{1}'.format(db_name, vi) + db_name_full = "{0}_{1}".format(db_name, vi) # Store get params from config/environment variables context.conf = { - 'host': context.config.userdata.get( - 'my_test_host', - os.getenv('PYTEST_HOST', 'localhost') - ), - 'port': context.config.userdata.get( - 'my_test_port', - int(os.getenv('PYTEST_PORT', '3306')) - ), - 'user': context.config.userdata.get( - 'my_test_user', - os.getenv('PYTEST_USER', 'root') - ), - 'pass': context.config.userdata.get( - 'my_test_pass', - os.getenv('PYTEST_PASSWORD', None) - ), - 'cli_command': context.config.userdata.get( - 'my_cli_command', None) or - sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"', - 'dbname': db_name, - 'dbname_tmp': db_name_full + '_tmp', - 'vi': vi, - 'pager_boundary': '---boundary---', + "host": context.config.userdata.get("my_test_host", os.getenv("PYTEST_HOST", "localhost")), + "port": context.config.userdata.get("my_test_port", int(os.getenv("PYTEST_PORT", "3306"))), + "user": context.config.userdata.get("my_test_user", os.getenv("PYTEST_USER", "root")), + "pass": context.config.userdata.get("my_test_pass", os.getenv("PYTEST_PASSWORD", None)), + "cli_command": context.config.userdata.get("my_cli_command", None) + or sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"', + "dbname": db_name, + "dbname_tmp": db_name_full + "_tmp", + "vi": vi, + "pager_boundary": "---boundary---", } _, my_cnf = mkstemp() - with open(my_cnf, 'w') as f: + with open(my_cnf, "w") as f: f.write( - '[client]\n' - 'pager={0} {1} {2}\n'.format( - sys.executable, os.path.join(context.package_root, - 'test/features/wrappager.py'), - context.conf['pager_boundary']) + "[client]\n" "pager={0} {1} {2}\n".format( + sys.executable, os.path.join(context.package_root, "test/features/wrappager.py"), context.conf["pager_boundary"] + ) ) - context.conf['defaults-file'] = my_cnf - context.conf['myclirc'] = os.path.join(context.package_root, 'test', - 'myclirc') + context.conf["defaults-file"] = my_cnf + context.conf["myclirc"] = os.path.join(context.package_root, "test", "myclirc") - context.cn = dbutils.create_db(context.conf['host'], context.conf['port'], - context.conf['user'], - context.conf['pass'], - context.conf['dbname']) + context.cn = dbutils.create_db( + context.conf["host"], context.conf["port"], context.conf["user"], context.conf["pass"], context.conf["dbname"] + ) context.fixture_data = fixutils.read_fixture_files() @@ -106,12 +82,10 @@ def before_all(context): def after_all(context): """Unset env parameters.""" dbutils.close_cn(context.cn) - dbutils.drop_db(context.conf['host'], context.conf['port'], - context.conf['user'], context.conf['pass'], - context.conf['dbname']) + dbutils.drop_db(context.conf["host"], context.conf["port"], context.conf["user"], context.conf["pass"], context.conf["dbname"]) # Restore env vars. - #for k, v in context.pgenv.items(): + # for k, v in context.pgenv.items(): # if k in os.environ and v is None: # del os.environ[k] # elif v: @@ -123,8 +97,8 @@ def before_step(context, _): def before_scenario(context, arg): - with open(test_log_file, 'w') as f: - f.write('') + with open(test_log_file, "w") as f: + f.write("") if arg.location.filename not in SELF_CONNECTING_FEATURES: run_cli(context) wait_prompt(context) @@ -140,23 +114,18 @@ def after_scenario(context, _): """Cleans up after each test complete.""" with open(test_log_file) as f: for line in f: - if 'error' in line.lower(): - raise RuntimeError(f'Error in log file: {line}') + if "error" in line.lower(): + raise RuntimeError(f"Error in log file: {line}") - if hasattr(context, 'cli') and not context.exit_sent: + if hasattr(context, "cli") and not context.exit_sent: # Quit nicely. if not context.atprompt: - user = context.conf['user'] - host = context.conf['host'] + user = context.conf["user"] + host = context.conf["host"] dbname = context.currentdb - context.cli.expect_exact( - '{0}@{1}:{2}>'.format( - user, host, dbname - ), - timeout=5 - ) - context.cli.sendcontrol('c') - context.cli.sendcontrol('d') + context.cli.expect_exact("{0}@{1}:{2}>".format(user, host, dbname), timeout=5) + context.cli.sendcontrol("c") + context.cli.sendcontrol("d") context.cli.expect_exact(pexpect.EOF, timeout=5) if os.path.exists(MY_CNF_BACKUP_PATH): diff --git a/test/features/fixture_utils.py b/test/features/fixture_utils.py index f85e0f6..514e41f 100644 --- a/test/features/fixture_utils.py +++ b/test/features/fixture_utils.py @@ -1,5 +1,4 @@ import os -import io def read_fixture_lines(filename): @@ -20,9 +19,9 @@ def read_fixture_files(): fixture_dict = {} current_dir = os.path.dirname(__file__) - fixture_dir = os.path.join(current_dir, 'fixture_data/') + fixture_dir = os.path.join(current_dir, "fixture_data/") for filename in os.listdir(fixture_dir): - if filename not in ['.', '..']: + if filename not in [".", ".."]: fullname = os.path.join(fixture_dir, filename) fixture_dict[filename] = read_fixture_lines(fullname) diff --git a/test/features/steps/auto_vertical.py b/test/features/steps/auto_vertical.py index e1cb26f..ad20067 100644 --- a/test/features/steps/auto_vertical.py +++ b/test/features/steps/auto_vertical.py @@ -6,41 +6,42 @@ import wrappers from utils import parse_cli_args_to_dict -@when('we run dbcli with {arg}') +@when("we run dbcli with {arg}") def step_run_cli_with_arg(context, arg): wrappers.run_cli(context, run_args=parse_cli_args_to_dict(arg)) -@when('we execute a small query') +@when("we execute a small query") def step_execute_small_query(context): - context.cli.sendline('select 1') + context.cli.sendline("select 1") -@when('we execute a large query') +@when("we execute a large query") def step_execute_large_query(context): - context.cli.sendline( - 'select {}'.format(','.join([str(n) for n in range(1, 50)]))) + context.cli.sendline("select {}".format(",".join([str(n) for n in range(1, 50)]))) -@then('we see small results in horizontal format') +@then("we see small results in horizontal format") def step_see_small_results(context): - wrappers.expect_pager(context, dedent("""\ + wrappers.expect_pager( + context, + dedent("""\ +---+\r | 1 |\r +---+\r | 1 |\r +---+\r \r - """), timeout=5) - wrappers.expect_exact(context, '1 row in set', timeout=2) + """), + timeout=5, + ) + wrappers.expect_exact(context, "1 row in set", timeout=2) -@then('we see large results in vertical format') +@then("we see large results in vertical format") def step_see_large_results(context): - rows = ['{n:3}| {n}'.format(n=str(n)) for n in range(1, 50)] - expected = ('***************************[ 1. row ]' - '***************************\r\n' + - '{}\r\n'.format('\r\n'.join(rows) + '\r\n')) + rows = ["{n:3}| {n}".format(n=str(n)) for n in range(1, 50)] + expected = "***************************[ 1. row ]" "***************************\r\n" + "{}\r\n".format("\r\n".join(rows) + "\r\n") wrappers.expect_pager(context, expected, timeout=10) - wrappers.expect_exact(context, '1 row in set', timeout=2) + wrappers.expect_exact(context, "1 row in set", timeout=2) diff --git a/test/features/steps/basic_commands.py b/test/features/steps/basic_commands.py index 425ef67..ec1e47a 100644 --- a/test/features/steps/basic_commands.py +++ b/test/features/steps/basic_commands.py @@ -5,18 +5,18 @@ to call the step in "*.feature" file. """ -from behave import when +from behave import when, then from textwrap import dedent import tempfile import wrappers -@when('we run dbcli') +@when("we run dbcli") def step_run_cli(context): wrappers.run_cli(context) -@when('we wait for prompt') +@when("we wait for prompt") def step_wait_prompt(context): wrappers.wait_prompt(context) @@ -24,77 +24,75 @@ def step_wait_prompt(context): @when('we send "ctrl + d"') def step_ctrl_d(context): """Send Ctrl + D to hopefully exit.""" - context.cli.sendcontrol('d') + context.cli.sendcontrol("d") context.exit_sent = True -@when('we send "\?" command') +@when(r'we send "\?" command') def step_send_help(context): - """Send \? + r"""Send \? to see help. """ - context.cli.sendline('\\?') - wrappers.expect_exact( - context, context.conf['pager_boundary'] + '\r\n', timeout=5) + context.cli.sendline("\\?") + wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) -@when(u'we send source command') +@when("we send source command") def step_send_source_command(context): with tempfile.NamedTemporaryFile() as f: - f.write(b'\?') + f.write(b"\\?") f.flush() - context.cli.sendline('\. {0}'.format(f.name)) - wrappers.expect_exact( - context, context.conf['pager_boundary'] + '\r\n', timeout=5) + context.cli.sendline("\\. {0}".format(f.name)) + wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) -@when(u'we run query to check application_name') +@when("we run query to check application_name") def step_check_application_name(context): context.cli.sendline( "SELECT 'found' FROM performance_schema.session_connect_attrs WHERE attr_name = 'program_name' AND attr_value = 'mycli'" ) -@then(u'we see found') +@then("we see found") def step_see_found(context): wrappers.expect_exact( context, - context.conf['pager_boundary'] + '\r' + dedent(''' + context.conf["pager_boundary"] + + "\r" + + dedent(""" +-------+\r | found |\r +-------+\r | found |\r +-------+\r \r - ''') + context.conf['pager_boundary'], - timeout=5 + """) + + context.conf["pager_boundary"], + timeout=5, ) -@then(u'we confirm the destructive warning') -def step_confirm_destructive_command(context): +@then("we confirm the destructive warning") +def step_confirm_destructive_command(context): # noqa """Confirm destructive command.""" - wrappers.expect_exact( - context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) - context.cli.sendline('y') + wrappers.expect_exact(context, "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", timeout=2) + context.cli.sendline("y") -@when(u'we answer the destructive warning with "{confirmation}"') -def step_confirm_destructive_command(context, confirmation): +@when('we answer the destructive warning with "{confirmation}"') +def step_confirm_destructive_command(context, confirmation): # noqa """Confirm destructive command.""" - wrappers.expect_exact( - context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) + wrappers.expect_exact(context, "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", timeout=2) context.cli.sendline(confirmation) -@then(u'we answer the destructive warning with invalid "{confirmation}" and see text "{text}"') -def step_confirm_destructive_command(context, confirmation, text): +@then('we answer the destructive warning with invalid "{confirmation}" and see text "{text}"') +def step_confirm_destructive_command(context, confirmation, text): # noqa """Confirm destructive command.""" - wrappers.expect_exact( - context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) + wrappers.expect_exact(context, "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", timeout=2) context.cli.sendline(confirmation) wrappers.expect_exact(context, text, timeout=2) # we must exit the Click loop, or the feature will hang - context.cli.sendline('n') + context.cli.sendline("n") diff --git a/test/features/steps/connection.py b/test/features/steps/connection.py index e16dd86..80d0653 100644 --- a/test/features/steps/connection.py +++ b/test/features/steps/connection.py @@ -1,9 +1,7 @@ import io import os -import shlex from behave import when, then -import pexpect import wrappers from test.features.steps.utils import parse_cli_args_to_dict @@ -12,60 +10,44 @@ from test.utils import HOST, PORT, USER, PASSWORD from mycli.config import encrypt_mylogin_cnf -TEST_LOGIN_PATH = 'test_login_path' +TEST_LOGIN_PATH = "test_login_path" @when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"') @when('we run mycli without arguments "{excluded_args}"') -def step_run_cli_without_args(context, excluded_args, exact_args=''): - wrappers.run_cli( - context, - run_args=parse_cli_args_to_dict(exact_args), - exclude_args=parse_cli_args_to_dict(excluded_args).keys() - ) +def step_run_cli_without_args(context, excluded_args, exact_args=""): + wrappers.run_cli(context, run_args=parse_cli_args_to_dict(exact_args), exclude_args=parse_cli_args_to_dict(excluded_args).keys()) @then('status contains "{expression}"') def status_contains(context, expression): - wrappers.expect_exact(context, f'{expression}', timeout=5) + wrappers.expect_exact(context, f"{expression}", timeout=5) # Normally, the shutdown after scenario waits for the prompt. # But we may have changed the prompt, depending on parameters, # so let's wait for its last character - context.cli.expect_exact('>') + context.cli.expect_exact(">") context.atprompt = True -@when('we create my.cnf file') +@when("we create my.cnf file") def step_create_my_cnf_file(context): - my_cnf = ( - '[client]\n' - f'host = {HOST}\n' - f'port = {PORT}\n' - f'user = {USER}\n' - f'password = {PASSWORD}\n' - ) - with open(MY_CNF_PATH, 'w') as f: + my_cnf = "[client]\n" f"host = {HOST}\n" f"port = {PORT}\n" f"user = {USER}\n" f"password = {PASSWORD}\n" + with open(MY_CNF_PATH, "w") as f: f.write(my_cnf) -@when('we create mylogin.cnf file') +@when("we create mylogin.cnf file") def step_create_mylogin_cnf_file(context): - os.environ.pop('MYSQL_TEST_LOGIN_FILE', None) - mylogin_cnf = ( - f'[{TEST_LOGIN_PATH}]\n' - f'host = {HOST}\n' - f'port = {PORT}\n' - f'user = {USER}\n' - f'password = {PASSWORD}\n' - ) - with open(MYLOGIN_CNF_PATH, 'wb') as f: + os.environ.pop("MYSQL_TEST_LOGIN_FILE", None) + mylogin_cnf = f"[{TEST_LOGIN_PATH}]\n" f"host = {HOST}\n" f"port = {PORT}\n" f"user = {USER}\n" f"password = {PASSWORD}\n" + with open(MYLOGIN_CNF_PATH, "wb") as f: input_file = io.StringIO(mylogin_cnf) f.write(encrypt_mylogin_cnf(input_file).read()) -@then('we are logged in') +@then("we are logged in") def we_are_logged_in(context): db_name = get_db_name_from_context(context) - context.cli.expect_exact(f'{db_name}>', timeout=5) + context.cli.expect_exact(f"{db_name}>", timeout=5) context.atprompt = True diff --git a/test/features/steps/crud_database.py b/test/features/steps/crud_database.py index 841f37d..56ff114 100644 --- a/test/features/steps/crud_database.py +++ b/test/features/steps/crud_database.py @@ -11,105 +11,99 @@ import wrappers from behave import when, then -@when('we create database') +@when("we create database") def step_db_create(context): """Send create database.""" - context.cli.sendline('create database {0};'.format( - context.conf['dbname_tmp'])) + context.cli.sendline("create database {0};".format(context.conf["dbname_tmp"])) - context.response = { - 'database_name': context.conf['dbname_tmp'] - } + context.response = {"database_name": context.conf["dbname_tmp"]} -@when('we drop database') +@when("we drop database") def step_db_drop(context): """Send drop database.""" - context.cli.sendline('drop database {0};'.format( - context.conf['dbname_tmp'])) + context.cli.sendline("drop database {0};".format(context.conf["dbname_tmp"])) -@when('we connect to test database') +@when("we connect to test database") def step_db_connect_test(context): """Send connect to database.""" - db_name = context.conf['dbname'] + db_name = context.conf["dbname"] context.currentdb = db_name - context.cli.sendline('use {0};'.format(db_name)) + context.cli.sendline("use {0};".format(db_name)) -@when('we connect to quoted test database') +@when("we connect to quoted test database") def step_db_connect_quoted_tmp(context): """Send connect to database.""" - db_name = context.conf['dbname'] + db_name = context.conf["dbname"] context.currentdb = db_name - context.cli.sendline('use `{0}`;'.format(db_name)) + context.cli.sendline("use `{0}`;".format(db_name)) -@when('we connect to tmp database') +@when("we connect to tmp database") def step_db_connect_tmp(context): """Send connect to database.""" - db_name = context.conf['dbname_tmp'] + db_name = context.conf["dbname_tmp"] context.currentdb = db_name - context.cli.sendline('use {0}'.format(db_name)) + context.cli.sendline("use {0}".format(db_name)) -@when('we connect to dbserver') +@when("we connect to dbserver") def step_db_connect_dbserver(context): """Send connect to database.""" - context.currentdb = 'mysql' - context.cli.sendline('use mysql') + context.currentdb = "mysql" + context.cli.sendline("use mysql") -@then('dbcli exits') +@then("dbcli exits") def step_wait_exit(context): """Make sure the cli exits.""" wrappers.expect_exact(context, pexpect.EOF, timeout=5) -@then('we see dbcli prompt') +@then("we see dbcli prompt") def step_see_prompt(context): """Wait to see the prompt.""" - user = context.conf['user'] - host = context.conf['host'] + user = context.conf["user"] + host = context.conf["host"] dbname = context.currentdb - wrappers.wait_prompt(context, '{0}@{1}:{2}> '.format(user, host, dbname)) + wrappers.wait_prompt(context, "{0}@{1}:{2}> ".format(user, host, dbname)) -@then('we see help output') +@then("we see help output") def step_see_help(context): - for expected_line in context.fixture_data['help_commands.txt']: + for expected_line in context.fixture_data["help_commands.txt"]: wrappers.expect_exact(context, expected_line, timeout=1) -@then('we see database created') +@then("we see database created") def step_see_db_created(context): """Wait to see create database output.""" - wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2) -@then('we see database dropped') +@then("we see database dropped") def step_see_db_dropped(context): """Wait to see drop database output.""" - wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) + wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2) -@then('we see database dropped and no default database') +@then("we see database dropped and no default database") def step_see_db_dropped_no_default(context): """Wait to see drop database output.""" - user = context.conf['user'] - host = context.conf['host'] - database = '(none)' + user = context.conf["user"] + host = context.conf["host"] + database = "(none)" context.currentdb = None - wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) - wrappers.wait_prompt(context, '{0}@{1}:{2}>'.format(user, host, database)) + wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2) + wrappers.wait_prompt(context, "{0}@{1}:{2}>".format(user, host, database)) -@then('we see database connected') +@then("we see database connected") def step_see_db_connected(context): """Wait to see drop database output.""" - wrappers.expect_exact( - context, 'You are now connected to database "', timeout=2) + wrappers.expect_exact(context, 'You are now connected to database "', timeout=2) wrappers.expect_exact(context, '"', timeout=2) - wrappers.expect_exact(context, ' as user "{0}"'.format( - context.conf['user']), timeout=2) + wrappers.expect_exact(context, ' as user "{0}"'.format(context.conf["user"]), timeout=2) diff --git a/test/features/steps/crud_table.py b/test/features/steps/crud_table.py index f715f0c..48a6408 100644 --- a/test/features/steps/crud_table.py +++ b/test/features/steps/crud_table.py @@ -10,103 +10,109 @@ from behave import when, then from textwrap import dedent -@when('we create table') +@when("we create table") def step_create_table(context): """Send create table.""" - context.cli.sendline('create table a(x text);') + context.cli.sendline("create table a(x text);") -@when('we insert into table') +@when("we insert into table") def step_insert_into_table(context): """Send insert into table.""" - context.cli.sendline('''insert into a(x) values('xxx');''') + context.cli.sendline("""insert into a(x) values('xxx');""") -@when('we update table') +@when("we update table") def step_update_table(context): """Send insert into table.""" - context.cli.sendline('''update a set x = 'yyy' where x = 'xxx';''') + context.cli.sendline("""update a set x = 'yyy' where x = 'xxx';""") -@when('we select from table') +@when("we select from table") def step_select_from_table(context): """Send select from table.""" - context.cli.sendline('select * from a;') + context.cli.sendline("select * from a;") -@when('we delete from table') +@when("we delete from table") def step_delete_from_table(context): """Send deete from table.""" - context.cli.sendline('''delete from a where x = 'yyy';''') + context.cli.sendline("""delete from a where x = 'yyy';""") -@when('we drop table') +@when("we drop table") def step_drop_table(context): """Send drop table.""" - context.cli.sendline('drop table a;') + context.cli.sendline("drop table a;") -@then('we see table created') +@then("we see table created") def step_see_table_created(context): """Wait to see create table output.""" - wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) + wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2) -@then('we see record inserted') +@then("we see record inserted") def step_see_record_inserted(context): """Wait to see insert output.""" - wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2) -@then('we see record updated') +@then("we see record updated") def step_see_record_updated(context): """Wait to see update output.""" - wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2) -@then('we see data selected') +@then("we see data selected") def step_see_data_selected(context): """Wait to see select output.""" wrappers.expect_pager( - context, dedent("""\ + context, + dedent("""\ +-----+\r | x |\r +-----+\r | yyy |\r +-----+\r \r - """), timeout=2) - wrappers.expect_exact(context, '1 row in set', timeout=2) + """), + timeout=2, + ) + wrappers.expect_exact(context, "1 row in set", timeout=2) -@then('we see record deleted') +@then("we see record deleted") def step_see_data_deleted(context): """Wait to see delete output.""" - wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2) -@then('we see table dropped') +@then("we see table dropped") def step_see_table_dropped(context): """Wait to see drop output.""" - wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) + wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2) -@when('we select null') +@when("we select null") def step_select_null(context): """Send select null.""" - context.cli.sendline('select null;') + context.cli.sendline("select null;") -@then('we see null selected') +@then("we see null selected") def step_see_null_selected(context): """Wait to see null output.""" wrappers.expect_pager( - context, dedent("""\ + context, + dedent("""\ +--------+\r | NULL |\r +--------+\r | <null> |\r +--------+\r \r - """), timeout=2) - wrappers.expect_exact(context, '1 row in set', timeout=2) + """), + timeout=2, + ) + wrappers.expect_exact(context, "1 row in set", timeout=2) diff --git a/test/features/steps/iocommands.py b/test/features/steps/iocommands.py index bbabf43..07d5c77 100644 --- a/test/features/steps/iocommands.py +++ b/test/features/steps/iocommands.py @@ -5,101 +5,93 @@ from behave import when, then from textwrap import dedent -@when('we start external editor providing a file name') +@when("we start external editor providing a file name") def step_edit_file(context): """Edit file with external editor.""" - context.editor_file_name = os.path.join( - context.package_root, 'test_file_{0}.sql'.format(context.conf['vi'])) + context.editor_file_name = os.path.join(context.package_root, "test_file_{0}.sql".format(context.conf["vi"])) if os.path.exists(context.editor_file_name): os.remove(context.editor_file_name) - context.cli.sendline('\e {0}'.format( - os.path.basename(context.editor_file_name))) - wrappers.expect_exact( - context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2) - wrappers.expect_exact(context, '\r\n:', timeout=2) + context.cli.sendline("\\e {0}".format(os.path.basename(context.editor_file_name))) + wrappers.expect_exact(context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2) + wrappers.expect_exact(context, "\r\n:", timeout=2) @when('we type "{query}" in the editor') def step_edit_type_sql(context, query): - context.cli.sendline('i') + context.cli.sendline("i") context.cli.sendline(query) - context.cli.sendline('.') - wrappers.expect_exact(context, '\r\n:', timeout=2) + context.cli.sendline(".") + wrappers.expect_exact(context, "\r\n:", timeout=2) -@when('we exit the editor') +@when("we exit the editor") def step_edit_quit(context): - context.cli.sendline('x') + context.cli.sendline("x") wrappers.expect_exact(context, "written", timeout=2) @then('we see "{query}" in prompt') def step_edit_done_sql(context, query): - for match in query.split(' '): + for match in query.split(" "): wrappers.expect_exact(context, match, timeout=5) # Cleanup the command line. - context.cli.sendcontrol('c') + context.cli.sendcontrol("c") # Cleanup the edited file. if context.editor_file_name and os.path.exists(context.editor_file_name): os.remove(context.editor_file_name) -@when(u'we tee output') +@when("we tee output") def step_tee_ouptut(context): - context.tee_file_name = os.path.join( - context.package_root, 'tee_file_{0}.sql'.format(context.conf['vi'])) + context.tee_file_name = os.path.join(context.package_root, "tee_file_{0}.sql".format(context.conf["vi"])) if os.path.exists(context.tee_file_name): os.remove(context.tee_file_name) - context.cli.sendline('tee {0}'.format( - os.path.basename(context.tee_file_name))) + context.cli.sendline("tee {0}".format(os.path.basename(context.tee_file_name))) -@when(u'we select "select {param}"') +@when('we select "select {param}"') def step_query_select_number(context, param): - context.cli.sendline(u'select {}'.format(param)) - wrappers.expect_pager(context, dedent(u"""\ + context.cli.sendline("select {}".format(param)) + wrappers.expect_pager( + context, + dedent( + """\ +{dashes}+\r | {param} |\r +{dashes}+\r | {param} |\r +{dashes}+\r \r - """.format(param=param, dashes='-' * (len(param) + 2)) - ), timeout=5) - wrappers.expect_exact(context, '1 row in set', timeout=2) + """.format(param=param, dashes="-" * (len(param) + 2)) + ), + timeout=5, + ) + wrappers.expect_exact(context, "1 row in set", timeout=2) -@then(u'we see result "{result}"') +@then('we see result "{result}"') def step_see_result(context, result): - wrappers.expect_exact( - context, - u"| {} |".format(result), - timeout=2 - ) + wrappers.expect_exact(context, "| {} |".format(result), timeout=2) -@when(u'we query "{query}"') +@when('we query "{query}"') def step_query(context, query): context.cli.sendline(query) -@when(u'we notee output') +@when("we notee output") def step_notee_output(context): - context.cli.sendline('notee') + context.cli.sendline("notee") -@then(u'we see 123456 in tee output') +@then("we see 123456 in tee output") def step_see_123456_in_ouput(context): with open(context.tee_file_name) as f: - assert '123456' in f.read() + assert "123456" in f.read() if os.path.exists(context.tee_file_name): os.remove(context.tee_file_name) -@then(u'delimiter is set to "{delimiter}"') +@then('delimiter is set to "{delimiter}"') def delimiter_is_set(context, delimiter): - wrappers.expect_exact( - context, - u'Changed delimiter to {}'.format(delimiter), - timeout=2 - ) + wrappers.expect_exact(context, "Changed delimiter to {}".format(delimiter), timeout=2) diff --git a/test/features/steps/named_queries.py b/test/features/steps/named_queries.py index bc1f866..93d68ba 100644 --- a/test/features/steps/named_queries.py +++ b/test/features/steps/named_queries.py @@ -9,82 +9,79 @@ import wrappers from behave import when, then -@when('we save a named query') +@when("we save a named query") def step_save_named_query(context): """Send \fs command.""" - context.cli.sendline('\\fs foo SELECT 12345') + context.cli.sendline("\\fs foo SELECT 12345") -@when('we use a named query') +@when("we use a named query") def step_use_named_query(context): """Send \f command.""" - context.cli.sendline('\\f foo') + context.cli.sendline("\\f foo") -@when('we delete a named query') +@when("we delete a named query") def step_delete_named_query(context): """Send \fd command.""" - context.cli.sendline('\\fd foo') + context.cli.sendline("\\fd foo") -@then('we see the named query saved') +@then("we see the named query saved") def step_see_named_query_saved(context): """Wait to see query saved.""" - wrappers.expect_exact(context, 'Saved.', timeout=2) + wrappers.expect_exact(context, "Saved.", timeout=2) -@then('we see the named query executed') +@then("we see the named query executed") def step_see_named_query_executed(context): """Wait to see select output.""" - wrappers.expect_exact(context, 'SELECT 12345', timeout=2) + wrappers.expect_exact(context, "SELECT 12345", timeout=2) -@then('we see the named query deleted') +@then("we see the named query deleted") def step_see_named_query_deleted(context): """Wait to see query deleted.""" - wrappers.expect_exact(context, 'foo: Deleted', timeout=2) + wrappers.expect_exact(context, "foo: Deleted", timeout=2) -@when('we save a named query with parameters') +@when("we save a named query with parameters") def step_save_named_query_with_parameters(context): """Send \fs command for query with parameters.""" context.cli.sendline('\\fs foo_args SELECT $1, "$2", "$3"') -@when('we use named query with parameters') +@when("we use named query with parameters") def step_use_named_query_with_parameters(context): """Send \f command with parameters.""" context.cli.sendline('\\f foo_args 101 second "third value"') -@then('we see the named query with parameters executed') +@then("we see the named query with parameters executed") def step_see_named_query_with_parameters_executed(context): """Wait to see select output.""" - wrappers.expect_exact( - context, 'SELECT 101, "second", "third value"', timeout=2) + wrappers.expect_exact(context, 'SELECT 101, "second", "third value"', timeout=2) -@when('we use named query with too few parameters') +@when("we use named query with too few parameters") def step_use_named_query_with_too_few_parameters(context): """Send \f command with missing parameters.""" - context.cli.sendline('\\f foo_args 101') + context.cli.sendline("\\f foo_args 101") -@then('we see the named query with parameters fail with missing parameters') +@then("we see the named query with parameters fail with missing parameters") def step_see_named_query_with_parameters_fail_with_missing_parameters(context): """Wait to see select output.""" - wrappers.expect_exact( - context, 'missing substitution for $2 in query:', timeout=2) + wrappers.expect_exact(context, "missing substitution for $2 in query:", timeout=2) -@when('we use named query with too many parameters') +@when("we use named query with too many parameters") def step_use_named_query_with_too_many_parameters(context): """Send \f command with extra parameters.""" - context.cli.sendline('\\f foo_args 101 102 103 104') + context.cli.sendline("\\f foo_args 101 102 103 104") -@then('we see the named query with parameters fail with extra parameters') +@then("we see the named query with parameters fail with extra parameters") def step_see_named_query_with_parameters_fail_with_extra_parameters(context): """Wait to see select output.""" - wrappers.expect_exact( - context, 'query does not have substitution parameter $4:', timeout=2) + wrappers.expect_exact(context, "query does not have substitution parameter $4:", timeout=2) diff --git a/test/features/steps/specials.py b/test/features/steps/specials.py index e8b99e3..1b50a00 100644 --- a/test/features/steps/specials.py +++ b/test/features/steps/specials.py @@ -9,10 +9,10 @@ import wrappers from behave import when, then -@when('we refresh completions') +@when("we refresh completions") def step_refresh_completions(context): """Send refresh command.""" - context.cli.sendline('rehash') + context.cli.sendline("rehash") @then('we see text "{text}"') @@ -20,8 +20,8 @@ def step_see_text(context, text): """Wait to see given text message.""" wrappers.expect_exact(context, text, timeout=2) -@then('we see completions refresh started') + +@then("we see completions refresh started") def step_see_refresh_started(context): """Wait to see refresh output.""" - wrappers.expect_exact( - context, 'Auto-completion refresh started in the background.', timeout=2) + wrappers.expect_exact(context, "Auto-completion refresh started in the background.", timeout=2) diff --git a/test/features/steps/utils.py b/test/features/steps/utils.py index 1ae63d2..873f9d4 100644 --- a/test/features/steps/utils.py +++ b/test/features/steps/utils.py @@ -4,8 +4,8 @@ import shlex def parse_cli_args_to_dict(cli_args: str): args_dict = {} for arg in shlex.split(cli_args): - if '=' in arg: - key, value = arg.split('=') + if "=" in arg: + key, value = arg.split("=") args_dict[key] = value else: args_dict[arg] = None diff --git a/test/features/steps/wrappers.py b/test/features/steps/wrappers.py index 6408f23..f9325c6 100644 --- a/test/features/steps/wrappers.py +++ b/test/features/steps/wrappers.py @@ -18,10 +18,9 @@ def expect_exact(context, expected, timeout): timedout = True if timedout: # Strip color codes out of the output. - actual = re.sub(r'\x1b\[([0-9A-Za-z;?])+[m|K]?', - '', context.cli.before) + actual = re.sub(r"\x1b\[([0-9A-Za-z;?])+[m|K]?", "", context.cli.before) raise Exception( - textwrap.dedent('''\ + textwrap.dedent("""\ Expected: --- {0!r} @@ -34,17 +33,12 @@ def expect_exact(context, expected, timeout): --- {2!r} --- - ''').format( - expected, - actual, - context.logfile.getvalue() - ) + """).format(expected, actual, context.logfile.getvalue()) ) def expect_pager(context, expected, timeout): - expect_exact(context, "{0}\r\n{1}{0}\r\n".format( - context.conf['pager_boundary'], expected), timeout=timeout) + expect_exact(context, "{0}\r\n{1}{0}\r\n".format(context.conf["pager_boundary"], expected), timeout=timeout) def run_cli(context, run_args=None, exclude_args=None): @@ -63,55 +57,49 @@ def run_cli(context, run_args=None, exclude_args=None): else: rendered_args.append(key) - if conf.get('host', None): - add_arg('host', '-h', conf['host']) - if conf.get('user', None): - add_arg('user', '-u', conf['user']) - if conf.get('pass', None): - add_arg('pass', '-p', conf['pass']) - if conf.get('port', None): - add_arg('port', '-P', str(conf['port'])) - if conf.get('dbname', None): - add_arg('dbname', '-D', conf['dbname']) - if conf.get('defaults-file', None): - add_arg('defaults_file', '--defaults-file', conf['defaults-file']) - if conf.get('myclirc', None): - add_arg('myclirc', '--myclirc', conf['myclirc']) - if conf.get('login_path'): - add_arg('login_path', '--login-path', conf['login_path']) + if conf.get("host", None): + add_arg("host", "-h", conf["host"]) + if conf.get("user", None): + add_arg("user", "-u", conf["user"]) + if conf.get("pass", None): + add_arg("pass", "-p", conf["pass"]) + if conf.get("port", None): + add_arg("port", "-P", str(conf["port"])) + if conf.get("dbname", None): + add_arg("dbname", "-D", conf["dbname"]) + if conf.get("defaults-file", None): + add_arg("defaults_file", "--defaults-file", conf["defaults-file"]) + if conf.get("myclirc", None): + add_arg("myclirc", "--myclirc", conf["myclirc"]) + if conf.get("login_path"): + add_arg("login_path", "--login-path", conf["login_path"]) for arg_name, arg_value in conf.items(): - if arg_name.startswith('-'): + if arg_name.startswith("-"): add_arg(arg_name, arg_name, arg_value) try: - cli_cmd = context.conf['cli_command'] + cli_cmd = context.conf["cli_command"] except KeyError: - cli_cmd = ( - '{0!s} -c "' - 'import coverage ; ' - 'coverage.process_startup(); ' - 'import mycli.main; ' - 'mycli.main.cli()' - '"' - ).format(sys.executable) + cli_cmd = ('{0!s} -c "' "import coverage ; " "coverage.process_startup(); " "import mycli.main; " "mycli.main.cli()" '"').format( + sys.executable + ) cmd_parts = [cli_cmd] + rendered_args - cmd = ' '.join(cmd_parts) + cmd = " ".join(cmd_parts) context.cli = pexpect.spawnu(cmd, cwd=context.package_root) context.logfile = StringIO() context.cli.logfile = context.logfile context.exit_sent = False - context.currentdb = context.conf['dbname'] + context.currentdb = context.conf["dbname"] def wait_prompt(context, prompt=None): """Make sure prompt is displayed.""" if prompt is None: - user = context.conf['user'] - host = context.conf['host'] + user = context.conf["user"] + host = context.conf["host"] dbname = context.currentdb - prompt = '{0}@{1}:{2}>'.format( - user, host, dbname), + prompt = ("{0}@{1}:{2}>".format(user, host, dbname),) expect_exact(context, prompt, timeout=5) context.atprompt = True diff --git a/test/myclirc b/test/myclirc index 7d96c45..58f7279 100644 --- a/test/myclirc +++ b/test/myclirc @@ -153,6 +153,7 @@ output.null = "#808080" # Favorite queries. [favorite_queries] check = 'select "✔"' +foo_args = 'SELECT $1, "$2", "$3"' # Use the -d option to reference a DSN. # Special characters in passwords and other strings can be escaped with URL encoding. diff --git a/test/test_clistyle.py b/test/test_clistyle.py index f82cdf0..ab40444 100644 --- a/test/test_clistyle.py +++ b/test/test_clistyle.py @@ -1,4 +1,5 @@ """Test the mycli.clistyle module.""" + import pytest from pygments.style import Style @@ -10,9 +11,9 @@ from mycli.clistyle import style_factory @pytest.mark.skip(reason="incompatible with new prompt toolkit") def test_style_factory(): """Test that a Pygments Style class is created.""" - header = 'bold underline #ansired' - cli_style = {'Token.Output.Header': header} - style = style_factory('default', cli_style) + header = "bold underline #ansired" + cli_style = {"Token.Output.Header": header} + style = style_factory("default", cli_style) assert isinstance(style(), Style) assert Token.Output.Header in style.styles @@ -22,6 +23,6 @@ def test_style_factory(): @pytest.mark.skip(reason="incompatible with new prompt toolkit") def test_style_factory_unknown_name(): """Test that an unrecognized name will not throw an error.""" - style = style_factory('foobar', {}) + style = style_factory("foobar", {}) assert isinstance(style(), Style) diff --git a/test/test_completion_engine.py b/test/test_completion_engine.py index 318b632..3104065 100644 --- a/test/test_completion_engine.py +++ b/test/test_completion_engine.py @@ -8,494 +8,528 @@ def sorted_dicts(dicts): def test_select_suggests_cols_with_visible_table_scope(): - suggestions = suggest_type('SELECT FROM tabl', 'SELECT ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['tabl']}, - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + suggestions = suggest_type("SELECT FROM tabl", "SELECT ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) def test_select_suggests_cols_with_qualified_table_scope(): - suggestions = suggest_type('SELECT FROM sch.tabl', 'SELECT ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['tabl']}, - {'type': 'column', 'tables': [('sch', 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) - - -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM tabl WHERE ', - 'SELECT * FROM tabl WHERE (', - 'SELECT * FROM tabl WHERE foo = ', - 'SELECT * FROM tabl WHERE bar OR ', - 'SELECT * FROM tabl WHERE foo = 1 AND ', - 'SELECT * FROM tabl WHERE (bar > 10 AND ', - 'SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (', - 'SELECT * FROM tabl WHERE 10 < ', - 'SELECT * FROM tabl WHERE foo BETWEEN ', - 'SELECT * FROM tabl WHERE foo BETWEEN foo AND ', -]) + suggestions = suggest_type("SELECT FROM sch.tabl", "SELECT ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [("sch", "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM tabl WHERE ", + "SELECT * FROM tabl WHERE (", + "SELECT * FROM tabl WHERE foo = ", + "SELECT * FROM tabl WHERE bar OR ", + "SELECT * FROM tabl WHERE foo = 1 AND ", + "SELECT * FROM tabl WHERE (bar > 10 AND ", + "SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (", + "SELECT * FROM tabl WHERE 10 < ", + "SELECT * FROM tabl WHERE foo BETWEEN ", + "SELECT * FROM tabl WHERE foo BETWEEN foo AND ", + ], +) def test_where_suggests_columns_functions(expression): suggestions = suggest_type(expression, expression) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['tabl']}, - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) - - -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM tabl WHERE foo IN (', - 'SELECT * FROM tabl WHERE foo IN (bar, ', -]) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM tabl WHERE foo IN (", + "SELECT * FROM tabl WHERE foo IN (bar, ", + ], +) def test_where_in_suggests_columns(expression): suggestions = suggest_type(expression, expression) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['tabl']}, - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) def test_where_equals_any_suggests_columns_or_keywords(): - text = 'SELECT * FROM tabl WHERE foo = ANY(' + text = "SELECT * FROM tabl WHERE foo = ANY(" suggestions = suggest_type(text, text) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['tabl']}, - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}]) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) def test_lparen_suggests_cols(): - suggestion = suggest_type('SELECT MAX( FROM tbl', 'SELECT MAX(') - assert suggestion == [ - {'type': 'column', 'tables': [(None, 'tbl', None)]}] + suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(") + assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}] def test_operand_inside_function_suggests_cols1(): - suggestion = suggest_type( - 'SELECT MAX(col1 + FROM tbl', 'SELECT MAX(col1 + ') - assert suggestion == [ - {'type': 'column', 'tables': [(None, 'tbl', None)]}] + suggestion = suggest_type("SELECT MAX(col1 + FROM tbl", "SELECT MAX(col1 + ") + assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}] def test_operand_inside_function_suggests_cols2(): - suggestion = suggest_type( - 'SELECT MAX(col1 + col2 + FROM tbl', 'SELECT MAX(col1 + col2 + ') - assert suggestion == [ - {'type': 'column', 'tables': [(None, 'tbl', None)]}] + suggestion = suggest_type("SELECT MAX(col1 + col2 + FROM tbl", "SELECT MAX(col1 + col2 + ") + assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}] def test_select_suggests_cols_and_funcs(): - suggestions = suggest_type('SELECT ', 'SELECT ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': []}, - {'type': 'column', 'tables': []}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) - - -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM ', - 'INSERT INTO ', - 'COPY ', - 'UPDATE ', - 'DESCRIBE ', - 'DESC ', - 'EXPLAIN ', - 'SELECT * FROM foo JOIN ', -]) + suggestions = suggest_type("SELECT ", "SELECT ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": []}, + {"type": "column", "tables": []}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM ", + "INSERT INTO ", + "COPY ", + "UPDATE ", + "DESCRIBE ", + "DESC ", + "EXPLAIN ", + "SELECT * FROM foo JOIN ", + ], +) def test_expression_suggests_tables_views_and_schemas(expression): suggestions = suggest_type(expression, expression) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) - - -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM sch.', - 'INSERT INTO sch.', - 'COPY sch.', - 'UPDATE sch.', - 'DESCRIBE sch.', - 'DESC sch.', - 'EXPLAIN sch.', - 'SELECT * FROM foo JOIN sch.', -]) + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM sch.", + "INSERT INTO sch.", + "COPY sch.", + "UPDATE sch.", + "DESCRIBE sch.", + "DESC sch.", + "EXPLAIN sch.", + "SELECT * FROM foo JOIN sch.", + ], +) def test_expression_suggests_qualified_tables_views_and_schemas(expression): suggestions = suggest_type(expression, expression) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': 'sch'}, - {'type': 'view', 'schema': 'sch'}]) + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": "sch"}, {"type": "view", "schema": "sch"}]) def test_truncate_suggests_tables_and_schemas(): - suggestions = suggest_type('TRUNCATE ', 'TRUNCATE ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'schema'}]) + suggestions = suggest_type("TRUNCATE ", "TRUNCATE ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "schema"}]) def test_truncate_suggests_qualified_tables(): - suggestions = suggest_type('TRUNCATE sch.', 'TRUNCATE sch.') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': 'sch'}]) + suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": "sch"}]) def test_distinct_suggests_cols(): - suggestions = suggest_type('SELECT DISTINCT ', 'SELECT DISTINCT ') - assert suggestions == [{'type': 'column', 'tables': []}] + suggestions = suggest_type("SELECT DISTINCT ", "SELECT DISTINCT ") + assert suggestions == [{"type": "column", "tables": []}] def test_col_comma_suggests_cols(): - suggestions = suggest_type('SELECT a, b, FROM tbl', 'SELECT a, b,') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['tbl']}, - {'type': 'column', 'tables': [(None, 'tbl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["tbl"]}, + {"type": "column", "tables": [(None, "tbl", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) def test_table_comma_suggests_tables_and_schemas(): - suggestions = suggest_type('SELECT a, b FROM tbl1, ', - 'SELECT a, b FROM tbl1, ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) + suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) def test_into_suggests_tables_and_schemas(): - suggestion = suggest_type('INSERT INTO ', 'INSERT INTO ') - assert sorted_dicts(suggestion) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) + suggestion = suggest_type("INSERT INTO ", "INSERT INTO ") + assert sorted_dicts(suggestion) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) def test_insert_into_lparen_suggests_cols(): - suggestions = suggest_type('INSERT INTO abc (', 'INSERT INTO abc (') - assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}] + suggestions = suggest_type("INSERT INTO abc (", "INSERT INTO abc (") + assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}] def test_insert_into_lparen_partial_text_suggests_cols(): - suggestions = suggest_type('INSERT INTO abc (i', 'INSERT INTO abc (i') - assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}] + suggestions = suggest_type("INSERT INTO abc (i", "INSERT INTO abc (i") + assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}] def test_insert_into_lparen_comma_suggests_cols(): - suggestions = suggest_type('INSERT INTO abc (id,', 'INSERT INTO abc (id,') - assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}] + suggestions = suggest_type("INSERT INTO abc (id,", "INSERT INTO abc (id,") + assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}] def test_partially_typed_col_name_suggests_col_names(): - suggestions = suggest_type('SELECT * FROM tabl WHERE col_n', - 'SELECT * FROM tabl WHERE col_n') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['tabl']}, - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + suggestions = suggest_type("SELECT * FROM tabl WHERE col_n", "SELECT * FROM tabl WHERE col_n") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) def test_dot_suggests_cols_of_a_table_or_schema_qualified_table(): - suggestions = suggest_type('SELECT tabl. FROM tabl', 'SELECT tabl.') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'table', 'schema': 'tabl'}, - {'type': 'view', 'schema': 'tabl'}, - {'type': 'function', 'schema': 'tabl'}]) + suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "table", "schema": "tabl"}, + {"type": "view", "schema": "tabl"}, + {"type": "function", "schema": "tabl"}, + ] + ) def test_dot_suggests_cols_of_an_alias(): - suggestions = suggest_type('SELECT t1. FROM tabl1 t1, tabl2 t2', - 'SELECT t1.') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': 't1'}, - {'type': 'view', 'schema': 't1'}, - {'type': 'column', 'tables': [(None, 'tabl1', 't1')]}, - {'type': 'function', 'schema': 't1'}]) + suggestions = suggest_type("SELECT t1. FROM tabl1 t1, tabl2 t2", "SELECT t1.") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "table", "schema": "t1"}, + {"type": "view", "schema": "t1"}, + {"type": "column", "tables": [(None, "tabl1", "t1")]}, + {"type": "function", "schema": "t1"}, + ] + ) def test_dot_col_comma_suggests_cols_or_schema_qualified_table(): - suggestions = suggest_type('SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2', - 'SELECT t1.a, t2.') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'tabl2', 't2')]}, - {'type': 'table', 'schema': 't2'}, - {'type': 'view', 'schema': 't2'}, - {'type': 'function', 'schema': 't2'}]) - - -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM (', - 'SELECT * FROM foo WHERE EXISTS (', - 'SELECT * FROM foo WHERE bar AND NOT EXISTS (', - 'SELECT 1 AS', -]) + suggestions = suggest_type("SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2.") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "column", "tables": [(None, "tabl2", "t2")]}, + {"type": "table", "schema": "t2"}, + {"type": "view", "schema": "t2"}, + {"type": "function", "schema": "t2"}, + ] + ) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM (", + "SELECT * FROM foo WHERE EXISTS (", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (", + "SELECT 1 AS", + ], +) def test_sub_select_suggests_keyword(expression): suggestion = suggest_type(expression, expression) - assert suggestion == [{'type': 'keyword'}] + assert suggestion == [{"type": "keyword"}] -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM (S', - 'SELECT * FROM foo WHERE EXISTS (S', - 'SELECT * FROM foo WHERE bar AND NOT EXISTS (S', -]) +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM (S", + "SELECT * FROM foo WHERE EXISTS (S", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (S", + ], +) def test_sub_select_partial_text_suggests_keyword(expression): suggestion = suggest_type(expression, expression) - assert suggestion == [{'type': 'keyword'}] + assert suggestion == [{"type": "keyword"}] def test_outer_table_reference_in_exists_subquery_suggests_columns(): - q = 'SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f.' + q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f." suggestions = suggest_type(q, q) assert suggestions == [ - {'type': 'column', 'tables': [(None, 'foo', 'f')]}, - {'type': 'table', 'schema': 'f'}, - {'type': 'view', 'schema': 'f'}, - {'type': 'function', 'schema': 'f'}] - - -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM (SELECT * FROM ', - 'SELECT * FROM foo WHERE EXISTS (SELECT * FROM ', - 'SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ', -]) + {"type": "column", "tables": [(None, "foo", "f")]}, + {"type": "table", "schema": "f"}, + {"type": "view", "schema": "f"}, + {"type": "function", "schema": "f"}, + ] + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM (SELECT * FROM ", + "SELECT * FROM foo WHERE EXISTS (SELECT * FROM ", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ", + ], +) def test_sub_select_table_name_completion(expression): suggestion = suggest_type(expression, expression) - assert sorted_dicts(suggestion) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) + assert sorted_dicts(suggestion) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) def test_sub_select_col_name_completion(): - suggestions = suggest_type('SELECT * FROM (SELECT FROM abc', - 'SELECT * FROM (SELECT ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['abc']}, - {'type': 'column', 'tables': [(None, 'abc', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + suggestions = suggest_type("SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["abc"]}, + {"type": "column", "tables": [(None, "abc", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) @pytest.mark.xfail def test_sub_select_multiple_col_name_completion(): - suggestions = suggest_type('SELECT * FROM (SELECT a, FROM abc', - 'SELECT * FROM (SELECT a, ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'abc', None)]}, - {'type': 'function', 'schema': []}]) + suggestions = suggest_type("SELECT * FROM (SELECT a, FROM abc", "SELECT * FROM (SELECT a, ") + assert sorted_dicts(suggestions) == sorted_dicts( + [{"type": "column", "tables": [(None, "abc", None)]}, {"type": "function", "schema": []}] + ) def test_sub_select_dot_col_name_completion(): - suggestions = suggest_type('SELECT * FROM (SELECT t. FROM tabl t', - 'SELECT * FROM (SELECT t.') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'tabl', 't')]}, - {'type': 'table', 'schema': 't'}, - {'type': 'view', 'schema': 't'}, - {'type': 'function', 'schema': 't'}]) - - -@pytest.mark.parametrize('join_type', ['', 'INNER', 'LEFT', 'RIGHT OUTER']) -@pytest.mark.parametrize('tbl_alias', ['', 'foo']) + suggestions = suggest_type("SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t.") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "column", "tables": [(None, "tabl", "t")]}, + {"type": "table", "schema": "t"}, + {"type": "view", "schema": "t"}, + {"type": "function", "schema": "t"}, + ] + ) + + +@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"]) +@pytest.mark.parametrize("tbl_alias", ["", "foo"]) def test_join_suggests_tables_and_schemas(tbl_alias, join_type): - text = 'SELECT * FROM abc {0} {1} JOIN '.format(tbl_alias, join_type) + text = "SELECT * FROM abc {0} {1} JOIN ".format(tbl_alias, join_type) suggestion = suggest_type(text, text) - assert sorted_dicts(suggestion) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) + assert sorted_dicts(suggestion) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) -@pytest.mark.parametrize('sql', [ - 'SELECT * FROM abc a JOIN def d ON a.', - 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.', -]) +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM abc a JOIN def d ON a.", + "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.", + ], +) def test_join_alias_dot_suggests_cols1(sql): suggestions = suggest_type(sql, sql) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'abc', 'a')]}, - {'type': 'table', 'schema': 'a'}, - {'type': 'view', 'schema': 'a'}, - {'type': 'function', 'schema': 'a'}]) - - -@pytest.mark.parametrize('sql', [ - 'SELECT * FROM abc a JOIN def d ON a.id = d.', - 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.', -]) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "column", "tables": [(None, "abc", "a")]}, + {"type": "table", "schema": "a"}, + {"type": "view", "schema": "a"}, + {"type": "function", "schema": "a"}, + ] + ) + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM abc a JOIN def d ON a.id = d.", + "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.", + ], +) def test_join_alias_dot_suggests_cols2(sql): suggestions = suggest_type(sql, sql) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'def', 'd')]}, - {'type': 'table', 'schema': 'd'}, - {'type': 'view', 'schema': 'd'}, - {'type': 'function', 'schema': 'd'}]) - - -@pytest.mark.parametrize('sql', [ - 'select a.x, b.y from abc a join bcd b on ', - 'select a.x, b.y from abc a join bcd b on a.id = b.id OR ', -]) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "column", "tables": [(None, "def", "d")]}, + {"type": "table", "schema": "d"}, + {"type": "view", "schema": "d"}, + {"type": "function", "schema": "d"}, + ] + ) + + +@pytest.mark.parametrize( + "sql", + [ + "select a.x, b.y from abc a join bcd b on ", + "select a.x, b.y from abc a join bcd b on a.id = b.id OR ", + ], +) def test_on_suggests_aliases(sql): suggestions = suggest_type(sql, sql) - assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}] + assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}] -@pytest.mark.parametrize('sql', [ - 'select abc.x, bcd.y from abc join bcd on ', - 'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ', -]) +@pytest.mark.parametrize( + "sql", + [ + "select abc.x, bcd.y from abc join bcd on ", + "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ", + ], +) def test_on_suggests_tables(sql): suggestions = suggest_type(sql, sql) - assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}] + assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}] -@pytest.mark.parametrize('sql', [ - 'select a.x, b.y from abc a join bcd b on a.id = ', - 'select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ', -]) +@pytest.mark.parametrize( + "sql", + [ + "select a.x, b.y from abc a join bcd b on a.id = ", + "select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ", + ], +) def test_on_suggests_aliases_right_side(sql): suggestions = suggest_type(sql, sql) - assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}] + assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}] -@pytest.mark.parametrize('sql', [ - 'select abc.x, bcd.y from abc join bcd on ', - 'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ', -]) +@pytest.mark.parametrize( + "sql", + [ + "select abc.x, bcd.y from abc join bcd on ", + "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ", + ], +) def test_on_suggests_tables_right_side(sql): suggestions = suggest_type(sql, sql) - assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}] + assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}] -@pytest.mark.parametrize('col_list', ['', 'col1, ']) +@pytest.mark.parametrize("col_list", ["", "col1, "]) def test_join_using_suggests_common_columns(col_list): - text = 'select * from abc inner join def using (' + col_list - assert suggest_type(text, text) == [ - {'type': 'column', - 'tables': [(None, 'abc', None), (None, 'def', None)], - 'drop_unique': True}] - -@pytest.mark.parametrize('sql', [ - 'SELECT * FROM abc a JOIN def d ON a.id = d.id JOIN ghi g ON g.', - 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.id2 JOIN ghi g ON d.id = g.id AND g.', -]) + text = "select * from abc inner join def using (" + col_list + assert suggest_type(text, text) == [{"type": "column", "tables": [(None, "abc", None), (None, "def", None)], "drop_unique": True}] + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM abc a JOIN def d ON a.id = d.id JOIN ghi g ON g.", + "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.id2 JOIN ghi g ON d.id = g.id AND g.", + ], +) def test_two_join_alias_dot_suggests_cols1(sql): suggestions = suggest_type(sql, sql) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'ghi', 'g')]}, - {'type': 'table', 'schema': 'g'}, - {'type': 'view', 'schema': 'g'}, - {'type': 'function', 'schema': 'g'}]) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "column", "tables": [(None, "ghi", "g")]}, + {"type": "table", "schema": "g"}, + {"type": "view", "schema": "g"}, + {"type": "function", "schema": "g"}, + ] + ) + def test_2_statements_2nd_current(): - suggestions = suggest_type('select * from a; select * from ', - 'select * from a; select * from ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) - - suggestions = suggest_type('select * from a; select from b', - 'select * from a; select ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['b']}, - {'type': 'column', 'tables': [(None, 'b', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + suggestions = suggest_type("select * from a; select * from ", "select * from a; select * from ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) + + suggestions = suggest_type("select * from a; select from b", "select * from a; select ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["b"]}, + {"type": "column", "tables": [(None, "b", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) # Should work even if first statement is invalid - suggestions = suggest_type('select * from; select * from ', - 'select * from; select * from ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) + suggestions = suggest_type("select * from; select * from ", "select * from; select * from ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) def test_2_statements_1st_current(): - suggestions = suggest_type('select * from ; select * from b', - 'select * from ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) - - suggestions = suggest_type('select from a; select * from b', - 'select ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['a']}, - {'type': 'column', 'tables': [(None, 'a', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + suggestions = suggest_type("select * from ; select * from b", "select * from ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) + + suggestions = suggest_type("select from a; select * from b", "select ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["a"]}, + {"type": "column", "tables": [(None, "a", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) def test_3_statements_2nd_current(): - suggestions = suggest_type('select * from a; select * from ; select * from c', - 'select * from a; select * from ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) - - suggestions = suggest_type('select * from a; select from b; select * from c', - 'select * from a; select ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['b']}, - {'type': 'column', 'tables': [(None, 'b', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + suggestions = suggest_type("select * from a; select * from ; select * from c", "select * from a; select * from ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) + + suggestions = suggest_type("select * from a; select from b; select * from c", "select * from a; select ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["b"]}, + {"type": "column", "tables": [(None, "b", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) def test_create_db_with_template(): - suggestions = suggest_type('create database foo with template ', - 'create database foo with template ') + suggestions = suggest_type("create database foo with template ", "create database foo with template ") - assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'database'}]) + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "database"}]) -@pytest.mark.parametrize('initial_text', ['', ' ', '\t \t']) +@pytest.mark.parametrize("initial_text", ["", " ", "\t \t"]) def test_specials_included_for_initial_completion(initial_text): suggestions = suggest_type(initial_text, initial_text) - assert sorted_dicts(suggestions) == \ - sorted_dicts([{'type': 'keyword'}, {'type': 'special'}]) + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "keyword"}, {"type": "special"}]) def test_specials_not_included_after_initial_token(): - suggestions = suggest_type('create table foo (dt d', - 'create table foo (dt d') + suggestions = suggest_type("create table foo (dt d", "create table foo (dt d") - assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'keyword'}]) + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "keyword"}]) def test_drop_schema_qualified_table_suggests_only_tables(): - text = 'DROP TABLE schema_name.table_name' + text = "DROP TABLE schema_name.table_name" suggestions = suggest_type(text, text) - assert suggestions == [{'type': 'table', 'schema': 'schema_name'}] + assert suggestions == [{"type": "table", "schema": "schema_name"}] -@pytest.mark.parametrize('text', [',', ' ,', 'sel ,']) +@pytest.mark.parametrize("text", [",", " ,", "sel ,"]) def test_handle_pre_completion_comma_gracefully(text): suggestions = suggest_type(text, text) @@ -503,53 +537,59 @@ def test_handle_pre_completion_comma_gracefully(text): def test_cross_join(): - text = 'select * from v1 cross join v2 JOIN v1.id, ' + text = "select * from v1 cross join v2 JOIN v1.id, " suggestions = suggest_type(text, text) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) -@pytest.mark.parametrize('expression', [ - 'SELECT 1 AS ', - 'SELECT 1 FROM tabl AS ', -]) +@pytest.mark.parametrize( + "expression", + [ + "SELECT 1 AS ", + "SELECT 1 FROM tabl AS ", + ], +) def test_after_as(expression): suggestions = suggest_type(expression, expression) assert set(suggestions) == set() -@pytest.mark.parametrize('expression', [ - '\\. ', - 'select 1; \\. ', - 'select 1;\\. ', - 'select 1 ; \\. ', - 'source ', - 'truncate table test; source ', - 'truncate table test ; source ', - 'truncate table test;source ', -]) +@pytest.mark.parametrize( + "expression", + [ + "\\. ", + "select 1; \\. ", + "select 1;\\. ", + "select 1 ; \\. ", + "source ", + "truncate table test; source ", + "truncate table test ; source ", + "truncate table test;source ", + ], +) def test_source_is_file(expression): suggestions = suggest_type(expression, expression) - assert suggestions == [{'type': 'file_name'}] + assert suggestions == [{"type": "file_name"}] -@pytest.mark.parametrize("expression", [ - "\\f ", -]) +@pytest.mark.parametrize( + "expression", + [ + "\\f ", + ], +) def test_favorite_name_suggestion(expression): suggestions = suggest_type(expression, expression) - assert suggestions == [{'type': 'favoritequery'}] + assert suggestions == [{"type": "favoritequery"}] def test_order_by(): - text = 'select * from foo order by ' + text = "select * from foo order by " suggestions = suggest_type(text, text) - assert suggestions == [{'tables': [(None, 'foo', None)], 'type': 'column'}] + assert suggestions == [{"tables": [(None, "foo", None)], "type": "column"}] def test_quoted_where(): text = "'where i=';" suggestions = suggest_type(text, text) - assert suggestions == [{'type': 'keyword'}] + assert suggestions == [{"type": "keyword"}] diff --git a/test/test_completion_refresher.py b/test/test_completion_refresher.py index 31359cf..6f192d0 100644 --- a/test/test_completion_refresher.py +++ b/test/test_completion_refresher.py @@ -6,6 +6,7 @@ from unittest.mock import Mock, patch @pytest.fixture def refresher(): from mycli.completion_refresher import CompletionRefresher + return CompletionRefresher() @@ -18,8 +19,7 @@ def test_ctor(refresher): """ assert len(refresher.refreshers) > 0 actual_handlers = list(refresher.refreshers.keys()) - expected_handlers = ['databases', 'schemata', 'tables', 'users', 'functions', - 'special_commands', 'show_commands', 'keywords'] + expected_handlers = ["databases", "schemata", "tables", "users", "functions", "special_commands", "show_commands", "keywords"] assert expected_handlers == actual_handlers @@ -32,12 +32,12 @@ def test_refresh_called_once(refresher): callbacks = Mock() sqlexecute = Mock() - with patch.object(refresher, '_bg_refresh') as bg_refresh: + with patch.object(refresher, "_bg_refresh") as bg_refresh: actual = refresher.refresh(sqlexecute, callbacks) time.sleep(1) # Wait for the thread to work. assert len(actual) == 1 assert len(actual[0]) == 4 - assert actual[0][3] == 'Auto-completion refresh started in the background.' + assert actual[0][3] == "Auto-completion refresh started in the background." bg_refresh.assert_called_with(sqlexecute, callbacks, {}) @@ -61,13 +61,13 @@ def test_refresh_called_twice(refresher): time.sleep(1) # Wait for the thread to work. assert len(actual1) == 1 assert len(actual1[0]) == 4 - assert actual1[0][3] == 'Auto-completion refresh started in the background.' + assert actual1[0][3] == "Auto-completion refresh started in the background." actual2 = refresher.refresh(sqlexecute, callbacks) time.sleep(1) # Wait for the thread to work. assert len(actual2) == 1 assert len(actual2[0]) == 4 - assert actual2[0][3] == 'Auto-completion refresh restarted.' + assert actual2[0][3] == "Auto-completion refresh restarted." def test_refresh_with_callbacks(refresher): @@ -80,9 +80,9 @@ def test_refresh_with_callbacks(refresher): sqlexecute_class = Mock() sqlexecute = Mock() - with patch('mycli.completion_refresher.SQLExecute', sqlexecute_class): + with patch("mycli.completion_refresher.SQLExecute", sqlexecute_class): # Set refreshers to 0: we're not testing refresh logic here refresher.refreshers = {} refresher.refresh(sqlexecute, callbacks) time.sleep(1) # Wait for the thread to work. - assert (callbacks[0].call_count == 1) + assert callbacks[0].call_count == 1 diff --git a/test/test_config.py b/test/test_config.py index 7f2b244..859ca02 100644 --- a/test/test_config.py +++ b/test/test_config.py @@ -1,4 +1,5 @@ """Unit tests for the mycli.config module.""" + from io import BytesIO, StringIO, TextIOWrapper import os import struct @@ -6,21 +7,26 @@ import sys import tempfile import pytest -from mycli.config import (get_mylogin_cnf_path, open_mylogin_cnf, - read_and_decrypt_mylogin_cnf, read_config_file, - str_to_bool, strip_matching_quotes) +from mycli.config import ( + get_mylogin_cnf_path, + open_mylogin_cnf, + read_and_decrypt_mylogin_cnf, + read_config_file, + str_to_bool, + strip_matching_quotes, +) -LOGIN_PATH_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__), - 'mylogin.cnf')) +LOGIN_PATH_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__), "mylogin.cnf")) def open_bmylogin_cnf(name): """Open contents of *name* in a BytesIO buffer.""" - with open(name, 'rb') as f: + with open(name, "rb") as f: buf = BytesIO() buf.write(f.read()) return buf + def test_read_mylogin_cnf(): """Tests that a login path file can be read and decrypted.""" mylogin_cnf = open_mylogin_cnf(LOGIN_PATH_FILE) @@ -28,7 +34,7 @@ def test_read_mylogin_cnf(): assert isinstance(mylogin_cnf, TextIOWrapper) contents = mylogin_cnf.read() - for word in ('[test]', 'user', 'password', 'host', 'port'): + for word in ("[test]", "user", "password", "host", "port"): assert word in contents @@ -46,7 +52,7 @@ def test_corrupted_login_key(): buf.seek(4) # Write null bytes over half the login key - buf.write(b'\0\0\0\0\0\0\0\0\0\0') + buf.write(b"\0\0\0\0\0\0\0\0\0\0") buf.seek(0) mylogin_cnf = read_and_decrypt_mylogin_cnf(buf) @@ -63,58 +69,58 @@ def test_corrupted_pad(): # Skip option group len_buf = buf.read(4) - cipher_len, = struct.unpack("<i", len_buf) + (cipher_len,) = struct.unpack("<i", len_buf) buf.read(cipher_len) # Corrupt the pad for the user line len_buf = buf.read(4) - cipher_len, = struct.unpack("<i", len_buf) + (cipher_len,) = struct.unpack("<i", len_buf) buf.read(cipher_len - 1) - buf.write(b'\0') + buf.write(b"\0") buf.seek(0) mylogin_cnf = TextIOWrapper(read_and_decrypt_mylogin_cnf(buf)) contents = mylogin_cnf.read() - for word in ('[test]', 'password', 'host', 'port'): + for word in ("[test]", "password", "host", "port"): assert word in contents - assert 'user' not in contents + assert "user" not in contents def test_get_mylogin_cnf_path(): """Tests that the path for .mylogin.cnf is detected.""" original_env = None - if 'MYSQL_TEST_LOGIN_FILE' in os.environ: - original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE') - is_windows = sys.platform == 'win32' + if "MYSQL_TEST_LOGIN_FILE" in os.environ: + original_env = os.environ.pop("MYSQL_TEST_LOGIN_FILE") + is_windows = sys.platform == "win32" login_cnf_path = get_mylogin_cnf_path() if original_env is not None: - os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env + os.environ["MYSQL_TEST_LOGIN_FILE"] = original_env if login_cnf_path is not None: - assert login_cnf_path.endswith('.mylogin.cnf') + assert login_cnf_path.endswith(".mylogin.cnf") if is_windows is True: - assert 'MySQL' in login_cnf_path + assert "MySQL" in login_cnf_path else: - home_dir = os.path.expanduser('~') + home_dir = os.path.expanduser("~") assert login_cnf_path.startswith(home_dir) def test_alternate_get_mylogin_cnf_path(): """Tests that the alternate path for .mylogin.cnf is detected.""" original_env = None - if 'MYSQL_TEST_LOGIN_FILE' in os.environ: - original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE') + if "MYSQL_TEST_LOGIN_FILE" in os.environ: + original_env = os.environ.pop("MYSQL_TEST_LOGIN_FILE") _, temp_path = tempfile.mkstemp() - os.environ['MYSQL_TEST_LOGIN_FILE'] = temp_path + os.environ["MYSQL_TEST_LOGIN_FILE"] = temp_path login_cnf_path = get_mylogin_cnf_path() if original_env is not None: - os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env + os.environ["MYSQL_TEST_LOGIN_FILE"] = original_env assert temp_path == login_cnf_path @@ -124,17 +130,17 @@ def test_str_to_bool(): assert str_to_bool(False) is False assert str_to_bool(True) is True - assert str_to_bool('False') is False - assert str_to_bool('True') is True - assert str_to_bool('TRUE') is True - assert str_to_bool('1') is True - assert str_to_bool('0') is False - assert str_to_bool('on') is True - assert str_to_bool('off') is False - assert str_to_bool('off') is False + assert str_to_bool("False") is False + assert str_to_bool("True") is True + assert str_to_bool("TRUE") is True + assert str_to_bool("1") is True + assert str_to_bool("0") is False + assert str_to_bool("on") is True + assert str_to_bool("off") is False + assert str_to_bool("off") is False with pytest.raises(ValueError): - str_to_bool('foo') + str_to_bool("foo") with pytest.raises(TypeError): str_to_bool(None) @@ -143,19 +149,19 @@ def test_str_to_bool(): def test_read_config_file_list_values_default(): """Test that reading a config file uses list_values by default.""" - f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n") + f = StringIO("[main]\nweather='cloudy with a chance of meatballs'\n") config = read_config_file(f) - assert config['main']['weather'] == u"cloudy with a chance of meatballs" + assert config["main"]["weather"] == "cloudy with a chance of meatballs" def test_read_config_file_list_values_off(): """Test that you can disable list_values when reading a config file.""" - f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n") + f = StringIO("[main]\nweather='cloudy with a chance of meatballs'\n") config = read_config_file(f, list_values=False) - assert config['main']['weather'] == u"'cloudy with a chance of meatballs'" + assert config["main"]["weather"] == "'cloudy with a chance of meatballs'" def test_strip_quotes_with_matching_quotes(): @@ -177,7 +183,7 @@ def test_strip_quotes_with_unmatching_quotes(): def test_strip_quotes_with_empty_string(): """Test that an empty string is handled during unquoting.""" - assert '' == strip_matching_quotes('') + assert "" == strip_matching_quotes("") def test_strip_quotes_with_none(): diff --git a/test/test_dbspecial.py b/test/test_dbspecial.py index 21e389c..aee6e05 100644 --- a/test/test_dbspecial.py +++ b/test/test_dbspecial.py @@ -4,39 +4,32 @@ from mycli.packages.special.utils import format_uptime def test_u_suggests_databases(): - suggestions = suggest_type('\\u ', '\\u ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'database'}]) + suggestions = suggest_type("\\u ", "\\u ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "database"}]) def test_describe_table(): - suggestions = suggest_type('\\dt', '\\dt ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) + suggestions = suggest_type("\\dt", "\\dt ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) def test_list_or_show_create_tables(): - suggestions = suggest_type('\\dt+', '\\dt+ ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) + suggestions = suggest_type("\\dt+", "\\dt+ ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) def test_format_uptime(): seconds = 59 - assert '59 sec' == format_uptime(seconds) + assert "59 sec" == format_uptime(seconds) seconds = 120 - assert '2 min 0 sec' == format_uptime(seconds) + assert "2 min 0 sec" == format_uptime(seconds) seconds = 54890 - assert '15 hours 14 min 50 sec' == format_uptime(seconds) + assert "15 hours 14 min 50 sec" == format_uptime(seconds) seconds = 598244 - assert '6 days 22 hours 10 min 44 sec' == format_uptime(seconds) + assert "6 days 22 hours 10 min 44 sec" == format_uptime(seconds) seconds = 522600 - assert '6 days 1 hour 10 min 0 sec' == format_uptime(seconds) + assert "6 days 1 hour 10 min 0 sec" == format_uptime(seconds) diff --git a/test/test_main.py b/test/test_main.py index 589d6cd..b0f8d4c 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -13,52 +13,62 @@ from textwrap import dedent from collections import namedtuple from tempfile import NamedTemporaryFile -from textwrap import dedent test_dir = os.path.abspath(os.path.dirname(__file__)) project_dir = os.path.dirname(test_dir) -default_config_file = os.path.join(project_dir, 'test', 'myclirc') -login_path_file = os.path.join(test_dir, 'mylogin.cnf') - -os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file -CLI_ARGS = ['--user', USER, '--host', HOST, '--port', PORT, - '--password', PASSWORD, '--myclirc', default_config_file, - '--defaults-file', default_config_file, - 'mycli_test_db'] +default_config_file = os.path.join(project_dir, "test", "myclirc") +login_path_file = os.path.join(test_dir, "mylogin.cnf") + +os.environ["MYSQL_TEST_LOGIN_FILE"] = login_path_file +CLI_ARGS = [ + "--user", + USER, + "--host", + HOST, + "--port", + PORT, + "--password", + PASSWORD, + "--myclirc", + default_config_file, + "--defaults-file", + default_config_file, + "mycli_test_db", +] @dbtest def test_execute_arg(executor): - run(executor, 'create table test (a text)') + run(executor, "create table test (a text)") run(executor, 'insert into test values("abc")') - sql = 'select * from test;' + sql = "select * from test;" runner = CliRunner() - result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql]) + result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql]) assert result.exit_code == 0 - assert 'abc' in result.output + assert "abc" in result.output - result = runner.invoke(cli, args=CLI_ARGS + ['--execute', sql]) + result = runner.invoke(cli, args=CLI_ARGS + ["--execute", sql]) assert result.exit_code == 0 - assert 'abc' in result.output + assert "abc" in result.output - expected = 'a\nabc\n' + expected = "a\nabc\n" assert expected in result.output @dbtest def test_execute_arg_with_table(executor): - run(executor, 'create table test (a text)') + run(executor, "create table test (a text)") run(executor, 'insert into test values("abc")') - sql = 'select * from test;' + sql = "select * from test;" runner = CliRunner() - result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--table']) - expected = '+-----+\n| a |\n+-----+\n| abc |\n+-----+\n' + result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql] + ["--table"]) + expected = "+-----+\n| a |\n+-----+\n| abc |\n+-----+\n" assert result.exit_code == 0 assert expected in result.output @@ -66,12 +76,12 @@ def test_execute_arg_with_table(executor): @dbtest def test_execute_arg_with_csv(executor): - run(executor, 'create table test (a text)') + run(executor, "create table test (a text)") run(executor, 'insert into test values("abc")') - sql = 'select * from test;' + sql = "select * from test;" runner = CliRunner() - result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--csv']) + result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql] + ["--csv"]) expected = '"a"\n"abc"\n' assert result.exit_code == 0 @@ -80,35 +90,29 @@ def test_execute_arg_with_csv(executor): @dbtest def test_batch_mode(executor): - run(executor, '''create table test(a text)''') - run(executor, '''insert into test values('abc'), ('def'), ('ghi')''') + run(executor, """create table test(a text)""") + run(executor, """insert into test values('abc'), ('def'), ('ghi')""") - sql = ( - 'select count(*) from test;\n' - 'select * from test limit 1;' - ) + sql = "select count(*) from test;\n" "select * from test limit 1;" runner = CliRunner() result = runner.invoke(cli, args=CLI_ARGS, input=sql) assert result.exit_code == 0 - assert 'count(*)\n3\na\nabc\n' in "".join(result.output) + assert "count(*)\n3\na\nabc\n" in "".join(result.output) @dbtest def test_batch_mode_table(executor): - run(executor, '''create table test(a text)''') - run(executor, '''insert into test values('abc'), ('def'), ('ghi')''') + run(executor, """create table test(a text)""") + run(executor, """insert into test values('abc'), ('def'), ('ghi')""") - sql = ( - 'select count(*) from test;\n' - 'select * from test limit 1;' - ) + sql = "select count(*) from test;\n" "select * from test limit 1;" runner = CliRunner() - result = runner.invoke(cli, args=CLI_ARGS + ['-t'], input=sql) + result = runner.invoke(cli, args=CLI_ARGS + ["-t"], input=sql) - expected = (dedent("""\ + expected = dedent("""\ +----------+ | count(*) | +----------+ @@ -118,7 +122,7 @@ def test_batch_mode_table(executor): | a | +-----+ | abc | - +-----+""")) + +-----+""") assert result.exit_code == 0 assert expected in result.output @@ -126,14 +130,13 @@ def test_batch_mode_table(executor): @dbtest def test_batch_mode_csv(executor): - run(executor, '''create table test(a text, b text)''') - run(executor, - '''insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')''') + run(executor, """create table test(a text, b text)""") + run(executor, """insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')""") - sql = 'select * from test;' + sql = "select * from test;" runner = CliRunner() - result = runner.invoke(cli, args=CLI_ARGS + ['--csv'], input=sql) + result = runner.invoke(cli, args=CLI_ARGS + ["--csv"], input=sql) expected = '"a","b"\n"abc","de\nf"\n"ghi","jkl"\n' @@ -150,15 +153,15 @@ def test_help_strings_end_with_periods(): """Make sure click options have help text that end with a period.""" for param in cli.params: if isinstance(param, click.core.Option): - assert hasattr(param, 'help') - assert param.help.endswith('.') + assert hasattr(param, "help") + assert param.help.endswith(".") def test_command_descriptions_end_with_periods(): """Make sure that mycli commands' descriptions end with a period.""" MyCli() for _, command in SPECIAL_COMMANDS.items(): - assert command[3].endswith('.') + assert command[3].endswith(".") def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager): @@ -166,23 +169,23 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager): clickoutput = "" m = MyCli(myclirc=default_config_file) - class TestOutput(): + class TestOutput: def get_size(self): - size = namedtuple('Size', 'rows columns') + size = namedtuple("Size", "rows columns") size.columns, size.rows = terminal_size return size - class TestExecute(): - host = 'test' - user = 'test' - dbname = 'test' - server_info = ServerInfo.from_version_string('unknown') + class TestExecute: + host = "test" + user = "test" + dbname = "test" + server_info = ServerInfo.from_version_string("unknown") port = 0 def server_type(self): - return ['test'] + return ["test"] - class PromptBuffer(): + class PromptBuffer: output = TestOutput() m.prompt_app = PromptBuffer() @@ -199,8 +202,8 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager): global clickoutput clickoutput += s + "\n" - monkeypatch.setattr(click, 'echo_via_pager', echo_via_pager) - monkeypatch.setattr(click, 'secho', secho) + monkeypatch.setattr(click, "echo_via_pager", echo_via_pager) + monkeypatch.setattr(click, "secho", secho) m.output(testdata) if clickoutput.endswith("\n"): clickoutput = clickoutput[:-1] @@ -208,59 +211,29 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager): def test_conditional_pager(monkeypatch): - testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split( - " ") + testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split(" ") # User didn't set pager, output doesn't fit screen -> pager - output( - monkeypatch, - terminal_size=(5, 10), - testdata=testdata, - explicit_pager=False, - expect_pager=True - ) + output(monkeypatch, terminal_size=(5, 10), testdata=testdata, explicit_pager=False, expect_pager=True) # User didn't set pager, output fits screen -> no pager - output( - monkeypatch, - terminal_size=(20, 20), - testdata=testdata, - explicit_pager=False, - expect_pager=False - ) + output(monkeypatch, terminal_size=(20, 20), testdata=testdata, explicit_pager=False, expect_pager=False) # User manually configured pager, output doesn't fit screen -> pager - output( - monkeypatch, - terminal_size=(5, 10), - testdata=testdata, - explicit_pager=True, - expect_pager=True - ) + output(monkeypatch, terminal_size=(5, 10), testdata=testdata, explicit_pager=True, expect_pager=True) # User manually configured pager, output fit screen -> pager - output( - monkeypatch, - terminal_size=(20, 20), - testdata=testdata, - explicit_pager=True, - expect_pager=True - ) + output(monkeypatch, terminal_size=(20, 20), testdata=testdata, explicit_pager=True, expect_pager=True) - SPECIAL_COMMANDS['nopager'].handler() - output( - monkeypatch, - terminal_size=(5, 10), - testdata=testdata, - explicit_pager=False, - expect_pager=False - ) - SPECIAL_COMMANDS['pager'].handler('') + SPECIAL_COMMANDS["nopager"].handler() + output(monkeypatch, terminal_size=(5, 10), testdata=testdata, explicit_pager=False, expect_pager=False) + SPECIAL_COMMANDS["pager"].handler("") def test_reserved_space_is_integer(monkeypatch): """Make sure that reserved space is returned as an integer.""" + def stub_terminal_size(): return (5, 5) with monkeypatch.context() as m: - m.setattr(shutil, 'get_terminal_size', stub_terminal_size) + m.setattr(shutil, "get_terminal_size", stub_terminal_size) mycli = MyCli() assert isinstance(mycli.get_reserved_space(), int) @@ -268,18 +241,20 @@ def test_reserved_space_is_integer(monkeypatch): def test_list_dsn(): runner = CliRunner() # keep Windows from locking the file with delete=False - with NamedTemporaryFile(mode="w",delete=False) as myclirc: - myclirc.write(dedent("""\ + with NamedTemporaryFile(mode="w", delete=False) as myclirc: + myclirc.write( + dedent("""\ [alias_dsn] test = mysql://test/test - """)) + """) + ) myclirc.flush() - args = ['--list-dsn', '--myclirc', myclirc.name] + args = ["--list-dsn", "--myclirc", myclirc.name] result = runner.invoke(cli, args=args) assert result.output == "test\n" - result = runner.invoke(cli, args=args + ['--verbose']) + result = runner.invoke(cli, args=args + ["--verbose"]) assert result.output == "test : mysql://test/test\n" - + # delete=False means we should try to clean up try: if os.path.exists(myclirc.name): @@ -287,41 +262,41 @@ def test_list_dsn(): except Exception as e: print(f"An error occurred while attempting to delete the file: {e}") - - def test_prettify_statement(): - statement = 'SELECT 1' + statement = "SELECT 1" m = MyCli() pretty_statement = m.handle_prettify_binding(statement) - assert pretty_statement == 'SELECT\n 1;' + assert pretty_statement == "SELECT\n 1;" def test_unprettify_statement(): - statement = 'SELECT\n 1' + statement = "SELECT\n 1" m = MyCli() unpretty_statement = m.handle_unprettify_binding(statement) - assert unpretty_statement == 'SELECT 1;' + assert unpretty_statement == "SELECT 1;" def test_list_ssh_config(): runner = CliRunner() # keep Windows from locking the file with delete=False - with NamedTemporaryFile(mode="w",delete=False) as ssh_config: - ssh_config.write(dedent("""\ + with NamedTemporaryFile(mode="w", delete=False) as ssh_config: + ssh_config.write( + dedent("""\ Host test Hostname test.example.com User joe Port 22222 IdentityFile ~/.ssh/gateway - """)) + """) + ) ssh_config.flush() - args = ['--list-ssh-config', '--ssh-config-path', ssh_config.name] + args = ["--list-ssh-config", "--ssh-config-path", ssh_config.name] result = runner.invoke(cli, args=args) assert "test\n" in result.output - result = runner.invoke(cli, args=args + ['--verbose']) + result = runner.invoke(cli, args=args + ["--verbose"]) assert "test : test.example.com\n" in result.output - + # delete=False means we should try to clean up try: if os.path.exists(ssh_config.name): @@ -343,7 +318,7 @@ def test_dsn(monkeypatch): pass class MockMyCli: - config = {'alias_dsn': {}} + config = {"alias_dsn": {}} def __init__(self, **args): self.logger = Logger() @@ -357,97 +332,109 @@ def test_dsn(monkeypatch): pass import mycli.main - monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + + monkeypatch.setattr(mycli.main, "MyCli", MockMyCli) runner = CliRunner() # When a user supplies a DSN as database argument to mycli, # use these values. - result = runner.invoke(mycli.main.cli, args=[ - "mysql://dsn_user:dsn_passwd@dsn_host:1/dsn_database"] - ) + result = runner.invoke(mycli.main.cli, args=["mysql://dsn_user:dsn_passwd@dsn_host:1/dsn_database"]) assert result.exit_code == 0, result.output + " " + str(result.exception) - assert \ - MockMyCli.connect_args["user"] == "dsn_user" and \ - MockMyCli.connect_args["passwd"] == "dsn_passwd" and \ - MockMyCli.connect_args["host"] == "dsn_host" and \ - MockMyCli.connect_args["port"] == 1 and \ - MockMyCli.connect_args["database"] == "dsn_database" + assert ( + MockMyCli.connect_args["user"] == "dsn_user" + and MockMyCli.connect_args["passwd"] == "dsn_passwd" + and MockMyCli.connect_args["host"] == "dsn_host" + and MockMyCli.connect_args["port"] == 1 + and MockMyCli.connect_args["database"] == "dsn_database" + ) MockMyCli.connect_args = None # When a use supplies a DSN as database argument to mycli, # and used command line arguments, use the command line # arguments. - result = runner.invoke(mycli.main.cli, args=[ - "mysql://dsn_user:dsn_passwd@dsn_host:2/dsn_database", - "--user", "arg_user", - "--password", "arg_password", - "--host", "arg_host", - "--port", "3", - "--database", "arg_database", - ]) + result = runner.invoke( + mycli.main.cli, + args=[ + "mysql://dsn_user:dsn_passwd@dsn_host:2/dsn_database", + "--user", + "arg_user", + "--password", + "arg_password", + "--host", + "arg_host", + "--port", + "3", + "--database", + "arg_database", + ], + ) assert result.exit_code == 0, result.output + " " + str(result.exception) - assert \ - MockMyCli.connect_args["user"] == "arg_user" and \ - MockMyCli.connect_args["passwd"] == "arg_password" and \ - MockMyCli.connect_args["host"] == "arg_host" and \ - MockMyCli.connect_args["port"] == 3 and \ - MockMyCli.connect_args["database"] == "arg_database" - - MockMyCli.config = { - 'alias_dsn': { - 'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database' - } - } + assert ( + MockMyCli.connect_args["user"] == "arg_user" + and MockMyCli.connect_args["passwd"] == "arg_password" + and MockMyCli.connect_args["host"] == "arg_host" + and MockMyCli.connect_args["port"] == 3 + and MockMyCli.connect_args["database"] == "arg_database" + ) + + MockMyCli.config = {"alias_dsn": {"test": "mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database"}} MockMyCli.connect_args = None # When a user uses a DSN from the configuration file (alias_dsn), # use these values. - result = runner.invoke(cli, args=['--dsn', 'test']) + result = runner.invoke(cli, args=["--dsn", "test"]) assert result.exit_code == 0, result.output + " " + str(result.exception) - assert \ - MockMyCli.connect_args["user"] == "alias_dsn_user" and \ - MockMyCli.connect_args["passwd"] == "alias_dsn_passwd" and \ - MockMyCli.connect_args["host"] == "alias_dsn_host" and \ - MockMyCli.connect_args["port"] == 4 and \ - MockMyCli.connect_args["database"] == "alias_dsn_database" - - MockMyCli.config = { - 'alias_dsn': { - 'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database' - } - } + assert ( + MockMyCli.connect_args["user"] == "alias_dsn_user" + and MockMyCli.connect_args["passwd"] == "alias_dsn_passwd" + and MockMyCli.connect_args["host"] == "alias_dsn_host" + and MockMyCli.connect_args["port"] == 4 + and MockMyCli.connect_args["database"] == "alias_dsn_database" + ) + + MockMyCli.config = {"alias_dsn": {"test": "mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database"}} MockMyCli.connect_args = None # When a user uses a DSN from the configuration file (alias_dsn) # and used command line arguments, use the command line arguments. - result = runner.invoke(cli, args=[ - '--dsn', 'test', '', - "--user", "arg_user", - "--password", "arg_password", - "--host", "arg_host", - "--port", "5", - "--database", "arg_database", - ]) + result = runner.invoke( + cli, + args=[ + "--dsn", + "test", + "", + "--user", + "arg_user", + "--password", + "arg_password", + "--host", + "arg_host", + "--port", + "5", + "--database", + "arg_database", + ], + ) assert result.exit_code == 0, result.output + " " + str(result.exception) - assert \ - MockMyCli.connect_args["user"] == "arg_user" and \ - MockMyCli.connect_args["passwd"] == "arg_password" and \ - MockMyCli.connect_args["host"] == "arg_host" and \ - MockMyCli.connect_args["port"] == 5 and \ - MockMyCli.connect_args["database"] == "arg_database" + assert ( + MockMyCli.connect_args["user"] == "arg_user" + and MockMyCli.connect_args["passwd"] == "arg_password" + and MockMyCli.connect_args["host"] == "arg_host" + and MockMyCli.connect_args["port"] == 5 + and MockMyCli.connect_args["database"] == "arg_database" + ) # Use a DSN without password - result = runner.invoke(mycli.main.cli, args=[ - "mysql://dsn_user@dsn_host:6/dsn_database"] - ) + result = runner.invoke(mycli.main.cli, args=["mysql://dsn_user@dsn_host:6/dsn_database"]) assert result.exit_code == 0, result.output + " " + str(result.exception) - assert \ - MockMyCli.connect_args["user"] == "dsn_user" and \ - MockMyCli.connect_args["passwd"] is None and \ - MockMyCli.connect_args["host"] == "dsn_host" and \ - MockMyCli.connect_args["port"] == 6 and \ - MockMyCli.connect_args["database"] == "dsn_database" + assert ( + MockMyCli.connect_args["user"] == "dsn_user" + and MockMyCli.connect_args["passwd"] is None + and MockMyCli.connect_args["host"] == "dsn_host" + and MockMyCli.connect_args["port"] == 6 + and MockMyCli.connect_args["database"] == "dsn_database" + ) def test_ssh_config(monkeypatch): @@ -463,7 +450,7 @@ def test_ssh_config(monkeypatch): pass class MockMyCli: - config = {'alias_dsn': {}} + config = {"alias_dsn": {}} def __init__(self, **args): self.logger = Logger() @@ -477,58 +464,62 @@ def test_ssh_config(monkeypatch): pass import mycli.main - monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + + monkeypatch.setattr(mycli.main, "MyCli", MockMyCli) runner = CliRunner() # Setup temporary configuration # keep Windows from locking the file with delete=False - with NamedTemporaryFile(mode="w",delete=False) as ssh_config: - ssh_config.write(dedent("""\ + with NamedTemporaryFile(mode="w", delete=False) as ssh_config: + ssh_config.write( + dedent("""\ Host test Hostname test.example.com User joe Port 22222 IdentityFile ~/.ssh/gateway - """)) + """) + ) ssh_config.flush() # When a user supplies a ssh config. - result = runner.invoke(mycli.main.cli, args=[ - "--ssh-config-path", - ssh_config.name, - "--ssh-config-host", - "test" - ]) - assert result.exit_code == 0, result.output + \ - " " + str(result.exception) - assert \ - MockMyCli.connect_args["ssh_user"] == "joe" and \ - MockMyCli.connect_args["ssh_host"] == "test.example.com" and \ - MockMyCli.connect_args["ssh_port"] == 22222 and \ - MockMyCli.connect_args["ssh_key_filename"] == os.path.expanduser( - "~") + "/.ssh/gateway" + result = runner.invoke(mycli.main.cli, args=["--ssh-config-path", ssh_config.name, "--ssh-config-host", "test"]) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert ( + MockMyCli.connect_args["ssh_user"] == "joe" + and MockMyCli.connect_args["ssh_host"] == "test.example.com" + and MockMyCli.connect_args["ssh_port"] == 22222 + and MockMyCli.connect_args["ssh_key_filename"] == os.path.expanduser("~") + "/.ssh/gateway" + ) # When a user supplies a ssh config host as argument to mycli, # and used command line arguments, use the command line # arguments. - result = runner.invoke(mycli.main.cli, args=[ - "--ssh-config-path", - ssh_config.name, - "--ssh-config-host", - "test", - "--ssh-user", "arg_user", - "--ssh-host", "arg_host", - "--ssh-port", "3", - "--ssh-key-filename", "/path/to/key" - ]) - assert result.exit_code == 0, result.output + \ - " " + str(result.exception) - assert \ - MockMyCli.connect_args["ssh_user"] == "arg_user" and \ - MockMyCli.connect_args["ssh_host"] == "arg_host" and \ - MockMyCli.connect_args["ssh_port"] == 3 and \ - MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key" - + result = runner.invoke( + mycli.main.cli, + args=[ + "--ssh-config-path", + ssh_config.name, + "--ssh-config-host", + "test", + "--ssh-user", + "arg_user", + "--ssh-host", + "arg_host", + "--ssh-port", + "3", + "--ssh-key-filename", + "/path/to/key", + ], + ) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert ( + MockMyCli.connect_args["ssh_user"] == "arg_user" + and MockMyCli.connect_args["ssh_host"] == "arg_host" + and MockMyCli.connect_args["ssh_port"] == 3 + and MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key" + ) + # delete=False means we should try to clean up try: if os.path.exists(ssh_config.name): @@ -542,9 +533,7 @@ def test_init_command_arg(executor): init_command = "set sql_select_limit=1000" sql = 'show variables like "sql_select_limit";' runner = CliRunner() - result = runner.invoke( - cli, args=CLI_ARGS + ["--init-command", init_command], input=sql - ) + result = runner.invoke(cli, args=CLI_ARGS + ["--init-command", init_command], input=sql) expected = "sql_select_limit\t1000\n" assert result.exit_code == 0 @@ -553,18 +542,13 @@ def test_init_command_arg(executor): @dbtest def test_init_command_multiple_arg(executor): - init_command = 'set sql_select_limit=2000; set max_join_size=20000' - sql = ( - 'show variables like "sql_select_limit";\n' - 'show variables like "max_join_size"' - ) + init_command = "set sql_select_limit=2000; set max_join_size=20000" + sql = 'show variables like "sql_select_limit";\n' 'show variables like "max_join_size"' runner = CliRunner() - result = runner.invoke( - cli, args=CLI_ARGS + ['--init-command', init_command], input=sql - ) + result = runner.invoke(cli, args=CLI_ARGS + ["--init-command", init_command], input=sql) - expected_sql_select_limit = 'sql_select_limit\t2000\n' - expected_max_join_size = 'max_join_size\t20000\n' + expected_sql_select_limit = "sql_select_limit\t2000\n" + expected_max_join_size = "max_join_size\t20000\n" assert result.exit_code == 0 assert expected_sql_select_limit in result.output diff --git a/test/test_naive_completion.py b/test/test_naive_completion.py index 0bc3bf8..31ac165 100644 --- a/test/test_naive_completion.py +++ b/test/test_naive_completion.py @@ -6,56 +6,48 @@ from prompt_toolkit.document import Document @pytest.fixture def completer(): import mycli.sqlcompleter as sqlcompleter + return sqlcompleter.SQLCompleter(smart_completion=False) @pytest.fixture def complete_event(): from unittest.mock import Mock + return Mock() def test_empty_string_completion(completer, complete_event): - text = '' + text = "" position = 0 - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list(map(Completion, completer.all_completions)) def test_select_keyword_completion(completer, complete_event): - text = 'SEL' - position = len('SEL') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == list([Completion(text='SELECT', start_position=-3)]) + text = "SEL" + position = len("SEL") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == list([Completion(text="SELECT", start_position=-3)]) def test_function_name_completion(completer, complete_event): - text = 'SELECT MA' - position = len('SELECT MA') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) + text = "SELECT MA" + position = len("SELECT MA") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert sorted(x.text for x in result) == ["MASTER", "MAX"] def test_column_name_completion(completer, complete_event): - text = 'SELECT FROM users' - position = len('SELECT ') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) + text = "SELECT FROM users" + position = len("SELECT ") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list(map(Completion, completer.all_completions)) def test_special_name_completion(completer, complete_event): - text = '\\' - position = len('\\') - result = set(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) + text = "\\" + position = len("\\") + result = set(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) # Special commands will NOT be suggested during naive completion mode. assert result == set() diff --git a/test/test_parseutils.py b/test/test_parseutils.py index 920a08d..0925299 100644 --- a/test/test_parseutils.py +++ b/test/test_parseutils.py @@ -1,67 +1,72 @@ import pytest from mycli.packages.parseutils import ( - extract_tables, query_starts_with, queries_start_with, is_destructive, query_has_where_clause, - is_dropping_database) + extract_tables, + query_starts_with, + queries_start_with, + is_destructive, + query_has_where_clause, + is_dropping_database, +) def test_empty_string(): - tables = extract_tables('') + tables = extract_tables("") assert tables == [] def test_simple_select_single_table(): - tables = extract_tables('select * from abc') - assert tables == [(None, 'abc', None)] + tables = extract_tables("select * from abc") + assert tables == [(None, "abc", None)] def test_simple_select_single_table_schema_qualified(): - tables = extract_tables('select * from abc.def') - assert tables == [('abc', 'def', None)] + tables = extract_tables("select * from abc.def") + assert tables == [("abc", "def", None)] def test_simple_select_multiple_tables(): - tables = extract_tables('select * from abc, def') - assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] + tables = extract_tables("select * from abc, def") + assert sorted(tables) == [(None, "abc", None), (None, "def", None)] def test_simple_select_multiple_tables_schema_qualified(): - tables = extract_tables('select * from abc.def, ghi.jkl') - assert sorted(tables) == [('abc', 'def', None), ('ghi', 'jkl', None)] + tables = extract_tables("select * from abc.def, ghi.jkl") + assert sorted(tables) == [("abc", "def", None), ("ghi", "jkl", None)] def test_simple_select_with_cols_single_table(): - tables = extract_tables('select a,b from abc') - assert tables == [(None, 'abc', None)] + tables = extract_tables("select a,b from abc") + assert tables == [(None, "abc", None)] def test_simple_select_with_cols_single_table_schema_qualified(): - tables = extract_tables('select a,b from abc.def') - assert tables == [('abc', 'def', None)] + tables = extract_tables("select a,b from abc.def") + assert tables == [("abc", "def", None)] def test_simple_select_with_cols_multiple_tables(): - tables = extract_tables('select a,b from abc, def') - assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] + tables = extract_tables("select a,b from abc, def") + assert sorted(tables) == [(None, "abc", None), (None, "def", None)] def test_simple_select_with_cols_multiple_tables_with_schema(): - tables = extract_tables('select a,b from abc.def, def.ghi') - assert sorted(tables) == [('abc', 'def', None), ('def', 'ghi', None)] + tables = extract_tables("select a,b from abc.def, def.ghi") + assert sorted(tables) == [("abc", "def", None), ("def", "ghi", None)] def test_select_with_hanging_comma_single_table(): - tables = extract_tables('select a, from abc') - assert tables == [(None, 'abc', None)] + tables = extract_tables("select a, from abc") + assert tables == [(None, "abc", None)] def test_select_with_hanging_comma_multiple_tables(): - tables = extract_tables('select a, from abc, def') - assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] + tables = extract_tables("select a, from abc, def") + assert sorted(tables) == [(None, "abc", None), (None, "def", None)] def test_select_with_hanging_period_multiple_tables(): - tables = extract_tables('SELECT t1. FROM tabl1 t1, tabl2 t2') - assert sorted(tables) == [(None, 'tabl1', 't1'), (None, 'tabl2', 't2')] + tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2") + assert sorted(tables) == [(None, "tabl1", "t1"), (None, "tabl2", "t2")] def test_simple_insert_single_table(): @@ -69,97 +74,80 @@ def test_simple_insert_single_table(): # sqlparse mistakenly assigns an alias to the table # assert tables == [(None, 'abc', None)] - assert tables == [(None, 'abc', 'abc')] + assert tables == [(None, "abc", "abc")] @pytest.mark.xfail def test_simple_insert_single_table_schema_qualified(): tables = extract_tables('insert into abc.def (id, name) values (1, "def")') - assert tables == [('abc', 'def', None)] + assert tables == [("abc", "def", None)] def test_simple_update_table(): - tables = extract_tables('update abc set id = 1') - assert tables == [(None, 'abc', None)] + tables = extract_tables("update abc set id = 1") + assert tables == [(None, "abc", None)] def test_simple_update_table_with_schema(): - tables = extract_tables('update abc.def set id = 1') - assert tables == [('abc', 'def', None)] + tables = extract_tables("update abc.def set id = 1") + assert tables == [("abc", "def", None)] def test_join_table(): - tables = extract_tables('SELECT * FROM abc a JOIN def d ON a.id = d.num') - assert sorted(tables) == [(None, 'abc', 'a'), (None, 'def', 'd')] + tables = extract_tables("SELECT * FROM abc a JOIN def d ON a.id = d.num") + assert sorted(tables) == [(None, "abc", "a"), (None, "def", "d")] def test_join_table_schema_qualified(): - tables = extract_tables( - 'SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num') - assert tables == [('abc', 'def', 'x'), ('ghi', 'jkl', 'y')] + tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num") + assert tables == [("abc", "def", "x"), ("ghi", "jkl", "y")] def test_join_as_table(): - tables = extract_tables('SELECT * FROM my_table AS m WHERE m.a > 5') - assert tables == [(None, 'my_table', 'm')] + tables = extract_tables("SELECT * FROM my_table AS m WHERE m.a > 5") + assert tables == [(None, "my_table", "m")] def test_query_starts_with(): - query = 'USE test;' - assert query_starts_with(query, ('use', )) is True + query = "USE test;" + assert query_starts_with(query, ("use",)) is True - query = 'DROP DATABASE test;' - assert query_starts_with(query, ('use', )) is False + query = "DROP DATABASE test;" + assert query_starts_with(query, ("use",)) is False def test_query_starts_with_comment(): - query = '# comment\nUSE test;' - assert query_starts_with(query, ('use', )) is True + query = "# comment\nUSE test;" + assert query_starts_with(query, ("use",)) is True def test_queries_start_with(): - sql = ( - '# comment\n' - 'show databases;' - 'use foo;' - ) - assert queries_start_with(sql, ('show', 'select')) is True - assert queries_start_with(sql, ('use', 'drop')) is True - assert queries_start_with(sql, ('delete', 'update')) is False + sql = "# comment\n" "show databases;" "use foo;" + assert queries_start_with(sql, ("show", "select")) is True + assert queries_start_with(sql, ("use", "drop")) is True + assert queries_start_with(sql, ("delete", "update")) is False def test_is_destructive(): - sql = ( - 'use test;\n' - 'show databases;\n' - 'drop database foo;' - ) + sql = "use test;\n" "show databases;\n" "drop database foo;" assert is_destructive(sql) is True def test_is_destructive_update_with_where_clause(): - sql = ( - 'use test;\n' - 'show databases;\n' - 'UPDATE test SET x = 1 WHERE id = 1;' - ) + sql = "use test;\n" "show databases;\n" "UPDATE test SET x = 1 WHERE id = 1;" assert is_destructive(sql) is False def test_is_destructive_update_without_where_clause(): - sql = ( - 'use test;\n' - 'show databases;\n' - 'UPDATE test SET x = 1;' - ) + sql = "use test;\n" "show databases;\n" "UPDATE test SET x = 1;" assert is_destructive(sql) is True @pytest.mark.parametrize( - ('sql', 'has_where_clause'), + ("sql", "has_where_clause"), [ - ('update test set dummy = 1;', False), - ('update test set dummy = 1 where id = 1);', True), + ("update test set dummy = 1;", False), + ("update test set dummy = 1 where id = 1);", True), ], ) def test_query_has_where_clause(sql, has_where_clause): @@ -167,24 +155,20 @@ def test_query_has_where_clause(sql, has_where_clause): @pytest.mark.parametrize( - ('sql', 'dbname', 'is_dropping'), + ("sql", "dbname", "is_dropping"), [ - ('select bar from foo', 'foo', False), - ('drop database "foo";', '`foo`', True), - ('drop schema foo', 'foo', True), - ('drop schema foo', 'bar', False), - ('drop database bar', 'foo', False), - ('drop database foo', None, False), - ('drop database foo; create database foo', 'foo', False), - ('drop database foo; create database bar', 'foo', True), - ('select bar from foo; drop database bazz', 'foo', False), - ('select bar from foo; drop database bazz', 'bazz', True), - ('-- dropping database \n ' - 'drop -- really dropping \n ' - 'schema abc -- now it is dropped', - 'abc', - True) - ] + ("select bar from foo", "foo", False), + ('drop database "foo";', "`foo`", True), + ("drop schema foo", "foo", True), + ("drop schema foo", "bar", False), + ("drop database bar", "foo", False), + ("drop database foo", None, False), + ("drop database foo; create database foo", "foo", False), + ("drop database foo; create database bar", "foo", True), + ("select bar from foo; drop database bazz", "foo", False), + ("select bar from foo; drop database bazz", "bazz", True), + ("-- dropping database \n " "drop -- really dropping \n " "schema abc -- now it is dropped", "abc", True), + ], ) def test_is_dropping_database(sql, dbname, is_dropping): assert is_dropping_database(sql, dbname) == is_dropping diff --git a/test/test_prompt_utils.py b/test/test_prompt_utils.py index 2373fac..625e022 100644 --- a/test/test_prompt_utils.py +++ b/test/test_prompt_utils.py @@ -4,8 +4,8 @@ from mycli.packages.prompt_utils import confirm_destructive_query def test_confirm_destructive_query_notty(): - stdin = click.get_text_stream('stdin') + stdin = click.get_text_stream("stdin") assert stdin.isatty() is False - sql = 'drop database foo;' + sql = "drop database foo;" assert confirm_destructive_query(sql) is None diff --git a/test/test_smart_completion_public_schema_only.py b/test/test_smart_completion_public_schema_only.py index 30b15ac..8ad40a4 100644 --- a/test/test_smart_completion_public_schema_only.py +++ b/test/test_smart_completion_public_schema_only.py @@ -43,49 +43,35 @@ def complete_event(): def test_special_name_completion(completer, complete_event): text = "\\d" position = len("\\d") - result = completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) assert result == [Completion(text="\\dt", start_position=-2)] def test_empty_string_completion(completer, complete_event): text = "" position = 0 - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) - assert ( - list(map(Completion, completer.keywords + completer.special_commands)) == result - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert list(map(Completion, completer.keywords + completer.special_commands)) == result def test_select_keyword_completion(completer, complete_event): text = "SEL" position = len("SEL") - result = completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) assert list(result) == list([Completion(text="SELECT", start_position=-3)]) def test_select_star(completer, complete_event): text = "SELECT * " position = len(text) - result = completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) assert list(result) == list(map(Completion, completer.keywords)) def test_table_completion(completer, complete_event): text = "SELECT * FROM " position = len(text) - result = completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) assert list(result) == list( [ Completion(text="users", start_position=0), @@ -99,9 +85,7 @@ def test_table_completion(completer, complete_event): def test_function_name_completion(completer, complete_event): text = "SELECT MA" position = len("SELECT MA") - result = completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) assert list(result) == list( [ Completion(text="MAX", start_position=-2), @@ -127,11 +111,7 @@ def test_suggested_column_names(completer, complete_event): """ text = "SELECT from users" position = len("SELECT ") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="*", start_position=0), @@ -157,9 +137,7 @@ def test_suggested_column_names_in_function(completer, complete_event): """ text = "SELECT MAX( from users" position = len("SELECT MAX(") - result = completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) assert list(result) == list( [ Completion(text="*", start_position=0), @@ -181,11 +159,7 @@ def test_suggested_column_names_with_table_dot(completer, complete_event): """ text = "SELECT users. from users" position = len("SELECT users.") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="*", start_position=0), @@ -207,11 +181,7 @@ def test_suggested_column_names_with_alias(completer, complete_event): """ text = "SELECT u. from users u" position = len("SELECT u.") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="*", start_position=0), @@ -234,11 +204,7 @@ def test_suggested_multiple_column_names(completer, complete_event): """ text = "SELECT id, from users u" position = len("SELECT id, ") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="*", start_position=0), @@ -264,11 +230,7 @@ def test_suggested_multiple_column_names_with_alias(completer, complete_event): """ text = "SELECT u.id, u. from users u" position = len("SELECT u.id, u.") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="*", start_position=0), @@ -291,11 +253,7 @@ def test_suggested_multiple_column_names_with_dot(completer, complete_event): """ text = "SELECT users.id, users. from users u" position = len("SELECT users.id, users.") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="*", start_position=0), @@ -310,11 +268,7 @@ def test_suggested_multiple_column_names_with_dot(completer, complete_event): def test_suggested_aliases_after_on(completer, complete_event): text = "SELECT u.name, o.id FROM users u JOIN orders o ON " position = len("SELECT u.name, o.id FROM users u JOIN orders o ON ") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="u", start_position=0), @@ -326,11 +280,7 @@ def test_suggested_aliases_after_on(completer, complete_event): def test_suggested_aliases_after_on_right_side(completer, complete_event): text = "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = " position = len("SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="u", start_position=0), @@ -342,11 +292,7 @@ def test_suggested_aliases_after_on_right_side(completer, complete_event): def test_suggested_tables_after_on(completer, complete_event): text = "SELECT users.name, orders.id FROM users JOIN orders ON " position = len("SELECT users.name, orders.id FROM users JOIN orders ON ") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="users", start_position=0), @@ -357,14 +303,8 @@ def test_suggested_tables_after_on(completer, complete_event): def test_suggested_tables_after_on_right_side(completer, complete_event): text = "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = " - position = len( - "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = " - ) - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + position = len("SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="users", start_position=0), @@ -376,11 +316,7 @@ def test_suggested_tables_after_on_right_side(completer, complete_event): def test_table_names_after_from(completer, complete_event): text = "SELECT * FROM " position = len("SELECT * FROM ") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="users", start_position=0), @@ -394,29 +330,21 @@ def test_table_names_after_from(completer, complete_event): def test_auto_escaped_col_names(completer, complete_event): text = "SELECT from `select`" position = len("SELECT ") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == [ Completion(text="*", start_position=0), Completion(text="id", start_position=0), Completion(text="`insert`", start_position=0), Completion(text="`ABC`", start_position=0), - ] + list(map(Completion, completer.functions)) + [ - Completion(text="select", start_position=0) - ] + list(map(Completion, completer.keywords)) + ] + list(map(Completion, completer.functions)) + [Completion(text="select", start_position=0)] + list( + map(Completion, completer.keywords) + ) def test_un_escaped_table_names(completer, complete_event): text = "SELECT from réveillé" position = len("SELECT ") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="*", start_position=0), @@ -464,10 +392,6 @@ def dummy_list_path(dir_name): ) def test_file_name_completion(completer, complete_event, text, expected): position = len(text) - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) expected = list((Completion(txt, pos) for txt, pos in expected)) assert result == expected diff --git a/test/test_special_iocommands.py b/test/test_special_iocommands.py index d0ca45f..bea5620 100644 --- a/test/test_special_iocommands.py +++ b/test/test_special_iocommands.py @@ -17,11 +17,11 @@ def test_set_get_pager(): assert mycli.packages.special.is_pager_enabled() mycli.packages.special.set_pager_enabled(False) assert not mycli.packages.special.is_pager_enabled() - mycli.packages.special.set_pager('less') - assert os.environ['PAGER'] == "less" + mycli.packages.special.set_pager("less") + assert os.environ["PAGER"] == "less" mycli.packages.special.set_pager(False) - assert os.environ['PAGER'] == "less" - del os.environ['PAGER'] + assert os.environ["PAGER"] == "less" + del os.environ["PAGER"] mycli.packages.special.set_pager(False) mycli.packages.special.disable_pager() assert not mycli.packages.special.is_pager_enabled() @@ -42,45 +42,44 @@ def test_set_get_expanded_output(): def test_editor_command(): - assert mycli.packages.special.editor_command(r'hello\e') - assert mycli.packages.special.editor_command(r'\ehello') - assert not mycli.packages.special.editor_command(r'hello') + assert mycli.packages.special.editor_command(r"hello\e") + assert mycli.packages.special.editor_command(r"\ehello") + assert not mycli.packages.special.editor_command(r"hello") - assert mycli.packages.special.get_filename(r'\e filename') == "filename" + assert mycli.packages.special.get_filename(r"\e filename") == "filename" - os.environ['EDITOR'] = 'true' - os.environ['VISUAL'] = 'true' + os.environ["EDITOR"] = "true" + os.environ["VISUAL"] = "true" # Set the editor to Notepad on Windows - if os.name != 'nt': - mycli.packages.special.open_external_editor(sql=r'select 1') == "select 1" + if os.name != "nt": + mycli.packages.special.open_external_editor(sql=r"select 1") == "select 1" else: - pytest.skip('Skipping on Windows platform.') - + pytest.skip("Skipping on Windows platform.") def test_tee_command(): - mycli.packages.special.write_tee(u"hello world") # write without file set + mycli.packages.special.write_tee("hello world") # write without file set # keep Windows from locking the file with delete=False with tempfile.NamedTemporaryFile(delete=False) as f: - mycli.packages.special.execute(None, u"tee " + f.name) - mycli.packages.special.write_tee(u"hello world") - if os.name=='nt': + mycli.packages.special.execute(None, "tee " + f.name) + mycli.packages.special.write_tee("hello world") + if os.name == "nt": assert f.read() == b"hello world\r\n" else: assert f.read() == b"hello world\n" - mycli.packages.special.execute(None, u"tee -o " + f.name) - mycli.packages.special.write_tee(u"hello world") + mycli.packages.special.execute(None, "tee -o " + f.name) + mycli.packages.special.write_tee("hello world") f.seek(0) - if os.name=='nt': + if os.name == "nt": assert f.read() == b"hello world\r\n" else: assert f.read() == b"hello world\n" - mycli.packages.special.execute(None, u"notee") - mycli.packages.special.write_tee(u"hello world") + mycli.packages.special.execute(None, "notee") + mycli.packages.special.write_tee("hello world") f.seek(0) - if os.name=='nt': + if os.name == "nt": assert f.read() == b"hello world\r\n" else: assert f.read() == b"hello world\n" @@ -92,52 +91,49 @@ def test_tee_command(): os.remove(f.name) except Exception as e: print(f"An error occurred while attempting to delete the file: {e}") - def test_tee_command_error(): with pytest.raises(TypeError): - mycli.packages.special.execute(None, 'tee') + mycli.packages.special.execute(None, "tee") with pytest.raises(OSError): with tempfile.NamedTemporaryFile() as f: os.chmod(f.name, stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH) - mycli.packages.special.execute(None, 'tee {}'.format(f.name)) + mycli.packages.special.execute(None, "tee {}".format(f.name)) @dbtest - @pytest.mark.skipif(os.name == "nt", reason="Bug: fails on Windows, needs fixing, singleton of FQ not working right") def test_favorite_query(): with db_connection().cursor() as cur: - query = u'select "✔"' - mycli.packages.special.execute(cur, u'\\fs check {0}'.format(query)) - assert next(mycli.packages.special.execute( - cur, u'\\f check'))[0] == "> " + query + query = 'select "✔"' + mycli.packages.special.execute(cur, "\\fs check {0}".format(query)) + assert next(mycli.packages.special.execute(cur, "\\f check"))[0] == "> " + query def test_once_command(): with pytest.raises(TypeError): - mycli.packages.special.execute(None, u"\\once") + mycli.packages.special.execute(None, "\\once") with pytest.raises(OSError): - mycli.packages.special.execute(None, u"\\once /proc/access-denied") + mycli.packages.special.execute(None, "\\once /proc/access-denied") - mycli.packages.special.write_once(u"hello world") # write without file set + mycli.packages.special.write_once("hello world") # write without file set # keep Windows from locking the file with delete=False with tempfile.NamedTemporaryFile(delete=False) as f: - mycli.packages.special.execute(None, u"\\once " + f.name) - mycli.packages.special.write_once(u"hello world") - if os.name=='nt': + mycli.packages.special.execute(None, "\\once " + f.name) + mycli.packages.special.write_once("hello world") + if os.name == "nt": assert f.read() == b"hello world\r\n" else: assert f.read() == b"hello world\n" - mycli.packages.special.execute(None, u"\\once -o " + f.name) - mycli.packages.special.write_once(u"hello world line 1") - mycli.packages.special.write_once(u"hello world line 2") + mycli.packages.special.execute(None, "\\once -o " + f.name) + mycli.packages.special.write_once("hello world line 1") + mycli.packages.special.write_once("hello world line 2") f.seek(0) - if os.name=='nt': + if os.name == "nt": assert f.read() == b"hello world line 1\r\nhello world line 2\r\n" else: assert f.read() == b"hello world line 1\nhello world line 2\n" @@ -151,52 +147,47 @@ def test_once_command(): def test_pipe_once_command(): with pytest.raises(IOError): - mycli.packages.special.execute(None, u"\\pipe_once") + mycli.packages.special.execute(None, "\\pipe_once") with pytest.raises(OSError): - mycli.packages.special.execute( - None, u"\\pipe_once /proc/access-denied") + mycli.packages.special.execute(None, "\\pipe_once /proc/access-denied") - if os.name == 'nt': - mycli.packages.special.execute(None, u'\\pipe_once python -c "import sys; print(len(sys.stdin.read().strip()))"') - mycli.packages.special.write_once(u"hello world") + if os.name == "nt": + mycli.packages.special.execute(None, '\\pipe_once python -c "import sys; print(len(sys.stdin.read().strip()))"') + mycli.packages.special.write_once("hello world") mycli.packages.special.unset_pipe_once_if_written() else: - mycli.packages.special.execute(None, u"\\pipe_once wc") - mycli.packages.special.write_once(u"hello world") - mycli.packages.special.unset_pipe_once_if_written() - # how to assert on wc output? + with tempfile.NamedTemporaryFile() as f: + mycli.packages.special.execute(None, "\\pipe_once tee " + f.name) + mycli.packages.special.write_pipe_once("hello world") + mycli.packages.special.unset_pipe_once_if_written() + f.seek(0) + assert f.read() == b"hello world\n" def test_parseargfile(): """Test that parseargfile expands the user directory.""" - expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'), - 'mode': 'a'} - - if os.name=='nt': - assert expected == mycli.packages.special.iocommands.parseargfile( - '~\\filename') + expected = {"file": os.path.join(os.path.expanduser("~"), "filename"), "mode": "a"} + + if os.name == "nt": + assert expected == mycli.packages.special.iocommands.parseargfile("~\\filename") else: - assert expected == mycli.packages.special.iocommands.parseargfile( - '~/filename') - - expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'), - 'mode': 'w'} - if os.name=='nt': - assert expected == mycli.packages.special.iocommands.parseargfile( - '-o ~\\filename') + assert expected == mycli.packages.special.iocommands.parseargfile("~/filename") + + expected = {"file": os.path.join(os.path.expanduser("~"), "filename"), "mode": "w"} + if os.name == "nt": + assert expected == mycli.packages.special.iocommands.parseargfile("-o ~\\filename") else: - assert expected == mycli.packages.special.iocommands.parseargfile( - '-o ~/filename') + assert expected == mycli.packages.special.iocommands.parseargfile("-o ~/filename") def test_parseargfile_no_file(): """Test that parseargfile raises a TypeError if there is no filename.""" with pytest.raises(TypeError): - mycli.packages.special.iocommands.parseargfile('') + mycli.packages.special.iocommands.parseargfile("") with pytest.raises(TypeError): - mycli.packages.special.iocommands.parseargfile('-o ') + mycli.packages.special.iocommands.parseargfile("-o ") @dbtest @@ -205,11 +196,9 @@ def test_watch_query_iteration(): the desired query and returns the given results.""" expected_value = "1" query = "SELECT {0!s}".format(expected_value) - expected_title = '> {0!s}'.format(query) + expected_title = "> {0!s}".format(query) with db_connection().cursor() as cur: - result = next(mycli.packages.special.iocommands.watch_query( - arg=query, cur=cur - )) + result = next(mycli.packages.special.iocommands.watch_query(arg=query, cur=cur)) assert result[0] == expected_title assert result[2][0] == expected_value @@ -230,14 +219,12 @@ def test_watch_query_full(): wait_interval = 1 expected_value = "1" query = "SELECT {0!s}".format(expected_value) - expected_title = '> {0!s}'.format(query) + expected_title = "> {0!s}".format(query) expected_results = 4 ctrl_c_process = send_ctrl_c(wait_interval) with db_connection().cursor() as cur: results = list( - result for result in mycli.packages.special.iocommands.watch_query( - arg='{0!s} {1!s}'.format(watch_seconds, query), cur=cur - ) + result for result in mycli.packages.special.iocommands.watch_query(arg="{0!s} {1!s}".format(watch_seconds, query), cur=cur) ) ctrl_c_process.join(1) assert len(results) == expected_results @@ -247,14 +234,12 @@ def test_watch_query_full(): @dbtest -@patch('click.clear') +@patch("click.clear") def test_watch_query_clear(clear_mock): """Test that the screen is cleared with the -c flag of `watch` command before execute the query.""" with db_connection().cursor() as cur: - watch_gen = mycli.packages.special.iocommands.watch_query( - arg='0.1 -c select 1;', cur=cur - ) + watch_gen = mycli.packages.special.iocommands.watch_query(arg="0.1 -c select 1;", cur=cur) assert not clear_mock.called next(watch_gen) assert clear_mock.called @@ -271,19 +256,20 @@ def test_watch_query_bad_arguments(): watch_query = mycli.packages.special.iocommands.watch_query with db_connection().cursor() as cur: with pytest.raises(ProgrammingError): - next(watch_query('a select 1;', cur=cur)) + next(watch_query("a select 1;", cur=cur)) with pytest.raises(ProgrammingError): - next(watch_query('-a select 1;', cur=cur)) + next(watch_query("-a select 1;", cur=cur)) with pytest.raises(ProgrammingError): - next(watch_query('1 -a select 1;', cur=cur)) + next(watch_query("1 -a select 1;", cur=cur)) with pytest.raises(ProgrammingError): - next(watch_query('-c -a select 1;', cur=cur)) + next(watch_query("-c -a select 1;", cur=cur)) @dbtest -@patch('click.clear') +@patch("click.clear") def test_watch_query_interval_clear(clear_mock): """Test `watch` command with interval and clear flag.""" + def test_asserts(gen): clear_mock.reset_mock() start = time() @@ -296,46 +282,32 @@ def test_watch_query_interval_clear(clear_mock): seconds = 1.0 watch_query = mycli.packages.special.iocommands.watch_query with db_connection().cursor() as cur: - test_asserts(watch_query('{0!s} -c select 1;'.format(seconds), - cur=cur)) - test_asserts(watch_query('-c {0!s} select 1;'.format(seconds), - cur=cur)) + test_asserts(watch_query("{0!s} -c select 1;".format(seconds), cur=cur)) + test_asserts(watch_query("-c {0!s} select 1;".format(seconds), cur=cur)) def test_split_sql_by_delimiter(): - for delimiter_str in (';', '$', '😀'): + for delimiter_str in (";", "$", "😀"): mycli.packages.special.set_delimiter(delimiter_str) sql_input = "select 1{} select \ufffc2".format(delimiter_str) - queries = ( - "select 1", - "select \ufffc2" - ) - for query, parsed_query in zip( - queries, mycli.packages.special.split_queries(sql_input)): - assert(query == parsed_query) + queries = ("select 1", "select \ufffc2") + for query, parsed_query in zip(queries, mycli.packages.special.split_queries(sql_input)): + assert query == parsed_query def test_switch_delimiter_within_query(): - mycli.packages.special.set_delimiter(';') + mycli.packages.special.set_delimiter(";") sql_input = "select 1; delimiter $$ select 2 $$ select 3 $$" - queries = ( - "select 1", - "delimiter $$ select 2 $$ select 3 $$", - "select 2", - "select 3" - ) - for query, parsed_query in zip( - queries, - mycli.packages.special.split_queries(sql_input)): - assert(query == parsed_query) + queries = ("select 1", "delimiter $$ select 2 $$ select 3 $$", "select 2", "select 3") + for query, parsed_query in zip(queries, mycli.packages.special.split_queries(sql_input)): + assert query == parsed_query def test_set_delimiter(): - - for delim in ('foo', 'bar'): + for delim in ("foo", "bar"): mycli.packages.special.set_delimiter(delim) assert mycli.packages.special.get_current_delimiter() == delim def teardown_function(): - mycli.packages.special.set_delimiter(';') + mycli.packages.special.set_delimiter(";") diff --git a/test/test_sqlexecute.py b/test/test_sqlexecute.py index ca186bc..17e082b 100644 --- a/test/test_sqlexecute.py +++ b/test/test_sqlexecute.py @@ -7,14 +7,11 @@ from mycli.sqlexecute import ServerInfo, ServerSpecies from .utils import run, dbtest, set_expanded_output, is_expanded_output -def assert_result_equal(result, title=None, rows=None, headers=None, - status=None, auto_status=True, assert_contains=False): +def assert_result_equal(result, title=None, rows=None, headers=None, status=None, auto_status=True, assert_contains=False): """Assert that an sqlexecute.run() result matches the expected values.""" if status is None and auto_status and rows: - status = '{} row{} in set'.format( - len(rows), 's' if len(rows) > 1 else '') - fields = {'title': title, 'rows': rows, 'headers': headers, - 'status': status} + status = "{} row{} in set".format(len(rows), "s" if len(rows) > 1 else "") + fields = {"title": title, "rows": rows, "headers": headers, "status": status} if assert_contains: # Do a loose match on the results using the *in* operator. @@ -28,34 +25,35 @@ def assert_result_equal(result, title=None, rows=None, headers=None, @dbtest def test_conn(executor): - run(executor, '''create table test(a text)''') - run(executor, '''insert into test values('abc')''') - results = run(executor, '''select * from test''') + run(executor, """create table test(a text)""") + run(executor, """insert into test values('abc')""") + results = run(executor, """select * from test""") - assert_result_equal(results, headers=['a'], rows=[('abc',)]) + assert_result_equal(results, headers=["a"], rows=[("abc",)]) @dbtest def test_bools(executor): - run(executor, '''create table test(a boolean)''') - run(executor, '''insert into test values(True)''') - results = run(executor, '''select * from test''') + run(executor, """create table test(a boolean)""") + run(executor, """insert into test values(True)""") + results = run(executor, """select * from test""") - assert_result_equal(results, headers=['a'], rows=[(1,)]) + assert_result_equal(results, headers=["a"], rows=[(1,)]) @dbtest def test_binary(executor): - run(executor, '''create table bt(geom linestring NOT NULL)''') - run(executor, "INSERT INTO bt VALUES " - "(ST_GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));") - results = run(executor, '''select * from bt''') - - geom = (b'\x00\x00\x00\x00\x01\x02\x00\x00\x00\x02\x00\x00\x009\x7f\x13\n' - b'\x11\x18]@4\xf4Op\xb1\xdeC@\x00\x00\x00\x00\x00\x18]@B>\xe8\xd9' - b'\xac\xdeC@') + run(executor, """create table bt(geom linestring NOT NULL)""") + run(executor, "INSERT INTO bt VALUES " "(ST_GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));") + results = run(executor, """select * from bt""") + + geom = ( + b"\x00\x00\x00\x00\x01\x02\x00\x00\x00\x02\x00\x00\x009\x7f\x13\n" + b"\x11\x18]@4\xf4Op\xb1\xdeC@\x00\x00\x00\x00\x00\x18]@B>\xe8\xd9" + b"\xac\xdeC@" + ) - assert_result_equal(results, headers=['geom'], rows=[(geom,)]) + assert_result_equal(results, headers=["geom"], rows=[(geom,)]) @dbtest @@ -63,49 +61,48 @@ def test_table_and_columns_query(executor): run(executor, "create table a(x text, y text)") run(executor, "create table b(z text)") - assert set(executor.tables()) == set([('a',), ('b',)]) - assert set(executor.table_columns()) == set( - [('a', 'x'), ('a', 'y'), ('b', 'z')]) + assert set(executor.tables()) == set([("a",), ("b",)]) + assert set(executor.table_columns()) == set([("a", "x"), ("a", "y"), ("b", "z")]) @dbtest def test_database_list(executor): databases = executor.databases() - assert 'mycli_test_db' in databases + assert "mycli_test_db" in databases @dbtest def test_invalid_syntax(executor): with pytest.raises(pymysql.ProgrammingError) as excinfo: - run(executor, 'invalid syntax!') - assert 'You have an error in your SQL syntax;' in str(excinfo.value) + run(executor, "invalid syntax!") + assert "You have an error in your SQL syntax;" in str(excinfo.value) @dbtest def test_invalid_column_name(executor): with pytest.raises(pymysql.err.OperationalError) as excinfo: - run(executor, 'select invalid command') + run(executor, "select invalid command") assert "Unknown column 'invalid' in 'field list'" in str(excinfo.value) @dbtest def test_unicode_support_in_output(executor): run(executor, "create table unicodechars(t text)") - run(executor, u"insert into unicodechars (t) values ('é')") + run(executor, "insert into unicodechars (t) values ('é')") # See issue #24, this raises an exception without proper handling - results = run(executor, u"select * from unicodechars") - assert_result_equal(results, headers=['t'], rows=[(u'é',)]) + results = run(executor, "select * from unicodechars") + assert_result_equal(results, headers=["t"], rows=[("é",)]) @dbtest def test_multiple_queries_same_line(executor): results = run(executor, "select 'foo'; select 'bar'") - expected = [{'title': None, 'headers': ['foo'], 'rows': [('foo',)], - 'status': '1 row in set'}, - {'title': None, 'headers': ['bar'], 'rows': [('bar',)], - 'status': '1 row in set'}] + expected = [ + {"title": None, "headers": ["foo"], "rows": [("foo",)], "status": "1 row in set"}, + {"title": None, "headers": ["bar"], "rows": [("bar",)], "status": "1 row in set"}, + ] assert expected == results @@ -113,7 +110,7 @@ def test_multiple_queries_same_line(executor): def test_multiple_queries_same_line_syntaxerror(executor): with pytest.raises(pymysql.ProgrammingError) as excinfo: run(executor, "select 'foo'; invalid syntax") - assert 'You have an error in your SQL syntax;' in str(excinfo.value) + assert "You have an error in your SQL syntax;" in str(excinfo.value) @dbtest @@ -125,15 +122,13 @@ def test_favorite_query(executor): run(executor, "insert into test values('def')") results = run(executor, "\\fs test-a select * from test where a like 'a%'") - assert_result_equal(results, status='Saved.') + assert_result_equal(results, status="Saved.") results = run(executor, "\\f test-a") - assert_result_equal(results, - title="> select * from test where a like 'a%'", - headers=['a'], rows=[('abc',)], auto_status=False) + assert_result_equal(results, title="> select * from test where a like 'a%'", headers=["a"], rows=[("abc",)], auto_status=False) results = run(executor, "\\fd test-a") - assert_result_equal(results, status='test-a: Deleted') + assert_result_equal(results, status="test-a: Deleted") @dbtest @@ -144,158 +139,147 @@ def test_favorite_query_multiple_statement(executor): run(executor, "insert into test values('abc')") run(executor, "insert into test values('def')") - results = run(executor, - "\\fs test-ad select * from test where a like 'a%'; " - "select * from test where a like 'd%'") - assert_result_equal(results, status='Saved.') + results = run(executor, "\\fs test-ad select * from test where a like 'a%'; " "select * from test where a like 'd%'") + assert_result_equal(results, status="Saved.") results = run(executor, "\\f test-ad") - expected = [{'title': "> select * from test where a like 'a%'", - 'headers': ['a'], 'rows': [('abc',)], 'status': None}, - {'title': "> select * from test where a like 'd%'", - 'headers': ['a'], 'rows': [('def',)], 'status': None}] + expected = [ + {"title": "> select * from test where a like 'a%'", "headers": ["a"], "rows": [("abc",)], "status": None}, + {"title": "> select * from test where a like 'd%'", "headers": ["a"], "rows": [("def",)], "status": None}, + ] assert expected == results results = run(executor, "\\fd test-ad") - assert_result_equal(results, status='test-ad: Deleted') + assert_result_equal(results, status="test-ad: Deleted") @dbtest @pytest.mark.skipif(os.name == "nt", reason="Bug: fails on Windows, needs fixing, singleton of FQ not working right") def test_favorite_query_expanded_output(executor): set_expanded_output(False) - run(executor, '''create table test(a text)''') - run(executor, '''insert into test values('abc')''') + run(executor, """create table test(a text)""") + run(executor, """insert into test values('abc')""") results = run(executor, "\\fs test-ae select * from test") - assert_result_equal(results, status='Saved.') + assert_result_equal(results, status="Saved.") results = run(executor, "\\f test-ae \\G") assert is_expanded_output() is True - assert_result_equal(results, title='> select * from test', - headers=['a'], rows=[('abc',)], auto_status=False) + assert_result_equal(results, title="> select * from test", headers=["a"], rows=[("abc",)], auto_status=False) set_expanded_output(False) results = run(executor, "\\fd test-ae") - assert_result_equal(results, status='test-ae: Deleted') + assert_result_equal(results, status="test-ae: Deleted") @dbtest def test_special_command(executor): - results = run(executor, '\\?') - assert_result_equal(results, rows=('quit', '\\q', 'Quit.'), - headers='Command', assert_contains=True, - auto_status=False) + results = run(executor, "\\?") + assert_result_equal(results, rows=("quit", "\\q", "Quit."), headers="Command", assert_contains=True, auto_status=False) @dbtest def test_cd_command_without_a_folder_name(executor): - results = run(executor, 'system cd') - assert_result_equal(results, status='No folder name was provided.') + results = run(executor, "system cd") + assert_result_equal(results, status="No folder name was provided.") @dbtest def test_system_command_not_found(executor): - results = run(executor, 'system xyz') - if os.name=='nt': - assert_result_equal(results, status='OSError: The system cannot find the file specified', - assert_contains=True) + results = run(executor, "system xyz") + if os.name == "nt": + assert_result_equal(results, status="OSError: The system cannot find the file specified", assert_contains=True) else: - assert_result_equal(results, status='OSError: No such file or directory', - assert_contains=True) + assert_result_equal(results, status="OSError: No such file or directory", assert_contains=True) @dbtest def test_system_command_output(executor): eol = os.linesep test_dir = os.path.abspath(os.path.dirname(__file__)) - test_file_path = os.path.join(test_dir, 'test.txt') - results = run(executor, 'system cat {0}'.format(test_file_path)) - assert_result_equal(results, status=f'mycli rocks!{eol}') + test_file_path = os.path.join(test_dir, "test.txt") + results = run(executor, "system cat {0}".format(test_file_path)) + assert_result_equal(results, status=f"mycli rocks!{eol}") @dbtest def test_cd_command_current_dir(executor): test_path = os.path.abspath(os.path.dirname(__file__)) - run(executor, 'system cd {0}'.format(test_path)) + run(executor, "system cd {0}".format(test_path)) assert os.getcwd() == test_path @dbtest def test_unicode_support(executor): - results = run(executor, u"SELECT '日本語' AS japanese;") - assert_result_equal(results, headers=['japanese'], rows=[(u'日本語',)]) + results = run(executor, "SELECT '日本語' AS japanese;") + assert_result_equal(results, headers=["japanese"], rows=[("日本語",)]) @dbtest def test_timestamp_null(executor): - run(executor, '''create table ts_null(a timestamp null)''') - run(executor, '''insert into ts_null values(null)''') - results = run(executor, '''select * from ts_null''') - assert_result_equal(results, headers=['a'], - rows=[(None,)]) + run(executor, """create table ts_null(a timestamp null)""") + run(executor, """insert into ts_null values(null)""") + results = run(executor, """select * from ts_null""") + assert_result_equal(results, headers=["a"], rows=[(None,)]) @dbtest def test_datetime_null(executor): - run(executor, '''create table dt_null(a datetime null)''') - run(executor, '''insert into dt_null values(null)''') - results = run(executor, '''select * from dt_null''') - assert_result_equal(results, headers=['a'], - rows=[(None,)]) + run(executor, """create table dt_null(a datetime null)""") + run(executor, """insert into dt_null values(null)""") + results = run(executor, """select * from dt_null""") + assert_result_equal(results, headers=["a"], rows=[(None,)]) @dbtest def test_date_null(executor): - run(executor, '''create table date_null(a date null)''') - run(executor, '''insert into date_null values(null)''') - results = run(executor, '''select * from date_null''') - assert_result_equal(results, headers=['a'], rows=[(None,)]) + run(executor, """create table date_null(a date null)""") + run(executor, """insert into date_null values(null)""") + results = run(executor, """select * from date_null""") + assert_result_equal(results, headers=["a"], rows=[(None,)]) @dbtest def test_time_null(executor): - run(executor, '''create table time_null(a time null)''') - run(executor, '''insert into time_null values(null)''') - results = run(executor, '''select * from time_null''') - assert_result_equal(results, headers=['a'], rows=[(None,)]) + run(executor, """create table time_null(a time null)""") + run(executor, """insert into time_null values(null)""") + results = run(executor, """select * from time_null""") + assert_result_equal(results, headers=["a"], rows=[(None,)]) @dbtest def test_multiple_results(executor): - query = '''CREATE PROCEDURE dmtest() + query = """CREATE PROCEDURE dmtest() BEGIN SELECT 1; SELECT 2; - END''' + END""" executor.conn.cursor().execute(query) - results = run(executor, 'call dmtest;') + results = run(executor, "call dmtest;") expected = [ - {'title': None, 'rows': [(1,)], 'headers': ['1'], - 'status': '1 row in set'}, - {'title': None, 'rows': [(2,)], 'headers': ['2'], - 'status': '1 row in set'} + {"title": None, "rows": [(1,)], "headers": ["1"], "status": "1 row in set"}, + {"title": None, "rows": [(2,)], "headers": ["2"], "status": "1 row in set"}, ] assert results == expected @pytest.mark.parametrize( - 'version_string, species, parsed_version_string, version', + "version_string, species, parsed_version_string, version", ( - ('5.7.25-TiDB-v6.1.0','TiDB', '6.1.0', 60100), - ('8.0.11-TiDB-v7.2.0-alpha-69-g96e9e68daa', 'TiDB', '7.2.0', 70200), - ('5.7.32-35', 'Percona', '5.7.32', 50732), - ('5.7.32-0ubuntu0.18.04.1', 'MySQL', '5.7.32', 50732), - ('10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508), - ('5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508), - ('5.0.16-pro-nt-log', 'MySQL', '5.0.16', 50016), - ('5.1.5a-alpha', 'MySQL', '5.1.5', 50105), - ('unexpected version string', None, '', 0), - ('', None, '', 0), - (None, None, '', 0), - ) + ("5.7.25-TiDB-v6.1.0", "TiDB", "6.1.0", 60100), + ("8.0.11-TiDB-v7.2.0-alpha-69-g96e9e68daa", "TiDB", "7.2.0", 70200), + ("5.7.32-35", "Percona", "5.7.32", 50732), + ("5.7.32-0ubuntu0.18.04.1", "MySQL", "5.7.32", 50732), + ("10.5.8-MariaDB-1:10.5.8+maria~focal", "MariaDB", "10.5.8", 100508), + ("5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal", "MariaDB", "10.5.8", 100508), + ("5.0.16-pro-nt-log", "MySQL", "5.0.16", 50016), + ("5.1.5a-alpha", "MySQL", "5.1.5", 50105), + ("unexpected version string", None, "", 0), + ("", None, "", 0), + (None, None, "", 0), + ), ) def test_version_parsing(version_string, species, parsed_version_string, version): server_info = ServerInfo.from_version_string(version_string) diff --git a/test/test_tabular_output.py b/test/test_tabular_output.py index bdc1dbf..45e97af 100644 --- a/test/test_tabular_output.py +++ b/test/test_tabular_output.py @@ -2,8 +2,6 @@ from textwrap import dedent -from mycli.packages.tabular_output import sql_format -from cli_helpers.tabular_output import TabularOutputFormatter from .utils import USER, PASSWORD, HOST, PORT, dbtest @@ -23,20 +21,17 @@ def mycli(): @dbtest def test_sql_output(mycli): """Test the sql output adapter.""" - headers = ['letters', 'number', 'optional', 'float', 'binary'] + headers = ["letters", "number", "optional", "float", "binary"] class FakeCursor(object): def __init__(self): - self.data = [ - ('abc', 1, None, 10.0, b'\xAA'), - ('d', 456, '1', 0.5, b'\xAA\xBB') - ] + self.data = [("abc", 1, None, 10.0, b"\xaa"), ("d", 456, "1", 0.5, b"\xaa\xbb")] self.description = [ (None, FIELD_TYPE.VARCHAR), (None, FIELD_TYPE.LONG), (None, FIELD_TYPE.LONG), (None, FIELD_TYPE.FLOAT), - (None, FIELD_TYPE.BLOB) + (None, FIELD_TYPE.BLOB), ] def __iter__(self): @@ -52,12 +47,11 @@ def test_sql_output(mycli): return self.description # Test sql-update output format - assert list(mycli.change_table_format("sql-update")) == \ - [(None, None, None, 'Changed table format to sql-update')] + assert list(mycli.change_table_format("sql-update")) == [(None, None, None, "Changed table format to sql-update")] mycli.formatter.query = "" output = mycli.format_output(None, FakeCursor(), headers) actual = "\n".join(output) - assert actual == dedent('''\ + assert actual == dedent("""\ UPDATE `DUAL` SET `number` = 1 , `optional` = NULL @@ -69,13 +63,12 @@ def test_sql_output(mycli): , `optional` = '1' , `float` = 0.5e0 , `binary` = X'aabb' - WHERE `letters` = 'd';''') + WHERE `letters` = 'd';""") # Test sql-update-2 output format - assert list(mycli.change_table_format("sql-update-2")) == \ - [(None, None, None, 'Changed table format to sql-update-2')] + assert list(mycli.change_table_format("sql-update-2")) == [(None, None, None, "Changed table format to sql-update-2")] mycli.formatter.query = "" output = mycli.format_output(None, FakeCursor(), headers) - assert "\n".join(output) == dedent('''\ + assert "\n".join(output) == dedent("""\ UPDATE `DUAL` SET `optional` = NULL , `float` = 10.0e0 @@ -85,34 +78,31 @@ def test_sql_output(mycli): `optional` = '1' , `float` = 0.5e0 , `binary` = X'aabb' - WHERE `letters` = 'd' AND `number` = 456;''') + WHERE `letters` = 'd' AND `number` = 456;""") # Test sql-insert output format (without table name) - assert list(mycli.change_table_format("sql-insert")) == \ - [(None, None, None, 'Changed table format to sql-insert')] + assert list(mycli.change_table_format("sql-insert")) == [(None, None, None, "Changed table format to sql-insert")] mycli.formatter.query = "" output = mycli.format_output(None, FakeCursor(), headers) - assert "\n".join(output) == dedent('''\ + assert "\n".join(output) == dedent("""\ INSERT INTO `DUAL` (`letters`, `number`, `optional`, `float`, `binary`) VALUES ('abc', 1, NULL, 10.0e0, X'aa') , ('d', 456, '1', 0.5e0, X'aabb') - ;''') + ;""") # Test sql-insert output format (with table name) - assert list(mycli.change_table_format("sql-insert")) == \ - [(None, None, None, 'Changed table format to sql-insert')] + assert list(mycli.change_table_format("sql-insert")) == [(None, None, None, "Changed table format to sql-insert")] mycli.formatter.query = "SELECT * FROM `table`" output = mycli.format_output(None, FakeCursor(), headers) - assert "\n".join(output) == dedent('''\ + assert "\n".join(output) == dedent("""\ INSERT INTO table (`letters`, `number`, `optional`, `float`, `binary`) VALUES ('abc', 1, NULL, 10.0e0, X'aa') , ('d', 456, '1', 0.5e0, X'aabb') - ;''') + ;""") # Test sql-insert output format (with database + table name) - assert list(mycli.change_table_format("sql-insert")) == \ - [(None, None, None, 'Changed table format to sql-insert')] + assert list(mycli.change_table_format("sql-insert")) == [(None, None, None, "Changed table format to sql-insert")] mycli.formatter.query = "SELECT * FROM `database`.`table`" output = mycli.format_output(None, FakeCursor(), headers) - assert "\n".join(output) == dedent('''\ + assert "\n".join(output) == dedent("""\ INSERT INTO database.table (`letters`, `number`, `optional`, `float`, `binary`) VALUES ('abc', 1, NULL, 10.0e0, X'aa') , ('d', 456, '1', 0.5e0, X'aabb') - ;''') + ;""") diff --git a/test/utils.py b/test/utils.py index ab12248..383f502 100644 --- a/test/utils.py +++ b/test/utils.py @@ -9,20 +9,18 @@ import pytest from mycli.main import special -PASSWORD = os.getenv('PYTEST_PASSWORD') -USER = os.getenv('PYTEST_USER', 'root') -HOST = os.getenv('PYTEST_HOST', 'localhost') -PORT = int(os.getenv('PYTEST_PORT', 3306)) -CHARSET = os.getenv('PYTEST_CHARSET', 'utf8') -SSH_USER = os.getenv('PYTEST_SSH_USER', None) -SSH_HOST = os.getenv('PYTEST_SSH_HOST', None) -SSH_PORT = os.getenv('PYTEST_SSH_PORT', 22) +PASSWORD = os.getenv("PYTEST_PASSWORD") +USER = os.getenv("PYTEST_USER", "root") +HOST = os.getenv("PYTEST_HOST", "localhost") +PORT = int(os.getenv("PYTEST_PORT", 3306)) +CHARSET = os.getenv("PYTEST_CHARSET", "utf8") +SSH_USER = os.getenv("PYTEST_SSH_USER", None) +SSH_HOST = os.getenv("PYTEST_SSH_HOST", None) +SSH_PORT = os.getenv("PYTEST_SSH_PORT", 22) def db_connection(dbname=None): - conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname, - password=PASSWORD, charset=CHARSET, - local_infile=False) + conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname, password=PASSWORD, charset=CHARSET, local_infile=False) conn.autocommit = True return conn @@ -30,20 +28,18 @@ def db_connection(dbname=None): try: db_connection() CAN_CONNECT_TO_DB = True -except: +except Exception: CAN_CONNECT_TO_DB = False -dbtest = pytest.mark.skipif( - not CAN_CONNECT_TO_DB, - reason="Need a mysql instance at localhost accessible by user 'root'") +dbtest = pytest.mark.skipif(not CAN_CONNECT_TO_DB, reason="Need a mysql instance at localhost accessible by user 'root'") def create_db(dbname): with db_connection().cursor() as cur: try: - cur.execute('''DROP DATABASE IF EXISTS mycli_test_db''') - cur.execute('''CREATE DATABASE mycli_test_db''') - except: + cur.execute("""DROP DATABASE IF EXISTS mycli_test_db""") + cur.execute("""CREATE DATABASE mycli_test_db""") + except Exception: pass @@ -53,8 +49,7 @@ def run(executor, sql, rows_as_list=True): for title, rows, headers, status in executor.run(sql): rows = list(rows) if (rows_as_list and rows) else rows - result.append({'title': title, 'rows': rows, 'headers': headers, - 'status': status}) + result.append({"title": title, "rows": rows, "headers": headers, "status": status}) return result @@ -87,8 +82,6 @@ def send_ctrl_c(wait_seconds): Returns the `multiprocessing.Process` created. """ - ctrl_c_process = multiprocessing.Process( - target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds) - ) + ctrl_c_process = multiprocessing.Process(target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds)) ctrl_c_process.start() return ctrl_c_process |