summaryrefslogtreecommitdiffstats
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/__init__.py0
-rw-r--r--test/conftest.py29
-rw-r--r--test/features/__init__.py0
-rw-r--r--test/features/auto_vertical.feature12
-rw-r--r--test/features/basic_commands.feature19
-rw-r--r--test/features/connection.feature35
-rw-r--r--test/features/crud_database.feature30
-rw-r--r--test/features/crud_table.feature49
-rw-r--r--test/features/db_utils.py93
-rw-r--r--test/features/environment.py176
-rw-r--r--test/features/fixture_data/help.txt24
-rw-r--r--test/features/fixture_data/help_commands.txt31
-rw-r--r--test/features/fixture_utils.py29
-rw-r--r--test/features/iocommands.feature47
-rw-r--r--test/features/named_queries.feature24
-rw-r--r--test/features/specials.feature7
-rw-r--r--test/features/steps/__init__.py0
-rw-r--r--test/features/steps/auto_vertical.py46
-rw-r--r--test/features/steps/basic_commands.py100
-rw-r--r--test/features/steps/connection.py71
-rw-r--r--test/features/steps/crud_database.py115
-rw-r--r--test/features/steps/crud_table.py112
-rw-r--r--test/features/steps/iocommands.py105
-rw-r--r--test/features/steps/named_queries.py90
-rw-r--r--test/features/steps/specials.py27
-rw-r--r--test/features/steps/utils.py12
-rw-r--r--test/features/steps/wrappers.py117
-rwxr-xr-xtest/features/wrappager.py16
-rw-r--r--test/myclirc12
-rw-r--r--test/mylogin.cnfbin0 -> 156 bytes
-rw-r--r--test/test.txt1
-rw-r--r--test/test_clistyle.py27
-rw-r--r--test/test_completion_engine.py555
-rw-r--r--test/test_completion_refresher.py88
-rw-r--r--test/test_config.py196
-rw-r--r--test/test_dbspecial.py42
-rw-r--r--test/test_main.py548
-rw-r--r--test/test_naive_completion.py63
-rw-r--r--test/test_parseutils.py190
-rw-r--r--test/test_plan.wiki38
-rw-r--r--test/test_prompt_utils.py11
-rw-r--r--test/test_smart_completion_public_schema_only.py385
-rw-r--r--test/test_special_iocommands.py287
-rw-r--r--test/test_sqlexecute.py295
-rw-r--r--test/test_tabular_output.py118
-rw-r--r--test/utils.py94
46 files changed, 4366 insertions, 0 deletions
diff --git a/test/__init__.py b/test/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/__init__.py
diff --git a/test/conftest.py b/test/conftest.py
new file mode 100644
index 0000000..1325596
--- /dev/null
+++ b/test/conftest.py
@@ -0,0 +1,29 @@
+import pytest
+from .utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db,
+ db_connection, SSH_USER, SSH_HOST, SSH_PORT)
+import mycli.sqlexecute
+
+
+@pytest.fixture(scope="function")
+def connection():
+ create_db('mycli_test_db')
+ connection = db_connection('mycli_test_db')
+ yield connection
+
+ connection.close()
+
+
+@pytest.fixture
+def cursor(connection):
+ with connection.cursor() as cur:
+ return cur
+
+
+@pytest.fixture
+def executor(connection):
+ return mycli.sqlexecute.SQLExecute(
+ database='mycli_test_db', user=USER,
+ host=HOST, password=PASSWORD, port=PORT, socket=None, charset=CHARSET,
+ local_infile=False, ssl=None, ssh_user=SSH_USER, ssh_host=SSH_HOST,
+ ssh_port=SSH_PORT, ssh_password=None, ssh_key_filename=None
+ )
diff --git a/test/features/__init__.py b/test/features/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/features/__init__.py
diff --git a/test/features/auto_vertical.feature b/test/features/auto_vertical.feature
new file mode 100644
index 0000000..aa95718
--- /dev/null
+++ b/test/features/auto_vertical.feature
@@ -0,0 +1,12 @@
+Feature: auto_vertical mode:
+ on, off
+
+ Scenario: auto_vertical on with small query
+ When we run dbcli with --auto-vertical-output
+ and we execute a small query
+ then we see small results in horizontal format
+
+ Scenario: auto_vertical on with large query
+ When we run dbcli with --auto-vertical-output
+ and we execute a large query
+ then we see large results in vertical format
diff --git a/test/features/basic_commands.feature b/test/features/basic_commands.feature
new file mode 100644
index 0000000..a12e899
--- /dev/null
+++ b/test/features/basic_commands.feature
@@ -0,0 +1,19 @@
+Feature: run the cli,
+ call the help command,
+ exit the cli
+
+ Scenario: run "\?" command
+ When we send "\?" command
+ then we see help output
+
+ Scenario: run source command
+ When we send source command
+ then we see help output
+
+ Scenario: check our application_name
+ When we run query to check application_name
+ then we see found
+
+ Scenario: run the cli and exit
+ When we send "ctrl + d"
+ then dbcli exits
diff --git a/test/features/connection.feature b/test/features/connection.feature
new file mode 100644
index 0000000..b06935e
--- /dev/null
+++ b/test/features/connection.feature
@@ -0,0 +1,35 @@
+Feature: connect to a database:
+
+ @requires_local_db
+ Scenario: run mycli on localhost without port
+ When we run mycli with arguments "host=localhost" without arguments "port"
+ When we query "status"
+ Then status contains "via UNIX socket"
+
+ Scenario: run mycli on TCP host without port
+ When we run mycli without arguments "port"
+ When we query "status"
+ Then status contains "via TCP/IP"
+
+ Scenario: run mycli with port but without host
+ When we run mycli without arguments "host"
+ When we query "status"
+ Then status contains "via TCP/IP"
+
+ @requires_local_db
+ Scenario: run mycli without host and port
+ When we run mycli without arguments "host port"
+ When we query "status"
+ Then status contains "via UNIX socket"
+
+ Scenario: run mycli with my.cnf configuration
+ When we create my.cnf file
+ When we run mycli without arguments "host port user pass defaults_file"
+ Then we are logged in
+
+ Scenario: run mycli with mylogin.cnf configuration
+ When we create mylogin.cnf file
+ When we run mycli with arguments "login_path=test_login_path" without arguments "host port user pass defaults_file"
+ Then we are logged in
+
+
diff --git a/test/features/crud_database.feature b/test/features/crud_database.feature
new file mode 100644
index 0000000..f4a7a7f
--- /dev/null
+++ b/test/features/crud_database.feature
@@ -0,0 +1,30 @@
+Feature: manipulate databases:
+ create, drop, connect, disconnect
+
+ Scenario: create and drop temporary database
+ When we create database
+ then we see database created
+ when we drop database
+ then we confirm the destructive warning
+ then we see database dropped
+ when we connect to dbserver
+ then we see database connected
+
+ Scenario: connect and disconnect from test database
+ When we connect to test database
+ then we see database connected
+ when we connect to dbserver
+ then we see database connected
+
+ Scenario: connect and disconnect from quoted test database
+ When we connect to quoted test database
+ then we see database connected
+
+ Scenario: create and drop default database
+ When we create database
+ then we see database created
+ when we connect to tmp database
+ then we see database connected
+ when we drop database
+ then we confirm the destructive warning
+ then we see database dropped and no default database
diff --git a/test/features/crud_table.feature b/test/features/crud_table.feature
new file mode 100644
index 0000000..3384efd
--- /dev/null
+++ b/test/features/crud_table.feature
@@ -0,0 +1,49 @@
+Feature: manipulate tables:
+ create, insert, update, select, delete from, drop
+
+ Scenario: create, insert, select from, update, drop table
+ When we connect to test database
+ then we see database connected
+ when we create table
+ then we see table created
+ when we insert into table
+ then we see record inserted
+ when we update table
+ then we see record updated
+ when we select from table
+ then we see data selected
+ when we delete from table
+ then we confirm the destructive warning
+ then we see record deleted
+ when we drop table
+ then we confirm the destructive warning
+ then we see table dropped
+ when we connect to dbserver
+ then we see database connected
+
+ Scenario: select null values
+ When we connect to test database
+ then we see database connected
+ when we select null
+ then we see null selected
+
+ Scenario: confirm destructive query
+ When we query "create table foo(x integer);"
+ and we query "delete from foo;"
+ and we answer the destructive warning with "y"
+ then we see text "Your call!"
+
+ Scenario: decline destructive query
+ When we query "delete from foo;"
+ and we answer the destructive warning with "n"
+ then we see text "Wise choice!"
+
+ Scenario: no destructive warning if disabled in config
+ When we run dbcli with --no-warn
+ and we query "create table blabla(x integer);"
+ and we query "delete from blabla;"
+ Then we see text "Query OK"
+
+ Scenario: confirm destructive query with invalid response
+ When we query "delete from foo;"
+ then we answer the destructive warning with invalid "1" and see text "is not a valid boolean"
diff --git a/test/features/db_utils.py b/test/features/db_utils.py
new file mode 100644
index 0000000..be550e9
--- /dev/null
+++ b/test/features/db_utils.py
@@ -0,0 +1,93 @@
+import pymysql
+
+
+def create_db(hostname='localhost', port=3306, username=None,
+ password=None, dbname=None):
+ """Create test database.
+
+ :param hostname: string
+ :param port: int
+ :param username: string
+ :param password: string
+ :param dbname: string
+ :return:
+
+ """
+ cn = pymysql.connect(
+ host=hostname,
+ port=port,
+ user=username,
+ password=password,
+ charset='utf8mb4',
+ cursorclass=pymysql.cursors.DictCursor
+ )
+
+ with cn.cursor() as cr:
+ cr.execute('drop database if exists ' + dbname)
+ cr.execute('create database ' + dbname)
+
+ cn.close()
+
+ cn = create_cn(hostname, port, password, username, dbname)
+ return cn
+
+
+def create_cn(hostname, port, password, username, dbname):
+ """Open connection to database.
+
+ :param hostname:
+ :param port:
+ :param password:
+ :param username:
+ :param dbname: string
+ :return: psycopg2.connection
+
+ """
+ cn = pymysql.connect(
+ host=hostname,
+ port=port,
+ user=username,
+ password=password,
+ db=dbname,
+ charset='utf8mb4',
+ cursorclass=pymysql.cursors.DictCursor
+ )
+
+ return cn
+
+
+def drop_db(hostname='localhost', port=3306, username=None,
+ password=None, dbname=None):
+ """Drop database.
+
+ :param hostname: string
+ :param port: int
+ :param username: string
+ :param password: string
+ :param dbname: string
+
+ """
+ cn = pymysql.connect(
+ host=hostname,
+ port=port,
+ user=username,
+ password=password,
+ db=dbname,
+ charset='utf8mb4',
+ cursorclass=pymysql.cursors.DictCursor
+ )
+
+ with cn.cursor() as cr:
+ cr.execute('drop database if exists ' + dbname)
+
+ close_cn(cn)
+
+
+def close_cn(cn=None):
+ """Close connection.
+
+ :param connection: pymysql.connection
+
+ """
+ if cn:
+ cn.close()
diff --git a/test/features/environment.py b/test/features/environment.py
new file mode 100644
index 0000000..1ea0f08
--- /dev/null
+++ b/test/features/environment.py
@@ -0,0 +1,176 @@
+import os
+import shutil
+import sys
+from tempfile import mkstemp
+
+import db_utils as dbutils
+import fixture_utils as fixutils
+import pexpect
+
+from steps.wrappers import run_cli, wait_prompt
+
+test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log')
+
+
+SELF_CONNECTING_FEATURES = (
+ 'test/features/connection.feature',
+)
+
+
+MY_CNF_PATH = os.path.expanduser('~/.my.cnf')
+MY_CNF_BACKUP_PATH = f'{MY_CNF_PATH}.backup'
+MYLOGIN_CNF_PATH = os.path.expanduser('~/.mylogin.cnf')
+MYLOGIN_CNF_BACKUP_PATH = f'{MYLOGIN_CNF_PATH}.backup'
+
+
+def get_db_name_from_context(context):
+ return context.config.userdata.get(
+ 'my_test_db', None
+ ) or "mycli_behave_tests"
+
+
+
+def before_all(context):
+ """Set env parameters."""
+ os.environ['LINES'] = "100"
+ os.environ['COLUMNS'] = "100"
+ os.environ['EDITOR'] = 'ex'
+ os.environ['LC_ALL'] = 'en_US.UTF-8'
+ os.environ['PROMPT_TOOLKIT_NO_CPR'] = '1'
+ os.environ['MYCLI_HISTFILE'] = os.devnull
+
+ test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
+ login_path_file = os.path.join(test_dir, 'mylogin.cnf')
+# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
+
+ context.package_root = os.path.abspath(
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
+
+ os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root,
+ '.coveragerc')
+
+ context.exit_sent = False
+
+ vi = '_'.join([str(x) for x in sys.version_info[:3]])
+ db_name = get_db_name_from_context(context)
+ db_name_full = '{0}_{1}'.format(db_name, vi)
+
+ # Store get params from config/environment variables
+ context.conf = {
+ 'host': context.config.userdata.get(
+ 'my_test_host',
+ os.getenv('PYTEST_HOST', 'localhost')
+ ),
+ 'port': context.config.userdata.get(
+ 'my_test_port',
+ int(os.getenv('PYTEST_PORT', '3306'))
+ ),
+ 'user': context.config.userdata.get(
+ 'my_test_user',
+ os.getenv('PYTEST_USER', 'root')
+ ),
+ 'pass': context.config.userdata.get(
+ 'my_test_pass',
+ os.getenv('PYTEST_PASSWORD', None)
+ ),
+ 'cli_command': context.config.userdata.get(
+ 'my_cli_command', None) or
+ sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"',
+ 'dbname': db_name,
+ 'dbname_tmp': db_name_full + '_tmp',
+ 'vi': vi,
+ 'pager_boundary': '---boundary---',
+ }
+
+ _, my_cnf = mkstemp()
+ with open(my_cnf, 'w') as f:
+ f.write(
+ '[client]\n'
+ 'pager={0} {1} {2}\n'.format(
+ sys.executable, os.path.join(context.package_root,
+ 'test/features/wrappager.py'),
+ context.conf['pager_boundary'])
+ )
+ context.conf['defaults-file'] = my_cnf
+ context.conf['myclirc'] = os.path.join(context.package_root, 'test',
+ 'myclirc')
+
+ context.cn = dbutils.create_db(context.conf['host'], context.conf['port'],
+ context.conf['user'],
+ context.conf['pass'],
+ context.conf['dbname'])
+
+ context.fixture_data = fixutils.read_fixture_files()
+
+
+def after_all(context):
+ """Unset env parameters."""
+ dbutils.close_cn(context.cn)
+ dbutils.drop_db(context.conf['host'], context.conf['port'],
+ context.conf['user'], context.conf['pass'],
+ context.conf['dbname'])
+
+ # Restore env vars.
+ #for k, v in context.pgenv.items():
+ # if k in os.environ and v is None:
+ # del os.environ[k]
+ # elif v:
+ # os.environ[k] = v
+
+
+def before_step(context, _):
+ context.atprompt = False
+
+
+def before_scenario(context, arg):
+ with open(test_log_file, 'w') as f:
+ f.write('')
+ if arg.location.filename not in SELF_CONNECTING_FEATURES:
+ run_cli(context)
+ wait_prompt(context)
+
+ if os.path.exists(MY_CNF_PATH):
+ shutil.move(MY_CNF_PATH, MY_CNF_BACKUP_PATH)
+
+ if os.path.exists(MYLOGIN_CNF_PATH):
+ shutil.move(MYLOGIN_CNF_PATH, MYLOGIN_CNF_BACKUP_PATH)
+
+
+def after_scenario(context, _):
+ """Cleans up after each test complete."""
+ with open(test_log_file) as f:
+ for line in f:
+ if 'error' in line.lower():
+ raise RuntimeError(f'Error in log file: {line}')
+
+ if hasattr(context, 'cli') and not context.exit_sent:
+ # Quit nicely.
+ if not context.atprompt:
+ user = context.conf['user']
+ host = context.conf['host']
+ dbname = context.currentdb
+ context.cli.expect_exact(
+ '{0}@{1}:{2}>'.format(
+ user, host, dbname
+ ),
+ timeout=5
+ )
+ context.cli.sendcontrol('c')
+ context.cli.sendcontrol('d')
+ context.cli.expect_exact(pexpect.EOF, timeout=5)
+
+ if os.path.exists(MY_CNF_BACKUP_PATH):
+ shutil.move(MY_CNF_BACKUP_PATH, MY_CNF_PATH)
+
+ if os.path.exists(MYLOGIN_CNF_BACKUP_PATH):
+ shutil.move(MYLOGIN_CNF_BACKUP_PATH, MYLOGIN_CNF_PATH)
+ elif os.path.exists(MYLOGIN_CNF_PATH):
+ # This file was moved in `before_scenario`.
+ # If it exists now, it has been created during a test
+ os.remove(MYLOGIN_CNF_PATH)
+
+
+# TODO: uncomment to debug a failure
+# def after_step(context, step):
+# if step.status == "failed":
+# import ipdb; ipdb.set_trace()
diff --git a/test/features/fixture_data/help.txt b/test/features/fixture_data/help.txt
new file mode 100644
index 0000000..deb499a
--- /dev/null
+++ b/test/features/fixture_data/help.txt
@@ -0,0 +1,24 @@
++--------------------------+-----------------------------------------------+
+| Command | Description |
+|--------------------------+-----------------------------------------------|
+| \# | Refresh auto-completions. |
+| \? | Show Help. |
+| \c[onnect] database_name | Change to a new database. |
+| \d [pattern] | List or describe tables, views and sequences. |
+| \dT[S+] [pattern] | List data types |
+| \df[+] [pattern] | List functions. |
+| \di[+] [pattern] | List indexes. |
+| \dn[+] [pattern] | List schemas. |
+| \ds[+] [pattern] | List sequences. |
+| \dt[+] [pattern] | List tables. |
+| \du[+] [pattern] | List roles. |
+| \dv[+] [pattern] | List views. |
+| \e [file] | Edit the query with external editor. |
+| \l | List databases. |
+| \n[+] [name] | List or execute named queries. |
+| \nd [name [query]] | Delete a named query. |
+| \ns name query | Save a named query. |
+| \refresh | Refresh auto-completions. |
+| \timing | Toggle timing of commands. |
+| \x | Toggle expanded output. |
++--------------------------+-----------------------------------------------+
diff --git a/test/features/fixture_data/help_commands.txt b/test/features/fixture_data/help_commands.txt
new file mode 100644
index 0000000..2c06d5d
--- /dev/null
+++ b/test/features/fixture_data/help_commands.txt
@@ -0,0 +1,31 @@
++-------------+----------------------------+------------------------------------------------------------+
+| Command | Shortcut | Description |
++-------------+----------------------------+------------------------------------------------------------+
+| \G | \G | Display current query results vertically. |
+| \clip | \clip | Copy query to the system clipboard. |
+| \dt | \dt[+] [table] | List or describe tables. |
+| \e | \e | Edit command with editor (uses $EDITOR). |
+| \f | \f [name [args..]] | List or execute favorite queries. |
+| \fd | \fd [name] | Delete a favorite query. |
+| \fs | \fs name query | Save a favorite query. |
+| \l | \l | List databases. |
+| \once | \o [-o] filename | Append next result to an output file (overwrite using -o). |
+| \pipe_once | \| command | Send next result to a subprocess. |
+| \timing | \t | Toggle timing of commands. |
+| connect | \r | Reconnect to the database. Optional database argument. |
+| exit | \q | Exit. |
+| help | \? | Show this help. |
+| nopager | \n | Disable pager, print to stdout. |
+| notee | notee | Stop writing results to an output file. |
+| pager | \P [command] | Set PAGER. Print the query results via PAGER. |
+| prompt | \R | Change prompt format. |
+| quit | \q | Quit. |
+| rehash | \# | Refresh auto-completions. |
+| source | \. filename | Execute commands from file. |
+| status | \s | Get status information from the server. |
+| system | system [command] | Execute a system shell commmand. |
+| tableformat | \T | Change the table format used to output results. |
+| tee | tee [-o] filename | Append all results to an output file (overwrite using -o). |
+| use | \u | Change to a new database. |
+| watch | watch [seconds] [-c] query | Executes the query every [seconds] seconds (by default 5). |
++-------------+----------------------------+------------------------------------------------------------+
diff --git a/test/features/fixture_utils.py b/test/features/fixture_utils.py
new file mode 100644
index 0000000..f85e0f6
--- /dev/null
+++ b/test/features/fixture_utils.py
@@ -0,0 +1,29 @@
+import os
+import io
+
+
+def read_fixture_lines(filename):
+ """Read lines of text from file.
+
+ :param filename: string name
+ :return: list of strings
+
+ """
+ lines = []
+ for line in open(filename):
+ lines.append(line.strip())
+ return lines
+
+
+def read_fixture_files():
+ """Read all files inside fixture_data directory."""
+ fixture_dict = {}
+
+ current_dir = os.path.dirname(__file__)
+ fixture_dir = os.path.join(current_dir, 'fixture_data/')
+ for filename in os.listdir(fixture_dir):
+ if filename not in ['.', '..']:
+ fullname = os.path.join(fixture_dir, filename)
+ fixture_dict[filename] = read_fixture_lines(fullname)
+
+ return fixture_dict
diff --git a/test/features/iocommands.feature b/test/features/iocommands.feature
new file mode 100644
index 0000000..95366eb
--- /dev/null
+++ b/test/features/iocommands.feature
@@ -0,0 +1,47 @@
+Feature: I/O commands
+
+ Scenario: edit sql in file with external editor
+ When we start external editor providing a file name
+ and we type "select * from abc" in the editor
+ and we exit the editor
+ then we see dbcli prompt
+ and we see "select * from abc" in prompt
+
+ Scenario: tee output from query
+ When we tee output
+ and we wait for prompt
+ and we select "select 123456"
+ and we wait for prompt
+ and we notee output
+ and we wait for prompt
+ then we see 123456 in tee output
+
+ Scenario: set delimiter
+ When we query "delimiter $"
+ then delimiter is set to "$"
+
+ Scenario: set delimiter twice
+ When we query "delimiter $"
+ and we query "delimiter ]]"
+ then delimiter is set to "]]"
+
+ Scenario: set delimiter and query on same line
+ When we query "select 123; delimiter $ select 456 $ delimiter %"
+ then we see result "123"
+ and we see result "456"
+ and delimiter is set to "%"
+
+ Scenario: send output to file
+ When we query "\o /tmp/output1.sql"
+ and we query "select 123"
+ and we query "system cat /tmp/output1.sql"
+ then we see result "123"
+
+ Scenario: send output to file two times
+ When we query "\o /tmp/output1.sql"
+ and we query "select 123"
+ and we query "\o /tmp/output2.sql"
+ and we query "select 456"
+ and we query "system cat /tmp/output2.sql"
+ then we see result "456"
+ \ No newline at end of file
diff --git a/test/features/named_queries.feature b/test/features/named_queries.feature
new file mode 100644
index 0000000..5e681ec
--- /dev/null
+++ b/test/features/named_queries.feature
@@ -0,0 +1,24 @@
+Feature: named queries:
+ save, use and delete named queries
+
+ Scenario: save, use and delete named queries
+ When we connect to test database
+ then we see database connected
+ when we save a named query
+ then we see the named query saved
+ when we use a named query
+ then we see the named query executed
+ when we delete a named query
+ then we see the named query deleted
+
+ Scenario: save, use and delete named queries with parameters
+ When we connect to test database
+ then we see database connected
+ when we save a named query with parameters
+ then we see the named query saved
+ when we use named query with parameters
+ then we see the named query with parameters executed
+ when we use named query with too few parameters
+ then we see the named query with parameters fail with missing parameters
+ when we use named query with too many parameters
+ then we see the named query with parameters fail with extra parameters
diff --git a/test/features/specials.feature b/test/features/specials.feature
new file mode 100644
index 0000000..bb36757
--- /dev/null
+++ b/test/features/specials.feature
@@ -0,0 +1,7 @@
+Feature: Special commands
+
+ @wip
+ Scenario: run refresh command
+ When we refresh completions
+ and we wait for prompt
+ then we see completions refresh started
diff --git a/test/features/steps/__init__.py b/test/features/steps/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/features/steps/__init__.py
diff --git a/test/features/steps/auto_vertical.py b/test/features/steps/auto_vertical.py
new file mode 100644
index 0000000..e1cb26f
--- /dev/null
+++ b/test/features/steps/auto_vertical.py
@@ -0,0 +1,46 @@
+from textwrap import dedent
+
+from behave import then, when
+
+import wrappers
+from utils import parse_cli_args_to_dict
+
+
+@when('we run dbcli with {arg}')
+def step_run_cli_with_arg(context, arg):
+ wrappers.run_cli(context, run_args=parse_cli_args_to_dict(arg))
+
+
+@when('we execute a small query')
+def step_execute_small_query(context):
+ context.cli.sendline('select 1')
+
+
+@when('we execute a large query')
+def step_execute_large_query(context):
+ context.cli.sendline(
+ 'select {}'.format(','.join([str(n) for n in range(1, 50)])))
+
+
+@then('we see small results in horizontal format')
+def step_see_small_results(context):
+ wrappers.expect_pager(context, dedent("""\
+ +---+\r
+ | 1 |\r
+ +---+\r
+ | 1 |\r
+ +---+\r
+ \r
+ """), timeout=5)
+ wrappers.expect_exact(context, '1 row in set', timeout=2)
+
+
+@then('we see large results in vertical format')
+def step_see_large_results(context):
+ rows = ['{n:3}| {n}'.format(n=str(n)) for n in range(1, 50)]
+ expected = ('***************************[ 1. row ]'
+ '***************************\r\n' +
+ '{}\r\n'.format('\r\n'.join(rows) + '\r\n'))
+
+ wrappers.expect_pager(context, expected, timeout=10)
+ wrappers.expect_exact(context, '1 row in set', timeout=2)
diff --git a/test/features/steps/basic_commands.py b/test/features/steps/basic_commands.py
new file mode 100644
index 0000000..425ef67
--- /dev/null
+++ b/test/features/steps/basic_commands.py
@@ -0,0 +1,100 @@
+"""Steps for behavioral style tests are defined in this module.
+
+Each step is defined by the string decorating it. This string is used
+to call the step in "*.feature" file.
+
+"""
+
+from behave import when
+from textwrap import dedent
+import tempfile
+import wrappers
+
+
+@when('we run dbcli')
+def step_run_cli(context):
+ wrappers.run_cli(context)
+
+
+@when('we wait for prompt')
+def step_wait_prompt(context):
+ wrappers.wait_prompt(context)
+
+
+@when('we send "ctrl + d"')
+def step_ctrl_d(context):
+ """Send Ctrl + D to hopefully exit."""
+ context.cli.sendcontrol('d')
+ context.exit_sent = True
+
+
+@when('we send "\?" command')
+def step_send_help(context):
+ """Send \?
+
+ to see help.
+
+ """
+ context.cli.sendline('\\?')
+ wrappers.expect_exact(
+ context, context.conf['pager_boundary'] + '\r\n', timeout=5)
+
+
+@when(u'we send source command')
+def step_send_source_command(context):
+ with tempfile.NamedTemporaryFile() as f:
+ f.write(b'\?')
+ f.flush()
+ context.cli.sendline('\. {0}'.format(f.name))
+ wrappers.expect_exact(
+ context, context.conf['pager_boundary'] + '\r\n', timeout=5)
+
+
+@when(u'we run query to check application_name')
+def step_check_application_name(context):
+ context.cli.sendline(
+ "SELECT 'found' FROM performance_schema.session_connect_attrs WHERE attr_name = 'program_name' AND attr_value = 'mycli'"
+ )
+
+
+@then(u'we see found')
+def step_see_found(context):
+ wrappers.expect_exact(
+ context,
+ context.conf['pager_boundary'] + '\r' + dedent('''
+ +-------+\r
+ | found |\r
+ +-------+\r
+ | found |\r
+ +-------+\r
+ \r
+ ''') + context.conf['pager_boundary'],
+ timeout=5
+ )
+
+
+@then(u'we confirm the destructive warning')
+def step_confirm_destructive_command(context):
+ """Confirm destructive command."""
+ wrappers.expect_exact(
+ context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
+ context.cli.sendline('y')
+
+
+@when(u'we answer the destructive warning with "{confirmation}"')
+def step_confirm_destructive_command(context, confirmation):
+ """Confirm destructive command."""
+ wrappers.expect_exact(
+ context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
+ context.cli.sendline(confirmation)
+
+
+@then(u'we answer the destructive warning with invalid "{confirmation}" and see text "{text}"')
+def step_confirm_destructive_command(context, confirmation, text):
+ """Confirm destructive command."""
+ wrappers.expect_exact(
+ context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
+ context.cli.sendline(confirmation)
+ wrappers.expect_exact(context, text, timeout=2)
+ # we must exit the Click loop, or the feature will hang
+ context.cli.sendline('n')
diff --git a/test/features/steps/connection.py b/test/features/steps/connection.py
new file mode 100644
index 0000000..e16dd86
--- /dev/null
+++ b/test/features/steps/connection.py
@@ -0,0 +1,71 @@
+import io
+import os
+import shlex
+
+from behave import when, then
+import pexpect
+
+import wrappers
+from test.features.steps.utils import parse_cli_args_to_dict
+from test.features.environment import MY_CNF_PATH, MYLOGIN_CNF_PATH, get_db_name_from_context
+from test.utils import HOST, PORT, USER, PASSWORD
+from mycli.config import encrypt_mylogin_cnf
+
+
+TEST_LOGIN_PATH = 'test_login_path'
+
+
+@when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"')
+@when('we run mycli without arguments "{excluded_args}"')
+def step_run_cli_without_args(context, excluded_args, exact_args=''):
+ wrappers.run_cli(
+ context,
+ run_args=parse_cli_args_to_dict(exact_args),
+ exclude_args=parse_cli_args_to_dict(excluded_args).keys()
+ )
+
+
+@then('status contains "{expression}"')
+def status_contains(context, expression):
+ wrappers.expect_exact(context, f'{expression}', timeout=5)
+
+ # Normally, the shutdown after scenario waits for the prompt.
+ # But we may have changed the prompt, depending on parameters,
+ # so let's wait for its last character
+ context.cli.expect_exact('>')
+ context.atprompt = True
+
+
+@when('we create my.cnf file')
+def step_create_my_cnf_file(context):
+ my_cnf = (
+ '[client]\n'
+ f'host = {HOST}\n'
+ f'port = {PORT}\n'
+ f'user = {USER}\n'
+ f'password = {PASSWORD}\n'
+ )
+ with open(MY_CNF_PATH, 'w') as f:
+ f.write(my_cnf)
+
+
+@when('we create mylogin.cnf file')
+def step_create_mylogin_cnf_file(context):
+ os.environ.pop('MYSQL_TEST_LOGIN_FILE', None)
+ mylogin_cnf = (
+ f'[{TEST_LOGIN_PATH}]\n'
+ f'host = {HOST}\n'
+ f'port = {PORT}\n'
+ f'user = {USER}\n'
+ f'password = {PASSWORD}\n'
+ )
+ with open(MYLOGIN_CNF_PATH, 'wb') as f:
+ input_file = io.StringIO(mylogin_cnf)
+ f.write(encrypt_mylogin_cnf(input_file).read())
+
+
+@then('we are logged in')
+def we_are_logged_in(context):
+ db_name = get_db_name_from_context(context)
+ context.cli.expect_exact(f'{db_name}>', timeout=5)
+ context.atprompt = True
diff --git a/test/features/steps/crud_database.py b/test/features/steps/crud_database.py
new file mode 100644
index 0000000..841f37d
--- /dev/null
+++ b/test/features/steps/crud_database.py
@@ -0,0 +1,115 @@
+"""Steps for behavioral style tests are defined in this module.
+
+Each step is defined by the string decorating it. This string is used
+to call the step in "*.feature" file.
+
+"""
+
+import pexpect
+
+import wrappers
+from behave import when, then
+
+
+@when('we create database')
+def step_db_create(context):
+ """Send create database."""
+ context.cli.sendline('create database {0};'.format(
+ context.conf['dbname_tmp']))
+
+ context.response = {
+ 'database_name': context.conf['dbname_tmp']
+ }
+
+
+@when('we drop database')
+def step_db_drop(context):
+ """Send drop database."""
+ context.cli.sendline('drop database {0};'.format(
+ context.conf['dbname_tmp']))
+
+
+@when('we connect to test database')
+def step_db_connect_test(context):
+ """Send connect to database."""
+ db_name = context.conf['dbname']
+ context.currentdb = db_name
+ context.cli.sendline('use {0};'.format(db_name))
+
+
+@when('we connect to quoted test database')
+def step_db_connect_quoted_tmp(context):
+ """Send connect to database."""
+ db_name = context.conf['dbname']
+ context.currentdb = db_name
+ context.cli.sendline('use `{0}`;'.format(db_name))
+
+
+@when('we connect to tmp database')
+def step_db_connect_tmp(context):
+ """Send connect to database."""
+ db_name = context.conf['dbname_tmp']
+ context.currentdb = db_name
+ context.cli.sendline('use {0}'.format(db_name))
+
+
+@when('we connect to dbserver')
+def step_db_connect_dbserver(context):
+ """Send connect to database."""
+ context.currentdb = 'mysql'
+ context.cli.sendline('use mysql')
+
+
+@then('dbcli exits')
+def step_wait_exit(context):
+ """Make sure the cli exits."""
+ wrappers.expect_exact(context, pexpect.EOF, timeout=5)
+
+
+@then('we see dbcli prompt')
+def step_see_prompt(context):
+ """Wait to see the prompt."""
+ user = context.conf['user']
+ host = context.conf['host']
+ dbname = context.currentdb
+ wrappers.wait_prompt(context, '{0}@{1}:{2}> '.format(user, host, dbname))
+
+
+@then('we see help output')
+def step_see_help(context):
+ for expected_line in context.fixture_data['help_commands.txt']:
+ wrappers.expect_exact(context, expected_line, timeout=1)
+
+
+@then('we see database created')
+def step_see_db_created(context):
+ """Wait to see create database output."""
+ wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
+
+
+@then('we see database dropped')
+def step_see_db_dropped(context):
+ """Wait to see drop database output."""
+ wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
+
+
+@then('we see database dropped and no default database')
+def step_see_db_dropped_no_default(context):
+ """Wait to see drop database output."""
+ user = context.conf['user']
+ host = context.conf['host']
+ database = '(none)'
+ context.currentdb = None
+
+ wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
+ wrappers.wait_prompt(context, '{0}@{1}:{2}>'.format(user, host, database))
+
+
+@then('we see database connected')
+def step_see_db_connected(context):
+ """Wait to see drop database output."""
+ wrappers.expect_exact(
+ context, 'You are now connected to database "', timeout=2)
+ wrappers.expect_exact(context, '"', timeout=2)
+ wrappers.expect_exact(context, ' as user "{0}"'.format(
+ context.conf['user']), timeout=2)
diff --git a/test/features/steps/crud_table.py b/test/features/steps/crud_table.py
new file mode 100644
index 0000000..f715f0c
--- /dev/null
+++ b/test/features/steps/crud_table.py
@@ -0,0 +1,112 @@
+"""Steps for behavioral style tests are defined in this module.
+
+Each step is defined by the string decorating it. This string is used
+to call the step in "*.feature" file.
+
+"""
+
+import wrappers
+from behave import when, then
+from textwrap import dedent
+
+
+@when('we create table')
+def step_create_table(context):
+ """Send create table."""
+ context.cli.sendline('create table a(x text);')
+
+
+@when('we insert into table')
+def step_insert_into_table(context):
+ """Send insert into table."""
+ context.cli.sendline('''insert into a(x) values('xxx');''')
+
+
+@when('we update table')
+def step_update_table(context):
+ """Send insert into table."""
+ context.cli.sendline('''update a set x = 'yyy' where x = 'xxx';''')
+
+
+@when('we select from table')
+def step_select_from_table(context):
+ """Send select from table."""
+ context.cli.sendline('select * from a;')
+
+
+@when('we delete from table')
+def step_delete_from_table(context):
+ """Send deete from table."""
+ context.cli.sendline('''delete from a where x = 'yyy';''')
+
+
+@when('we drop table')
+def step_drop_table(context):
+ """Send drop table."""
+ context.cli.sendline('drop table a;')
+
+
+@then('we see table created')
+def step_see_table_created(context):
+ """Wait to see create table output."""
+ wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
+
+
+@then('we see record inserted')
+def step_see_record_inserted(context):
+ """Wait to see insert output."""
+ wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
+
+
+@then('we see record updated')
+def step_see_record_updated(context):
+ """Wait to see update output."""
+ wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
+
+
+@then('we see data selected')
+def step_see_data_selected(context):
+ """Wait to see select output."""
+ wrappers.expect_pager(
+ context, dedent("""\
+ +-----+\r
+ | x |\r
+ +-----+\r
+ | yyy |\r
+ +-----+\r
+ \r
+ """), timeout=2)
+ wrappers.expect_exact(context, '1 row in set', timeout=2)
+
+
+@then('we see record deleted')
+def step_see_data_deleted(context):
+ """Wait to see delete output."""
+ wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
+
+
+@then('we see table dropped')
+def step_see_table_dropped(context):
+ """Wait to see drop output."""
+ wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
+
+
+@when('we select null')
+def step_select_null(context):
+ """Send select null."""
+ context.cli.sendline('select null;')
+
+
+@then('we see null selected')
+def step_see_null_selected(context):
+ """Wait to see null output."""
+ wrappers.expect_pager(
+ context, dedent("""\
+ +--------+\r
+ | NULL |\r
+ +--------+\r
+ | <null> |\r
+ +--------+\r
+ \r
+ """), timeout=2)
+ wrappers.expect_exact(context, '1 row in set', timeout=2)
diff --git a/test/features/steps/iocommands.py b/test/features/steps/iocommands.py
new file mode 100644
index 0000000..bbabf43
--- /dev/null
+++ b/test/features/steps/iocommands.py
@@ -0,0 +1,105 @@
+import os
+import wrappers
+
+from behave import when, then
+from textwrap import dedent
+
+
+@when('we start external editor providing a file name')
+def step_edit_file(context):
+ """Edit file with external editor."""
+ context.editor_file_name = os.path.join(
+ context.package_root, 'test_file_{0}.sql'.format(context.conf['vi']))
+ if os.path.exists(context.editor_file_name):
+ os.remove(context.editor_file_name)
+ context.cli.sendline('\e {0}'.format(
+ os.path.basename(context.editor_file_name)))
+ wrappers.expect_exact(
+ context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2)
+ wrappers.expect_exact(context, '\r\n:', timeout=2)
+
+
+@when('we type "{query}" in the editor')
+def step_edit_type_sql(context, query):
+ context.cli.sendline('i')
+ context.cli.sendline(query)
+ context.cli.sendline('.')
+ wrappers.expect_exact(context, '\r\n:', timeout=2)
+
+
+@when('we exit the editor')
+def step_edit_quit(context):
+ context.cli.sendline('x')
+ wrappers.expect_exact(context, "written", timeout=2)
+
+
+@then('we see "{query}" in prompt')
+def step_edit_done_sql(context, query):
+ for match in query.split(' '):
+ wrappers.expect_exact(context, match, timeout=5)
+ # Cleanup the command line.
+ context.cli.sendcontrol('c')
+ # Cleanup the edited file.
+ if context.editor_file_name and os.path.exists(context.editor_file_name):
+ os.remove(context.editor_file_name)
+
+
+@when(u'we tee output')
+def step_tee_ouptut(context):
+ context.tee_file_name = os.path.join(
+ context.package_root, 'tee_file_{0}.sql'.format(context.conf['vi']))
+ if os.path.exists(context.tee_file_name):
+ os.remove(context.tee_file_name)
+ context.cli.sendline('tee {0}'.format(
+ os.path.basename(context.tee_file_name)))
+
+
+@when(u'we select "select {param}"')
+def step_query_select_number(context, param):
+ context.cli.sendline(u'select {}'.format(param))
+ wrappers.expect_pager(context, dedent(u"""\
+ +{dashes}+\r
+ | {param} |\r
+ +{dashes}+\r
+ | {param} |\r
+ +{dashes}+\r
+ \r
+ """.format(param=param, dashes='-' * (len(param) + 2))
+ ), timeout=5)
+ wrappers.expect_exact(context, '1 row in set', timeout=2)
+
+
+@then(u'we see result "{result}"')
+def step_see_result(context, result):
+ wrappers.expect_exact(
+ context,
+ u"| {} |".format(result),
+ timeout=2
+ )
+
+
+@when(u'we query "{query}"')
+def step_query(context, query):
+ context.cli.sendline(query)
+
+
+@when(u'we notee output')
+def step_notee_output(context):
+ context.cli.sendline('notee')
+
+
+@then(u'we see 123456 in tee output')
+def step_see_123456_in_ouput(context):
+ with open(context.tee_file_name) as f:
+ assert '123456' in f.read()
+ if os.path.exists(context.tee_file_name):
+ os.remove(context.tee_file_name)
+
+
+@then(u'delimiter is set to "{delimiter}"')
+def delimiter_is_set(context, delimiter):
+ wrappers.expect_exact(
+ context,
+ u'Changed delimiter to {}'.format(delimiter),
+ timeout=2
+ )
diff --git a/test/features/steps/named_queries.py b/test/features/steps/named_queries.py
new file mode 100644
index 0000000..bc1f866
--- /dev/null
+++ b/test/features/steps/named_queries.py
@@ -0,0 +1,90 @@
+"""Steps for behavioral style tests are defined in this module.
+
+Each step is defined by the string decorating it. This string is used
+to call the step in "*.feature" file.
+
+"""
+
+import wrappers
+from behave import when, then
+
+
+@when('we save a named query')
+def step_save_named_query(context):
+ """Send \fs command."""
+ context.cli.sendline('\\fs foo SELECT 12345')
+
+
+@when('we use a named query')
+def step_use_named_query(context):
+ """Send \f command."""
+ context.cli.sendline('\\f foo')
+
+
+@when('we delete a named query')
+def step_delete_named_query(context):
+ """Send \fd command."""
+ context.cli.sendline('\\fd foo')
+
+
+@then('we see the named query saved')
+def step_see_named_query_saved(context):
+ """Wait to see query saved."""
+ wrappers.expect_exact(context, 'Saved.', timeout=2)
+
+
+@then('we see the named query executed')
+def step_see_named_query_executed(context):
+ """Wait to see select output."""
+ wrappers.expect_exact(context, 'SELECT 12345', timeout=2)
+
+
+@then('we see the named query deleted')
+def step_see_named_query_deleted(context):
+ """Wait to see query deleted."""
+ wrappers.expect_exact(context, 'foo: Deleted', timeout=2)
+
+
+@when('we save a named query with parameters')
+def step_save_named_query_with_parameters(context):
+ """Send \fs command for query with parameters."""
+ context.cli.sendline('\\fs foo_args SELECT $1, "$2", "$3"')
+
+
+@when('we use named query with parameters')
+def step_use_named_query_with_parameters(context):
+ """Send \f command with parameters."""
+ context.cli.sendline('\\f foo_args 101 second "third value"')
+
+
+@then('we see the named query with parameters executed')
+def step_see_named_query_with_parameters_executed(context):
+ """Wait to see select output."""
+ wrappers.expect_exact(
+ context, 'SELECT 101, "second", "third value"', timeout=2)
+
+
+@when('we use named query with too few parameters')
+def step_use_named_query_with_too_few_parameters(context):
+ """Send \f command with missing parameters."""
+ context.cli.sendline('\\f foo_args 101')
+
+
+@then('we see the named query with parameters fail with missing parameters')
+def step_see_named_query_with_parameters_fail_with_missing_parameters(context):
+ """Wait to see select output."""
+ wrappers.expect_exact(
+ context, 'missing substitution for $2 in query:', timeout=2)
+
+
+@when('we use named query with too many parameters')
+def step_use_named_query_with_too_many_parameters(context):
+ """Send \f command with extra parameters."""
+ context.cli.sendline('\\f foo_args 101 102 103 104')
+
+
+@then('we see the named query with parameters fail with extra parameters')
+def step_see_named_query_with_parameters_fail_with_extra_parameters(context):
+ """Wait to see select output."""
+ wrappers.expect_exact(
+ context, 'query does not have substitution parameter $4:', timeout=2)
diff --git a/test/features/steps/specials.py b/test/features/steps/specials.py
new file mode 100644
index 0000000..e8b99e3
--- /dev/null
+++ b/test/features/steps/specials.py
@@ -0,0 +1,27 @@
+"""Steps for behavioral style tests are defined in this module.
+
+Each step is defined by the string decorating it. This string is used
+to call the step in "*.feature" file.
+
+"""
+
+import wrappers
+from behave import when, then
+
+
+@when('we refresh completions')
+def step_refresh_completions(context):
+ """Send refresh command."""
+ context.cli.sendline('rehash')
+
+
+@then('we see text "{text}"')
+def step_see_text(context, text):
+ """Wait to see given text message."""
+ wrappers.expect_exact(context, text, timeout=2)
+
+@then('we see completions refresh started')
+def step_see_refresh_started(context):
+ """Wait to see refresh output."""
+ wrappers.expect_exact(
+ context, 'Auto-completion refresh started in the background.', timeout=2)
diff --git a/test/features/steps/utils.py b/test/features/steps/utils.py
new file mode 100644
index 0000000..1ae63d2
--- /dev/null
+++ b/test/features/steps/utils.py
@@ -0,0 +1,12 @@
+import shlex
+
+
+def parse_cli_args_to_dict(cli_args: str):
+ args_dict = {}
+ for arg in shlex.split(cli_args):
+ if '=' in arg:
+ key, value = arg.split('=')
+ args_dict[key] = value
+ else:
+ args_dict[arg] = None
+ return args_dict
diff --git a/test/features/steps/wrappers.py b/test/features/steps/wrappers.py
new file mode 100644
index 0000000..6408f23
--- /dev/null
+++ b/test/features/steps/wrappers.py
@@ -0,0 +1,117 @@
+import re
+import pexpect
+import sys
+import textwrap
+
+
+try:
+ from StringIO import StringIO
+except ImportError:
+ from io import StringIO
+
+
+def expect_exact(context, expected, timeout):
+ timedout = False
+ try:
+ context.cli.expect_exact(expected, timeout=timeout)
+ except pexpect.TIMEOUT:
+ timedout = True
+ if timedout:
+ # Strip color codes out of the output.
+ actual = re.sub(r'\x1b\[([0-9A-Za-z;?])+[m|K]?',
+ '', context.cli.before)
+ raise Exception(
+ textwrap.dedent('''\
+ Expected:
+ ---
+ {0!r}
+ ---
+ Actual:
+ ---
+ {1!r}
+ ---
+ Full log:
+ ---
+ {2!r}
+ ---
+ ''').format(
+ expected,
+ actual,
+ context.logfile.getvalue()
+ )
+ )
+
+
+def expect_pager(context, expected, timeout):
+ expect_exact(context, "{0}\r\n{1}{0}\r\n".format(
+ context.conf['pager_boundary'], expected), timeout=timeout)
+
+
+def run_cli(context, run_args=None, exclude_args=None):
+ """Run the process using pexpect."""
+ run_args = run_args or {}
+ rendered_args = []
+ exclude_args = set(exclude_args) if exclude_args else set()
+
+ conf = dict(**context.conf)
+ conf.update(run_args)
+
+ def add_arg(name, key, value):
+ if name not in exclude_args:
+ if value is not None:
+ rendered_args.extend((key, value))
+ else:
+ rendered_args.append(key)
+
+ if conf.get('host', None):
+ add_arg('host', '-h', conf['host'])
+ if conf.get('user', None):
+ add_arg('user', '-u', conf['user'])
+ if conf.get('pass', None):
+ add_arg('pass', '-p', conf['pass'])
+ if conf.get('port', None):
+ add_arg('port', '-P', str(conf['port']))
+ if conf.get('dbname', None):
+ add_arg('dbname', '-D', conf['dbname'])
+ if conf.get('defaults-file', None):
+ add_arg('defaults_file', '--defaults-file', conf['defaults-file'])
+ if conf.get('myclirc', None):
+ add_arg('myclirc', '--myclirc', conf['myclirc'])
+ if conf.get('login_path'):
+ add_arg('login_path', '--login-path', conf['login_path'])
+
+ for arg_name, arg_value in conf.items():
+ if arg_name.startswith('-'):
+ add_arg(arg_name, arg_name, arg_value)
+
+ try:
+ cli_cmd = context.conf['cli_command']
+ except KeyError:
+ cli_cmd = (
+ '{0!s} -c "'
+ 'import coverage ; '
+ 'coverage.process_startup(); '
+ 'import mycli.main; '
+ 'mycli.main.cli()'
+ '"'
+ ).format(sys.executable)
+
+ cmd_parts = [cli_cmd] + rendered_args
+ cmd = ' '.join(cmd_parts)
+ context.cli = pexpect.spawnu(cmd, cwd=context.package_root)
+ context.logfile = StringIO()
+ context.cli.logfile = context.logfile
+ context.exit_sent = False
+ context.currentdb = context.conf['dbname']
+
+
+def wait_prompt(context, prompt=None):
+ """Make sure prompt is displayed."""
+ if prompt is None:
+ user = context.conf['user']
+ host = context.conf['host']
+ dbname = context.currentdb
+ prompt = '{0}@{1}:{2}>'.format(
+ user, host, dbname),
+ expect_exact(context, prompt, timeout=5)
+ context.atprompt = True
diff --git a/test/features/wrappager.py b/test/features/wrappager.py
new file mode 100755
index 0000000..51d4909
--- /dev/null
+++ b/test/features/wrappager.py
@@ -0,0 +1,16 @@
+#!/usr/bin/env python
+import sys
+
+
+def wrappager(boundary):
+ print(boundary)
+ while 1:
+ buf = sys.stdin.read(2048)
+ if not buf:
+ break
+ sys.stdout.write(buf)
+ print(boundary)
+
+
+if __name__ == "__main__":
+ wrappager(sys.argv[1])
diff --git a/test/myclirc b/test/myclirc
new file mode 100644
index 0000000..261bee6
--- /dev/null
+++ b/test/myclirc
@@ -0,0 +1,12 @@
+# vi: ft=dosini
+
+# This file is loaded after mycli/myclirc and should override only those
+# variables needed for testing.
+# To see what every variable does see mycli/myclirc
+
+[main]
+
+log_file = ~/.mycli.test.log
+log_level = DEBUG
+prompt = '\t \u@\h:\d> '
+less_chatty = True
diff --git a/test/mylogin.cnf b/test/mylogin.cnf
new file mode 100644
index 0000000..1363cc3
--- /dev/null
+++ b/test/mylogin.cnf
Binary files differ
diff --git a/test/test.txt b/test/test.txt
new file mode 100644
index 0000000..8d8b211
--- /dev/null
+++ b/test/test.txt
@@ -0,0 +1 @@
+mycli rocks!
diff --git a/test/test_clistyle.py b/test/test_clistyle.py
new file mode 100644
index 0000000..f82cdf0
--- /dev/null
+++ b/test/test_clistyle.py
@@ -0,0 +1,27 @@
+"""Test the mycli.clistyle module."""
+import pytest
+
+from pygments.style import Style
+from pygments.token import Token
+
+from mycli.clistyle import style_factory
+
+
+@pytest.mark.skip(reason="incompatible with new prompt toolkit")
+def test_style_factory():
+ """Test that a Pygments Style class is created."""
+ header = 'bold underline #ansired'
+ cli_style = {'Token.Output.Header': header}
+ style = style_factory('default', cli_style)
+
+ assert isinstance(style(), Style)
+ assert Token.Output.Header in style.styles
+ assert header == style.styles[Token.Output.Header]
+
+
+@pytest.mark.skip(reason="incompatible with new prompt toolkit")
+def test_style_factory_unknown_name():
+ """Test that an unrecognized name will not throw an error."""
+ style = style_factory('foobar', {})
+
+ assert isinstance(style(), Style)
diff --git a/test/test_completion_engine.py b/test/test_completion_engine.py
new file mode 100644
index 0000000..318b632
--- /dev/null
+++ b/test/test_completion_engine.py
@@ -0,0 +1,555 @@
+from mycli.packages.completion_engine import suggest_type
+import pytest
+
+
+def sorted_dicts(dicts):
+ """input is a list of dicts."""
+ return sorted(tuple(x.items()) for x in dicts)
+
+
+def test_select_suggests_cols_with_visible_table_scope():
+ suggestions = suggest_type('SELECT FROM tabl', 'SELECT ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['tabl']},
+ {'type': 'column', 'tables': [(None, 'tabl', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+
+def test_select_suggests_cols_with_qualified_table_scope():
+ suggestions = suggest_type('SELECT FROM sch.tabl', 'SELECT ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['tabl']},
+ {'type': 'column', 'tables': [('sch', 'tabl', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+
+@pytest.mark.parametrize('expression', [
+ 'SELECT * FROM tabl WHERE ',
+ 'SELECT * FROM tabl WHERE (',
+ 'SELECT * FROM tabl WHERE foo = ',
+ 'SELECT * FROM tabl WHERE bar OR ',
+ 'SELECT * FROM tabl WHERE foo = 1 AND ',
+ 'SELECT * FROM tabl WHERE (bar > 10 AND ',
+ 'SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (',
+ 'SELECT * FROM tabl WHERE 10 < ',
+ 'SELECT * FROM tabl WHERE foo BETWEEN ',
+ 'SELECT * FROM tabl WHERE foo BETWEEN foo AND ',
+])
+def test_where_suggests_columns_functions(expression):
+ suggestions = suggest_type(expression, expression)
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['tabl']},
+ {'type': 'column', 'tables': [(None, 'tabl', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+
+@pytest.mark.parametrize('expression', [
+ 'SELECT * FROM tabl WHERE foo IN (',
+ 'SELECT * FROM tabl WHERE foo IN (bar, ',
+])
+def test_where_in_suggests_columns(expression):
+ suggestions = suggest_type(expression, expression)
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['tabl']},
+ {'type': 'column', 'tables': [(None, 'tabl', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+
+def test_where_equals_any_suggests_columns_or_keywords():
+ text = 'SELECT * FROM tabl WHERE foo = ANY('
+ suggestions = suggest_type(text, text)
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['tabl']},
+ {'type': 'column', 'tables': [(None, 'tabl', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'}])
+
+
+def test_lparen_suggests_cols():
+ suggestion = suggest_type('SELECT MAX( FROM tbl', 'SELECT MAX(')
+ assert suggestion == [
+ {'type': 'column', 'tables': [(None, 'tbl', None)]}]
+
+
+def test_operand_inside_function_suggests_cols1():
+ suggestion = suggest_type(
+ 'SELECT MAX(col1 + FROM tbl', 'SELECT MAX(col1 + ')
+ assert suggestion == [
+ {'type': 'column', 'tables': [(None, 'tbl', None)]}]
+
+
+def test_operand_inside_function_suggests_cols2():
+ suggestion = suggest_type(
+ 'SELECT MAX(col1 + col2 + FROM tbl', 'SELECT MAX(col1 + col2 + ')
+ assert suggestion == [
+ {'type': 'column', 'tables': [(None, 'tbl', None)]}]
+
+
+def test_select_suggests_cols_and_funcs():
+ suggestions = suggest_type('SELECT ', 'SELECT ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': []},
+ {'type': 'column', 'tables': []},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+
+@pytest.mark.parametrize('expression', [
+ 'SELECT * FROM ',
+ 'INSERT INTO ',
+ 'COPY ',
+ 'UPDATE ',
+ 'DESCRIBE ',
+ 'DESC ',
+ 'EXPLAIN ',
+ 'SELECT * FROM foo JOIN ',
+])
+def test_expression_suggests_tables_views_and_schemas(expression):
+ suggestions = suggest_type(expression, expression)
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+
+@pytest.mark.parametrize('expression', [
+ 'SELECT * FROM sch.',
+ 'INSERT INTO sch.',
+ 'COPY sch.',
+ 'UPDATE sch.',
+ 'DESCRIBE sch.',
+ 'DESC sch.',
+ 'EXPLAIN sch.',
+ 'SELECT * FROM foo JOIN sch.',
+])
+def test_expression_suggests_qualified_tables_views_and_schemas(expression):
+ suggestions = suggest_type(expression, expression)
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': 'sch'},
+ {'type': 'view', 'schema': 'sch'}])
+
+
+def test_truncate_suggests_tables_and_schemas():
+ suggestions = suggest_type('TRUNCATE ', 'TRUNCATE ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'schema'}])
+
+
+def test_truncate_suggests_qualified_tables():
+ suggestions = suggest_type('TRUNCATE sch.', 'TRUNCATE sch.')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': 'sch'}])
+
+
+def test_distinct_suggests_cols():
+ suggestions = suggest_type('SELECT DISTINCT ', 'SELECT DISTINCT ')
+ assert suggestions == [{'type': 'column', 'tables': []}]
+
+
+def test_col_comma_suggests_cols():
+ suggestions = suggest_type('SELECT a, b, FROM tbl', 'SELECT a, b,')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['tbl']},
+ {'type': 'column', 'tables': [(None, 'tbl', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+
+def test_table_comma_suggests_tables_and_schemas():
+ suggestions = suggest_type('SELECT a, b FROM tbl1, ',
+ 'SELECT a, b FROM tbl1, ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+
+def test_into_suggests_tables_and_schemas():
+ suggestion = suggest_type('INSERT INTO ', 'INSERT INTO ')
+ assert sorted_dicts(suggestion) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+
+def test_insert_into_lparen_suggests_cols():
+ suggestions = suggest_type('INSERT INTO abc (', 'INSERT INTO abc (')
+ assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
+
+
+def test_insert_into_lparen_partial_text_suggests_cols():
+ suggestions = suggest_type('INSERT INTO abc (i', 'INSERT INTO abc (i')
+ assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
+
+
+def test_insert_into_lparen_comma_suggests_cols():
+ suggestions = suggest_type('INSERT INTO abc (id,', 'INSERT INTO abc (id,')
+ assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
+
+
+def test_partially_typed_col_name_suggests_col_names():
+ suggestions = suggest_type('SELECT * FROM tabl WHERE col_n',
+ 'SELECT * FROM tabl WHERE col_n')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['tabl']},
+ {'type': 'column', 'tables': [(None, 'tabl', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+
+def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
+ suggestions = suggest_type('SELECT tabl. FROM tabl', 'SELECT tabl.')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'column', 'tables': [(None, 'tabl', None)]},
+ {'type': 'table', 'schema': 'tabl'},
+ {'type': 'view', 'schema': 'tabl'},
+ {'type': 'function', 'schema': 'tabl'}])
+
+
+def test_dot_suggests_cols_of_an_alias():
+ suggestions = suggest_type('SELECT t1. FROM tabl1 t1, tabl2 t2',
+ 'SELECT t1.')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': 't1'},
+ {'type': 'view', 'schema': 't1'},
+ {'type': 'column', 'tables': [(None, 'tabl1', 't1')]},
+ {'type': 'function', 'schema': 't1'}])
+
+
+def test_dot_col_comma_suggests_cols_or_schema_qualified_table():
+ suggestions = suggest_type('SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2',
+ 'SELECT t1.a, t2.')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'column', 'tables': [(None, 'tabl2', 't2')]},
+ {'type': 'table', 'schema': 't2'},
+ {'type': 'view', 'schema': 't2'},
+ {'type': 'function', 'schema': 't2'}])
+
+
+@pytest.mark.parametrize('expression', [
+ 'SELECT * FROM (',
+ 'SELECT * FROM foo WHERE EXISTS (',
+ 'SELECT * FROM foo WHERE bar AND NOT EXISTS (',
+ 'SELECT 1 AS',
+])
+def test_sub_select_suggests_keyword(expression):
+ suggestion = suggest_type(expression, expression)
+ assert suggestion == [{'type': 'keyword'}]
+
+
+@pytest.mark.parametrize('expression', [
+ 'SELECT * FROM (S',
+ 'SELECT * FROM foo WHERE EXISTS (S',
+ 'SELECT * FROM foo WHERE bar AND NOT EXISTS (S',
+])
+def test_sub_select_partial_text_suggests_keyword(expression):
+ suggestion = suggest_type(expression, expression)
+ assert suggestion == [{'type': 'keyword'}]
+
+
+def test_outer_table_reference_in_exists_subquery_suggests_columns():
+ q = 'SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f.'
+ suggestions = suggest_type(q, q)
+ assert suggestions == [
+ {'type': 'column', 'tables': [(None, 'foo', 'f')]},
+ {'type': 'table', 'schema': 'f'},
+ {'type': 'view', 'schema': 'f'},
+ {'type': 'function', 'schema': 'f'}]
+
+
+@pytest.mark.parametrize('expression', [
+ 'SELECT * FROM (SELECT * FROM ',
+ 'SELECT * FROM foo WHERE EXISTS (SELECT * FROM ',
+ 'SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ',
+])
+def test_sub_select_table_name_completion(expression):
+ suggestion = suggest_type(expression, expression)
+ assert sorted_dicts(suggestion) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+
+def test_sub_select_col_name_completion():
+ suggestions = suggest_type('SELECT * FROM (SELECT FROM abc',
+ 'SELECT * FROM (SELECT ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['abc']},
+ {'type': 'column', 'tables': [(None, 'abc', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+
+@pytest.mark.xfail
+def test_sub_select_multiple_col_name_completion():
+ suggestions = suggest_type('SELECT * FROM (SELECT a, FROM abc',
+ 'SELECT * FROM (SELECT a, ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'column', 'tables': [(None, 'abc', None)]},
+ {'type': 'function', 'schema': []}])
+
+
+def test_sub_select_dot_col_name_completion():
+ suggestions = suggest_type('SELECT * FROM (SELECT t. FROM tabl t',
+ 'SELECT * FROM (SELECT t.')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'column', 'tables': [(None, 'tabl', 't')]},
+ {'type': 'table', 'schema': 't'},
+ {'type': 'view', 'schema': 't'},
+ {'type': 'function', 'schema': 't'}])
+
+
+@pytest.mark.parametrize('join_type', ['', 'INNER', 'LEFT', 'RIGHT OUTER'])
+@pytest.mark.parametrize('tbl_alias', ['', 'foo'])
+def test_join_suggests_tables_and_schemas(tbl_alias, join_type):
+ text = 'SELECT * FROM abc {0} {1} JOIN '.format(tbl_alias, join_type)
+ suggestion = suggest_type(text, text)
+ assert sorted_dicts(suggestion) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+
+@pytest.mark.parametrize('sql', [
+ 'SELECT * FROM abc a JOIN def d ON a.',
+ 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.',
+])
+def test_join_alias_dot_suggests_cols1(sql):
+ suggestions = suggest_type(sql, sql)
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'column', 'tables': [(None, 'abc', 'a')]},
+ {'type': 'table', 'schema': 'a'},
+ {'type': 'view', 'schema': 'a'},
+ {'type': 'function', 'schema': 'a'}])
+
+
+@pytest.mark.parametrize('sql', [
+ 'SELECT * FROM abc a JOIN def d ON a.id = d.',
+ 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.',
+])
+def test_join_alias_dot_suggests_cols2(sql):
+ suggestions = suggest_type(sql, sql)
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'column', 'tables': [(None, 'def', 'd')]},
+ {'type': 'table', 'schema': 'd'},
+ {'type': 'view', 'schema': 'd'},
+ {'type': 'function', 'schema': 'd'}])
+
+
+@pytest.mark.parametrize('sql', [
+ 'select a.x, b.y from abc a join bcd b on ',
+ 'select a.x, b.y from abc a join bcd b on a.id = b.id OR ',
+])
+def test_on_suggests_aliases(sql):
+ suggestions = suggest_type(sql, sql)
+ assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}]
+
+
+@pytest.mark.parametrize('sql', [
+ 'select abc.x, bcd.y from abc join bcd on ',
+ 'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ',
+])
+def test_on_suggests_tables(sql):
+ suggestions = suggest_type(sql, sql)
+ assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}]
+
+
+@pytest.mark.parametrize('sql', [
+ 'select a.x, b.y from abc a join bcd b on a.id = ',
+ 'select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ',
+])
+def test_on_suggests_aliases_right_side(sql):
+ suggestions = suggest_type(sql, sql)
+ assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}]
+
+
+@pytest.mark.parametrize('sql', [
+ 'select abc.x, bcd.y from abc join bcd on ',
+ 'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ',
+])
+def test_on_suggests_tables_right_side(sql):
+ suggestions = suggest_type(sql, sql)
+ assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}]
+
+
+@pytest.mark.parametrize('col_list', ['', 'col1, '])
+def test_join_using_suggests_common_columns(col_list):
+ text = 'select * from abc inner join def using (' + col_list
+ assert suggest_type(text, text) == [
+ {'type': 'column',
+ 'tables': [(None, 'abc', None), (None, 'def', None)],
+ 'drop_unique': True}]
+
+@pytest.mark.parametrize('sql', [
+ 'SELECT * FROM abc a JOIN def d ON a.id = d.id JOIN ghi g ON g.',
+ 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.id2 JOIN ghi g ON d.id = g.id AND g.',
+])
+def test_two_join_alias_dot_suggests_cols1(sql):
+ suggestions = suggest_type(sql, sql)
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'column', 'tables': [(None, 'ghi', 'g')]},
+ {'type': 'table', 'schema': 'g'},
+ {'type': 'view', 'schema': 'g'},
+ {'type': 'function', 'schema': 'g'}])
+
+def test_2_statements_2nd_current():
+ suggestions = suggest_type('select * from a; select * from ',
+ 'select * from a; select * from ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+ suggestions = suggest_type('select * from a; select from b',
+ 'select * from a; select ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['b']},
+ {'type': 'column', 'tables': [(None, 'b', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+ # Should work even if first statement is invalid
+ suggestions = suggest_type('select * from; select * from ',
+ 'select * from; select * from ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+
+def test_2_statements_1st_current():
+ suggestions = suggest_type('select * from ; select * from b',
+ 'select * from ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+ suggestions = suggest_type('select from a; select * from b',
+ 'select ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['a']},
+ {'type': 'column', 'tables': [(None, 'a', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+
+def test_3_statements_2nd_current():
+ suggestions = suggest_type('select * from a; select * from ; select * from c',
+ 'select * from a; select * from ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+ suggestions = suggest_type('select * from a; select from b; select * from c',
+ 'select * from a; select ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['b']},
+ {'type': 'column', 'tables': [(None, 'b', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+
+def test_create_db_with_template():
+ suggestions = suggest_type('create database foo with template ',
+ 'create database foo with template ')
+
+ assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'database'}])
+
+
+@pytest.mark.parametrize('initial_text', ['', ' ', '\t \t'])
+def test_specials_included_for_initial_completion(initial_text):
+ suggestions = suggest_type(initial_text, initial_text)
+
+ assert sorted_dicts(suggestions) == \
+ sorted_dicts([{'type': 'keyword'}, {'type': 'special'}])
+
+
+def test_specials_not_included_after_initial_token():
+ suggestions = suggest_type('create table foo (dt d',
+ 'create table foo (dt d')
+
+ assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'keyword'}])
+
+
+def test_drop_schema_qualified_table_suggests_only_tables():
+ text = 'DROP TABLE schema_name.table_name'
+ suggestions = suggest_type(text, text)
+ assert suggestions == [{'type': 'table', 'schema': 'schema_name'}]
+
+
+@pytest.mark.parametrize('text', [',', ' ,', 'sel ,'])
+def test_handle_pre_completion_comma_gracefully(text):
+ suggestions = suggest_type(text, text)
+
+ assert iter(suggestions)
+
+
+def test_cross_join():
+ text = 'select * from v1 cross join v2 JOIN v1.id, '
+ suggestions = suggest_type(text, text)
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+
+@pytest.mark.parametrize('expression', [
+ 'SELECT 1 AS ',
+ 'SELECT 1 FROM tabl AS ',
+])
+def test_after_as(expression):
+ suggestions = suggest_type(expression, expression)
+ assert set(suggestions) == set()
+
+
+@pytest.mark.parametrize('expression', [
+ '\\. ',
+ 'select 1; \\. ',
+ 'select 1;\\. ',
+ 'select 1 ; \\. ',
+ 'source ',
+ 'truncate table test; source ',
+ 'truncate table test ; source ',
+ 'truncate table test;source ',
+])
+def test_source_is_file(expression):
+ suggestions = suggest_type(expression, expression)
+ assert suggestions == [{'type': 'file_name'}]
+
+
+@pytest.mark.parametrize("expression", [
+ "\\f ",
+])
+def test_favorite_name_suggestion(expression):
+ suggestions = suggest_type(expression, expression)
+ assert suggestions == [{'type': 'favoritequery'}]
+
+
+def test_order_by():
+ text = 'select * from foo order by '
+ suggestions = suggest_type(text, text)
+ assert suggestions == [{'tables': [(None, 'foo', None)], 'type': 'column'}]
+
+
+def test_quoted_where():
+ text = "'where i=';"
+ suggestions = suggest_type(text, text)
+ assert suggestions == [{'type': 'keyword'}]
diff --git a/test/test_completion_refresher.py b/test/test_completion_refresher.py
new file mode 100644
index 0000000..cdc2fb5
--- /dev/null
+++ b/test/test_completion_refresher.py
@@ -0,0 +1,88 @@
+import time
+import pytest
+from unittest.mock import Mock, patch
+
+
+@pytest.fixture
+def refresher():
+ from mycli.completion_refresher import CompletionRefresher
+ return CompletionRefresher()
+
+
+def test_ctor(refresher):
+ """Refresher object should contain a few handlers.
+
+ :param refresher:
+ :return:
+
+ """
+ assert len(refresher.refreshers) > 0
+ actual_handlers = list(refresher.refreshers.keys())
+ expected_handlers = ['databases', 'schemata', 'tables', 'users', 'functions',
+ 'special_commands', 'show_commands']
+ assert expected_handlers == actual_handlers
+
+
+def test_refresh_called_once(refresher):
+ """
+
+ :param refresher:
+ :return:
+ """
+ callbacks = Mock()
+ sqlexecute = Mock()
+
+ with patch.object(refresher, '_bg_refresh') as bg_refresh:
+ actual = refresher.refresh(sqlexecute, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert len(actual) == 1
+ assert len(actual[0]) == 4
+ assert actual[0][3] == 'Auto-completion refresh started in the background.'
+ bg_refresh.assert_called_with(sqlexecute, callbacks, {})
+
+
+def test_refresh_called_twice(refresher):
+ """If refresh is called a second time, it should be restarted.
+
+ :param refresher:
+ :return:
+
+ """
+ callbacks = Mock()
+
+ sqlexecute = Mock()
+
+ def dummy_bg_refresh(*args):
+ time.sleep(3) # seconds
+
+ refresher._bg_refresh = dummy_bg_refresh
+
+ actual1 = refresher.refresh(sqlexecute, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert len(actual1) == 1
+ assert len(actual1[0]) == 4
+ assert actual1[0][3] == 'Auto-completion refresh started in the background.'
+
+ actual2 = refresher.refresh(sqlexecute, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert len(actual2) == 1
+ assert len(actual2[0]) == 4
+ assert actual2[0][3] == 'Auto-completion refresh restarted.'
+
+
+def test_refresh_with_callbacks(refresher):
+ """Callbacks must be called.
+
+ :param refresher:
+
+ """
+ callbacks = [Mock()]
+ sqlexecute_class = Mock()
+ sqlexecute = Mock()
+
+ with patch('mycli.completion_refresher.SQLExecute', sqlexecute_class):
+ # Set refreshers to 0: we're not testing refresh logic here
+ refresher.refreshers = {}
+ refresher.refresh(sqlexecute, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert (callbacks[0].call_count == 1)
diff --git a/test/test_config.py b/test/test_config.py
new file mode 100644
index 0000000..7f2b244
--- /dev/null
+++ b/test/test_config.py
@@ -0,0 +1,196 @@
+"""Unit tests for the mycli.config module."""
+from io import BytesIO, StringIO, TextIOWrapper
+import os
+import struct
+import sys
+import tempfile
+import pytest
+
+from mycli.config import (get_mylogin_cnf_path, open_mylogin_cnf,
+ read_and_decrypt_mylogin_cnf, read_config_file,
+ str_to_bool, strip_matching_quotes)
+
+LOGIN_PATH_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__),
+ 'mylogin.cnf'))
+
+
+def open_bmylogin_cnf(name):
+ """Open contents of *name* in a BytesIO buffer."""
+ with open(name, 'rb') as f:
+ buf = BytesIO()
+ buf.write(f.read())
+ return buf
+
+def test_read_mylogin_cnf():
+ """Tests that a login path file can be read and decrypted."""
+ mylogin_cnf = open_mylogin_cnf(LOGIN_PATH_FILE)
+
+ assert isinstance(mylogin_cnf, TextIOWrapper)
+
+ contents = mylogin_cnf.read()
+ for word in ('[test]', 'user', 'password', 'host', 'port'):
+ assert word in contents
+
+
+def test_decrypt_blank_mylogin_cnf():
+ """Test that a blank login path file is handled correctly."""
+ mylogin_cnf = read_and_decrypt_mylogin_cnf(BytesIO())
+ assert mylogin_cnf is None
+
+
+def test_corrupted_login_key():
+ """Test that a corrupted login path key is handled correctly."""
+ buf = open_bmylogin_cnf(LOGIN_PATH_FILE)
+
+ # Skip past the unused bytes
+ buf.seek(4)
+
+ # Write null bytes over half the login key
+ buf.write(b'\0\0\0\0\0\0\0\0\0\0')
+
+ buf.seek(0)
+ mylogin_cnf = read_and_decrypt_mylogin_cnf(buf)
+
+ assert mylogin_cnf is None
+
+
+def test_corrupted_pad():
+ """Tests that a login path file with a corrupted pad is partially read."""
+ buf = open_bmylogin_cnf(LOGIN_PATH_FILE)
+
+ # Skip past the login key
+ buf.seek(24)
+
+ # Skip option group
+ len_buf = buf.read(4)
+ cipher_len, = struct.unpack("<i", len_buf)
+ buf.read(cipher_len)
+
+ # Corrupt the pad for the user line
+ len_buf = buf.read(4)
+ cipher_len, = struct.unpack("<i", len_buf)
+ buf.read(cipher_len - 1)
+ buf.write(b'\0')
+
+ buf.seek(0)
+ mylogin_cnf = TextIOWrapper(read_and_decrypt_mylogin_cnf(buf))
+ contents = mylogin_cnf.read()
+ for word in ('[test]', 'password', 'host', 'port'):
+ assert word in contents
+ assert 'user' not in contents
+
+
+def test_get_mylogin_cnf_path():
+ """Tests that the path for .mylogin.cnf is detected."""
+ original_env = None
+ if 'MYSQL_TEST_LOGIN_FILE' in os.environ:
+ original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE')
+ is_windows = sys.platform == 'win32'
+
+ login_cnf_path = get_mylogin_cnf_path()
+
+ if original_env is not None:
+ os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env
+
+ if login_cnf_path is not None:
+ assert login_cnf_path.endswith('.mylogin.cnf')
+
+ if is_windows is True:
+ assert 'MySQL' in login_cnf_path
+ else:
+ home_dir = os.path.expanduser('~')
+ assert login_cnf_path.startswith(home_dir)
+
+
+def test_alternate_get_mylogin_cnf_path():
+ """Tests that the alternate path for .mylogin.cnf is detected."""
+ original_env = None
+ if 'MYSQL_TEST_LOGIN_FILE' in os.environ:
+ original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE')
+
+ _, temp_path = tempfile.mkstemp()
+ os.environ['MYSQL_TEST_LOGIN_FILE'] = temp_path
+
+ login_cnf_path = get_mylogin_cnf_path()
+
+ if original_env is not None:
+ os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env
+
+ assert temp_path == login_cnf_path
+
+
+def test_str_to_bool():
+ """Tests that str_to_bool function converts values correctly."""
+
+ assert str_to_bool(False) is False
+ assert str_to_bool(True) is True
+ assert str_to_bool('False') is False
+ assert str_to_bool('True') is True
+ assert str_to_bool('TRUE') is True
+ assert str_to_bool('1') is True
+ assert str_to_bool('0') is False
+ assert str_to_bool('on') is True
+ assert str_to_bool('off') is False
+ assert str_to_bool('off') is False
+
+ with pytest.raises(ValueError):
+ str_to_bool('foo')
+
+ with pytest.raises(TypeError):
+ str_to_bool(None)
+
+
+def test_read_config_file_list_values_default():
+ """Test that reading a config file uses list_values by default."""
+
+ f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n")
+ config = read_config_file(f)
+
+ assert config['main']['weather'] == u"cloudy with a chance of meatballs"
+
+
+def test_read_config_file_list_values_off():
+ """Test that you can disable list_values when reading a config file."""
+
+ f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n")
+ config = read_config_file(f, list_values=False)
+
+ assert config['main']['weather'] == u"'cloudy with a chance of meatballs'"
+
+
+def test_strip_quotes_with_matching_quotes():
+ """Test that a string with matching quotes is unquoted."""
+
+ s = "May the force be with you."
+ assert s == strip_matching_quotes('"{}"'.format(s))
+ assert s == strip_matching_quotes("'{}'".format(s))
+
+
+def test_strip_quotes_with_unmatching_quotes():
+ """Test that a string with unmatching quotes is not unquoted."""
+
+ s = "May the force be with you."
+ assert '"' + s == strip_matching_quotes('"{}'.format(s))
+ assert s + "'" == strip_matching_quotes("{}'".format(s))
+
+
+def test_strip_quotes_with_empty_string():
+ """Test that an empty string is handled during unquoting."""
+
+ assert '' == strip_matching_quotes('')
+
+
+def test_strip_quotes_with_none():
+ """Test that None is handled during unquoting."""
+
+ assert None is strip_matching_quotes(None)
+
+
+def test_strip_quotes_with_quotes():
+ """Test that strings with quotes in them are handled during unquoting."""
+
+ s1 = 'Darth Vader said, "Luke, I am your father."'
+ assert s1 == strip_matching_quotes(s1)
+
+ s2 = '"Darth Vader said, "Luke, I am your father.""'
+ assert s2[1:-1] == strip_matching_quotes(s2)
diff --git a/test/test_dbspecial.py b/test/test_dbspecial.py
new file mode 100644
index 0000000..21e389c
--- /dev/null
+++ b/test/test_dbspecial.py
@@ -0,0 +1,42 @@
+from mycli.packages.completion_engine import suggest_type
+from .test_completion_engine import sorted_dicts
+from mycli.packages.special.utils import format_uptime
+
+
+def test_u_suggests_databases():
+ suggestions = suggest_type('\\u ', '\\u ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'database'}])
+
+
+def test_describe_table():
+ suggestions = suggest_type('\\dt', '\\dt ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+
+def test_list_or_show_create_tables():
+ suggestions = suggest_type('\\dt+', '\\dt+ ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+
+def test_format_uptime():
+ seconds = 59
+ assert '59 sec' == format_uptime(seconds)
+
+ seconds = 120
+ assert '2 min 0 sec' == format_uptime(seconds)
+
+ seconds = 54890
+ assert '15 hours 14 min 50 sec' == format_uptime(seconds)
+
+ seconds = 598244
+ assert '6 days 22 hours 10 min 44 sec' == format_uptime(seconds)
+
+ seconds = 522600
+ assert '6 days 1 hour 10 min 0 sec' == format_uptime(seconds)
diff --git a/test/test_main.py b/test/test_main.py
new file mode 100644
index 0000000..64cba0a
--- /dev/null
+++ b/test/test_main.py
@@ -0,0 +1,548 @@
+import os
+import shutil
+
+import click
+from click.testing import CliRunner
+
+from mycli.main import MyCli, cli, thanks_picker
+from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
+from mycli.sqlexecute import ServerInfo
+from .utils import USER, HOST, PORT, PASSWORD, dbtest, run
+
+from textwrap import dedent
+from collections import namedtuple
+
+from tempfile import NamedTemporaryFile
+from textwrap import dedent
+
+
+test_dir = os.path.abspath(os.path.dirname(__file__))
+project_dir = os.path.dirname(test_dir)
+default_config_file = os.path.join(project_dir, 'test', 'myclirc')
+login_path_file = os.path.join(test_dir, 'mylogin.cnf')
+
+os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
+CLI_ARGS = ['--user', USER, '--host', HOST, '--port', PORT,
+ '--password', PASSWORD, '--myclirc', default_config_file,
+ '--defaults-file', default_config_file,
+ 'mycli_test_db']
+
+
+@dbtest
+def test_execute_arg(executor):
+ run(executor, 'create table test (a text)')
+ run(executor, 'insert into test values("abc")')
+
+ sql = 'select * from test;'
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql])
+
+ assert result.exit_code == 0
+ assert 'abc' in result.output
+
+ result = runner.invoke(cli, args=CLI_ARGS + ['--execute', sql])
+
+ assert result.exit_code == 0
+ assert 'abc' in result.output
+
+ expected = 'a\nabc\n'
+
+ assert expected in result.output
+
+
+@dbtest
+def test_execute_arg_with_table(executor):
+ run(executor, 'create table test (a text)')
+ run(executor, 'insert into test values("abc")')
+
+ sql = 'select * from test;'
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--table'])
+ expected = '+-----+\n| a |\n+-----+\n| abc |\n+-----+\n'
+
+ assert result.exit_code == 0
+ assert expected in result.output
+
+
+@dbtest
+def test_execute_arg_with_csv(executor):
+ run(executor, 'create table test (a text)')
+ run(executor, 'insert into test values("abc")')
+
+ sql = 'select * from test;'
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--csv'])
+ expected = '"a"\n"abc"\n'
+
+ assert result.exit_code == 0
+ assert expected in "".join(result.output)
+
+
+@dbtest
+def test_batch_mode(executor):
+ run(executor, '''create table test(a text)''')
+ run(executor, '''insert into test values('abc'), ('def'), ('ghi')''')
+
+ sql = (
+ 'select count(*) from test;\n'
+ 'select * from test limit 1;'
+ )
+
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS, input=sql)
+
+ assert result.exit_code == 0
+ assert 'count(*)\n3\na\nabc\n' in "".join(result.output)
+
+
+@dbtest
+def test_batch_mode_table(executor):
+ run(executor, '''create table test(a text)''')
+ run(executor, '''insert into test values('abc'), ('def'), ('ghi')''')
+
+ sql = (
+ 'select count(*) from test;\n'
+ 'select * from test limit 1;'
+ )
+
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ['-t'], input=sql)
+
+ expected = (dedent("""\
+ +----------+
+ | count(*) |
+ +----------+
+ | 3 |
+ +----------+
+ +-----+
+ | a |
+ +-----+
+ | abc |
+ +-----+"""))
+
+ assert result.exit_code == 0
+ assert expected in result.output
+
+
+@dbtest
+def test_batch_mode_csv(executor):
+ run(executor, '''create table test(a text, b text)''')
+ run(executor,
+ '''insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')''')
+
+ sql = 'select * from test;'
+
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ['--csv'], input=sql)
+
+ expected = '"a","b"\n"abc","de\nf"\n"ghi","jkl"\n'
+
+ assert result.exit_code == 0
+ assert expected in "".join(result.output)
+
+
+def test_thanks_picker_utf8():
+ name = thanks_picker()
+ assert name and isinstance(name, str)
+
+
+def test_help_strings_end_with_periods():
+ """Make sure click options have help text that end with a period."""
+ for param in cli.params:
+ if isinstance(param, click.core.Option):
+ assert hasattr(param, 'help')
+ assert param.help.endswith('.')
+
+
+def test_command_descriptions_end_with_periods():
+ """Make sure that mycli commands' descriptions end with a period."""
+ MyCli()
+ for _, command in SPECIAL_COMMANDS.items():
+ assert command[3].endswith('.')
+
+
+def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
+ global clickoutput
+ clickoutput = ""
+ m = MyCli(myclirc=default_config_file)
+
+ class TestOutput():
+ def get_size(self):
+ size = namedtuple('Size', 'rows columns')
+ size.columns, size.rows = terminal_size
+ return size
+
+ class TestExecute():
+ host = 'test'
+ user = 'test'
+ dbname = 'test'
+ server_info = ServerInfo.from_version_string('unknown')
+ port = 0
+
+ def server_type(self):
+ return ['test']
+
+ class PromptBuffer():
+ output = TestOutput()
+
+ m.prompt_app = PromptBuffer()
+ m.sqlexecute = TestExecute()
+ m.explicit_pager = explicit_pager
+
+ def echo_via_pager(s):
+ assert expect_pager
+ global clickoutput
+ clickoutput += "".join(s)
+
+ def secho(s):
+ assert not expect_pager
+ global clickoutput
+ clickoutput += s + "\n"
+
+ monkeypatch.setattr(click, 'echo_via_pager', echo_via_pager)
+ monkeypatch.setattr(click, 'secho', secho)
+ m.output(testdata)
+ if clickoutput.endswith("\n"):
+ clickoutput = clickoutput[:-1]
+ assert clickoutput == "\n".join(testdata)
+
+
+def test_conditional_pager(monkeypatch):
+ testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split(
+ " ")
+ # User didn't set pager, output doesn't fit screen -> pager
+ output(
+ monkeypatch,
+ terminal_size=(5, 10),
+ testdata=testdata,
+ explicit_pager=False,
+ expect_pager=True
+ )
+ # User didn't set pager, output fits screen -> no pager
+ output(
+ monkeypatch,
+ terminal_size=(20, 20),
+ testdata=testdata,
+ explicit_pager=False,
+ expect_pager=False
+ )
+ # User manually configured pager, output doesn't fit screen -> pager
+ output(
+ monkeypatch,
+ terminal_size=(5, 10),
+ testdata=testdata,
+ explicit_pager=True,
+ expect_pager=True
+ )
+ # User manually configured pager, output fit screen -> pager
+ output(
+ monkeypatch,
+ terminal_size=(20, 20),
+ testdata=testdata,
+ explicit_pager=True,
+ expect_pager=True
+ )
+
+ SPECIAL_COMMANDS['nopager'].handler()
+ output(
+ monkeypatch,
+ terminal_size=(5, 10),
+ testdata=testdata,
+ explicit_pager=False,
+ expect_pager=False
+ )
+ SPECIAL_COMMANDS['pager'].handler('')
+
+
+def test_reserved_space_is_integer():
+ """Make sure that reserved space is returned as an integer."""
+ def stub_terminal_size():
+ return (5, 5)
+
+ old_func = shutil.get_terminal_size
+
+ shutil.get_terminal_size = stub_terminal_size
+ mycli = MyCli()
+ assert isinstance(mycli.get_reserved_space(), int)
+
+ shutil.get_terminal_size = old_func
+
+
+def test_list_dsn():
+ runner = CliRunner()
+ with NamedTemporaryFile(mode="w") as myclirc:
+ myclirc.write(dedent("""\
+ [alias_dsn]
+ test = mysql://test/test
+ """))
+ myclirc.flush()
+ args = ['--list-dsn', '--myclirc', myclirc.name]
+ result = runner.invoke(cli, args=args)
+ assert result.output == "test\n"
+ result = runner.invoke(cli, args=args + ['--verbose'])
+ assert result.output == "test : mysql://test/test\n"
+
+
+def test_prettify_statement():
+ statement = 'SELECT 1'
+ m = MyCli()
+ pretty_statement = m.handle_prettify_binding(statement)
+ assert pretty_statement == 'SELECT\n 1;'
+
+
+def test_unprettify_statement():
+ statement = 'SELECT\n 1'
+ m = MyCli()
+ unpretty_statement = m.handle_unprettify_binding(statement)
+ assert unpretty_statement == 'SELECT 1;'
+
+
+def test_list_ssh_config():
+ runner = CliRunner()
+ with NamedTemporaryFile(mode="w") as ssh_config:
+ ssh_config.write(dedent("""\
+ Host test
+ Hostname test.example.com
+ User joe
+ Port 22222
+ IdentityFile ~/.ssh/gateway
+ """))
+ ssh_config.flush()
+ args = ['--list-ssh-config', '--ssh-config-path', ssh_config.name]
+ result = runner.invoke(cli, args=args)
+ assert "test\n" in result.output
+ result = runner.invoke(cli, args=args + ['--verbose'])
+ assert "test : test.example.com\n" in result.output
+
+
+def test_dsn(monkeypatch):
+ # Setup classes to mock mycli.main.MyCli
+ class Formatter:
+ format_name = None
+
+ class Logger:
+ def debug(self, *args, **args_dict):
+ pass
+
+ def warning(self, *args, **args_dict):
+ pass
+
+ class MockMyCli:
+ config = {'alias_dsn': {}}
+
+ def __init__(self, **args):
+ self.logger = Logger()
+ self.destructive_warning = False
+ self.formatter = Formatter()
+
+ def connect(self, **args):
+ MockMyCli.connect_args = args
+
+ def run_query(self, query, new_line=True):
+ pass
+
+ import mycli.main
+ monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli)
+ runner = CliRunner()
+
+ # When a user supplies a DSN as database argument to mycli,
+ # use these values.
+ result = runner.invoke(mycli.main.cli, args=[
+ "mysql://dsn_user:dsn_passwd@dsn_host:1/dsn_database"]
+ )
+ assert result.exit_code == 0, result.output + " " + str(result.exception)
+ assert \
+ MockMyCli.connect_args["user"] == "dsn_user" and \
+ MockMyCli.connect_args["passwd"] == "dsn_passwd" and \
+ MockMyCli.connect_args["host"] == "dsn_host" and \
+ MockMyCli.connect_args["port"] == 1 and \
+ MockMyCli.connect_args["database"] == "dsn_database"
+
+ MockMyCli.connect_args = None
+
+ # When a use supplies a DSN as database argument to mycli,
+ # and used command line arguments, use the command line
+ # arguments.
+ result = runner.invoke(mycli.main.cli, args=[
+ "mysql://dsn_user:dsn_passwd@dsn_host:2/dsn_database",
+ "--user", "arg_user",
+ "--password", "arg_password",
+ "--host", "arg_host",
+ "--port", "3",
+ "--database", "arg_database",
+ ])
+ assert result.exit_code == 0, result.output + " " + str(result.exception)
+ assert \
+ MockMyCli.connect_args["user"] == "arg_user" and \
+ MockMyCli.connect_args["passwd"] == "arg_password" and \
+ MockMyCli.connect_args["host"] == "arg_host" and \
+ MockMyCli.connect_args["port"] == 3 and \
+ MockMyCli.connect_args["database"] == "arg_database"
+
+ MockMyCli.config = {
+ 'alias_dsn': {
+ 'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database'
+ }
+ }
+ MockMyCli.connect_args = None
+
+ # When a user uses a DSN from the configuration file (alias_dsn),
+ # use these values.
+ result = runner.invoke(cli, args=['--dsn', 'test'])
+ assert result.exit_code == 0, result.output + " " + str(result.exception)
+ assert \
+ MockMyCli.connect_args["user"] == "alias_dsn_user" and \
+ MockMyCli.connect_args["passwd"] == "alias_dsn_passwd" and \
+ MockMyCli.connect_args["host"] == "alias_dsn_host" and \
+ MockMyCli.connect_args["port"] == 4 and \
+ MockMyCli.connect_args["database"] == "alias_dsn_database"
+
+ MockMyCli.config = {
+ 'alias_dsn': {
+ 'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database'
+ }
+ }
+ MockMyCli.connect_args = None
+
+ # When a user uses a DSN from the configuration file (alias_dsn)
+ # and used command line arguments, use the command line arguments.
+ result = runner.invoke(cli, args=[
+ '--dsn', 'test', '',
+ "--user", "arg_user",
+ "--password", "arg_password",
+ "--host", "arg_host",
+ "--port", "5",
+ "--database", "arg_database",
+ ])
+ assert result.exit_code == 0, result.output + " " + str(result.exception)
+ assert \
+ MockMyCli.connect_args["user"] == "arg_user" and \
+ MockMyCli.connect_args["passwd"] == "arg_password" and \
+ MockMyCli.connect_args["host"] == "arg_host" and \
+ MockMyCli.connect_args["port"] == 5 and \
+ MockMyCli.connect_args["database"] == "arg_database"
+
+ # Use a DSN without password
+ result = runner.invoke(mycli.main.cli, args=[
+ "mysql://dsn_user@dsn_host:6/dsn_database"]
+ )
+ assert result.exit_code == 0, result.output + " " + str(result.exception)
+ assert \
+ MockMyCli.connect_args["user"] == "dsn_user" and \
+ MockMyCli.connect_args["passwd"] is None and \
+ MockMyCli.connect_args["host"] == "dsn_host" and \
+ MockMyCli.connect_args["port"] == 6 and \
+ MockMyCli.connect_args["database"] == "dsn_database"
+
+
+def test_ssh_config(monkeypatch):
+ # Setup classes to mock mycli.main.MyCli
+ class Formatter:
+ format_name = None
+
+ class Logger:
+ def debug(self, *args, **args_dict):
+ pass
+
+ def warning(self, *args, **args_dict):
+ pass
+
+ class MockMyCli:
+ config = {'alias_dsn': {}}
+
+ def __init__(self, **args):
+ self.logger = Logger()
+ self.destructive_warning = False
+ self.formatter = Formatter()
+
+ def connect(self, **args):
+ MockMyCli.connect_args = args
+
+ def run_query(self, query, new_line=True):
+ pass
+
+ import mycli.main
+ monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli)
+ runner = CliRunner()
+
+ # Setup temporary configuration
+ with NamedTemporaryFile(mode="w") as ssh_config:
+ ssh_config.write(dedent("""\
+ Host test
+ Hostname test.example.com
+ User joe
+ Port 22222
+ IdentityFile ~/.ssh/gateway
+ """))
+ ssh_config.flush()
+
+ # When a user supplies a ssh config.
+ result = runner.invoke(mycli.main.cli, args=[
+ "--ssh-config-path",
+ ssh_config.name,
+ "--ssh-config-host",
+ "test"
+ ])
+ assert result.exit_code == 0, result.output + \
+ " " + str(result.exception)
+ assert \
+ MockMyCli.connect_args["ssh_user"] == "joe" and \
+ MockMyCli.connect_args["ssh_host"] == "test.example.com" and \
+ MockMyCli.connect_args["ssh_port"] == 22222 and \
+ MockMyCli.connect_args["ssh_key_filename"] == os.getenv(
+ "HOME") + "/.ssh/gateway"
+
+ # When a user supplies a ssh config host as argument to mycli,
+ # and used command line arguments, use the command line
+ # arguments.
+ result = runner.invoke(mycli.main.cli, args=[
+ "--ssh-config-path",
+ ssh_config.name,
+ "--ssh-config-host",
+ "test",
+ "--ssh-user", "arg_user",
+ "--ssh-host", "arg_host",
+ "--ssh-port", "3",
+ "--ssh-key-filename", "/path/to/key"
+ ])
+ assert result.exit_code == 0, result.output + \
+ " " + str(result.exception)
+ assert \
+ MockMyCli.connect_args["ssh_user"] == "arg_user" and \
+ MockMyCli.connect_args["ssh_host"] == "arg_host" and \
+ MockMyCli.connect_args["ssh_port"] == 3 and \
+ MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key"
+
+
+@dbtest
+def test_init_command_arg(executor):
+ init_command = "set sql_select_limit=1000"
+ sql = 'show variables like "sql_select_limit";'
+ runner = CliRunner()
+ result = runner.invoke(
+ cli, args=CLI_ARGS + ["--init-command", init_command], input=sql
+ )
+
+ expected = "sql_select_limit\t1000\n"
+ assert result.exit_code == 0
+ assert expected in result.output
+
+
+@dbtest
+def test_init_command_multiple_arg(executor):
+ init_command = 'set sql_select_limit=2000; set max_join_size=20000'
+ sql = (
+ 'show variables like "sql_select_limit";\n'
+ 'show variables like "max_join_size"'
+ )
+ runner = CliRunner()
+ result = runner.invoke(
+ cli, args=CLI_ARGS + ['--init-command', init_command], input=sql
+ )
+
+ expected_sql_select_limit = 'sql_select_limit\t2000\n'
+ expected_max_join_size = 'max_join_size\t20000\n'
+
+ assert result.exit_code == 0
+ assert expected_sql_select_limit in result.output
+ assert expected_max_join_size in result.output
diff --git a/test/test_naive_completion.py b/test/test_naive_completion.py
new file mode 100644
index 0000000..32b2abd
--- /dev/null
+++ b/test/test_naive_completion.py
@@ -0,0 +1,63 @@
+import pytest
+from prompt_toolkit.completion import Completion
+from prompt_toolkit.document import Document
+
+
+@pytest.fixture
+def completer():
+ import mycli.sqlcompleter as sqlcompleter
+ return sqlcompleter.SQLCompleter(smart_completion=False)
+
+
+@pytest.fixture
+def complete_event():
+ from unittest.mock import Mock
+ return Mock()
+
+
+def test_empty_string_completion(completer, complete_event):
+ text = ''
+ position = 0
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list(map(Completion, sorted(completer.all_completions)))
+
+
+def test_select_keyword_completion(completer, complete_event):
+ text = 'SEL'
+ position = len('SEL')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([Completion(text='SELECT', start_position=-3)])
+
+
+def test_function_name_completion(completer, complete_event):
+ text = 'SELECT MA'
+ position = len('SELECT MA')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='MASTER', start_position=-2),
+ Completion(text='MAX', start_position=-2)])
+
+
+def test_column_name_completion(completer, complete_event):
+ text = 'SELECT FROM users'
+ position = len('SELECT ')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list(map(Completion, sorted(completer.all_completions)))
+
+
+def test_special_name_completion(completer, complete_event):
+ text = '\\'
+ position = len('\\')
+ result = set(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ # Special commands will NOT be suggested during naive completion mode.
+ assert result == set()
diff --git a/test/test_parseutils.py b/test/test_parseutils.py
new file mode 100644
index 0000000..920a08d
--- /dev/null
+++ b/test/test_parseutils.py
@@ -0,0 +1,190 @@
+import pytest
+from mycli.packages.parseutils import (
+ extract_tables, query_starts_with, queries_start_with, is_destructive, query_has_where_clause,
+ is_dropping_database)
+
+
+def test_empty_string():
+ tables = extract_tables('')
+ assert tables == []
+
+
+def test_simple_select_single_table():
+ tables = extract_tables('select * from abc')
+ assert tables == [(None, 'abc', None)]
+
+
+def test_simple_select_single_table_schema_qualified():
+ tables = extract_tables('select * from abc.def')
+ assert tables == [('abc', 'def', None)]
+
+
+def test_simple_select_multiple_tables():
+ tables = extract_tables('select * from abc, def')
+ assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
+
+
+def test_simple_select_multiple_tables_schema_qualified():
+ tables = extract_tables('select * from abc.def, ghi.jkl')
+ assert sorted(tables) == [('abc', 'def', None), ('ghi', 'jkl', None)]
+
+
+def test_simple_select_with_cols_single_table():
+ tables = extract_tables('select a,b from abc')
+ assert tables == [(None, 'abc', None)]
+
+
+def test_simple_select_with_cols_single_table_schema_qualified():
+ tables = extract_tables('select a,b from abc.def')
+ assert tables == [('abc', 'def', None)]
+
+
+def test_simple_select_with_cols_multiple_tables():
+ tables = extract_tables('select a,b from abc, def')
+ assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
+
+
+def test_simple_select_with_cols_multiple_tables_with_schema():
+ tables = extract_tables('select a,b from abc.def, def.ghi')
+ assert sorted(tables) == [('abc', 'def', None), ('def', 'ghi', None)]
+
+
+def test_select_with_hanging_comma_single_table():
+ tables = extract_tables('select a, from abc')
+ assert tables == [(None, 'abc', None)]
+
+
+def test_select_with_hanging_comma_multiple_tables():
+ tables = extract_tables('select a, from abc, def')
+ assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
+
+
+def test_select_with_hanging_period_multiple_tables():
+ tables = extract_tables('SELECT t1. FROM tabl1 t1, tabl2 t2')
+ assert sorted(tables) == [(None, 'tabl1', 't1'), (None, 'tabl2', 't2')]
+
+
+def test_simple_insert_single_table():
+ tables = extract_tables('insert into abc (id, name) values (1, "def")')
+
+ # sqlparse mistakenly assigns an alias to the table
+ # assert tables == [(None, 'abc', None)]
+ assert tables == [(None, 'abc', 'abc')]
+
+
+@pytest.mark.xfail
+def test_simple_insert_single_table_schema_qualified():
+ tables = extract_tables('insert into abc.def (id, name) values (1, "def")')
+ assert tables == [('abc', 'def', None)]
+
+
+def test_simple_update_table():
+ tables = extract_tables('update abc set id = 1')
+ assert tables == [(None, 'abc', None)]
+
+
+def test_simple_update_table_with_schema():
+ tables = extract_tables('update abc.def set id = 1')
+ assert tables == [('abc', 'def', None)]
+
+
+def test_join_table():
+ tables = extract_tables('SELECT * FROM abc a JOIN def d ON a.id = d.num')
+ assert sorted(tables) == [(None, 'abc', 'a'), (None, 'def', 'd')]
+
+
+def test_join_table_schema_qualified():
+ tables = extract_tables(
+ 'SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num')
+ assert tables == [('abc', 'def', 'x'), ('ghi', 'jkl', 'y')]
+
+
+def test_join_as_table():
+ tables = extract_tables('SELECT * FROM my_table AS m WHERE m.a > 5')
+ assert tables == [(None, 'my_table', 'm')]
+
+
+def test_query_starts_with():
+ query = 'USE test;'
+ assert query_starts_with(query, ('use', )) is True
+
+ query = 'DROP DATABASE test;'
+ assert query_starts_with(query, ('use', )) is False
+
+
+def test_query_starts_with_comment():
+ query = '# comment\nUSE test;'
+ assert query_starts_with(query, ('use', )) is True
+
+
+def test_queries_start_with():
+ sql = (
+ '# comment\n'
+ 'show databases;'
+ 'use foo;'
+ )
+ assert queries_start_with(sql, ('show', 'select')) is True
+ assert queries_start_with(sql, ('use', 'drop')) is True
+ assert queries_start_with(sql, ('delete', 'update')) is False
+
+
+def test_is_destructive():
+ sql = (
+ 'use test;\n'
+ 'show databases;\n'
+ 'drop database foo;'
+ )
+ assert is_destructive(sql) is True
+
+
+def test_is_destructive_update_with_where_clause():
+ sql = (
+ 'use test;\n'
+ 'show databases;\n'
+ 'UPDATE test SET x = 1 WHERE id = 1;'
+ )
+ assert is_destructive(sql) is False
+
+
+def test_is_destructive_update_without_where_clause():
+ sql = (
+ 'use test;\n'
+ 'show databases;\n'
+ 'UPDATE test SET x = 1;'
+ )
+ assert is_destructive(sql) is True
+
+
+@pytest.mark.parametrize(
+ ('sql', 'has_where_clause'),
+ [
+ ('update test set dummy = 1;', False),
+ ('update test set dummy = 1 where id = 1);', True),
+ ],
+)
+def test_query_has_where_clause(sql, has_where_clause):
+ assert query_has_where_clause(sql) is has_where_clause
+
+
+@pytest.mark.parametrize(
+ ('sql', 'dbname', 'is_dropping'),
+ [
+ ('select bar from foo', 'foo', False),
+ ('drop database "foo";', '`foo`', True),
+ ('drop schema foo', 'foo', True),
+ ('drop schema foo', 'bar', False),
+ ('drop database bar', 'foo', False),
+ ('drop database foo', None, False),
+ ('drop database foo; create database foo', 'foo', False),
+ ('drop database foo; create database bar', 'foo', True),
+ ('select bar from foo; drop database bazz', 'foo', False),
+ ('select bar from foo; drop database bazz', 'bazz', True),
+ ('-- dropping database \n '
+ 'drop -- really dropping \n '
+ 'schema abc -- now it is dropped',
+ 'abc',
+ True)
+ ]
+)
+def test_is_dropping_database(sql, dbname, is_dropping):
+ assert is_dropping_database(sql, dbname) == is_dropping
diff --git a/test/test_plan.wiki b/test/test_plan.wiki
new file mode 100644
index 0000000..43e9083
--- /dev/null
+++ b/test/test_plan.wiki
@@ -0,0 +1,38 @@
+= Gross Checks =
+ * [ ] Check connecting to a local database.
+ * [ ] Check connecting to a remote database.
+ * [ ] Check connecting to a database with a user/password.
+ * [ ] Check connecting to a non-existent database.
+ * [ ] Test changing the database.
+
+ == PGExecute ==
+ * [ ] Test successful execution given a cursor.
+ * [ ] Test unsuccessful execution with a syntax error.
+ * [ ] Test a series of executions with the same cursor without failure.
+ * [ ] Test a series of executions with the same cursor with failure.
+ * [ ] Test passing in a special command.
+
+ == Naive Autocompletion ==
+ * [ ] Input empty string, ask for completions - Everything.
+ * [ ] Input partial prefix, ask for completions - Stars with prefix.
+ * [ ] Input fully autocompleted string, ask for completions - Only full match
+ * [ ] Input non-existent prefix, ask for completions - nothing
+ * [ ] Input lowercase prefix - case insensitive completions
+
+ == Smart Autocompletion ==
+ * [ ] Input empty string and check if only keywords are returned.
+ * [ ] Input SELECT prefix and check if only columns and '*' are returned.
+ * [ ] Input SELECT blah - only keywords are returned.
+ * [ ] Input SELECT * FROM - Table names only
+
+ == PGSpecial ==
+ * [ ] Test \d
+ * [ ] Test \d tablename
+ * [ ] Test \d tablena*
+ * [ ] Test \d non-existent-tablename
+ * [ ] Test \d index
+ * [ ] Test \d sequence
+ * [ ] Test \d view
+
+ == Exceptionals ==
+ * [ ] Test the 'use' command to change db.
diff --git a/test/test_prompt_utils.py b/test/test_prompt_utils.py
new file mode 100644
index 0000000..2373fac
--- /dev/null
+++ b/test/test_prompt_utils.py
@@ -0,0 +1,11 @@
+import click
+
+from mycli.packages.prompt_utils import confirm_destructive_query
+
+
+def test_confirm_destructive_query_notty():
+ stdin = click.get_text_stream('stdin')
+ assert stdin.isatty() is False
+
+ sql = 'drop database foo;'
+ assert confirm_destructive_query(sql) is None
diff --git a/test/test_smart_completion_public_schema_only.py b/test/test_smart_completion_public_schema_only.py
new file mode 100644
index 0000000..e7d460a
--- /dev/null
+++ b/test/test_smart_completion_public_schema_only.py
@@ -0,0 +1,385 @@
+import pytest
+from unittest.mock import patch
+from prompt_toolkit.completion import Completion
+from prompt_toolkit.document import Document
+import mycli.packages.special.main as special
+
+metadata = {
+ 'users': ['id', 'email', 'first_name', 'last_name'],
+ 'orders': ['id', 'ordered_date', 'status'],
+ 'select': ['id', 'insert', 'ABC'],
+ 'réveillé': ['id', 'insert', 'ABC']
+}
+
+
+@pytest.fixture
+def completer():
+
+ import mycli.sqlcompleter as sqlcompleter
+ comp = sqlcompleter.SQLCompleter(smart_completion=True)
+
+ tables, columns = [], []
+
+ for table, cols in metadata.items():
+ tables.append((table,))
+ columns.extend([(table, col) for col in cols])
+
+ comp.set_dbname('test')
+ comp.extend_schemata('test')
+ comp.extend_relations(tables, kind='tables')
+ comp.extend_columns(columns, kind='tables')
+ comp.extend_special_commands(special.COMMANDS)
+
+ return comp
+
+
+@pytest.fixture
+def complete_event():
+ from unittest.mock import Mock
+ return Mock()
+
+
+def test_special_name_completion(completer, complete_event):
+ text = '\\d'
+ position = len('\\d')
+ result = completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event)
+ assert result == [Completion(text='\\dt', start_position=-2)]
+
+
+def test_empty_string_completion(completer, complete_event):
+ text = ''
+ position = 0
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert list(map(Completion, sorted(completer.keywords) +
+ sorted(completer.special_commands))) == result
+
+
+def test_select_keyword_completion(completer, complete_event):
+ text = 'SEL'
+ position = len('SEL')
+ result = completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event)
+ assert list(result) == list([Completion(text='SELECT', start_position=-3)])
+
+
+def test_table_completion(completer, complete_event):
+ text = 'SELECT * FROM '
+ position = len(text)
+ result = completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event)
+ assert list(result) == list([
+ Completion(text='`réveillé`', start_position=0),
+ Completion(text='`select`', start_position=0),
+ Completion(text='orders', start_position=0),
+ Completion(text='users', start_position=0),
+ ])
+
+
+def test_function_name_completion(completer, complete_event):
+ text = 'SELECT MA'
+ position = len('SELECT MA')
+ result = completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event)
+ assert list(result) == list([Completion(text='MAX', start_position=-2),
+ Completion(text='MASTER', start_position=-2),
+ ])
+
+
+def test_suggested_column_names(completer, complete_event):
+ """Suggest column and function names when selecting from table.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = 'SELECT from users'
+ position = len('SELECT ')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='*', start_position=0),
+ Completion(text='email', start_position=0),
+ Completion(text='first_name', start_position=0),
+ Completion(text='id', start_position=0),
+ Completion(text='last_name', start_position=0),
+ ] +
+ list(map(Completion, completer.functions)) +
+ [Completion(text='users', start_position=0)] +
+ list(map(Completion, completer.keywords)))
+
+
+def test_suggested_column_names_in_function(completer, complete_event):
+ """Suggest column and function names when selecting multiple columns from
+ table.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = 'SELECT MAX( from users'
+ position = len('SELECT MAX(')
+ result = completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event)
+ assert list(result) == list([
+ Completion(text='*', start_position=0),
+ Completion(text='email', start_position=0),
+ Completion(text='first_name', start_position=0),
+ Completion(text='id', start_position=0),
+ Completion(text='last_name', start_position=0)])
+
+
+def test_suggested_column_names_with_table_dot(completer, complete_event):
+ """Suggest column names on table name and dot.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = 'SELECT users. from users'
+ position = len('SELECT users.')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='*', start_position=0),
+ Completion(text='email', start_position=0),
+ Completion(text='first_name', start_position=0),
+ Completion(text='id', start_position=0),
+ Completion(text='last_name', start_position=0)])
+
+
+def test_suggested_column_names_with_alias(completer, complete_event):
+ """Suggest column names on table alias and dot.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = 'SELECT u. from users u'
+ position = len('SELECT u.')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='*', start_position=0),
+ Completion(text='email', start_position=0),
+ Completion(text='first_name', start_position=0),
+ Completion(text='id', start_position=0),
+ Completion(text='last_name', start_position=0)])
+
+
+def test_suggested_multiple_column_names(completer, complete_event):
+ """Suggest column and function names when selecting multiple columns from
+ table.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = 'SELECT id, from users u'
+ position = len('SELECT id, ')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='*', start_position=0),
+ Completion(text='email', start_position=0),
+ Completion(text='first_name', start_position=0),
+ Completion(text='id', start_position=0),
+ Completion(text='last_name', start_position=0)] +
+ list(map(Completion, completer.functions)) +
+ [Completion(text='u', start_position=0)] +
+ list(map(Completion, completer.keywords)))
+
+
+def test_suggested_multiple_column_names_with_alias(completer, complete_event):
+ """Suggest column names on table alias and dot when selecting multiple
+ columns from table.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = 'SELECT u.id, u. from users u'
+ position = len('SELECT u.id, u.')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='*', start_position=0),
+ Completion(text='email', start_position=0),
+ Completion(text='first_name', start_position=0),
+ Completion(text='id', start_position=0),
+ Completion(text='last_name', start_position=0)])
+
+
+def test_suggested_multiple_column_names_with_dot(completer, complete_event):
+ """Suggest column names on table names and dot when selecting multiple
+ columns from table.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = 'SELECT users.id, users. from users u'
+ position = len('SELECT users.id, users.')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='*', start_position=0),
+ Completion(text='email', start_position=0),
+ Completion(text='first_name', start_position=0),
+ Completion(text='id', start_position=0),
+ Completion(text='last_name', start_position=0)])
+
+
+def test_suggested_aliases_after_on(completer, complete_event):
+ text = 'SELECT u.name, o.id FROM users u JOIN orders o ON '
+ position = len('SELECT u.name, o.id FROM users u JOIN orders o ON ')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='o', start_position=0),
+ Completion(text='u', start_position=0)])
+
+
+def test_suggested_aliases_after_on_right_side(completer, complete_event):
+ text = 'SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = '
+ position = len(
+ 'SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='o', start_position=0),
+ Completion(text='u', start_position=0)])
+
+
+def test_suggested_tables_after_on(completer, complete_event):
+ text = 'SELECT users.name, orders.id FROM users JOIN orders ON '
+ position = len('SELECT users.name, orders.id FROM users JOIN orders ON ')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='orders', start_position=0),
+ Completion(text='users', start_position=0)])
+
+
+def test_suggested_tables_after_on_right_side(completer, complete_event):
+ text = 'SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = '
+ position = len(
+ 'SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='orders', start_position=0),
+ Completion(text='users', start_position=0)])
+
+
+def test_table_names_after_from(completer, complete_event):
+ text = 'SELECT * FROM '
+ position = len('SELECT * FROM ')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='`réveillé`', start_position=0),
+ Completion(text='`select`', start_position=0),
+ Completion(text='orders', start_position=0),
+ Completion(text='users', start_position=0),
+ ])
+
+
+def test_auto_escaped_col_names(completer, complete_event):
+ text = 'SELECT from `select`'
+ position = len('SELECT ')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == [
+ Completion(text='*', start_position=0),
+ Completion(text='`ABC`', start_position=0),
+ Completion(text='`insert`', start_position=0),
+ Completion(text='id', start_position=0),
+ ] + \
+ list(map(Completion, completer.functions)) + \
+ [Completion(text='`select`', start_position=0)] + \
+ list(map(Completion, completer.keywords))
+
+
+def test_un_escaped_table_names(completer, complete_event):
+ text = 'SELECT from réveillé'
+ position = len('SELECT ')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='*', start_position=0),
+ Completion(text='`ABC`', start_position=0),
+ Completion(text='`insert`', start_position=0),
+ Completion(text='id', start_position=0),
+ ] +
+ list(map(Completion, completer.functions)) +
+ [Completion(text='réveillé', start_position=0)] +
+ list(map(Completion, completer.keywords)))
+
+
+def dummy_list_path(dir_name):
+ dirs = {
+ '/': [
+ 'dir1',
+ 'file1.sql',
+ 'file2.sql',
+ ],
+ '/dir1': [
+ 'subdir1',
+ 'subfile1.sql',
+ 'subfile2.sql',
+ ],
+ '/dir1/subdir1': [
+ 'lastfile.sql',
+ ],
+ }
+ return dirs.get(dir_name, [])
+
+
+@patch('mycli.packages.filepaths.list_path', new=dummy_list_path)
+@pytest.mark.parametrize('text,expected', [
+ # ('source ', [('~', 0),
+ # ('/', 0),
+ # ('.', 0),
+ # ('..', 0)]),
+ ('source /', [('dir1', 0),
+ ('file1.sql', 0),
+ ('file2.sql', 0)]),
+ ('source /dir1/', [('subdir1', 0),
+ ('subfile1.sql', 0),
+ ('subfile2.sql', 0)]),
+ ('source /dir1/subdir1/', [('lastfile.sql', 0)]),
+])
+def test_file_name_completion(completer, complete_event, text, expected):
+ position = len(text)
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ expected = list((Completion(txt, pos) for txt, pos in expected))
+ assert result == expected
diff --git a/test/test_special_iocommands.py b/test/test_special_iocommands.py
new file mode 100644
index 0000000..8b6be33
--- /dev/null
+++ b/test/test_special_iocommands.py
@@ -0,0 +1,287 @@
+import os
+import stat
+import tempfile
+from time import time
+from unittest.mock import patch
+
+import pytest
+from pymysql import ProgrammingError
+
+import mycli.packages.special
+
+from .utils import dbtest, db_connection, send_ctrl_c
+
+
+def test_set_get_pager():
+ mycli.packages.special.set_pager_enabled(True)
+ assert mycli.packages.special.is_pager_enabled()
+ mycli.packages.special.set_pager_enabled(False)
+ assert not mycli.packages.special.is_pager_enabled()
+ mycli.packages.special.set_pager('less')
+ assert os.environ['PAGER'] == "less"
+ mycli.packages.special.set_pager(False)
+ assert os.environ['PAGER'] == "less"
+ del os.environ['PAGER']
+ mycli.packages.special.set_pager(False)
+ mycli.packages.special.disable_pager()
+ assert not mycli.packages.special.is_pager_enabled()
+
+
+def test_set_get_timing():
+ mycli.packages.special.set_timing_enabled(True)
+ assert mycli.packages.special.is_timing_enabled()
+ mycli.packages.special.set_timing_enabled(False)
+ assert not mycli.packages.special.is_timing_enabled()
+
+
+def test_set_get_expanded_output():
+ mycli.packages.special.set_expanded_output(True)
+ assert mycli.packages.special.is_expanded_output()
+ mycli.packages.special.set_expanded_output(False)
+ assert not mycli.packages.special.is_expanded_output()
+
+
+def test_editor_command():
+ assert mycli.packages.special.editor_command(r'hello\e')
+ assert mycli.packages.special.editor_command(r'\ehello')
+ assert not mycli.packages.special.editor_command(r'hello')
+
+ assert mycli.packages.special.get_filename(r'\e filename') == "filename"
+
+ os.environ['EDITOR'] = 'true'
+ os.environ['VISUAL'] = 'true'
+ mycli.packages.special.open_external_editor(sql=r'select 1') == "select 1"
+
+
+def test_tee_command():
+ mycli.packages.special.write_tee(u"hello world") # write without file set
+ with tempfile.NamedTemporaryFile() as f:
+ mycli.packages.special.execute(None, u"tee " + f.name)
+ mycli.packages.special.write_tee(u"hello world")
+ assert f.read() == b"hello world\n"
+
+ mycli.packages.special.execute(None, u"tee -o " + f.name)
+ mycli.packages.special.write_tee(u"hello world")
+ f.seek(0)
+ assert f.read() == b"hello world\n"
+
+ mycli.packages.special.execute(None, u"notee")
+ mycli.packages.special.write_tee(u"hello world")
+ f.seek(0)
+ assert f.read() == b"hello world\n"
+
+
+def test_tee_command_error():
+ with pytest.raises(TypeError):
+ mycli.packages.special.execute(None, 'tee')
+
+ with pytest.raises(OSError):
+ with tempfile.NamedTemporaryFile() as f:
+ os.chmod(f.name, stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH)
+ mycli.packages.special.execute(None, 'tee {}'.format(f.name))
+
+
+@dbtest
+def test_favorite_query():
+ with db_connection().cursor() as cur:
+ query = u'select "✔"'
+ mycli.packages.special.execute(cur, u'\\fs check {0}'.format(query))
+ assert next(mycli.packages.special.execute(
+ cur, u'\\f check'))[0] == "> " + query
+
+
+def test_once_command():
+ with pytest.raises(TypeError):
+ mycli.packages.special.execute(None, u"\\once")
+
+ with pytest.raises(OSError):
+ mycli.packages.special.execute(None, u"\\once /proc/access-denied")
+
+ mycli.packages.special.write_once(u"hello world") # write without file set
+ with tempfile.NamedTemporaryFile() as f:
+ mycli.packages.special.execute(None, u"\\once " + f.name)
+ mycli.packages.special.write_once(u"hello world")
+ assert f.read() == b"hello world\n"
+
+ mycli.packages.special.execute(None, u"\\once -o " + f.name)
+ mycli.packages.special.write_once(u"hello world line 1")
+ mycli.packages.special.write_once(u"hello world line 2")
+ f.seek(0)
+ assert f.read() == b"hello world line 1\nhello world line 2\n"
+
+
+def test_pipe_once_command():
+ with pytest.raises(IOError):
+ mycli.packages.special.execute(None, u"\\pipe_once")
+
+ with pytest.raises(OSError):
+ mycli.packages.special.execute(
+ None, u"\\pipe_once /proc/access-denied")
+
+ mycli.packages.special.execute(None, u"\\pipe_once wc")
+ mycli.packages.special.write_once(u"hello world")
+ mycli.packages.special.unset_pipe_once_if_written()
+ # how to assert on wc output?
+
+
+def test_parseargfile():
+ """Test that parseargfile expands the user directory."""
+ expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'),
+ 'mode': 'a'}
+ assert expected == mycli.packages.special.iocommands.parseargfile(
+ '~/filename')
+
+ expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'),
+ 'mode': 'w'}
+ assert expected == mycli.packages.special.iocommands.parseargfile(
+ '-o ~/filename')
+
+
+def test_parseargfile_no_file():
+ """Test that parseargfile raises a TypeError if there is no filename."""
+ with pytest.raises(TypeError):
+ mycli.packages.special.iocommands.parseargfile('')
+
+ with pytest.raises(TypeError):
+ mycli.packages.special.iocommands.parseargfile('-o ')
+
+
+@dbtest
+def test_watch_query_iteration():
+ """Test that a single iteration of the result of `watch_query` executes
+ the desired query and returns the given results."""
+ expected_value = "1"
+ query = "SELECT {0!s}".format(expected_value)
+ expected_title = '> {0!s}'.format(query)
+ with db_connection().cursor() as cur:
+ result = next(mycli.packages.special.iocommands.watch_query(
+ arg=query, cur=cur
+ ))
+ assert result[0] == expected_title
+ assert result[2][0] == expected_value
+
+
+@dbtest
+def test_watch_query_full():
+ """Test that `watch_query`:
+
+ * Returns the expected results.
+ * Executes the defined times inside the given interval, in this case with
+ a 0.3 seconds wait, it should execute 4 times inside a 1 seconds
+ interval.
+ * Stops at Ctrl-C
+
+ """
+ watch_seconds = 0.3
+ wait_interval = 1
+ expected_value = "1"
+ query = "SELECT {0!s}".format(expected_value)
+ expected_title = '> {0!s}'.format(query)
+ expected_results = 4
+ ctrl_c_process = send_ctrl_c(wait_interval)
+ with db_connection().cursor() as cur:
+ results = list(
+ result for result in mycli.packages.special.iocommands.watch_query(
+ arg='{0!s} {1!s}'.format(watch_seconds, query), cur=cur
+ )
+ )
+ ctrl_c_process.join(1)
+ assert len(results) == expected_results
+ for result in results:
+ assert result[0] == expected_title
+ assert result[2][0] == expected_value
+
+
+@dbtest
+@patch('click.clear')
+def test_watch_query_clear(clear_mock):
+ """Test that the screen is cleared with the -c flag of `watch` command
+ before execute the query."""
+ with db_connection().cursor() as cur:
+ watch_gen = mycli.packages.special.iocommands.watch_query(
+ arg='0.1 -c select 1;', cur=cur
+ )
+ assert not clear_mock.called
+ next(watch_gen)
+ assert clear_mock.called
+ clear_mock.reset_mock()
+ next(watch_gen)
+ assert clear_mock.called
+ clear_mock.reset_mock()
+
+
+@dbtest
+def test_watch_query_bad_arguments():
+ """Test different incorrect combinations of arguments for `watch`
+ command."""
+ watch_query = mycli.packages.special.iocommands.watch_query
+ with db_connection().cursor() as cur:
+ with pytest.raises(ProgrammingError):
+ next(watch_query('a select 1;', cur=cur))
+ with pytest.raises(ProgrammingError):
+ next(watch_query('-a select 1;', cur=cur))
+ with pytest.raises(ProgrammingError):
+ next(watch_query('1 -a select 1;', cur=cur))
+ with pytest.raises(ProgrammingError):
+ next(watch_query('-c -a select 1;', cur=cur))
+
+
+@dbtest
+@patch('click.clear')
+def test_watch_query_interval_clear(clear_mock):
+ """Test `watch` command with interval and clear flag."""
+ def test_asserts(gen):
+ clear_mock.reset_mock()
+ start = time()
+ next(gen)
+ assert clear_mock.called
+ next(gen)
+ exec_time = time() - start
+ assert exec_time > seconds and exec_time < (seconds + seconds)
+
+ seconds = 1.0
+ watch_query = mycli.packages.special.iocommands.watch_query
+ with db_connection().cursor() as cur:
+ test_asserts(watch_query('{0!s} -c select 1;'.format(seconds),
+ cur=cur))
+ test_asserts(watch_query('-c {0!s} select 1;'.format(seconds),
+ cur=cur))
+
+
+def test_split_sql_by_delimiter():
+ for delimiter_str in (';', '$', '😀'):
+ mycli.packages.special.set_delimiter(delimiter_str)
+ sql_input = "select 1{} select \ufffc2".format(delimiter_str)
+ queries = (
+ "select 1",
+ "select \ufffc2"
+ )
+ for query, parsed_query in zip(
+ queries, mycli.packages.special.split_queries(sql_input)):
+ assert(query == parsed_query)
+
+
+def test_switch_delimiter_within_query():
+ mycli.packages.special.set_delimiter(';')
+ sql_input = "select 1; delimiter $$ select 2 $$ select 3 $$"
+ queries = (
+ "select 1",
+ "delimiter $$ select 2 $$ select 3 $$",
+ "select 2",
+ "select 3"
+ )
+ for query, parsed_query in zip(
+ queries,
+ mycli.packages.special.split_queries(sql_input)):
+ assert(query == parsed_query)
+
+
+def test_set_delimiter():
+
+ for delim in ('foo', 'bar'):
+ mycli.packages.special.set_delimiter(delim)
+ assert mycli.packages.special.get_current_delimiter() == delim
+
+
+def teardown_function():
+ mycli.packages.special.set_delimiter(';')
diff --git a/test/test_sqlexecute.py b/test/test_sqlexecute.py
new file mode 100644
index 0000000..38ca5ef
--- /dev/null
+++ b/test/test_sqlexecute.py
@@ -0,0 +1,295 @@
+import os
+
+import pytest
+import pymysql
+
+from mycli.sqlexecute import ServerInfo, ServerSpecies
+from .utils import run, dbtest, set_expanded_output, is_expanded_output
+
+
+def assert_result_equal(result, title=None, rows=None, headers=None,
+ status=None, auto_status=True, assert_contains=False):
+ """Assert that an sqlexecute.run() result matches the expected values."""
+ if status is None and auto_status and rows:
+ status = '{} row{} in set'.format(
+ len(rows), 's' if len(rows) > 1 else '')
+ fields = {'title': title, 'rows': rows, 'headers': headers,
+ 'status': status}
+
+ if assert_contains:
+ # Do a loose match on the results using the *in* operator.
+ for key, field in fields.items():
+ if field:
+ assert field in result[0][key]
+ else:
+ # Do an exact match on the fields.
+ assert result == [fields]
+
+
+@dbtest
+def test_conn(executor):
+ run(executor, '''create table test(a text)''')
+ run(executor, '''insert into test values('abc')''')
+ results = run(executor, '''select * from test''')
+
+ assert_result_equal(results, headers=['a'], rows=[('abc',)])
+
+
+@dbtest
+def test_bools(executor):
+ run(executor, '''create table test(a boolean)''')
+ run(executor, '''insert into test values(True)''')
+ results = run(executor, '''select * from test''')
+
+ assert_result_equal(results, headers=['a'], rows=[(1,)])
+
+
+@dbtest
+def test_binary(executor):
+ run(executor, '''create table bt(geom linestring NOT NULL)''')
+ run(executor, "INSERT INTO bt VALUES "
+ "(ST_GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));")
+ results = run(executor, '''select * from bt''')
+
+ geom = (b'\x00\x00\x00\x00\x01\x02\x00\x00\x00\x02\x00\x00\x009\x7f\x13\n'
+ b'\x11\x18]@4\xf4Op\xb1\xdeC@\x00\x00\x00\x00\x00\x18]@B>\xe8\xd9'
+ b'\xac\xdeC@')
+
+ assert_result_equal(results, headers=['geom'], rows=[(geom,)])
+
+
+@dbtest
+def test_table_and_columns_query(executor):
+ run(executor, "create table a(x text, y text)")
+ run(executor, "create table b(z text)")
+
+ assert set(executor.tables()) == set([('a',), ('b',)])
+ assert set(executor.table_columns()) == set(
+ [('a', 'x'), ('a', 'y'), ('b', 'z')])
+
+
+@dbtest
+def test_database_list(executor):
+ databases = executor.databases()
+ assert 'mycli_test_db' in databases
+
+
+@dbtest
+def test_invalid_syntax(executor):
+ with pytest.raises(pymysql.ProgrammingError) as excinfo:
+ run(executor, 'invalid syntax!')
+ assert 'You have an error in your SQL syntax;' in str(excinfo.value)
+
+
+@dbtest
+def test_invalid_column_name(executor):
+ with pytest.raises(pymysql.err.OperationalError) as excinfo:
+ run(executor, 'select invalid command')
+ assert "Unknown column 'invalid' in 'field list'" in str(excinfo.value)
+
+
+@dbtest
+def test_unicode_support_in_output(executor):
+ run(executor, "create table unicodechars(t text)")
+ run(executor, u"insert into unicodechars (t) values ('é')")
+
+ # See issue #24, this raises an exception without proper handling
+ results = run(executor, u"select * from unicodechars")
+ assert_result_equal(results, headers=['t'], rows=[(u'é',)])
+
+
+@dbtest
+def test_multiple_queries_same_line(executor):
+ results = run(executor, "select 'foo'; select 'bar'")
+
+ expected = [{'title': None, 'headers': ['foo'], 'rows': [('foo',)],
+ 'status': '1 row in set'},
+ {'title': None, 'headers': ['bar'], 'rows': [('bar',)],
+ 'status': '1 row in set'}]
+ assert expected == results
+
+
+@dbtest
+def test_multiple_queries_same_line_syntaxerror(executor):
+ with pytest.raises(pymysql.ProgrammingError) as excinfo:
+ run(executor, "select 'foo'; invalid syntax")
+ assert 'You have an error in your SQL syntax;' in str(excinfo.value)
+
+
+@dbtest
+def test_favorite_query(executor):
+ set_expanded_output(False)
+ run(executor, "create table test(a text)")
+ run(executor, "insert into test values('abc')")
+ run(executor, "insert into test values('def')")
+
+ results = run(executor, "\\fs test-a select * from test where a like 'a%'")
+ assert_result_equal(results, status='Saved.')
+
+ results = run(executor, "\\f test-a")
+ assert_result_equal(results,
+ title="> select * from test where a like 'a%'",
+ headers=['a'], rows=[('abc',)], auto_status=False)
+
+ results = run(executor, "\\fd test-a")
+ assert_result_equal(results, status='test-a: Deleted')
+
+
+@dbtest
+def test_favorite_query_multiple_statement(executor):
+ set_expanded_output(False)
+ run(executor, "create table test(a text)")
+ run(executor, "insert into test values('abc')")
+ run(executor, "insert into test values('def')")
+
+ results = run(executor,
+ "\\fs test-ad select * from test where a like 'a%'; "
+ "select * from test where a like 'd%'")
+ assert_result_equal(results, status='Saved.')
+
+ results = run(executor, "\\f test-ad")
+ expected = [{'title': "> select * from test where a like 'a%'",
+ 'headers': ['a'], 'rows': [('abc',)], 'status': None},
+ {'title': "> select * from test where a like 'd%'",
+ 'headers': ['a'], 'rows': [('def',)], 'status': None}]
+ assert expected == results
+
+ results = run(executor, "\\fd test-ad")
+ assert_result_equal(results, status='test-ad: Deleted')
+
+
+@dbtest
+def test_favorite_query_expanded_output(executor):
+ set_expanded_output(False)
+ run(executor, '''create table test(a text)''')
+ run(executor, '''insert into test values('abc')''')
+
+ results = run(executor, "\\fs test-ae select * from test")
+ assert_result_equal(results, status='Saved.')
+
+ results = run(executor, "\\f test-ae \\G")
+ assert is_expanded_output() is True
+ assert_result_equal(results, title='> select * from test',
+ headers=['a'], rows=[('abc',)], auto_status=False)
+
+ set_expanded_output(False)
+
+ results = run(executor, "\\fd test-ae")
+ assert_result_equal(results, status='test-ae: Deleted')
+
+
+@dbtest
+def test_special_command(executor):
+ results = run(executor, '\\?')
+ assert_result_equal(results, rows=('quit', '\\q', 'Quit.'),
+ headers='Command', assert_contains=True,
+ auto_status=False)
+
+
+@dbtest
+def test_cd_command_without_a_folder_name(executor):
+ results = run(executor, 'system cd')
+ assert_result_equal(results, status='No folder name was provided.')
+
+
+@dbtest
+def test_system_command_not_found(executor):
+ results = run(executor, 'system xyz')
+ assert_result_equal(results, status='OSError: No such file or directory',
+ assert_contains=True)
+
+
+@dbtest
+def test_system_command_output(executor):
+ test_dir = os.path.abspath(os.path.dirname(__file__))
+ test_file_path = os.path.join(test_dir, 'test.txt')
+ results = run(executor, 'system cat {0}'.format(test_file_path))
+ assert_result_equal(results, status='mycli rocks!\n')
+
+
+@dbtest
+def test_cd_command_current_dir(executor):
+ test_path = os.path.abspath(os.path.dirname(__file__))
+ run(executor, 'system cd {0}'.format(test_path))
+ assert os.getcwd() == test_path
+
+
+@dbtest
+def test_unicode_support(executor):
+ results = run(executor, u"SELECT '日本語' AS japanese;")
+ assert_result_equal(results, headers=['japanese'], rows=[(u'日本語',)])
+
+
+@dbtest
+def test_timestamp_null(executor):
+ run(executor, '''create table ts_null(a timestamp null)''')
+ run(executor, '''insert into ts_null values(null)''')
+ results = run(executor, '''select * from ts_null''')
+ assert_result_equal(results, headers=['a'],
+ rows=[(None,)])
+
+
+@dbtest
+def test_datetime_null(executor):
+ run(executor, '''create table dt_null(a datetime null)''')
+ run(executor, '''insert into dt_null values(null)''')
+ results = run(executor, '''select * from dt_null''')
+ assert_result_equal(results, headers=['a'],
+ rows=[(None,)])
+
+
+@dbtest
+def test_date_null(executor):
+ run(executor, '''create table date_null(a date null)''')
+ run(executor, '''insert into date_null values(null)''')
+ results = run(executor, '''select * from date_null''')
+ assert_result_equal(results, headers=['a'], rows=[(None,)])
+
+
+@dbtest
+def test_time_null(executor):
+ run(executor, '''create table time_null(a time null)''')
+ run(executor, '''insert into time_null values(null)''')
+ results = run(executor, '''select * from time_null''')
+ assert_result_equal(results, headers=['a'], rows=[(None,)])
+
+
+@dbtest
+def test_multiple_results(executor):
+ query = '''CREATE PROCEDURE dmtest()
+ BEGIN
+ SELECT 1;
+ SELECT 2;
+ END'''
+ executor.conn.cursor().execute(query)
+
+ results = run(executor, 'call dmtest;')
+ expected = [
+ {'title': None, 'rows': [(1,)], 'headers': ['1'],
+ 'status': '1 row in set'},
+ {'title': None, 'rows': [(2,)], 'headers': ['2'],
+ 'status': '1 row in set'}
+ ]
+ assert results == expected
+
+
+@pytest.mark.parametrize(
+ 'version_string, species, parsed_version_string, version',
+ (
+ ('5.7.25-TiDB-v6.1.0','TiDB', '5.7.25', 50725),
+ ('5.7.32-35', 'Percona', '5.7.32', 50732),
+ ('5.7.32-0ubuntu0.18.04.1', 'MySQL', '5.7.32', 50732),
+ ('10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
+ ('5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
+ ('5.0.16-pro-nt-log', 'MySQL', '5.0.16', 50016),
+ ('5.1.5a-alpha', 'MySQL', '5.1.5', 50105),
+ ('unexpected version string', None, '', 0),
+ ('', None, '', 0),
+ (None, None, '', 0),
+ )
+)
+def test_version_parsing(version_string, species, parsed_version_string, version):
+ server_info = ServerInfo.from_version_string(version_string)
+ assert (server_info.species and server_info.species.name) == species or ServerSpecies.Unknown
+ assert server_info.version_str == parsed_version_string
+ assert server_info.version == version
diff --git a/test/test_tabular_output.py b/test/test_tabular_output.py
new file mode 100644
index 0000000..c20c7de
--- /dev/null
+++ b/test/test_tabular_output.py
@@ -0,0 +1,118 @@
+"""Test the sql output adapter."""
+
+from textwrap import dedent
+
+from mycli.packages.tabular_output import sql_format
+from cli_helpers.tabular_output import TabularOutputFormatter
+
+from .utils import USER, PASSWORD, HOST, PORT, dbtest
+
+import pytest
+from mycli.main import MyCli
+
+from pymysql.constants import FIELD_TYPE
+
+
+@pytest.fixture
+def mycli():
+ cli = MyCli()
+ cli.connect(None, USER, PASSWORD, HOST, PORT, None, init_command=None)
+ return cli
+
+
+@dbtest
+def test_sql_output(mycli):
+ """Test the sql output adapter."""
+ headers = ['letters', 'number', 'optional', 'float', 'binary']
+
+ class FakeCursor(object):
+ def __init__(self):
+ self.data = [
+ ('abc', 1, None, 10.0, b'\xAA'),
+ ('d', 456, '1', 0.5, b'\xAA\xBB')
+ ]
+ self.description = [
+ (None, FIELD_TYPE.VARCHAR),
+ (None, FIELD_TYPE.LONG),
+ (None, FIELD_TYPE.LONG),
+ (None, FIELD_TYPE.FLOAT),
+ (None, FIELD_TYPE.BLOB)
+ ]
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.data:
+ return self.data.pop(0)
+ else:
+ raise StopIteration()
+
+ def description(self):
+ return self.description
+
+ # Test sql-update output format
+ assert list(mycli.change_table_format("sql-update")) == \
+ [(None, None, None, 'Changed table format to sql-update')]
+ mycli.formatter.query = ""
+ output = mycli.format_output(None, FakeCursor(), headers)
+ actual = "\n".join(output)
+ assert actual == dedent('''\
+ UPDATE `DUAL` SET
+ `number` = 1
+ , `optional` = NULL
+ , `float` = 10.0e0
+ , `binary` = X'aa'
+ WHERE `letters` = 'abc';
+ UPDATE `DUAL` SET
+ `number` = 456
+ , `optional` = '1'
+ , `float` = 0.5e0
+ , `binary` = X'aabb'
+ WHERE `letters` = 'd';''')
+ # Test sql-update-2 output format
+ assert list(mycli.change_table_format("sql-update-2")) == \
+ [(None, None, None, 'Changed table format to sql-update-2')]
+ mycli.formatter.query = ""
+ output = mycli.format_output(None, FakeCursor(), headers)
+ assert "\n".join(output) == dedent('''\
+ UPDATE `DUAL` SET
+ `optional` = NULL
+ , `float` = 10.0e0
+ , `binary` = X'aa'
+ WHERE `letters` = 'abc' AND `number` = 1;
+ UPDATE `DUAL` SET
+ `optional` = '1'
+ , `float` = 0.5e0
+ , `binary` = X'aabb'
+ WHERE `letters` = 'd' AND `number` = 456;''')
+ # Test sql-insert output format (without table name)
+ assert list(mycli.change_table_format("sql-insert")) == \
+ [(None, None, None, 'Changed table format to sql-insert')]
+ mycli.formatter.query = ""
+ output = mycli.format_output(None, FakeCursor(), headers)
+ assert "\n".join(output) == dedent('''\
+ INSERT INTO `DUAL` (`letters`, `number`, `optional`, `float`, `binary`) VALUES
+ ('abc', 1, NULL, 10.0e0, X'aa')
+ , ('d', 456, '1', 0.5e0, X'aabb')
+ ;''')
+ # Test sql-insert output format (with table name)
+ assert list(mycli.change_table_format("sql-insert")) == \
+ [(None, None, None, 'Changed table format to sql-insert')]
+ mycli.formatter.query = "SELECT * FROM `table`"
+ output = mycli.format_output(None, FakeCursor(), headers)
+ assert "\n".join(output) == dedent('''\
+ INSERT INTO `table` (`letters`, `number`, `optional`, `float`, `binary`) VALUES
+ ('abc', 1, NULL, 10.0e0, X'aa')
+ , ('d', 456, '1', 0.5e0, X'aabb')
+ ;''')
+ # Test sql-insert output format (with database + table name)
+ assert list(mycli.change_table_format("sql-insert")) == \
+ [(None, None, None, 'Changed table format to sql-insert')]
+ mycli.formatter.query = "SELECT * FROM `database`.`table`"
+ output = mycli.format_output(None, FakeCursor(), headers)
+ assert "\n".join(output) == dedent('''\
+ INSERT INTO `database`.`table` (`letters`, `number`, `optional`, `float`, `binary`) VALUES
+ ('abc', 1, NULL, 10.0e0, X'aa')
+ , ('d', 456, '1', 0.5e0, X'aabb')
+ ;''')
diff --git a/test/utils.py b/test/utils.py
new file mode 100644
index 0000000..ab12248
--- /dev/null
+++ b/test/utils.py
@@ -0,0 +1,94 @@
+import os
+import time
+import signal
+import platform
+import multiprocessing
+
+import pymysql
+import pytest
+
+from mycli.main import special
+
+PASSWORD = os.getenv('PYTEST_PASSWORD')
+USER = os.getenv('PYTEST_USER', 'root')
+HOST = os.getenv('PYTEST_HOST', 'localhost')
+PORT = int(os.getenv('PYTEST_PORT', 3306))
+CHARSET = os.getenv('PYTEST_CHARSET', 'utf8')
+SSH_USER = os.getenv('PYTEST_SSH_USER', None)
+SSH_HOST = os.getenv('PYTEST_SSH_HOST', None)
+SSH_PORT = os.getenv('PYTEST_SSH_PORT', 22)
+
+
+def db_connection(dbname=None):
+ conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname,
+ password=PASSWORD, charset=CHARSET,
+ local_infile=False)
+ conn.autocommit = True
+ return conn
+
+
+try:
+ db_connection()
+ CAN_CONNECT_TO_DB = True
+except:
+ CAN_CONNECT_TO_DB = False
+
+dbtest = pytest.mark.skipif(
+ not CAN_CONNECT_TO_DB,
+ reason="Need a mysql instance at localhost accessible by user 'root'")
+
+
+def create_db(dbname):
+ with db_connection().cursor() as cur:
+ try:
+ cur.execute('''DROP DATABASE IF EXISTS mycli_test_db''')
+ cur.execute('''CREATE DATABASE mycli_test_db''')
+ except:
+ pass
+
+
+def run(executor, sql, rows_as_list=True):
+ """Return string output for the sql to be run."""
+ result = []
+
+ for title, rows, headers, status in executor.run(sql):
+ rows = list(rows) if (rows_as_list and rows) else rows
+ result.append({'title': title, 'rows': rows, 'headers': headers,
+ 'status': status})
+
+ return result
+
+
+def set_expanded_output(is_expanded):
+ """Pass-through for the tests."""
+ return special.set_expanded_output(is_expanded)
+
+
+def is_expanded_output():
+ """Pass-through for the tests."""
+ return special.is_expanded_output()
+
+
+def send_ctrl_c_to_pid(pid, wait_seconds):
+ """Sends a Ctrl-C like signal to the given `pid` after `wait_seconds`
+ seconds."""
+ time.sleep(wait_seconds)
+ system_name = platform.system()
+ if system_name == "Windows":
+ os.kill(pid, signal.CTRL_C_EVENT)
+ else:
+ os.kill(pid, signal.SIGINT)
+
+
+def send_ctrl_c(wait_seconds):
+ """Create a process that sends a Ctrl-C like signal to the current process
+ after `wait_seconds` seconds.
+
+ Returns the `multiprocessing.Process` created.
+
+ """
+ ctrl_c_process = multiprocessing.Process(
+ target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds)
+ )
+ ctrl_c_process.start()
+ return ctrl_c_process