summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/conftest.py40
-rw-r--r--tests/data/import_data.csv2
-rw-r--r--tests/liteclirc128
-rw-r--r--tests/test.txt1
-rw-r--r--tests/test_clistyle.py28
-rw-r--r--tests/test_completion_engine.py655
-rw-r--r--tests/test_completion_refresher.py94
-rw-r--r--tests/test_dbspecial.py65
-rw-r--r--tests/test_main.py261
-rw-r--r--tests/test_parseutils.py131
-rw-r--r--tests/test_prompt_utils.py14
-rw-r--r--tests/test_smart_completion_public_schema_only.py430
-rw-r--r--tests/test_sqlexecute.py392
-rw-r--r--tests/utils.py96
14 files changed, 2337 insertions, 0 deletions
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..dce0d7e
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,40 @@
+from __future__ import print_function
+
+import os
+import pytest
+from utils import create_db, db_connection, drop_tables
+import litecli.sqlexecute
+
+
+@pytest.yield_fixture(scope="function")
+def connection():
+ create_db("_test_db")
+ connection = db_connection("_test_db")
+ yield connection
+
+ drop_tables(connection)
+ connection.close()
+ os.remove("_test_db")
+
+
+@pytest.fixture
+def cursor(connection):
+ with connection.cursor() as cur:
+ return cur
+
+
+@pytest.fixture
+def executor(connection):
+ return litecli.sqlexecute.SQLExecute(database="_test_db")
+
+
+@pytest.fixture
+def exception_formatter():
+ return lambda e: str(e)
+
+
+@pytest.fixture(scope="session", autouse=True)
+def temp_config(tmpdir_factory):
+ # this function runs on start of test session.
+ # use temporary directory for config home so user config will not be used
+ os.environ["XDG_CONFIG_HOME"] = str(tmpdir_factory.mktemp("data"))
diff --git a/tests/data/import_data.csv b/tests/data/import_data.csv
new file mode 100644
index 0000000..d68d655
--- /dev/null
+++ b/tests/data/import_data.csv
@@ -0,0 +1,2 @@
+t1,11
+t2,22
diff --git a/tests/liteclirc b/tests/liteclirc
new file mode 100644
index 0000000..e31942f
--- /dev/null
+++ b/tests/liteclirc
@@ -0,0 +1,128 @@
+[main]
+
+# Multi-line mode allows breaking up the sql statements into multiple lines. If
+# this is set to True, then the end of the statements must have a semi-colon.
+# If this is set to False then sql statements can't be split into multiple
+# lines. End of line (return) is considered as the end of the statement.
+multi_line = False
+
+# Destructive warning mode will alert you before executing a sql statement
+# that may cause harm to the database such as "drop table", "drop database"
+# or "shutdown".
+destructive_warning = True
+
+# log_file location.
+# In Unix/Linux: ~/.config/litecli/log
+# In Windows: %USERPROFILE%\AppData\Local\dbcli\litecli\log
+# %USERPROFILE% is typically C:\Users\{username}
+log_file = default
+
+# Default log level. Possible values: "CRITICAL", "ERROR", "WARNING", "INFO"
+# and "DEBUG". "NONE" disables logging.
+log_level = INFO
+
+# Log every query and its results to a file. Enable this by uncommenting the
+# line below.
+# audit_log = ~/.litecli-audit.log
+
+# Default pager.
+# By default '$PAGER' environment variable is used
+# pager = less -SRXF
+
+# Table format. Possible values:
+# ascii, double, github, psql, plain, simple, grid, fancy_grid, pipe, orgtbl,
+# rst, mediawiki, html, latex, latex_booktabs, textile, moinmoin, jira,
+# vertical, tsv, csv.
+# Recommended: ascii
+table_format = ascii
+
+# Syntax coloring style. Possible values (many support the "-dark" suffix):
+# manni, igor, xcode, vim, autumn, vs, rrt, native, perldoc, borland, tango, emacs,
+# friendly, monokai, paraiso, colorful, murphy, bw, pastie, paraiso, trac, default,
+# fruity.
+# Screenshots at http://mycli.net/syntax
+syntax_style = default
+
+# Keybindings: Possible values: emacs, vi.
+# Emacs mode: Ctrl-A is home, Ctrl-E is end. All emacs keybindings are available in the REPL.
+# When Vi mode is enabled you can use modal editing features offered by Vi in the REPL.
+key_bindings = emacs
+
+# Enabling this option will show the suggestions in a wider menu. Thus more items are suggested.
+wider_completion_menu = False
+
+# litecli prompt
+# \D - The full current date
+# \d - Database name
+# \m - Minutes of the current time
+# \n - Newline
+# \P - AM/PM
+# \R - The current time, in 24-hour military time (0-23)
+# \r - The current time, standard 12-hour time (1-12)
+# \s - Seconds of the current time
+prompt = "\t :\d> "
+prompt_continuation = "-> "
+
+# Skip intro info on startup and outro info on exit
+less_chatty = False
+
+# Use alias from --login-path instead of host name in prompt
+login_path_as_host = False
+
+# Cause result sets to be displayed vertically if they are too wide for the current window,
+# and using normal tabular format otherwise. (This applies to statements terminated by ; or \G.)
+auto_vertical_output = False
+
+# keyword casing preference. Possible values "lower", "upper", "auto"
+keyword_casing = auto
+
+# disabled pager on startup
+enable_pager = True
+[colors]
+completion-menu.completion.current = "bg:#ffffff #000000"
+completion-menu.completion = "bg:#008888 #ffffff"
+completion-menu.meta.completion.current = "bg:#44aaaa #000000"
+completion-menu.meta.completion = "bg:#448888 #ffffff"
+completion-menu.multi-column-meta = "bg:#aaffff #000000"
+scrollbar.arrow = "bg:#003333"
+scrollbar = "bg:#00aaaa"
+selected = "#ffffff bg:#6666aa"
+search = "#ffffff bg:#4444aa"
+search.current = "#ffffff bg:#44aa44"
+bottom-toolbar = "bg:#222222 #aaaaaa"
+bottom-toolbar.off = "bg:#222222 #888888"
+bottom-toolbar.on = "bg:#222222 #ffffff"
+search-toolbar = noinherit bold
+search-toolbar.text = nobold
+system-toolbar = noinherit bold
+arg-toolbar = noinherit bold
+arg-toolbar.text = nobold
+bottom-toolbar.transaction.valid = "bg:#222222 #00ff5f bold"
+bottom-toolbar.transaction.failed = "bg:#222222 #ff005f bold"
+
+# style classes for colored table output
+output.header = "#00ff5f bold"
+output.odd-row = ""
+output.even-row = ""
+Token.Menu.Completions.Completion.Current = "bg:#00aaaa #000000"
+Token.Menu.Completions.Completion = "bg:#008888 #ffffff"
+Token.Menu.Completions.MultiColumnMeta = "bg:#aaffff #000000"
+Token.Menu.Completions.ProgressButton = "bg:#003333"
+Token.Menu.Completions.ProgressBar = "bg:#00aaaa"
+Token.Output.Header = bold
+Token.Output.OddRow = ""
+Token.Output.EvenRow = ""
+Token.SelectedText = "#ffffff bg:#6666aa"
+Token.SearchMatch = "#ffffff bg:#4444aa"
+Token.SearchMatch.Current = "#ffffff bg:#44aa44"
+Token.Toolbar = "bg:#222222 #aaaaaa"
+Token.Toolbar.Off = "bg:#222222 #888888"
+Token.Toolbar.On = "bg:#222222 #ffffff"
+Token.Toolbar.Search = noinherit bold
+Token.Toolbar.Search.Text = nobold
+Token.Toolbar.System = noinherit bold
+Token.Toolbar.Arg = noinherit bold
+Token.Toolbar.Arg.Text = nobold
+[favorite_queries]
+q_param = select * from test where name=?
+sh_param = select * from test where id=$1
diff --git a/tests/test.txt b/tests/test.txt
new file mode 100644
index 0000000..1fa4cf0
--- /dev/null
+++ b/tests/test.txt
@@ -0,0 +1 @@
+litecli is awesome!
diff --git a/tests/test_clistyle.py b/tests/test_clistyle.py
new file mode 100644
index 0000000..c1177de
--- /dev/null
+++ b/tests/test_clistyle.py
@@ -0,0 +1,28 @@
+# -*- coding: utf-8 -*-
+"""Test the litecli.clistyle module."""
+import pytest
+
+from pygments.style import Style
+from pygments.token import Token
+
+from litecli.clistyle import style_factory
+
+
+@pytest.mark.skip(reason="incompatible with new prompt toolkit")
+def test_style_factory():
+ """Test that a Pygments Style class is created."""
+ header = "bold underline #ansired"
+ cli_style = {"Token.Output.Header": header}
+ style = style_factory("default", cli_style)
+
+ assert isinstance(style(), Style)
+ assert Token.Output.Header in style.styles
+ assert header == style.styles[Token.Output.Header]
+
+
+@pytest.mark.skip(reason="incompatible with new prompt toolkit")
+def test_style_factory_unknown_name():
+ """Test that an unrecognized name will not throw an error."""
+ style = style_factory("foobar", {})
+
+ assert isinstance(style(), Style)
diff --git a/tests/test_completion_engine.py b/tests/test_completion_engine.py
new file mode 100644
index 0000000..84d5536
--- /dev/null
+++ b/tests/test_completion_engine.py
@@ -0,0 +1,655 @@
+from litecli.packages.completion_engine import suggest_type
+import pytest
+
+
+def sorted_dicts(dicts):
+ """input is a list of dicts."""
+ return sorted(tuple(x.items()) for x in dicts)
+
+
+def test_select_suggests_cols_with_visible_table_scope():
+ suggestions = suggest_type("SELECT FROM tabl", "SELECT ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["tabl"]},
+ {"type": "column", "tables": [(None, "tabl", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+def test_select_suggests_cols_with_qualified_table_scope():
+ suggestions = suggest_type("SELECT FROM sch.tabl", "SELECT ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["tabl"]},
+ {"type": "column", "tables": [("sch", "tabl", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+def test_order_by_suggests_cols_with_qualified_table_scope():
+ suggestions = suggest_type(
+ "SELECT * FROM sch.tabl ORDER BY ", "SELECT * FROM sch.tabl ORDER BY "
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [{"type": "column", "tables": [("sch", "tabl", None)]},]
+ )
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM tabl WHERE ",
+ "SELECT * FROM tabl WHERE (",
+ "SELECT * FROM tabl WHERE foo = ",
+ "SELECT * FROM tabl WHERE bar OR ",
+ "SELECT * FROM tabl WHERE foo = 1 AND ",
+ "SELECT * FROM tabl WHERE (bar > 10 AND ",
+ "SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (",
+ "SELECT * FROM tabl WHERE 10 < ",
+ "SELECT * FROM tabl WHERE foo BETWEEN ",
+ "SELECT * FROM tabl WHERE foo BETWEEN foo AND ",
+ ],
+)
+def test_where_suggests_columns_functions(expression):
+ suggestions = suggest_type(expression, expression)
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["tabl"]},
+ {"type": "column", "tables": [(None, "tabl", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "expression",
+ ["SELECT * FROM tabl WHERE foo IN (", "SELECT * FROM tabl WHERE foo IN (bar, "],
+)
+def test_where_in_suggests_columns(expression):
+ suggestions = suggest_type(expression, expression)
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["tabl"]},
+ {"type": "column", "tables": [(None, "tabl", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+def test_where_equals_any_suggests_columns_or_keywords():
+ text = "SELECT * FROM tabl WHERE foo = ANY("
+ suggestions = suggest_type(text, text)
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["tabl"]},
+ {"type": "column", "tables": [(None, "tabl", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+def test_lparen_suggests_cols():
+ suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(")
+ assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}]
+
+
+def test_operand_inside_function_suggests_cols1():
+ suggestion = suggest_type("SELECT MAX(col1 + FROM tbl", "SELECT MAX(col1 + ")
+ assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}]
+
+
+def test_operand_inside_function_suggests_cols2():
+ suggestion = suggest_type(
+ "SELECT MAX(col1 + col2 + FROM tbl", "SELECT MAX(col1 + col2 + "
+ )
+ assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}]
+
+
+def test_select_suggests_cols_and_funcs():
+ suggestions = suggest_type("SELECT ", "SELECT ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": []},
+ {"type": "column", "tables": []},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM ",
+ "INSERT INTO ",
+ "COPY ",
+ "UPDATE ",
+ "DESCRIBE ",
+ "DESC ",
+ "EXPLAIN ",
+ "SELECT * FROM foo JOIN ",
+ ],
+)
+def test_expression_suggests_tables_views_and_schemas(expression):
+ suggestions = suggest_type(expression, expression)
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM sch.",
+ "INSERT INTO sch.",
+ "COPY sch.",
+ "UPDATE sch.",
+ "DESCRIBE sch.",
+ "DESC sch.",
+ "EXPLAIN sch.",
+ "SELECT * FROM foo JOIN sch.",
+ ],
+)
+def test_expression_suggests_qualified_tables_views_and_schemas(expression):
+ suggestions = suggest_type(expression, expression)
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [{"type": "table", "schema": "sch"}, {"type": "view", "schema": "sch"}]
+ )
+
+
+def test_truncate_suggests_tables_and_schemas():
+ suggestions = suggest_type("TRUNCATE ", "TRUNCATE ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [{"type": "table", "schema": []}, {"type": "schema"}]
+ )
+
+
+def test_truncate_suggests_qualified_tables():
+ suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [{"type": "table", "schema": "sch"}]
+ )
+
+
+def test_distinct_suggests_cols():
+ suggestions = suggest_type("SELECT DISTINCT ", "SELECT DISTINCT ")
+ assert suggestions == [{"type": "column", "tables": []}]
+
+
+def test_col_comma_suggests_cols():
+ suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["tbl"]},
+ {"type": "column", "tables": [(None, "tbl", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+def test_table_comma_suggests_tables_and_schemas():
+ suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+
+def test_into_suggests_tables_and_schemas():
+ suggestion = suggest_type("INSERT INTO ", "INSERT INTO ")
+ assert sorted_dicts(suggestion) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+
+def test_insert_into_lparen_suggests_cols():
+ suggestions = suggest_type("INSERT INTO abc (", "INSERT INTO abc (")
+ assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}]
+
+
+def test_insert_into_lparen_partial_text_suggests_cols():
+ suggestions = suggest_type("INSERT INTO abc (i", "INSERT INTO abc (i")
+ assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}]
+
+
+def test_insert_into_lparen_comma_suggests_cols():
+ suggestions = suggest_type("INSERT INTO abc (id,", "INSERT INTO abc (id,")
+ assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}]
+
+
+def test_partially_typed_col_name_suggests_col_names():
+ suggestions = suggest_type(
+ "SELECT * FROM tabl WHERE col_n", "SELECT * FROM tabl WHERE col_n"
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["tabl"]},
+ {"type": "column", "tables": [(None, "tabl", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
+ suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "column", "tables": [(None, "tabl", None)]},
+ {"type": "table", "schema": "tabl"},
+ {"type": "view", "schema": "tabl"},
+ {"type": "function", "schema": "tabl"},
+ ]
+ )
+
+
+def test_dot_suggests_cols_of_an_alias():
+ suggestions = suggest_type("SELECT t1. FROM tabl1 t1, tabl2 t2", "SELECT t1.")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": "t1"},
+ {"type": "view", "schema": "t1"},
+ {"type": "column", "tables": [(None, "tabl1", "t1")]},
+ {"type": "function", "schema": "t1"},
+ ]
+ )
+
+
+def test_dot_col_comma_suggests_cols_or_schema_qualified_table():
+ suggestions = suggest_type(
+ "SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2."
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "column", "tables": [(None, "tabl2", "t2")]},
+ {"type": "table", "schema": "t2"},
+ {"type": "view", "schema": "t2"},
+ {"type": "function", "schema": "t2"},
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM (",
+ "SELECT * FROM foo WHERE EXISTS (",
+ "SELECT * FROM foo WHERE bar AND NOT EXISTS (",
+ "SELECT 1 AS",
+ ],
+)
+def test_sub_select_suggests_keyword(expression):
+ suggestion = suggest_type(expression, expression)
+ assert suggestion == [{"type": "keyword"}]
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM (S",
+ "SELECT * FROM foo WHERE EXISTS (S",
+ "SELECT * FROM foo WHERE bar AND NOT EXISTS (S",
+ ],
+)
+def test_sub_select_partial_text_suggests_keyword(expression):
+ suggestion = suggest_type(expression, expression)
+ assert suggestion == [{"type": "keyword"}]
+
+
+def test_outer_table_reference_in_exists_subquery_suggests_columns():
+ q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f."
+ suggestions = suggest_type(q, q)
+ assert suggestions == [
+ {"type": "column", "tables": [(None, "foo", "f")]},
+ {"type": "table", "schema": "f"},
+ {"type": "view", "schema": "f"},
+ {"type": "function", "schema": "f"},
+ ]
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM (SELECT * FROM ",
+ "SELECT * FROM foo WHERE EXISTS (SELECT * FROM ",
+ "SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ",
+ ],
+)
+def test_sub_select_table_name_completion(expression):
+ suggestion = suggest_type(expression, expression)
+ assert sorted_dicts(suggestion) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+
+def test_sub_select_col_name_completion():
+ suggestions = suggest_type(
+ "SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT "
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["abc"]},
+ {"type": "column", "tables": [(None, "abc", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+@pytest.mark.xfail
+def test_sub_select_multiple_col_name_completion():
+ suggestions = suggest_type(
+ "SELECT * FROM (SELECT a, FROM abc", "SELECT * FROM (SELECT a, "
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "column", "tables": [(None, "abc", None)]},
+ {"type": "function", "schema": []},
+ ]
+ )
+
+
+def test_sub_select_dot_col_name_completion():
+ suggestions = suggest_type(
+ "SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t."
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "column", "tables": [(None, "tabl", "t")]},
+ {"type": "table", "schema": "t"},
+ {"type": "view", "schema": "t"},
+ {"type": "function", "schema": "t"},
+ ]
+ )
+
+
+@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"])
+@pytest.mark.parametrize("tbl_alias", ["", "foo"])
+def test_join_suggests_tables_and_schemas(tbl_alias, join_type):
+ text = "SELECT * FROM abc {0} {1} JOIN ".format(tbl_alias, join_type)
+ suggestion = suggest_type(text, text)
+ assert sorted_dicts(suggestion) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "SELECT * FROM abc a JOIN def d ON a.",
+ "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.",
+ ],
+)
+def test_join_alias_dot_suggests_cols1(sql):
+ suggestions = suggest_type(sql, sql)
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "column", "tables": [(None, "abc", "a")]},
+ {"type": "table", "schema": "a"},
+ {"type": "view", "schema": "a"},
+ {"type": "function", "schema": "a"},
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "SELECT * FROM abc a JOIN def d ON a.id = d.",
+ "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.",
+ ],
+)
+def test_join_alias_dot_suggests_cols2(sql):
+ suggestions = suggest_type(sql, sql)
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "column", "tables": [(None, "def", "d")]},
+ {"type": "table", "schema": "d"},
+ {"type": "view", "schema": "d"},
+ {"type": "function", "schema": "d"},
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select a.x, b.y from abc a join bcd b on ",
+ "select a.x, b.y from abc a join bcd b on a.id = b.id OR ",
+ ],
+)
+def test_on_suggests_aliases(sql):
+ suggestions = suggest_type(sql, sql)
+ assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}]
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select abc.x, bcd.y from abc join bcd on ",
+ "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ",
+ ],
+)
+def test_on_suggests_tables(sql):
+ suggestions = suggest_type(sql, sql)
+ assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}]
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select a.x, b.y from abc a join bcd b on a.id = ",
+ "select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ",
+ ],
+)
+def test_on_suggests_aliases_right_side(sql):
+ suggestions = suggest_type(sql, sql)
+ assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}]
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select abc.x, bcd.y from abc join bcd on ",
+ "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ",
+ ],
+)
+def test_on_suggests_tables_right_side(sql):
+ suggestions = suggest_type(sql, sql)
+ assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}]
+
+
+@pytest.mark.parametrize("col_list", ["", "col1, "])
+def test_join_using_suggests_common_columns(col_list):
+ text = "select * from abc inner join def using (" + col_list
+ assert suggest_type(text, text) == [
+ {
+ "type": "column",
+ "tables": [(None, "abc", None), (None, "def", None)],
+ "drop_unique": True,
+ }
+ ]
+
+
+def test_2_statements_2nd_current():
+ suggestions = suggest_type(
+ "select * from a; select * from ", "select * from a; select * from "
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+ suggestions = suggest_type(
+ "select * from a; select from b", "select * from a; select "
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["b"]},
+ {"type": "column", "tables": [(None, "b", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+ # Should work even if first statement is invalid
+ suggestions = suggest_type(
+ "select * from; select * from ", "select * from; select * from "
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+
+def test_2_statements_1st_current():
+ suggestions = suggest_type("select * from ; select * from b", "select * from ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+ suggestions = suggest_type("select from a; select * from b", "select ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["a"]},
+ {"type": "column", "tables": [(None, "a", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+def test_3_statements_2nd_current():
+ suggestions = suggest_type(
+ "select * from a; select * from ; select * from c",
+ "select * from a; select * from ",
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+ suggestions = suggest_type(
+ "select * from a; select from b; select * from c", "select * from a; select "
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["b"]},
+ {"type": "column", "tables": [(None, "b", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+def test_create_db_with_template():
+ suggestions = suggest_type(
+ "create database foo with template ", "create database foo with template "
+ )
+
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "database"}])
+
+
+@pytest.mark.parametrize("initial_text", ["", " ", "\t \t"])
+def test_specials_included_for_initial_completion(initial_text):
+ suggestions = suggest_type(initial_text, initial_text)
+
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [{"type": "keyword"}, {"type": "special"}]
+ )
+
+
+def test_specials_not_included_after_initial_token():
+ suggestions = suggest_type("create table foo (dt d", "create table foo (dt d")
+
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "keyword"}])
+
+
+def test_drop_schema_qualified_table_suggests_only_tables():
+ text = "DROP TABLE schema_name.table_name"
+ suggestions = suggest_type(text, text)
+ assert suggestions == [{"type": "table", "schema": "schema_name"}]
+
+
+@pytest.mark.parametrize("text", [",", " ,", "sel ,"])
+def test_handle_pre_completion_comma_gracefully(text):
+ suggestions = suggest_type(text, text)
+
+ assert iter(suggestions)
+
+
+def test_cross_join():
+ text = "select * from v1 cross join v2 JOIN v1.id, "
+ suggestions = suggest_type(text, text)
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+
+@pytest.mark.parametrize("expression", ["SELECT 1 AS ", "SELECT 1 FROM tabl AS "])
+def test_after_as(expression):
+ suggestions = suggest_type(expression, expression)
+ assert set(suggestions) == set()
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "\\. ",
+ "select 1; \\. ",
+ "select 1;\\. ",
+ "select 1 ; \\. ",
+ "source ",
+ "truncate table test; source ",
+ "truncate table test ; source ",
+ "truncate table test;source ",
+ ],
+)
+def test_source_is_file(expression):
+ suggestions = suggest_type(expression, expression)
+ assert suggestions == [{"type": "file_name"}]
diff --git a/tests/test_completion_refresher.py b/tests/test_completion_refresher.py
new file mode 100644
index 0000000..620a364
--- /dev/null
+++ b/tests/test_completion_refresher.py
@@ -0,0 +1,94 @@
+import time
+import pytest
+from mock import Mock, patch
+
+
+@pytest.fixture
+def refresher():
+ from litecli.completion_refresher import CompletionRefresher
+
+ return CompletionRefresher()
+
+
+def test_ctor(refresher):
+ """Refresher object should contain a few handlers.
+
+ :param refresher:
+ :return:
+
+ """
+ assert len(refresher.refreshers) > 0
+ actual_handlers = list(refresher.refreshers.keys())
+ expected_handlers = [
+ "databases",
+ "schemata",
+ "tables",
+ "functions",
+ "special_commands",
+ ]
+ assert expected_handlers == actual_handlers
+
+
+def test_refresh_called_once(refresher):
+ """
+
+ :param refresher:
+ :return:
+ """
+ callbacks = Mock()
+ sqlexecute = Mock()
+
+ with patch.object(refresher, "_bg_refresh") as bg_refresh:
+ actual = refresher.refresh(sqlexecute, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert len(actual) == 1
+ assert len(actual[0]) == 4
+ assert actual[0][3] == "Auto-completion refresh started in the background."
+ bg_refresh.assert_called_with(sqlexecute, callbacks, {})
+
+
+def test_refresh_called_twice(refresher):
+ """If refresh is called a second time, it should be restarted.
+
+ :param refresher:
+ :return:
+
+ """
+ callbacks = Mock()
+
+ sqlexecute = Mock()
+
+ def dummy_bg_refresh(*args):
+ time.sleep(3) # seconds
+
+ refresher._bg_refresh = dummy_bg_refresh
+
+ actual1 = refresher.refresh(sqlexecute, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert len(actual1) == 1
+ assert len(actual1[0]) == 4
+ assert actual1[0][3] == "Auto-completion refresh started in the background."
+
+ actual2 = refresher.refresh(sqlexecute, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert len(actual2) == 1
+ assert len(actual2[0]) == 4
+ assert actual2[0][3] == "Auto-completion refresh restarted."
+
+
+def test_refresh_with_callbacks(refresher):
+ """Callbacks must be called.
+
+ :param refresher:
+
+ """
+ callbacks = [Mock()]
+ sqlexecute_class = Mock()
+ sqlexecute = Mock()
+
+ with patch("litecli.completion_refresher.SQLExecute", sqlexecute_class):
+ # Set refreshers to 0: we're not testing refresh logic here
+ refresher.refreshers = {}
+ refresher.refresh(sqlexecute, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert callbacks[0].call_count == 1
diff --git a/tests/test_dbspecial.py b/tests/test_dbspecial.py
new file mode 100644
index 0000000..c7065a9
--- /dev/null
+++ b/tests/test_dbspecial.py
@@ -0,0 +1,65 @@
+from litecli.packages.completion_engine import suggest_type
+from test_completion_engine import sorted_dicts
+from litecli.packages.special.utils import format_uptime
+
+
+def test_import_first_argument():
+ test_cases = [
+ # text, expecting_arg_idx
+ [".import ", 1],
+ [".import ./da", 1],
+ [".import ./data.csv ", 2],
+ [".import ./data.csv t", 2],
+ [".import ./data.csv `t", 2],
+ ['.import ./data.csv "t', 2],
+ ]
+ for text, expecting_arg_idx in test_cases:
+ suggestions = suggest_type(text, text)
+ if expecting_arg_idx == 1:
+ assert suggestions == [{"type": "file_name"}]
+ else:
+ assert suggestions == [{"type": "table", "schema": []}]
+
+
+def test_u_suggests_databases():
+ suggestions = suggest_type("\\u ", "\\u ")
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "database"}])
+
+
+def test_describe_table():
+ suggestions = suggest_type("\\dt", "\\dt ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+
+def test_list_or_show_create_tables():
+ suggestions = suggest_type("\\dt+", "\\dt+ ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+
+def test_format_uptime():
+ seconds = 59
+ assert "59 sec" == format_uptime(seconds)
+
+ seconds = 120
+ assert "2 min 0 sec" == format_uptime(seconds)
+
+ seconds = 54890
+ assert "15 hours 14 min 50 sec" == format_uptime(seconds)
+
+ seconds = 598244
+ assert "6 days 22 hours 10 min 44 sec" == format_uptime(seconds)
+
+ seconds = 522600
+ assert "6 days 1 hour 10 min 0 sec" == format_uptime(seconds)
diff --git a/tests/test_main.py b/tests/test_main.py
new file mode 100644
index 0000000..90132f1
--- /dev/null
+++ b/tests/test_main.py
@@ -0,0 +1,261 @@
+import os
+from collections import namedtuple
+from textwrap import dedent
+from tempfile import NamedTemporaryFile
+
+import click
+from click.testing import CliRunner
+
+from litecli.main import cli, LiteCli
+from litecli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
+from utils import dbtest, run
+
+test_dir = os.path.abspath(os.path.dirname(__file__))
+project_dir = os.path.dirname(test_dir)
+default_config_file = os.path.join(project_dir, "tests", "liteclirc")
+
+CLI_ARGS = ["--liteclirc", default_config_file, "_test_db"]
+
+
+@dbtest
+def test_execute_arg(executor):
+ run(executor, "create table test (a text)")
+ run(executor, 'insert into test values("abc")')
+
+ sql = "select * from test;"
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql])
+
+ assert result.exit_code == 0
+ assert "abc" in result.output
+
+ result = runner.invoke(cli, args=CLI_ARGS + ["--execute", sql])
+
+ assert result.exit_code == 0
+ assert "abc" in result.output
+
+ expected = "a\nabc\n"
+
+ assert expected in result.output
+
+
+@dbtest
+def test_execute_arg_with_table(executor):
+ run(executor, "create table test (a text)")
+ run(executor, 'insert into test values("abc")')
+
+ sql = "select * from test;"
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql] + ["--table"])
+ expected = "+-----+\n| a |\n+-----+\n| abc |\n+-----+\n"
+
+ assert result.exit_code == 0
+ assert expected in result.output
+
+
+@dbtest
+def test_execute_arg_with_csv(executor):
+ run(executor, "create table test (a text)")
+ run(executor, 'insert into test values("abc")')
+
+ sql = "select * from test;"
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql] + ["--csv"])
+ expected = '"a"\n"abc"\n'
+
+ assert result.exit_code == 0
+ assert expected in "".join(result.output)
+
+
+@dbtest
+def test_batch_mode(executor):
+ run(executor, """create table test(a text)""")
+ run(executor, """insert into test values('abc'), ('def'), ('ghi')""")
+
+ sql = "select count(*) from test;\n" "select * from test limit 1;"
+
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS, input=sql)
+
+ assert result.exit_code == 0
+ assert "count(*)\n3\na\nabc\n" in "".join(result.output)
+
+
+@dbtest
+def test_batch_mode_table(executor):
+ run(executor, """create table test(a text)""")
+ run(executor, """insert into test values('abc'), ('def'), ('ghi')""")
+
+ sql = "select count(*) from test;\n" "select * from test limit 1;"
+
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ["-t"], input=sql)
+
+ expected = dedent(
+ """\
+ +----------+
+ | count(*) |
+ +----------+
+ | 3 |
+ +----------+
+ +-----+
+ | a |
+ +-----+
+ | abc |
+ +-----+"""
+ )
+
+ assert result.exit_code == 0
+ assert expected in result.output
+
+
+@dbtest
+def test_batch_mode_csv(executor):
+ run(executor, """create table test(a text, b text)""")
+ run(executor, """insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')""")
+
+ sql = "select * from test;"
+
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ["--csv"], input=sql)
+
+ expected = '"a","b"\n"abc","de\nf"\n"ghi","jkl"\n'
+
+ assert result.exit_code == 0
+ assert expected in "".join(result.output)
+
+
+def test_help_strings_end_with_periods():
+ """Make sure click options have help text that end with a period."""
+ for param in cli.params:
+ if isinstance(param, click.core.Option):
+ assert hasattr(param, "help")
+ assert param.help.endswith(".")
+
+
+def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
+ global clickoutput
+ clickoutput = ""
+ m = LiteCli(liteclirc=default_config_file)
+
+ class TestOutput:
+ def get_size(self):
+ size = namedtuple("Size", "rows columns")
+ size.columns, size.rows = terminal_size
+ return size
+
+ class TestExecute:
+ host = "test"
+ user = "test"
+ dbname = "test"
+ port = 0
+
+ def server_type(self):
+ return ["test"]
+
+ class PromptBuffer:
+ output = TestOutput()
+
+ m.prompt_app = PromptBuffer()
+ m.sqlexecute = TestExecute()
+ m.explicit_pager = explicit_pager
+
+ def echo_via_pager(s):
+ assert expect_pager
+ global clickoutput
+ clickoutput += s
+
+ def secho(s):
+ assert not expect_pager
+ global clickoutput
+ clickoutput += s + "\n"
+
+ monkeypatch.setattr(click, "echo_via_pager", echo_via_pager)
+ monkeypatch.setattr(click, "secho", secho)
+ m.output(testdata)
+ if clickoutput.endswith("\n"):
+ clickoutput = clickoutput[:-1]
+ assert clickoutput == "\n".join(testdata)
+
+
+def test_conditional_pager(monkeypatch):
+ testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split(
+ " "
+ )
+ # User didn't set pager, output doesn't fit screen -> pager
+ output(
+ monkeypatch,
+ terminal_size=(5, 10),
+ testdata=testdata,
+ explicit_pager=False,
+ expect_pager=True,
+ )
+ # User didn't set pager, output fits screen -> no pager
+ output(
+ monkeypatch,
+ terminal_size=(20, 20),
+ testdata=testdata,
+ explicit_pager=False,
+ expect_pager=False,
+ )
+ # User manually configured pager, output doesn't fit screen -> pager
+ output(
+ monkeypatch,
+ terminal_size=(5, 10),
+ testdata=testdata,
+ explicit_pager=True,
+ expect_pager=True,
+ )
+ # User manually configured pager, output fit screen -> pager
+ output(
+ monkeypatch,
+ terminal_size=(20, 20),
+ testdata=testdata,
+ explicit_pager=True,
+ expect_pager=True,
+ )
+
+ SPECIAL_COMMANDS["nopager"].handler()
+ output(
+ monkeypatch,
+ terminal_size=(5, 10),
+ testdata=testdata,
+ explicit_pager=False,
+ expect_pager=False,
+ )
+ SPECIAL_COMMANDS["pager"].handler("")
+
+
+def test_reserved_space_is_integer():
+ """Make sure that reserved space is returned as an integer."""
+
+ def stub_terminal_size():
+ return (5, 5)
+
+ old_func = click.get_terminal_size
+
+ click.get_terminal_size = stub_terminal_size
+ lc = LiteCli()
+ assert isinstance(lc.get_reserved_space(), int)
+ click.get_terminal_size = old_func
+
+
+@dbtest
+def test_import_command(executor):
+ data_file = os.path.join(project_dir, "tests", "data", "import_data.csv")
+ run(executor, """create table tbl1(one varchar(10), two smallint)""")
+
+ # execute
+ run(executor, """.import %s tbl1""" % data_file)
+
+ # verify
+ sql = "select * from tbl1;"
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ["--csv"], input=sql)
+
+ expected = """one","two"
+"t1","11"
+"t2","22"
+"""
+ assert result.exit_code == 0
+ assert expected in "".join(result.output)
diff --git a/tests/test_parseutils.py b/tests/test_parseutils.py
new file mode 100644
index 0000000..cad7a8c
--- /dev/null
+++ b/tests/test_parseutils.py
@@ -0,0 +1,131 @@
+import pytest
+from litecli.packages.parseutils import (
+ extract_tables,
+ query_starts_with,
+ queries_start_with,
+ is_destructive,
+)
+
+
+def test_empty_string():
+ tables = extract_tables("")
+ assert tables == []
+
+
+def test_simple_select_single_table():
+ tables = extract_tables("select * from abc")
+ assert tables == [(None, "abc", None)]
+
+
+def test_simple_select_single_table_schema_qualified():
+ tables = extract_tables("select * from abc.def")
+ assert tables == [("abc", "def", None)]
+
+
+def test_simple_select_multiple_tables():
+ tables = extract_tables("select * from abc, def")
+ assert sorted(tables) == [(None, "abc", None), (None, "def", None)]
+
+
+def test_simple_select_multiple_tables_schema_qualified():
+ tables = extract_tables("select * from abc.def, ghi.jkl")
+ assert sorted(tables) == [("abc", "def", None), ("ghi", "jkl", None)]
+
+
+def test_simple_select_with_cols_single_table():
+ tables = extract_tables("select a,b from abc")
+ assert tables == [(None, "abc", None)]
+
+
+def test_simple_select_with_cols_single_table_schema_qualified():
+ tables = extract_tables("select a,b from abc.def")
+ assert tables == [("abc", "def", None)]
+
+
+def test_simple_select_with_cols_multiple_tables():
+ tables = extract_tables("select a,b from abc, def")
+ assert sorted(tables) == [(None, "abc", None), (None, "def", None)]
+
+
+def test_simple_select_with_cols_multiple_tables_with_schema():
+ tables = extract_tables("select a,b from abc.def, def.ghi")
+ assert sorted(tables) == [("abc", "def", None), ("def", "ghi", None)]
+
+
+def test_select_with_hanging_comma_single_table():
+ tables = extract_tables("select a, from abc")
+ assert tables == [(None, "abc", None)]
+
+
+def test_select_with_hanging_comma_multiple_tables():
+ tables = extract_tables("select a, from abc, def")
+ assert sorted(tables) == [(None, "abc", None), (None, "def", None)]
+
+
+def test_select_with_hanging_period_multiple_tables():
+ tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2")
+ assert sorted(tables) == [(None, "tabl1", "t1"), (None, "tabl2", "t2")]
+
+
+def test_simple_insert_single_table():
+ tables = extract_tables('insert into abc (id, name) values (1, "def")')
+
+ # sqlparse mistakenly assigns an alias to the table
+ # assert tables == [(None, 'abc', None)]
+ assert tables == [(None, "abc", "abc")]
+
+
+@pytest.mark.xfail
+def test_simple_insert_single_table_schema_qualified():
+ tables = extract_tables('insert into abc.def (id, name) values (1, "def")')
+ assert tables == [("abc", "def", None)]
+
+
+def test_simple_update_table():
+ tables = extract_tables("update abc set id = 1")
+ assert tables == [(None, "abc", None)]
+
+
+def test_simple_update_table_with_schema():
+ tables = extract_tables("update abc.def set id = 1")
+ assert tables == [("abc", "def", None)]
+
+
+def test_join_table():
+ tables = extract_tables("SELECT * FROM abc a JOIN def d ON a.id = d.num")
+ assert sorted(tables) == [(None, "abc", "a"), (None, "def", "d")]
+
+
+def test_join_table_schema_qualified():
+ tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num")
+ assert tables == [("abc", "def", "x"), ("ghi", "jkl", "y")]
+
+
+def test_join_as_table():
+ tables = extract_tables("SELECT * FROM my_table AS m WHERE m.a > 5")
+ assert tables == [(None, "my_table", "m")]
+
+
+def test_query_starts_with():
+ query = "USE test;"
+ assert query_starts_with(query, ("use",)) is True
+
+ query = "DROP DATABASE test;"
+ assert query_starts_with(query, ("use",)) is False
+
+
+def test_query_starts_with_comment():
+ query = "# comment\nUSE test;"
+ assert query_starts_with(query, ("use",)) is True
+
+
+def test_queries_start_with():
+ sql = "# comment\n" "show databases;" "use foo;"
+ assert queries_start_with(sql, ("show", "select")) is True
+ assert queries_start_with(sql, ("use", "drop")) is True
+ assert queries_start_with(sql, ("delete", "update")) is False
+
+
+def test_is_destructive():
+ sql = "use test;\n" "show databases;\n" "drop database foo;"
+ assert is_destructive(sql) is True
diff --git a/tests/test_prompt_utils.py b/tests/test_prompt_utils.py
new file mode 100644
index 0000000..2de74ce
--- /dev/null
+++ b/tests/test_prompt_utils.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+
+
+import click
+
+from litecli.packages.prompt_utils import confirm_destructive_query
+
+
+def test_confirm_destructive_query_notty():
+ stdin = click.get_text_stream("stdin")
+ assert stdin.isatty() is False
+
+ sql = "drop database foo;"
+ assert confirm_destructive_query(sql) is None
diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py
new file mode 100644
index 0000000..ea5c580
--- /dev/null
+++ b/tests/test_smart_completion_public_schema_only.py
@@ -0,0 +1,430 @@
+# coding: utf-8
+from __future__ import unicode_literals
+import pytest
+from mock import patch
+from prompt_toolkit.completion import Completion
+from prompt_toolkit.document import Document
+
+metadata = {
+ "users": ["id", "email", "first_name", "last_name"],
+ "orders": ["id", "ordered_date", "status"],
+ "select": ["id", "insert", "ABC"],
+ "réveillé": ["id", "insert", "ABC"],
+}
+
+
+@pytest.fixture
+def completer():
+
+ import litecli.sqlcompleter as sqlcompleter
+
+ comp = sqlcompleter.SQLCompleter()
+
+ tables, columns = [], []
+
+ for table, cols in metadata.items():
+ tables.append((table,))
+ columns.extend([(table, col) for col in cols])
+
+ comp.set_dbname("test")
+ comp.extend_schemata("test")
+ comp.extend_relations(tables, kind="tables")
+ comp.extend_columns(columns, kind="tables")
+
+ return comp
+
+
+@pytest.fixture
+def complete_event():
+ from mock import Mock
+
+ return Mock()
+
+
+def test_empty_string_completion(completer, complete_event):
+ text = ""
+ position = 0
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert list(map(Completion, sorted(completer.keywords))) == result
+
+
+def test_select_keyword_completion(completer, complete_event):
+ text = "SEL"
+ position = len("SEL")
+ result = completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ assert list(result) == list([Completion(text="SELECT", start_position=-3)])
+
+
+def test_table_completion(completer, complete_event):
+ text = "SELECT * FROM "
+ position = len(text)
+ result = completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ assert list(result) == list(
+ [
+ Completion(text="`réveillé`", start_position=0),
+ Completion(text="`select`", start_position=0),
+ Completion(text="orders", start_position=0),
+ Completion(text="users", start_position=0),
+ ]
+ )
+
+
+def test_function_name_completion(completer, complete_event):
+ text = "SELECT MA"
+ position = len("SELECT MA")
+ result = completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ assert list(result) == list(
+ [
+ Completion(text="MAX", start_position=-2),
+ Completion(text="MATCH", start_position=-2),
+ ]
+ )
+
+
+def test_suggested_column_names(completer, complete_event):
+ """Suggest column and function names when selecting from table.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = "SELECT from users"
+ position = len("SELECT ")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == list(
+ [
+ Completion(text="*", start_position=0),
+ Completion(text="email", start_position=0),
+ Completion(text="first_name", start_position=0),
+ Completion(text="id", start_position=0),
+ Completion(text="last_name", start_position=0),
+ ]
+ + list(map(Completion, completer.functions))
+ + [Completion(text="users", start_position=0)]
+ + list(map(Completion, sorted(completer.keywords)))
+ )
+
+
+def test_suggested_column_names_in_function(completer, complete_event):
+ """Suggest column and function names when selecting multiple columns from
+ table.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = "SELECT MAX( from users"
+ position = len("SELECT MAX(")
+ result = completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ assert list(result) == list(
+ [
+ Completion(text="*", start_position=0),
+ Completion(text="email", start_position=0),
+ Completion(text="first_name", start_position=0),
+ Completion(text="id", start_position=0),
+ Completion(text="last_name", start_position=0),
+ ]
+ )
+
+
+def test_suggested_column_names_with_table_dot(completer, complete_event):
+ """Suggest column names on table name and dot.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = "SELECT users. from users"
+ position = len("SELECT users.")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == list(
+ [
+ Completion(text="*", start_position=0),
+ Completion(text="email", start_position=0),
+ Completion(text="first_name", start_position=0),
+ Completion(text="id", start_position=0),
+ Completion(text="last_name", start_position=0),
+ ]
+ )
+
+
+def test_suggested_column_names_with_alias(completer, complete_event):
+ """Suggest column names on table alias and dot.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = "SELECT u. from users u"
+ position = len("SELECT u.")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == list(
+ [
+ Completion(text="*", start_position=0),
+ Completion(text="email", start_position=0),
+ Completion(text="first_name", start_position=0),
+ Completion(text="id", start_position=0),
+ Completion(text="last_name", start_position=0),
+ ]
+ )
+
+
+def test_suggested_multiple_column_names(completer, complete_event):
+ """Suggest column and function names when selecting multiple columns from
+ table.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = "SELECT id, from users u"
+ position = len("SELECT id, ")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == list(
+ [
+ Completion(text="*", start_position=0),
+ Completion(text="email", start_position=0),
+ Completion(text="first_name", start_position=0),
+ Completion(text="id", start_position=0),
+ Completion(text="last_name", start_position=0),
+ ]
+ + list(map(Completion, completer.functions))
+ + [Completion(text="u", start_position=0)]
+ + list(map(Completion, sorted(completer.keywords)))
+ )
+
+
+def test_suggested_multiple_column_names_with_alias(completer, complete_event):
+ """Suggest column names on table alias and dot when selecting multiple
+ columns from table.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = "SELECT u.id, u. from users u"
+ position = len("SELECT u.id, u.")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == list(
+ [
+ Completion(text="*", start_position=0),
+ Completion(text="email", start_position=0),
+ Completion(text="first_name", start_position=0),
+ Completion(text="id", start_position=0),
+ Completion(text="last_name", start_position=0),
+ ]
+ )
+
+
+def test_suggested_multiple_column_names_with_dot(completer, complete_event):
+ """Suggest column names on table names and dot when selecting multiple
+ columns from table.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = "SELECT users.id, users. from users u"
+ position = len("SELECT users.id, users.")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == list(
+ [
+ Completion(text="*", start_position=0),
+ Completion(text="email", start_position=0),
+ Completion(text="first_name", start_position=0),
+ Completion(text="id", start_position=0),
+ Completion(text="last_name", start_position=0),
+ ]
+ )
+
+
+def test_suggested_aliases_after_on(completer, complete_event):
+ text = "SELECT u.name, o.id FROM users u JOIN orders o ON "
+ position = len("SELECT u.name, o.id FROM users u JOIN orders o ON ")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == list(
+ [Completion(text="o", start_position=0), Completion(text="u", start_position=0)]
+ )
+
+
+def test_suggested_aliases_after_on_right_side(completer, complete_event):
+ text = "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = "
+ position = len("SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == list(
+ [Completion(text="o", start_position=0), Completion(text="u", start_position=0)]
+ )
+
+
+def test_suggested_tables_after_on(completer, complete_event):
+ text = "SELECT users.name, orders.id FROM users JOIN orders ON "
+ position = len("SELECT users.name, orders.id FROM users JOIN orders ON ")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == list(
+ [
+ Completion(text="orders", start_position=0),
+ Completion(text="users", start_position=0),
+ ]
+ )
+
+
+def test_suggested_tables_after_on_right_side(completer, complete_event):
+ text = "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = "
+ position = len(
+ "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = "
+ )
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert list(result) == list(
+ [
+ Completion(text="orders", start_position=0),
+ Completion(text="users", start_position=0),
+ ]
+ )
+
+
+def test_table_names_after_from(completer, complete_event):
+ text = "SELECT * FROM "
+ position = len("SELECT * FROM ")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert list(result) == list(
+ [
+ Completion(text="`réveillé`", start_position=0),
+ Completion(text="`select`", start_position=0),
+ Completion(text="orders", start_position=0),
+ Completion(text="users", start_position=0),
+ ]
+ )
+
+
+def test_auto_escaped_col_names(completer, complete_event):
+ text = "SELECT from `select`"
+ position = len("SELECT ")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == [
+ Completion(text="*", start_position=0),
+ Completion(text="`ABC`", start_position=0),
+ Completion(text="`insert`", start_position=0),
+ Completion(text="id", start_position=0),
+ ] + list(map(Completion, completer.functions)) + [
+ Completion(text="`select`", start_position=0)
+ ] + list(
+ map(Completion, sorted(completer.keywords))
+ )
+
+
+def test_un_escaped_table_names(completer, complete_event):
+ text = "SELECT from réveillé"
+ position = len("SELECT ")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == list(
+ [
+ Completion(text="*", start_position=0),
+ Completion(text="`ABC`", start_position=0),
+ Completion(text="`insert`", start_position=0),
+ Completion(text="id", start_position=0),
+ ]
+ + list(map(Completion, completer.functions))
+ + [Completion(text="réveillé", start_position=0)]
+ + list(map(Completion, sorted(completer.keywords)))
+ )
+
+
+def dummy_list_path(dir_name):
+ dirs = {
+ "/": ["dir1", "file1.sql", "file2.sql"],
+ "/dir1": ["subdir1", "subfile1.sql", "subfile2.sql"],
+ "/dir1/subdir1": ["lastfile.sql"],
+ }
+ return dirs.get(dir_name, [])
+
+
+@patch("litecli.packages.filepaths.list_path", new=dummy_list_path)
+@pytest.mark.parametrize(
+ "text,expected",
+ [
+ ("source ", [(".", 0), ("..", 0), ("/", 0), ("~", 0)]),
+ ("source /", [("dir1", 0), ("file1.sql", 0), ("file2.sql", 0)]),
+ ("source /dir1/", [("subdir1", 0), ("subfile1.sql", 0), ("subfile2.sql", 0)]),
+ ("source /dir1/subdir1/", [("lastfile.sql", 0)]),
+ ],
+)
+def test_file_name_completion(completer, complete_event, text, expected):
+ position = len(text)
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ expected = list([Completion(txt, pos) for txt, pos in expected])
+ assert result == expected
diff --git a/tests/test_sqlexecute.py b/tests/test_sqlexecute.py
new file mode 100644
index 0000000..2dde4d4
--- /dev/null
+++ b/tests/test_sqlexecute.py
@@ -0,0 +1,392 @@
+# coding=UTF-8
+
+import os
+
+import pytest
+
+from utils import run, dbtest, set_expanded_output, is_expanded_output
+from sqlite3 import OperationalError, ProgrammingError
+
+
+def assert_result_equal(
+ result,
+ title=None,
+ rows=None,
+ headers=None,
+ status=None,
+ auto_status=True,
+ assert_contains=False,
+):
+ """Assert that an sqlexecute.run() result matches the expected values."""
+ if status is None and auto_status and rows:
+ status = "{} row{} in set".format(len(rows), "s" if len(rows) > 1 else "")
+ fields = {"title": title, "rows": rows, "headers": headers, "status": status}
+
+ if assert_contains:
+ # Do a loose match on the results using the *in* operator.
+ for key, field in fields.items():
+ if field:
+ assert field in result[0][key]
+ else:
+ # Do an exact match on the fields.
+ assert result == [fields]
+
+
+@dbtest
+def test_conn(executor):
+ run(executor, """create table test(a text)""")
+ run(executor, """insert into test values('abc')""")
+ results = run(executor, """select * from test""")
+
+ assert_result_equal(results, headers=["a"], rows=[("abc",)])
+
+
+@dbtest
+def test_bools(executor):
+ run(executor, """create table test(a boolean)""")
+ run(executor, """insert into test values(1)""")
+ results = run(executor, """select * from test""")
+
+ assert_result_equal(results, headers=["a"], rows=[(1,)])
+
+
+@dbtest
+def test_binary(executor):
+ run(executor, """create table foo(blb BLOB NOT NULL)""")
+ run(executor, """INSERT INTO foo VALUES ('\x01\x01\x01\n')""")
+ results = run(executor, """select * from foo""")
+
+ expected = "\x01\x01\x01\n"
+
+ assert_result_equal(results, headers=["blb"], rows=[(expected,)])
+
+
+## Failing in Travis for some unknown reason.
+# @dbtest
+# def test_table_and_columns_query(executor):
+# run(executor, "create table a(x text, y text)")
+# run(executor, "create table b(z text)")
+
+# assert set(executor.tables()) == set([("a",), ("b",)])
+# assert set(executor.table_columns()) == set([("a", "x"), ("a", "y"), ("b", "z")])
+
+
+@dbtest
+def test_database_list(executor):
+ databases = executor.databases()
+ assert "main" in list(databases)
+
+
+@dbtest
+def test_invalid_syntax(executor):
+ with pytest.raises(OperationalError) as excinfo:
+ run(executor, "invalid syntax!")
+ assert "syntax error" in str(excinfo.value)
+
+
+@dbtest
+def test_invalid_column_name(executor):
+ with pytest.raises(OperationalError) as excinfo:
+ run(executor, "select invalid command")
+ assert "no such column: invalid" in str(excinfo.value)
+
+
+@dbtest
+def test_unicode_support_in_output(executor):
+ run(executor, "create table unicodechars(t text)")
+ run(executor, u"insert into unicodechars (t) values ('é')")
+
+ # See issue #24, this raises an exception without proper handling
+ results = run(executor, u"select * from unicodechars")
+ assert_result_equal(results, headers=["t"], rows=[(u"é",)])
+
+
+@dbtest
+def test_multiple_queries_same_line(executor):
+ results = run(executor, "select 'foo'; select 'bar'")
+
+ expected = [
+ {
+ "title": None,
+ "headers": ["'foo'"],
+ "rows": [(u"foo",)],
+ "status": "1 row in set",
+ },
+ {
+ "title": None,
+ "headers": ["'bar'"],
+ "rows": [(u"bar",)],
+ "status": "1 row in set",
+ },
+ ]
+ assert expected == results
+
+
+@dbtest
+def test_multiple_queries_same_line_syntaxerror(executor):
+ with pytest.raises(OperationalError) as excinfo:
+ run(executor, "select 'foo'; invalid syntax")
+ assert "syntax error" in str(excinfo.value)
+
+
+@dbtest
+def test_favorite_query(executor):
+ set_expanded_output(False)
+ run(executor, "create table test(a text)")
+ run(executor, "insert into test values('abc')")
+ run(executor, "insert into test values('def')")
+
+ results = run(executor, "\\fs test-a select * from test where a like 'a%'")
+ assert_result_equal(results, status="Saved.")
+
+ results = run(executor, "\\f+ test-a")
+ assert_result_equal(
+ results,
+ title="> select * from test where a like 'a%'",
+ headers=["a"],
+ rows=[("abc",)],
+ auto_status=False,
+ )
+
+ results = run(executor, "\\fd test-a")
+ assert_result_equal(results, status="test-a: Deleted")
+
+
+@dbtest
+def test_bind_parameterized_favorite_query(executor):
+ set_expanded_output(False)
+ run(executor, "create table test(name text, id integer)")
+ run(executor, "insert into test values('def', 2)")
+ run(executor, "insert into test values('two words', 3)")
+
+ results = run(executor, "\\fs q_param select * from test where name=?")
+ assert_result_equal(results, status="Saved.")
+
+ results = run(executor, "\\f+ q_param def")
+ assert_result_equal(
+ results,
+ title="> select * from test where name=?",
+ headers=["name", "id"],
+ rows=[("def", 2)],
+ auto_status=False,
+ )
+
+ results = run(executor, "\\f+ q_param 'two words'")
+ assert_result_equal(
+ results,
+ title="> select * from test where name=?",
+ headers=["name", "id"],
+ rows=[("two words", 3)],
+ auto_status=False,
+ )
+
+ with pytest.raises(ProgrammingError):
+ results = run(executor, "\\f+ q_param")
+
+ with pytest.raises(ProgrammingError):
+ results = run(executor, "\\f+ q_param 1 2")
+
+@dbtest
+def test_verbose_feature_of_favorite_query(executor):
+ set_expanded_output(False)
+ run(executor, "create table test(a text, id integer)")
+ run(executor, "insert into test values('abc', 1)")
+ run(executor, "insert into test values('def', 2)")
+
+ results = run(executor, "\\fs sh_param select * from test where id=$1")
+ assert_result_equal(results, status="Saved.")
+
+ results = run(executor, "\\f sh_param 1")
+ assert_result_equal(
+ results,
+ title=None,
+ headers=["a", "id"],
+ rows=[("abc", 1)],
+ auto_status=False,
+ )
+
+ results = run(executor, "\\f+ sh_param 1")
+ assert_result_equal(
+ results,
+ title="> select * from test where id=1",
+ headers=["a", "id"],
+ rows=[("abc", 1)],
+ auto_status=False,
+ )
+
+@dbtest
+def test_shell_parameterized_favorite_query(executor):
+ set_expanded_output(False)
+ run(executor, "create table test(a text, id integer)")
+ run(executor, "insert into test values('abc', 1)")
+ run(executor, "insert into test values('def', 2)")
+
+ results = run(executor, "\\fs sh_param select * from test where id=$1")
+ assert_result_equal(results, status="Saved.")
+
+ results = run(executor, "\\f+ sh_param 1")
+ assert_result_equal(
+ results,
+ title="> select * from test where id=1",
+ headers=["a", "id"],
+ rows=[("abc", 1)],
+ auto_status=False,
+ )
+
+ results = run(executor, "\\f+ sh_param")
+ assert_result_equal(
+ results,
+ title=None,
+ headers=None,
+ rows=None,
+ status="missing substitution for $1 in query:\n select * from test where id=$1",
+ )
+
+ results = run(executor, "\\f+ sh_param 1 2")
+ assert_result_equal(
+ results,
+ title=None,
+ headers=None,
+ rows=None,
+ status="Too many arguments.\nQuery does not have enough place holders to substitute.\nselect * from test where id=1",
+ )
+
+
+@dbtest
+def test_favorite_query_multiple_statement(executor):
+ set_expanded_output(False)
+ run(executor, "create table test(a text)")
+ run(executor, "insert into test values('abc')")
+ run(executor, "insert into test values('def')")
+
+ results = run(
+ executor,
+ "\\fs test-ad select * from test where a like 'a%'; "
+ "select * from test where a like 'd%'",
+ )
+ assert_result_equal(results, status="Saved.")
+
+ results = run(executor, "\\f+ test-ad")
+ expected = [
+ {
+ "title": "> select * from test where a like 'a%'",
+ "headers": ["a"],
+ "rows": [("abc",)],
+ "status": None,
+ },
+ {
+ "title": "> select * from test where a like 'd%'",
+ "headers": ["a"],
+ "rows": [("def",)],
+ "status": None,
+ },
+ ]
+ assert expected == results
+
+ results = run(executor, "\\fd test-ad")
+ assert_result_equal(results, status="test-ad: Deleted")
+
+
+@dbtest
+def test_favorite_query_expanded_output(executor):
+ set_expanded_output(False)
+ run(executor, """create table test(a text)""")
+ run(executor, """insert into test values('abc')""")
+
+ results = run(executor, "\\fs test-ae select * from test")
+ assert_result_equal(results, status="Saved.")
+
+ results = run(executor, "\\f+ test-ae \G")
+ assert is_expanded_output() is True
+ assert_result_equal(
+ results,
+ title="> select * from test",
+ headers=["a"],
+ rows=[("abc",)],
+ auto_status=False,
+ )
+
+ set_expanded_output(False)
+
+ results = run(executor, "\\fd test-ae")
+ assert_result_equal(results, status="test-ae: Deleted")
+
+
+@dbtest
+def test_special_command(executor):
+ results = run(executor, "\\?")
+ assert_result_equal(
+ results,
+ rows=("quit", "\\q", "Quit."),
+ headers="Command",
+ assert_contains=True,
+ auto_status=False,
+ )
+
+
+@dbtest
+def test_cd_command_without_a_folder_name(executor):
+ results = run(executor, "system cd")
+ assert_result_equal(results, status="No folder name was provided.")
+
+
+@dbtest
+def test_system_command_not_found(executor):
+ results = run(executor, "system xyz")
+ assert_result_equal(
+ results, status="OSError: No such file or directory", assert_contains=True
+ )
+
+
+@dbtest
+def test_system_command_output(executor):
+ test_dir = os.path.abspath(os.path.dirname(__file__))
+ test_file_path = os.path.join(test_dir, "test.txt")
+ results = run(executor, "system cat {0}".format(test_file_path))
+ assert_result_equal(results, status="litecli is awesome!\n")
+
+
+@dbtest
+def test_cd_command_current_dir(executor):
+ test_path = os.path.abspath(os.path.dirname(__file__))
+ run(executor, "system cd {0}".format(test_path))
+ assert os.getcwd() == test_path
+ run(executor, "system cd ..")
+
+
+@dbtest
+def test_unicode_support(executor):
+ results = run(executor, u"SELECT '日本語' AS japanese;")
+ assert_result_equal(results, headers=["japanese"], rows=[(u"日本語",)])
+
+
+@dbtest
+def test_timestamp_null(executor):
+ run(executor, """create table ts_null(a timestamp null)""")
+ run(executor, """insert into ts_null values(null)""")
+ results = run(executor, """select * from ts_null""")
+ assert_result_equal(results, headers=["a"], rows=[(None,)])
+
+
+@dbtest
+def test_datetime_null(executor):
+ run(executor, """create table dt_null(a datetime null)""")
+ run(executor, """insert into dt_null values(null)""")
+ results = run(executor, """select * from dt_null""")
+ assert_result_equal(results, headers=["a"], rows=[(None,)])
+
+
+@dbtest
+def test_date_null(executor):
+ run(executor, """create table date_null(a date null)""")
+ run(executor, """insert into date_null values(null)""")
+ results = run(executor, """select * from date_null""")
+ assert_result_equal(results, headers=["a"], rows=[(None,)])
+
+
+@dbtest
+def test_time_null(executor):
+ run(executor, """create table time_null(a time null)""")
+ run(executor, """insert into time_null values(null)""")
+ results = run(executor, """select * from time_null""")
+ assert_result_equal(results, headers=["a"], rows=[(None,)])
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 0000000..41bac9b
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,96 @@
+# -*- coding: utf-8 -*-
+
+import os
+import time
+import signal
+import platform
+import multiprocessing
+from contextlib import closing
+
+import sqlite3
+import pytest
+
+from litecli.main import special
+
+DATABASE = os.getenv("PYTEST_DATABASE", "test.sqlite3")
+
+
+def db_connection(dbname=":memory:"):
+ conn = sqlite3.connect(database=dbname, isolation_level=None)
+ return conn
+
+
+try:
+ db_connection()
+ CAN_CONNECT_TO_DB = True
+except Exception as ex:
+ CAN_CONNECT_TO_DB = False
+
+dbtest = pytest.mark.skipif(
+ not CAN_CONNECT_TO_DB, reason="Error creating sqlite connection"
+)
+
+
+def create_db(dbname):
+ with closing(db_connection().cursor()) as cur:
+ try:
+ cur.execute("""DROP DATABASE IF EXISTS _test_db""")
+ cur.execute("""CREATE DATABASE _test_db""")
+ except:
+ pass
+
+
+def drop_tables(dbname):
+ with closing(db_connection().cursor()) as cur:
+ try:
+ cur.execute("""DROP DATABASE IF EXISTS _test_db""")
+ except:
+ pass
+
+
+def run(executor, sql, rows_as_list=True):
+ """Return string output for the sql to be run."""
+ result = []
+
+ for title, rows, headers, status in executor.run(sql):
+ rows = list(rows) if (rows_as_list and rows) else rows
+ result.append(
+ {"title": title, "rows": rows, "headers": headers, "status": status}
+ )
+
+ return result
+
+
+def set_expanded_output(is_expanded):
+ """Pass-through for the tests."""
+ return special.set_expanded_output(is_expanded)
+
+
+def is_expanded_output():
+ """Pass-through for the tests."""
+ return special.is_expanded_output()
+
+
+def send_ctrl_c_to_pid(pid, wait_seconds):
+ """Sends a Ctrl-C like signal to the given `pid` after `wait_seconds`
+ seconds."""
+ time.sleep(wait_seconds)
+ system_name = platform.system()
+ if system_name == "Windows":
+ os.kill(pid, signal.CTRL_C_EVENT)
+ else:
+ os.kill(pid, signal.SIGINT)
+
+
+def send_ctrl_c(wait_seconds):
+ """Create a process that sends a Ctrl-C like signal to the current process
+ after `wait_seconds` seconds.
+
+ Returns the `multiprocessing.Process` created.
+
+ """
+ ctrl_c_process = multiprocessing.Process(
+ target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds)
+ )
+ ctrl_c_process.start()
+ return ctrl_c_process