summaryrefslogtreecommitdiffstats
path: root/litecli/packages
diff options
context:
space:
mode:
Diffstat (limited to 'litecli/packages')
-rw-r--r--litecli/packages/__init__.py0
-rw-r--r--litecli/packages/completion_engine.py331
-rw-r--r--litecli/packages/filepaths.py88
-rw-r--r--litecli/packages/parseutils.py227
-rw-r--r--litecli/packages/prompt_utils.py39
-rw-r--r--litecli/packages/special/__init__.py12
-rw-r--r--litecli/packages/special/dbcommands.py273
-rw-r--r--litecli/packages/special/favoritequeries.py59
-rw-r--r--litecli/packages/special/iocommands.py479
-rw-r--r--litecli/packages/special/main.py160
-rw-r--r--litecli/packages/special/utils.py48
11 files changed, 1716 insertions, 0 deletions
diff --git a/litecli/packages/__init__.py b/litecli/packages/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/litecli/packages/__init__.py
diff --git a/litecli/packages/completion_engine.py b/litecli/packages/completion_engine.py
new file mode 100644
index 0000000..0397857
--- /dev/null
+++ b/litecli/packages/completion_engine.py
@@ -0,0 +1,331 @@
+from __future__ import print_function
+import sys
+import sqlparse
+from sqlparse.sql import Comparison, Identifier, Where
+from litecli.encodingutils import string_types, text_type
+from .parseutils import last_word, extract_tables, find_prev_keyword
+from .special import parse_special_command
+
+
+def suggest_type(full_text, text_before_cursor):
+ """Takes the full_text that is typed so far and also the text before the
+ cursor to suggest completion type and scope.
+
+ Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
+ A scope for a column category will be a list of tables.
+ """
+
+ word_before_cursor = last_word(text_before_cursor, include="many_punctuations")
+
+ identifier = None
+
+ # here should be removed once sqlparse has been fixed
+ try:
+ # If we've partially typed a word then word_before_cursor won't be an empty
+ # string. In that case we want to remove the partially typed string before
+ # sending it to the sqlparser. Otherwise the last token will always be the
+ # partially typed string which renders the smart completion useless because
+ # it will always return the list of keywords as completion.
+ if word_before_cursor:
+ if word_before_cursor.endswith("(") or word_before_cursor.startswith("\\"):
+ parsed = sqlparse.parse(text_before_cursor)
+ else:
+ parsed = sqlparse.parse(text_before_cursor[: -len(word_before_cursor)])
+
+ # word_before_cursor may include a schema qualification, like
+ # "schema_name.partial_name" or "schema_name.", so parse it
+ # separately
+ p = sqlparse.parse(word_before_cursor)[0]
+
+ if p.tokens and isinstance(p.tokens[0], Identifier):
+ identifier = p.tokens[0]
+ else:
+ parsed = sqlparse.parse(text_before_cursor)
+ except (TypeError, AttributeError):
+ return [{"type": "keyword"}]
+
+ if len(parsed) > 1:
+ # Multiple statements being edited -- isolate the current one by
+ # cumulatively summing statement lengths to find the one that bounds the
+ # current position
+ current_pos = len(text_before_cursor)
+ stmt_start, stmt_end = 0, 0
+
+ for statement in parsed:
+ stmt_len = len(text_type(statement))
+ stmt_start, stmt_end = stmt_end, stmt_end + stmt_len
+
+ if stmt_end >= current_pos:
+ text_before_cursor = full_text[stmt_start:current_pos]
+ full_text = full_text[stmt_start:]
+ break
+
+ elif parsed:
+ # A single statement
+ statement = parsed[0]
+ else:
+ # The empty string
+ statement = None
+
+ # Check for special commands and handle those separately
+ if statement:
+ # Be careful here because trivial whitespace is parsed as a statement,
+ # but the statement won't have a first token
+ tok1 = statement.token_first()
+ if tok1 and tok1.value.startswith("."):
+ return suggest_special(text_before_cursor)
+ elif tok1 and tok1.value.startswith("\\"):
+ return suggest_special(text_before_cursor)
+ elif tok1 and tok1.value.startswith("source"):
+ return suggest_special(text_before_cursor)
+ elif text_before_cursor and text_before_cursor.startswith(".open "):
+ return suggest_special(text_before_cursor)
+
+ last_token = statement and statement.token_prev(len(statement.tokens))[1] or ""
+
+ return suggest_based_on_last_token(
+ last_token, text_before_cursor, full_text, identifier
+ )
+
+
+def suggest_special(text):
+ text = text.lstrip()
+ cmd, _, arg = parse_special_command(text)
+
+ if cmd == text:
+ # Trying to complete the special command itself
+ return [{"type": "special"}]
+
+ if cmd in ("\\u", "\\r"):
+ return [{"type": "database"}]
+
+ if cmd in ("\\T"):
+ return [{"type": "table_format"}]
+
+ if cmd in ["\\f", "\\fs", "\\fd"]:
+ return [{"type": "favoritequery"}]
+
+ if cmd in ["\\d", "\\dt", "\\dt+", ".schema"]:
+ return [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+
+ if cmd in ["\\.", "source", ".open"]:
+ return [{"type": "file_name"}]
+
+ if cmd in [".import"]:
+ # Usage: .import filename table
+ if _expecting_arg_idx(arg, text) == 1:
+ return [{"type": "file_name"}]
+ else:
+ return [{"type": "table", "schema": []}]
+
+ return [{"type": "keyword"}, {"type": "special"}]
+
+
+def _expecting_arg_idx(arg, text):
+ """Return the index of expecting argument.
+
+ >>> _expecting_arg_idx("./da", ".import ./da")
+ 1
+ >>> _expecting_arg_idx("./data.csv", ".import ./data.csv")
+ 1
+ >>> _expecting_arg_idx("./data.csv", ".import ./data.csv ")
+ 2
+ >>> _expecting_arg_idx("./data.csv t", ".import ./data.csv t")
+ 2
+ """
+ args = arg.split()
+ return len(args) + int(text[-1].isspace())
+
+
+def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier):
+ if isinstance(token, string_types):
+ token_v = token.lower()
+ elif isinstance(token, Comparison):
+ # If 'token' is a Comparison type such as
+ # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling
+ # token.value on the comparison type will only return the lhs of the
+ # comparison. In this case a.id. So we need to do token.tokens to get
+ # both sides of the comparison and pick the last token out of that
+ # list.
+ token_v = token.tokens[-1].value.lower()
+ elif isinstance(token, Where):
+ # sqlparse groups all tokens from the where clause into a single token
+ # list. This means that token.value may be something like
+ # 'where foo > 5 and '. We need to look "inside" token.tokens to handle
+ # suggestions in complicated where clauses correctly
+ prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
+ return suggest_based_on_last_token(
+ prev_keyword, text_before_cursor, full_text, identifier
+ )
+ else:
+ token_v = token.value.lower()
+
+ is_operand = lambda x: x and any([x.endswith(op) for op in ["+", "-", "*", "/"]])
+
+ if not token:
+ return [{"type": "keyword"}, {"type": "special"}]
+ elif token_v.endswith("("):
+ p = sqlparse.parse(text_before_cursor)[0]
+
+ if p.tokens and isinstance(p.tokens[-1], Where):
+ # Four possibilities:
+ # 1 - Parenthesized clause like "WHERE foo AND ("
+ # Suggest columns/functions
+ # 2 - Function call like "WHERE foo("
+ # Suggest columns/functions
+ # 3 - Subquery expression like "WHERE EXISTS ("
+ # Suggest keywords, in order to do a subquery
+ # 4 - Subquery OR array comparison like "WHERE foo = ANY("
+ # Suggest columns/functions AND keywords. (If we wanted to be
+ # really fancy, we could suggest only array-typed columns)
+
+ column_suggestions = suggest_based_on_last_token(
+ "where", text_before_cursor, full_text, identifier
+ )
+
+ # Check for a subquery expression (cases 3 & 4)
+ where = p.tokens[-1]
+ idx, prev_tok = where.token_prev(len(where.tokens) - 1)
+
+ if isinstance(prev_tok, Comparison):
+ # e.g. "SELECT foo FROM bar WHERE foo = ANY("
+ prev_tok = prev_tok.tokens[-1]
+
+ prev_tok = prev_tok.value.lower()
+ if prev_tok == "exists":
+ return [{"type": "keyword"}]
+ else:
+ return column_suggestions
+
+ # Get the token before the parens
+ idx, prev_tok = p.token_prev(len(p.tokens) - 1)
+ if prev_tok and prev_tok.value and prev_tok.value.lower() == "using":
+ # tbl1 INNER JOIN tbl2 USING (col1, col2)
+ tables = extract_tables(full_text)
+
+ # suggest columns that are present in more than one table
+ return [{"type": "column", "tables": tables, "drop_unique": True}]
+ elif p.token_first().value.lower() == "select":
+ # If the lparen is preceeded by a space chances are we're about to
+ # do a sub-select.
+ if last_word(text_before_cursor, "all_punctuations").startswith("("):
+ return [{"type": "keyword"}]
+ elif p.token_first().value.lower() == "show":
+ return [{"type": "show"}]
+
+ # We're probably in a function argument list
+ return [{"type": "column", "tables": extract_tables(full_text)}]
+ elif token_v in ("set", "order by", "distinct"):
+ return [{"type": "column", "tables": extract_tables(full_text)}]
+ elif token_v == "as":
+ # Don't suggest anything for an alias
+ return []
+ elif token_v in ("show"):
+ return [{"type": "show"}]
+ elif token_v in ("to",):
+ p = sqlparse.parse(text_before_cursor)[0]
+ if p.token_first().value.lower() == "change":
+ return [{"type": "change"}]
+ else:
+ return [{"type": "user"}]
+ elif token_v in ("user", "for"):
+ return [{"type": "user"}]
+ elif token_v in ("select", "where", "having"):
+ # Check for a table alias or schema qualification
+ parent = (identifier and identifier.get_parent_name()) or []
+
+ tables = extract_tables(full_text)
+ if parent:
+ tables = [t for t in tables if identifies(parent, *t)]
+ return [
+ {"type": "column", "tables": tables},
+ {"type": "table", "schema": parent},
+ {"type": "view", "schema": parent},
+ {"type": "function", "schema": parent},
+ ]
+ else:
+ aliases = [alias or table for (schema, table, alias) in tables]
+ return [
+ {"type": "column", "tables": tables},
+ {"type": "function", "schema": []},
+ {"type": "alias", "aliases": aliases},
+ {"type": "keyword"},
+ ]
+ elif (token_v.endswith("join") and token.is_keyword) or (
+ token_v
+ in ("copy", "from", "update", "into", "describe", "truncate", "desc", "explain")
+ ):
+ schema = (identifier and identifier.get_parent_name()) or []
+
+ # Suggest tables from either the currently-selected schema or the
+ # public schema if no schema has been specified
+ suggest = [{"type": "table", "schema": schema}]
+
+ if not schema:
+ # Suggest schemas
+ suggest.insert(0, {"type": "schema"})
+
+ # Only tables can be TRUNCATED, otherwise suggest views
+ if token_v != "truncate":
+ suggest.append({"type": "view", "schema": schema})
+
+ return suggest
+
+ elif token_v in ("table", "view", "function"):
+ # E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>'
+ rel_type = token_v
+ schema = (identifier and identifier.get_parent_name()) or []
+ if schema:
+ return [{"type": rel_type, "schema": schema}]
+ else:
+ return [{"type": "schema"}, {"type": rel_type, "schema": []}]
+ elif token_v == "on":
+ tables = extract_tables(full_text) # [(schema, table, alias), ...]
+ parent = (identifier and identifier.get_parent_name()) or []
+ if parent:
+ # "ON parent.<suggestion>"
+ # parent can be either a schema name or table alias
+ tables = [t for t in tables if identifies(parent, *t)]
+ return [
+ {"type": "column", "tables": tables},
+ {"type": "table", "schema": parent},
+ {"type": "view", "schema": parent},
+ {"type": "function", "schema": parent},
+ ]
+ else:
+ # ON <suggestion>
+ # Use table alias if there is one, otherwise the table name
+ aliases = [alias or table for (schema, table, alias) in tables]
+ suggest = [{"type": "alias", "aliases": aliases}]
+
+ # The lists of 'aliases' could be empty if we're trying to complete
+ # a GRANT query. eg: GRANT SELECT, INSERT ON <tab>
+ # In that case we just suggest all tables.
+ if not aliases:
+ suggest.append({"type": "table", "schema": parent})
+ return suggest
+
+ elif token_v in ("use", "database", "template", "connect"):
+ # "\c <db", "use <db>", "DROP DATABASE <db>",
+ # "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
+ return [{"type": "database"}]
+ elif token_v == "tableformat":
+ return [{"type": "table_format"}]
+ elif token_v.endswith(",") or is_operand(token_v) or token_v in ["=", "and", "or"]:
+ prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
+ if prev_keyword:
+ return suggest_based_on_last_token(
+ prev_keyword, text_before_cursor, full_text, identifier
+ )
+ else:
+ return []
+ else:
+ return [{"type": "keyword"}]
+
+
+def identifies(id, schema, table, alias):
+ return id == alias or id == table or (schema and (id == schema + "." + table))
diff --git a/litecli/packages/filepaths.py b/litecli/packages/filepaths.py
new file mode 100644
index 0000000..2f01046
--- /dev/null
+++ b/litecli/packages/filepaths.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8
+
+from __future__ import unicode_literals
+
+from litecli.encodingutils import text_type
+import os
+
+
+def list_path(root_dir):
+ """List directory if exists.
+
+ :param dir: str
+ :return: list
+
+ """
+ res = []
+ if os.path.isdir(root_dir):
+ for name in os.listdir(root_dir):
+ res.append(name)
+ return res
+
+
+def complete_path(curr_dir, last_dir):
+ """Return the path to complete that matches the last entered component.
+
+ If the last entered component is ~, expanded path would not
+ match, so return all of the available paths.
+
+ :param curr_dir: str
+ :param last_dir: str
+ :return: str
+
+ """
+ if not last_dir or curr_dir.startswith(last_dir):
+ return curr_dir
+ elif last_dir == "~":
+ return os.path.join(last_dir, curr_dir)
+
+
+def parse_path(root_dir):
+ """Split path into head and last component for the completer.
+
+ Also return position where last component starts.
+
+ :param root_dir: str path
+ :return: tuple of (string, string, int)
+
+ """
+ base_dir, last_dir, position = "", "", 0
+ if root_dir:
+ base_dir, last_dir = os.path.split(root_dir)
+ position = -len(last_dir) if last_dir else 0
+ return base_dir, last_dir, position
+
+
+def suggest_path(root_dir):
+ """List all files and subdirectories in a directory.
+
+ If the directory is not specified, suggest root directory,
+ user directory, current and parent directory.
+
+ :param root_dir: string: directory to list
+ :return: list
+
+ """
+ if not root_dir:
+ return map(text_type, [os.path.abspath(os.sep), "~", os.curdir, os.pardir])
+
+ if "~" in root_dir:
+ root_dir = text_type(os.path.expanduser(root_dir))
+
+ if not os.path.exists(root_dir):
+ root_dir, _ = os.path.split(root_dir)
+
+ return list_path(root_dir)
+
+
+def dir_path_exists(path):
+ """Check if the directory path exists for a given file.
+
+ For example, for a file /home/user/.cache/litecli/log, check if
+ /home/user/.cache/litecli exists.
+
+ :param str path: The file path.
+ :return: Whether or not the directory path exists.
+
+ """
+ return os.path.exists(os.path.dirname(path))
diff --git a/litecli/packages/parseutils.py b/litecli/packages/parseutils.py
new file mode 100644
index 0000000..92fe365
--- /dev/null
+++ b/litecli/packages/parseutils.py
@@ -0,0 +1,227 @@
+from __future__ import print_function
+import re
+import sqlparse
+from sqlparse.sql import IdentifierList, Identifier, Function
+from sqlparse.tokens import Keyword, DML, Punctuation
+
+cleanup_regex = {
+ # This matches only alphanumerics and underscores.
+ "alphanum_underscore": re.compile(r"(\w+)$"),
+ # This matches everything except spaces, parens, colon, and comma
+ "many_punctuations": re.compile(r"([^():,\s]+)$"),
+ # This matches everything except spaces, parens, colon, comma, and period
+ "most_punctuations": re.compile(r"([^\.():,\s]+)$"),
+ # This matches everything except a space.
+ "all_punctuations": re.compile("([^\s]+)$"),
+}
+
+
+def last_word(text, include="alphanum_underscore"):
+ """
+ Find the last word in a sentence.
+
+ >>> last_word('abc')
+ 'abc'
+ >>> last_word(' abc')
+ 'abc'
+ >>> last_word('')
+ ''
+ >>> last_word(' ')
+ ''
+ >>> last_word('abc ')
+ ''
+ >>> last_word('abc def')
+ 'def'
+ >>> last_word('abc def ')
+ ''
+ >>> last_word('abc def;')
+ ''
+ >>> last_word('bac $def')
+ 'def'
+ >>> last_word('bac $def', include='most_punctuations')
+ '$def'
+ >>> last_word('bac \def', include='most_punctuations')
+ '\\\\def'
+ >>> last_word('bac \def;', include='most_punctuations')
+ '\\\\def;'
+ >>> last_word('bac::def', include='most_punctuations')
+ 'def'
+ """
+
+ if not text: # Empty string
+ return ""
+
+ if text[-1].isspace():
+ return ""
+ else:
+ regex = cleanup_regex[include]
+ matches = regex.search(text)
+ if matches:
+ return matches.group(0)
+ else:
+ return ""
+
+
+# This code is borrowed from sqlparse example script.
+# <url>
+def is_subselect(parsed):
+ if not parsed.is_group:
+ return False
+ for item in parsed.tokens:
+ if item.ttype is DML and item.value.upper() in (
+ "SELECT",
+ "INSERT",
+ "UPDATE",
+ "CREATE",
+ "DELETE",
+ ):
+ return True
+ return False
+
+
+def extract_from_part(parsed, stop_at_punctuation=True):
+ tbl_prefix_seen = False
+ for item in parsed.tokens:
+ if tbl_prefix_seen:
+ if is_subselect(item):
+ for x in extract_from_part(item, stop_at_punctuation):
+ yield x
+ elif stop_at_punctuation and item.ttype is Punctuation:
+ return
+ # An incomplete nested select won't be recognized correctly as a
+ # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
+ # the second FROM to trigger this elif condition resulting in a
+ # `return`. So we need to ignore the keyword if the keyword
+ # FROM.
+ # Also 'SELECT * FROM abc JOIN def' will trigger this elif
+ # condition. So we need to ignore the keyword JOIN and its variants
+ # INNER JOIN, FULL OUTER JOIN, etc.
+ elif (
+ item.ttype is Keyword
+ and (not item.value.upper() == "FROM")
+ and (not item.value.upper().endswith("JOIN"))
+ ):
+ return
+ else:
+ yield item
+ elif (
+ item.ttype is Keyword or item.ttype is Keyword.DML
+ ) and item.value.upper() in ("COPY", "FROM", "INTO", "UPDATE", "TABLE", "JOIN"):
+ tbl_prefix_seen = True
+ # 'SELECT a, FROM abc' will detect FROM as part of the column list.
+ # So this check here is necessary.
+ elif isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ if identifier.ttype is Keyword and identifier.value.upper() == "FROM":
+ tbl_prefix_seen = True
+ break
+
+
+def extract_table_identifiers(token_stream):
+ """yields tuples of (schema_name, table_name, table_alias)"""
+
+ for item in token_stream:
+ if isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ # Sometimes Keywords (such as FROM ) are classified as
+ # identifiers which don't have the get_real_name() method.
+ try:
+ schema_name = identifier.get_parent_name()
+ real_name = identifier.get_real_name()
+ except AttributeError:
+ continue
+ if real_name:
+ yield (schema_name, real_name, identifier.get_alias())
+ elif isinstance(item, Identifier):
+ real_name = item.get_real_name()
+ schema_name = item.get_parent_name()
+
+ if real_name:
+ yield (schema_name, real_name, item.get_alias())
+ else:
+ name = item.get_name()
+ yield (None, name, item.get_alias() or name)
+ elif isinstance(item, Function):
+ yield (None, item.get_name(), item.get_name())
+
+
+# extract_tables is inspired from examples in the sqlparse lib.
+def extract_tables(sql):
+ """Extract the table names from an SQL statment.
+
+ Returns a list of (schema, table, alias) tuples
+
+ """
+ parsed = sqlparse.parse(sql)
+ if not parsed:
+ return []
+
+ # INSERT statements must stop looking for tables at the sign of first
+ # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
+ # abc is the table name, but if we don't stop at the first lparen, then
+ # we'll identify abc, col1 and col2 as table names.
+ insert_stmt = parsed[0].token_first().value.lower() == "insert"
+ stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
+ return list(extract_table_identifiers(stream))
+
+
+def find_prev_keyword(sql):
+ """ Find the last sql keyword in an SQL statement
+
+ Returns the value of the last keyword, and the text of the query with
+ everything after the last keyword stripped
+ """
+ if not sql.strip():
+ return None, ""
+
+ parsed = sqlparse.parse(sql)[0]
+ flattened = list(parsed.flatten())
+
+ logical_operators = ("AND", "OR", "NOT", "BETWEEN")
+
+ for t in reversed(flattened):
+ if t.value == "(" or (
+ t.is_keyword and (t.value.upper() not in logical_operators)
+ ):
+ # Find the location of token t in the original parsed statement
+ # We can't use parsed.token_index(t) because t may be a child token
+ # inside a TokenList, in which case token_index thows an error
+ # Minimal example:
+ # p = sqlparse.parse('select * from foo where bar')
+ # t = list(p.flatten())[-3] # The "Where" token
+ # p.token_index(t) # Throws ValueError: not in list
+ idx = flattened.index(t)
+
+ # Combine the string values of all tokens in the original list
+ # up to and including the target keyword token t, to produce a
+ # query string with everything after the keyword token removed
+ text = "".join(tok.value for tok in flattened[: idx + 1])
+ return t, text
+
+ return None, ""
+
+
+def query_starts_with(query, prefixes):
+ """Check if the query starts with any item from *prefixes*."""
+ prefixes = [prefix.lower() for prefix in prefixes]
+ formatted_sql = sqlparse.format(query.lower(), strip_comments=True)
+ return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
+
+
+def queries_start_with(queries, prefixes):
+ """Check if any queries start with any item from *prefixes*."""
+ for query in sqlparse.split(queries):
+ if query and query_starts_with(query, prefixes) is True:
+ return True
+ return False
+
+
+def is_destructive(queries):
+ """Returns if any of the queries in *queries* is destructive."""
+ keywords = ("drop", "shutdown", "delete", "truncate", "alter")
+ return queries_start_with(queries, keywords)
+
+
+if __name__ == "__main__":
+ sql = "select * from (select t. from tabl t"
+ print(extract_tables(sql))
diff --git a/litecli/packages/prompt_utils.py b/litecli/packages/prompt_utils.py
new file mode 100644
index 0000000..d9ad2b6
--- /dev/null
+++ b/litecli/packages/prompt_utils.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+from __future__ import unicode_literals
+
+
+import sys
+import click
+from .parseutils import is_destructive
+
+
+def confirm_destructive_query(queries):
+ """Check if the query is destructive and prompts the user to confirm.
+
+ Returns:
+ * None if the query is non-destructive or we can't prompt the user.
+ * True if the query is destructive and the user wants to proceed.
+ * False if the query is destructive and the user doesn't want to proceed.
+
+ """
+ prompt_text = (
+ "You're about to run a destructive command.\n" "Do you want to proceed? (y/n)"
+ )
+ if is_destructive(queries) and sys.stdin.isatty():
+ return prompt(prompt_text, type=bool)
+
+
+def confirm(*args, **kwargs):
+ """Prompt for confirmation (yes/no) and handle any abort exceptions."""
+ try:
+ return click.confirm(*args, **kwargs)
+ except click.Abort:
+ return False
+
+
+def prompt(*args, **kwargs):
+ """Prompt the user for input and handle any abort exceptions."""
+ try:
+ return click.prompt(*args, **kwargs)
+ except click.Abort:
+ return False
diff --git a/litecli/packages/special/__init__.py b/litecli/packages/special/__init__.py
new file mode 100644
index 0000000..fd2b18c
--- /dev/null
+++ b/litecli/packages/special/__init__.py
@@ -0,0 +1,12 @@
+__all__ = []
+
+
+def export(defn):
+ """Decorator to explicitly mark functions that are exposed in a lib."""
+ globals()[defn.__name__] = defn
+ __all__.append(defn.__name__)
+ return defn
+
+
+from . import dbcommands
+from . import iocommands
diff --git a/litecli/packages/special/dbcommands.py b/litecli/packages/special/dbcommands.py
new file mode 100644
index 0000000..a7eaa0c
--- /dev/null
+++ b/litecli/packages/special/dbcommands.py
@@ -0,0 +1,273 @@
+from __future__ import unicode_literals, print_function
+import csv
+import logging
+import os
+import sys
+import platform
+import shlex
+from sqlite3 import ProgrammingError
+
+from litecli import __version__
+from litecli.packages.special import iocommands
+from litecli.packages.special.utils import format_uptime
+from .main import special_command, RAW_QUERY, PARSED_QUERY, ArgumentMissing
+
+log = logging.getLogger(__name__)
+
+
+@special_command(
+ ".tables",
+ "\\dt",
+ "List tables.",
+ arg_type=PARSED_QUERY,
+ case_sensitive=True,
+ aliases=("\\dt",),
+)
+def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False):
+ if arg:
+ args = ("{0}%".format(arg),)
+ query = """
+ SELECT name FROM sqlite_master
+ WHERE type IN ('table','view') AND name LIKE ? AND name NOT LIKE 'sqlite_%'
+ ORDER BY 1
+ """
+ else:
+ args = tuple()
+ query = """
+ SELECT name FROM sqlite_master
+ WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%'
+ ORDER BY 1
+ """
+
+ log.debug(query)
+ cur.execute(query, args)
+ tables = cur.fetchall()
+ status = ""
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ else:
+ return [(None, None, None, "")]
+
+ # if verbose and arg:
+ # query = "SELECT sql FROM sqlite_master WHERE name LIKE ?"
+ # log.debug(query)
+ # cur.execute(query)
+ # status = cur.fetchone()[1]
+
+ return [(None, tables, headers, status)]
+
+
+@special_command(
+ ".schema",
+ ".schema[+] [table]",
+ "The complete schema for the database or a single table",
+ arg_type=PARSED_QUERY,
+ case_sensitive=True,
+)
+def show_schema(cur, arg=None, **_):
+ if arg:
+ args = (arg,)
+ query = """
+ SELECT sql FROM sqlite_master
+ WHERE name==?
+ ORDER BY tbl_name, type DESC, name
+ """
+ else:
+ args = tuple()
+ query = """
+ SELECT sql FROM sqlite_master
+ ORDER BY tbl_name, type DESC, name
+ """
+
+ log.debug(query)
+ cur.execute(query, args)
+ tables = cur.fetchall()
+ status = ""
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ else:
+ return [(None, None, None, "")]
+
+ return [(None, tables, headers, status)]
+
+
+@special_command(
+ ".databases",
+ ".databases",
+ "List databases.",
+ arg_type=RAW_QUERY,
+ case_sensitive=True,
+ aliases=("\\l",),
+)
+def list_databases(cur, **_):
+ query = "PRAGMA database_list"
+ log.debug(query)
+ cur.execute(query)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ return [(None, cur, headers, "")]
+ else:
+ return [(None, None, None, "")]
+
+
+@special_command(
+ ".status",
+ "\\s",
+ "Show current settings.",
+ arg_type=RAW_QUERY,
+ aliases=("\\s",),
+ case_sensitive=True,
+)
+def status(cur, **_):
+ # Create output buffers.
+ footer = []
+ footer.append("--------------")
+
+ # Output the litecli client information.
+ implementation = platform.python_implementation()
+ version = platform.python_version()
+ client_info = []
+ client_info.append("litecli {0},".format(__version__))
+ client_info.append("running on {0} {1}".format(implementation, version))
+ footer.append(" ".join(client_info))
+
+ # Build the output that will be displayed as a table.
+ query = "SELECT file from pragma_database_list() where name = 'main';"
+ log.debug(query)
+ cur.execute(query)
+ db = cur.fetchone()[0]
+ if db is None:
+ db = ""
+
+ footer.append("Current database: " + db)
+ if iocommands.is_pager_enabled():
+ if "PAGER" in os.environ:
+ pager = os.environ["PAGER"]
+ else:
+ pager = "System default"
+ else:
+ pager = "stdout"
+ footer.append("Current pager:" + pager)
+
+ footer.append("--------------")
+ return [(None, None, "", "\n".join(footer))]
+
+
+@special_command(
+ ".load",
+ ".load path",
+ "Load an extension library.",
+ arg_type=PARSED_QUERY,
+ case_sensitive=True,
+)
+def load_extension(cur, arg, **_):
+ args = shlex.split(arg)
+ if len(args) != 1:
+ raise TypeError(".load accepts exactly one path")
+ path = args[0]
+ conn = cur.connection
+ conn.enable_load_extension(True)
+ conn.load_extension(path)
+ return [(None, None, None, "")]
+
+
+@special_command(
+ "describe",
+ "\\d [table]",
+ "Description of a table",
+ arg_type=PARSED_QUERY,
+ case_sensitive=True,
+ aliases=("\\d", "describe", "desc"),
+)
+def describe(cur, arg, **_):
+ if arg:
+ args = (arg,)
+ query = """
+ PRAGMA table_info({})
+ """.format(
+ arg
+ )
+ else:
+ raise ArgumentMissing("Table name required.")
+
+ log.debug(query)
+ cur.execute(query)
+ tables = cur.fetchall()
+ status = ""
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ else:
+ return [(None, None, None, "")]
+
+ return [(None, tables, headers, status)]
+
+
+@special_command(
+ ".read",
+ ".read path",
+ "Read input from path",
+ arg_type=PARSED_QUERY,
+ case_sensitive=True,
+)
+def read_script(cur, arg, **_):
+ args = shlex.split(arg)
+ if len(args) != 1:
+ raise TypeError(".read accepts exactly one path")
+ path = args[0]
+ with open(path, "r") as f:
+ script = f.read()
+ cur.executescript(script)
+ return [(None, None, None, "")]
+
+
+@special_command(
+ ".import",
+ ".import filename table",
+ "Import data from filename into an existing table",
+ arg_type=PARSED_QUERY,
+ case_sensitive=True,
+)
+def import_file(cur, arg=None, **_):
+ def split(s):
+ # this is a modification of shlex.split function, just to make it support '`',
+ # because table name might contain '`' character.
+ lex = shlex.shlex(s, posix=True)
+ lex.whitespace_split = True
+ lex.commenters = ""
+ lex.quotes += "`"
+ return list(lex)
+
+ args = split(arg)
+ log.debug("[arg = %r], [args = %r]", arg, args)
+ if len(args) != 2:
+ raise TypeError("Usage: .import filename table")
+
+ filename, table = args
+ cur.execute('PRAGMA table_info("%s")' % table)
+ ncols = len(cur.fetchall())
+ insert_tmpl = 'INSERT INTO "%s" VALUES (?%s)' % (table, ",?" * (ncols - 1))
+
+ with open(filename, "r") as csvfile:
+ dialect = csv.Sniffer().sniff(csvfile.read(1024))
+ csvfile.seek(0)
+ reader = csv.reader(csvfile, dialect)
+
+ cur.execute("BEGIN")
+ ninserted, nignored = 0, 0
+ for i, row in enumerate(reader):
+ if len(row) != ncols:
+ print(
+ "%s:%d expected %d columns but found %d - ignored"
+ % (filename, i, ncols, len(row)),
+ file=sys.stderr,
+ )
+ nignored += 1
+ continue
+ cur.execute(insert_tmpl, row)
+ ninserted += 1
+ cur.execute("COMMIT")
+
+ status = "Inserted %d rows into %s" % (ninserted, table)
+ if nignored > 0:
+ status += " (%d rows are ignored)" % nignored
+ return [(None, None, None, status)]
diff --git a/litecli/packages/special/favoritequeries.py b/litecli/packages/special/favoritequeries.py
new file mode 100644
index 0000000..7da6fbf
--- /dev/null
+++ b/litecli/packages/special/favoritequeries.py
@@ -0,0 +1,59 @@
+# -*- coding: utf-8 -*-
+from __future__ import unicode_literals
+
+
+class FavoriteQueries(object):
+
+ section_name = "favorite_queries"
+
+ usage = """
+Favorite Queries are a way to save frequently used queries
+with a short name.
+Examples:
+
+ # Save a new favorite query.
+ > \\fs simple select * from abc where a is not Null;
+
+ # List all favorite queries.
+ > \\f
+ ╒════════╤═══════════════════════════════════════╕
+ │ Name │ Query │
+ ╞════════╪═══════════════════════════════════════╡
+ │ simple │ SELECT * FROM abc where a is not NULL │
+ ╘════════╧═══════════════════════════════════════╛
+
+ # Run a favorite query.
+ > \\f simple
+ ╒════════╤════════╕
+ │ a │ b │
+ ╞════════╪════════╡
+ │ 日本語 │ 日本語 │
+ ╘════════╧════════╛
+
+ # Delete a favorite query.
+ > \\fd simple
+ simple: Deleted
+"""
+
+ def __init__(self, config):
+ self.config = config
+
+ def list(self):
+ return self.config.get(self.section_name, [])
+
+ def get(self, name):
+ return self.config.get(self.section_name, {}).get(name, None)
+
+ def save(self, name, query):
+ if self.section_name not in self.config:
+ self.config[self.section_name] = {}
+ self.config[self.section_name][name] = query
+ self.config.write()
+
+ def delete(self, name):
+ try:
+ del self.config[self.section_name][name]
+ except KeyError:
+ return "%s: Not Found." % name
+ self.config.write()
+ return "%s: Deleted" % name
diff --git a/litecli/packages/special/iocommands.py b/litecli/packages/special/iocommands.py
new file mode 100644
index 0000000..8940057
--- /dev/null
+++ b/litecli/packages/special/iocommands.py
@@ -0,0 +1,479 @@
+from __future__ import unicode_literals
+import os
+import re
+import locale
+import logging
+import subprocess
+import shlex
+from io import open
+from time import sleep
+
+import click
+import sqlparse
+from configobj import ConfigObj
+
+from . import export
+from .main import special_command, NO_QUERY, PARSED_QUERY
+from .favoritequeries import FavoriteQueries
+from .utils import handle_cd_command
+from litecli.packages.prompt_utils import confirm_destructive_query
+
+use_expanded_output = False
+PAGER_ENABLED = True
+tee_file = None
+once_file = written_to_once_file = None
+favoritequeries = FavoriteQueries(ConfigObj())
+
+
+@export
+def set_favorite_queries(config):
+ global favoritequeries
+ favoritequeries = FavoriteQueries(config)
+
+
+@export
+def set_pager_enabled(val):
+ global PAGER_ENABLED
+ PAGER_ENABLED = val
+
+
+@export
+def is_pager_enabled():
+ return PAGER_ENABLED
+
+
+@export
+@special_command(
+ "pager",
+ "\\P [command]",
+ "Set PAGER. Print the query results via PAGER.",
+ arg_type=PARSED_QUERY,
+ aliases=("\\P",),
+ case_sensitive=True,
+)
+def set_pager(arg, **_):
+ if arg:
+ os.environ["PAGER"] = arg
+ msg = "PAGER set to %s." % arg
+ set_pager_enabled(True)
+ else:
+ if "PAGER" in os.environ:
+ msg = "PAGER set to %s." % os.environ["PAGER"]
+ else:
+ # This uses click's default per echo_via_pager.
+ msg = "Pager enabled."
+ set_pager_enabled(True)
+
+ return [(None, None, None, msg)]
+
+
+@export
+@special_command(
+ "nopager",
+ "\\n",
+ "Disable pager, print to stdout.",
+ arg_type=NO_QUERY,
+ aliases=("\\n",),
+ case_sensitive=True,
+)
+def disable_pager():
+ set_pager_enabled(False)
+ return [(None, None, None, "Pager disabled.")]
+
+
+@export
+def set_expanded_output(val):
+ global use_expanded_output
+ use_expanded_output = val
+
+
+@export
+def is_expanded_output():
+ return use_expanded_output
+
+
+_logger = logging.getLogger(__name__)
+
+
+@export
+def editor_command(command):
+ """
+ Is this an external editor command?
+ :param command: string
+ """
+ # It is possible to have `\e filename` or `SELECT * FROM \e`. So we check
+ # for both conditions.
+ return command.strip().endswith("\\e") or command.strip().startswith("\\e")
+
+
+@export
+def get_filename(sql):
+ if sql.strip().startswith("\\e"):
+ command, _, filename = sql.partition(" ")
+ return filename.strip() or None
+
+
+@export
+def get_editor_query(sql):
+ """Get the query part of an editor command."""
+ sql = sql.strip()
+
+ # The reason we can't simply do .strip('\e') is that it strips characters,
+ # not a substring. So it'll strip "e" in the end of the sql also!
+ # Ex: "select * from style\e" -> "select * from styl".
+ pattern = re.compile("(^\\\e|\\\e$)")
+ while pattern.search(sql):
+ sql = pattern.sub("", sql)
+
+ return sql
+
+
+@export
+def open_external_editor(filename=None, sql=None):
+ """Open external editor, wait for the user to type in their query, return
+ the query.
+
+ :return: list with one tuple, query as first element.
+
+ """
+
+ message = None
+ filename = filename.strip().split(" ", 1)[0] if filename else None
+
+ sql = sql or ""
+ MARKER = "# Type your query above this line.\n"
+
+ # Populate the editor buffer with the partial sql (if available) and a
+ # placeholder comment.
+ query = click.edit(
+ "{sql}\n\n{marker}".format(sql=sql, marker=MARKER),
+ filename=filename,
+ extension=".sql",
+ )
+
+ if filename:
+ try:
+ with open(filename, encoding="utf-8") as f:
+ query = f.read()
+ except IOError:
+ message = "Error reading file: %s." % filename
+
+ if query is not None:
+ query = query.split(MARKER, 1)[0].rstrip("\n")
+ else:
+ # Don't return None for the caller to deal with.
+ # Empty string is ok.
+ query = sql
+
+ return (query, message)
+
+
+@special_command(
+ "\\f",
+ "\\f [name [args..]]",
+ "List or execute favorite queries.",
+ arg_type=PARSED_QUERY,
+ case_sensitive=True,
+)
+def execute_favorite_query(cur, arg, verbose=False, **_):
+ """Returns (title, rows, headers, status)"""
+ if arg == "":
+ for result in list_favorite_queries():
+ yield result
+
+ """Parse out favorite name and optional substitution parameters"""
+ name, _, arg_str = arg.partition(" ")
+ args = shlex.split(arg_str)
+
+ query = favoritequeries.get(name)
+ if query is None:
+ message = "No favorite query: %s" % (name)
+ yield (None, None, None, message)
+ elif "?" in query:
+ for sql in sqlparse.split(query):
+ sql = sql.rstrip(";")
+ title = "> %s" % (sql) if verbose else None
+ cur.execute(sql, args)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ yield (title, cur, headers, None)
+ else:
+ yield (title, None, None, None)
+ else:
+ query, arg_error = subst_favorite_query_args(query, args)
+ if arg_error:
+ yield (None, None, None, arg_error)
+ else:
+ for sql in sqlparse.split(query):
+ sql = sql.rstrip(";")
+ title = "> %s" % (sql) if verbose else None
+ cur.execute(sql)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ yield (title, cur, headers, None)
+ else:
+ yield (title, None, None, None)
+
+
+def list_favorite_queries():
+ """List of all favorite queries.
+ Returns (title, rows, headers, status)"""
+
+ headers = ["Name", "Query"]
+ rows = [(r, favoritequeries.get(r)) for r in favoritequeries.list()]
+
+ if not rows:
+ status = "\nNo favorite queries found." + favoritequeries.usage
+ else:
+ status = ""
+ return [("", rows, headers, status)]
+
+
+def subst_favorite_query_args(query, args):
+ """replace positional parameters ($1...$N) in query."""
+ for idx, val in enumerate(args):
+ shell_subst_var = "$" + str(idx + 1)
+ question_subst_var = "?"
+ if shell_subst_var in query:
+ query = query.replace(shell_subst_var, val)
+ elif question_subst_var in query:
+ query = query.replace(question_subst_var, val, 1)
+ else:
+ return [
+ None,
+ "Too many arguments.\nQuery does not have enough place holders to substitute.\n"
+ + query,
+ ]
+
+ match = re.search("\\?|\\$\d+", query)
+ if match:
+ return [
+ None,
+ "missing substitution for " + match.group(0) + " in query:\n " + query,
+ ]
+
+ return [query, None]
+
+
+@special_command("\\fs", "\\fs name query", "Save a favorite query.")
+def save_favorite_query(arg, **_):
+ """Save a new favorite query.
+ Returns (title, rows, headers, status)"""
+
+ usage = "Syntax: \\fs name query.\n\n" + favoritequeries.usage
+ if not arg:
+ return [(None, None, None, usage)]
+
+ name, _, query = arg.partition(" ")
+
+ # If either name or query is missing then print the usage and complain.
+ if (not name) or (not query):
+ return [(None, None, None, usage + "Err: Both name and query are required.")]
+
+ favoritequeries.save(name, query)
+ return [(None, None, None, "Saved.")]
+
+
+@special_command("\\fd", "\\fd [name]", "Delete a favorite query.")
+def delete_favorite_query(arg, **_):
+ """Delete an existing favorite query.
+ """
+ usage = "Syntax: \\fd name.\n\n" + favoritequeries.usage
+ if not arg:
+ return [(None, None, None, usage)]
+
+ status = favoritequeries.delete(arg)
+
+ return [(None, None, None, status)]
+
+
+@special_command("system", "system [command]", "Execute a system shell commmand.")
+def execute_system_command(arg, **_):
+ """Execute a system shell command."""
+ usage = "Syntax: system [command].\n"
+
+ if not arg:
+ return [(None, None, None, usage)]
+
+ try:
+ command = arg.strip()
+ if command.startswith("cd"):
+ ok, error_message = handle_cd_command(arg)
+ if not ok:
+ return [(None, None, None, error_message)]
+ return [(None, None, None, "")]
+
+ args = arg.split(" ")
+ process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ output, error = process.communicate()
+ response = output if not error else error
+
+ # Python 3 returns bytes. This needs to be decoded to a string.
+ if isinstance(response, bytes):
+ encoding = locale.getpreferredencoding(False)
+ response = response.decode(encoding)
+
+ return [(None, None, None, response)]
+ except OSError as e:
+ return [(None, None, None, "OSError: %s" % e.strerror)]
+
+
+def parseargfile(arg):
+ if arg.startswith("-o "):
+ mode = "w"
+ filename = arg[3:]
+ else:
+ mode = "a"
+ filename = arg
+
+ if not filename:
+ raise TypeError("You must provide a filename.")
+
+ return {"file": os.path.expanduser(filename), "mode": mode}
+
+
+@special_command(
+ "tee",
+ "tee [-o] filename",
+ "Append all results to an output file (overwrite using -o).",
+)
+def set_tee(arg, **_):
+ global tee_file
+
+ try:
+ tee_file = open(**parseargfile(arg))
+ except (IOError, OSError) as e:
+ raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror))
+
+ return [(None, None, None, "")]
+
+
+@export
+def close_tee():
+ global tee_file
+ if tee_file:
+ tee_file.close()
+ tee_file = None
+
+
+@special_command("notee", "notee", "Stop writing results to an output file.")
+def no_tee(arg, **_):
+ close_tee()
+ return [(None, None, None, "")]
+
+
+@export
+def write_tee(output):
+ global tee_file
+ if tee_file:
+ click.echo(output, file=tee_file, nl=False)
+ click.echo("\n", file=tee_file, nl=False)
+ tee_file.flush()
+
+
+@special_command(
+ ".once",
+ "\\o [-o] filename",
+ "Append next result to an output file (overwrite using -o).",
+ aliases=("\\o", "\\once"),
+)
+def set_once(arg, **_):
+ global once_file
+
+ once_file = parseargfile(arg)
+
+ return [(None, None, None, "")]
+
+
+@export
+def write_once(output):
+ global once_file, written_to_once_file
+ if output and once_file:
+ try:
+ f = open(**once_file)
+ except (IOError, OSError) as e:
+ once_file = None
+ raise OSError(
+ "Cannot write to file '{}': {}".format(e.filename, e.strerror)
+ )
+
+ with f:
+ click.echo(output, file=f, nl=False)
+ click.echo("\n", file=f, nl=False)
+ written_to_once_file = True
+
+
+@export
+def unset_once_if_written():
+ """Unset the once file, if it has been written to."""
+ global once_file
+ if written_to_once_file:
+ once_file = None
+
+
+@special_command(
+ "watch",
+ "watch [seconds] [-c] query",
+ "Executes the query every [seconds] seconds (by default 5).",
+)
+def watch_query(arg, **kwargs):
+ usage = """Syntax: watch [seconds] [-c] query.
+ * seconds: The interval at the query will be repeated, in seconds.
+ By default 5.
+ * -c: Clears the screen between every iteration.
+"""
+ if not arg:
+ yield (None, None, None, usage)
+ raise StopIteration
+ seconds = 5
+ clear_screen = False
+ statement = None
+ while statement is None:
+ arg = arg.strip()
+ if not arg:
+ # Oops, we parsed all the arguments without finding a statement
+ yield (None, None, None, usage)
+ raise StopIteration
+ (current_arg, _, arg) = arg.partition(" ")
+ try:
+ seconds = float(current_arg)
+ continue
+ except ValueError:
+ pass
+ if current_arg == "-c":
+ clear_screen = True
+ continue
+ statement = "{0!s} {1!s}".format(current_arg, arg)
+ destructive_prompt = confirm_destructive_query(statement)
+ if destructive_prompt is False:
+ click.secho("Wise choice!")
+ raise StopIteration
+ elif destructive_prompt is True:
+ click.secho("Your call!")
+ cur = kwargs["cur"]
+ sql_list = [
+ (sql.rstrip(";"), "> {0!s}".format(sql)) for sql in sqlparse.split(statement)
+ ]
+ old_pager_enabled = is_pager_enabled()
+ while True:
+ if clear_screen:
+ click.clear()
+ try:
+ # Somewhere in the code the pager its activated after every yield,
+ # so we disable it in every iteration
+ set_pager_enabled(False)
+ for (sql, title) in sql_list:
+ cur.execute(sql)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ yield (title, cur, headers, None)
+ else:
+ yield (title, None, None, None)
+ sleep(seconds)
+ except KeyboardInterrupt:
+ # This prints the Ctrl-C character in its own line, which prevents
+ # to print a line with the cursor positioned behind the prompt
+ click.secho("", nl=True)
+ raise StopIteration
+ finally:
+ set_pager_enabled(old_pager_enabled)
diff --git a/litecli/packages/special/main.py b/litecli/packages/special/main.py
new file mode 100644
index 0000000..3dd0e77
--- /dev/null
+++ b/litecli/packages/special/main.py
@@ -0,0 +1,160 @@
+from __future__ import unicode_literals
+import logging
+from collections import namedtuple
+
+from . import export
+
+log = logging.getLogger(__name__)
+
+NO_QUERY = 0
+PARSED_QUERY = 1
+RAW_QUERY = 2
+
+SpecialCommand = namedtuple(
+ "SpecialCommand",
+ [
+ "handler",
+ "command",
+ "shortcut",
+ "description",
+ "arg_type",
+ "hidden",
+ "case_sensitive",
+ ],
+)
+
+COMMANDS = {}
+
+
+@export
+class ArgumentMissing(Exception):
+ pass
+
+
+@export
+class CommandNotFound(Exception):
+ pass
+
+
+@export
+def parse_special_command(sql):
+ command, _, arg = sql.partition(" ")
+ verbose = "+" in command
+ command = command.strip().replace("+", "")
+ return (command, verbose, arg.strip())
+
+
+@export
+def special_command(
+ command,
+ shortcut,
+ description,
+ arg_type=PARSED_QUERY,
+ hidden=False,
+ case_sensitive=False,
+ aliases=(),
+):
+ def wrapper(wrapped):
+ register_special_command(
+ wrapped,
+ command,
+ shortcut,
+ description,
+ arg_type,
+ hidden,
+ case_sensitive,
+ aliases,
+ )
+ return wrapped
+
+ return wrapper
+
+
+@export
+def register_special_command(
+ handler,
+ command,
+ shortcut,
+ description,
+ arg_type=PARSED_QUERY,
+ hidden=False,
+ case_sensitive=False,
+ aliases=(),
+):
+ cmd = command.lower() if not case_sensitive else command
+ COMMANDS[cmd] = SpecialCommand(
+ handler, command, shortcut, description, arg_type, hidden, case_sensitive
+ )
+ for alias in aliases:
+ cmd = alias.lower() if not case_sensitive else alias
+ COMMANDS[cmd] = SpecialCommand(
+ handler,
+ command,
+ shortcut,
+ description,
+ arg_type,
+ case_sensitive=case_sensitive,
+ hidden=True,
+ )
+
+
+@export
+def execute(cur, sql):
+ """Execute a special command and return the results. If the special command
+ is not supported a KeyError will be raised.
+ """
+ command, verbose, arg = parse_special_command(sql)
+
+ if (command not in COMMANDS) and (command.lower() not in COMMANDS):
+ raise CommandNotFound
+
+ try:
+ special_cmd = COMMANDS[command]
+ except KeyError:
+ special_cmd = COMMANDS[command.lower()]
+ if special_cmd.case_sensitive:
+ raise CommandNotFound("Command not found: %s" % command)
+
+ if special_cmd.arg_type == NO_QUERY:
+ return special_cmd.handler()
+ elif special_cmd.arg_type == PARSED_QUERY:
+ return special_cmd.handler(cur=cur, arg=arg, verbose=verbose)
+ elif special_cmd.arg_type == RAW_QUERY:
+ return special_cmd.handler(cur=cur, query=sql)
+
+
+@special_command(
+ "help", "\\?", "Show this help.", arg_type=NO_QUERY, aliases=("\\?", "?")
+)
+def show_help(): # All the parameters are ignored.
+ headers = ["Command", "Shortcut", "Description"]
+ result = []
+
+ for _, value in sorted(COMMANDS.items()):
+ if not value.hidden:
+ result.append((value.command, value.shortcut, value.description))
+ return [(None, result, headers, None)]
+
+
+@special_command(".exit", "\\q", "Exit.", arg_type=NO_QUERY, aliases=("\\q", "exit"))
+@special_command("quit", "\\q", "Quit.", arg_type=NO_QUERY)
+def quit(*_args):
+ raise EOFError
+
+
+@special_command(
+ "\\e",
+ "\\e",
+ "Edit command with editor (uses $EDITOR).",
+ arg_type=NO_QUERY,
+ case_sensitive=True,
+)
+@special_command(
+ "\\G",
+ "\\G",
+ "Display current query results vertically.",
+ arg_type=NO_QUERY,
+ case_sensitive=True,
+)
+def stub():
+ raise NotImplementedError
diff --git a/litecli/packages/special/utils.py b/litecli/packages/special/utils.py
new file mode 100644
index 0000000..eed9306
--- /dev/null
+++ b/litecli/packages/special/utils.py
@@ -0,0 +1,48 @@
+import os
+import subprocess
+
+
+def handle_cd_command(arg):
+ """Handles a `cd` shell command by calling python's os.chdir."""
+ CD_CMD = "cd"
+ tokens = arg.split(CD_CMD + " ")
+ directory = tokens[-1] if len(tokens) > 1 else None
+ if not directory:
+ return False, "No folder name was provided."
+ try:
+ os.chdir(directory)
+ subprocess.call(["pwd"])
+ return True, None
+ except OSError as e:
+ return False, e.strerror
+
+
+def format_uptime(uptime_in_seconds):
+ """Format number of seconds into human-readable string.
+
+ :param uptime_in_seconds: The server uptime in seconds.
+ :returns: A human-readable string representing the uptime.
+
+ >>> uptime = format_uptime('56892')
+ >>> print(uptime)
+ 15 hours 48 min 12 sec
+ """
+
+ m, s = divmod(int(uptime_in_seconds), 60)
+ h, m = divmod(m, 60)
+ d, h = divmod(h, 24)
+
+ uptime_values = []
+
+ for value, unit in ((d, "days"), (h, "hours"), (m, "min"), (s, "sec")):
+ if value == 0 and not uptime_values:
+ # Don't include a value/unit if the unit isn't applicable to
+ # the uptime. E.g. don't do 0 days 0 hours 1 min 30 sec.
+ continue
+ elif value == 1 and unit.endswith("s"):
+ # Remove the "s" if the unit is singular.
+ unit = unit[:-1]
+ uptime_values.append("{0} {1}".format(value, unit))
+
+ uptime = " ".join(uptime_values)
+ return uptime