summaryrefslogtreecommitdiffstats
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/conftest.py25
-rw-r--r--test/features/db_utils.py35
-rw-r--r--test/features/environment.py133
-rw-r--r--test/features/fixture_utils.py5
-rw-r--r--test/features/steps/auto_vertical.py33
-rw-r--r--test/features/steps/basic_commands.py64
-rw-r--r--test/features/steps/connection.py46
-rw-r--r--test/features/steps/crud_database.py80
-rw-r--r--test/features/steps/crud_table.py70
-rw-r--r--test/features/steps/iocommands.py78
-rw-r--r--test/features/steps/named_queries.py51
-rw-r--r--test/features/steps/specials.py10
-rw-r--r--test/features/steps/utils.py4
-rw-r--r--test/features/steps/wrappers.py72
-rw-r--r--test/myclirc1
-rw-r--r--test/test_clistyle.py9
-rw-r--r--test/test_completion_engine.py780
-rw-r--r--test/test_completion_refresher.py16
-rw-r--r--test/test_config.py82
-rw-r--r--test/test_dbspecial.py29
-rw-r--r--test/test_main.py450
-rw-r--r--test/test_naive_completion.py42
-rw-r--r--test/test_parseutils.py158
-rw-r--r--test/test_prompt_utils.py4
-rw-r--r--test/test_smart_completion_public_schema_only.py128
-rw-r--r--test/test_special_iocommands.py198
-rw-r--r--test/test_sqlexecute.py212
-rw-r--r--test/test_tabular_output.py46
-rw-r--r--test/utils.py39
29 files changed, 1340 insertions, 1560 deletions
diff --git a/test/conftest.py b/test/conftest.py
index 1325596..5575b40 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -1,13 +1,12 @@
import pytest
-from .utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db,
- db_connection, SSH_USER, SSH_HOST, SSH_PORT)
+from .utils import HOST, USER, PASSWORD, PORT, CHARSET, create_db, db_connection, SSH_USER, SSH_HOST, SSH_PORT
import mycli.sqlexecute
@pytest.fixture(scope="function")
def connection():
- create_db('mycli_test_db')
- connection = db_connection('mycli_test_db')
+ create_db("mycli_test_db")
+ connection = db_connection("mycli_test_db")
yield connection
connection.close()
@@ -22,8 +21,18 @@ def cursor(connection):
@pytest.fixture
def executor(connection):
return mycli.sqlexecute.SQLExecute(
- database='mycli_test_db', user=USER,
- host=HOST, password=PASSWORD, port=PORT, socket=None, charset=CHARSET,
- local_infile=False, ssl=None, ssh_user=SSH_USER, ssh_host=SSH_HOST,
- ssh_port=SSH_PORT, ssh_password=None, ssh_key_filename=None
+ database="mycli_test_db",
+ user=USER,
+ host=HOST,
+ password=PASSWORD,
+ port=PORT,
+ socket=None,
+ charset=CHARSET,
+ local_infile=False,
+ ssl=None,
+ ssh_user=SSH_USER,
+ ssh_host=SSH_HOST,
+ ssh_port=SSH_PORT,
+ ssh_password=None,
+ ssh_key_filename=None,
)
diff --git a/test/features/db_utils.py b/test/features/db_utils.py
index be550e9..175cc1b 100644
--- a/test/features/db_utils.py
+++ b/test/features/db_utils.py
@@ -1,8 +1,7 @@
import pymysql
-def create_db(hostname='localhost', port=3306, username=None,
- password=None, dbname=None):
+def create_db(hostname="localhost", port=3306, username=None, password=None, dbname=None):
"""Create test database.
:param hostname: string
@@ -14,17 +13,12 @@ def create_db(hostname='localhost', port=3306, username=None,
"""
cn = pymysql.connect(
- host=hostname,
- port=port,
- user=username,
- password=password,
- charset='utf8mb4',
- cursorclass=pymysql.cursors.DictCursor
+ host=hostname, port=port, user=username, password=password, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor
)
with cn.cursor() as cr:
- cr.execute('drop database if exists ' + dbname)
- cr.execute('create database ' + dbname)
+ cr.execute("drop database if exists " + dbname)
+ cr.execute("create database " + dbname)
cn.close()
@@ -44,20 +38,13 @@ def create_cn(hostname, port, password, username, dbname):
"""
cn = pymysql.connect(
- host=hostname,
- port=port,
- user=username,
- password=password,
- db=dbname,
- charset='utf8mb4',
- cursorclass=pymysql.cursors.DictCursor
+ host=hostname, port=port, user=username, password=password, db=dbname, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor
)
return cn
-def drop_db(hostname='localhost', port=3306, username=None,
- password=None, dbname=None):
+def drop_db(hostname="localhost", port=3306, username=None, password=None, dbname=None):
"""Drop database.
:param hostname: string
@@ -68,17 +55,11 @@ def drop_db(hostname='localhost', port=3306, username=None,
"""
cn = pymysql.connect(
- host=hostname,
- port=port,
- user=username,
- password=password,
- db=dbname,
- charset='utf8mb4',
- cursorclass=pymysql.cursors.DictCursor
+ host=hostname, port=port, user=username, password=password, db=dbname, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor
)
with cn.cursor() as cr:
- cr.execute('drop database if exists ' + dbname)
+ cr.execute("drop database if exists " + dbname)
close_cn(cn)
diff --git a/test/features/environment.py b/test/features/environment.py
index 1ea0f08..a3d3764 100644
--- a/test/features/environment.py
+++ b/test/features/environment.py
@@ -9,96 +9,72 @@ import pexpect
from steps.wrappers import run_cli, wait_prompt
-test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log')
+test_log_file = os.path.join(os.environ["HOME"], ".mycli.test.log")
-SELF_CONNECTING_FEATURES = (
- 'test/features/connection.feature',
-)
+SELF_CONNECTING_FEATURES = ("test/features/connection.feature",)
-MY_CNF_PATH = os.path.expanduser('~/.my.cnf')
-MY_CNF_BACKUP_PATH = f'{MY_CNF_PATH}.backup'
-MYLOGIN_CNF_PATH = os.path.expanduser('~/.mylogin.cnf')
-MYLOGIN_CNF_BACKUP_PATH = f'{MYLOGIN_CNF_PATH}.backup'
+MY_CNF_PATH = os.path.expanduser("~/.my.cnf")
+MY_CNF_BACKUP_PATH = f"{MY_CNF_PATH}.backup"
+MYLOGIN_CNF_PATH = os.path.expanduser("~/.mylogin.cnf")
+MYLOGIN_CNF_BACKUP_PATH = f"{MYLOGIN_CNF_PATH}.backup"
def get_db_name_from_context(context):
- return context.config.userdata.get(
- 'my_test_db', None
- ) or "mycli_behave_tests"
-
+ return context.config.userdata.get("my_test_db", None) or "mycli_behave_tests"
def before_all(context):
"""Set env parameters."""
- os.environ['LINES'] = "100"
- os.environ['COLUMNS'] = "100"
- os.environ['EDITOR'] = 'ex'
- os.environ['LC_ALL'] = 'en_US.UTF-8'
- os.environ['PROMPT_TOOLKIT_NO_CPR'] = '1'
- os.environ['MYCLI_HISTFILE'] = os.devnull
+ os.environ["LINES"] = "100"
+ os.environ["COLUMNS"] = "100"
+ os.environ["EDITOR"] = "ex"
+ os.environ["LC_ALL"] = "en_US.UTF-8"
+ os.environ["PROMPT_TOOLKIT_NO_CPR"] = "1"
+ os.environ["MYCLI_HISTFILE"] = os.devnull
- test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
- login_path_file = os.path.join(test_dir, 'mylogin.cnf')
-# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
+ # test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
+ # login_path_file = os.path.join(test_dir, "mylogin.cnf")
+ # os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
- context.package_root = os.path.abspath(
- os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
+ context.package_root = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
- os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root,
- '.coveragerc')
+ os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root, ".coveragerc")
context.exit_sent = False
- vi = '_'.join([str(x) for x in sys.version_info[:3]])
+ vi = "_".join([str(x) for x in sys.version_info[:3]])
db_name = get_db_name_from_context(context)
- db_name_full = '{0}_{1}'.format(db_name, vi)
+ db_name_full = "{0}_{1}".format(db_name, vi)
# Store get params from config/environment variables
context.conf = {
- 'host': context.config.userdata.get(
- 'my_test_host',
- os.getenv('PYTEST_HOST', 'localhost')
- ),
- 'port': context.config.userdata.get(
- 'my_test_port',
- int(os.getenv('PYTEST_PORT', '3306'))
- ),
- 'user': context.config.userdata.get(
- 'my_test_user',
- os.getenv('PYTEST_USER', 'root')
- ),
- 'pass': context.config.userdata.get(
- 'my_test_pass',
- os.getenv('PYTEST_PASSWORD', None)
- ),
- 'cli_command': context.config.userdata.get(
- 'my_cli_command', None) or
- sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"',
- 'dbname': db_name,
- 'dbname_tmp': db_name_full + '_tmp',
- 'vi': vi,
- 'pager_boundary': '---boundary---',
+ "host": context.config.userdata.get("my_test_host", os.getenv("PYTEST_HOST", "localhost")),
+ "port": context.config.userdata.get("my_test_port", int(os.getenv("PYTEST_PORT", "3306"))),
+ "user": context.config.userdata.get("my_test_user", os.getenv("PYTEST_USER", "root")),
+ "pass": context.config.userdata.get("my_test_pass", os.getenv("PYTEST_PASSWORD", None)),
+ "cli_command": context.config.userdata.get("my_cli_command", None)
+ or sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"',
+ "dbname": db_name,
+ "dbname_tmp": db_name_full + "_tmp",
+ "vi": vi,
+ "pager_boundary": "---boundary---",
}
_, my_cnf = mkstemp()
- with open(my_cnf, 'w') as f:
+ with open(my_cnf, "w") as f:
f.write(
- '[client]\n'
- 'pager={0} {1} {2}\n'.format(
- sys.executable, os.path.join(context.package_root,
- 'test/features/wrappager.py'),
- context.conf['pager_boundary'])
+ "[client]\n" "pager={0} {1} {2}\n".format(
+ sys.executable, os.path.join(context.package_root, "test/features/wrappager.py"), context.conf["pager_boundary"]
+ )
)
- context.conf['defaults-file'] = my_cnf
- context.conf['myclirc'] = os.path.join(context.package_root, 'test',
- 'myclirc')
+ context.conf["defaults-file"] = my_cnf
+ context.conf["myclirc"] = os.path.join(context.package_root, "test", "myclirc")
- context.cn = dbutils.create_db(context.conf['host'], context.conf['port'],
- context.conf['user'],
- context.conf['pass'],
- context.conf['dbname'])
+ context.cn = dbutils.create_db(
+ context.conf["host"], context.conf["port"], context.conf["user"], context.conf["pass"], context.conf["dbname"]
+ )
context.fixture_data = fixutils.read_fixture_files()
@@ -106,12 +82,10 @@ def before_all(context):
def after_all(context):
"""Unset env parameters."""
dbutils.close_cn(context.cn)
- dbutils.drop_db(context.conf['host'], context.conf['port'],
- context.conf['user'], context.conf['pass'],
- context.conf['dbname'])
+ dbutils.drop_db(context.conf["host"], context.conf["port"], context.conf["user"], context.conf["pass"], context.conf["dbname"])
# Restore env vars.
- #for k, v in context.pgenv.items():
+ # for k, v in context.pgenv.items():
# if k in os.environ and v is None:
# del os.environ[k]
# elif v:
@@ -123,8 +97,8 @@ def before_step(context, _):
def before_scenario(context, arg):
- with open(test_log_file, 'w') as f:
- f.write('')
+ with open(test_log_file, "w") as f:
+ f.write("")
if arg.location.filename not in SELF_CONNECTING_FEATURES:
run_cli(context)
wait_prompt(context)
@@ -140,23 +114,18 @@ def after_scenario(context, _):
"""Cleans up after each test complete."""
with open(test_log_file) as f:
for line in f:
- if 'error' in line.lower():
- raise RuntimeError(f'Error in log file: {line}')
+ if "error" in line.lower():
+ raise RuntimeError(f"Error in log file: {line}")
- if hasattr(context, 'cli') and not context.exit_sent:
+ if hasattr(context, "cli") and not context.exit_sent:
# Quit nicely.
if not context.atprompt:
- user = context.conf['user']
- host = context.conf['host']
+ user = context.conf["user"]
+ host = context.conf["host"]
dbname = context.currentdb
- context.cli.expect_exact(
- '{0}@{1}:{2}>'.format(
- user, host, dbname
- ),
- timeout=5
- )
- context.cli.sendcontrol('c')
- context.cli.sendcontrol('d')
+ context.cli.expect_exact("{0}@{1}:{2}>".format(user, host, dbname), timeout=5)
+ context.cli.sendcontrol("c")
+ context.cli.sendcontrol("d")
context.cli.expect_exact(pexpect.EOF, timeout=5)
if os.path.exists(MY_CNF_BACKUP_PATH):
diff --git a/test/features/fixture_utils.py b/test/features/fixture_utils.py
index f85e0f6..514e41f 100644
--- a/test/features/fixture_utils.py
+++ b/test/features/fixture_utils.py
@@ -1,5 +1,4 @@
import os
-import io
def read_fixture_lines(filename):
@@ -20,9 +19,9 @@ def read_fixture_files():
fixture_dict = {}
current_dir = os.path.dirname(__file__)
- fixture_dir = os.path.join(current_dir, 'fixture_data/')
+ fixture_dir = os.path.join(current_dir, "fixture_data/")
for filename in os.listdir(fixture_dir):
- if filename not in ['.', '..']:
+ if filename not in [".", ".."]:
fullname = os.path.join(fixture_dir, filename)
fixture_dict[filename] = read_fixture_lines(fullname)
diff --git a/test/features/steps/auto_vertical.py b/test/features/steps/auto_vertical.py
index e1cb26f..ad20067 100644
--- a/test/features/steps/auto_vertical.py
+++ b/test/features/steps/auto_vertical.py
@@ -6,41 +6,42 @@ import wrappers
from utils import parse_cli_args_to_dict
-@when('we run dbcli with {arg}')
+@when("we run dbcli with {arg}")
def step_run_cli_with_arg(context, arg):
wrappers.run_cli(context, run_args=parse_cli_args_to_dict(arg))
-@when('we execute a small query')
+@when("we execute a small query")
def step_execute_small_query(context):
- context.cli.sendline('select 1')
+ context.cli.sendline("select 1")
-@when('we execute a large query')
+@when("we execute a large query")
def step_execute_large_query(context):
- context.cli.sendline(
- 'select {}'.format(','.join([str(n) for n in range(1, 50)])))
+ context.cli.sendline("select {}".format(",".join([str(n) for n in range(1, 50)])))
-@then('we see small results in horizontal format')
+@then("we see small results in horizontal format")
def step_see_small_results(context):
- wrappers.expect_pager(context, dedent("""\
+ wrappers.expect_pager(
+ context,
+ dedent("""\
+---+\r
| 1 |\r
+---+\r
| 1 |\r
+---+\r
\r
- """), timeout=5)
- wrappers.expect_exact(context, '1 row in set', timeout=2)
+ """),
+ timeout=5,
+ )
+ wrappers.expect_exact(context, "1 row in set", timeout=2)
-@then('we see large results in vertical format')
+@then("we see large results in vertical format")
def step_see_large_results(context):
- rows = ['{n:3}| {n}'.format(n=str(n)) for n in range(1, 50)]
- expected = ('***************************[ 1. row ]'
- '***************************\r\n' +
- '{}\r\n'.format('\r\n'.join(rows) + '\r\n'))
+ rows = ["{n:3}| {n}".format(n=str(n)) for n in range(1, 50)]
+ expected = "***************************[ 1. row ]" "***************************\r\n" + "{}\r\n".format("\r\n".join(rows) + "\r\n")
wrappers.expect_pager(context, expected, timeout=10)
- wrappers.expect_exact(context, '1 row in set', timeout=2)
+ wrappers.expect_exact(context, "1 row in set", timeout=2)
diff --git a/test/features/steps/basic_commands.py b/test/features/steps/basic_commands.py
index 425ef67..ec1e47a 100644
--- a/test/features/steps/basic_commands.py
+++ b/test/features/steps/basic_commands.py
@@ -5,18 +5,18 @@ to call the step in "*.feature" file.
"""
-from behave import when
+from behave import when, then
from textwrap import dedent
import tempfile
import wrappers
-@when('we run dbcli')
+@when("we run dbcli")
def step_run_cli(context):
wrappers.run_cli(context)
-@when('we wait for prompt')
+@when("we wait for prompt")
def step_wait_prompt(context):
wrappers.wait_prompt(context)
@@ -24,77 +24,75 @@ def step_wait_prompt(context):
@when('we send "ctrl + d"')
def step_ctrl_d(context):
"""Send Ctrl + D to hopefully exit."""
- context.cli.sendcontrol('d')
+ context.cli.sendcontrol("d")
context.exit_sent = True
-@when('we send "\?" command')
+@when(r'we send "\?" command')
def step_send_help(context):
- """Send \?
+ r"""Send \?
to see help.
"""
- context.cli.sendline('\\?')
- wrappers.expect_exact(
- context, context.conf['pager_boundary'] + '\r\n', timeout=5)
+ context.cli.sendline("\\?")
+ wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
-@when(u'we send source command')
+@when("we send source command")
def step_send_source_command(context):
with tempfile.NamedTemporaryFile() as f:
- f.write(b'\?')
+ f.write(b"\\?")
f.flush()
- context.cli.sendline('\. {0}'.format(f.name))
- wrappers.expect_exact(
- context, context.conf['pager_boundary'] + '\r\n', timeout=5)
+ context.cli.sendline("\\. {0}".format(f.name))
+ wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
-@when(u'we run query to check application_name')
+@when("we run query to check application_name")
def step_check_application_name(context):
context.cli.sendline(
"SELECT 'found' FROM performance_schema.session_connect_attrs WHERE attr_name = 'program_name' AND attr_value = 'mycli'"
)
-@then(u'we see found')
+@then("we see found")
def step_see_found(context):
wrappers.expect_exact(
context,
- context.conf['pager_boundary'] + '\r' + dedent('''
+ context.conf["pager_boundary"]
+ + "\r"
+ + dedent("""
+-------+\r
| found |\r
+-------+\r
| found |\r
+-------+\r
\r
- ''') + context.conf['pager_boundary'],
- timeout=5
+ """)
+ + context.conf["pager_boundary"],
+ timeout=5,
)
-@then(u'we confirm the destructive warning')
-def step_confirm_destructive_command(context):
+@then("we confirm the destructive warning")
+def step_confirm_destructive_command(context): # noqa
"""Confirm destructive command."""
- wrappers.expect_exact(
- context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
- context.cli.sendline('y')
+ wrappers.expect_exact(context, "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", timeout=2)
+ context.cli.sendline("y")
-@when(u'we answer the destructive warning with "{confirmation}"')
-def step_confirm_destructive_command(context, confirmation):
+@when('we answer the destructive warning with "{confirmation}"')
+def step_confirm_destructive_command(context, confirmation): # noqa
"""Confirm destructive command."""
- wrappers.expect_exact(
- context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
+ wrappers.expect_exact(context, "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", timeout=2)
context.cli.sendline(confirmation)
-@then(u'we answer the destructive warning with invalid "{confirmation}" and see text "{text}"')
-def step_confirm_destructive_command(context, confirmation, text):
+@then('we answer the destructive warning with invalid "{confirmation}" and see text "{text}"')
+def step_confirm_destructive_command(context, confirmation, text): # noqa
"""Confirm destructive command."""
- wrappers.expect_exact(
- context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
+ wrappers.expect_exact(context, "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", timeout=2)
context.cli.sendline(confirmation)
wrappers.expect_exact(context, text, timeout=2)
# we must exit the Click loop, or the feature will hang
- context.cli.sendline('n')
+ context.cli.sendline("n")
diff --git a/test/features/steps/connection.py b/test/features/steps/connection.py
index e16dd86..80d0653 100644
--- a/test/features/steps/connection.py
+++ b/test/features/steps/connection.py
@@ -1,9 +1,7 @@
import io
import os
-import shlex
from behave import when, then
-import pexpect
import wrappers
from test.features.steps.utils import parse_cli_args_to_dict
@@ -12,60 +10,44 @@ from test.utils import HOST, PORT, USER, PASSWORD
from mycli.config import encrypt_mylogin_cnf
-TEST_LOGIN_PATH = 'test_login_path'
+TEST_LOGIN_PATH = "test_login_path"
@when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"')
@when('we run mycli without arguments "{excluded_args}"')
-def step_run_cli_without_args(context, excluded_args, exact_args=''):
- wrappers.run_cli(
- context,
- run_args=parse_cli_args_to_dict(exact_args),
- exclude_args=parse_cli_args_to_dict(excluded_args).keys()
- )
+def step_run_cli_without_args(context, excluded_args, exact_args=""):
+ wrappers.run_cli(context, run_args=parse_cli_args_to_dict(exact_args), exclude_args=parse_cli_args_to_dict(excluded_args).keys())
@then('status contains "{expression}"')
def status_contains(context, expression):
- wrappers.expect_exact(context, f'{expression}', timeout=5)
+ wrappers.expect_exact(context, f"{expression}", timeout=5)
# Normally, the shutdown after scenario waits for the prompt.
# But we may have changed the prompt, depending on parameters,
# so let's wait for its last character
- context.cli.expect_exact('>')
+ context.cli.expect_exact(">")
context.atprompt = True
-@when('we create my.cnf file')
+@when("we create my.cnf file")
def step_create_my_cnf_file(context):
- my_cnf = (
- '[client]\n'
- f'host = {HOST}\n'
- f'port = {PORT}\n'
- f'user = {USER}\n'
- f'password = {PASSWORD}\n'
- )
- with open(MY_CNF_PATH, 'w') as f:
+ my_cnf = "[client]\n" f"host = {HOST}\n" f"port = {PORT}\n" f"user = {USER}\n" f"password = {PASSWORD}\n"
+ with open(MY_CNF_PATH, "w") as f:
f.write(my_cnf)
-@when('we create mylogin.cnf file')
+@when("we create mylogin.cnf file")
def step_create_mylogin_cnf_file(context):
- os.environ.pop('MYSQL_TEST_LOGIN_FILE', None)
- mylogin_cnf = (
- f'[{TEST_LOGIN_PATH}]\n'
- f'host = {HOST}\n'
- f'port = {PORT}\n'
- f'user = {USER}\n'
- f'password = {PASSWORD}\n'
- )
- with open(MYLOGIN_CNF_PATH, 'wb') as f:
+ os.environ.pop("MYSQL_TEST_LOGIN_FILE", None)
+ mylogin_cnf = f"[{TEST_LOGIN_PATH}]\n" f"host = {HOST}\n" f"port = {PORT}\n" f"user = {USER}\n" f"password = {PASSWORD}\n"
+ with open(MYLOGIN_CNF_PATH, "wb") as f:
input_file = io.StringIO(mylogin_cnf)
f.write(encrypt_mylogin_cnf(input_file).read())
-@then('we are logged in')
+@then("we are logged in")
def we_are_logged_in(context):
db_name = get_db_name_from_context(context)
- context.cli.expect_exact(f'{db_name}>', timeout=5)
+ context.cli.expect_exact(f"{db_name}>", timeout=5)
context.atprompt = True
diff --git a/test/features/steps/crud_database.py b/test/features/steps/crud_database.py
index 841f37d..56ff114 100644
--- a/test/features/steps/crud_database.py
+++ b/test/features/steps/crud_database.py
@@ -11,105 +11,99 @@ import wrappers
from behave import when, then
-@when('we create database')
+@when("we create database")
def step_db_create(context):
"""Send create database."""
- context.cli.sendline('create database {0};'.format(
- context.conf['dbname_tmp']))
+ context.cli.sendline("create database {0};".format(context.conf["dbname_tmp"]))
- context.response = {
- 'database_name': context.conf['dbname_tmp']
- }
+ context.response = {"database_name": context.conf["dbname_tmp"]}
-@when('we drop database')
+@when("we drop database")
def step_db_drop(context):
"""Send drop database."""
- context.cli.sendline('drop database {0};'.format(
- context.conf['dbname_tmp']))
+ context.cli.sendline("drop database {0};".format(context.conf["dbname_tmp"]))
-@when('we connect to test database')
+@when("we connect to test database")
def step_db_connect_test(context):
"""Send connect to database."""
- db_name = context.conf['dbname']
+ db_name = context.conf["dbname"]
context.currentdb = db_name
- context.cli.sendline('use {0};'.format(db_name))
+ context.cli.sendline("use {0};".format(db_name))
-@when('we connect to quoted test database')
+@when("we connect to quoted test database")
def step_db_connect_quoted_tmp(context):
"""Send connect to database."""
- db_name = context.conf['dbname']
+ db_name = context.conf["dbname"]
context.currentdb = db_name
- context.cli.sendline('use `{0}`;'.format(db_name))
+ context.cli.sendline("use `{0}`;".format(db_name))
-@when('we connect to tmp database')
+@when("we connect to tmp database")
def step_db_connect_tmp(context):
"""Send connect to database."""
- db_name = context.conf['dbname_tmp']
+ db_name = context.conf["dbname_tmp"]
context.currentdb = db_name
- context.cli.sendline('use {0}'.format(db_name))
+ context.cli.sendline("use {0}".format(db_name))
-@when('we connect to dbserver')
+@when("we connect to dbserver")
def step_db_connect_dbserver(context):
"""Send connect to database."""
- context.currentdb = 'mysql'
- context.cli.sendline('use mysql')
+ context.currentdb = "mysql"
+ context.cli.sendline("use mysql")
-@then('dbcli exits')
+@then("dbcli exits")
def step_wait_exit(context):
"""Make sure the cli exits."""
wrappers.expect_exact(context, pexpect.EOF, timeout=5)
-@then('we see dbcli prompt')
+@then("we see dbcli prompt")
def step_see_prompt(context):
"""Wait to see the prompt."""
- user = context.conf['user']
- host = context.conf['host']
+ user = context.conf["user"]
+ host = context.conf["host"]
dbname = context.currentdb
- wrappers.wait_prompt(context, '{0}@{1}:{2}> '.format(user, host, dbname))
+ wrappers.wait_prompt(context, "{0}@{1}:{2}> ".format(user, host, dbname))
-@then('we see help output')
+@then("we see help output")
def step_see_help(context):
- for expected_line in context.fixture_data['help_commands.txt']:
+ for expected_line in context.fixture_data["help_commands.txt"]:
wrappers.expect_exact(context, expected_line, timeout=1)
-@then('we see database created')
+@then("we see database created")
def step_see_db_created(context):
"""Wait to see create database output."""
- wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
+ wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2)
-@then('we see database dropped')
+@then("we see database dropped")
def step_see_db_dropped(context):
"""Wait to see drop database output."""
- wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
+ wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2)
-@then('we see database dropped and no default database')
+@then("we see database dropped and no default database")
def step_see_db_dropped_no_default(context):
"""Wait to see drop database output."""
- user = context.conf['user']
- host = context.conf['host']
- database = '(none)'
+ user = context.conf["user"]
+ host = context.conf["host"]
+ database = "(none)"
context.currentdb = None
- wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
- wrappers.wait_prompt(context, '{0}@{1}:{2}>'.format(user, host, database))
+ wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2)
+ wrappers.wait_prompt(context, "{0}@{1}:{2}>".format(user, host, database))
-@then('we see database connected')
+@then("we see database connected")
def step_see_db_connected(context):
"""Wait to see drop database output."""
- wrappers.expect_exact(
- context, 'You are now connected to database "', timeout=2)
+ wrappers.expect_exact(context, 'You are now connected to database "', timeout=2)
wrappers.expect_exact(context, '"', timeout=2)
- wrappers.expect_exact(context, ' as user "{0}"'.format(
- context.conf['user']), timeout=2)
+ wrappers.expect_exact(context, ' as user "{0}"'.format(context.conf["user"]), timeout=2)
diff --git a/test/features/steps/crud_table.py b/test/features/steps/crud_table.py
index f715f0c..48a6408 100644
--- a/test/features/steps/crud_table.py
+++ b/test/features/steps/crud_table.py
@@ -10,103 +10,109 @@ from behave import when, then
from textwrap import dedent
-@when('we create table')
+@when("we create table")
def step_create_table(context):
"""Send create table."""
- context.cli.sendline('create table a(x text);')
+ context.cli.sendline("create table a(x text);")
-@when('we insert into table')
+@when("we insert into table")
def step_insert_into_table(context):
"""Send insert into table."""
- context.cli.sendline('''insert into a(x) values('xxx');''')
+ context.cli.sendline("""insert into a(x) values('xxx');""")
-@when('we update table')
+@when("we update table")
def step_update_table(context):
"""Send insert into table."""
- context.cli.sendline('''update a set x = 'yyy' where x = 'xxx';''')
+ context.cli.sendline("""update a set x = 'yyy' where x = 'xxx';""")
-@when('we select from table')
+@when("we select from table")
def step_select_from_table(context):
"""Send select from table."""
- context.cli.sendline('select * from a;')
+ context.cli.sendline("select * from a;")
-@when('we delete from table')
+@when("we delete from table")
def step_delete_from_table(context):
"""Send deete from table."""
- context.cli.sendline('''delete from a where x = 'yyy';''')
+ context.cli.sendline("""delete from a where x = 'yyy';""")
-@when('we drop table')
+@when("we drop table")
def step_drop_table(context):
"""Send drop table."""
- context.cli.sendline('drop table a;')
+ context.cli.sendline("drop table a;")
-@then('we see table created')
+@then("we see table created")
def step_see_table_created(context):
"""Wait to see create table output."""
- wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
+ wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2)
-@then('we see record inserted')
+@then("we see record inserted")
def step_see_record_inserted(context):
"""Wait to see insert output."""
- wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
+ wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2)
-@then('we see record updated')
+@then("we see record updated")
def step_see_record_updated(context):
"""Wait to see update output."""
- wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
+ wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2)
-@then('we see data selected')
+@then("we see data selected")
def step_see_data_selected(context):
"""Wait to see select output."""
wrappers.expect_pager(
- context, dedent("""\
+ context,
+ dedent("""\
+-----+\r
| x |\r
+-----+\r
| yyy |\r
+-----+\r
\r
- """), timeout=2)
- wrappers.expect_exact(context, '1 row in set', timeout=2)
+ """),
+ timeout=2,
+ )
+ wrappers.expect_exact(context, "1 row in set", timeout=2)
-@then('we see record deleted')
+@then("we see record deleted")
def step_see_data_deleted(context):
"""Wait to see delete output."""
- wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
+ wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2)
-@then('we see table dropped')
+@then("we see table dropped")
def step_see_table_dropped(context):
"""Wait to see drop output."""
- wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
+ wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2)
-@when('we select null')
+@when("we select null")
def step_select_null(context):
"""Send select null."""
- context.cli.sendline('select null;')
+ context.cli.sendline("select null;")
-@then('we see null selected')
+@then("we see null selected")
def step_see_null_selected(context):
"""Wait to see null output."""
wrappers.expect_pager(
- context, dedent("""\
+ context,
+ dedent("""\
+--------+\r
| NULL |\r
+--------+\r
| <null> |\r
+--------+\r
\r
- """), timeout=2)
- wrappers.expect_exact(context, '1 row in set', timeout=2)
+ """),
+ timeout=2,
+ )
+ wrappers.expect_exact(context, "1 row in set", timeout=2)
diff --git a/test/features/steps/iocommands.py b/test/features/steps/iocommands.py
index bbabf43..07d5c77 100644
--- a/test/features/steps/iocommands.py
+++ b/test/features/steps/iocommands.py
@@ -5,101 +5,93 @@ from behave import when, then
from textwrap import dedent
-@when('we start external editor providing a file name')
+@when("we start external editor providing a file name")
def step_edit_file(context):
"""Edit file with external editor."""
- context.editor_file_name = os.path.join(
- context.package_root, 'test_file_{0}.sql'.format(context.conf['vi']))
+ context.editor_file_name = os.path.join(context.package_root, "test_file_{0}.sql".format(context.conf["vi"]))
if os.path.exists(context.editor_file_name):
os.remove(context.editor_file_name)
- context.cli.sendline('\e {0}'.format(
- os.path.basename(context.editor_file_name)))
- wrappers.expect_exact(
- context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2)
- wrappers.expect_exact(context, '\r\n:', timeout=2)
+ context.cli.sendline("\\e {0}".format(os.path.basename(context.editor_file_name)))
+ wrappers.expect_exact(context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2)
+ wrappers.expect_exact(context, "\r\n:", timeout=2)
@when('we type "{query}" in the editor')
def step_edit_type_sql(context, query):
- context.cli.sendline('i')
+ context.cli.sendline("i")
context.cli.sendline(query)
- context.cli.sendline('.')
- wrappers.expect_exact(context, '\r\n:', timeout=2)
+ context.cli.sendline(".")
+ wrappers.expect_exact(context, "\r\n:", timeout=2)
-@when('we exit the editor')
+@when("we exit the editor")
def step_edit_quit(context):
- context.cli.sendline('x')
+ context.cli.sendline("x")
wrappers.expect_exact(context, "written", timeout=2)
@then('we see "{query}" in prompt')
def step_edit_done_sql(context, query):
- for match in query.split(' '):
+ for match in query.split(" "):
wrappers.expect_exact(context, match, timeout=5)
# Cleanup the command line.
- context.cli.sendcontrol('c')
+ context.cli.sendcontrol("c")
# Cleanup the edited file.
if context.editor_file_name and os.path.exists(context.editor_file_name):
os.remove(context.editor_file_name)
-@when(u'we tee output')
+@when("we tee output")
def step_tee_ouptut(context):
- context.tee_file_name = os.path.join(
- context.package_root, 'tee_file_{0}.sql'.format(context.conf['vi']))
+ context.tee_file_name = os.path.join(context.package_root, "tee_file_{0}.sql".format(context.conf["vi"]))
if os.path.exists(context.tee_file_name):
os.remove(context.tee_file_name)
- context.cli.sendline('tee {0}'.format(
- os.path.basename(context.tee_file_name)))
+ context.cli.sendline("tee {0}".format(os.path.basename(context.tee_file_name)))
-@when(u'we select "select {param}"')
+@when('we select "select {param}"')
def step_query_select_number(context, param):
- context.cli.sendline(u'select {}'.format(param))
- wrappers.expect_pager(context, dedent(u"""\
+ context.cli.sendline("select {}".format(param))
+ wrappers.expect_pager(
+ context,
+ dedent(
+ """\
+{dashes}+\r
| {param} |\r
+{dashes}+\r
| {param} |\r
+{dashes}+\r
\r
- """.format(param=param, dashes='-' * (len(param) + 2))
- ), timeout=5)
- wrappers.expect_exact(context, '1 row in set', timeout=2)
+ """.format(param=param, dashes="-" * (len(param) + 2))
+ ),
+ timeout=5,
+ )
+ wrappers.expect_exact(context, "1 row in set", timeout=2)
-@then(u'we see result "{result}"')
+@then('we see result "{result}"')
def step_see_result(context, result):
- wrappers.expect_exact(
- context,
- u"| {} |".format(result),
- timeout=2
- )
+ wrappers.expect_exact(context, "| {} |".format(result), timeout=2)
-@when(u'we query "{query}"')
+@when('we query "{query}"')
def step_query(context, query):
context.cli.sendline(query)
-@when(u'we notee output')
+@when("we notee output")
def step_notee_output(context):
- context.cli.sendline('notee')
+ context.cli.sendline("notee")
-@then(u'we see 123456 in tee output')
+@then("we see 123456 in tee output")
def step_see_123456_in_ouput(context):
with open(context.tee_file_name) as f:
- assert '123456' in f.read()
+ assert "123456" in f.read()
if os.path.exists(context.tee_file_name):
os.remove(context.tee_file_name)
-@then(u'delimiter is set to "{delimiter}"')
+@then('delimiter is set to "{delimiter}"')
def delimiter_is_set(context, delimiter):
- wrappers.expect_exact(
- context,
- u'Changed delimiter to {}'.format(delimiter),
- timeout=2
- )
+ wrappers.expect_exact(context, "Changed delimiter to {}".format(delimiter), timeout=2)
diff --git a/test/features/steps/named_queries.py b/test/features/steps/named_queries.py
index bc1f866..93d68ba 100644
--- a/test/features/steps/named_queries.py
+++ b/test/features/steps/named_queries.py
@@ -9,82 +9,79 @@ import wrappers
from behave import when, then
-@when('we save a named query')
+@when("we save a named query")
def step_save_named_query(context):
"""Send \fs command."""
- context.cli.sendline('\\fs foo SELECT 12345')
+ context.cli.sendline("\\fs foo SELECT 12345")
-@when('we use a named query')
+@when("we use a named query")
def step_use_named_query(context):
"""Send \f command."""
- context.cli.sendline('\\f foo')
+ context.cli.sendline("\\f foo")
-@when('we delete a named query')
+@when("we delete a named query")
def step_delete_named_query(context):
"""Send \fd command."""
- context.cli.sendline('\\fd foo')
+ context.cli.sendline("\\fd foo")
-@then('we see the named query saved')
+@then("we see the named query saved")
def step_see_named_query_saved(context):
"""Wait to see query saved."""
- wrappers.expect_exact(context, 'Saved.', timeout=2)
+ wrappers.expect_exact(context, "Saved.", timeout=2)
-@then('we see the named query executed')
+@then("we see the named query executed")
def step_see_named_query_executed(context):
"""Wait to see select output."""
- wrappers.expect_exact(context, 'SELECT 12345', timeout=2)
+ wrappers.expect_exact(context, "SELECT 12345", timeout=2)
-@then('we see the named query deleted')
+@then("we see the named query deleted")
def step_see_named_query_deleted(context):
"""Wait to see query deleted."""
- wrappers.expect_exact(context, 'foo: Deleted', timeout=2)
+ wrappers.expect_exact(context, "foo: Deleted", timeout=2)
-@when('we save a named query with parameters')
+@when("we save a named query with parameters")
def step_save_named_query_with_parameters(context):
"""Send \fs command for query with parameters."""
context.cli.sendline('\\fs foo_args SELECT $1, "$2", "$3"')
-@when('we use named query with parameters')
+@when("we use named query with parameters")
def step_use_named_query_with_parameters(context):
"""Send \f command with parameters."""
context.cli.sendline('\\f foo_args 101 second "third value"')
-@then('we see the named query with parameters executed')
+@then("we see the named query with parameters executed")
def step_see_named_query_with_parameters_executed(context):
"""Wait to see select output."""
- wrappers.expect_exact(
- context, 'SELECT 101, "second", "third value"', timeout=2)
+ wrappers.expect_exact(context, 'SELECT 101, "second", "third value"', timeout=2)
-@when('we use named query with too few parameters')
+@when("we use named query with too few parameters")
def step_use_named_query_with_too_few_parameters(context):
"""Send \f command with missing parameters."""
- context.cli.sendline('\\f foo_args 101')
+ context.cli.sendline("\\f foo_args 101")
-@then('we see the named query with parameters fail with missing parameters')
+@then("we see the named query with parameters fail with missing parameters")
def step_see_named_query_with_parameters_fail_with_missing_parameters(context):
"""Wait to see select output."""
- wrappers.expect_exact(
- context, 'missing substitution for $2 in query:', timeout=2)
+ wrappers.expect_exact(context, "missing substitution for $2 in query:", timeout=2)
-@when('we use named query with too many parameters')
+@when("we use named query with too many parameters")
def step_use_named_query_with_too_many_parameters(context):
"""Send \f command with extra parameters."""
- context.cli.sendline('\\f foo_args 101 102 103 104')
+ context.cli.sendline("\\f foo_args 101 102 103 104")
-@then('we see the named query with parameters fail with extra parameters')
+@then("we see the named query with parameters fail with extra parameters")
def step_see_named_query_with_parameters_fail_with_extra_parameters(context):
"""Wait to see select output."""
- wrappers.expect_exact(
- context, 'query does not have substitution parameter $4:', timeout=2)
+ wrappers.expect_exact(context, "query does not have substitution parameter $4:", timeout=2)
diff --git a/test/features/steps/specials.py b/test/features/steps/specials.py
index e8b99e3..1b50a00 100644
--- a/test/features/steps/specials.py
+++ b/test/features/steps/specials.py
@@ -9,10 +9,10 @@ import wrappers
from behave import when, then
-@when('we refresh completions')
+@when("we refresh completions")
def step_refresh_completions(context):
"""Send refresh command."""
- context.cli.sendline('rehash')
+ context.cli.sendline("rehash")
@then('we see text "{text}"')
@@ -20,8 +20,8 @@ def step_see_text(context, text):
"""Wait to see given text message."""
wrappers.expect_exact(context, text, timeout=2)
-@then('we see completions refresh started')
+
+@then("we see completions refresh started")
def step_see_refresh_started(context):
"""Wait to see refresh output."""
- wrappers.expect_exact(
- context, 'Auto-completion refresh started in the background.', timeout=2)
+ wrappers.expect_exact(context, "Auto-completion refresh started in the background.", timeout=2)
diff --git a/test/features/steps/utils.py b/test/features/steps/utils.py
index 1ae63d2..873f9d4 100644
--- a/test/features/steps/utils.py
+++ b/test/features/steps/utils.py
@@ -4,8 +4,8 @@ import shlex
def parse_cli_args_to_dict(cli_args: str):
args_dict = {}
for arg in shlex.split(cli_args):
- if '=' in arg:
- key, value = arg.split('=')
+ if "=" in arg:
+ key, value = arg.split("=")
args_dict[key] = value
else:
args_dict[arg] = None
diff --git a/test/features/steps/wrappers.py b/test/features/steps/wrappers.py
index 6408f23..f9325c6 100644
--- a/test/features/steps/wrappers.py
+++ b/test/features/steps/wrappers.py
@@ -18,10 +18,9 @@ def expect_exact(context, expected, timeout):
timedout = True
if timedout:
# Strip color codes out of the output.
- actual = re.sub(r'\x1b\[([0-9A-Za-z;?])+[m|K]?',
- '', context.cli.before)
+ actual = re.sub(r"\x1b\[([0-9A-Za-z;?])+[m|K]?", "", context.cli.before)
raise Exception(
- textwrap.dedent('''\
+ textwrap.dedent("""\
Expected:
---
{0!r}
@@ -34,17 +33,12 @@ def expect_exact(context, expected, timeout):
---
{2!r}
---
- ''').format(
- expected,
- actual,
- context.logfile.getvalue()
- )
+ """).format(expected, actual, context.logfile.getvalue())
)
def expect_pager(context, expected, timeout):
- expect_exact(context, "{0}\r\n{1}{0}\r\n".format(
- context.conf['pager_boundary'], expected), timeout=timeout)
+ expect_exact(context, "{0}\r\n{1}{0}\r\n".format(context.conf["pager_boundary"], expected), timeout=timeout)
def run_cli(context, run_args=None, exclude_args=None):
@@ -63,55 +57,49 @@ def run_cli(context, run_args=None, exclude_args=None):
else:
rendered_args.append(key)
- if conf.get('host', None):
- add_arg('host', '-h', conf['host'])
- if conf.get('user', None):
- add_arg('user', '-u', conf['user'])
- if conf.get('pass', None):
- add_arg('pass', '-p', conf['pass'])
- if conf.get('port', None):
- add_arg('port', '-P', str(conf['port']))
- if conf.get('dbname', None):
- add_arg('dbname', '-D', conf['dbname'])
- if conf.get('defaults-file', None):
- add_arg('defaults_file', '--defaults-file', conf['defaults-file'])
- if conf.get('myclirc', None):
- add_arg('myclirc', '--myclirc', conf['myclirc'])
- if conf.get('login_path'):
- add_arg('login_path', '--login-path', conf['login_path'])
+ if conf.get("host", None):
+ add_arg("host", "-h", conf["host"])
+ if conf.get("user", None):
+ add_arg("user", "-u", conf["user"])
+ if conf.get("pass", None):
+ add_arg("pass", "-p", conf["pass"])
+ if conf.get("port", None):
+ add_arg("port", "-P", str(conf["port"]))
+ if conf.get("dbname", None):
+ add_arg("dbname", "-D", conf["dbname"])
+ if conf.get("defaults-file", None):
+ add_arg("defaults_file", "--defaults-file", conf["defaults-file"])
+ if conf.get("myclirc", None):
+ add_arg("myclirc", "--myclirc", conf["myclirc"])
+ if conf.get("login_path"):
+ add_arg("login_path", "--login-path", conf["login_path"])
for arg_name, arg_value in conf.items():
- if arg_name.startswith('-'):
+ if arg_name.startswith("-"):
add_arg(arg_name, arg_name, arg_value)
try:
- cli_cmd = context.conf['cli_command']
+ cli_cmd = context.conf["cli_command"]
except KeyError:
- cli_cmd = (
- '{0!s} -c "'
- 'import coverage ; '
- 'coverage.process_startup(); '
- 'import mycli.main; '
- 'mycli.main.cli()'
- '"'
- ).format(sys.executable)
+ cli_cmd = ('{0!s} -c "' "import coverage ; " "coverage.process_startup(); " "import mycli.main; " "mycli.main.cli()" '"').format(
+ sys.executable
+ )
cmd_parts = [cli_cmd] + rendered_args
- cmd = ' '.join(cmd_parts)
+ cmd = " ".join(cmd_parts)
context.cli = pexpect.spawnu(cmd, cwd=context.package_root)
context.logfile = StringIO()
context.cli.logfile = context.logfile
context.exit_sent = False
- context.currentdb = context.conf['dbname']
+ context.currentdb = context.conf["dbname"]
def wait_prompt(context, prompt=None):
"""Make sure prompt is displayed."""
if prompt is None:
- user = context.conf['user']
- host = context.conf['host']
+ user = context.conf["user"]
+ host = context.conf["host"]
dbname = context.currentdb
- prompt = '{0}@{1}:{2}>'.format(
- user, host, dbname),
+ prompt = ("{0}@{1}:{2}>".format(user, host, dbname),)
expect_exact(context, prompt, timeout=5)
context.atprompt = True
diff --git a/test/myclirc b/test/myclirc
index 7d96c45..58f7279 100644
--- a/test/myclirc
+++ b/test/myclirc
@@ -153,6 +153,7 @@ output.null = "#808080"
# Favorite queries.
[favorite_queries]
check = 'select "✔"'
+foo_args = 'SELECT $1, "$2", "$3"'
# Use the -d option to reference a DSN.
# Special characters in passwords and other strings can be escaped with URL encoding.
diff --git a/test/test_clistyle.py b/test/test_clistyle.py
index f82cdf0..ab40444 100644
--- a/test/test_clistyle.py
+++ b/test/test_clistyle.py
@@ -1,4 +1,5 @@
"""Test the mycli.clistyle module."""
+
import pytest
from pygments.style import Style
@@ -10,9 +11,9 @@ from mycli.clistyle import style_factory
@pytest.mark.skip(reason="incompatible with new prompt toolkit")
def test_style_factory():
"""Test that a Pygments Style class is created."""
- header = 'bold underline #ansired'
- cli_style = {'Token.Output.Header': header}
- style = style_factory('default', cli_style)
+ header = "bold underline #ansired"
+ cli_style = {"Token.Output.Header": header}
+ style = style_factory("default", cli_style)
assert isinstance(style(), Style)
assert Token.Output.Header in style.styles
@@ -22,6 +23,6 @@ def test_style_factory():
@pytest.mark.skip(reason="incompatible with new prompt toolkit")
def test_style_factory_unknown_name():
"""Test that an unrecognized name will not throw an error."""
- style = style_factory('foobar', {})
+ style = style_factory("foobar", {})
assert isinstance(style(), Style)
diff --git a/test/test_completion_engine.py b/test/test_completion_engine.py
index 318b632..3104065 100644
--- a/test/test_completion_engine.py
+++ b/test/test_completion_engine.py
@@ -8,494 +8,528 @@ def sorted_dicts(dicts):
def test_select_suggests_cols_with_visible_table_scope():
- suggestions = suggest_type('SELECT FROM tabl', 'SELECT ')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'alias', 'aliases': ['tabl']},
- {'type': 'column', 'tables': [(None, 'tabl', None)]},
- {'type': 'function', 'schema': []},
- {'type': 'keyword'},
- ])
+ suggestions = suggest_type("SELECT FROM tabl", "SELECT ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["tabl"]},
+ {"type": "column", "tables": [(None, "tabl", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
def test_select_suggests_cols_with_qualified_table_scope():
- suggestions = suggest_type('SELECT FROM sch.tabl', 'SELECT ')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'alias', 'aliases': ['tabl']},
- {'type': 'column', 'tables': [('sch', 'tabl', None)]},
- {'type': 'function', 'schema': []},
- {'type': 'keyword'},
- ])
-
-
-@pytest.mark.parametrize('expression', [
- 'SELECT * FROM tabl WHERE ',
- 'SELECT * FROM tabl WHERE (',
- 'SELECT * FROM tabl WHERE foo = ',
- 'SELECT * FROM tabl WHERE bar OR ',
- 'SELECT * FROM tabl WHERE foo = 1 AND ',
- 'SELECT * FROM tabl WHERE (bar > 10 AND ',
- 'SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (',
- 'SELECT * FROM tabl WHERE 10 < ',
- 'SELECT * FROM tabl WHERE foo BETWEEN ',
- 'SELECT * FROM tabl WHERE foo BETWEEN foo AND ',
-])
+ suggestions = suggest_type("SELECT FROM sch.tabl", "SELECT ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["tabl"]},
+ {"type": "column", "tables": [("sch", "tabl", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM tabl WHERE ",
+ "SELECT * FROM tabl WHERE (",
+ "SELECT * FROM tabl WHERE foo = ",
+ "SELECT * FROM tabl WHERE bar OR ",
+ "SELECT * FROM tabl WHERE foo = 1 AND ",
+ "SELECT * FROM tabl WHERE (bar > 10 AND ",
+ "SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (",
+ "SELECT * FROM tabl WHERE 10 < ",
+ "SELECT * FROM tabl WHERE foo BETWEEN ",
+ "SELECT * FROM tabl WHERE foo BETWEEN foo AND ",
+ ],
+)
def test_where_suggests_columns_functions(expression):
suggestions = suggest_type(expression, expression)
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'alias', 'aliases': ['tabl']},
- {'type': 'column', 'tables': [(None, 'tabl', None)]},
- {'type': 'function', 'schema': []},
- {'type': 'keyword'},
- ])
-
-
-@pytest.mark.parametrize('expression', [
- 'SELECT * FROM tabl WHERE foo IN (',
- 'SELECT * FROM tabl WHERE foo IN (bar, ',
-])
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["tabl"]},
+ {"type": "column", "tables": [(None, "tabl", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM tabl WHERE foo IN (",
+ "SELECT * FROM tabl WHERE foo IN (bar, ",
+ ],
+)
def test_where_in_suggests_columns(expression):
suggestions = suggest_type(expression, expression)
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'alias', 'aliases': ['tabl']},
- {'type': 'column', 'tables': [(None, 'tabl', None)]},
- {'type': 'function', 'schema': []},
- {'type': 'keyword'},
- ])
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["tabl"]},
+ {"type": "column", "tables": [(None, "tabl", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
def test_where_equals_any_suggests_columns_or_keywords():
- text = 'SELECT * FROM tabl WHERE foo = ANY('
+ text = "SELECT * FROM tabl WHERE foo = ANY("
suggestions = suggest_type(text, text)
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'alias', 'aliases': ['tabl']},
- {'type': 'column', 'tables': [(None, 'tabl', None)]},
- {'type': 'function', 'schema': []},
- {'type': 'keyword'}])
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["tabl"]},
+ {"type": "column", "tables": [(None, "tabl", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
def test_lparen_suggests_cols():
- suggestion = suggest_type('SELECT MAX( FROM tbl', 'SELECT MAX(')
- assert suggestion == [
- {'type': 'column', 'tables': [(None, 'tbl', None)]}]
+ suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(")
+ assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}]
def test_operand_inside_function_suggests_cols1():
- suggestion = suggest_type(
- 'SELECT MAX(col1 + FROM tbl', 'SELECT MAX(col1 + ')
- assert suggestion == [
- {'type': 'column', 'tables': [(None, 'tbl', None)]}]
+ suggestion = suggest_type("SELECT MAX(col1 + FROM tbl", "SELECT MAX(col1 + ")
+ assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}]
def test_operand_inside_function_suggests_cols2():
- suggestion = suggest_type(
- 'SELECT MAX(col1 + col2 + FROM tbl', 'SELECT MAX(col1 + col2 + ')
- assert suggestion == [
- {'type': 'column', 'tables': [(None, 'tbl', None)]}]
+ suggestion = suggest_type("SELECT MAX(col1 + col2 + FROM tbl", "SELECT MAX(col1 + col2 + ")
+ assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}]
def test_select_suggests_cols_and_funcs():
- suggestions = suggest_type('SELECT ', 'SELECT ')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'alias', 'aliases': []},
- {'type': 'column', 'tables': []},
- {'type': 'function', 'schema': []},
- {'type': 'keyword'},
- ])
-
-
-@pytest.mark.parametrize('expression', [
- 'SELECT * FROM ',
- 'INSERT INTO ',
- 'COPY ',
- 'UPDATE ',
- 'DESCRIBE ',
- 'DESC ',
- 'EXPLAIN ',
- 'SELECT * FROM foo JOIN ',
-])
+ suggestions = suggest_type("SELECT ", "SELECT ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": []},
+ {"type": "column", "tables": []},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM ",
+ "INSERT INTO ",
+ "COPY ",
+ "UPDATE ",
+ "DESCRIBE ",
+ "DESC ",
+ "EXPLAIN ",
+ "SELECT * FROM foo JOIN ",
+ ],
+)
def test_expression_suggests_tables_views_and_schemas(expression):
suggestions = suggest_type(expression, expression)
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'table', 'schema': []},
- {'type': 'view', 'schema': []},
- {'type': 'schema'}])
-
-
-@pytest.mark.parametrize('expression', [
- 'SELECT * FROM sch.',
- 'INSERT INTO sch.',
- 'COPY sch.',
- 'UPDATE sch.',
- 'DESCRIBE sch.',
- 'DESC sch.',
- 'EXPLAIN sch.',
- 'SELECT * FROM foo JOIN sch.',
-])
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM sch.",
+ "INSERT INTO sch.",
+ "COPY sch.",
+ "UPDATE sch.",
+ "DESCRIBE sch.",
+ "DESC sch.",
+ "EXPLAIN sch.",
+ "SELECT * FROM foo JOIN sch.",
+ ],
+)
def test_expression_suggests_qualified_tables_views_and_schemas(expression):
suggestions = suggest_type(expression, expression)
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'table', 'schema': 'sch'},
- {'type': 'view', 'schema': 'sch'}])
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": "sch"}, {"type": "view", "schema": "sch"}])
def test_truncate_suggests_tables_and_schemas():
- suggestions = suggest_type('TRUNCATE ', 'TRUNCATE ')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'table', 'schema': []},
- {'type': 'schema'}])
+ suggestions = suggest_type("TRUNCATE ", "TRUNCATE ")
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "schema"}])
def test_truncate_suggests_qualified_tables():
- suggestions = suggest_type('TRUNCATE sch.', 'TRUNCATE sch.')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'table', 'schema': 'sch'}])
+ suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.")
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": "sch"}])
def test_distinct_suggests_cols():
- suggestions = suggest_type('SELECT DISTINCT ', 'SELECT DISTINCT ')
- assert suggestions == [{'type': 'column', 'tables': []}]
+ suggestions = suggest_type("SELECT DISTINCT ", "SELECT DISTINCT ")
+ assert suggestions == [{"type": "column", "tables": []}]
def test_col_comma_suggests_cols():
- suggestions = suggest_type('SELECT a, b, FROM tbl', 'SELECT a, b,')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'alias', 'aliases': ['tbl']},
- {'type': 'column', 'tables': [(None, 'tbl', None)]},
- {'type': 'function', 'schema': []},
- {'type': 'keyword'},
- ])
+ suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["tbl"]},
+ {"type": "column", "tables": [(None, "tbl", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
def test_table_comma_suggests_tables_and_schemas():
- suggestions = suggest_type('SELECT a, b FROM tbl1, ',
- 'SELECT a, b FROM tbl1, ')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'table', 'schema': []},
- {'type': 'view', 'schema': []},
- {'type': 'schema'}])
+ suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ")
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
def test_into_suggests_tables_and_schemas():
- suggestion = suggest_type('INSERT INTO ', 'INSERT INTO ')
- assert sorted_dicts(suggestion) == sorted_dicts([
- {'type': 'table', 'schema': []},
- {'type': 'view', 'schema': []},
- {'type': 'schema'}])
+ suggestion = suggest_type("INSERT INTO ", "INSERT INTO ")
+ assert sorted_dicts(suggestion) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
def test_insert_into_lparen_suggests_cols():
- suggestions = suggest_type('INSERT INTO abc (', 'INSERT INTO abc (')
- assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
+ suggestions = suggest_type("INSERT INTO abc (", "INSERT INTO abc (")
+ assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}]
def test_insert_into_lparen_partial_text_suggests_cols():
- suggestions = suggest_type('INSERT INTO abc (i', 'INSERT INTO abc (i')
- assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
+ suggestions = suggest_type("INSERT INTO abc (i", "INSERT INTO abc (i")
+ assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}]
def test_insert_into_lparen_comma_suggests_cols():
- suggestions = suggest_type('INSERT INTO abc (id,', 'INSERT INTO abc (id,')
- assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
+ suggestions = suggest_type("INSERT INTO abc (id,", "INSERT INTO abc (id,")
+ assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}]
def test_partially_typed_col_name_suggests_col_names():
- suggestions = suggest_type('SELECT * FROM tabl WHERE col_n',
- 'SELECT * FROM tabl WHERE col_n')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'alias', 'aliases': ['tabl']},
- {'type': 'column', 'tables': [(None, 'tabl', None)]},
- {'type': 'function', 'schema': []},
- {'type': 'keyword'},
- ])
+ suggestions = suggest_type("SELECT * FROM tabl WHERE col_n", "SELECT * FROM tabl WHERE col_n")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["tabl"]},
+ {"type": "column", "tables": [(None, "tabl", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
- suggestions = suggest_type('SELECT tabl. FROM tabl', 'SELECT tabl.')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'column', 'tables': [(None, 'tabl', None)]},
- {'type': 'table', 'schema': 'tabl'},
- {'type': 'view', 'schema': 'tabl'},
- {'type': 'function', 'schema': 'tabl'}])
+ suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "column", "tables": [(None, "tabl", None)]},
+ {"type": "table", "schema": "tabl"},
+ {"type": "view", "schema": "tabl"},
+ {"type": "function", "schema": "tabl"},
+ ]
+ )
def test_dot_suggests_cols_of_an_alias():
- suggestions = suggest_type('SELECT t1. FROM tabl1 t1, tabl2 t2',
- 'SELECT t1.')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'table', 'schema': 't1'},
- {'type': 'view', 'schema': 't1'},
- {'type': 'column', 'tables': [(None, 'tabl1', 't1')]},
- {'type': 'function', 'schema': 't1'}])
+ suggestions = suggest_type("SELECT t1. FROM tabl1 t1, tabl2 t2", "SELECT t1.")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": "t1"},
+ {"type": "view", "schema": "t1"},
+ {"type": "column", "tables": [(None, "tabl1", "t1")]},
+ {"type": "function", "schema": "t1"},
+ ]
+ )
def test_dot_col_comma_suggests_cols_or_schema_qualified_table():
- suggestions = suggest_type('SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2',
- 'SELECT t1.a, t2.')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'column', 'tables': [(None, 'tabl2', 't2')]},
- {'type': 'table', 'schema': 't2'},
- {'type': 'view', 'schema': 't2'},
- {'type': 'function', 'schema': 't2'}])
-
-
-@pytest.mark.parametrize('expression', [
- 'SELECT * FROM (',
- 'SELECT * FROM foo WHERE EXISTS (',
- 'SELECT * FROM foo WHERE bar AND NOT EXISTS (',
- 'SELECT 1 AS',
-])
+ suggestions = suggest_type("SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2.")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "column", "tables": [(None, "tabl2", "t2")]},
+ {"type": "table", "schema": "t2"},
+ {"type": "view", "schema": "t2"},
+ {"type": "function", "schema": "t2"},
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM (",
+ "SELECT * FROM foo WHERE EXISTS (",
+ "SELECT * FROM foo WHERE bar AND NOT EXISTS (",
+ "SELECT 1 AS",
+ ],
+)
def test_sub_select_suggests_keyword(expression):
suggestion = suggest_type(expression, expression)
- assert suggestion == [{'type': 'keyword'}]
+ assert suggestion == [{"type": "keyword"}]
-@pytest.mark.parametrize('expression', [
- 'SELECT * FROM (S',
- 'SELECT * FROM foo WHERE EXISTS (S',
- 'SELECT * FROM foo WHERE bar AND NOT EXISTS (S',
-])
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM (S",
+ "SELECT * FROM foo WHERE EXISTS (S",
+ "SELECT * FROM foo WHERE bar AND NOT EXISTS (S",
+ ],
+)
def test_sub_select_partial_text_suggests_keyword(expression):
suggestion = suggest_type(expression, expression)
- assert suggestion == [{'type': 'keyword'}]
+ assert suggestion == [{"type": "keyword"}]
def test_outer_table_reference_in_exists_subquery_suggests_columns():
- q = 'SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f.'
+ q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f."
suggestions = suggest_type(q, q)
assert suggestions == [
- {'type': 'column', 'tables': [(None, 'foo', 'f')]},
- {'type': 'table', 'schema': 'f'},
- {'type': 'view', 'schema': 'f'},
- {'type': 'function', 'schema': 'f'}]
-
-
-@pytest.mark.parametrize('expression', [
- 'SELECT * FROM (SELECT * FROM ',
- 'SELECT * FROM foo WHERE EXISTS (SELECT * FROM ',
- 'SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ',
-])
+ {"type": "column", "tables": [(None, "foo", "f")]},
+ {"type": "table", "schema": "f"},
+ {"type": "view", "schema": "f"},
+ {"type": "function", "schema": "f"},
+ ]
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM (SELECT * FROM ",
+ "SELECT * FROM foo WHERE EXISTS (SELECT * FROM ",
+ "SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ",
+ ],
+)
def test_sub_select_table_name_completion(expression):
suggestion = suggest_type(expression, expression)
- assert sorted_dicts(suggestion) == sorted_dicts([
- {'type': 'table', 'schema': []},
- {'type': 'view', 'schema': []},
- {'type': 'schema'}])
+ assert sorted_dicts(suggestion) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
def test_sub_select_col_name_completion():
- suggestions = suggest_type('SELECT * FROM (SELECT FROM abc',
- 'SELECT * FROM (SELECT ')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'alias', 'aliases': ['abc']},
- {'type': 'column', 'tables': [(None, 'abc', None)]},
- {'type': 'function', 'schema': []},
- {'type': 'keyword'},
- ])
+ suggestions = suggest_type("SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["abc"]},
+ {"type": "column", "tables": [(None, "abc", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
@pytest.mark.xfail
def test_sub_select_multiple_col_name_completion():
- suggestions = suggest_type('SELECT * FROM (SELECT a, FROM abc',
- 'SELECT * FROM (SELECT a, ')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'column', 'tables': [(None, 'abc', None)]},
- {'type': 'function', 'schema': []}])
+ suggestions = suggest_type("SELECT * FROM (SELECT a, FROM abc", "SELECT * FROM (SELECT a, ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [{"type": "column", "tables": [(None, "abc", None)]}, {"type": "function", "schema": []}]
+ )
def test_sub_select_dot_col_name_completion():
- suggestions = suggest_type('SELECT * FROM (SELECT t. FROM tabl t',
- 'SELECT * FROM (SELECT t.')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'column', 'tables': [(None, 'tabl', 't')]},
- {'type': 'table', 'schema': 't'},
- {'type': 'view', 'schema': 't'},
- {'type': 'function', 'schema': 't'}])
-
-
-@pytest.mark.parametrize('join_type', ['', 'INNER', 'LEFT', 'RIGHT OUTER'])
-@pytest.mark.parametrize('tbl_alias', ['', 'foo'])
+ suggestions = suggest_type("SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t.")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "column", "tables": [(None, "tabl", "t")]},
+ {"type": "table", "schema": "t"},
+ {"type": "view", "schema": "t"},
+ {"type": "function", "schema": "t"},
+ ]
+ )
+
+
+@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"])
+@pytest.mark.parametrize("tbl_alias", ["", "foo"])
def test_join_suggests_tables_and_schemas(tbl_alias, join_type):
- text = 'SELECT * FROM abc {0} {1} JOIN '.format(tbl_alias, join_type)
+ text = "SELECT * FROM abc {0} {1} JOIN ".format(tbl_alias, join_type)
suggestion = suggest_type(text, text)
- assert sorted_dicts(suggestion) == sorted_dicts([
- {'type': 'table', 'schema': []},
- {'type': 'view', 'schema': []},
- {'type': 'schema'}])
+ assert sorted_dicts(suggestion) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
-@pytest.mark.parametrize('sql', [
- 'SELECT * FROM abc a JOIN def d ON a.',
- 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.',
-])
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "SELECT * FROM abc a JOIN def d ON a.",
+ "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.",
+ ],
+)
def test_join_alias_dot_suggests_cols1(sql):
suggestions = suggest_type(sql, sql)
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'column', 'tables': [(None, 'abc', 'a')]},
- {'type': 'table', 'schema': 'a'},
- {'type': 'view', 'schema': 'a'},
- {'type': 'function', 'schema': 'a'}])
-
-
-@pytest.mark.parametrize('sql', [
- 'SELECT * FROM abc a JOIN def d ON a.id = d.',
- 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.',
-])
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "column", "tables": [(None, "abc", "a")]},
+ {"type": "table", "schema": "a"},
+ {"type": "view", "schema": "a"},
+ {"type": "function", "schema": "a"},
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "SELECT * FROM abc a JOIN def d ON a.id = d.",
+ "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.",
+ ],
+)
def test_join_alias_dot_suggests_cols2(sql):
suggestions = suggest_type(sql, sql)
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'column', 'tables': [(None, 'def', 'd')]},
- {'type': 'table', 'schema': 'd'},
- {'type': 'view', 'schema': 'd'},
- {'type': 'function', 'schema': 'd'}])
-
-
-@pytest.mark.parametrize('sql', [
- 'select a.x, b.y from abc a join bcd b on ',
- 'select a.x, b.y from abc a join bcd b on a.id = b.id OR ',
-])
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "column", "tables": [(None, "def", "d")]},
+ {"type": "table", "schema": "d"},
+ {"type": "view", "schema": "d"},
+ {"type": "function", "schema": "d"},
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select a.x, b.y from abc a join bcd b on ",
+ "select a.x, b.y from abc a join bcd b on a.id = b.id OR ",
+ ],
+)
def test_on_suggests_aliases(sql):
suggestions = suggest_type(sql, sql)
- assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}]
+ assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}]
-@pytest.mark.parametrize('sql', [
- 'select abc.x, bcd.y from abc join bcd on ',
- 'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ',
-])
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select abc.x, bcd.y from abc join bcd on ",
+ "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ",
+ ],
+)
def test_on_suggests_tables(sql):
suggestions = suggest_type(sql, sql)
- assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}]
+ assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}]
-@pytest.mark.parametrize('sql', [
- 'select a.x, b.y from abc a join bcd b on a.id = ',
- 'select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ',
-])
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select a.x, b.y from abc a join bcd b on a.id = ",
+ "select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ",
+ ],
+)
def test_on_suggests_aliases_right_side(sql):
suggestions = suggest_type(sql, sql)
- assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}]
+ assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}]
-@pytest.mark.parametrize('sql', [
- 'select abc.x, bcd.y from abc join bcd on ',
- 'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ',
-])
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select abc.x, bcd.y from abc join bcd on ",
+ "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ",
+ ],
+)
def test_on_suggests_tables_right_side(sql):
suggestions = suggest_type(sql, sql)
- assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}]
+ assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}]
-@pytest.mark.parametrize('col_list', ['', 'col1, '])
+@pytest.mark.parametrize("col_list", ["", "col1, "])
def test_join_using_suggests_common_columns(col_list):
- text = 'select * from abc inner join def using (' + col_list
- assert suggest_type(text, text) == [
- {'type': 'column',
- 'tables': [(None, 'abc', None), (None, 'def', None)],
- 'drop_unique': True}]
-
-@pytest.mark.parametrize('sql', [
- 'SELECT * FROM abc a JOIN def d ON a.id = d.id JOIN ghi g ON g.',
- 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.id2 JOIN ghi g ON d.id = g.id AND g.',
-])
+ text = "select * from abc inner join def using (" + col_list
+ assert suggest_type(text, text) == [{"type": "column", "tables": [(None, "abc", None), (None, "def", None)], "drop_unique": True}]
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "SELECT * FROM abc a JOIN def d ON a.id = d.id JOIN ghi g ON g.",
+ "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.id2 JOIN ghi g ON d.id = g.id AND g.",
+ ],
+)
def test_two_join_alias_dot_suggests_cols1(sql):
suggestions = suggest_type(sql, sql)
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'column', 'tables': [(None, 'ghi', 'g')]},
- {'type': 'table', 'schema': 'g'},
- {'type': 'view', 'schema': 'g'},
- {'type': 'function', 'schema': 'g'}])
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "column", "tables": [(None, "ghi", "g")]},
+ {"type": "table", "schema": "g"},
+ {"type": "view", "schema": "g"},
+ {"type": "function", "schema": "g"},
+ ]
+ )
+
def test_2_statements_2nd_current():
- suggestions = suggest_type('select * from a; select * from ',
- 'select * from a; select * from ')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'table', 'schema': []},
- {'type': 'view', 'schema': []},
- {'type': 'schema'}])
-
- suggestions = suggest_type('select * from a; select from b',
- 'select * from a; select ')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'alias', 'aliases': ['b']},
- {'type': 'column', 'tables': [(None, 'b', None)]},
- {'type': 'function', 'schema': []},
- {'type': 'keyword'},
- ])
+ suggestions = suggest_type("select * from a; select * from ", "select * from a; select * from ")
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
+
+ suggestions = suggest_type("select * from a; select from b", "select * from a; select ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["b"]},
+ {"type": "column", "tables": [(None, "b", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
# Should work even if first statement is invalid
- suggestions = suggest_type('select * from; select * from ',
- 'select * from; select * from ')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'table', 'schema': []},
- {'type': 'view', 'schema': []},
- {'type': 'schema'}])
+ suggestions = suggest_type("select * from; select * from ", "select * from; select * from ")
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
def test_2_statements_1st_current():
- suggestions = suggest_type('select * from ; select * from b',
- 'select * from ')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'table', 'schema': []},
- {'type': 'view', 'schema': []},
- {'type': 'schema'}])
-
- suggestions = suggest_type('select from a; select * from b',
- 'select ')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'alias', 'aliases': ['a']},
- {'type': 'column', 'tables': [(None, 'a', None)]},
- {'type': 'function', 'schema': []},
- {'type': 'keyword'},
- ])
+ suggestions = suggest_type("select * from ; select * from b", "select * from ")
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
+
+ suggestions = suggest_type("select from a; select * from b", "select ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["a"]},
+ {"type": "column", "tables": [(None, "a", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
def test_3_statements_2nd_current():
- suggestions = suggest_type('select * from a; select * from ; select * from c',
- 'select * from a; select * from ')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'table', 'schema': []},
- {'type': 'view', 'schema': []},
- {'type': 'schema'}])
-
- suggestions = suggest_type('select * from a; select from b; select * from c',
- 'select * from a; select ')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'alias', 'aliases': ['b']},
- {'type': 'column', 'tables': [(None, 'b', None)]},
- {'type': 'function', 'schema': []},
- {'type': 'keyword'},
- ])
+ suggestions = suggest_type("select * from a; select * from ; select * from c", "select * from a; select * from ")
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
+
+ suggestions = suggest_type("select * from a; select from b; select * from c", "select * from a; select ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["b"]},
+ {"type": "column", "tables": [(None, "b", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
def test_create_db_with_template():
- suggestions = suggest_type('create database foo with template ',
- 'create database foo with template ')
+ suggestions = suggest_type("create database foo with template ", "create database foo with template ")
- assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'database'}])
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "database"}])
-@pytest.mark.parametrize('initial_text', ['', ' ', '\t \t'])
+@pytest.mark.parametrize("initial_text", ["", " ", "\t \t"])
def test_specials_included_for_initial_completion(initial_text):
suggestions = suggest_type(initial_text, initial_text)
- assert sorted_dicts(suggestions) == \
- sorted_dicts([{'type': 'keyword'}, {'type': 'special'}])
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "keyword"}, {"type": "special"}])
def test_specials_not_included_after_initial_token():
- suggestions = suggest_type('create table foo (dt d',
- 'create table foo (dt d')
+ suggestions = suggest_type("create table foo (dt d", "create table foo (dt d")
- assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'keyword'}])
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "keyword"}])
def test_drop_schema_qualified_table_suggests_only_tables():
- text = 'DROP TABLE schema_name.table_name'
+ text = "DROP TABLE schema_name.table_name"
suggestions = suggest_type(text, text)
- assert suggestions == [{'type': 'table', 'schema': 'schema_name'}]
+ assert suggestions == [{"type": "table", "schema": "schema_name"}]
-@pytest.mark.parametrize('text', [',', ' ,', 'sel ,'])
+@pytest.mark.parametrize("text", [",", " ,", "sel ,"])
def test_handle_pre_completion_comma_gracefully(text):
suggestions = suggest_type(text, text)
@@ -503,53 +537,59 @@ def test_handle_pre_completion_comma_gracefully(text):
def test_cross_join():
- text = 'select * from v1 cross join v2 JOIN v1.id, '
+ text = "select * from v1 cross join v2 JOIN v1.id, "
suggestions = suggest_type(text, text)
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'table', 'schema': []},
- {'type': 'view', 'schema': []},
- {'type': 'schema'}])
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
-@pytest.mark.parametrize('expression', [
- 'SELECT 1 AS ',
- 'SELECT 1 FROM tabl AS ',
-])
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT 1 AS ",
+ "SELECT 1 FROM tabl AS ",
+ ],
+)
def test_after_as(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set()
-@pytest.mark.parametrize('expression', [
- '\\. ',
- 'select 1; \\. ',
- 'select 1;\\. ',
- 'select 1 ; \\. ',
- 'source ',
- 'truncate table test; source ',
- 'truncate table test ; source ',
- 'truncate table test;source ',
-])
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "\\. ",
+ "select 1; \\. ",
+ "select 1;\\. ",
+ "select 1 ; \\. ",
+ "source ",
+ "truncate table test; source ",
+ "truncate table test ; source ",
+ "truncate table test;source ",
+ ],
+)
def test_source_is_file(expression):
suggestions = suggest_type(expression, expression)
- assert suggestions == [{'type': 'file_name'}]
+ assert suggestions == [{"type": "file_name"}]
-@pytest.mark.parametrize("expression", [
- "\\f ",
-])
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "\\f ",
+ ],
+)
def test_favorite_name_suggestion(expression):
suggestions = suggest_type(expression, expression)
- assert suggestions == [{'type': 'favoritequery'}]
+ assert suggestions == [{"type": "favoritequery"}]
def test_order_by():
- text = 'select * from foo order by '
+ text = "select * from foo order by "
suggestions = suggest_type(text, text)
- assert suggestions == [{'tables': [(None, 'foo', None)], 'type': 'column'}]
+ assert suggestions == [{"tables": [(None, "foo", None)], "type": "column"}]
def test_quoted_where():
text = "'where i=';"
suggestions = suggest_type(text, text)
- assert suggestions == [{'type': 'keyword'}]
+ assert suggestions == [{"type": "keyword"}]
diff --git a/test/test_completion_refresher.py b/test/test_completion_refresher.py
index 31359cf..6f192d0 100644
--- a/test/test_completion_refresher.py
+++ b/test/test_completion_refresher.py
@@ -6,6 +6,7 @@ from unittest.mock import Mock, patch
@pytest.fixture
def refresher():
from mycli.completion_refresher import CompletionRefresher
+
return CompletionRefresher()
@@ -18,8 +19,7 @@ def test_ctor(refresher):
"""
assert len(refresher.refreshers) > 0
actual_handlers = list(refresher.refreshers.keys())
- expected_handlers = ['databases', 'schemata', 'tables', 'users', 'functions',
- 'special_commands', 'show_commands', 'keywords']
+ expected_handlers = ["databases", "schemata", "tables", "users", "functions", "special_commands", "show_commands", "keywords"]
assert expected_handlers == actual_handlers
@@ -32,12 +32,12 @@ def test_refresh_called_once(refresher):
callbacks = Mock()
sqlexecute = Mock()
- with patch.object(refresher, '_bg_refresh') as bg_refresh:
+ with patch.object(refresher, "_bg_refresh") as bg_refresh:
actual = refresher.refresh(sqlexecute, callbacks)
time.sleep(1) # Wait for the thread to work.
assert len(actual) == 1
assert len(actual[0]) == 4
- assert actual[0][3] == 'Auto-completion refresh started in the background.'
+ assert actual[0][3] == "Auto-completion refresh started in the background."
bg_refresh.assert_called_with(sqlexecute, callbacks, {})
@@ -61,13 +61,13 @@ def test_refresh_called_twice(refresher):
time.sleep(1) # Wait for the thread to work.
assert len(actual1) == 1
assert len(actual1[0]) == 4
- assert actual1[0][3] == 'Auto-completion refresh started in the background.'
+ assert actual1[0][3] == "Auto-completion refresh started in the background."
actual2 = refresher.refresh(sqlexecute, callbacks)
time.sleep(1) # Wait for the thread to work.
assert len(actual2) == 1
assert len(actual2[0]) == 4
- assert actual2[0][3] == 'Auto-completion refresh restarted.'
+ assert actual2[0][3] == "Auto-completion refresh restarted."
def test_refresh_with_callbacks(refresher):
@@ -80,9 +80,9 @@ def test_refresh_with_callbacks(refresher):
sqlexecute_class = Mock()
sqlexecute = Mock()
- with patch('mycli.completion_refresher.SQLExecute', sqlexecute_class):
+ with patch("mycli.completion_refresher.SQLExecute", sqlexecute_class):
# Set refreshers to 0: we're not testing refresh logic here
refresher.refreshers = {}
refresher.refresh(sqlexecute, callbacks)
time.sleep(1) # Wait for the thread to work.
- assert (callbacks[0].call_count == 1)
+ assert callbacks[0].call_count == 1
diff --git a/test/test_config.py b/test/test_config.py
index 7f2b244..859ca02 100644
--- a/test/test_config.py
+++ b/test/test_config.py
@@ -1,4 +1,5 @@
"""Unit tests for the mycli.config module."""
+
from io import BytesIO, StringIO, TextIOWrapper
import os
import struct
@@ -6,21 +7,26 @@ import sys
import tempfile
import pytest
-from mycli.config import (get_mylogin_cnf_path, open_mylogin_cnf,
- read_and_decrypt_mylogin_cnf, read_config_file,
- str_to_bool, strip_matching_quotes)
+from mycli.config import (
+ get_mylogin_cnf_path,
+ open_mylogin_cnf,
+ read_and_decrypt_mylogin_cnf,
+ read_config_file,
+ str_to_bool,
+ strip_matching_quotes,
+)
-LOGIN_PATH_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__),
- 'mylogin.cnf'))
+LOGIN_PATH_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__), "mylogin.cnf"))
def open_bmylogin_cnf(name):
"""Open contents of *name* in a BytesIO buffer."""
- with open(name, 'rb') as f:
+ with open(name, "rb") as f:
buf = BytesIO()
buf.write(f.read())
return buf
+
def test_read_mylogin_cnf():
"""Tests that a login path file can be read and decrypted."""
mylogin_cnf = open_mylogin_cnf(LOGIN_PATH_FILE)
@@ -28,7 +34,7 @@ def test_read_mylogin_cnf():
assert isinstance(mylogin_cnf, TextIOWrapper)
contents = mylogin_cnf.read()
- for word in ('[test]', 'user', 'password', 'host', 'port'):
+ for word in ("[test]", "user", "password", "host", "port"):
assert word in contents
@@ -46,7 +52,7 @@ def test_corrupted_login_key():
buf.seek(4)
# Write null bytes over half the login key
- buf.write(b'\0\0\0\0\0\0\0\0\0\0')
+ buf.write(b"\0\0\0\0\0\0\0\0\0\0")
buf.seek(0)
mylogin_cnf = read_and_decrypt_mylogin_cnf(buf)
@@ -63,58 +69,58 @@ def test_corrupted_pad():
# Skip option group
len_buf = buf.read(4)
- cipher_len, = struct.unpack("<i", len_buf)
+ (cipher_len,) = struct.unpack("<i", len_buf)
buf.read(cipher_len)
# Corrupt the pad for the user line
len_buf = buf.read(4)
- cipher_len, = struct.unpack("<i", len_buf)
+ (cipher_len,) = struct.unpack("<i", len_buf)
buf.read(cipher_len - 1)
- buf.write(b'\0')
+ buf.write(b"\0")
buf.seek(0)
mylogin_cnf = TextIOWrapper(read_and_decrypt_mylogin_cnf(buf))
contents = mylogin_cnf.read()
- for word in ('[test]', 'password', 'host', 'port'):
+ for word in ("[test]", "password", "host", "port"):
assert word in contents
- assert 'user' not in contents
+ assert "user" not in contents
def test_get_mylogin_cnf_path():
"""Tests that the path for .mylogin.cnf is detected."""
original_env = None
- if 'MYSQL_TEST_LOGIN_FILE' in os.environ:
- original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE')
- is_windows = sys.platform == 'win32'
+ if "MYSQL_TEST_LOGIN_FILE" in os.environ:
+ original_env = os.environ.pop("MYSQL_TEST_LOGIN_FILE")
+ is_windows = sys.platform == "win32"
login_cnf_path = get_mylogin_cnf_path()
if original_env is not None:
- os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env
+ os.environ["MYSQL_TEST_LOGIN_FILE"] = original_env
if login_cnf_path is not None:
- assert login_cnf_path.endswith('.mylogin.cnf')
+ assert login_cnf_path.endswith(".mylogin.cnf")
if is_windows is True:
- assert 'MySQL' in login_cnf_path
+ assert "MySQL" in login_cnf_path
else:
- home_dir = os.path.expanduser('~')
+ home_dir = os.path.expanduser("~")
assert login_cnf_path.startswith(home_dir)
def test_alternate_get_mylogin_cnf_path():
"""Tests that the alternate path for .mylogin.cnf is detected."""
original_env = None
- if 'MYSQL_TEST_LOGIN_FILE' in os.environ:
- original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE')
+ if "MYSQL_TEST_LOGIN_FILE" in os.environ:
+ original_env = os.environ.pop("MYSQL_TEST_LOGIN_FILE")
_, temp_path = tempfile.mkstemp()
- os.environ['MYSQL_TEST_LOGIN_FILE'] = temp_path
+ os.environ["MYSQL_TEST_LOGIN_FILE"] = temp_path
login_cnf_path = get_mylogin_cnf_path()
if original_env is not None:
- os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env
+ os.environ["MYSQL_TEST_LOGIN_FILE"] = original_env
assert temp_path == login_cnf_path
@@ -124,17 +130,17 @@ def test_str_to_bool():
assert str_to_bool(False) is False
assert str_to_bool(True) is True
- assert str_to_bool('False') is False
- assert str_to_bool('True') is True
- assert str_to_bool('TRUE') is True
- assert str_to_bool('1') is True
- assert str_to_bool('0') is False
- assert str_to_bool('on') is True
- assert str_to_bool('off') is False
- assert str_to_bool('off') is False
+ assert str_to_bool("False") is False
+ assert str_to_bool("True") is True
+ assert str_to_bool("TRUE") is True
+ assert str_to_bool("1") is True
+ assert str_to_bool("0") is False
+ assert str_to_bool("on") is True
+ assert str_to_bool("off") is False
+ assert str_to_bool("off") is False
with pytest.raises(ValueError):
- str_to_bool('foo')
+ str_to_bool("foo")
with pytest.raises(TypeError):
str_to_bool(None)
@@ -143,19 +149,19 @@ def test_str_to_bool():
def test_read_config_file_list_values_default():
"""Test that reading a config file uses list_values by default."""
- f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n")
+ f = StringIO("[main]\nweather='cloudy with a chance of meatballs'\n")
config = read_config_file(f)
- assert config['main']['weather'] == u"cloudy with a chance of meatballs"
+ assert config["main"]["weather"] == "cloudy with a chance of meatballs"
def test_read_config_file_list_values_off():
"""Test that you can disable list_values when reading a config file."""
- f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n")
+ f = StringIO("[main]\nweather='cloudy with a chance of meatballs'\n")
config = read_config_file(f, list_values=False)
- assert config['main']['weather'] == u"'cloudy with a chance of meatballs'"
+ assert config["main"]["weather"] == "'cloudy with a chance of meatballs'"
def test_strip_quotes_with_matching_quotes():
@@ -177,7 +183,7 @@ def test_strip_quotes_with_unmatching_quotes():
def test_strip_quotes_with_empty_string():
"""Test that an empty string is handled during unquoting."""
- assert '' == strip_matching_quotes('')
+ assert "" == strip_matching_quotes("")
def test_strip_quotes_with_none():
diff --git a/test/test_dbspecial.py b/test/test_dbspecial.py
index 21e389c..aee6e05 100644
--- a/test/test_dbspecial.py
+++ b/test/test_dbspecial.py
@@ -4,39 +4,32 @@ from mycli.packages.special.utils import format_uptime
def test_u_suggests_databases():
- suggestions = suggest_type('\\u ', '\\u ')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'database'}])
+ suggestions = suggest_type("\\u ", "\\u ")
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "database"}])
def test_describe_table():
- suggestions = suggest_type('\\dt', '\\dt ')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'table', 'schema': []},
- {'type': 'view', 'schema': []},
- {'type': 'schema'}])
+ suggestions = suggest_type("\\dt", "\\dt ")
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
def test_list_or_show_create_tables():
- suggestions = suggest_type('\\dt+', '\\dt+ ')
- assert sorted_dicts(suggestions) == sorted_dicts([
- {'type': 'table', 'schema': []},
- {'type': 'view', 'schema': []},
- {'type': 'schema'}])
+ suggestions = suggest_type("\\dt+", "\\dt+ ")
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}])
def test_format_uptime():
seconds = 59
- assert '59 sec' == format_uptime(seconds)
+ assert "59 sec" == format_uptime(seconds)
seconds = 120
- assert '2 min 0 sec' == format_uptime(seconds)
+ assert "2 min 0 sec" == format_uptime(seconds)
seconds = 54890
- assert '15 hours 14 min 50 sec' == format_uptime(seconds)
+ assert "15 hours 14 min 50 sec" == format_uptime(seconds)
seconds = 598244
- assert '6 days 22 hours 10 min 44 sec' == format_uptime(seconds)
+ assert "6 days 22 hours 10 min 44 sec" == format_uptime(seconds)
seconds = 522600
- assert '6 days 1 hour 10 min 0 sec' == format_uptime(seconds)
+ assert "6 days 1 hour 10 min 0 sec" == format_uptime(seconds)
diff --git a/test/test_main.py b/test/test_main.py
index 589d6cd..b0f8d4c 100644
--- a/test/test_main.py
+++ b/test/test_main.py
@@ -13,52 +13,62 @@ from textwrap import dedent
from collections import namedtuple
from tempfile import NamedTemporaryFile
-from textwrap import dedent
test_dir = os.path.abspath(os.path.dirname(__file__))
project_dir = os.path.dirname(test_dir)
-default_config_file = os.path.join(project_dir, 'test', 'myclirc')
-login_path_file = os.path.join(test_dir, 'mylogin.cnf')
-
-os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
-CLI_ARGS = ['--user', USER, '--host', HOST, '--port', PORT,
- '--password', PASSWORD, '--myclirc', default_config_file,
- '--defaults-file', default_config_file,
- 'mycli_test_db']
+default_config_file = os.path.join(project_dir, "test", "myclirc")
+login_path_file = os.path.join(test_dir, "mylogin.cnf")
+
+os.environ["MYSQL_TEST_LOGIN_FILE"] = login_path_file
+CLI_ARGS = [
+ "--user",
+ USER,
+ "--host",
+ HOST,
+ "--port",
+ PORT,
+ "--password",
+ PASSWORD,
+ "--myclirc",
+ default_config_file,
+ "--defaults-file",
+ default_config_file,
+ "mycli_test_db",
+]
@dbtest
def test_execute_arg(executor):
- run(executor, 'create table test (a text)')
+ run(executor, "create table test (a text)")
run(executor, 'insert into test values("abc")')
- sql = 'select * from test;'
+ sql = "select * from test;"
runner = CliRunner()
- result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql])
+ result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql])
assert result.exit_code == 0
- assert 'abc' in result.output
+ assert "abc" in result.output
- result = runner.invoke(cli, args=CLI_ARGS + ['--execute', sql])
+ result = runner.invoke(cli, args=CLI_ARGS + ["--execute", sql])
assert result.exit_code == 0
- assert 'abc' in result.output
+ assert "abc" in result.output
- expected = 'a\nabc\n'
+ expected = "a\nabc\n"
assert expected in result.output
@dbtest
def test_execute_arg_with_table(executor):
- run(executor, 'create table test (a text)')
+ run(executor, "create table test (a text)")
run(executor, 'insert into test values("abc")')
- sql = 'select * from test;'
+ sql = "select * from test;"
runner = CliRunner()
- result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--table'])
- expected = '+-----+\n| a |\n+-----+\n| abc |\n+-----+\n'
+ result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql] + ["--table"])
+ expected = "+-----+\n| a |\n+-----+\n| abc |\n+-----+\n"
assert result.exit_code == 0
assert expected in result.output
@@ -66,12 +76,12 @@ def test_execute_arg_with_table(executor):
@dbtest
def test_execute_arg_with_csv(executor):
- run(executor, 'create table test (a text)')
+ run(executor, "create table test (a text)")
run(executor, 'insert into test values("abc")')
- sql = 'select * from test;'
+ sql = "select * from test;"
runner = CliRunner()
- result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--csv'])
+ result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql] + ["--csv"])
expected = '"a"\n"abc"\n'
assert result.exit_code == 0
@@ -80,35 +90,29 @@ def test_execute_arg_with_csv(executor):
@dbtest
def test_batch_mode(executor):
- run(executor, '''create table test(a text)''')
- run(executor, '''insert into test values('abc'), ('def'), ('ghi')''')
+ run(executor, """create table test(a text)""")
+ run(executor, """insert into test values('abc'), ('def'), ('ghi')""")
- sql = (
- 'select count(*) from test;\n'
- 'select * from test limit 1;'
- )
+ sql = "select count(*) from test;\n" "select * from test limit 1;"
runner = CliRunner()
result = runner.invoke(cli, args=CLI_ARGS, input=sql)
assert result.exit_code == 0
- assert 'count(*)\n3\na\nabc\n' in "".join(result.output)
+ assert "count(*)\n3\na\nabc\n" in "".join(result.output)
@dbtest
def test_batch_mode_table(executor):
- run(executor, '''create table test(a text)''')
- run(executor, '''insert into test values('abc'), ('def'), ('ghi')''')
+ run(executor, """create table test(a text)""")
+ run(executor, """insert into test values('abc'), ('def'), ('ghi')""")
- sql = (
- 'select count(*) from test;\n'
- 'select * from test limit 1;'
- )
+ sql = "select count(*) from test;\n" "select * from test limit 1;"
runner = CliRunner()
- result = runner.invoke(cli, args=CLI_ARGS + ['-t'], input=sql)
+ result = runner.invoke(cli, args=CLI_ARGS + ["-t"], input=sql)
- expected = (dedent("""\
+ expected = dedent("""\
+----------+
| count(*) |
+----------+
@@ -118,7 +122,7 @@ def test_batch_mode_table(executor):
| a |
+-----+
| abc |
- +-----+"""))
+ +-----+""")
assert result.exit_code == 0
assert expected in result.output
@@ -126,14 +130,13 @@ def test_batch_mode_table(executor):
@dbtest
def test_batch_mode_csv(executor):
- run(executor, '''create table test(a text, b text)''')
- run(executor,
- '''insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')''')
+ run(executor, """create table test(a text, b text)""")
+ run(executor, """insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')""")
- sql = 'select * from test;'
+ sql = "select * from test;"
runner = CliRunner()
- result = runner.invoke(cli, args=CLI_ARGS + ['--csv'], input=sql)
+ result = runner.invoke(cli, args=CLI_ARGS + ["--csv"], input=sql)
expected = '"a","b"\n"abc","de\nf"\n"ghi","jkl"\n'
@@ -150,15 +153,15 @@ def test_help_strings_end_with_periods():
"""Make sure click options have help text that end with a period."""
for param in cli.params:
if isinstance(param, click.core.Option):
- assert hasattr(param, 'help')
- assert param.help.endswith('.')
+ assert hasattr(param, "help")
+ assert param.help.endswith(".")
def test_command_descriptions_end_with_periods():
"""Make sure that mycli commands' descriptions end with a period."""
MyCli()
for _, command in SPECIAL_COMMANDS.items():
- assert command[3].endswith('.')
+ assert command[3].endswith(".")
def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
@@ -166,23 +169,23 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
clickoutput = ""
m = MyCli(myclirc=default_config_file)
- class TestOutput():
+ class TestOutput:
def get_size(self):
- size = namedtuple('Size', 'rows columns')
+ size = namedtuple("Size", "rows columns")
size.columns, size.rows = terminal_size
return size
- class TestExecute():
- host = 'test'
- user = 'test'
- dbname = 'test'
- server_info = ServerInfo.from_version_string('unknown')
+ class TestExecute:
+ host = "test"
+ user = "test"
+ dbname = "test"
+ server_info = ServerInfo.from_version_string("unknown")
port = 0
def server_type(self):
- return ['test']
+ return ["test"]
- class PromptBuffer():
+ class PromptBuffer:
output = TestOutput()
m.prompt_app = PromptBuffer()
@@ -199,8 +202,8 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
global clickoutput
clickoutput += s + "\n"
- monkeypatch.setattr(click, 'echo_via_pager', echo_via_pager)
- monkeypatch.setattr(click, 'secho', secho)
+ monkeypatch.setattr(click, "echo_via_pager", echo_via_pager)
+ monkeypatch.setattr(click, "secho", secho)
m.output(testdata)
if clickoutput.endswith("\n"):
clickoutput = clickoutput[:-1]
@@ -208,59 +211,29 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
def test_conditional_pager(monkeypatch):
- testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split(
- " ")
+ testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split(" ")
# User didn't set pager, output doesn't fit screen -> pager
- output(
- monkeypatch,
- terminal_size=(5, 10),
- testdata=testdata,
- explicit_pager=False,
- expect_pager=True
- )
+ output(monkeypatch, terminal_size=(5, 10), testdata=testdata, explicit_pager=False, expect_pager=True)
# User didn't set pager, output fits screen -> no pager
- output(
- monkeypatch,
- terminal_size=(20, 20),
- testdata=testdata,
- explicit_pager=False,
- expect_pager=False
- )
+ output(monkeypatch, terminal_size=(20, 20), testdata=testdata, explicit_pager=False, expect_pager=False)
# User manually configured pager, output doesn't fit screen -> pager
- output(
- monkeypatch,
- terminal_size=(5, 10),
- testdata=testdata,
- explicit_pager=True,
- expect_pager=True
- )
+ output(monkeypatch, terminal_size=(5, 10), testdata=testdata, explicit_pager=True, expect_pager=True)
# User manually configured pager, output fit screen -> pager
- output(
- monkeypatch,
- terminal_size=(20, 20),
- testdata=testdata,
- explicit_pager=True,
- expect_pager=True
- )
+ output(monkeypatch, terminal_size=(20, 20), testdata=testdata, explicit_pager=True, expect_pager=True)
- SPECIAL_COMMANDS['nopager'].handler()
- output(
- monkeypatch,
- terminal_size=(5, 10),
- testdata=testdata,
- explicit_pager=False,
- expect_pager=False
- )
- SPECIAL_COMMANDS['pager'].handler('')
+ SPECIAL_COMMANDS["nopager"].handler()
+ output(monkeypatch, terminal_size=(5, 10), testdata=testdata, explicit_pager=False, expect_pager=False)
+ SPECIAL_COMMANDS["pager"].handler("")
def test_reserved_space_is_integer(monkeypatch):
"""Make sure that reserved space is returned as an integer."""
+
def stub_terminal_size():
return (5, 5)
with monkeypatch.context() as m:
- m.setattr(shutil, 'get_terminal_size', stub_terminal_size)
+ m.setattr(shutil, "get_terminal_size", stub_terminal_size)
mycli = MyCli()
assert isinstance(mycli.get_reserved_space(), int)
@@ -268,18 +241,20 @@ def test_reserved_space_is_integer(monkeypatch):
def test_list_dsn():
runner = CliRunner()
# keep Windows from locking the file with delete=False
- with NamedTemporaryFile(mode="w",delete=False) as myclirc:
- myclirc.write(dedent("""\
+ with NamedTemporaryFile(mode="w", delete=False) as myclirc:
+ myclirc.write(
+ dedent("""\
[alias_dsn]
test = mysql://test/test
- """))
+ """)
+ )
myclirc.flush()
- args = ['--list-dsn', '--myclirc', myclirc.name]
+ args = ["--list-dsn", "--myclirc", myclirc.name]
result = runner.invoke(cli, args=args)
assert result.output == "test\n"
- result = runner.invoke(cli, args=args + ['--verbose'])
+ result = runner.invoke(cli, args=args + ["--verbose"])
assert result.output == "test : mysql://test/test\n"
-
+
# delete=False means we should try to clean up
try:
if os.path.exists(myclirc.name):
@@ -287,41 +262,41 @@ def test_list_dsn():
except Exception as e:
print(f"An error occurred while attempting to delete the file: {e}")
-
-
def test_prettify_statement():
- statement = 'SELECT 1'
+ statement = "SELECT 1"
m = MyCli()
pretty_statement = m.handle_prettify_binding(statement)
- assert pretty_statement == 'SELECT\n 1;'
+ assert pretty_statement == "SELECT\n 1;"
def test_unprettify_statement():
- statement = 'SELECT\n 1'
+ statement = "SELECT\n 1"
m = MyCli()
unpretty_statement = m.handle_unprettify_binding(statement)
- assert unpretty_statement == 'SELECT 1;'
+ assert unpretty_statement == "SELECT 1;"
def test_list_ssh_config():
runner = CliRunner()
# keep Windows from locking the file with delete=False
- with NamedTemporaryFile(mode="w",delete=False) as ssh_config:
- ssh_config.write(dedent("""\
+ with NamedTemporaryFile(mode="w", delete=False) as ssh_config:
+ ssh_config.write(
+ dedent("""\
Host test
Hostname test.example.com
User joe
Port 22222
IdentityFile ~/.ssh/gateway
- """))
+ """)
+ )
ssh_config.flush()
- args = ['--list-ssh-config', '--ssh-config-path', ssh_config.name]
+ args = ["--list-ssh-config", "--ssh-config-path", ssh_config.name]
result = runner.invoke(cli, args=args)
assert "test\n" in result.output
- result = runner.invoke(cli, args=args + ['--verbose'])
+ result = runner.invoke(cli, args=args + ["--verbose"])
assert "test : test.example.com\n" in result.output
-
+
# delete=False means we should try to clean up
try:
if os.path.exists(ssh_config.name):
@@ -343,7 +318,7 @@ def test_dsn(monkeypatch):
pass
class MockMyCli:
- config = {'alias_dsn': {}}
+ config = {"alias_dsn": {}}
def __init__(self, **args):
self.logger = Logger()
@@ -357,97 +332,109 @@ def test_dsn(monkeypatch):
pass
import mycli.main
- monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli)
+
+ monkeypatch.setattr(mycli.main, "MyCli", MockMyCli)
runner = CliRunner()
# When a user supplies a DSN as database argument to mycli,
# use these values.
- result = runner.invoke(mycli.main.cli, args=[
- "mysql://dsn_user:dsn_passwd@dsn_host:1/dsn_database"]
- )
+ result = runner.invoke(mycli.main.cli, args=["mysql://dsn_user:dsn_passwd@dsn_host:1/dsn_database"])
assert result.exit_code == 0, result.output + " " + str(result.exception)
- assert \
- MockMyCli.connect_args["user"] == "dsn_user" and \
- MockMyCli.connect_args["passwd"] == "dsn_passwd" and \
- MockMyCli.connect_args["host"] == "dsn_host" and \
- MockMyCli.connect_args["port"] == 1 and \
- MockMyCli.connect_args["database"] == "dsn_database"
+ assert (
+ MockMyCli.connect_args["user"] == "dsn_user"
+ and MockMyCli.connect_args["passwd"] == "dsn_passwd"
+ and MockMyCli.connect_args["host"] == "dsn_host"
+ and MockMyCli.connect_args["port"] == 1
+ and MockMyCli.connect_args["database"] == "dsn_database"
+ )
MockMyCli.connect_args = None
# When a use supplies a DSN as database argument to mycli,
# and used command line arguments, use the command line
# arguments.
- result = runner.invoke(mycli.main.cli, args=[
- "mysql://dsn_user:dsn_passwd@dsn_host:2/dsn_database",
- "--user", "arg_user",
- "--password", "arg_password",
- "--host", "arg_host",
- "--port", "3",
- "--database", "arg_database",
- ])
+ result = runner.invoke(
+ mycli.main.cli,
+ args=[
+ "mysql://dsn_user:dsn_passwd@dsn_host:2/dsn_database",
+ "--user",
+ "arg_user",
+ "--password",
+ "arg_password",
+ "--host",
+ "arg_host",
+ "--port",
+ "3",
+ "--database",
+ "arg_database",
+ ],
+ )
assert result.exit_code == 0, result.output + " " + str(result.exception)
- assert \
- MockMyCli.connect_args["user"] == "arg_user" and \
- MockMyCli.connect_args["passwd"] == "arg_password" and \
- MockMyCli.connect_args["host"] == "arg_host" and \
- MockMyCli.connect_args["port"] == 3 and \
- MockMyCli.connect_args["database"] == "arg_database"
-
- MockMyCli.config = {
- 'alias_dsn': {
- 'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database'
- }
- }
+ assert (
+ MockMyCli.connect_args["user"] == "arg_user"
+ and MockMyCli.connect_args["passwd"] == "arg_password"
+ and MockMyCli.connect_args["host"] == "arg_host"
+ and MockMyCli.connect_args["port"] == 3
+ and MockMyCli.connect_args["database"] == "arg_database"
+ )
+
+ MockMyCli.config = {"alias_dsn": {"test": "mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database"}}
MockMyCli.connect_args = None
# When a user uses a DSN from the configuration file (alias_dsn),
# use these values.
- result = runner.invoke(cli, args=['--dsn', 'test'])
+ result = runner.invoke(cli, args=["--dsn", "test"])
assert result.exit_code == 0, result.output + " " + str(result.exception)
- assert \
- MockMyCli.connect_args["user"] == "alias_dsn_user" and \
- MockMyCli.connect_args["passwd"] == "alias_dsn_passwd" and \
- MockMyCli.connect_args["host"] == "alias_dsn_host" and \
- MockMyCli.connect_args["port"] == 4 and \
- MockMyCli.connect_args["database"] == "alias_dsn_database"
-
- MockMyCli.config = {
- 'alias_dsn': {
- 'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database'
- }
- }
+ assert (
+ MockMyCli.connect_args["user"] == "alias_dsn_user"
+ and MockMyCli.connect_args["passwd"] == "alias_dsn_passwd"
+ and MockMyCli.connect_args["host"] == "alias_dsn_host"
+ and MockMyCli.connect_args["port"] == 4
+ and MockMyCli.connect_args["database"] == "alias_dsn_database"
+ )
+
+ MockMyCli.config = {"alias_dsn": {"test": "mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database"}}
MockMyCli.connect_args = None
# When a user uses a DSN from the configuration file (alias_dsn)
# and used command line arguments, use the command line arguments.
- result = runner.invoke(cli, args=[
- '--dsn', 'test', '',
- "--user", "arg_user",
- "--password", "arg_password",
- "--host", "arg_host",
- "--port", "5",
- "--database", "arg_database",
- ])
+ result = runner.invoke(
+ cli,
+ args=[
+ "--dsn",
+ "test",
+ "",
+ "--user",
+ "arg_user",
+ "--password",
+ "arg_password",
+ "--host",
+ "arg_host",
+ "--port",
+ "5",
+ "--database",
+ "arg_database",
+ ],
+ )
assert result.exit_code == 0, result.output + " " + str(result.exception)
- assert \
- MockMyCli.connect_args["user"] == "arg_user" and \
- MockMyCli.connect_args["passwd"] == "arg_password" and \
- MockMyCli.connect_args["host"] == "arg_host" and \
- MockMyCli.connect_args["port"] == 5 and \
- MockMyCli.connect_args["database"] == "arg_database"
+ assert (
+ MockMyCli.connect_args["user"] == "arg_user"
+ and MockMyCli.connect_args["passwd"] == "arg_password"
+ and MockMyCli.connect_args["host"] == "arg_host"
+ and MockMyCli.connect_args["port"] == 5
+ and MockMyCli.connect_args["database"] == "arg_database"
+ )
# Use a DSN without password
- result = runner.invoke(mycli.main.cli, args=[
- "mysql://dsn_user@dsn_host:6/dsn_database"]
- )
+ result = runner.invoke(mycli.main.cli, args=["mysql://dsn_user@dsn_host:6/dsn_database"])
assert result.exit_code == 0, result.output + " " + str(result.exception)
- assert \
- MockMyCli.connect_args["user"] == "dsn_user" and \
- MockMyCli.connect_args["passwd"] is None and \
- MockMyCli.connect_args["host"] == "dsn_host" and \
- MockMyCli.connect_args["port"] == 6 and \
- MockMyCli.connect_args["database"] == "dsn_database"
+ assert (
+ MockMyCli.connect_args["user"] == "dsn_user"
+ and MockMyCli.connect_args["passwd"] is None
+ and MockMyCli.connect_args["host"] == "dsn_host"
+ and MockMyCli.connect_args["port"] == 6
+ and MockMyCli.connect_args["database"] == "dsn_database"
+ )
def test_ssh_config(monkeypatch):
@@ -463,7 +450,7 @@ def test_ssh_config(monkeypatch):
pass
class MockMyCli:
- config = {'alias_dsn': {}}
+ config = {"alias_dsn": {}}
def __init__(self, **args):
self.logger = Logger()
@@ -477,58 +464,62 @@ def test_ssh_config(monkeypatch):
pass
import mycli.main
- monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli)
+
+ monkeypatch.setattr(mycli.main, "MyCli", MockMyCli)
runner = CliRunner()
# Setup temporary configuration
# keep Windows from locking the file with delete=False
- with NamedTemporaryFile(mode="w",delete=False) as ssh_config:
- ssh_config.write(dedent("""\
+ with NamedTemporaryFile(mode="w", delete=False) as ssh_config:
+ ssh_config.write(
+ dedent("""\
Host test
Hostname test.example.com
User joe
Port 22222
IdentityFile ~/.ssh/gateway
- """))
+ """)
+ )
ssh_config.flush()
# When a user supplies a ssh config.
- result = runner.invoke(mycli.main.cli, args=[
- "--ssh-config-path",
- ssh_config.name,
- "--ssh-config-host",
- "test"
- ])
- assert result.exit_code == 0, result.output + \
- " " + str(result.exception)
- assert \
- MockMyCli.connect_args["ssh_user"] == "joe" and \
- MockMyCli.connect_args["ssh_host"] == "test.example.com" and \
- MockMyCli.connect_args["ssh_port"] == 22222 and \
- MockMyCli.connect_args["ssh_key_filename"] == os.path.expanduser(
- "~") + "/.ssh/gateway"
+ result = runner.invoke(mycli.main.cli, args=["--ssh-config-path", ssh_config.name, "--ssh-config-host", "test"])
+ assert result.exit_code == 0, result.output + " " + str(result.exception)
+ assert (
+ MockMyCli.connect_args["ssh_user"] == "joe"
+ and MockMyCli.connect_args["ssh_host"] == "test.example.com"
+ and MockMyCli.connect_args["ssh_port"] == 22222
+ and MockMyCli.connect_args["ssh_key_filename"] == os.path.expanduser("~") + "/.ssh/gateway"
+ )
# When a user supplies a ssh config host as argument to mycli,
# and used command line arguments, use the command line
# arguments.
- result = runner.invoke(mycli.main.cli, args=[
- "--ssh-config-path",
- ssh_config.name,
- "--ssh-config-host",
- "test",
- "--ssh-user", "arg_user",
- "--ssh-host", "arg_host",
- "--ssh-port", "3",
- "--ssh-key-filename", "/path/to/key"
- ])
- assert result.exit_code == 0, result.output + \
- " " + str(result.exception)
- assert \
- MockMyCli.connect_args["ssh_user"] == "arg_user" and \
- MockMyCli.connect_args["ssh_host"] == "arg_host" and \
- MockMyCli.connect_args["ssh_port"] == 3 and \
- MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key"
-
+ result = runner.invoke(
+ mycli.main.cli,
+ args=[
+ "--ssh-config-path",
+ ssh_config.name,
+ "--ssh-config-host",
+ "test",
+ "--ssh-user",
+ "arg_user",
+ "--ssh-host",
+ "arg_host",
+ "--ssh-port",
+ "3",
+ "--ssh-key-filename",
+ "/path/to/key",
+ ],
+ )
+ assert result.exit_code == 0, result.output + " " + str(result.exception)
+ assert (
+ MockMyCli.connect_args["ssh_user"] == "arg_user"
+ and MockMyCli.connect_args["ssh_host"] == "arg_host"
+ and MockMyCli.connect_args["ssh_port"] == 3
+ and MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key"
+ )
+
# delete=False means we should try to clean up
try:
if os.path.exists(ssh_config.name):
@@ -542,9 +533,7 @@ def test_init_command_arg(executor):
init_command = "set sql_select_limit=1000"
sql = 'show variables like "sql_select_limit";'
runner = CliRunner()
- result = runner.invoke(
- cli, args=CLI_ARGS + ["--init-command", init_command], input=sql
- )
+ result = runner.invoke(cli, args=CLI_ARGS + ["--init-command", init_command], input=sql)
expected = "sql_select_limit\t1000\n"
assert result.exit_code == 0
@@ -553,18 +542,13 @@ def test_init_command_arg(executor):
@dbtest
def test_init_command_multiple_arg(executor):
- init_command = 'set sql_select_limit=2000; set max_join_size=20000'
- sql = (
- 'show variables like "sql_select_limit";\n'
- 'show variables like "max_join_size"'
- )
+ init_command = "set sql_select_limit=2000; set max_join_size=20000"
+ sql = 'show variables like "sql_select_limit";\n' 'show variables like "max_join_size"'
runner = CliRunner()
- result = runner.invoke(
- cli, args=CLI_ARGS + ['--init-command', init_command], input=sql
- )
+ result = runner.invoke(cli, args=CLI_ARGS + ["--init-command", init_command], input=sql)
- expected_sql_select_limit = 'sql_select_limit\t2000\n'
- expected_max_join_size = 'max_join_size\t20000\n'
+ expected_sql_select_limit = "sql_select_limit\t2000\n"
+ expected_max_join_size = "max_join_size\t20000\n"
assert result.exit_code == 0
assert expected_sql_select_limit in result.output
diff --git a/test/test_naive_completion.py b/test/test_naive_completion.py
index 0bc3bf8..31ac165 100644
--- a/test/test_naive_completion.py
+++ b/test/test_naive_completion.py
@@ -6,56 +6,48 @@ from prompt_toolkit.document import Document
@pytest.fixture
def completer():
import mycli.sqlcompleter as sqlcompleter
+
return sqlcompleter.SQLCompleter(smart_completion=False)
@pytest.fixture
def complete_event():
from unittest.mock import Mock
+
return Mock()
def test_empty_string_completion(completer, complete_event):
- text = ''
+ text = ""
position = 0
- result = list(completer.get_completions(
- Document(text=text, cursor_position=position),
- complete_event))
+ result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(map(Completion, completer.all_completions))
def test_select_keyword_completion(completer, complete_event):
- text = 'SEL'
- position = len('SEL')
- result = list(completer.get_completions(
- Document(text=text, cursor_position=position),
- complete_event))
- assert result == list([Completion(text='SELECT', start_position=-3)])
+ text = "SEL"
+ position = len("SEL")
+ result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
+ assert result == list([Completion(text="SELECT", start_position=-3)])
def test_function_name_completion(completer, complete_event):
- text = 'SELECT MA'
- position = len('SELECT MA')
- result = list(completer.get_completions(
- Document(text=text, cursor_position=position),
- complete_event))
+ text = "SELECT MA"
+ position = len("SELECT MA")
+ result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert sorted(x.text for x in result) == ["MASTER", "MAX"]
def test_column_name_completion(completer, complete_event):
- text = 'SELECT FROM users'
- position = len('SELECT ')
- result = list(completer.get_completions(
- Document(text=text, cursor_position=position),
- complete_event))
+ text = "SELECT FROM users"
+ position = len("SELECT ")
+ result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(map(Completion, completer.all_completions))
def test_special_name_completion(completer, complete_event):
- text = '\\'
- position = len('\\')
- result = set(completer.get_completions(
- Document(text=text, cursor_position=position),
- complete_event))
+ text = "\\"
+ position = len("\\")
+ result = set(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
# Special commands will NOT be suggested during naive completion mode.
assert result == set()
diff --git a/test/test_parseutils.py b/test/test_parseutils.py
index 920a08d..0925299 100644
--- a/test/test_parseutils.py
+++ b/test/test_parseutils.py
@@ -1,67 +1,72 @@
import pytest
from mycli.packages.parseutils import (
- extract_tables, query_starts_with, queries_start_with, is_destructive, query_has_where_clause,
- is_dropping_database)
+ extract_tables,
+ query_starts_with,
+ queries_start_with,
+ is_destructive,
+ query_has_where_clause,
+ is_dropping_database,
+)
def test_empty_string():
- tables = extract_tables('')
+ tables = extract_tables("")
assert tables == []
def test_simple_select_single_table():
- tables = extract_tables('select * from abc')
- assert tables == [(None, 'abc', None)]
+ tables = extract_tables("select * from abc")
+ assert tables == [(None, "abc", None)]
def test_simple_select_single_table_schema_qualified():
- tables = extract_tables('select * from abc.def')
- assert tables == [('abc', 'def', None)]
+ tables = extract_tables("select * from abc.def")
+ assert tables == [("abc", "def", None)]
def test_simple_select_multiple_tables():
- tables = extract_tables('select * from abc, def')
- assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
+ tables = extract_tables("select * from abc, def")
+ assert sorted(tables) == [(None, "abc", None), (None, "def", None)]
def test_simple_select_multiple_tables_schema_qualified():
- tables = extract_tables('select * from abc.def, ghi.jkl')
- assert sorted(tables) == [('abc', 'def', None), ('ghi', 'jkl', None)]
+ tables = extract_tables("select * from abc.def, ghi.jkl")
+ assert sorted(tables) == [("abc", "def", None), ("ghi", "jkl", None)]
def test_simple_select_with_cols_single_table():
- tables = extract_tables('select a,b from abc')
- assert tables == [(None, 'abc', None)]
+ tables = extract_tables("select a,b from abc")
+ assert tables == [(None, "abc", None)]
def test_simple_select_with_cols_single_table_schema_qualified():
- tables = extract_tables('select a,b from abc.def')
- assert tables == [('abc', 'def', None)]
+ tables = extract_tables("select a,b from abc.def")
+ assert tables == [("abc", "def", None)]
def test_simple_select_with_cols_multiple_tables():
- tables = extract_tables('select a,b from abc, def')
- assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
+ tables = extract_tables("select a,b from abc, def")
+ assert sorted(tables) == [(None, "abc", None), (None, "def", None)]
def test_simple_select_with_cols_multiple_tables_with_schema():
- tables = extract_tables('select a,b from abc.def, def.ghi')
- assert sorted(tables) == [('abc', 'def', None), ('def', 'ghi', None)]
+ tables = extract_tables("select a,b from abc.def, def.ghi")
+ assert sorted(tables) == [("abc", "def", None), ("def", "ghi", None)]
def test_select_with_hanging_comma_single_table():
- tables = extract_tables('select a, from abc')
- assert tables == [(None, 'abc', None)]
+ tables = extract_tables("select a, from abc")
+ assert tables == [(None, "abc", None)]
def test_select_with_hanging_comma_multiple_tables():
- tables = extract_tables('select a, from abc, def')
- assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
+ tables = extract_tables("select a, from abc, def")
+ assert sorted(tables) == [(None, "abc", None), (None, "def", None)]
def test_select_with_hanging_period_multiple_tables():
- tables = extract_tables('SELECT t1. FROM tabl1 t1, tabl2 t2')
- assert sorted(tables) == [(None, 'tabl1', 't1'), (None, 'tabl2', 't2')]
+ tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2")
+ assert sorted(tables) == [(None, "tabl1", "t1"), (None, "tabl2", "t2")]
def test_simple_insert_single_table():
@@ -69,97 +74,80 @@ def test_simple_insert_single_table():
# sqlparse mistakenly assigns an alias to the table
# assert tables == [(None, 'abc', None)]
- assert tables == [(None, 'abc', 'abc')]
+ assert tables == [(None, "abc", "abc")]
@pytest.mark.xfail
def test_simple_insert_single_table_schema_qualified():
tables = extract_tables('insert into abc.def (id, name) values (1, "def")')
- assert tables == [('abc', 'def', None)]
+ assert tables == [("abc", "def", None)]
def test_simple_update_table():
- tables = extract_tables('update abc set id = 1')
- assert tables == [(None, 'abc', None)]
+ tables = extract_tables("update abc set id = 1")
+ assert tables == [(None, "abc", None)]
def test_simple_update_table_with_schema():
- tables = extract_tables('update abc.def set id = 1')
- assert tables == [('abc', 'def', None)]
+ tables = extract_tables("update abc.def set id = 1")
+ assert tables == [("abc", "def", None)]
def test_join_table():
- tables = extract_tables('SELECT * FROM abc a JOIN def d ON a.id = d.num')
- assert sorted(tables) == [(None, 'abc', 'a'), (None, 'def', 'd')]
+ tables = extract_tables("SELECT * FROM abc a JOIN def d ON a.id = d.num")
+ assert sorted(tables) == [(None, "abc", "a"), (None, "def", "d")]
def test_join_table_schema_qualified():
- tables = extract_tables(
- 'SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num')
- assert tables == [('abc', 'def', 'x'), ('ghi', 'jkl', 'y')]
+ tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num")
+ assert tables == [("abc", "def", "x"), ("ghi", "jkl", "y")]
def test_join_as_table():
- tables = extract_tables('SELECT * FROM my_table AS m WHERE m.a > 5')
- assert tables == [(None, 'my_table', 'm')]
+ tables = extract_tables("SELECT * FROM my_table AS m WHERE m.a > 5")
+ assert tables == [(None, "my_table", "m")]
def test_query_starts_with():
- query = 'USE test;'
- assert query_starts_with(query, ('use', )) is True
+ query = "USE test;"
+ assert query_starts_with(query, ("use",)) is True
- query = 'DROP DATABASE test;'
- assert query_starts_with(query, ('use', )) is False
+ query = "DROP DATABASE test;"
+ assert query_starts_with(query, ("use",)) is False
def test_query_starts_with_comment():
- query = '# comment\nUSE test;'
- assert query_starts_with(query, ('use', )) is True
+ query = "# comment\nUSE test;"
+ assert query_starts_with(query, ("use",)) is True
def test_queries_start_with():
- sql = (
- '# comment\n'
- 'show databases;'
- 'use foo;'
- )
- assert queries_start_with(sql, ('show', 'select')) is True
- assert queries_start_with(sql, ('use', 'drop')) is True
- assert queries_start_with(sql, ('delete', 'update')) is False
+ sql = "# comment\n" "show databases;" "use foo;"
+ assert queries_start_with(sql, ("show", "select")) is True
+ assert queries_start_with(sql, ("use", "drop")) is True
+ assert queries_start_with(sql, ("delete", "update")) is False
def test_is_destructive():
- sql = (
- 'use test;\n'
- 'show databases;\n'
- 'drop database foo;'
- )
+ sql = "use test;\n" "show databases;\n" "drop database foo;"
assert is_destructive(sql) is True
def test_is_destructive_update_with_where_clause():
- sql = (
- 'use test;\n'
- 'show databases;\n'
- 'UPDATE test SET x = 1 WHERE id = 1;'
- )
+ sql = "use test;\n" "show databases;\n" "UPDATE test SET x = 1 WHERE id = 1;"
assert is_destructive(sql) is False
def test_is_destructive_update_without_where_clause():
- sql = (
- 'use test;\n'
- 'show databases;\n'
- 'UPDATE test SET x = 1;'
- )
+ sql = "use test;\n" "show databases;\n" "UPDATE test SET x = 1;"
assert is_destructive(sql) is True
@pytest.mark.parametrize(
- ('sql', 'has_where_clause'),
+ ("sql", "has_where_clause"),
[
- ('update test set dummy = 1;', False),
- ('update test set dummy = 1 where id = 1);', True),
+ ("update test set dummy = 1;", False),
+ ("update test set dummy = 1 where id = 1);", True),
],
)
def test_query_has_where_clause(sql, has_where_clause):
@@ -167,24 +155,20 @@ def test_query_has_where_clause(sql, has_where_clause):
@pytest.mark.parametrize(
- ('sql', 'dbname', 'is_dropping'),
+ ("sql", "dbname", "is_dropping"),
[
- ('select bar from foo', 'foo', False),
- ('drop database "foo";', '`foo`', True),
- ('drop schema foo', 'foo', True),
- ('drop schema foo', 'bar', False),
- ('drop database bar', 'foo', False),
- ('drop database foo', None, False),
- ('drop database foo; create database foo', 'foo', False),
- ('drop database foo; create database bar', 'foo', True),
- ('select bar from foo; drop database bazz', 'foo', False),
- ('select bar from foo; drop database bazz', 'bazz', True),
- ('-- dropping database \n '
- 'drop -- really dropping \n '
- 'schema abc -- now it is dropped',
- 'abc',
- True)
- ]
+ ("select bar from foo", "foo", False),
+ ('drop database "foo";', "`foo`", True),
+ ("drop schema foo", "foo", True),
+ ("drop schema foo", "bar", False),
+ ("drop database bar", "foo", False),
+ ("drop database foo", None, False),
+ ("drop database foo; create database foo", "foo", False),
+ ("drop database foo; create database bar", "foo", True),
+ ("select bar from foo; drop database bazz", "foo", False),
+ ("select bar from foo; drop database bazz", "bazz", True),
+ ("-- dropping database \n " "drop -- really dropping \n " "schema abc -- now it is dropped", "abc", True),
+ ],
)
def test_is_dropping_database(sql, dbname, is_dropping):
assert is_dropping_database(sql, dbname) == is_dropping
diff --git a/test/test_prompt_utils.py b/test/test_prompt_utils.py
index 2373fac..625e022 100644
--- a/test/test_prompt_utils.py
+++ b/test/test_prompt_utils.py
@@ -4,8 +4,8 @@ from mycli.packages.prompt_utils import confirm_destructive_query
def test_confirm_destructive_query_notty():
- stdin = click.get_text_stream('stdin')
+ stdin = click.get_text_stream("stdin")
assert stdin.isatty() is False
- sql = 'drop database foo;'
+ sql = "drop database foo;"
assert confirm_destructive_query(sql) is None
diff --git a/test/test_smart_completion_public_schema_only.py b/test/test_smart_completion_public_schema_only.py
index 30b15ac..8ad40a4 100644
--- a/test/test_smart_completion_public_schema_only.py
+++ b/test/test_smart_completion_public_schema_only.py
@@ -43,49 +43,35 @@ def complete_event():
def test_special_name_completion(completer, complete_event):
text = "\\d"
position = len("\\d")
- result = completer.get_completions(
- Document(text=text, cursor_position=position), complete_event
- )
+ result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
assert result == [Completion(text="\\dt", start_position=-2)]
def test_empty_string_completion(completer, complete_event):
text = ""
position = 0
- result = list(
- completer.get_completions(
- Document(text=text, cursor_position=position), complete_event
- )
- )
- assert (
- list(map(Completion, completer.keywords + completer.special_commands)) == result
- )
+ result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
+ assert list(map(Completion, completer.keywords + completer.special_commands)) == result
def test_select_keyword_completion(completer, complete_event):
text = "SEL"
position = len("SEL")
- result = completer.get_completions(
- Document(text=text, cursor_position=position), complete_event
- )
+ result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
assert list(result) == list([Completion(text="SELECT", start_position=-3)])
def test_select_star(completer, complete_event):
text = "SELECT * "
position = len(text)
- result = completer.get_completions(
- Document(text=text, cursor_position=position), complete_event
- )
+ result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
assert list(result) == list(map(Completion, completer.keywords))
def test_table_completion(completer, complete_event):
text = "SELECT * FROM "
position = len(text)
- result = completer.get_completions(
- Document(text=text, cursor_position=position), complete_event
- )
+ result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
assert list(result) == list(
[
Completion(text="users", start_position=0),
@@ -99,9 +85,7 @@ def test_table_completion(completer, complete_event):
def test_function_name_completion(completer, complete_event):
text = "SELECT MA"
position = len("SELECT MA")
- result = completer.get_completions(
- Document(text=text, cursor_position=position), complete_event
- )
+ result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
assert list(result) == list(
[
Completion(text="MAX", start_position=-2),
@@ -127,11 +111,7 @@ def test_suggested_column_names(completer, complete_event):
"""
text = "SELECT from users"
position = len("SELECT ")
- result = list(
- completer.get_completions(
- Document(text=text, cursor_position=position), complete_event
- )
- )
+ result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="*", start_position=0),
@@ -157,9 +137,7 @@ def test_suggested_column_names_in_function(completer, complete_event):
"""
text = "SELECT MAX( from users"
position = len("SELECT MAX(")
- result = completer.get_completions(
- Document(text=text, cursor_position=position), complete_event
- )
+ result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
assert list(result) == list(
[
Completion(text="*", start_position=0),
@@ -181,11 +159,7 @@ def test_suggested_column_names_with_table_dot(completer, complete_event):
"""
text = "SELECT users. from users"
position = len("SELECT users.")
- result = list(
- completer.get_completions(
- Document(text=text, cursor_position=position), complete_event
- )
- )
+ result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="*", start_position=0),
@@ -207,11 +181,7 @@ def test_suggested_column_names_with_alias(completer, complete_event):
"""
text = "SELECT u. from users u"
position = len("SELECT u.")
- result = list(
- completer.get_completions(
- Document(text=text, cursor_position=position), complete_event
- )
- )
+ result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="*", start_position=0),
@@ -234,11 +204,7 @@ def test_suggested_multiple_column_names(completer, complete_event):
"""
text = "SELECT id, from users u"
position = len("SELECT id, ")
- result = list(
- completer.get_completions(
- Document(text=text, cursor_position=position), complete_event
- )
- )
+ result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="*", start_position=0),
@@ -264,11 +230,7 @@ def test_suggested_multiple_column_names_with_alias(completer, complete_event):
"""
text = "SELECT u.id, u. from users u"
position = len("SELECT u.id, u.")
- result = list(
- completer.get_completions(
- Document(text=text, cursor_position=position), complete_event
- )
- )
+ result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="*", start_position=0),
@@ -291,11 +253,7 @@ def test_suggested_multiple_column_names_with_dot(completer, complete_event):
"""
text = "SELECT users.id, users. from users u"
position = len("SELECT users.id, users.")
- result = list(
- completer.get_completions(
- Document(text=text, cursor_position=position), complete_event
- )
- )
+ result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="*", start_position=0),
@@ -310,11 +268,7 @@ def test_suggested_multiple_column_names_with_dot(completer, complete_event):
def test_suggested_aliases_after_on(completer, complete_event):
text = "SELECT u.name, o.id FROM users u JOIN orders o ON "
position = len("SELECT u.name, o.id FROM users u JOIN orders o ON ")
- result = list(
- completer.get_completions(
- Document(text=text, cursor_position=position), complete_event
- )
- )
+ result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="u", start_position=0),
@@ -326,11 +280,7 @@ def test_suggested_aliases_after_on(completer, complete_event):
def test_suggested_aliases_after_on_right_side(completer, complete_event):
text = "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = "
position = len("SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ")
- result = list(
- completer.get_completions(
- Document(text=text, cursor_position=position), complete_event
- )
- )
+ result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="u", start_position=0),
@@ -342,11 +292,7 @@ def test_suggested_aliases_after_on_right_side(completer, complete_event):
def test_suggested_tables_after_on(completer, complete_event):
text = "SELECT users.name, orders.id FROM users JOIN orders ON "
position = len("SELECT users.name, orders.id FROM users JOIN orders ON ")
- result = list(
- completer.get_completions(
- Document(text=text, cursor_position=position), complete_event
- )
- )
+ result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="users", start_position=0),
@@ -357,14 +303,8 @@ def test_suggested_tables_after_on(completer, complete_event):
def test_suggested_tables_after_on_right_side(completer, complete_event):
text = "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = "
- position = len(
- "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = "
- )
- result = list(
- completer.get_completions(
- Document(text=text, cursor_position=position), complete_event
- )
- )
+ position = len("SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ")
+ result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="users", start_position=0),
@@ -376,11 +316,7 @@ def test_suggested_tables_after_on_right_side(completer, complete_event):
def test_table_names_after_from(completer, complete_event):
text = "SELECT * FROM "
position = len("SELECT * FROM ")
- result = list(
- completer.get_completions(
- Document(text=text, cursor_position=position), complete_event
- )
- )
+ result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="users", start_position=0),
@@ -394,29 +330,21 @@ def test_table_names_after_from(completer, complete_event):
def test_auto_escaped_col_names(completer, complete_event):
text = "SELECT from `select`"
position = len("SELECT ")
- result = list(
- completer.get_completions(
- Document(text=text, cursor_position=position), complete_event
- )
- )
+ result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == [
Completion(text="*", start_position=0),
Completion(text="id", start_position=0),
Completion(text="`insert`", start_position=0),
Completion(text="`ABC`", start_position=0),
- ] + list(map(Completion, completer.functions)) + [
- Completion(text="select", start_position=0)
- ] + list(map(Completion, completer.keywords))
+ ] + list(map(Completion, completer.functions)) + [Completion(text="select", start_position=0)] + list(
+ map(Completion, completer.keywords)
+ )
def test_un_escaped_table_names(completer, complete_event):
text = "SELECT from réveillé"
position = len("SELECT ")
- result = list(
- completer.get_completions(
- Document(text=text, cursor_position=position), complete_event
- )
- )
+ result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
assert result == list(
[
Completion(text="*", start_position=0),
@@ -464,10 +392,6 @@ def dummy_list_path(dir_name):
)
def test_file_name_completion(completer, complete_event, text, expected):
position = len(text)
- result = list(
- completer.get_completions(
- Document(text=text, cursor_position=position), complete_event
- )
- )
+ result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event))
expected = list((Completion(txt, pos) for txt, pos in expected))
assert result == expected
diff --git a/test/test_special_iocommands.py b/test/test_special_iocommands.py
index d0ca45f..bea5620 100644
--- a/test/test_special_iocommands.py
+++ b/test/test_special_iocommands.py
@@ -17,11 +17,11 @@ def test_set_get_pager():
assert mycli.packages.special.is_pager_enabled()
mycli.packages.special.set_pager_enabled(False)
assert not mycli.packages.special.is_pager_enabled()
- mycli.packages.special.set_pager('less')
- assert os.environ['PAGER'] == "less"
+ mycli.packages.special.set_pager("less")
+ assert os.environ["PAGER"] == "less"
mycli.packages.special.set_pager(False)
- assert os.environ['PAGER'] == "less"
- del os.environ['PAGER']
+ assert os.environ["PAGER"] == "less"
+ del os.environ["PAGER"]
mycli.packages.special.set_pager(False)
mycli.packages.special.disable_pager()
assert not mycli.packages.special.is_pager_enabled()
@@ -42,45 +42,44 @@ def test_set_get_expanded_output():
def test_editor_command():
- assert mycli.packages.special.editor_command(r'hello\e')
- assert mycli.packages.special.editor_command(r'\ehello')
- assert not mycli.packages.special.editor_command(r'hello')
+ assert mycli.packages.special.editor_command(r"hello\e")
+ assert mycli.packages.special.editor_command(r"\ehello")
+ assert not mycli.packages.special.editor_command(r"hello")
- assert mycli.packages.special.get_filename(r'\e filename') == "filename"
+ assert mycli.packages.special.get_filename(r"\e filename") == "filename"
- os.environ['EDITOR'] = 'true'
- os.environ['VISUAL'] = 'true'
+ os.environ["EDITOR"] = "true"
+ os.environ["VISUAL"] = "true"
# Set the editor to Notepad on Windows
- if os.name != 'nt':
- mycli.packages.special.open_external_editor(sql=r'select 1') == "select 1"
+ if os.name != "nt":
+ mycli.packages.special.open_external_editor(sql=r"select 1") == "select 1"
else:
- pytest.skip('Skipping on Windows platform.')
-
+ pytest.skip("Skipping on Windows platform.")
def test_tee_command():
- mycli.packages.special.write_tee(u"hello world") # write without file set
+ mycli.packages.special.write_tee("hello world") # write without file set
# keep Windows from locking the file with delete=False
with tempfile.NamedTemporaryFile(delete=False) as f:
- mycli.packages.special.execute(None, u"tee " + f.name)
- mycli.packages.special.write_tee(u"hello world")
- if os.name=='nt':
+ mycli.packages.special.execute(None, "tee " + f.name)
+ mycli.packages.special.write_tee("hello world")
+ if os.name == "nt":
assert f.read() == b"hello world\r\n"
else:
assert f.read() == b"hello world\n"
- mycli.packages.special.execute(None, u"tee -o " + f.name)
- mycli.packages.special.write_tee(u"hello world")
+ mycli.packages.special.execute(None, "tee -o " + f.name)
+ mycli.packages.special.write_tee("hello world")
f.seek(0)
- if os.name=='nt':
+ if os.name == "nt":
assert f.read() == b"hello world\r\n"
else:
assert f.read() == b"hello world\n"
- mycli.packages.special.execute(None, u"notee")
- mycli.packages.special.write_tee(u"hello world")
+ mycli.packages.special.execute(None, "notee")
+ mycli.packages.special.write_tee("hello world")
f.seek(0)
- if os.name=='nt':
+ if os.name == "nt":
assert f.read() == b"hello world\r\n"
else:
assert f.read() == b"hello world\n"
@@ -92,52 +91,49 @@ def test_tee_command():
os.remove(f.name)
except Exception as e:
print(f"An error occurred while attempting to delete the file: {e}")
-
def test_tee_command_error():
with pytest.raises(TypeError):
- mycli.packages.special.execute(None, 'tee')
+ mycli.packages.special.execute(None, "tee")
with pytest.raises(OSError):
with tempfile.NamedTemporaryFile() as f:
os.chmod(f.name, stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH)
- mycli.packages.special.execute(None, 'tee {}'.format(f.name))
+ mycli.packages.special.execute(None, "tee {}".format(f.name))
@dbtest
-
@pytest.mark.skipif(os.name == "nt", reason="Bug: fails on Windows, needs fixing, singleton of FQ not working right")
def test_favorite_query():
with db_connection().cursor() as cur:
- query = u'select "✔"'
- mycli.packages.special.execute(cur, u'\\fs check {0}'.format(query))
- assert next(mycli.packages.special.execute(
- cur, u'\\f check'))[0] == "> " + query
+ query = 'select "✔"'
+ mycli.packages.special.execute(cur, "\\fs check {0}".format(query))
+ assert next(mycli.packages.special.execute(cur, "\\f check"))[0] == "> " + query
def test_once_command():
with pytest.raises(TypeError):
- mycli.packages.special.execute(None, u"\\once")
+ mycli.packages.special.execute(None, "\\once")
with pytest.raises(OSError):
- mycli.packages.special.execute(None, u"\\once /proc/access-denied")
+ mycli.packages.special.execute(None, "\\once /proc/access-denied")
- mycli.packages.special.write_once(u"hello world") # write without file set
+ mycli.packages.special.write_once("hello world") # write without file set
# keep Windows from locking the file with delete=False
with tempfile.NamedTemporaryFile(delete=False) as f:
- mycli.packages.special.execute(None, u"\\once " + f.name)
- mycli.packages.special.write_once(u"hello world")
- if os.name=='nt':
+ mycli.packages.special.execute(None, "\\once " + f.name)
+ mycli.packages.special.write_once("hello world")
+ if os.name == "nt":
assert f.read() == b"hello world\r\n"
else:
assert f.read() == b"hello world\n"
- mycli.packages.special.execute(None, u"\\once -o " + f.name)
- mycli.packages.special.write_once(u"hello world line 1")
- mycli.packages.special.write_once(u"hello world line 2")
+ mycli.packages.special.execute(None, "\\once -o " + f.name)
+ mycli.packages.special.write_once("hello world line 1")
+ mycli.packages.special.write_once("hello world line 2")
f.seek(0)
- if os.name=='nt':
+ if os.name == "nt":
assert f.read() == b"hello world line 1\r\nhello world line 2\r\n"
else:
assert f.read() == b"hello world line 1\nhello world line 2\n"
@@ -151,52 +147,47 @@ def test_once_command():
def test_pipe_once_command():
with pytest.raises(IOError):
- mycli.packages.special.execute(None, u"\\pipe_once")
+ mycli.packages.special.execute(None, "\\pipe_once")
with pytest.raises(OSError):
- mycli.packages.special.execute(
- None, u"\\pipe_once /proc/access-denied")
+ mycli.packages.special.execute(None, "\\pipe_once /proc/access-denied")
- if os.name == 'nt':
- mycli.packages.special.execute(None, u'\\pipe_once python -c "import sys; print(len(sys.stdin.read().strip()))"')
- mycli.packages.special.write_once(u"hello world")
+ if os.name == "nt":
+ mycli.packages.special.execute(None, '\\pipe_once python -c "import sys; print(len(sys.stdin.read().strip()))"')
+ mycli.packages.special.write_once("hello world")
mycli.packages.special.unset_pipe_once_if_written()
else:
- mycli.packages.special.execute(None, u"\\pipe_once wc")
- mycli.packages.special.write_once(u"hello world")
- mycli.packages.special.unset_pipe_once_if_written()
- # how to assert on wc output?
+ with tempfile.NamedTemporaryFile() as f:
+ mycli.packages.special.execute(None, "\\pipe_once tee " + f.name)
+ mycli.packages.special.write_pipe_once("hello world")
+ mycli.packages.special.unset_pipe_once_if_written()
+ f.seek(0)
+ assert f.read() == b"hello world\n"
def test_parseargfile():
"""Test that parseargfile expands the user directory."""
- expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'),
- 'mode': 'a'}
-
- if os.name=='nt':
- assert expected == mycli.packages.special.iocommands.parseargfile(
- '~\\filename')
+ expected = {"file": os.path.join(os.path.expanduser("~"), "filename"), "mode": "a"}
+
+ if os.name == "nt":
+ assert expected == mycli.packages.special.iocommands.parseargfile("~\\filename")
else:
- assert expected == mycli.packages.special.iocommands.parseargfile(
- '~/filename')
-
- expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'),
- 'mode': 'w'}
- if os.name=='nt':
- assert expected == mycli.packages.special.iocommands.parseargfile(
- '-o ~\\filename')
+ assert expected == mycli.packages.special.iocommands.parseargfile("~/filename")
+
+ expected = {"file": os.path.join(os.path.expanduser("~"), "filename"), "mode": "w"}
+ if os.name == "nt":
+ assert expected == mycli.packages.special.iocommands.parseargfile("-o ~\\filename")
else:
- assert expected == mycli.packages.special.iocommands.parseargfile(
- '-o ~/filename')
+ assert expected == mycli.packages.special.iocommands.parseargfile("-o ~/filename")
def test_parseargfile_no_file():
"""Test that parseargfile raises a TypeError if there is no filename."""
with pytest.raises(TypeError):
- mycli.packages.special.iocommands.parseargfile('')
+ mycli.packages.special.iocommands.parseargfile("")
with pytest.raises(TypeError):
- mycli.packages.special.iocommands.parseargfile('-o ')
+ mycli.packages.special.iocommands.parseargfile("-o ")
@dbtest
@@ -205,11 +196,9 @@ def test_watch_query_iteration():
the desired query and returns the given results."""
expected_value = "1"
query = "SELECT {0!s}".format(expected_value)
- expected_title = '> {0!s}'.format(query)
+ expected_title = "> {0!s}".format(query)
with db_connection().cursor() as cur:
- result = next(mycli.packages.special.iocommands.watch_query(
- arg=query, cur=cur
- ))
+ result = next(mycli.packages.special.iocommands.watch_query(arg=query, cur=cur))
assert result[0] == expected_title
assert result[2][0] == expected_value
@@ -230,14 +219,12 @@ def test_watch_query_full():
wait_interval = 1
expected_value = "1"
query = "SELECT {0!s}".format(expected_value)
- expected_title = '> {0!s}'.format(query)
+ expected_title = "> {0!s}".format(query)
expected_results = 4
ctrl_c_process = send_ctrl_c(wait_interval)
with db_connection().cursor() as cur:
results = list(
- result for result in mycli.packages.special.iocommands.watch_query(
- arg='{0!s} {1!s}'.format(watch_seconds, query), cur=cur
- )
+ result for result in mycli.packages.special.iocommands.watch_query(arg="{0!s} {1!s}".format(watch_seconds, query), cur=cur)
)
ctrl_c_process.join(1)
assert len(results) == expected_results
@@ -247,14 +234,12 @@ def test_watch_query_full():
@dbtest
-@patch('click.clear')
+@patch("click.clear")
def test_watch_query_clear(clear_mock):
"""Test that the screen is cleared with the -c flag of `watch` command
before execute the query."""
with db_connection().cursor() as cur:
- watch_gen = mycli.packages.special.iocommands.watch_query(
- arg='0.1 -c select 1;', cur=cur
- )
+ watch_gen = mycli.packages.special.iocommands.watch_query(arg="0.1 -c select 1;", cur=cur)
assert not clear_mock.called
next(watch_gen)
assert clear_mock.called
@@ -271,19 +256,20 @@ def test_watch_query_bad_arguments():
watch_query = mycli.packages.special.iocommands.watch_query
with db_connection().cursor() as cur:
with pytest.raises(ProgrammingError):
- next(watch_query('a select 1;', cur=cur))
+ next(watch_query("a select 1;", cur=cur))
with pytest.raises(ProgrammingError):
- next(watch_query('-a select 1;', cur=cur))
+ next(watch_query("-a select 1;", cur=cur))
with pytest.raises(ProgrammingError):
- next(watch_query('1 -a select 1;', cur=cur))
+ next(watch_query("1 -a select 1;", cur=cur))
with pytest.raises(ProgrammingError):
- next(watch_query('-c -a select 1;', cur=cur))
+ next(watch_query("-c -a select 1;", cur=cur))
@dbtest
-@patch('click.clear')
+@patch("click.clear")
def test_watch_query_interval_clear(clear_mock):
"""Test `watch` command with interval and clear flag."""
+
def test_asserts(gen):
clear_mock.reset_mock()
start = time()
@@ -296,46 +282,32 @@ def test_watch_query_interval_clear(clear_mock):
seconds = 1.0
watch_query = mycli.packages.special.iocommands.watch_query
with db_connection().cursor() as cur:
- test_asserts(watch_query('{0!s} -c select 1;'.format(seconds),
- cur=cur))
- test_asserts(watch_query('-c {0!s} select 1;'.format(seconds),
- cur=cur))
+ test_asserts(watch_query("{0!s} -c select 1;".format(seconds), cur=cur))
+ test_asserts(watch_query("-c {0!s} select 1;".format(seconds), cur=cur))
def test_split_sql_by_delimiter():
- for delimiter_str in (';', '$', '😀'):
+ for delimiter_str in (";", "$", "😀"):
mycli.packages.special.set_delimiter(delimiter_str)
sql_input = "select 1{} select \ufffc2".format(delimiter_str)
- queries = (
- "select 1",
- "select \ufffc2"
- )
- for query, parsed_query in zip(
- queries, mycli.packages.special.split_queries(sql_input)):
- assert(query == parsed_query)
+ queries = ("select 1", "select \ufffc2")
+ for query, parsed_query in zip(queries, mycli.packages.special.split_queries(sql_input)):
+ assert query == parsed_query
def test_switch_delimiter_within_query():
- mycli.packages.special.set_delimiter(';')
+ mycli.packages.special.set_delimiter(";")
sql_input = "select 1; delimiter $$ select 2 $$ select 3 $$"
- queries = (
- "select 1",
- "delimiter $$ select 2 $$ select 3 $$",
- "select 2",
- "select 3"
- )
- for query, parsed_query in zip(
- queries,
- mycli.packages.special.split_queries(sql_input)):
- assert(query == parsed_query)
+ queries = ("select 1", "delimiter $$ select 2 $$ select 3 $$", "select 2", "select 3")
+ for query, parsed_query in zip(queries, mycli.packages.special.split_queries(sql_input)):
+ assert query == parsed_query
def test_set_delimiter():
-
- for delim in ('foo', 'bar'):
+ for delim in ("foo", "bar"):
mycli.packages.special.set_delimiter(delim)
assert mycli.packages.special.get_current_delimiter() == delim
def teardown_function():
- mycli.packages.special.set_delimiter(';')
+ mycli.packages.special.set_delimiter(";")
diff --git a/test/test_sqlexecute.py b/test/test_sqlexecute.py
index ca186bc..17e082b 100644
--- a/test/test_sqlexecute.py
+++ b/test/test_sqlexecute.py
@@ -7,14 +7,11 @@ from mycli.sqlexecute import ServerInfo, ServerSpecies
from .utils import run, dbtest, set_expanded_output, is_expanded_output
-def assert_result_equal(result, title=None, rows=None, headers=None,
- status=None, auto_status=True, assert_contains=False):
+def assert_result_equal(result, title=None, rows=None, headers=None, status=None, auto_status=True, assert_contains=False):
"""Assert that an sqlexecute.run() result matches the expected values."""
if status is None and auto_status and rows:
- status = '{} row{} in set'.format(
- len(rows), 's' if len(rows) > 1 else '')
- fields = {'title': title, 'rows': rows, 'headers': headers,
- 'status': status}
+ status = "{} row{} in set".format(len(rows), "s" if len(rows) > 1 else "")
+ fields = {"title": title, "rows": rows, "headers": headers, "status": status}
if assert_contains:
# Do a loose match on the results using the *in* operator.
@@ -28,34 +25,35 @@ def assert_result_equal(result, title=None, rows=None, headers=None,
@dbtest
def test_conn(executor):
- run(executor, '''create table test(a text)''')
- run(executor, '''insert into test values('abc')''')
- results = run(executor, '''select * from test''')
+ run(executor, """create table test(a text)""")
+ run(executor, """insert into test values('abc')""")
+ results = run(executor, """select * from test""")
- assert_result_equal(results, headers=['a'], rows=[('abc',)])
+ assert_result_equal(results, headers=["a"], rows=[("abc",)])
@dbtest
def test_bools(executor):
- run(executor, '''create table test(a boolean)''')
- run(executor, '''insert into test values(True)''')
- results = run(executor, '''select * from test''')
+ run(executor, """create table test(a boolean)""")
+ run(executor, """insert into test values(True)""")
+ results = run(executor, """select * from test""")
- assert_result_equal(results, headers=['a'], rows=[(1,)])
+ assert_result_equal(results, headers=["a"], rows=[(1,)])
@dbtest
def test_binary(executor):
- run(executor, '''create table bt(geom linestring NOT NULL)''')
- run(executor, "INSERT INTO bt VALUES "
- "(ST_GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));")
- results = run(executor, '''select * from bt''')
-
- geom = (b'\x00\x00\x00\x00\x01\x02\x00\x00\x00\x02\x00\x00\x009\x7f\x13\n'
- b'\x11\x18]@4\xf4Op\xb1\xdeC@\x00\x00\x00\x00\x00\x18]@B>\xe8\xd9'
- b'\xac\xdeC@')
+ run(executor, """create table bt(geom linestring NOT NULL)""")
+ run(executor, "INSERT INTO bt VALUES " "(ST_GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));")
+ results = run(executor, """select * from bt""")
+
+ geom = (
+ b"\x00\x00\x00\x00\x01\x02\x00\x00\x00\x02\x00\x00\x009\x7f\x13\n"
+ b"\x11\x18]@4\xf4Op\xb1\xdeC@\x00\x00\x00\x00\x00\x18]@B>\xe8\xd9"
+ b"\xac\xdeC@"
+ )
- assert_result_equal(results, headers=['geom'], rows=[(geom,)])
+ assert_result_equal(results, headers=["geom"], rows=[(geom,)])
@dbtest
@@ -63,49 +61,48 @@ def test_table_and_columns_query(executor):
run(executor, "create table a(x text, y text)")
run(executor, "create table b(z text)")
- assert set(executor.tables()) == set([('a',), ('b',)])
- assert set(executor.table_columns()) == set(
- [('a', 'x'), ('a', 'y'), ('b', 'z')])
+ assert set(executor.tables()) == set([("a",), ("b",)])
+ assert set(executor.table_columns()) == set([("a", "x"), ("a", "y"), ("b", "z")])
@dbtest
def test_database_list(executor):
databases = executor.databases()
- assert 'mycli_test_db' in databases
+ assert "mycli_test_db" in databases
@dbtest
def test_invalid_syntax(executor):
with pytest.raises(pymysql.ProgrammingError) as excinfo:
- run(executor, 'invalid syntax!')
- assert 'You have an error in your SQL syntax;' in str(excinfo.value)
+ run(executor, "invalid syntax!")
+ assert "You have an error in your SQL syntax;" in str(excinfo.value)
@dbtest
def test_invalid_column_name(executor):
with pytest.raises(pymysql.err.OperationalError) as excinfo:
- run(executor, 'select invalid command')
+ run(executor, "select invalid command")
assert "Unknown column 'invalid' in 'field list'" in str(excinfo.value)
@dbtest
def test_unicode_support_in_output(executor):
run(executor, "create table unicodechars(t text)")
- run(executor, u"insert into unicodechars (t) values ('é')")
+ run(executor, "insert into unicodechars (t) values ('é')")
# See issue #24, this raises an exception without proper handling
- results = run(executor, u"select * from unicodechars")
- assert_result_equal(results, headers=['t'], rows=[(u'é',)])
+ results = run(executor, "select * from unicodechars")
+ assert_result_equal(results, headers=["t"], rows=[("é",)])
@dbtest
def test_multiple_queries_same_line(executor):
results = run(executor, "select 'foo'; select 'bar'")
- expected = [{'title': None, 'headers': ['foo'], 'rows': [('foo',)],
- 'status': '1 row in set'},
- {'title': None, 'headers': ['bar'], 'rows': [('bar',)],
- 'status': '1 row in set'}]
+ expected = [
+ {"title": None, "headers": ["foo"], "rows": [("foo",)], "status": "1 row in set"},
+ {"title": None, "headers": ["bar"], "rows": [("bar",)], "status": "1 row in set"},
+ ]
assert expected == results
@@ -113,7 +110,7 @@ def test_multiple_queries_same_line(executor):
def test_multiple_queries_same_line_syntaxerror(executor):
with pytest.raises(pymysql.ProgrammingError) as excinfo:
run(executor, "select 'foo'; invalid syntax")
- assert 'You have an error in your SQL syntax;' in str(excinfo.value)
+ assert "You have an error in your SQL syntax;" in str(excinfo.value)
@dbtest
@@ -125,15 +122,13 @@ def test_favorite_query(executor):
run(executor, "insert into test values('def')")
results = run(executor, "\\fs test-a select * from test where a like 'a%'")
- assert_result_equal(results, status='Saved.')
+ assert_result_equal(results, status="Saved.")
results = run(executor, "\\f test-a")
- assert_result_equal(results,
- title="> select * from test where a like 'a%'",
- headers=['a'], rows=[('abc',)], auto_status=False)
+ assert_result_equal(results, title="> select * from test where a like 'a%'", headers=["a"], rows=[("abc",)], auto_status=False)
results = run(executor, "\\fd test-a")
- assert_result_equal(results, status='test-a: Deleted')
+ assert_result_equal(results, status="test-a: Deleted")
@dbtest
@@ -144,158 +139,147 @@ def test_favorite_query_multiple_statement(executor):
run(executor, "insert into test values('abc')")
run(executor, "insert into test values('def')")
- results = run(executor,
- "\\fs test-ad select * from test where a like 'a%'; "
- "select * from test where a like 'd%'")
- assert_result_equal(results, status='Saved.')
+ results = run(executor, "\\fs test-ad select * from test where a like 'a%'; " "select * from test where a like 'd%'")
+ assert_result_equal(results, status="Saved.")
results = run(executor, "\\f test-ad")
- expected = [{'title': "> select * from test where a like 'a%'",
- 'headers': ['a'], 'rows': [('abc',)], 'status': None},
- {'title': "> select * from test where a like 'd%'",
- 'headers': ['a'], 'rows': [('def',)], 'status': None}]
+ expected = [
+ {"title": "> select * from test where a like 'a%'", "headers": ["a"], "rows": [("abc",)], "status": None},
+ {"title": "> select * from test where a like 'd%'", "headers": ["a"], "rows": [("def",)], "status": None},
+ ]
assert expected == results
results = run(executor, "\\fd test-ad")
- assert_result_equal(results, status='test-ad: Deleted')
+ assert_result_equal(results, status="test-ad: Deleted")
@dbtest
@pytest.mark.skipif(os.name == "nt", reason="Bug: fails on Windows, needs fixing, singleton of FQ not working right")
def test_favorite_query_expanded_output(executor):
set_expanded_output(False)
- run(executor, '''create table test(a text)''')
- run(executor, '''insert into test values('abc')''')
+ run(executor, """create table test(a text)""")
+ run(executor, """insert into test values('abc')""")
results = run(executor, "\\fs test-ae select * from test")
- assert_result_equal(results, status='Saved.')
+ assert_result_equal(results, status="Saved.")
results = run(executor, "\\f test-ae \\G")
assert is_expanded_output() is True
- assert_result_equal(results, title='> select * from test',
- headers=['a'], rows=[('abc',)], auto_status=False)
+ assert_result_equal(results, title="> select * from test", headers=["a"], rows=[("abc",)], auto_status=False)
set_expanded_output(False)
results = run(executor, "\\fd test-ae")
- assert_result_equal(results, status='test-ae: Deleted')
+ assert_result_equal(results, status="test-ae: Deleted")
@dbtest
def test_special_command(executor):
- results = run(executor, '\\?')
- assert_result_equal(results, rows=('quit', '\\q', 'Quit.'),
- headers='Command', assert_contains=True,
- auto_status=False)
+ results = run(executor, "\\?")
+ assert_result_equal(results, rows=("quit", "\\q", "Quit."), headers="Command", assert_contains=True, auto_status=False)
@dbtest
def test_cd_command_without_a_folder_name(executor):
- results = run(executor, 'system cd')
- assert_result_equal(results, status='No folder name was provided.')
+ results = run(executor, "system cd")
+ assert_result_equal(results, status="No folder name was provided.")
@dbtest
def test_system_command_not_found(executor):
- results = run(executor, 'system xyz')
- if os.name=='nt':
- assert_result_equal(results, status='OSError: The system cannot find the file specified',
- assert_contains=True)
+ results = run(executor, "system xyz")
+ if os.name == "nt":
+ assert_result_equal(results, status="OSError: The system cannot find the file specified", assert_contains=True)
else:
- assert_result_equal(results, status='OSError: No such file or directory',
- assert_contains=True)
+ assert_result_equal(results, status="OSError: No such file or directory", assert_contains=True)
@dbtest
def test_system_command_output(executor):
eol = os.linesep
test_dir = os.path.abspath(os.path.dirname(__file__))
- test_file_path = os.path.join(test_dir, 'test.txt')
- results = run(executor, 'system cat {0}'.format(test_file_path))
- assert_result_equal(results, status=f'mycli rocks!{eol}')
+ test_file_path = os.path.join(test_dir, "test.txt")
+ results = run(executor, "system cat {0}".format(test_file_path))
+ assert_result_equal(results, status=f"mycli rocks!{eol}")
@dbtest
def test_cd_command_current_dir(executor):
test_path = os.path.abspath(os.path.dirname(__file__))
- run(executor, 'system cd {0}'.format(test_path))
+ run(executor, "system cd {0}".format(test_path))
assert os.getcwd() == test_path
@dbtest
def test_unicode_support(executor):
- results = run(executor, u"SELECT '日本語' AS japanese;")
- assert_result_equal(results, headers=['japanese'], rows=[(u'日本語',)])
+ results = run(executor, "SELECT '日本語' AS japanese;")
+ assert_result_equal(results, headers=["japanese"], rows=[("日本語",)])
@dbtest
def test_timestamp_null(executor):
- run(executor, '''create table ts_null(a timestamp null)''')
- run(executor, '''insert into ts_null values(null)''')
- results = run(executor, '''select * from ts_null''')
- assert_result_equal(results, headers=['a'],
- rows=[(None,)])
+ run(executor, """create table ts_null(a timestamp null)""")
+ run(executor, """insert into ts_null values(null)""")
+ results = run(executor, """select * from ts_null""")
+ assert_result_equal(results, headers=["a"], rows=[(None,)])
@dbtest
def test_datetime_null(executor):
- run(executor, '''create table dt_null(a datetime null)''')
- run(executor, '''insert into dt_null values(null)''')
- results = run(executor, '''select * from dt_null''')
- assert_result_equal(results, headers=['a'],
- rows=[(None,)])
+ run(executor, """create table dt_null(a datetime null)""")
+ run(executor, """insert into dt_null values(null)""")
+ results = run(executor, """select * from dt_null""")
+ assert_result_equal(results, headers=["a"], rows=[(None,)])
@dbtest
def test_date_null(executor):
- run(executor, '''create table date_null(a date null)''')
- run(executor, '''insert into date_null values(null)''')
- results = run(executor, '''select * from date_null''')
- assert_result_equal(results, headers=['a'], rows=[(None,)])
+ run(executor, """create table date_null(a date null)""")
+ run(executor, """insert into date_null values(null)""")
+ results = run(executor, """select * from date_null""")
+ assert_result_equal(results, headers=["a"], rows=[(None,)])
@dbtest
def test_time_null(executor):
- run(executor, '''create table time_null(a time null)''')
- run(executor, '''insert into time_null values(null)''')
- results = run(executor, '''select * from time_null''')
- assert_result_equal(results, headers=['a'], rows=[(None,)])
+ run(executor, """create table time_null(a time null)""")
+ run(executor, """insert into time_null values(null)""")
+ results = run(executor, """select * from time_null""")
+ assert_result_equal(results, headers=["a"], rows=[(None,)])
@dbtest
def test_multiple_results(executor):
- query = '''CREATE PROCEDURE dmtest()
+ query = """CREATE PROCEDURE dmtest()
BEGIN
SELECT 1;
SELECT 2;
- END'''
+ END"""
executor.conn.cursor().execute(query)
- results = run(executor, 'call dmtest;')
+ results = run(executor, "call dmtest;")
expected = [
- {'title': None, 'rows': [(1,)], 'headers': ['1'],
- 'status': '1 row in set'},
- {'title': None, 'rows': [(2,)], 'headers': ['2'],
- 'status': '1 row in set'}
+ {"title": None, "rows": [(1,)], "headers": ["1"], "status": "1 row in set"},
+ {"title": None, "rows": [(2,)], "headers": ["2"], "status": "1 row in set"},
]
assert results == expected
@pytest.mark.parametrize(
- 'version_string, species, parsed_version_string, version',
+ "version_string, species, parsed_version_string, version",
(
- ('5.7.25-TiDB-v6.1.0','TiDB', '6.1.0', 60100),
- ('8.0.11-TiDB-v7.2.0-alpha-69-g96e9e68daa', 'TiDB', '7.2.0', 70200),
- ('5.7.32-35', 'Percona', '5.7.32', 50732),
- ('5.7.32-0ubuntu0.18.04.1', 'MySQL', '5.7.32', 50732),
- ('10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
- ('5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
- ('5.0.16-pro-nt-log', 'MySQL', '5.0.16', 50016),
- ('5.1.5a-alpha', 'MySQL', '5.1.5', 50105),
- ('unexpected version string', None, '', 0),
- ('', None, '', 0),
- (None, None, '', 0),
- )
+ ("5.7.25-TiDB-v6.1.0", "TiDB", "6.1.0", 60100),
+ ("8.0.11-TiDB-v7.2.0-alpha-69-g96e9e68daa", "TiDB", "7.2.0", 70200),
+ ("5.7.32-35", "Percona", "5.7.32", 50732),
+ ("5.7.32-0ubuntu0.18.04.1", "MySQL", "5.7.32", 50732),
+ ("10.5.8-MariaDB-1:10.5.8+maria~focal", "MariaDB", "10.5.8", 100508),
+ ("5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal", "MariaDB", "10.5.8", 100508),
+ ("5.0.16-pro-nt-log", "MySQL", "5.0.16", 50016),
+ ("5.1.5a-alpha", "MySQL", "5.1.5", 50105),
+ ("unexpected version string", None, "", 0),
+ ("", None, "", 0),
+ (None, None, "", 0),
+ ),
)
def test_version_parsing(version_string, species, parsed_version_string, version):
server_info = ServerInfo.from_version_string(version_string)
diff --git a/test/test_tabular_output.py b/test/test_tabular_output.py
index bdc1dbf..45e97af 100644
--- a/test/test_tabular_output.py
+++ b/test/test_tabular_output.py
@@ -2,8 +2,6 @@
from textwrap import dedent
-from mycli.packages.tabular_output import sql_format
-from cli_helpers.tabular_output import TabularOutputFormatter
from .utils import USER, PASSWORD, HOST, PORT, dbtest
@@ -23,20 +21,17 @@ def mycli():
@dbtest
def test_sql_output(mycli):
"""Test the sql output adapter."""
- headers = ['letters', 'number', 'optional', 'float', 'binary']
+ headers = ["letters", "number", "optional", "float", "binary"]
class FakeCursor(object):
def __init__(self):
- self.data = [
- ('abc', 1, None, 10.0, b'\xAA'),
- ('d', 456, '1', 0.5, b'\xAA\xBB')
- ]
+ self.data = [("abc", 1, None, 10.0, b"\xaa"), ("d", 456, "1", 0.5, b"\xaa\xbb")]
self.description = [
(None, FIELD_TYPE.VARCHAR),
(None, FIELD_TYPE.LONG),
(None, FIELD_TYPE.LONG),
(None, FIELD_TYPE.FLOAT),
- (None, FIELD_TYPE.BLOB)
+ (None, FIELD_TYPE.BLOB),
]
def __iter__(self):
@@ -52,12 +47,11 @@ def test_sql_output(mycli):
return self.description
# Test sql-update output format
- assert list(mycli.change_table_format("sql-update")) == \
- [(None, None, None, 'Changed table format to sql-update')]
+ assert list(mycli.change_table_format("sql-update")) == [(None, None, None, "Changed table format to sql-update")]
mycli.formatter.query = ""
output = mycli.format_output(None, FakeCursor(), headers)
actual = "\n".join(output)
- assert actual == dedent('''\
+ assert actual == dedent("""\
UPDATE `DUAL` SET
`number` = 1
, `optional` = NULL
@@ -69,13 +63,12 @@ def test_sql_output(mycli):
, `optional` = '1'
, `float` = 0.5e0
, `binary` = X'aabb'
- WHERE `letters` = 'd';''')
+ WHERE `letters` = 'd';""")
# Test sql-update-2 output format
- assert list(mycli.change_table_format("sql-update-2")) == \
- [(None, None, None, 'Changed table format to sql-update-2')]
+ assert list(mycli.change_table_format("sql-update-2")) == [(None, None, None, "Changed table format to sql-update-2")]
mycli.formatter.query = ""
output = mycli.format_output(None, FakeCursor(), headers)
- assert "\n".join(output) == dedent('''\
+ assert "\n".join(output) == dedent("""\
UPDATE `DUAL` SET
`optional` = NULL
, `float` = 10.0e0
@@ -85,34 +78,31 @@ def test_sql_output(mycli):
`optional` = '1'
, `float` = 0.5e0
, `binary` = X'aabb'
- WHERE `letters` = 'd' AND `number` = 456;''')
+ WHERE `letters` = 'd' AND `number` = 456;""")
# Test sql-insert output format (without table name)
- assert list(mycli.change_table_format("sql-insert")) == \
- [(None, None, None, 'Changed table format to sql-insert')]
+ assert list(mycli.change_table_format("sql-insert")) == [(None, None, None, "Changed table format to sql-insert")]
mycli.formatter.query = ""
output = mycli.format_output(None, FakeCursor(), headers)
- assert "\n".join(output) == dedent('''\
+ assert "\n".join(output) == dedent("""\
INSERT INTO `DUAL` (`letters`, `number`, `optional`, `float`, `binary`) VALUES
('abc', 1, NULL, 10.0e0, X'aa')
, ('d', 456, '1', 0.5e0, X'aabb')
- ;''')
+ ;""")
# Test sql-insert output format (with table name)
- assert list(mycli.change_table_format("sql-insert")) == \
- [(None, None, None, 'Changed table format to sql-insert')]
+ assert list(mycli.change_table_format("sql-insert")) == [(None, None, None, "Changed table format to sql-insert")]
mycli.formatter.query = "SELECT * FROM `table`"
output = mycli.format_output(None, FakeCursor(), headers)
- assert "\n".join(output) == dedent('''\
+ assert "\n".join(output) == dedent("""\
INSERT INTO table (`letters`, `number`, `optional`, `float`, `binary`) VALUES
('abc', 1, NULL, 10.0e0, X'aa')
, ('d', 456, '1', 0.5e0, X'aabb')
- ;''')
+ ;""")
# Test sql-insert output format (with database + table name)
- assert list(mycli.change_table_format("sql-insert")) == \
- [(None, None, None, 'Changed table format to sql-insert')]
+ assert list(mycli.change_table_format("sql-insert")) == [(None, None, None, "Changed table format to sql-insert")]
mycli.formatter.query = "SELECT * FROM `database`.`table`"
output = mycli.format_output(None, FakeCursor(), headers)
- assert "\n".join(output) == dedent('''\
+ assert "\n".join(output) == dedent("""\
INSERT INTO database.table (`letters`, `number`, `optional`, `float`, `binary`) VALUES
('abc', 1, NULL, 10.0e0, X'aa')
, ('d', 456, '1', 0.5e0, X'aabb')
- ;''')
+ ;""")
diff --git a/test/utils.py b/test/utils.py
index ab12248..383f502 100644
--- a/test/utils.py
+++ b/test/utils.py
@@ -9,20 +9,18 @@ import pytest
from mycli.main import special
-PASSWORD = os.getenv('PYTEST_PASSWORD')
-USER = os.getenv('PYTEST_USER', 'root')
-HOST = os.getenv('PYTEST_HOST', 'localhost')
-PORT = int(os.getenv('PYTEST_PORT', 3306))
-CHARSET = os.getenv('PYTEST_CHARSET', 'utf8')
-SSH_USER = os.getenv('PYTEST_SSH_USER', None)
-SSH_HOST = os.getenv('PYTEST_SSH_HOST', None)
-SSH_PORT = os.getenv('PYTEST_SSH_PORT', 22)
+PASSWORD = os.getenv("PYTEST_PASSWORD")
+USER = os.getenv("PYTEST_USER", "root")
+HOST = os.getenv("PYTEST_HOST", "localhost")
+PORT = int(os.getenv("PYTEST_PORT", 3306))
+CHARSET = os.getenv("PYTEST_CHARSET", "utf8")
+SSH_USER = os.getenv("PYTEST_SSH_USER", None)
+SSH_HOST = os.getenv("PYTEST_SSH_HOST", None)
+SSH_PORT = os.getenv("PYTEST_SSH_PORT", 22)
def db_connection(dbname=None):
- conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname,
- password=PASSWORD, charset=CHARSET,
- local_infile=False)
+ conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname, password=PASSWORD, charset=CHARSET, local_infile=False)
conn.autocommit = True
return conn
@@ -30,20 +28,18 @@ def db_connection(dbname=None):
try:
db_connection()
CAN_CONNECT_TO_DB = True
-except:
+except Exception:
CAN_CONNECT_TO_DB = False
-dbtest = pytest.mark.skipif(
- not CAN_CONNECT_TO_DB,
- reason="Need a mysql instance at localhost accessible by user 'root'")
+dbtest = pytest.mark.skipif(not CAN_CONNECT_TO_DB, reason="Need a mysql instance at localhost accessible by user 'root'")
def create_db(dbname):
with db_connection().cursor() as cur:
try:
- cur.execute('''DROP DATABASE IF EXISTS mycli_test_db''')
- cur.execute('''CREATE DATABASE mycli_test_db''')
- except:
+ cur.execute("""DROP DATABASE IF EXISTS mycli_test_db""")
+ cur.execute("""CREATE DATABASE mycli_test_db""")
+ except Exception:
pass
@@ -53,8 +49,7 @@ def run(executor, sql, rows_as_list=True):
for title, rows, headers, status in executor.run(sql):
rows = list(rows) if (rows_as_list and rows) else rows
- result.append({'title': title, 'rows': rows, 'headers': headers,
- 'status': status})
+ result.append({"title": title, "rows": rows, "headers": headers, "status": status})
return result
@@ -87,8 +82,6 @@ def send_ctrl_c(wait_seconds):
Returns the `multiprocessing.Process` created.
"""
- ctrl_c_process = multiprocessing.Process(
- target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds)
- )
+ ctrl_c_process = multiprocessing.Process(target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds))
ctrl_c_process.start()
return ctrl_c_process