diff options
Diffstat (limited to 'test')
46 files changed, 4366 insertions, 0 deletions
diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/test/__init__.py diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..1325596 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,29 @@ +import pytest +from .utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db, + db_connection, SSH_USER, SSH_HOST, SSH_PORT) +import mycli.sqlexecute + + +@pytest.fixture(scope="function") +def connection(): + create_db('mycli_test_db') + connection = db_connection('mycli_test_db') + yield connection + + connection.close() + + +@pytest.fixture +def cursor(connection): + with connection.cursor() as cur: + return cur + + +@pytest.fixture +def executor(connection): + return mycli.sqlexecute.SQLExecute( + database='mycli_test_db', user=USER, + host=HOST, password=PASSWORD, port=PORT, socket=None, charset=CHARSET, + local_infile=False, ssl=None, ssh_user=SSH_USER, ssh_host=SSH_HOST, + ssh_port=SSH_PORT, ssh_password=None, ssh_key_filename=None + ) diff --git a/test/features/__init__.py b/test/features/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/test/features/__init__.py diff --git a/test/features/auto_vertical.feature b/test/features/auto_vertical.feature new file mode 100644 index 0000000..aa95718 --- /dev/null +++ b/test/features/auto_vertical.feature @@ -0,0 +1,12 @@ +Feature: auto_vertical mode: + on, off + + Scenario: auto_vertical on with small query + When we run dbcli with --auto-vertical-output + and we execute a small query + then we see small results in horizontal format + + Scenario: auto_vertical on with large query + When we run dbcli with --auto-vertical-output + and we execute a large query + then we see large results in vertical format diff --git a/test/features/basic_commands.feature b/test/features/basic_commands.feature new file mode 100644 index 0000000..a12e899 --- /dev/null +++ b/test/features/basic_commands.feature @@ -0,0 +1,19 @@ +Feature: run the cli, + call the help command, + exit the cli + + Scenario: run "\?" command + When we send "\?" command + then we see help output + + Scenario: run source command + When we send source command + then we see help output + + Scenario: check our application_name + When we run query to check application_name + then we see found + + Scenario: run the cli and exit + When we send "ctrl + d" + then dbcli exits diff --git a/test/features/connection.feature b/test/features/connection.feature new file mode 100644 index 0000000..b06935e --- /dev/null +++ b/test/features/connection.feature @@ -0,0 +1,35 @@ +Feature: connect to a database: + + @requires_local_db + Scenario: run mycli on localhost without port + When we run mycli with arguments "host=localhost" without arguments "port" + When we query "status" + Then status contains "via UNIX socket" + + Scenario: run mycli on TCP host without port + When we run mycli without arguments "port" + When we query "status" + Then status contains "via TCP/IP" + + Scenario: run mycli with port but without host + When we run mycli without arguments "host" + When we query "status" + Then status contains "via TCP/IP" + + @requires_local_db + Scenario: run mycli without host and port + When we run mycli without arguments "host port" + When we query "status" + Then status contains "via UNIX socket" + + Scenario: run mycli with my.cnf configuration + When we create my.cnf file + When we run mycli without arguments "host port user pass defaults_file" + Then we are logged in + + Scenario: run mycli with mylogin.cnf configuration + When we create mylogin.cnf file + When we run mycli with arguments "login_path=test_login_path" without arguments "host port user pass defaults_file" + Then we are logged in + + diff --git a/test/features/crud_database.feature b/test/features/crud_database.feature new file mode 100644 index 0000000..f4a7a7f --- /dev/null +++ b/test/features/crud_database.feature @@ -0,0 +1,30 @@ +Feature: manipulate databases: + create, drop, connect, disconnect + + Scenario: create and drop temporary database + When we create database + then we see database created + when we drop database + then we confirm the destructive warning + then we see database dropped + when we connect to dbserver + then we see database connected + + Scenario: connect and disconnect from test database + When we connect to test database + then we see database connected + when we connect to dbserver + then we see database connected + + Scenario: connect and disconnect from quoted test database + When we connect to quoted test database + then we see database connected + + Scenario: create and drop default database + When we create database + then we see database created + when we connect to tmp database + then we see database connected + when we drop database + then we confirm the destructive warning + then we see database dropped and no default database diff --git a/test/features/crud_table.feature b/test/features/crud_table.feature new file mode 100644 index 0000000..3384efd --- /dev/null +++ b/test/features/crud_table.feature @@ -0,0 +1,49 @@ +Feature: manipulate tables: + create, insert, update, select, delete from, drop + + Scenario: create, insert, select from, update, drop table + When we connect to test database + then we see database connected + when we create table + then we see table created + when we insert into table + then we see record inserted + when we update table + then we see record updated + when we select from table + then we see data selected + when we delete from table + then we confirm the destructive warning + then we see record deleted + when we drop table + then we confirm the destructive warning + then we see table dropped + when we connect to dbserver + then we see database connected + + Scenario: select null values + When we connect to test database + then we see database connected + when we select null + then we see null selected + + Scenario: confirm destructive query + When we query "create table foo(x integer);" + and we query "delete from foo;" + and we answer the destructive warning with "y" + then we see text "Your call!" + + Scenario: decline destructive query + When we query "delete from foo;" + and we answer the destructive warning with "n" + then we see text "Wise choice!" + + Scenario: no destructive warning if disabled in config + When we run dbcli with --no-warn + and we query "create table blabla(x integer);" + and we query "delete from blabla;" + Then we see text "Query OK" + + Scenario: confirm destructive query with invalid response + When we query "delete from foo;" + then we answer the destructive warning with invalid "1" and see text "is not a valid boolean" diff --git a/test/features/db_utils.py b/test/features/db_utils.py new file mode 100644 index 0000000..be550e9 --- /dev/null +++ b/test/features/db_utils.py @@ -0,0 +1,93 @@ +import pymysql + + +def create_db(hostname='localhost', port=3306, username=None, + password=None, dbname=None): + """Create test database. + + :param hostname: string + :param port: int + :param username: string + :param password: string + :param dbname: string + :return: + + """ + cn = pymysql.connect( + host=hostname, + port=port, + user=username, + password=password, + charset='utf8mb4', + cursorclass=pymysql.cursors.DictCursor + ) + + with cn.cursor() as cr: + cr.execute('drop database if exists ' + dbname) + cr.execute('create database ' + dbname) + + cn.close() + + cn = create_cn(hostname, port, password, username, dbname) + return cn + + +def create_cn(hostname, port, password, username, dbname): + """Open connection to database. + + :param hostname: + :param port: + :param password: + :param username: + :param dbname: string + :return: psycopg2.connection + + """ + cn = pymysql.connect( + host=hostname, + port=port, + user=username, + password=password, + db=dbname, + charset='utf8mb4', + cursorclass=pymysql.cursors.DictCursor + ) + + return cn + + +def drop_db(hostname='localhost', port=3306, username=None, + password=None, dbname=None): + """Drop database. + + :param hostname: string + :param port: int + :param username: string + :param password: string + :param dbname: string + + """ + cn = pymysql.connect( + host=hostname, + port=port, + user=username, + password=password, + db=dbname, + charset='utf8mb4', + cursorclass=pymysql.cursors.DictCursor + ) + + with cn.cursor() as cr: + cr.execute('drop database if exists ' + dbname) + + close_cn(cn) + + +def close_cn(cn=None): + """Close connection. + + :param connection: pymysql.connection + + """ + if cn: + cn.close() diff --git a/test/features/environment.py b/test/features/environment.py new file mode 100644 index 0000000..1ea0f08 --- /dev/null +++ b/test/features/environment.py @@ -0,0 +1,176 @@ +import os +import shutil +import sys +from tempfile import mkstemp + +import db_utils as dbutils +import fixture_utils as fixutils +import pexpect + +from steps.wrappers import run_cli, wait_prompt + +test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log') + + +SELF_CONNECTING_FEATURES = ( + 'test/features/connection.feature', +) + + +MY_CNF_PATH = os.path.expanduser('~/.my.cnf') +MY_CNF_BACKUP_PATH = f'{MY_CNF_PATH}.backup' +MYLOGIN_CNF_PATH = os.path.expanduser('~/.mylogin.cnf') +MYLOGIN_CNF_BACKUP_PATH = f'{MYLOGIN_CNF_PATH}.backup' + + +def get_db_name_from_context(context): + return context.config.userdata.get( + 'my_test_db', None + ) or "mycli_behave_tests" + + + +def before_all(context): + """Set env parameters.""" + os.environ['LINES'] = "100" + os.environ['COLUMNS'] = "100" + os.environ['EDITOR'] = 'ex' + os.environ['LC_ALL'] = 'en_US.UTF-8' + os.environ['PROMPT_TOOLKIT_NO_CPR'] = '1' + os.environ['MYCLI_HISTFILE'] = os.devnull + + test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) + login_path_file = os.path.join(test_dir, 'mylogin.cnf') +# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file + + context.package_root = os.path.abspath( + os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + + os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root, + '.coveragerc') + + context.exit_sent = False + + vi = '_'.join([str(x) for x in sys.version_info[:3]]) + db_name = get_db_name_from_context(context) + db_name_full = '{0}_{1}'.format(db_name, vi) + + # Store get params from config/environment variables + context.conf = { + 'host': context.config.userdata.get( + 'my_test_host', + os.getenv('PYTEST_HOST', 'localhost') + ), + 'port': context.config.userdata.get( + 'my_test_port', + int(os.getenv('PYTEST_PORT', '3306')) + ), + 'user': context.config.userdata.get( + 'my_test_user', + os.getenv('PYTEST_USER', 'root') + ), + 'pass': context.config.userdata.get( + 'my_test_pass', + os.getenv('PYTEST_PASSWORD', None) + ), + 'cli_command': context.config.userdata.get( + 'my_cli_command', None) or + sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"', + 'dbname': db_name, + 'dbname_tmp': db_name_full + '_tmp', + 'vi': vi, + 'pager_boundary': '---boundary---', + } + + _, my_cnf = mkstemp() + with open(my_cnf, 'w') as f: + f.write( + '[client]\n' + 'pager={0} {1} {2}\n'.format( + sys.executable, os.path.join(context.package_root, + 'test/features/wrappager.py'), + context.conf['pager_boundary']) + ) + context.conf['defaults-file'] = my_cnf + context.conf['myclirc'] = os.path.join(context.package_root, 'test', + 'myclirc') + + context.cn = dbutils.create_db(context.conf['host'], context.conf['port'], + context.conf['user'], + context.conf['pass'], + context.conf['dbname']) + + context.fixture_data = fixutils.read_fixture_files() + + +def after_all(context): + """Unset env parameters.""" + dbutils.close_cn(context.cn) + dbutils.drop_db(context.conf['host'], context.conf['port'], + context.conf['user'], context.conf['pass'], + context.conf['dbname']) + + # Restore env vars. + #for k, v in context.pgenv.items(): + # if k in os.environ and v is None: + # del os.environ[k] + # elif v: + # os.environ[k] = v + + +def before_step(context, _): + context.atprompt = False + + +def before_scenario(context, arg): + with open(test_log_file, 'w') as f: + f.write('') + if arg.location.filename not in SELF_CONNECTING_FEATURES: + run_cli(context) + wait_prompt(context) + + if os.path.exists(MY_CNF_PATH): + shutil.move(MY_CNF_PATH, MY_CNF_BACKUP_PATH) + + if os.path.exists(MYLOGIN_CNF_PATH): + shutil.move(MYLOGIN_CNF_PATH, MYLOGIN_CNF_BACKUP_PATH) + + +def after_scenario(context, _): + """Cleans up after each test complete.""" + with open(test_log_file) as f: + for line in f: + if 'error' in line.lower(): + raise RuntimeError(f'Error in log file: {line}') + + if hasattr(context, 'cli') and not context.exit_sent: + # Quit nicely. + if not context.atprompt: + user = context.conf['user'] + host = context.conf['host'] + dbname = context.currentdb + context.cli.expect_exact( + '{0}@{1}:{2}>'.format( + user, host, dbname + ), + timeout=5 + ) + context.cli.sendcontrol('c') + context.cli.sendcontrol('d') + context.cli.expect_exact(pexpect.EOF, timeout=5) + + if os.path.exists(MY_CNF_BACKUP_PATH): + shutil.move(MY_CNF_BACKUP_PATH, MY_CNF_PATH) + + if os.path.exists(MYLOGIN_CNF_BACKUP_PATH): + shutil.move(MYLOGIN_CNF_BACKUP_PATH, MYLOGIN_CNF_PATH) + elif os.path.exists(MYLOGIN_CNF_PATH): + # This file was moved in `before_scenario`. + # If it exists now, it has been created during a test + os.remove(MYLOGIN_CNF_PATH) + + +# TODO: uncomment to debug a failure +# def after_step(context, step): +# if step.status == "failed": +# import ipdb; ipdb.set_trace() diff --git a/test/features/fixture_data/help.txt b/test/features/fixture_data/help.txt new file mode 100644 index 0000000..deb499a --- /dev/null +++ b/test/features/fixture_data/help.txt @@ -0,0 +1,24 @@ ++--------------------------+-----------------------------------------------+ +| Command | Description | +|--------------------------+-----------------------------------------------| +| \# | Refresh auto-completions. | +| \? | Show Help. | +| \c[onnect] database_name | Change to a new database. | +| \d [pattern] | List or describe tables, views and sequences. | +| \dT[S+] [pattern] | List data types | +| \df[+] [pattern] | List functions. | +| \di[+] [pattern] | List indexes. | +| \dn[+] [pattern] | List schemas. | +| \ds[+] [pattern] | List sequences. | +| \dt[+] [pattern] | List tables. | +| \du[+] [pattern] | List roles. | +| \dv[+] [pattern] | List views. | +| \e [file] | Edit the query with external editor. | +| \l | List databases. | +| \n[+] [name] | List or execute named queries. | +| \nd [name [query]] | Delete a named query. | +| \ns name query | Save a named query. | +| \refresh | Refresh auto-completions. | +| \timing | Toggle timing of commands. | +| \x | Toggle expanded output. | ++--------------------------+-----------------------------------------------+ diff --git a/test/features/fixture_data/help_commands.txt b/test/features/fixture_data/help_commands.txt new file mode 100644 index 0000000..2c06d5d --- /dev/null +++ b/test/features/fixture_data/help_commands.txt @@ -0,0 +1,31 @@ ++-------------+----------------------------+------------------------------------------------------------+
+| Command | Shortcut | Description |
++-------------+----------------------------+------------------------------------------------------------+
+| \G | \G | Display current query results vertically. |
+| \clip | \clip | Copy query to the system clipboard. |
+| \dt | \dt[+] [table] | List or describe tables. |
+| \e | \e | Edit command with editor (uses $EDITOR). |
+| \f | \f [name [args..]] | List or execute favorite queries. |
+| \fd | \fd [name] | Delete a favorite query. |
+| \fs | \fs name query | Save a favorite query. |
+| \l | \l | List databases. |
+| \once | \o [-o] filename | Append next result to an output file (overwrite using -o). |
+| \pipe_once | \| command | Send next result to a subprocess. |
+| \timing | \t | Toggle timing of commands. |
+| connect | \r | Reconnect to the database. Optional database argument. |
+| exit | \q | Exit. |
+| help | \? | Show this help. |
+| nopager | \n | Disable pager, print to stdout. |
+| notee | notee | Stop writing results to an output file. |
+| pager | \P [command] | Set PAGER. Print the query results via PAGER. |
+| prompt | \R | Change prompt format. |
+| quit | \q | Quit. |
+| rehash | \# | Refresh auto-completions. |
+| source | \. filename | Execute commands from file. |
+| status | \s | Get status information from the server. |
+| system | system [command] | Execute a system shell commmand. |
+| tableformat | \T | Change the table format used to output results. |
+| tee | tee [-o] filename | Append all results to an output file (overwrite using -o). |
+| use | \u | Change to a new database. |
+| watch | watch [seconds] [-c] query | Executes the query every [seconds] seconds (by default 5). |
++-------------+----------------------------+------------------------------------------------------------+
diff --git a/test/features/fixture_utils.py b/test/features/fixture_utils.py new file mode 100644 index 0000000..f85e0f6 --- /dev/null +++ b/test/features/fixture_utils.py @@ -0,0 +1,29 @@ +import os +import io + + +def read_fixture_lines(filename): + """Read lines of text from file. + + :param filename: string name + :return: list of strings + + """ + lines = [] + for line in open(filename): + lines.append(line.strip()) + return lines + + +def read_fixture_files(): + """Read all files inside fixture_data directory.""" + fixture_dict = {} + + current_dir = os.path.dirname(__file__) + fixture_dir = os.path.join(current_dir, 'fixture_data/') + for filename in os.listdir(fixture_dir): + if filename not in ['.', '..']: + fullname = os.path.join(fixture_dir, filename) + fixture_dict[filename] = read_fixture_lines(fullname) + + return fixture_dict diff --git a/test/features/iocommands.feature b/test/features/iocommands.feature new file mode 100644 index 0000000..95366eb --- /dev/null +++ b/test/features/iocommands.feature @@ -0,0 +1,47 @@ +Feature: I/O commands + + Scenario: edit sql in file with external editor + When we start external editor providing a file name + and we type "select * from abc" in the editor + and we exit the editor + then we see dbcli prompt + and we see "select * from abc" in prompt + + Scenario: tee output from query + When we tee output + and we wait for prompt + and we select "select 123456" + and we wait for prompt + and we notee output + and we wait for prompt + then we see 123456 in tee output + + Scenario: set delimiter + When we query "delimiter $" + then delimiter is set to "$" + + Scenario: set delimiter twice + When we query "delimiter $" + and we query "delimiter ]]" + then delimiter is set to "]]" + + Scenario: set delimiter and query on same line + When we query "select 123; delimiter $ select 456 $ delimiter %" + then we see result "123" + and we see result "456" + and delimiter is set to "%" + + Scenario: send output to file + When we query "\o /tmp/output1.sql" + and we query "select 123" + and we query "system cat /tmp/output1.sql" + then we see result "123" + + Scenario: send output to file two times + When we query "\o /tmp/output1.sql" + and we query "select 123" + and we query "\o /tmp/output2.sql" + and we query "select 456" + and we query "system cat /tmp/output2.sql" + then we see result "456" +
\ No newline at end of file diff --git a/test/features/named_queries.feature b/test/features/named_queries.feature new file mode 100644 index 0000000..5e681ec --- /dev/null +++ b/test/features/named_queries.feature @@ -0,0 +1,24 @@ +Feature: named queries: + save, use and delete named queries + + Scenario: save, use and delete named queries + When we connect to test database + then we see database connected + when we save a named query + then we see the named query saved + when we use a named query + then we see the named query executed + when we delete a named query + then we see the named query deleted + + Scenario: save, use and delete named queries with parameters + When we connect to test database + then we see database connected + when we save a named query with parameters + then we see the named query saved + when we use named query with parameters + then we see the named query with parameters executed + when we use named query with too few parameters + then we see the named query with parameters fail with missing parameters + when we use named query with too many parameters + then we see the named query with parameters fail with extra parameters diff --git a/test/features/specials.feature b/test/features/specials.feature new file mode 100644 index 0000000..bb36757 --- /dev/null +++ b/test/features/specials.feature @@ -0,0 +1,7 @@ +Feature: Special commands + + @wip + Scenario: run refresh command + When we refresh completions + and we wait for prompt + then we see completions refresh started diff --git a/test/features/steps/__init__.py b/test/features/steps/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/test/features/steps/__init__.py diff --git a/test/features/steps/auto_vertical.py b/test/features/steps/auto_vertical.py new file mode 100644 index 0000000..e1cb26f --- /dev/null +++ b/test/features/steps/auto_vertical.py @@ -0,0 +1,46 @@ +from textwrap import dedent + +from behave import then, when + +import wrappers +from utils import parse_cli_args_to_dict + + +@when('we run dbcli with {arg}') +def step_run_cli_with_arg(context, arg): + wrappers.run_cli(context, run_args=parse_cli_args_to_dict(arg)) + + +@when('we execute a small query') +def step_execute_small_query(context): + context.cli.sendline('select 1') + + +@when('we execute a large query') +def step_execute_large_query(context): + context.cli.sendline( + 'select {}'.format(','.join([str(n) for n in range(1, 50)]))) + + +@then('we see small results in horizontal format') +def step_see_small_results(context): + wrappers.expect_pager(context, dedent("""\ + +---+\r + | 1 |\r + +---+\r + | 1 |\r + +---+\r + \r + """), timeout=5) + wrappers.expect_exact(context, '1 row in set', timeout=2) + + +@then('we see large results in vertical format') +def step_see_large_results(context): + rows = ['{n:3}| {n}'.format(n=str(n)) for n in range(1, 50)] + expected = ('***************************[ 1. row ]' + '***************************\r\n' + + '{}\r\n'.format('\r\n'.join(rows) + '\r\n')) + + wrappers.expect_pager(context, expected, timeout=10) + wrappers.expect_exact(context, '1 row in set', timeout=2) diff --git a/test/features/steps/basic_commands.py b/test/features/steps/basic_commands.py new file mode 100644 index 0000000..425ef67 --- /dev/null +++ b/test/features/steps/basic_commands.py @@ -0,0 +1,100 @@ +"""Steps for behavioral style tests are defined in this module. + +Each step is defined by the string decorating it. This string is used +to call the step in "*.feature" file. + +""" + +from behave import when +from textwrap import dedent +import tempfile +import wrappers + + +@when('we run dbcli') +def step_run_cli(context): + wrappers.run_cli(context) + + +@when('we wait for prompt') +def step_wait_prompt(context): + wrappers.wait_prompt(context) + + +@when('we send "ctrl + d"') +def step_ctrl_d(context): + """Send Ctrl + D to hopefully exit.""" + context.cli.sendcontrol('d') + context.exit_sent = True + + +@when('we send "\?" command') +def step_send_help(context): + """Send \? + + to see help. + + """ + context.cli.sendline('\\?') + wrappers.expect_exact( + context, context.conf['pager_boundary'] + '\r\n', timeout=5) + + +@when(u'we send source command') +def step_send_source_command(context): + with tempfile.NamedTemporaryFile() as f: + f.write(b'\?') + f.flush() + context.cli.sendline('\. {0}'.format(f.name)) + wrappers.expect_exact( + context, context.conf['pager_boundary'] + '\r\n', timeout=5) + + +@when(u'we run query to check application_name') +def step_check_application_name(context): + context.cli.sendline( + "SELECT 'found' FROM performance_schema.session_connect_attrs WHERE attr_name = 'program_name' AND attr_value = 'mycli'" + ) + + +@then(u'we see found') +def step_see_found(context): + wrappers.expect_exact( + context, + context.conf['pager_boundary'] + '\r' + dedent(''' + +-------+\r + | found |\r + +-------+\r + | found |\r + +-------+\r + \r + ''') + context.conf['pager_boundary'], + timeout=5 + ) + + +@then(u'we confirm the destructive warning') +def step_confirm_destructive_command(context): + """Confirm destructive command.""" + wrappers.expect_exact( + context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) + context.cli.sendline('y') + + +@when(u'we answer the destructive warning with "{confirmation}"') +def step_confirm_destructive_command(context, confirmation): + """Confirm destructive command.""" + wrappers.expect_exact( + context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) + context.cli.sendline(confirmation) + + +@then(u'we answer the destructive warning with invalid "{confirmation}" and see text "{text}"') +def step_confirm_destructive_command(context, confirmation, text): + """Confirm destructive command.""" + wrappers.expect_exact( + context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) + context.cli.sendline(confirmation) + wrappers.expect_exact(context, text, timeout=2) + # we must exit the Click loop, or the feature will hang + context.cli.sendline('n') diff --git a/test/features/steps/connection.py b/test/features/steps/connection.py new file mode 100644 index 0000000..e16dd86 --- /dev/null +++ b/test/features/steps/connection.py @@ -0,0 +1,71 @@ +import io +import os +import shlex + +from behave import when, then +import pexpect + +import wrappers +from test.features.steps.utils import parse_cli_args_to_dict +from test.features.environment import MY_CNF_PATH, MYLOGIN_CNF_PATH, get_db_name_from_context +from test.utils import HOST, PORT, USER, PASSWORD +from mycli.config import encrypt_mylogin_cnf + + +TEST_LOGIN_PATH = 'test_login_path' + + +@when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"') +@when('we run mycli without arguments "{excluded_args}"') +def step_run_cli_without_args(context, excluded_args, exact_args=''): + wrappers.run_cli( + context, + run_args=parse_cli_args_to_dict(exact_args), + exclude_args=parse_cli_args_to_dict(excluded_args).keys() + ) + + +@then('status contains "{expression}"') +def status_contains(context, expression): + wrappers.expect_exact(context, f'{expression}', timeout=5) + + # Normally, the shutdown after scenario waits for the prompt. + # But we may have changed the prompt, depending on parameters, + # so let's wait for its last character + context.cli.expect_exact('>') + context.atprompt = True + + +@when('we create my.cnf file') +def step_create_my_cnf_file(context): + my_cnf = ( + '[client]\n' + f'host = {HOST}\n' + f'port = {PORT}\n' + f'user = {USER}\n' + f'password = {PASSWORD}\n' + ) + with open(MY_CNF_PATH, 'w') as f: + f.write(my_cnf) + + +@when('we create mylogin.cnf file') +def step_create_mylogin_cnf_file(context): + os.environ.pop('MYSQL_TEST_LOGIN_FILE', None) + mylogin_cnf = ( + f'[{TEST_LOGIN_PATH}]\n' + f'host = {HOST}\n' + f'port = {PORT}\n' + f'user = {USER}\n' + f'password = {PASSWORD}\n' + ) + with open(MYLOGIN_CNF_PATH, 'wb') as f: + input_file = io.StringIO(mylogin_cnf) + f.write(encrypt_mylogin_cnf(input_file).read()) + + +@then('we are logged in') +def we_are_logged_in(context): + db_name = get_db_name_from_context(context) + context.cli.expect_exact(f'{db_name}>', timeout=5) + context.atprompt = True diff --git a/test/features/steps/crud_database.py b/test/features/steps/crud_database.py new file mode 100644 index 0000000..841f37d --- /dev/null +++ b/test/features/steps/crud_database.py @@ -0,0 +1,115 @@ +"""Steps for behavioral style tests are defined in this module. + +Each step is defined by the string decorating it. This string is used +to call the step in "*.feature" file. + +""" + +import pexpect + +import wrappers +from behave import when, then + + +@when('we create database') +def step_db_create(context): + """Send create database.""" + context.cli.sendline('create database {0};'.format( + context.conf['dbname_tmp'])) + + context.response = { + 'database_name': context.conf['dbname_tmp'] + } + + +@when('we drop database') +def step_db_drop(context): + """Send drop database.""" + context.cli.sendline('drop database {0};'.format( + context.conf['dbname_tmp'])) + + +@when('we connect to test database') +def step_db_connect_test(context): + """Send connect to database.""" + db_name = context.conf['dbname'] + context.currentdb = db_name + context.cli.sendline('use {0};'.format(db_name)) + + +@when('we connect to quoted test database') +def step_db_connect_quoted_tmp(context): + """Send connect to database.""" + db_name = context.conf['dbname'] + context.currentdb = db_name + context.cli.sendline('use `{0}`;'.format(db_name)) + + +@when('we connect to tmp database') +def step_db_connect_tmp(context): + """Send connect to database.""" + db_name = context.conf['dbname_tmp'] + context.currentdb = db_name + context.cli.sendline('use {0}'.format(db_name)) + + +@when('we connect to dbserver') +def step_db_connect_dbserver(context): + """Send connect to database.""" + context.currentdb = 'mysql' + context.cli.sendline('use mysql') + + +@then('dbcli exits') +def step_wait_exit(context): + """Make sure the cli exits.""" + wrappers.expect_exact(context, pexpect.EOF, timeout=5) + + +@then('we see dbcli prompt') +def step_see_prompt(context): + """Wait to see the prompt.""" + user = context.conf['user'] + host = context.conf['host'] + dbname = context.currentdb + wrappers.wait_prompt(context, '{0}@{1}:{2}> '.format(user, host, dbname)) + + +@then('we see help output') +def step_see_help(context): + for expected_line in context.fixture_data['help_commands.txt']: + wrappers.expect_exact(context, expected_line, timeout=1) + + +@then('we see database created') +def step_see_db_created(context): + """Wait to see create database output.""" + wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + + +@then('we see database dropped') +def step_see_db_dropped(context): + """Wait to see drop database output.""" + wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) + + +@then('we see database dropped and no default database') +def step_see_db_dropped_no_default(context): + """Wait to see drop database output.""" + user = context.conf['user'] + host = context.conf['host'] + database = '(none)' + context.currentdb = None + + wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) + wrappers.wait_prompt(context, '{0}@{1}:{2}>'.format(user, host, database)) + + +@then('we see database connected') +def step_see_db_connected(context): + """Wait to see drop database output.""" + wrappers.expect_exact( + context, 'You are now connected to database "', timeout=2) + wrappers.expect_exact(context, '"', timeout=2) + wrappers.expect_exact(context, ' as user "{0}"'.format( + context.conf['user']), timeout=2) diff --git a/test/features/steps/crud_table.py b/test/features/steps/crud_table.py new file mode 100644 index 0000000..f715f0c --- /dev/null +++ b/test/features/steps/crud_table.py @@ -0,0 +1,112 @@ +"""Steps for behavioral style tests are defined in this module. + +Each step is defined by the string decorating it. This string is used +to call the step in "*.feature" file. + +""" + +import wrappers +from behave import when, then +from textwrap import dedent + + +@when('we create table') +def step_create_table(context): + """Send create table.""" + context.cli.sendline('create table a(x text);') + + +@when('we insert into table') +def step_insert_into_table(context): + """Send insert into table.""" + context.cli.sendline('''insert into a(x) values('xxx');''') + + +@when('we update table') +def step_update_table(context): + """Send insert into table.""" + context.cli.sendline('''update a set x = 'yyy' where x = 'xxx';''') + + +@when('we select from table') +def step_select_from_table(context): + """Send select from table.""" + context.cli.sendline('select * from a;') + + +@when('we delete from table') +def step_delete_from_table(context): + """Send deete from table.""" + context.cli.sendline('''delete from a where x = 'yyy';''') + + +@when('we drop table') +def step_drop_table(context): + """Send drop table.""" + context.cli.sendline('drop table a;') + + +@then('we see table created') +def step_see_table_created(context): + """Wait to see create table output.""" + wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) + + +@then('we see record inserted') +def step_see_record_inserted(context): + """Wait to see insert output.""" + wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + + +@then('we see record updated') +def step_see_record_updated(context): + """Wait to see update output.""" + wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + + +@then('we see data selected') +def step_see_data_selected(context): + """Wait to see select output.""" + wrappers.expect_pager( + context, dedent("""\ + +-----+\r + | x |\r + +-----+\r + | yyy |\r + +-----+\r + \r + """), timeout=2) + wrappers.expect_exact(context, '1 row in set', timeout=2) + + +@then('we see record deleted') +def step_see_data_deleted(context): + """Wait to see delete output.""" + wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + + +@then('we see table dropped') +def step_see_table_dropped(context): + """Wait to see drop output.""" + wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) + + +@when('we select null') +def step_select_null(context): + """Send select null.""" + context.cli.sendline('select null;') + + +@then('we see null selected') +def step_see_null_selected(context): + """Wait to see null output.""" + wrappers.expect_pager( + context, dedent("""\ + +--------+\r + | NULL |\r + +--------+\r + | <null> |\r + +--------+\r + \r + """), timeout=2) + wrappers.expect_exact(context, '1 row in set', timeout=2) diff --git a/test/features/steps/iocommands.py b/test/features/steps/iocommands.py new file mode 100644 index 0000000..bbabf43 --- /dev/null +++ b/test/features/steps/iocommands.py @@ -0,0 +1,105 @@ +import os +import wrappers + +from behave import when, then +from textwrap import dedent + + +@when('we start external editor providing a file name') +def step_edit_file(context): + """Edit file with external editor.""" + context.editor_file_name = os.path.join( + context.package_root, 'test_file_{0}.sql'.format(context.conf['vi'])) + if os.path.exists(context.editor_file_name): + os.remove(context.editor_file_name) + context.cli.sendline('\e {0}'.format( + os.path.basename(context.editor_file_name))) + wrappers.expect_exact( + context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2) + wrappers.expect_exact(context, '\r\n:', timeout=2) + + +@when('we type "{query}" in the editor') +def step_edit_type_sql(context, query): + context.cli.sendline('i') + context.cli.sendline(query) + context.cli.sendline('.') + wrappers.expect_exact(context, '\r\n:', timeout=2) + + +@when('we exit the editor') +def step_edit_quit(context): + context.cli.sendline('x') + wrappers.expect_exact(context, "written", timeout=2) + + +@then('we see "{query}" in prompt') +def step_edit_done_sql(context, query): + for match in query.split(' '): + wrappers.expect_exact(context, match, timeout=5) + # Cleanup the command line. + context.cli.sendcontrol('c') + # Cleanup the edited file. + if context.editor_file_name and os.path.exists(context.editor_file_name): + os.remove(context.editor_file_name) + + +@when(u'we tee output') +def step_tee_ouptut(context): + context.tee_file_name = os.path.join( + context.package_root, 'tee_file_{0}.sql'.format(context.conf['vi'])) + if os.path.exists(context.tee_file_name): + os.remove(context.tee_file_name) + context.cli.sendline('tee {0}'.format( + os.path.basename(context.tee_file_name))) + + +@when(u'we select "select {param}"') +def step_query_select_number(context, param): + context.cli.sendline(u'select {}'.format(param)) + wrappers.expect_pager(context, dedent(u"""\ + +{dashes}+\r + | {param} |\r + +{dashes}+\r + | {param} |\r + +{dashes}+\r + \r + """.format(param=param, dashes='-' * (len(param) + 2)) + ), timeout=5) + wrappers.expect_exact(context, '1 row in set', timeout=2) + + +@then(u'we see result "{result}"') +def step_see_result(context, result): + wrappers.expect_exact( + context, + u"| {} |".format(result), + timeout=2 + ) + + +@when(u'we query "{query}"') +def step_query(context, query): + context.cli.sendline(query) + + +@when(u'we notee output') +def step_notee_output(context): + context.cli.sendline('notee') + + +@then(u'we see 123456 in tee output') +def step_see_123456_in_ouput(context): + with open(context.tee_file_name) as f: + assert '123456' in f.read() + if os.path.exists(context.tee_file_name): + os.remove(context.tee_file_name) + + +@then(u'delimiter is set to "{delimiter}"') +def delimiter_is_set(context, delimiter): + wrappers.expect_exact( + context, + u'Changed delimiter to {}'.format(delimiter), + timeout=2 + ) diff --git a/test/features/steps/named_queries.py b/test/features/steps/named_queries.py new file mode 100644 index 0000000..bc1f866 --- /dev/null +++ b/test/features/steps/named_queries.py @@ -0,0 +1,90 @@ +"""Steps for behavioral style tests are defined in this module. + +Each step is defined by the string decorating it. This string is used +to call the step in "*.feature" file. + +""" + +import wrappers +from behave import when, then + + +@when('we save a named query') +def step_save_named_query(context): + """Send \fs command.""" + context.cli.sendline('\\fs foo SELECT 12345') + + +@when('we use a named query') +def step_use_named_query(context): + """Send \f command.""" + context.cli.sendline('\\f foo') + + +@when('we delete a named query') +def step_delete_named_query(context): + """Send \fd command.""" + context.cli.sendline('\\fd foo') + + +@then('we see the named query saved') +def step_see_named_query_saved(context): + """Wait to see query saved.""" + wrappers.expect_exact(context, 'Saved.', timeout=2) + + +@then('we see the named query executed') +def step_see_named_query_executed(context): + """Wait to see select output.""" + wrappers.expect_exact(context, 'SELECT 12345', timeout=2) + + +@then('we see the named query deleted') +def step_see_named_query_deleted(context): + """Wait to see query deleted.""" + wrappers.expect_exact(context, 'foo: Deleted', timeout=2) + + +@when('we save a named query with parameters') +def step_save_named_query_with_parameters(context): + """Send \fs command for query with parameters.""" + context.cli.sendline('\\fs foo_args SELECT $1, "$2", "$3"') + + +@when('we use named query with parameters') +def step_use_named_query_with_parameters(context): + """Send \f command with parameters.""" + context.cli.sendline('\\f foo_args 101 second "third value"') + + +@then('we see the named query with parameters executed') +def step_see_named_query_with_parameters_executed(context): + """Wait to see select output.""" + wrappers.expect_exact( + context, 'SELECT 101, "second", "third value"', timeout=2) + + +@when('we use named query with too few parameters') +def step_use_named_query_with_too_few_parameters(context): + """Send \f command with missing parameters.""" + context.cli.sendline('\\f foo_args 101') + + +@then('we see the named query with parameters fail with missing parameters') +def step_see_named_query_with_parameters_fail_with_missing_parameters(context): + """Wait to see select output.""" + wrappers.expect_exact( + context, 'missing substitution for $2 in query:', timeout=2) + + +@when('we use named query with too many parameters') +def step_use_named_query_with_too_many_parameters(context): + """Send \f command with extra parameters.""" + context.cli.sendline('\\f foo_args 101 102 103 104') + + +@then('we see the named query with parameters fail with extra parameters') +def step_see_named_query_with_parameters_fail_with_extra_parameters(context): + """Wait to see select output.""" + wrappers.expect_exact( + context, 'query does not have substitution parameter $4:', timeout=2) diff --git a/test/features/steps/specials.py b/test/features/steps/specials.py new file mode 100644 index 0000000..e8b99e3 --- /dev/null +++ b/test/features/steps/specials.py @@ -0,0 +1,27 @@ +"""Steps for behavioral style tests are defined in this module. + +Each step is defined by the string decorating it. This string is used +to call the step in "*.feature" file. + +""" + +import wrappers +from behave import when, then + + +@when('we refresh completions') +def step_refresh_completions(context): + """Send refresh command.""" + context.cli.sendline('rehash') + + +@then('we see text "{text}"') +def step_see_text(context, text): + """Wait to see given text message.""" + wrappers.expect_exact(context, text, timeout=2) + +@then('we see completions refresh started') +def step_see_refresh_started(context): + """Wait to see refresh output.""" + wrappers.expect_exact( + context, 'Auto-completion refresh started in the background.', timeout=2) diff --git a/test/features/steps/utils.py b/test/features/steps/utils.py new file mode 100644 index 0000000..1ae63d2 --- /dev/null +++ b/test/features/steps/utils.py @@ -0,0 +1,12 @@ +import shlex + + +def parse_cli_args_to_dict(cli_args: str): + args_dict = {} + for arg in shlex.split(cli_args): + if '=' in arg: + key, value = arg.split('=') + args_dict[key] = value + else: + args_dict[arg] = None + return args_dict diff --git a/test/features/steps/wrappers.py b/test/features/steps/wrappers.py new file mode 100644 index 0000000..6408f23 --- /dev/null +++ b/test/features/steps/wrappers.py @@ -0,0 +1,117 @@ +import re +import pexpect +import sys +import textwrap + + +try: + from StringIO import StringIO +except ImportError: + from io import StringIO + + +def expect_exact(context, expected, timeout): + timedout = False + try: + context.cli.expect_exact(expected, timeout=timeout) + except pexpect.TIMEOUT: + timedout = True + if timedout: + # Strip color codes out of the output. + actual = re.sub(r'\x1b\[([0-9A-Za-z;?])+[m|K]?', + '', context.cli.before) + raise Exception( + textwrap.dedent('''\ + Expected: + --- + {0!r} + --- + Actual: + --- + {1!r} + --- + Full log: + --- + {2!r} + --- + ''').format( + expected, + actual, + context.logfile.getvalue() + ) + ) + + +def expect_pager(context, expected, timeout): + expect_exact(context, "{0}\r\n{1}{0}\r\n".format( + context.conf['pager_boundary'], expected), timeout=timeout) + + +def run_cli(context, run_args=None, exclude_args=None): + """Run the process using pexpect.""" + run_args = run_args or {} + rendered_args = [] + exclude_args = set(exclude_args) if exclude_args else set() + + conf = dict(**context.conf) + conf.update(run_args) + + def add_arg(name, key, value): + if name not in exclude_args: + if value is not None: + rendered_args.extend((key, value)) + else: + rendered_args.append(key) + + if conf.get('host', None): + add_arg('host', '-h', conf['host']) + if conf.get('user', None): + add_arg('user', '-u', conf['user']) + if conf.get('pass', None): + add_arg('pass', '-p', conf['pass']) + if conf.get('port', None): + add_arg('port', '-P', str(conf['port'])) + if conf.get('dbname', None): + add_arg('dbname', '-D', conf['dbname']) + if conf.get('defaults-file', None): + add_arg('defaults_file', '--defaults-file', conf['defaults-file']) + if conf.get('myclirc', None): + add_arg('myclirc', '--myclirc', conf['myclirc']) + if conf.get('login_path'): + add_arg('login_path', '--login-path', conf['login_path']) + + for arg_name, arg_value in conf.items(): + if arg_name.startswith('-'): + add_arg(arg_name, arg_name, arg_value) + + try: + cli_cmd = context.conf['cli_command'] + except KeyError: + cli_cmd = ( + '{0!s} -c "' + 'import coverage ; ' + 'coverage.process_startup(); ' + 'import mycli.main; ' + 'mycli.main.cli()' + '"' + ).format(sys.executable) + + cmd_parts = [cli_cmd] + rendered_args + cmd = ' '.join(cmd_parts) + context.cli = pexpect.spawnu(cmd, cwd=context.package_root) + context.logfile = StringIO() + context.cli.logfile = context.logfile + context.exit_sent = False + context.currentdb = context.conf['dbname'] + + +def wait_prompt(context, prompt=None): + """Make sure prompt is displayed.""" + if prompt is None: + user = context.conf['user'] + host = context.conf['host'] + dbname = context.currentdb + prompt = '{0}@{1}:{2}>'.format( + user, host, dbname), + expect_exact(context, prompt, timeout=5) + context.atprompt = True diff --git a/test/features/wrappager.py b/test/features/wrappager.py new file mode 100755 index 0000000..51d4909 --- /dev/null +++ b/test/features/wrappager.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +import sys + + +def wrappager(boundary): + print(boundary) + while 1: + buf = sys.stdin.read(2048) + if not buf: + break + sys.stdout.write(buf) + print(boundary) + + +if __name__ == "__main__": + wrappager(sys.argv[1]) diff --git a/test/myclirc b/test/myclirc new file mode 100644 index 0000000..261bee6 --- /dev/null +++ b/test/myclirc @@ -0,0 +1,12 @@ +# vi: ft=dosini + +# This file is loaded after mycli/myclirc and should override only those +# variables needed for testing. +# To see what every variable does see mycli/myclirc + +[main] + +log_file = ~/.mycli.test.log +log_level = DEBUG +prompt = '\t \u@\h:\d> ' +less_chatty = True diff --git a/test/mylogin.cnf b/test/mylogin.cnf Binary files differnew file mode 100644 index 0000000..1363cc3 --- /dev/null +++ b/test/mylogin.cnf diff --git a/test/test.txt b/test/test.txt new file mode 100644 index 0000000..8d8b211 --- /dev/null +++ b/test/test.txt @@ -0,0 +1 @@ +mycli rocks! diff --git a/test/test_clistyle.py b/test/test_clistyle.py new file mode 100644 index 0000000..f82cdf0 --- /dev/null +++ b/test/test_clistyle.py @@ -0,0 +1,27 @@ +"""Test the mycli.clistyle module.""" +import pytest + +from pygments.style import Style +from pygments.token import Token + +from mycli.clistyle import style_factory + + +@pytest.mark.skip(reason="incompatible with new prompt toolkit") +def test_style_factory(): + """Test that a Pygments Style class is created.""" + header = 'bold underline #ansired' + cli_style = {'Token.Output.Header': header} + style = style_factory('default', cli_style) + + assert isinstance(style(), Style) + assert Token.Output.Header in style.styles + assert header == style.styles[Token.Output.Header] + + +@pytest.mark.skip(reason="incompatible with new prompt toolkit") +def test_style_factory_unknown_name(): + """Test that an unrecognized name will not throw an error.""" + style = style_factory('foobar', {}) + + assert isinstance(style(), Style) diff --git a/test/test_completion_engine.py b/test/test_completion_engine.py new file mode 100644 index 0000000..318b632 --- /dev/null +++ b/test/test_completion_engine.py @@ -0,0 +1,555 @@ +from mycli.packages.completion_engine import suggest_type +import pytest + + +def sorted_dicts(dicts): + """input is a list of dicts.""" + return sorted(tuple(x.items()) for x in dicts) + + +def test_select_suggests_cols_with_visible_table_scope(): + suggestions = suggest_type('SELECT FROM tabl', 'SELECT ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + +def test_select_suggests_cols_with_qualified_table_scope(): + suggestions = suggest_type('SELECT FROM sch.tabl', 'SELECT ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'column', 'tables': [('sch', 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + +@pytest.mark.parametrize('expression', [ + 'SELECT * FROM tabl WHERE ', + 'SELECT * FROM tabl WHERE (', + 'SELECT * FROM tabl WHERE foo = ', + 'SELECT * FROM tabl WHERE bar OR ', + 'SELECT * FROM tabl WHERE foo = 1 AND ', + 'SELECT * FROM tabl WHERE (bar > 10 AND ', + 'SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (', + 'SELECT * FROM tabl WHERE 10 < ', + 'SELECT * FROM tabl WHERE foo BETWEEN ', + 'SELECT * FROM tabl WHERE foo BETWEEN foo AND ', +]) +def test_where_suggests_columns_functions(expression): + suggestions = suggest_type(expression, expression) + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + +@pytest.mark.parametrize('expression', [ + 'SELECT * FROM tabl WHERE foo IN (', + 'SELECT * FROM tabl WHERE foo IN (bar, ', +]) +def test_where_in_suggests_columns(expression): + suggestions = suggest_type(expression, expression) + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + +def test_where_equals_any_suggests_columns_or_keywords(): + text = 'SELECT * FROM tabl WHERE foo = ANY(' + suggestions = suggest_type(text, text) + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}]) + + +def test_lparen_suggests_cols(): + suggestion = suggest_type('SELECT MAX( FROM tbl', 'SELECT MAX(') + assert suggestion == [ + {'type': 'column', 'tables': [(None, 'tbl', None)]}] + + +def test_operand_inside_function_suggests_cols1(): + suggestion = suggest_type( + 'SELECT MAX(col1 + FROM tbl', 'SELECT MAX(col1 + ') + assert suggestion == [ + {'type': 'column', 'tables': [(None, 'tbl', None)]}] + + +def test_operand_inside_function_suggests_cols2(): + suggestion = suggest_type( + 'SELECT MAX(col1 + col2 + FROM tbl', 'SELECT MAX(col1 + col2 + ') + assert suggestion == [ + {'type': 'column', 'tables': [(None, 'tbl', None)]}] + + +def test_select_suggests_cols_and_funcs(): + suggestions = suggest_type('SELECT ', 'SELECT ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': []}, + {'type': 'column', 'tables': []}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + +@pytest.mark.parametrize('expression', [ + 'SELECT * FROM ', + 'INSERT INTO ', + 'COPY ', + 'UPDATE ', + 'DESCRIBE ', + 'DESC ', + 'EXPLAIN ', + 'SELECT * FROM foo JOIN ', +]) +def test_expression_suggests_tables_views_and_schemas(expression): + suggestions = suggest_type(expression, expression) + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + +@pytest.mark.parametrize('expression', [ + 'SELECT * FROM sch.', + 'INSERT INTO sch.', + 'COPY sch.', + 'UPDATE sch.', + 'DESCRIBE sch.', + 'DESC sch.', + 'EXPLAIN sch.', + 'SELECT * FROM foo JOIN sch.', +]) +def test_expression_suggests_qualified_tables_views_and_schemas(expression): + suggestions = suggest_type(expression, expression) + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': 'sch'}, + {'type': 'view', 'schema': 'sch'}]) + + +def test_truncate_suggests_tables_and_schemas(): + suggestions = suggest_type('TRUNCATE ', 'TRUNCATE ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'schema'}]) + + +def test_truncate_suggests_qualified_tables(): + suggestions = suggest_type('TRUNCATE sch.', 'TRUNCATE sch.') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': 'sch'}]) + + +def test_distinct_suggests_cols(): + suggestions = suggest_type('SELECT DISTINCT ', 'SELECT DISTINCT ') + assert suggestions == [{'type': 'column', 'tables': []}] + + +def test_col_comma_suggests_cols(): + suggestions = suggest_type('SELECT a, b, FROM tbl', 'SELECT a, b,') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['tbl']}, + {'type': 'column', 'tables': [(None, 'tbl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + +def test_table_comma_suggests_tables_and_schemas(): + suggestions = suggest_type('SELECT a, b FROM tbl1, ', + 'SELECT a, b FROM tbl1, ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + +def test_into_suggests_tables_and_schemas(): + suggestion = suggest_type('INSERT INTO ', 'INSERT INTO ') + assert sorted_dicts(suggestion) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + +def test_insert_into_lparen_suggests_cols(): + suggestions = suggest_type('INSERT INTO abc (', 'INSERT INTO abc (') + assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}] + + +def test_insert_into_lparen_partial_text_suggests_cols(): + suggestions = suggest_type('INSERT INTO abc (i', 'INSERT INTO abc (i') + assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}] + + +def test_insert_into_lparen_comma_suggests_cols(): + suggestions = suggest_type('INSERT INTO abc (id,', 'INSERT INTO abc (id,') + assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}] + + +def test_partially_typed_col_name_suggests_col_names(): + suggestions = suggest_type('SELECT * FROM tabl WHERE col_n', + 'SELECT * FROM tabl WHERE col_n') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + +def test_dot_suggests_cols_of_a_table_or_schema_qualified_table(): + suggestions = suggest_type('SELECT tabl. FROM tabl', 'SELECT tabl.') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'table', 'schema': 'tabl'}, + {'type': 'view', 'schema': 'tabl'}, + {'type': 'function', 'schema': 'tabl'}]) + + +def test_dot_suggests_cols_of_an_alias(): + suggestions = suggest_type('SELECT t1. FROM tabl1 t1, tabl2 t2', + 'SELECT t1.') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': 't1'}, + {'type': 'view', 'schema': 't1'}, + {'type': 'column', 'tables': [(None, 'tabl1', 't1')]}, + {'type': 'function', 'schema': 't1'}]) + + +def test_dot_col_comma_suggests_cols_or_schema_qualified_table(): + suggestions = suggest_type('SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2', + 'SELECT t1.a, t2.') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'tabl2', 't2')]}, + {'type': 'table', 'schema': 't2'}, + {'type': 'view', 'schema': 't2'}, + {'type': 'function', 'schema': 't2'}]) + + +@pytest.mark.parametrize('expression', [ + 'SELECT * FROM (', + 'SELECT * FROM foo WHERE EXISTS (', + 'SELECT * FROM foo WHERE bar AND NOT EXISTS (', + 'SELECT 1 AS', +]) +def test_sub_select_suggests_keyword(expression): + suggestion = suggest_type(expression, expression) + assert suggestion == [{'type': 'keyword'}] + + +@pytest.mark.parametrize('expression', [ + 'SELECT * FROM (S', + 'SELECT * FROM foo WHERE EXISTS (S', + 'SELECT * FROM foo WHERE bar AND NOT EXISTS (S', +]) +def test_sub_select_partial_text_suggests_keyword(expression): + suggestion = suggest_type(expression, expression) + assert suggestion == [{'type': 'keyword'}] + + +def test_outer_table_reference_in_exists_subquery_suggests_columns(): + q = 'SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f.' + suggestions = suggest_type(q, q) + assert suggestions == [ + {'type': 'column', 'tables': [(None, 'foo', 'f')]}, + {'type': 'table', 'schema': 'f'}, + {'type': 'view', 'schema': 'f'}, + {'type': 'function', 'schema': 'f'}] + + +@pytest.mark.parametrize('expression', [ + 'SELECT * FROM (SELECT * FROM ', + 'SELECT * FROM foo WHERE EXISTS (SELECT * FROM ', + 'SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ', +]) +def test_sub_select_table_name_completion(expression): + suggestion = suggest_type(expression, expression) + assert sorted_dicts(suggestion) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + +def test_sub_select_col_name_completion(): + suggestions = suggest_type('SELECT * FROM (SELECT FROM abc', + 'SELECT * FROM (SELECT ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['abc']}, + {'type': 'column', 'tables': [(None, 'abc', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + +@pytest.mark.xfail +def test_sub_select_multiple_col_name_completion(): + suggestions = suggest_type('SELECT * FROM (SELECT a, FROM abc', + 'SELECT * FROM (SELECT a, ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'abc', None)]}, + {'type': 'function', 'schema': []}]) + + +def test_sub_select_dot_col_name_completion(): + suggestions = suggest_type('SELECT * FROM (SELECT t. FROM tabl t', + 'SELECT * FROM (SELECT t.') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'tabl', 't')]}, + {'type': 'table', 'schema': 't'}, + {'type': 'view', 'schema': 't'}, + {'type': 'function', 'schema': 't'}]) + + +@pytest.mark.parametrize('join_type', ['', 'INNER', 'LEFT', 'RIGHT OUTER']) +@pytest.mark.parametrize('tbl_alias', ['', 'foo']) +def test_join_suggests_tables_and_schemas(tbl_alias, join_type): + text = 'SELECT * FROM abc {0} {1} JOIN '.format(tbl_alias, join_type) + suggestion = suggest_type(text, text) + assert sorted_dicts(suggestion) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + +@pytest.mark.parametrize('sql', [ + 'SELECT * FROM abc a JOIN def d ON a.', + 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.', +]) +def test_join_alias_dot_suggests_cols1(sql): + suggestions = suggest_type(sql, sql) + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'abc', 'a')]}, + {'type': 'table', 'schema': 'a'}, + {'type': 'view', 'schema': 'a'}, + {'type': 'function', 'schema': 'a'}]) + + +@pytest.mark.parametrize('sql', [ + 'SELECT * FROM abc a JOIN def d ON a.id = d.', + 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.', +]) +def test_join_alias_dot_suggests_cols2(sql): + suggestions = suggest_type(sql, sql) + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'def', 'd')]}, + {'type': 'table', 'schema': 'd'}, + {'type': 'view', 'schema': 'd'}, + {'type': 'function', 'schema': 'd'}]) + + +@pytest.mark.parametrize('sql', [ + 'select a.x, b.y from abc a join bcd b on ', + 'select a.x, b.y from abc a join bcd b on a.id = b.id OR ', +]) +def test_on_suggests_aliases(sql): + suggestions = suggest_type(sql, sql) + assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}] + + +@pytest.mark.parametrize('sql', [ + 'select abc.x, bcd.y from abc join bcd on ', + 'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ', +]) +def test_on_suggests_tables(sql): + suggestions = suggest_type(sql, sql) + assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}] + + +@pytest.mark.parametrize('sql', [ + 'select a.x, b.y from abc a join bcd b on a.id = ', + 'select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ', +]) +def test_on_suggests_aliases_right_side(sql): + suggestions = suggest_type(sql, sql) + assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}] + + +@pytest.mark.parametrize('sql', [ + 'select abc.x, bcd.y from abc join bcd on ', + 'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ', +]) +def test_on_suggests_tables_right_side(sql): + suggestions = suggest_type(sql, sql) + assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}] + + +@pytest.mark.parametrize('col_list', ['', 'col1, ']) +def test_join_using_suggests_common_columns(col_list): + text = 'select * from abc inner join def using (' + col_list + assert suggest_type(text, text) == [ + {'type': 'column', + 'tables': [(None, 'abc', None), (None, 'def', None)], + 'drop_unique': True}] + +@pytest.mark.parametrize('sql', [ + 'SELECT * FROM abc a JOIN def d ON a.id = d.id JOIN ghi g ON g.', + 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.id2 JOIN ghi g ON d.id = g.id AND g.', +]) +def test_two_join_alias_dot_suggests_cols1(sql): + suggestions = suggest_type(sql, sql) + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'ghi', 'g')]}, + {'type': 'table', 'schema': 'g'}, + {'type': 'view', 'schema': 'g'}, + {'type': 'function', 'schema': 'g'}]) + +def test_2_statements_2nd_current(): + suggestions = suggest_type('select * from a; select * from ', + 'select * from a; select * from ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + suggestions = suggest_type('select * from a; select from b', + 'select * from a; select ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['b']}, + {'type': 'column', 'tables': [(None, 'b', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + # Should work even if first statement is invalid + suggestions = suggest_type('select * from; select * from ', + 'select * from; select * from ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + +def test_2_statements_1st_current(): + suggestions = suggest_type('select * from ; select * from b', + 'select * from ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + suggestions = suggest_type('select from a; select * from b', + 'select ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['a']}, + {'type': 'column', 'tables': [(None, 'a', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + +def test_3_statements_2nd_current(): + suggestions = suggest_type('select * from a; select * from ; select * from c', + 'select * from a; select * from ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + suggestions = suggest_type('select * from a; select from b; select * from c', + 'select * from a; select ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['b']}, + {'type': 'column', 'tables': [(None, 'b', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + +def test_create_db_with_template(): + suggestions = suggest_type('create database foo with template ', + 'create database foo with template ') + + assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'database'}]) + + +@pytest.mark.parametrize('initial_text', ['', ' ', '\t \t']) +def test_specials_included_for_initial_completion(initial_text): + suggestions = suggest_type(initial_text, initial_text) + + assert sorted_dicts(suggestions) == \ + sorted_dicts([{'type': 'keyword'}, {'type': 'special'}]) + + +def test_specials_not_included_after_initial_token(): + suggestions = suggest_type('create table foo (dt d', + 'create table foo (dt d') + + assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'keyword'}]) + + +def test_drop_schema_qualified_table_suggests_only_tables(): + text = 'DROP TABLE schema_name.table_name' + suggestions = suggest_type(text, text) + assert suggestions == [{'type': 'table', 'schema': 'schema_name'}] + + +@pytest.mark.parametrize('text', [',', ' ,', 'sel ,']) +def test_handle_pre_completion_comma_gracefully(text): + suggestions = suggest_type(text, text) + + assert iter(suggestions) + + +def test_cross_join(): + text = 'select * from v1 cross join v2 JOIN v1.id, ' + suggestions = suggest_type(text, text) + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + +@pytest.mark.parametrize('expression', [ + 'SELECT 1 AS ', + 'SELECT 1 FROM tabl AS ', +]) +def test_after_as(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == set() + + +@pytest.mark.parametrize('expression', [ + '\\. ', + 'select 1; \\. ', + 'select 1;\\. ', + 'select 1 ; \\. ', + 'source ', + 'truncate table test; source ', + 'truncate table test ; source ', + 'truncate table test;source ', +]) +def test_source_is_file(expression): + suggestions = suggest_type(expression, expression) + assert suggestions == [{'type': 'file_name'}] + + +@pytest.mark.parametrize("expression", [ + "\\f ", +]) +def test_favorite_name_suggestion(expression): + suggestions = suggest_type(expression, expression) + assert suggestions == [{'type': 'favoritequery'}] + + +def test_order_by(): + text = 'select * from foo order by ' + suggestions = suggest_type(text, text) + assert suggestions == [{'tables': [(None, 'foo', None)], 'type': 'column'}] + + +def test_quoted_where(): + text = "'where i=';" + suggestions = suggest_type(text, text) + assert suggestions == [{'type': 'keyword'}] diff --git a/test/test_completion_refresher.py b/test/test_completion_refresher.py new file mode 100644 index 0000000..cdc2fb5 --- /dev/null +++ b/test/test_completion_refresher.py @@ -0,0 +1,88 @@ +import time +import pytest +from unittest.mock import Mock, patch + + +@pytest.fixture +def refresher(): + from mycli.completion_refresher import CompletionRefresher + return CompletionRefresher() + + +def test_ctor(refresher): + """Refresher object should contain a few handlers. + + :param refresher: + :return: + + """ + assert len(refresher.refreshers) > 0 + actual_handlers = list(refresher.refreshers.keys()) + expected_handlers = ['databases', 'schemata', 'tables', 'users', 'functions', + 'special_commands', 'show_commands'] + assert expected_handlers == actual_handlers + + +def test_refresh_called_once(refresher): + """ + + :param refresher: + :return: + """ + callbacks = Mock() + sqlexecute = Mock() + + with patch.object(refresher, '_bg_refresh') as bg_refresh: + actual = refresher.refresh(sqlexecute, callbacks) + time.sleep(1) # Wait for the thread to work. + assert len(actual) == 1 + assert len(actual[0]) == 4 + assert actual[0][3] == 'Auto-completion refresh started in the background.' + bg_refresh.assert_called_with(sqlexecute, callbacks, {}) + + +def test_refresh_called_twice(refresher): + """If refresh is called a second time, it should be restarted. + + :param refresher: + :return: + + """ + callbacks = Mock() + + sqlexecute = Mock() + + def dummy_bg_refresh(*args): + time.sleep(3) # seconds + + refresher._bg_refresh = dummy_bg_refresh + + actual1 = refresher.refresh(sqlexecute, callbacks) + time.sleep(1) # Wait for the thread to work. + assert len(actual1) == 1 + assert len(actual1[0]) == 4 + assert actual1[0][3] == 'Auto-completion refresh started in the background.' + + actual2 = refresher.refresh(sqlexecute, callbacks) + time.sleep(1) # Wait for the thread to work. + assert len(actual2) == 1 + assert len(actual2[0]) == 4 + assert actual2[0][3] == 'Auto-completion refresh restarted.' + + +def test_refresh_with_callbacks(refresher): + """Callbacks must be called. + + :param refresher: + + """ + callbacks = [Mock()] + sqlexecute_class = Mock() + sqlexecute = Mock() + + with patch('mycli.completion_refresher.SQLExecute', sqlexecute_class): + # Set refreshers to 0: we're not testing refresh logic here + refresher.refreshers = {} + refresher.refresh(sqlexecute, callbacks) + time.sleep(1) # Wait for the thread to work. + assert (callbacks[0].call_count == 1) diff --git a/test/test_config.py b/test/test_config.py new file mode 100644 index 0000000..7f2b244 --- /dev/null +++ b/test/test_config.py @@ -0,0 +1,196 @@ +"""Unit tests for the mycli.config module.""" +from io import BytesIO, StringIO, TextIOWrapper +import os +import struct +import sys +import tempfile +import pytest + +from mycli.config import (get_mylogin_cnf_path, open_mylogin_cnf, + read_and_decrypt_mylogin_cnf, read_config_file, + str_to_bool, strip_matching_quotes) + +LOGIN_PATH_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__), + 'mylogin.cnf')) + + +def open_bmylogin_cnf(name): + """Open contents of *name* in a BytesIO buffer.""" + with open(name, 'rb') as f: + buf = BytesIO() + buf.write(f.read()) + return buf + +def test_read_mylogin_cnf(): + """Tests that a login path file can be read and decrypted.""" + mylogin_cnf = open_mylogin_cnf(LOGIN_PATH_FILE) + + assert isinstance(mylogin_cnf, TextIOWrapper) + + contents = mylogin_cnf.read() + for word in ('[test]', 'user', 'password', 'host', 'port'): + assert word in contents + + +def test_decrypt_blank_mylogin_cnf(): + """Test that a blank login path file is handled correctly.""" + mylogin_cnf = read_and_decrypt_mylogin_cnf(BytesIO()) + assert mylogin_cnf is None + + +def test_corrupted_login_key(): + """Test that a corrupted login path key is handled correctly.""" + buf = open_bmylogin_cnf(LOGIN_PATH_FILE) + + # Skip past the unused bytes + buf.seek(4) + + # Write null bytes over half the login key + buf.write(b'\0\0\0\0\0\0\0\0\0\0') + + buf.seek(0) + mylogin_cnf = read_and_decrypt_mylogin_cnf(buf) + + assert mylogin_cnf is None + + +def test_corrupted_pad(): + """Tests that a login path file with a corrupted pad is partially read.""" + buf = open_bmylogin_cnf(LOGIN_PATH_FILE) + + # Skip past the login key + buf.seek(24) + + # Skip option group + len_buf = buf.read(4) + cipher_len, = struct.unpack("<i", len_buf) + buf.read(cipher_len) + + # Corrupt the pad for the user line + len_buf = buf.read(4) + cipher_len, = struct.unpack("<i", len_buf) + buf.read(cipher_len - 1) + buf.write(b'\0') + + buf.seek(0) + mylogin_cnf = TextIOWrapper(read_and_decrypt_mylogin_cnf(buf)) + contents = mylogin_cnf.read() + for word in ('[test]', 'password', 'host', 'port'): + assert word in contents + assert 'user' not in contents + + +def test_get_mylogin_cnf_path(): + """Tests that the path for .mylogin.cnf is detected.""" + original_env = None + if 'MYSQL_TEST_LOGIN_FILE' in os.environ: + original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE') + is_windows = sys.platform == 'win32' + + login_cnf_path = get_mylogin_cnf_path() + + if original_env is not None: + os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env + + if login_cnf_path is not None: + assert login_cnf_path.endswith('.mylogin.cnf') + + if is_windows is True: + assert 'MySQL' in login_cnf_path + else: + home_dir = os.path.expanduser('~') + assert login_cnf_path.startswith(home_dir) + + +def test_alternate_get_mylogin_cnf_path(): + """Tests that the alternate path for .mylogin.cnf is detected.""" + original_env = None + if 'MYSQL_TEST_LOGIN_FILE' in os.environ: + original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE') + + _, temp_path = tempfile.mkstemp() + os.environ['MYSQL_TEST_LOGIN_FILE'] = temp_path + + login_cnf_path = get_mylogin_cnf_path() + + if original_env is not None: + os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env + + assert temp_path == login_cnf_path + + +def test_str_to_bool(): + """Tests that str_to_bool function converts values correctly.""" + + assert str_to_bool(False) is False + assert str_to_bool(True) is True + assert str_to_bool('False') is False + assert str_to_bool('True') is True + assert str_to_bool('TRUE') is True + assert str_to_bool('1') is True + assert str_to_bool('0') is False + assert str_to_bool('on') is True + assert str_to_bool('off') is False + assert str_to_bool('off') is False + + with pytest.raises(ValueError): + str_to_bool('foo') + + with pytest.raises(TypeError): + str_to_bool(None) + + +def test_read_config_file_list_values_default(): + """Test that reading a config file uses list_values by default.""" + + f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n") + config = read_config_file(f) + + assert config['main']['weather'] == u"cloudy with a chance of meatballs" + + +def test_read_config_file_list_values_off(): + """Test that you can disable list_values when reading a config file.""" + + f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n") + config = read_config_file(f, list_values=False) + + assert config['main']['weather'] == u"'cloudy with a chance of meatballs'" + + +def test_strip_quotes_with_matching_quotes(): + """Test that a string with matching quotes is unquoted.""" + + s = "May the force be with you." + assert s == strip_matching_quotes('"{}"'.format(s)) + assert s == strip_matching_quotes("'{}'".format(s)) + + +def test_strip_quotes_with_unmatching_quotes(): + """Test that a string with unmatching quotes is not unquoted.""" + + s = "May the force be with you." + assert '"' + s == strip_matching_quotes('"{}'.format(s)) + assert s + "'" == strip_matching_quotes("{}'".format(s)) + + +def test_strip_quotes_with_empty_string(): + """Test that an empty string is handled during unquoting.""" + + assert '' == strip_matching_quotes('') + + +def test_strip_quotes_with_none(): + """Test that None is handled during unquoting.""" + + assert None is strip_matching_quotes(None) + + +def test_strip_quotes_with_quotes(): + """Test that strings with quotes in them are handled during unquoting.""" + + s1 = 'Darth Vader said, "Luke, I am your father."' + assert s1 == strip_matching_quotes(s1) + + s2 = '"Darth Vader said, "Luke, I am your father.""' + assert s2[1:-1] == strip_matching_quotes(s2) diff --git a/test/test_dbspecial.py b/test/test_dbspecial.py new file mode 100644 index 0000000..21e389c --- /dev/null +++ b/test/test_dbspecial.py @@ -0,0 +1,42 @@ +from mycli.packages.completion_engine import suggest_type +from .test_completion_engine import sorted_dicts +from mycli.packages.special.utils import format_uptime + + +def test_u_suggests_databases(): + suggestions = suggest_type('\\u ', '\\u ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'database'}]) + + +def test_describe_table(): + suggestions = suggest_type('\\dt', '\\dt ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + +def test_list_or_show_create_tables(): + suggestions = suggest_type('\\dt+', '\\dt+ ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + +def test_format_uptime(): + seconds = 59 + assert '59 sec' == format_uptime(seconds) + + seconds = 120 + assert '2 min 0 sec' == format_uptime(seconds) + + seconds = 54890 + assert '15 hours 14 min 50 sec' == format_uptime(seconds) + + seconds = 598244 + assert '6 days 22 hours 10 min 44 sec' == format_uptime(seconds) + + seconds = 522600 + assert '6 days 1 hour 10 min 0 sec' == format_uptime(seconds) diff --git a/test/test_main.py b/test/test_main.py new file mode 100644 index 0000000..64cba0a --- /dev/null +++ b/test/test_main.py @@ -0,0 +1,548 @@ +import os +import shutil + +import click +from click.testing import CliRunner + +from mycli.main import MyCli, cli, thanks_picker +from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS +from mycli.sqlexecute import ServerInfo +from .utils import USER, HOST, PORT, PASSWORD, dbtest, run + +from textwrap import dedent +from collections import namedtuple + +from tempfile import NamedTemporaryFile +from textwrap import dedent + + +test_dir = os.path.abspath(os.path.dirname(__file__)) +project_dir = os.path.dirname(test_dir) +default_config_file = os.path.join(project_dir, 'test', 'myclirc') +login_path_file = os.path.join(test_dir, 'mylogin.cnf') + +os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file +CLI_ARGS = ['--user', USER, '--host', HOST, '--port', PORT, + '--password', PASSWORD, '--myclirc', default_config_file, + '--defaults-file', default_config_file, + 'mycli_test_db'] + + +@dbtest +def test_execute_arg(executor): + run(executor, 'create table test (a text)') + run(executor, 'insert into test values("abc")') + + sql = 'select * from test;' + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql]) + + assert result.exit_code == 0 + assert 'abc' in result.output + + result = runner.invoke(cli, args=CLI_ARGS + ['--execute', sql]) + + assert result.exit_code == 0 + assert 'abc' in result.output + + expected = 'a\nabc\n' + + assert expected in result.output + + +@dbtest +def test_execute_arg_with_table(executor): + run(executor, 'create table test (a text)') + run(executor, 'insert into test values("abc")') + + sql = 'select * from test;' + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--table']) + expected = '+-----+\n| a |\n+-----+\n| abc |\n+-----+\n' + + assert result.exit_code == 0 + assert expected in result.output + + +@dbtest +def test_execute_arg_with_csv(executor): + run(executor, 'create table test (a text)') + run(executor, 'insert into test values("abc")') + + sql = 'select * from test;' + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--csv']) + expected = '"a"\n"abc"\n' + + assert result.exit_code == 0 + assert expected in "".join(result.output) + + +@dbtest +def test_batch_mode(executor): + run(executor, '''create table test(a text)''') + run(executor, '''insert into test values('abc'), ('def'), ('ghi')''') + + sql = ( + 'select count(*) from test;\n' + 'select * from test limit 1;' + ) + + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS, input=sql) + + assert result.exit_code == 0 + assert 'count(*)\n3\na\nabc\n' in "".join(result.output) + + +@dbtest +def test_batch_mode_table(executor): + run(executor, '''create table test(a text)''') + run(executor, '''insert into test values('abc'), ('def'), ('ghi')''') + + sql = ( + 'select count(*) from test;\n' + 'select * from test limit 1;' + ) + + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS + ['-t'], input=sql) + + expected = (dedent("""\ + +----------+ + | count(*) | + +----------+ + | 3 | + +----------+ + +-----+ + | a | + +-----+ + | abc | + +-----+""")) + + assert result.exit_code == 0 + assert expected in result.output + + +@dbtest +def test_batch_mode_csv(executor): + run(executor, '''create table test(a text, b text)''') + run(executor, + '''insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')''') + + sql = 'select * from test;' + + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS + ['--csv'], input=sql) + + expected = '"a","b"\n"abc","de\nf"\n"ghi","jkl"\n' + + assert result.exit_code == 0 + assert expected in "".join(result.output) + + +def test_thanks_picker_utf8(): + name = thanks_picker() + assert name and isinstance(name, str) + + +def test_help_strings_end_with_periods(): + """Make sure click options have help text that end with a period.""" + for param in cli.params: + if isinstance(param, click.core.Option): + assert hasattr(param, 'help') + assert param.help.endswith('.') + + +def test_command_descriptions_end_with_periods(): + """Make sure that mycli commands' descriptions end with a period.""" + MyCli() + for _, command in SPECIAL_COMMANDS.items(): + assert command[3].endswith('.') + + +def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager): + global clickoutput + clickoutput = "" + m = MyCli(myclirc=default_config_file) + + class TestOutput(): + def get_size(self): + size = namedtuple('Size', 'rows columns') + size.columns, size.rows = terminal_size + return size + + class TestExecute(): + host = 'test' + user = 'test' + dbname = 'test' + server_info = ServerInfo.from_version_string('unknown') + port = 0 + + def server_type(self): + return ['test'] + + class PromptBuffer(): + output = TestOutput() + + m.prompt_app = PromptBuffer() + m.sqlexecute = TestExecute() + m.explicit_pager = explicit_pager + + def echo_via_pager(s): + assert expect_pager + global clickoutput + clickoutput += "".join(s) + + def secho(s): + assert not expect_pager + global clickoutput + clickoutput += s + "\n" + + monkeypatch.setattr(click, 'echo_via_pager', echo_via_pager) + monkeypatch.setattr(click, 'secho', secho) + m.output(testdata) + if clickoutput.endswith("\n"): + clickoutput = clickoutput[:-1] + assert clickoutput == "\n".join(testdata) + + +def test_conditional_pager(monkeypatch): + testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split( + " ") + # User didn't set pager, output doesn't fit screen -> pager + output( + monkeypatch, + terminal_size=(5, 10), + testdata=testdata, + explicit_pager=False, + expect_pager=True + ) + # User didn't set pager, output fits screen -> no pager + output( + monkeypatch, + terminal_size=(20, 20), + testdata=testdata, + explicit_pager=False, + expect_pager=False + ) + # User manually configured pager, output doesn't fit screen -> pager + output( + monkeypatch, + terminal_size=(5, 10), + testdata=testdata, + explicit_pager=True, + expect_pager=True + ) + # User manually configured pager, output fit screen -> pager + output( + monkeypatch, + terminal_size=(20, 20), + testdata=testdata, + explicit_pager=True, + expect_pager=True + ) + + SPECIAL_COMMANDS['nopager'].handler() + output( + monkeypatch, + terminal_size=(5, 10), + testdata=testdata, + explicit_pager=False, + expect_pager=False + ) + SPECIAL_COMMANDS['pager'].handler('') + + +def test_reserved_space_is_integer(): + """Make sure that reserved space is returned as an integer.""" + def stub_terminal_size(): + return (5, 5) + + old_func = shutil.get_terminal_size + + shutil.get_terminal_size = stub_terminal_size + mycli = MyCli() + assert isinstance(mycli.get_reserved_space(), int) + + shutil.get_terminal_size = old_func + + +def test_list_dsn(): + runner = CliRunner() + with NamedTemporaryFile(mode="w") as myclirc: + myclirc.write(dedent("""\ + [alias_dsn] + test = mysql://test/test + """)) + myclirc.flush() + args = ['--list-dsn', '--myclirc', myclirc.name] + result = runner.invoke(cli, args=args) + assert result.output == "test\n" + result = runner.invoke(cli, args=args + ['--verbose']) + assert result.output == "test : mysql://test/test\n" + + +def test_prettify_statement(): + statement = 'SELECT 1' + m = MyCli() + pretty_statement = m.handle_prettify_binding(statement) + assert pretty_statement == 'SELECT\n 1;' + + +def test_unprettify_statement(): + statement = 'SELECT\n 1' + m = MyCli() + unpretty_statement = m.handle_unprettify_binding(statement) + assert unpretty_statement == 'SELECT 1;' + + +def test_list_ssh_config(): + runner = CliRunner() + with NamedTemporaryFile(mode="w") as ssh_config: + ssh_config.write(dedent("""\ + Host test + Hostname test.example.com + User joe + Port 22222 + IdentityFile ~/.ssh/gateway + """)) + ssh_config.flush() + args = ['--list-ssh-config', '--ssh-config-path', ssh_config.name] + result = runner.invoke(cli, args=args) + assert "test\n" in result.output + result = runner.invoke(cli, args=args + ['--verbose']) + assert "test : test.example.com\n" in result.output + + +def test_dsn(monkeypatch): + # Setup classes to mock mycli.main.MyCli + class Formatter: + format_name = None + + class Logger: + def debug(self, *args, **args_dict): + pass + + def warning(self, *args, **args_dict): + pass + + class MockMyCli: + config = {'alias_dsn': {}} + + def __init__(self, **args): + self.logger = Logger() + self.destructive_warning = False + self.formatter = Formatter() + + def connect(self, **args): + MockMyCli.connect_args = args + + def run_query(self, query, new_line=True): + pass + + import mycli.main + monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + runner = CliRunner() + + # When a user supplies a DSN as database argument to mycli, + # use these values. + result = runner.invoke(mycli.main.cli, args=[ + "mysql://dsn_user:dsn_passwd@dsn_host:1/dsn_database"] + ) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert \ + MockMyCli.connect_args["user"] == "dsn_user" and \ + MockMyCli.connect_args["passwd"] == "dsn_passwd" and \ + MockMyCli.connect_args["host"] == "dsn_host" and \ + MockMyCli.connect_args["port"] == 1 and \ + MockMyCli.connect_args["database"] == "dsn_database" + + MockMyCli.connect_args = None + + # When a use supplies a DSN as database argument to mycli, + # and used command line arguments, use the command line + # arguments. + result = runner.invoke(mycli.main.cli, args=[ + "mysql://dsn_user:dsn_passwd@dsn_host:2/dsn_database", + "--user", "arg_user", + "--password", "arg_password", + "--host", "arg_host", + "--port", "3", + "--database", "arg_database", + ]) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert \ + MockMyCli.connect_args["user"] == "arg_user" and \ + MockMyCli.connect_args["passwd"] == "arg_password" and \ + MockMyCli.connect_args["host"] == "arg_host" and \ + MockMyCli.connect_args["port"] == 3 and \ + MockMyCli.connect_args["database"] == "arg_database" + + MockMyCli.config = { + 'alias_dsn': { + 'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database' + } + } + MockMyCli.connect_args = None + + # When a user uses a DSN from the configuration file (alias_dsn), + # use these values. + result = runner.invoke(cli, args=['--dsn', 'test']) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert \ + MockMyCli.connect_args["user"] == "alias_dsn_user" and \ + MockMyCli.connect_args["passwd"] == "alias_dsn_passwd" and \ + MockMyCli.connect_args["host"] == "alias_dsn_host" and \ + MockMyCli.connect_args["port"] == 4 and \ + MockMyCli.connect_args["database"] == "alias_dsn_database" + + MockMyCli.config = { + 'alias_dsn': { + 'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database' + } + } + MockMyCli.connect_args = None + + # When a user uses a DSN from the configuration file (alias_dsn) + # and used command line arguments, use the command line arguments. + result = runner.invoke(cli, args=[ + '--dsn', 'test', '', + "--user", "arg_user", + "--password", "arg_password", + "--host", "arg_host", + "--port", "5", + "--database", "arg_database", + ]) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert \ + MockMyCli.connect_args["user"] == "arg_user" and \ + MockMyCli.connect_args["passwd"] == "arg_password" and \ + MockMyCli.connect_args["host"] == "arg_host" and \ + MockMyCli.connect_args["port"] == 5 and \ + MockMyCli.connect_args["database"] == "arg_database" + + # Use a DSN without password + result = runner.invoke(mycli.main.cli, args=[ + "mysql://dsn_user@dsn_host:6/dsn_database"] + ) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert \ + MockMyCli.connect_args["user"] == "dsn_user" and \ + MockMyCli.connect_args["passwd"] is None and \ + MockMyCli.connect_args["host"] == "dsn_host" and \ + MockMyCli.connect_args["port"] == 6 and \ + MockMyCli.connect_args["database"] == "dsn_database" + + +def test_ssh_config(monkeypatch): + # Setup classes to mock mycli.main.MyCli + class Formatter: + format_name = None + + class Logger: + def debug(self, *args, **args_dict): + pass + + def warning(self, *args, **args_dict): + pass + + class MockMyCli: + config = {'alias_dsn': {}} + + def __init__(self, **args): + self.logger = Logger() + self.destructive_warning = False + self.formatter = Formatter() + + def connect(self, **args): + MockMyCli.connect_args = args + + def run_query(self, query, new_line=True): + pass + + import mycli.main + monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + runner = CliRunner() + + # Setup temporary configuration + with NamedTemporaryFile(mode="w") as ssh_config: + ssh_config.write(dedent("""\ + Host test + Hostname test.example.com + User joe + Port 22222 + IdentityFile ~/.ssh/gateway + """)) + ssh_config.flush() + + # When a user supplies a ssh config. + result = runner.invoke(mycli.main.cli, args=[ + "--ssh-config-path", + ssh_config.name, + "--ssh-config-host", + "test" + ]) + assert result.exit_code == 0, result.output + \ + " " + str(result.exception) + assert \ + MockMyCli.connect_args["ssh_user"] == "joe" and \ + MockMyCli.connect_args["ssh_host"] == "test.example.com" and \ + MockMyCli.connect_args["ssh_port"] == 22222 and \ + MockMyCli.connect_args["ssh_key_filename"] == os.getenv( + "HOME") + "/.ssh/gateway" + + # When a user supplies a ssh config host as argument to mycli, + # and used command line arguments, use the command line + # arguments. + result = runner.invoke(mycli.main.cli, args=[ + "--ssh-config-path", + ssh_config.name, + "--ssh-config-host", + "test", + "--ssh-user", "arg_user", + "--ssh-host", "arg_host", + "--ssh-port", "3", + "--ssh-key-filename", "/path/to/key" + ]) + assert result.exit_code == 0, result.output + \ + " " + str(result.exception) + assert \ + MockMyCli.connect_args["ssh_user"] == "arg_user" and \ + MockMyCli.connect_args["ssh_host"] == "arg_host" and \ + MockMyCli.connect_args["ssh_port"] == 3 and \ + MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key" + + +@dbtest +def test_init_command_arg(executor): + init_command = "set sql_select_limit=1000" + sql = 'show variables like "sql_select_limit";' + runner = CliRunner() + result = runner.invoke( + cli, args=CLI_ARGS + ["--init-command", init_command], input=sql + ) + + expected = "sql_select_limit\t1000\n" + assert result.exit_code == 0 + assert expected in result.output + + +@dbtest +def test_init_command_multiple_arg(executor): + init_command = 'set sql_select_limit=2000; set max_join_size=20000' + sql = ( + 'show variables like "sql_select_limit";\n' + 'show variables like "max_join_size"' + ) + runner = CliRunner() + result = runner.invoke( + cli, args=CLI_ARGS + ['--init-command', init_command], input=sql + ) + + expected_sql_select_limit = 'sql_select_limit\t2000\n' + expected_max_join_size = 'max_join_size\t20000\n' + + assert result.exit_code == 0 + assert expected_sql_select_limit in result.output + assert expected_max_join_size in result.output diff --git a/test/test_naive_completion.py b/test/test_naive_completion.py new file mode 100644 index 0000000..32b2abd --- /dev/null +++ b/test/test_naive_completion.py @@ -0,0 +1,63 @@ +import pytest +from prompt_toolkit.completion import Completion +from prompt_toolkit.document import Document + + +@pytest.fixture +def completer(): + import mycli.sqlcompleter as sqlcompleter + return sqlcompleter.SQLCompleter(smart_completion=False) + + +@pytest.fixture +def complete_event(): + from unittest.mock import Mock + return Mock() + + +def test_empty_string_completion(completer, complete_event): + text = '' + position = 0 + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list(map(Completion, sorted(completer.all_completions))) + + +def test_select_keyword_completion(completer, complete_event): + text = 'SEL' + position = len('SEL') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([Completion(text='SELECT', start_position=-3)]) + + +def test_function_name_completion(completer, complete_event): + text = 'SELECT MA' + position = len('SELECT MA') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='MASTER', start_position=-2), + Completion(text='MAX', start_position=-2)]) + + +def test_column_name_completion(completer, complete_event): + text = 'SELECT FROM users' + position = len('SELECT ') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list(map(Completion, sorted(completer.all_completions))) + + +def test_special_name_completion(completer, complete_event): + text = '\\' + position = len('\\') + result = set(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + # Special commands will NOT be suggested during naive completion mode. + assert result == set() diff --git a/test/test_parseutils.py b/test/test_parseutils.py new file mode 100644 index 0000000..920a08d --- /dev/null +++ b/test/test_parseutils.py @@ -0,0 +1,190 @@ +import pytest +from mycli.packages.parseutils import ( + extract_tables, query_starts_with, queries_start_with, is_destructive, query_has_where_clause, + is_dropping_database) + + +def test_empty_string(): + tables = extract_tables('') + assert tables == [] + + +def test_simple_select_single_table(): + tables = extract_tables('select * from abc') + assert tables == [(None, 'abc', None)] + + +def test_simple_select_single_table_schema_qualified(): + tables = extract_tables('select * from abc.def') + assert tables == [('abc', 'def', None)] + + +def test_simple_select_multiple_tables(): + tables = extract_tables('select * from abc, def') + assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] + + +def test_simple_select_multiple_tables_schema_qualified(): + tables = extract_tables('select * from abc.def, ghi.jkl') + assert sorted(tables) == [('abc', 'def', None), ('ghi', 'jkl', None)] + + +def test_simple_select_with_cols_single_table(): + tables = extract_tables('select a,b from abc') + assert tables == [(None, 'abc', None)] + + +def test_simple_select_with_cols_single_table_schema_qualified(): + tables = extract_tables('select a,b from abc.def') + assert tables == [('abc', 'def', None)] + + +def test_simple_select_with_cols_multiple_tables(): + tables = extract_tables('select a,b from abc, def') + assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] + + +def test_simple_select_with_cols_multiple_tables_with_schema(): + tables = extract_tables('select a,b from abc.def, def.ghi') + assert sorted(tables) == [('abc', 'def', None), ('def', 'ghi', None)] + + +def test_select_with_hanging_comma_single_table(): + tables = extract_tables('select a, from abc') + assert tables == [(None, 'abc', None)] + + +def test_select_with_hanging_comma_multiple_tables(): + tables = extract_tables('select a, from abc, def') + assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] + + +def test_select_with_hanging_period_multiple_tables(): + tables = extract_tables('SELECT t1. FROM tabl1 t1, tabl2 t2') + assert sorted(tables) == [(None, 'tabl1', 't1'), (None, 'tabl2', 't2')] + + +def test_simple_insert_single_table(): + tables = extract_tables('insert into abc (id, name) values (1, "def")') + + # sqlparse mistakenly assigns an alias to the table + # assert tables == [(None, 'abc', None)] + assert tables == [(None, 'abc', 'abc')] + + +@pytest.mark.xfail +def test_simple_insert_single_table_schema_qualified(): + tables = extract_tables('insert into abc.def (id, name) values (1, "def")') + assert tables == [('abc', 'def', None)] + + +def test_simple_update_table(): + tables = extract_tables('update abc set id = 1') + assert tables == [(None, 'abc', None)] + + +def test_simple_update_table_with_schema(): + tables = extract_tables('update abc.def set id = 1') + assert tables == [('abc', 'def', None)] + + +def test_join_table(): + tables = extract_tables('SELECT * FROM abc a JOIN def d ON a.id = d.num') + assert sorted(tables) == [(None, 'abc', 'a'), (None, 'def', 'd')] + + +def test_join_table_schema_qualified(): + tables = extract_tables( + 'SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num') + assert tables == [('abc', 'def', 'x'), ('ghi', 'jkl', 'y')] + + +def test_join_as_table(): + tables = extract_tables('SELECT * FROM my_table AS m WHERE m.a > 5') + assert tables == [(None, 'my_table', 'm')] + + +def test_query_starts_with(): + query = 'USE test;' + assert query_starts_with(query, ('use', )) is True + + query = 'DROP DATABASE test;' + assert query_starts_with(query, ('use', )) is False + + +def test_query_starts_with_comment(): + query = '# comment\nUSE test;' + assert query_starts_with(query, ('use', )) is True + + +def test_queries_start_with(): + sql = ( + '# comment\n' + 'show databases;' + 'use foo;' + ) + assert queries_start_with(sql, ('show', 'select')) is True + assert queries_start_with(sql, ('use', 'drop')) is True + assert queries_start_with(sql, ('delete', 'update')) is False + + +def test_is_destructive(): + sql = ( + 'use test;\n' + 'show databases;\n' + 'drop database foo;' + ) + assert is_destructive(sql) is True + + +def test_is_destructive_update_with_where_clause(): + sql = ( + 'use test;\n' + 'show databases;\n' + 'UPDATE test SET x = 1 WHERE id = 1;' + ) + assert is_destructive(sql) is False + + +def test_is_destructive_update_without_where_clause(): + sql = ( + 'use test;\n' + 'show databases;\n' + 'UPDATE test SET x = 1;' + ) + assert is_destructive(sql) is True + + +@pytest.mark.parametrize( + ('sql', 'has_where_clause'), + [ + ('update test set dummy = 1;', False), + ('update test set dummy = 1 where id = 1);', True), + ], +) +def test_query_has_where_clause(sql, has_where_clause): + assert query_has_where_clause(sql) is has_where_clause + + +@pytest.mark.parametrize( + ('sql', 'dbname', 'is_dropping'), + [ + ('select bar from foo', 'foo', False), + ('drop database "foo";', '`foo`', True), + ('drop schema foo', 'foo', True), + ('drop schema foo', 'bar', False), + ('drop database bar', 'foo', False), + ('drop database foo', None, False), + ('drop database foo; create database foo', 'foo', False), + ('drop database foo; create database bar', 'foo', True), + ('select bar from foo; drop database bazz', 'foo', False), + ('select bar from foo; drop database bazz', 'bazz', True), + ('-- dropping database \n ' + 'drop -- really dropping \n ' + 'schema abc -- now it is dropped', + 'abc', + True) + ] +) +def test_is_dropping_database(sql, dbname, is_dropping): + assert is_dropping_database(sql, dbname) == is_dropping diff --git a/test/test_plan.wiki b/test/test_plan.wiki new file mode 100644 index 0000000..43e9083 --- /dev/null +++ b/test/test_plan.wiki @@ -0,0 +1,38 @@ += Gross Checks = + * [ ] Check connecting to a local database. + * [ ] Check connecting to a remote database. + * [ ] Check connecting to a database with a user/password. + * [ ] Check connecting to a non-existent database. + * [ ] Test changing the database. + + == PGExecute == + * [ ] Test successful execution given a cursor. + * [ ] Test unsuccessful execution with a syntax error. + * [ ] Test a series of executions with the same cursor without failure. + * [ ] Test a series of executions with the same cursor with failure. + * [ ] Test passing in a special command. + + == Naive Autocompletion == + * [ ] Input empty string, ask for completions - Everything. + * [ ] Input partial prefix, ask for completions - Stars with prefix. + * [ ] Input fully autocompleted string, ask for completions - Only full match + * [ ] Input non-existent prefix, ask for completions - nothing + * [ ] Input lowercase prefix - case insensitive completions + + == Smart Autocompletion == + * [ ] Input empty string and check if only keywords are returned. + * [ ] Input SELECT prefix and check if only columns and '*' are returned. + * [ ] Input SELECT blah - only keywords are returned. + * [ ] Input SELECT * FROM - Table names only + + == PGSpecial == + * [ ] Test \d + * [ ] Test \d tablename + * [ ] Test \d tablena* + * [ ] Test \d non-existent-tablename + * [ ] Test \d index + * [ ] Test \d sequence + * [ ] Test \d view + + == Exceptionals == + * [ ] Test the 'use' command to change db. diff --git a/test/test_prompt_utils.py b/test/test_prompt_utils.py new file mode 100644 index 0000000..2373fac --- /dev/null +++ b/test/test_prompt_utils.py @@ -0,0 +1,11 @@ +import click + +from mycli.packages.prompt_utils import confirm_destructive_query + + +def test_confirm_destructive_query_notty(): + stdin = click.get_text_stream('stdin') + assert stdin.isatty() is False + + sql = 'drop database foo;' + assert confirm_destructive_query(sql) is None diff --git a/test/test_smart_completion_public_schema_only.py b/test/test_smart_completion_public_schema_only.py new file mode 100644 index 0000000..e7d460a --- /dev/null +++ b/test/test_smart_completion_public_schema_only.py @@ -0,0 +1,385 @@ +import pytest +from unittest.mock import patch +from prompt_toolkit.completion import Completion +from prompt_toolkit.document import Document +import mycli.packages.special.main as special + +metadata = { + 'users': ['id', 'email', 'first_name', 'last_name'], + 'orders': ['id', 'ordered_date', 'status'], + 'select': ['id', 'insert', 'ABC'], + 'réveillé': ['id', 'insert', 'ABC'] +} + + +@pytest.fixture +def completer(): + + import mycli.sqlcompleter as sqlcompleter + comp = sqlcompleter.SQLCompleter(smart_completion=True) + + tables, columns = [], [] + + for table, cols in metadata.items(): + tables.append((table,)) + columns.extend([(table, col) for col in cols]) + + comp.set_dbname('test') + comp.extend_schemata('test') + comp.extend_relations(tables, kind='tables') + comp.extend_columns(columns, kind='tables') + comp.extend_special_commands(special.COMMANDS) + + return comp + + +@pytest.fixture +def complete_event(): + from unittest.mock import Mock + return Mock() + + +def test_special_name_completion(completer, complete_event): + text = '\\d' + position = len('\\d') + result = completer.get_completions( + Document(text=text, cursor_position=position), + complete_event) + assert result == [Completion(text='\\dt', start_position=-2)] + + +def test_empty_string_completion(completer, complete_event): + text = '' + position = 0 + result = list( + completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert list(map(Completion, sorted(completer.keywords) + + sorted(completer.special_commands))) == result + + +def test_select_keyword_completion(completer, complete_event): + text = 'SEL' + position = len('SEL') + result = completer.get_completions( + Document(text=text, cursor_position=position), + complete_event) + assert list(result) == list([Completion(text='SELECT', start_position=-3)]) + + +def test_table_completion(completer, complete_event): + text = 'SELECT * FROM ' + position = len(text) + result = completer.get_completions( + Document(text=text, cursor_position=position), complete_event) + assert list(result) == list([ + Completion(text='`réveillé`', start_position=0), + Completion(text='`select`', start_position=0), + Completion(text='orders', start_position=0), + Completion(text='users', start_position=0), + ]) + + +def test_function_name_completion(completer, complete_event): + text = 'SELECT MA' + position = len('SELECT MA') + result = completer.get_completions( + Document(text=text, cursor_position=position), complete_event) + assert list(result) == list([Completion(text='MAX', start_position=-2), + Completion(text='MASTER', start_position=-2), + ]) + + +def test_suggested_column_names(completer, complete_event): + """Suggest column and function names when selecting from table. + + :param completer: + :param complete_event: + :return: + + """ + text = 'SELECT from users' + position = len('SELECT ') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='*', start_position=0), + Completion(text='email', start_position=0), + Completion(text='first_name', start_position=0), + Completion(text='id', start_position=0), + Completion(text='last_name', start_position=0), + ] + + list(map(Completion, completer.functions)) + + [Completion(text='users', start_position=0)] + + list(map(Completion, completer.keywords))) + + +def test_suggested_column_names_in_function(completer, complete_event): + """Suggest column and function names when selecting multiple columns from + table. + + :param completer: + :param complete_event: + :return: + + """ + text = 'SELECT MAX( from users' + position = len('SELECT MAX(') + result = completer.get_completions( + Document(text=text, cursor_position=position), + complete_event) + assert list(result) == list([ + Completion(text='*', start_position=0), + Completion(text='email', start_position=0), + Completion(text='first_name', start_position=0), + Completion(text='id', start_position=0), + Completion(text='last_name', start_position=0)]) + + +def test_suggested_column_names_with_table_dot(completer, complete_event): + """Suggest column names on table name and dot. + + :param completer: + :param complete_event: + :return: + + """ + text = 'SELECT users. from users' + position = len('SELECT users.') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='*', start_position=0), + Completion(text='email', start_position=0), + Completion(text='first_name', start_position=0), + Completion(text='id', start_position=0), + Completion(text='last_name', start_position=0)]) + + +def test_suggested_column_names_with_alias(completer, complete_event): + """Suggest column names on table alias and dot. + + :param completer: + :param complete_event: + :return: + + """ + text = 'SELECT u. from users u' + position = len('SELECT u.') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='*', start_position=0), + Completion(text='email', start_position=0), + Completion(text='first_name', start_position=0), + Completion(text='id', start_position=0), + Completion(text='last_name', start_position=0)]) + + +def test_suggested_multiple_column_names(completer, complete_event): + """Suggest column and function names when selecting multiple columns from + table. + + :param completer: + :param complete_event: + :return: + + """ + text = 'SELECT id, from users u' + position = len('SELECT id, ') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='*', start_position=0), + Completion(text='email', start_position=0), + Completion(text='first_name', start_position=0), + Completion(text='id', start_position=0), + Completion(text='last_name', start_position=0)] + + list(map(Completion, completer.functions)) + + [Completion(text='u', start_position=0)] + + list(map(Completion, completer.keywords))) + + +def test_suggested_multiple_column_names_with_alias(completer, complete_event): + """Suggest column names on table alias and dot when selecting multiple + columns from table. + + :param completer: + :param complete_event: + :return: + + """ + text = 'SELECT u.id, u. from users u' + position = len('SELECT u.id, u.') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='*', start_position=0), + Completion(text='email', start_position=0), + Completion(text='first_name', start_position=0), + Completion(text='id', start_position=0), + Completion(text='last_name', start_position=0)]) + + +def test_suggested_multiple_column_names_with_dot(completer, complete_event): + """Suggest column names on table names and dot when selecting multiple + columns from table. + + :param completer: + :param complete_event: + :return: + + """ + text = 'SELECT users.id, users. from users u' + position = len('SELECT users.id, users.') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='*', start_position=0), + Completion(text='email', start_position=0), + Completion(text='first_name', start_position=0), + Completion(text='id', start_position=0), + Completion(text='last_name', start_position=0)]) + + +def test_suggested_aliases_after_on(completer, complete_event): + text = 'SELECT u.name, o.id FROM users u JOIN orders o ON ' + position = len('SELECT u.name, o.id FROM users u JOIN orders o ON ') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='o', start_position=0), + Completion(text='u', start_position=0)]) + + +def test_suggested_aliases_after_on_right_side(completer, complete_event): + text = 'SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ' + position = len( + 'SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='o', start_position=0), + Completion(text='u', start_position=0)]) + + +def test_suggested_tables_after_on(completer, complete_event): + text = 'SELECT users.name, orders.id FROM users JOIN orders ON ' + position = len('SELECT users.name, orders.id FROM users JOIN orders ON ') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='orders', start_position=0), + Completion(text='users', start_position=0)]) + + +def test_suggested_tables_after_on_right_side(completer, complete_event): + text = 'SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ' + position = len( + 'SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='orders', start_position=0), + Completion(text='users', start_position=0)]) + + +def test_table_names_after_from(completer, complete_event): + text = 'SELECT * FROM ' + position = len('SELECT * FROM ') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='`réveillé`', start_position=0), + Completion(text='`select`', start_position=0), + Completion(text='orders', start_position=0), + Completion(text='users', start_position=0), + ]) + + +def test_auto_escaped_col_names(completer, complete_event): + text = 'SELECT from `select`' + position = len('SELECT ') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == [ + Completion(text='*', start_position=0), + Completion(text='`ABC`', start_position=0), + Completion(text='`insert`', start_position=0), + Completion(text='id', start_position=0), + ] + \ + list(map(Completion, completer.functions)) + \ + [Completion(text='`select`', start_position=0)] + \ + list(map(Completion, completer.keywords)) + + +def test_un_escaped_table_names(completer, complete_event): + text = 'SELECT from réveillé' + position = len('SELECT ') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='*', start_position=0), + Completion(text='`ABC`', start_position=0), + Completion(text='`insert`', start_position=0), + Completion(text='id', start_position=0), + ] + + list(map(Completion, completer.functions)) + + [Completion(text='réveillé', start_position=0)] + + list(map(Completion, completer.keywords))) + + +def dummy_list_path(dir_name): + dirs = { + '/': [ + 'dir1', + 'file1.sql', + 'file2.sql', + ], + '/dir1': [ + 'subdir1', + 'subfile1.sql', + 'subfile2.sql', + ], + '/dir1/subdir1': [ + 'lastfile.sql', + ], + } + return dirs.get(dir_name, []) + + +@patch('mycli.packages.filepaths.list_path', new=dummy_list_path) +@pytest.mark.parametrize('text,expected', [ + # ('source ', [('~', 0), + # ('/', 0), + # ('.', 0), + # ('..', 0)]), + ('source /', [('dir1', 0), + ('file1.sql', 0), + ('file2.sql', 0)]), + ('source /dir1/', [('subdir1', 0), + ('subfile1.sql', 0), + ('subfile2.sql', 0)]), + ('source /dir1/subdir1/', [('lastfile.sql', 0)]), +]) +def test_file_name_completion(completer, complete_event, text, expected): + position = len(text) + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + expected = list((Completion(txt, pos) for txt, pos in expected)) + assert result == expected diff --git a/test/test_special_iocommands.py b/test/test_special_iocommands.py new file mode 100644 index 0000000..8b6be33 --- /dev/null +++ b/test/test_special_iocommands.py @@ -0,0 +1,287 @@ +import os +import stat +import tempfile +from time import time +from unittest.mock import patch + +import pytest +from pymysql import ProgrammingError + +import mycli.packages.special + +from .utils import dbtest, db_connection, send_ctrl_c + + +def test_set_get_pager(): + mycli.packages.special.set_pager_enabled(True) + assert mycli.packages.special.is_pager_enabled() + mycli.packages.special.set_pager_enabled(False) + assert not mycli.packages.special.is_pager_enabled() + mycli.packages.special.set_pager('less') + assert os.environ['PAGER'] == "less" + mycli.packages.special.set_pager(False) + assert os.environ['PAGER'] == "less" + del os.environ['PAGER'] + mycli.packages.special.set_pager(False) + mycli.packages.special.disable_pager() + assert not mycli.packages.special.is_pager_enabled() + + +def test_set_get_timing(): + mycli.packages.special.set_timing_enabled(True) + assert mycli.packages.special.is_timing_enabled() + mycli.packages.special.set_timing_enabled(False) + assert not mycli.packages.special.is_timing_enabled() + + +def test_set_get_expanded_output(): + mycli.packages.special.set_expanded_output(True) + assert mycli.packages.special.is_expanded_output() + mycli.packages.special.set_expanded_output(False) + assert not mycli.packages.special.is_expanded_output() + + +def test_editor_command(): + assert mycli.packages.special.editor_command(r'hello\e') + assert mycli.packages.special.editor_command(r'\ehello') + assert not mycli.packages.special.editor_command(r'hello') + + assert mycli.packages.special.get_filename(r'\e filename') == "filename" + + os.environ['EDITOR'] = 'true' + os.environ['VISUAL'] = 'true' + mycli.packages.special.open_external_editor(sql=r'select 1') == "select 1" + + +def test_tee_command(): + mycli.packages.special.write_tee(u"hello world") # write without file set + with tempfile.NamedTemporaryFile() as f: + mycli.packages.special.execute(None, u"tee " + f.name) + mycli.packages.special.write_tee(u"hello world") + assert f.read() == b"hello world\n" + + mycli.packages.special.execute(None, u"tee -o " + f.name) + mycli.packages.special.write_tee(u"hello world") + f.seek(0) + assert f.read() == b"hello world\n" + + mycli.packages.special.execute(None, u"notee") + mycli.packages.special.write_tee(u"hello world") + f.seek(0) + assert f.read() == b"hello world\n" + + +def test_tee_command_error(): + with pytest.raises(TypeError): + mycli.packages.special.execute(None, 'tee') + + with pytest.raises(OSError): + with tempfile.NamedTemporaryFile() as f: + os.chmod(f.name, stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH) + mycli.packages.special.execute(None, 'tee {}'.format(f.name)) + + +@dbtest +def test_favorite_query(): + with db_connection().cursor() as cur: + query = u'select "✔"' + mycli.packages.special.execute(cur, u'\\fs check {0}'.format(query)) + assert next(mycli.packages.special.execute( + cur, u'\\f check'))[0] == "> " + query + + +def test_once_command(): + with pytest.raises(TypeError): + mycli.packages.special.execute(None, u"\\once") + + with pytest.raises(OSError): + mycli.packages.special.execute(None, u"\\once /proc/access-denied") + + mycli.packages.special.write_once(u"hello world") # write without file set + with tempfile.NamedTemporaryFile() as f: + mycli.packages.special.execute(None, u"\\once " + f.name) + mycli.packages.special.write_once(u"hello world") + assert f.read() == b"hello world\n" + + mycli.packages.special.execute(None, u"\\once -o " + f.name) + mycli.packages.special.write_once(u"hello world line 1") + mycli.packages.special.write_once(u"hello world line 2") + f.seek(0) + assert f.read() == b"hello world line 1\nhello world line 2\n" + + +def test_pipe_once_command(): + with pytest.raises(IOError): + mycli.packages.special.execute(None, u"\\pipe_once") + + with pytest.raises(OSError): + mycli.packages.special.execute( + None, u"\\pipe_once /proc/access-denied") + + mycli.packages.special.execute(None, u"\\pipe_once wc") + mycli.packages.special.write_once(u"hello world") + mycli.packages.special.unset_pipe_once_if_written() + # how to assert on wc output? + + +def test_parseargfile(): + """Test that parseargfile expands the user directory.""" + expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'), + 'mode': 'a'} + assert expected == mycli.packages.special.iocommands.parseargfile( + '~/filename') + + expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'), + 'mode': 'w'} + assert expected == mycli.packages.special.iocommands.parseargfile( + '-o ~/filename') + + +def test_parseargfile_no_file(): + """Test that parseargfile raises a TypeError if there is no filename.""" + with pytest.raises(TypeError): + mycli.packages.special.iocommands.parseargfile('') + + with pytest.raises(TypeError): + mycli.packages.special.iocommands.parseargfile('-o ') + + +@dbtest +def test_watch_query_iteration(): + """Test that a single iteration of the result of `watch_query` executes + the desired query and returns the given results.""" + expected_value = "1" + query = "SELECT {0!s}".format(expected_value) + expected_title = '> {0!s}'.format(query) + with db_connection().cursor() as cur: + result = next(mycli.packages.special.iocommands.watch_query( + arg=query, cur=cur + )) + assert result[0] == expected_title + assert result[2][0] == expected_value + + +@dbtest +def test_watch_query_full(): + """Test that `watch_query`: + + * Returns the expected results. + * Executes the defined times inside the given interval, in this case with + a 0.3 seconds wait, it should execute 4 times inside a 1 seconds + interval. + * Stops at Ctrl-C + + """ + watch_seconds = 0.3 + wait_interval = 1 + expected_value = "1" + query = "SELECT {0!s}".format(expected_value) + expected_title = '> {0!s}'.format(query) + expected_results = 4 + ctrl_c_process = send_ctrl_c(wait_interval) + with db_connection().cursor() as cur: + results = list( + result for result in mycli.packages.special.iocommands.watch_query( + arg='{0!s} {1!s}'.format(watch_seconds, query), cur=cur + ) + ) + ctrl_c_process.join(1) + assert len(results) == expected_results + for result in results: + assert result[0] == expected_title + assert result[2][0] == expected_value + + +@dbtest +@patch('click.clear') +def test_watch_query_clear(clear_mock): + """Test that the screen is cleared with the -c flag of `watch` command + before execute the query.""" + with db_connection().cursor() as cur: + watch_gen = mycli.packages.special.iocommands.watch_query( + arg='0.1 -c select 1;', cur=cur + ) + assert not clear_mock.called + next(watch_gen) + assert clear_mock.called + clear_mock.reset_mock() + next(watch_gen) + assert clear_mock.called + clear_mock.reset_mock() + + +@dbtest +def test_watch_query_bad_arguments(): + """Test different incorrect combinations of arguments for `watch` + command.""" + watch_query = mycli.packages.special.iocommands.watch_query + with db_connection().cursor() as cur: + with pytest.raises(ProgrammingError): + next(watch_query('a select 1;', cur=cur)) + with pytest.raises(ProgrammingError): + next(watch_query('-a select 1;', cur=cur)) + with pytest.raises(ProgrammingError): + next(watch_query('1 -a select 1;', cur=cur)) + with pytest.raises(ProgrammingError): + next(watch_query('-c -a select 1;', cur=cur)) + + +@dbtest +@patch('click.clear') +def test_watch_query_interval_clear(clear_mock): + """Test `watch` command with interval and clear flag.""" + def test_asserts(gen): + clear_mock.reset_mock() + start = time() + next(gen) + assert clear_mock.called + next(gen) + exec_time = time() - start + assert exec_time > seconds and exec_time < (seconds + seconds) + + seconds = 1.0 + watch_query = mycli.packages.special.iocommands.watch_query + with db_connection().cursor() as cur: + test_asserts(watch_query('{0!s} -c select 1;'.format(seconds), + cur=cur)) + test_asserts(watch_query('-c {0!s} select 1;'.format(seconds), + cur=cur)) + + +def test_split_sql_by_delimiter(): + for delimiter_str in (';', '$', '😀'): + mycli.packages.special.set_delimiter(delimiter_str) + sql_input = "select 1{} select \ufffc2".format(delimiter_str) + queries = ( + "select 1", + "select \ufffc2" + ) + for query, parsed_query in zip( + queries, mycli.packages.special.split_queries(sql_input)): + assert(query == parsed_query) + + +def test_switch_delimiter_within_query(): + mycli.packages.special.set_delimiter(';') + sql_input = "select 1; delimiter $$ select 2 $$ select 3 $$" + queries = ( + "select 1", + "delimiter $$ select 2 $$ select 3 $$", + "select 2", + "select 3" + ) + for query, parsed_query in zip( + queries, + mycli.packages.special.split_queries(sql_input)): + assert(query == parsed_query) + + +def test_set_delimiter(): + + for delim in ('foo', 'bar'): + mycli.packages.special.set_delimiter(delim) + assert mycli.packages.special.get_current_delimiter() == delim + + +def teardown_function(): + mycli.packages.special.set_delimiter(';') diff --git a/test/test_sqlexecute.py b/test/test_sqlexecute.py new file mode 100644 index 0000000..38ca5ef --- /dev/null +++ b/test/test_sqlexecute.py @@ -0,0 +1,295 @@ +import os + +import pytest +import pymysql + +from mycli.sqlexecute import ServerInfo, ServerSpecies +from .utils import run, dbtest, set_expanded_output, is_expanded_output + + +def assert_result_equal(result, title=None, rows=None, headers=None, + status=None, auto_status=True, assert_contains=False): + """Assert that an sqlexecute.run() result matches the expected values.""" + if status is None and auto_status and rows: + status = '{} row{} in set'.format( + len(rows), 's' if len(rows) > 1 else '') + fields = {'title': title, 'rows': rows, 'headers': headers, + 'status': status} + + if assert_contains: + # Do a loose match on the results using the *in* operator. + for key, field in fields.items(): + if field: + assert field in result[0][key] + else: + # Do an exact match on the fields. + assert result == [fields] + + +@dbtest +def test_conn(executor): + run(executor, '''create table test(a text)''') + run(executor, '''insert into test values('abc')''') + results = run(executor, '''select * from test''') + + assert_result_equal(results, headers=['a'], rows=[('abc',)]) + + +@dbtest +def test_bools(executor): + run(executor, '''create table test(a boolean)''') + run(executor, '''insert into test values(True)''') + results = run(executor, '''select * from test''') + + assert_result_equal(results, headers=['a'], rows=[(1,)]) + + +@dbtest +def test_binary(executor): + run(executor, '''create table bt(geom linestring NOT NULL)''') + run(executor, "INSERT INTO bt VALUES " + "(ST_GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));") + results = run(executor, '''select * from bt''') + + geom = (b'\x00\x00\x00\x00\x01\x02\x00\x00\x00\x02\x00\x00\x009\x7f\x13\n' + b'\x11\x18]@4\xf4Op\xb1\xdeC@\x00\x00\x00\x00\x00\x18]@B>\xe8\xd9' + b'\xac\xdeC@') + + assert_result_equal(results, headers=['geom'], rows=[(geom,)]) + + +@dbtest +def test_table_and_columns_query(executor): + run(executor, "create table a(x text, y text)") + run(executor, "create table b(z text)") + + assert set(executor.tables()) == set([('a',), ('b',)]) + assert set(executor.table_columns()) == set( + [('a', 'x'), ('a', 'y'), ('b', 'z')]) + + +@dbtest +def test_database_list(executor): + databases = executor.databases() + assert 'mycli_test_db' in databases + + +@dbtest +def test_invalid_syntax(executor): + with pytest.raises(pymysql.ProgrammingError) as excinfo: + run(executor, 'invalid syntax!') + assert 'You have an error in your SQL syntax;' in str(excinfo.value) + + +@dbtest +def test_invalid_column_name(executor): + with pytest.raises(pymysql.err.OperationalError) as excinfo: + run(executor, 'select invalid command') + assert "Unknown column 'invalid' in 'field list'" in str(excinfo.value) + + +@dbtest +def test_unicode_support_in_output(executor): + run(executor, "create table unicodechars(t text)") + run(executor, u"insert into unicodechars (t) values ('é')") + + # See issue #24, this raises an exception without proper handling + results = run(executor, u"select * from unicodechars") + assert_result_equal(results, headers=['t'], rows=[(u'é',)]) + + +@dbtest +def test_multiple_queries_same_line(executor): + results = run(executor, "select 'foo'; select 'bar'") + + expected = [{'title': None, 'headers': ['foo'], 'rows': [('foo',)], + 'status': '1 row in set'}, + {'title': None, 'headers': ['bar'], 'rows': [('bar',)], + 'status': '1 row in set'}] + assert expected == results + + +@dbtest +def test_multiple_queries_same_line_syntaxerror(executor): + with pytest.raises(pymysql.ProgrammingError) as excinfo: + run(executor, "select 'foo'; invalid syntax") + assert 'You have an error in your SQL syntax;' in str(excinfo.value) + + +@dbtest +def test_favorite_query(executor): + set_expanded_output(False) + run(executor, "create table test(a text)") + run(executor, "insert into test values('abc')") + run(executor, "insert into test values('def')") + + results = run(executor, "\\fs test-a select * from test where a like 'a%'") + assert_result_equal(results, status='Saved.') + + results = run(executor, "\\f test-a") + assert_result_equal(results, + title="> select * from test where a like 'a%'", + headers=['a'], rows=[('abc',)], auto_status=False) + + results = run(executor, "\\fd test-a") + assert_result_equal(results, status='test-a: Deleted') + + +@dbtest +def test_favorite_query_multiple_statement(executor): + set_expanded_output(False) + run(executor, "create table test(a text)") + run(executor, "insert into test values('abc')") + run(executor, "insert into test values('def')") + + results = run(executor, + "\\fs test-ad select * from test where a like 'a%'; " + "select * from test where a like 'd%'") + assert_result_equal(results, status='Saved.') + + results = run(executor, "\\f test-ad") + expected = [{'title': "> select * from test where a like 'a%'", + 'headers': ['a'], 'rows': [('abc',)], 'status': None}, + {'title': "> select * from test where a like 'd%'", + 'headers': ['a'], 'rows': [('def',)], 'status': None}] + assert expected == results + + results = run(executor, "\\fd test-ad") + assert_result_equal(results, status='test-ad: Deleted') + + +@dbtest +def test_favorite_query_expanded_output(executor): + set_expanded_output(False) + run(executor, '''create table test(a text)''') + run(executor, '''insert into test values('abc')''') + + results = run(executor, "\\fs test-ae select * from test") + assert_result_equal(results, status='Saved.') + + results = run(executor, "\\f test-ae \\G") + assert is_expanded_output() is True + assert_result_equal(results, title='> select * from test', + headers=['a'], rows=[('abc',)], auto_status=False) + + set_expanded_output(False) + + results = run(executor, "\\fd test-ae") + assert_result_equal(results, status='test-ae: Deleted') + + +@dbtest +def test_special_command(executor): + results = run(executor, '\\?') + assert_result_equal(results, rows=('quit', '\\q', 'Quit.'), + headers='Command', assert_contains=True, + auto_status=False) + + +@dbtest +def test_cd_command_without_a_folder_name(executor): + results = run(executor, 'system cd') + assert_result_equal(results, status='No folder name was provided.') + + +@dbtest +def test_system_command_not_found(executor): + results = run(executor, 'system xyz') + assert_result_equal(results, status='OSError: No such file or directory', + assert_contains=True) + + +@dbtest +def test_system_command_output(executor): + test_dir = os.path.abspath(os.path.dirname(__file__)) + test_file_path = os.path.join(test_dir, 'test.txt') + results = run(executor, 'system cat {0}'.format(test_file_path)) + assert_result_equal(results, status='mycli rocks!\n') + + +@dbtest +def test_cd_command_current_dir(executor): + test_path = os.path.abspath(os.path.dirname(__file__)) + run(executor, 'system cd {0}'.format(test_path)) + assert os.getcwd() == test_path + + +@dbtest +def test_unicode_support(executor): + results = run(executor, u"SELECT '日本語' AS japanese;") + assert_result_equal(results, headers=['japanese'], rows=[(u'日本語',)]) + + +@dbtest +def test_timestamp_null(executor): + run(executor, '''create table ts_null(a timestamp null)''') + run(executor, '''insert into ts_null values(null)''') + results = run(executor, '''select * from ts_null''') + assert_result_equal(results, headers=['a'], + rows=[(None,)]) + + +@dbtest +def test_datetime_null(executor): + run(executor, '''create table dt_null(a datetime null)''') + run(executor, '''insert into dt_null values(null)''') + results = run(executor, '''select * from dt_null''') + assert_result_equal(results, headers=['a'], + rows=[(None,)]) + + +@dbtest +def test_date_null(executor): + run(executor, '''create table date_null(a date null)''') + run(executor, '''insert into date_null values(null)''') + results = run(executor, '''select * from date_null''') + assert_result_equal(results, headers=['a'], rows=[(None,)]) + + +@dbtest +def test_time_null(executor): + run(executor, '''create table time_null(a time null)''') + run(executor, '''insert into time_null values(null)''') + results = run(executor, '''select * from time_null''') + assert_result_equal(results, headers=['a'], rows=[(None,)]) + + +@dbtest +def test_multiple_results(executor): + query = '''CREATE PROCEDURE dmtest() + BEGIN + SELECT 1; + SELECT 2; + END''' + executor.conn.cursor().execute(query) + + results = run(executor, 'call dmtest;') + expected = [ + {'title': None, 'rows': [(1,)], 'headers': ['1'], + 'status': '1 row in set'}, + {'title': None, 'rows': [(2,)], 'headers': ['2'], + 'status': '1 row in set'} + ] + assert results == expected + + +@pytest.mark.parametrize( + 'version_string, species, parsed_version_string, version', + ( + ('5.7.25-TiDB-v6.1.0','TiDB', '5.7.25', 50725), + ('5.7.32-35', 'Percona', '5.7.32', 50732), + ('5.7.32-0ubuntu0.18.04.1', 'MySQL', '5.7.32', 50732), + ('10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508), + ('5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508), + ('5.0.16-pro-nt-log', 'MySQL', '5.0.16', 50016), + ('5.1.5a-alpha', 'MySQL', '5.1.5', 50105), + ('unexpected version string', None, '', 0), + ('', None, '', 0), + (None, None, '', 0), + ) +) +def test_version_parsing(version_string, species, parsed_version_string, version): + server_info = ServerInfo.from_version_string(version_string) + assert (server_info.species and server_info.species.name) == species or ServerSpecies.Unknown + assert server_info.version_str == parsed_version_string + assert server_info.version == version diff --git a/test/test_tabular_output.py b/test/test_tabular_output.py new file mode 100644 index 0000000..c20c7de --- /dev/null +++ b/test/test_tabular_output.py @@ -0,0 +1,118 @@ +"""Test the sql output adapter.""" + +from textwrap import dedent + +from mycli.packages.tabular_output import sql_format +from cli_helpers.tabular_output import TabularOutputFormatter + +from .utils import USER, PASSWORD, HOST, PORT, dbtest + +import pytest +from mycli.main import MyCli + +from pymysql.constants import FIELD_TYPE + + +@pytest.fixture +def mycli(): + cli = MyCli() + cli.connect(None, USER, PASSWORD, HOST, PORT, None, init_command=None) + return cli + + +@dbtest +def test_sql_output(mycli): + """Test the sql output adapter.""" + headers = ['letters', 'number', 'optional', 'float', 'binary'] + + class FakeCursor(object): + def __init__(self): + self.data = [ + ('abc', 1, None, 10.0, b'\xAA'), + ('d', 456, '1', 0.5, b'\xAA\xBB') + ] + self.description = [ + (None, FIELD_TYPE.VARCHAR), + (None, FIELD_TYPE.LONG), + (None, FIELD_TYPE.LONG), + (None, FIELD_TYPE.FLOAT), + (None, FIELD_TYPE.BLOB) + ] + + def __iter__(self): + return self + + def __next__(self): + if self.data: + return self.data.pop(0) + else: + raise StopIteration() + + def description(self): + return self.description + + # Test sql-update output format + assert list(mycli.change_table_format("sql-update")) == \ + [(None, None, None, 'Changed table format to sql-update')] + mycli.formatter.query = "" + output = mycli.format_output(None, FakeCursor(), headers) + actual = "\n".join(output) + assert actual == dedent('''\ + UPDATE `DUAL` SET + `number` = 1 + , `optional` = NULL + , `float` = 10.0e0 + , `binary` = X'aa' + WHERE `letters` = 'abc'; + UPDATE `DUAL` SET + `number` = 456 + , `optional` = '1' + , `float` = 0.5e0 + , `binary` = X'aabb' + WHERE `letters` = 'd';''') + # Test sql-update-2 output format + assert list(mycli.change_table_format("sql-update-2")) == \ + [(None, None, None, 'Changed table format to sql-update-2')] + mycli.formatter.query = "" + output = mycli.format_output(None, FakeCursor(), headers) + assert "\n".join(output) == dedent('''\ + UPDATE `DUAL` SET + `optional` = NULL + , `float` = 10.0e0 + , `binary` = X'aa' + WHERE `letters` = 'abc' AND `number` = 1; + UPDATE `DUAL` SET + `optional` = '1' + , `float` = 0.5e0 + , `binary` = X'aabb' + WHERE `letters` = 'd' AND `number` = 456;''') + # Test sql-insert output format (without table name) + assert list(mycli.change_table_format("sql-insert")) == \ + [(None, None, None, 'Changed table format to sql-insert')] + mycli.formatter.query = "" + output = mycli.format_output(None, FakeCursor(), headers) + assert "\n".join(output) == dedent('''\ + INSERT INTO `DUAL` (`letters`, `number`, `optional`, `float`, `binary`) VALUES + ('abc', 1, NULL, 10.0e0, X'aa') + , ('d', 456, '1', 0.5e0, X'aabb') + ;''') + # Test sql-insert output format (with table name) + assert list(mycli.change_table_format("sql-insert")) == \ + [(None, None, None, 'Changed table format to sql-insert')] + mycli.formatter.query = "SELECT * FROM `table`" + output = mycli.format_output(None, FakeCursor(), headers) + assert "\n".join(output) == dedent('''\ + INSERT INTO `table` (`letters`, `number`, `optional`, `float`, `binary`) VALUES + ('abc', 1, NULL, 10.0e0, X'aa') + , ('d', 456, '1', 0.5e0, X'aabb') + ;''') + # Test sql-insert output format (with database + table name) + assert list(mycli.change_table_format("sql-insert")) == \ + [(None, None, None, 'Changed table format to sql-insert')] + mycli.formatter.query = "SELECT * FROM `database`.`table`" + output = mycli.format_output(None, FakeCursor(), headers) + assert "\n".join(output) == dedent('''\ + INSERT INTO `database`.`table` (`letters`, `number`, `optional`, `float`, `binary`) VALUES + ('abc', 1, NULL, 10.0e0, X'aa') + , ('d', 456, '1', 0.5e0, X'aabb') + ;''') diff --git a/test/utils.py b/test/utils.py new file mode 100644 index 0000000..ab12248 --- /dev/null +++ b/test/utils.py @@ -0,0 +1,94 @@ +import os +import time +import signal +import platform +import multiprocessing + +import pymysql +import pytest + +from mycli.main import special + +PASSWORD = os.getenv('PYTEST_PASSWORD') +USER = os.getenv('PYTEST_USER', 'root') +HOST = os.getenv('PYTEST_HOST', 'localhost') +PORT = int(os.getenv('PYTEST_PORT', 3306)) +CHARSET = os.getenv('PYTEST_CHARSET', 'utf8') +SSH_USER = os.getenv('PYTEST_SSH_USER', None) +SSH_HOST = os.getenv('PYTEST_SSH_HOST', None) +SSH_PORT = os.getenv('PYTEST_SSH_PORT', 22) + + +def db_connection(dbname=None): + conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname, + password=PASSWORD, charset=CHARSET, + local_infile=False) + conn.autocommit = True + return conn + + +try: + db_connection() + CAN_CONNECT_TO_DB = True +except: + CAN_CONNECT_TO_DB = False + +dbtest = pytest.mark.skipif( + not CAN_CONNECT_TO_DB, + reason="Need a mysql instance at localhost accessible by user 'root'") + + +def create_db(dbname): + with db_connection().cursor() as cur: + try: + cur.execute('''DROP DATABASE IF EXISTS mycli_test_db''') + cur.execute('''CREATE DATABASE mycli_test_db''') + except: + pass + + +def run(executor, sql, rows_as_list=True): + """Return string output for the sql to be run.""" + result = [] + + for title, rows, headers, status in executor.run(sql): + rows = list(rows) if (rows_as_list and rows) else rows + result.append({'title': title, 'rows': rows, 'headers': headers, + 'status': status}) + + return result + + +def set_expanded_output(is_expanded): + """Pass-through for the tests.""" + return special.set_expanded_output(is_expanded) + + +def is_expanded_output(): + """Pass-through for the tests.""" + return special.is_expanded_output() + + +def send_ctrl_c_to_pid(pid, wait_seconds): + """Sends a Ctrl-C like signal to the given `pid` after `wait_seconds` + seconds.""" + time.sleep(wait_seconds) + system_name = platform.system() + if system_name == "Windows": + os.kill(pid, signal.CTRL_C_EVENT) + else: + os.kill(pid, signal.SIGINT) + + +def send_ctrl_c(wait_seconds): + """Create a process that sends a Ctrl-C like signal to the current process + after `wait_seconds` seconds. + + Returns the `multiprocessing.Process` created. + + """ + ctrl_c_process = multiprocessing.Process( + target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds) + ) + ctrl_c_process.start() + return ctrl_c_process |