summaryrefslogtreecommitdiffstats
path: root/litecli
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2020-08-14 16:58:23 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2021-02-07 10:11:48 +0000
commit4edd467b28c895483cd5468d51d1c6824a21715a (patch)
tree04a4f32d617905acfc23653025b6e8d3899f51c6 /litecli
parentInitial commit. (diff)
downloadlitecli-4edd467b28c895483cd5468d51d1c6824a21715a.tar.xz
litecli-4edd467b28c895483cd5468d51d1c6824a21715a.zip
Adding upstream version 1.5.0.upstream/1.5.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'litecli')
-rw-r--r--litecli/AUTHORS20
-rw-r--r--litecli/__init__.py1
-rw-r--r--litecli/clibuffer.py40
-rw-r--r--litecli/clistyle.py114
-rw-r--r--litecli/clitoolbar.py51
-rw-r--r--litecli/compat.py9
-rw-r--r--litecli/completion_refresher.py131
-rw-r--r--litecli/config.py62
-rw-r--r--litecli/encodingutils.py38
-rw-r--r--litecli/key_bindings.py84
-rw-r--r--litecli/lexer.py9
-rw-r--r--litecli/liteclirc113
-rw-r--r--litecli/main.py1008
-rw-r--r--litecli/packages/__init__.py0
-rw-r--r--litecli/packages/completion_engine.py331
-rw-r--r--litecli/packages/filepaths.py88
-rw-r--r--litecli/packages/parseutils.py227
-rw-r--r--litecli/packages/prompt_utils.py39
-rw-r--r--litecli/packages/special/__init__.py12
-rw-r--r--litecli/packages/special/dbcommands.py273
-rw-r--r--litecli/packages/special/favoritequeries.py59
-rw-r--r--litecli/packages/special/iocommands.py479
-rw-r--r--litecli/packages/special/main.py160
-rw-r--r--litecli/packages/special/utils.py48
-rw-r--r--litecli/sqlcompleter.py612
-rw-r--r--litecli/sqlexecute.py212
26 files changed, 4220 insertions, 0 deletions
diff --git a/litecli/AUTHORS b/litecli/AUTHORS
new file mode 100644
index 0000000..d5265de
--- /dev/null
+++ b/litecli/AUTHORS
@@ -0,0 +1,20 @@
+Project Lead:
+-------------
+
+ * Delgermurun Purevkhu
+
+
+Core Developers:
+----------------
+
+ * Amjith Ramanujam
+ * Irina Truong
+ * Dick Marinus
+
+Contributors:
+-------------
+
+ * Thomas Roten
+ * Zhaolong Zhu
+ * Zhiming Wang
+ * Shawn M. Chapla
diff --git a/litecli/__init__.py b/litecli/__init__.py
new file mode 100644
index 0000000..5b60188
--- /dev/null
+++ b/litecli/__init__.py
@@ -0,0 +1 @@
+__version__ = "1.5.0"
diff --git a/litecli/clibuffer.py b/litecli/clibuffer.py
new file mode 100644
index 0000000..a57192a
--- /dev/null
+++ b/litecli/clibuffer.py
@@ -0,0 +1,40 @@
+from __future__ import unicode_literals
+
+from prompt_toolkit.enums import DEFAULT_BUFFER
+from prompt_toolkit.filters import Condition
+from prompt_toolkit.application import get_app
+
+
+def cli_is_multiline(cli):
+ @Condition
+ def cond():
+ doc = get_app().layout.get_buffer_by_name(DEFAULT_BUFFER).document
+
+ if not cli.multi_line:
+ return False
+ else:
+ return not _multiline_exception(doc.text)
+
+ return cond
+
+
+def _multiline_exception(text):
+ orig = text
+ text = text.strip()
+
+ # Multi-statement favorite query is a special case. Because there will
+ # be a semicolon separating statements, we can't consider semicolon an
+ # EOL. Let's consider an empty line an EOL instead.
+ if text.startswith("\\fs"):
+ return orig.endswith("\n")
+
+ return (
+ text.startswith("\\") # Special Command
+ or text.endswith(";") # Ended with a semi-colon
+ or text.endswith("\\g") # Ended with \g
+ or text.endswith("\\G") # Ended with \G
+ or (text == "exit") # Exit doesn't need semi-colon
+ or (text == "quit") # Quit doesn't need semi-colon
+ or (text == ":q") # To all the vim fans out there
+ or (text == "") # Just a plain enter without any text
+ )
diff --git a/litecli/clistyle.py b/litecli/clistyle.py
new file mode 100644
index 0000000..7527315
--- /dev/null
+++ b/litecli/clistyle.py
@@ -0,0 +1,114 @@
+from __future__ import unicode_literals
+
+import logging
+
+import pygments.styles
+from pygments.token import string_to_tokentype, Token
+from pygments.style import Style as PygmentsStyle
+from pygments.util import ClassNotFound
+from prompt_toolkit.styles.pygments import style_from_pygments_cls
+from prompt_toolkit.styles import merge_styles, Style
+
+logger = logging.getLogger(__name__)
+
+# map Pygments tokens (ptk 1.0) to class names (ptk 2.0).
+TOKEN_TO_PROMPT_STYLE = {
+ Token.Menu.Completions.Completion.Current: "completion-menu.completion.current",
+ Token.Menu.Completions.Completion: "completion-menu.completion",
+ Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current",
+ Token.Menu.Completions.Meta: "completion-menu.meta.completion",
+ Token.Menu.Completions.MultiColumnMeta: "completion-menu.multi-column-meta",
+ Token.Menu.Completions.ProgressButton: "scrollbar.arrow", # best guess
+ Token.Menu.Completions.ProgressBar: "scrollbar", # best guess
+ Token.SelectedText: "selected",
+ Token.SearchMatch: "search",
+ Token.SearchMatch.Current: "search.current",
+ Token.Toolbar: "bottom-toolbar",
+ Token.Toolbar.Off: "bottom-toolbar.off",
+ Token.Toolbar.On: "bottom-toolbar.on",
+ Token.Toolbar.Search: "search-toolbar",
+ Token.Toolbar.Search.Text: "search-toolbar.text",
+ Token.Toolbar.System: "system-toolbar",
+ Token.Toolbar.Arg: "arg-toolbar",
+ Token.Toolbar.Arg.Text: "arg-toolbar.text",
+ Token.Toolbar.Transaction.Valid: "bottom-toolbar.transaction.valid",
+ Token.Toolbar.Transaction.Failed: "bottom-toolbar.transaction.failed",
+ Token.Output.Header: "output.header",
+ Token.Output.OddRow: "output.odd-row",
+ Token.Output.EvenRow: "output.even-row",
+ Token.Prompt: "prompt",
+ Token.Continuation: "continuation",
+}
+
+# reverse dict for cli_helpers, because they still expect Pygments tokens.
+PROMPT_STYLE_TO_TOKEN = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()}
+
+
+def parse_pygments_style(token_name, style_object, style_dict):
+ """Parse token type and style string.
+
+ :param token_name: str name of Pygments token. Example: "Token.String"
+ :param style_object: pygments.style.Style instance to use as base
+ :param style_dict: dict of token names and their styles, customized to this cli
+
+ """
+ token_type = string_to_tokentype(token_name)
+ try:
+ other_token_type = string_to_tokentype(style_dict[token_name])
+ return token_type, style_object.styles[other_token_type]
+ except AttributeError as err:
+ return token_type, style_dict[token_name]
+
+
+def style_factory(name, cli_style):
+ try:
+ style = pygments.styles.get_style_by_name(name)
+ except ClassNotFound:
+ style = pygments.styles.get_style_by_name("native")
+
+ prompt_styles = []
+ # prompt-toolkit used pygments tokens for styling before, switched to style
+ # names in 2.0. Convert old token types to new style names, for backwards compatibility.
+ for token in cli_style:
+ if token.startswith("Token."):
+ # treat as pygments token (1.0)
+ token_type, style_value = parse_pygments_style(token, style, cli_style)
+ if token_type in TOKEN_TO_PROMPT_STYLE:
+ prompt_style = TOKEN_TO_PROMPT_STYLE[token_type]
+ prompt_styles.append((prompt_style, style_value))
+ else:
+ # we don't want to support tokens anymore
+ logger.error("Unhandled style / class name: %s", token)
+ else:
+ # treat as prompt style name (2.0). See default style names here:
+ # https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py
+ prompt_styles.append((token, cli_style[token]))
+
+ override_style = Style([("bottom-toolbar", "noreverse")])
+ return merge_styles(
+ [style_from_pygments_cls(style), override_style, Style(prompt_styles)]
+ )
+
+
+def style_factory_output(name, cli_style):
+ try:
+ style = pygments.styles.get_style_by_name(name).styles
+ except ClassNotFound:
+ style = pygments.styles.get_style_by_name("native").styles
+
+ for token in cli_style:
+ if token.startswith("Token."):
+ token_type, style_value = parse_pygments_style(token, style, cli_style)
+ style.update({token_type: style_value})
+ elif token in PROMPT_STYLE_TO_TOKEN:
+ token_type = PROMPT_STYLE_TO_TOKEN[token]
+ style.update({token_type: cli_style[token]})
+ else:
+ # TODO: cli helpers will have to switch to ptk.Style
+ logger.error("Unhandled style / class name: %s", token)
+
+ class OutputStyle(PygmentsStyle):
+ default_style = ""
+ styles = style
+
+ return OutputStyle
diff --git a/litecli/clitoolbar.py b/litecli/clitoolbar.py
new file mode 100644
index 0000000..05d0bfd
--- /dev/null
+++ b/litecli/clitoolbar.py
@@ -0,0 +1,51 @@
+from __future__ import unicode_literals
+
+from prompt_toolkit.key_binding.vi_state import InputMode
+from prompt_toolkit.enums import EditingMode
+from prompt_toolkit.application import get_app
+
+
+def create_toolbar_tokens_func(cli, show_fish_help):
+ """
+ Return a function that generates the toolbar tokens.
+ """
+
+ def get_toolbar_tokens():
+ result = []
+ result.append(("class:bottom-toolbar", " "))
+
+ if cli.multi_line:
+ result.append(
+ ("class:bottom-toolbar", " (Semi-colon [;] will end the line) ")
+ )
+
+ if cli.multi_line:
+ result.append(("class:bottom-toolbar.on", "[F3] Multiline: ON "))
+ else:
+ result.append(("class:bottom-toolbar.off", "[F3] Multiline: OFF "))
+ if cli.prompt_app.editing_mode == EditingMode.VI:
+ result.append(
+ ("class:botton-toolbar.on", "Vi-mode ({})".format(_get_vi_mode()))
+ )
+
+ if show_fish_help():
+ result.append(
+ ("class:bottom-toolbar", " Right-arrow to complete suggestion")
+ )
+
+ if cli.completion_refresher.is_refreshing():
+ result.append(("class:bottom-toolbar", " Refreshing completions..."))
+
+ return result
+
+ return get_toolbar_tokens
+
+
+def _get_vi_mode():
+ """Get the current vi mode for display."""
+ return {
+ InputMode.INSERT: "I",
+ InputMode.NAVIGATION: "N",
+ InputMode.REPLACE: "R",
+ InputMode.INSERT_MULTIPLE: "M",
+ }[get_app().vi_state.input_mode]
diff --git a/litecli/compat.py b/litecli/compat.py
new file mode 100644
index 0000000..7316261
--- /dev/null
+++ b/litecli/compat.py
@@ -0,0 +1,9 @@
+# -*- coding: utf-8 -*-
+"""Platform and Python version compatibility support."""
+
+import sys
+
+
+PY2 = sys.version_info[0] == 2
+PY3 = sys.version_info[0] == 3
+WIN = sys.platform in ("win32", "cygwin")
diff --git a/litecli/completion_refresher.py b/litecli/completion_refresher.py
new file mode 100644
index 0000000..9602070
--- /dev/null
+++ b/litecli/completion_refresher.py
@@ -0,0 +1,131 @@
+import threading
+from .packages.special.main import COMMANDS
+from collections import OrderedDict
+
+from .sqlcompleter import SQLCompleter
+from .sqlexecute import SQLExecute
+
+
+class CompletionRefresher(object):
+
+ refreshers = OrderedDict()
+
+ def __init__(self):
+ self._completer_thread = None
+ self._restart_refresh = threading.Event()
+
+ def refresh(self, executor, callbacks, completer_options=None):
+ """Creates a SQLCompleter object and populates it with the relevant
+ completion suggestions in a background thread.
+
+ executor - SQLExecute object, used to extract the credentials to connect
+ to the database.
+ callbacks - A function or a list of functions to call after the thread
+ has completed the refresh. The newly created completion
+ object will be passed in as an argument to each callback.
+ completer_options - dict of options to pass to SQLCompleter.
+
+ """
+ if completer_options is None:
+ completer_options = {}
+
+ if self.is_refreshing():
+ self._restart_refresh.set()
+ return [(None, None, None, "Auto-completion refresh restarted.")]
+ else:
+ if executor.dbname == ":memory:":
+ # if DB is memory, needed to use same connection
+ # So can't use same connection with different thread
+ self._bg_refresh(executor, callbacks, completer_options)
+ else:
+ self._completer_thread = threading.Thread(
+ target=self._bg_refresh,
+ args=(executor, callbacks, completer_options),
+ name="completion_refresh",
+ )
+ self._completer_thread.setDaemon(True)
+ self._completer_thread.start()
+ return [
+ (
+ None,
+ None,
+ None,
+ "Auto-completion refresh started in the background.",
+ )
+ ]
+
+ def is_refreshing(self):
+ return self._completer_thread and self._completer_thread.is_alive()
+
+ def _bg_refresh(self, sqlexecute, callbacks, completer_options):
+ completer = SQLCompleter(**completer_options)
+
+ e = sqlexecute
+ if e.dbname == ":memory:":
+ # if DB is memory, needed to use same connection
+ executor = sqlexecute
+ else:
+ # Create a new sqlexecute method to popoulate the completions.
+ executor = SQLExecute(e.dbname)
+
+ # If callbacks is a single function then push it into a list.
+ if callable(callbacks):
+ callbacks = [callbacks]
+
+ while 1:
+ for refresher in self.refreshers.values():
+ refresher(completer, executor)
+ if self._restart_refresh.is_set():
+ self._restart_refresh.clear()
+ break
+ else:
+ # Break out of while loop if the for loop finishes natually
+ # without hitting the break statement.
+ break
+
+ # Start over the refresh from the beginning if the for loop hit the
+ # break statement.
+ continue
+
+ for callback in callbacks:
+ callback(completer)
+
+
+def refresher(name, refreshers=CompletionRefresher.refreshers):
+ """Decorator to add the decorated function to the dictionary of
+ refreshers. Any function decorated with a @refresher will be executed as
+ part of the completion refresh routine."""
+
+ def wrapper(wrapped):
+ refreshers[name] = wrapped
+ return wrapped
+
+ return wrapper
+
+
+@refresher("databases")
+def refresh_databases(completer, executor):
+ completer.extend_database_names(executor.databases())
+
+
+@refresher("schemata")
+def refresh_schemata(completer, executor):
+ # name of the current database.
+ completer.extend_schemata(executor.dbname)
+ completer.set_dbname(executor.dbname)
+
+
+@refresher("tables")
+def refresh_tables(completer, executor):
+ completer.extend_relations(executor.tables(), kind="tables")
+ completer.extend_columns(executor.table_columns(), kind="tables")
+
+
+@refresher("functions")
+def refresh_functions(completer, executor):
+ completer.extend_functions(executor.functions())
+
+
+@refresher("special_commands")
+def refresh_special(completer, executor):
+ completer.extend_special_commands(COMMANDS.keys())
diff --git a/litecli/config.py b/litecli/config.py
new file mode 100644
index 0000000..1c7fb25
--- /dev/null
+++ b/litecli/config.py
@@ -0,0 +1,62 @@
+import errno
+import shutil
+import os
+import platform
+from os.path import expanduser, exists, dirname
+from configobj import ConfigObj
+
+
+def config_location():
+ if "XDG_CONFIG_HOME" in os.environ:
+ return "%s/litecli/" % expanduser(os.environ["XDG_CONFIG_HOME"])
+ elif platform.system() == "Windows":
+ return os.getenv("USERPROFILE") + "\\AppData\\Local\\dbcli\\litecli\\"
+ else:
+ return expanduser("~/.config/litecli/")
+
+
+def load_config(usr_cfg, def_cfg=None):
+ cfg = ConfigObj()
+ cfg.merge(ConfigObj(def_cfg, interpolation=False))
+ cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8"))
+ cfg.filename = expanduser(usr_cfg)
+
+ return cfg
+
+
+def ensure_dir_exists(path):
+ parent_dir = expanduser(dirname(path))
+ try:
+ os.makedirs(parent_dir)
+ except OSError as exc:
+ # ignore existing destination (py2 has no exist_ok arg to makedirs)
+ if exc.errno != errno.EEXIST:
+ raise
+
+
+def write_default_config(source, destination, overwrite=False):
+ destination = expanduser(destination)
+ if not overwrite and exists(destination):
+ return
+
+ ensure_dir_exists(destination)
+
+ shutil.copyfile(source, destination)
+
+
+def upgrade_config(config, def_config):
+ cfg = load_config(config, def_config)
+ cfg.write()
+
+
+def get_config(liteclirc_file=None):
+ from litecli import __file__ as package_root
+
+ package_root = os.path.dirname(package_root)
+
+ liteclirc_file = liteclirc_file or "%sconfig" % config_location()
+
+ default_config = os.path.join(package_root, "liteclirc")
+ write_default_config(default_config, liteclirc_file)
+
+ return load_config(liteclirc_file, default_config)
diff --git a/litecli/encodingutils.py b/litecli/encodingutils.py
new file mode 100644
index 0000000..6caf14d
--- /dev/null
+++ b/litecli/encodingutils.py
@@ -0,0 +1,38 @@
+# -*- coding: utf-8 -*-
+from __future__ import unicode_literals
+
+from litecli.compat import PY2
+
+
+if PY2:
+ binary_type = str
+ string_types = basestring
+ text_type = unicode
+else:
+ binary_type = bytes
+ string_types = str
+ text_type = str
+
+
+def unicode2utf8(arg):
+ """Convert strings to UTF8-encoded bytes.
+
+ Only in Python 2. In Python 3 the args are expected as unicode.
+
+ """
+
+ if PY2 and isinstance(arg, text_type):
+ return arg.encode("utf-8")
+ return arg
+
+
+def utf8tounicode(arg):
+ """Convert UTF8-encoded bytes to strings.
+
+ Only in Python 2. In Python 3 the errors are returned as strings.
+
+ """
+
+ if PY2 and isinstance(arg, binary_type):
+ return arg.decode("utf-8")
+ return arg
diff --git a/litecli/key_bindings.py b/litecli/key_bindings.py
new file mode 100644
index 0000000..44d59d2
--- /dev/null
+++ b/litecli/key_bindings.py
@@ -0,0 +1,84 @@
+from __future__ import unicode_literals
+import logging
+from prompt_toolkit.enums import EditingMode
+from prompt_toolkit.filters import completion_is_selected
+from prompt_toolkit.key_binding import KeyBindings
+
+_logger = logging.getLogger(__name__)
+
+
+def cli_bindings(cli):
+ """Custom key bindings for cli."""
+ kb = KeyBindings()
+
+ @kb.add("f3")
+ def _(event):
+ """Enable/Disable Multiline Mode."""
+ _logger.debug("Detected F3 key.")
+ cli.multi_line = not cli.multi_line
+
+ @kb.add("f4")
+ def _(event):
+ """Toggle between Vi and Emacs mode."""
+ _logger.debug("Detected F4 key.")
+ if cli.key_bindings == "vi":
+ event.app.editing_mode = EditingMode.EMACS
+ cli.key_bindings = "emacs"
+ else:
+ event.app.editing_mode = EditingMode.VI
+ cli.key_bindings = "vi"
+
+ @kb.add("tab")
+ def _(event):
+ """Force autocompletion at cursor."""
+ _logger.debug("Detected <Tab> key.")
+ b = event.app.current_buffer
+ if b.complete_state:
+ b.complete_next()
+ else:
+ b.start_completion(select_first=True)
+
+ @kb.add("s-tab")
+ def _(event):
+ """Force autocompletion at cursor."""
+ _logger.debug("Detected <Tab> key.")
+ b = event.app.current_buffer
+ if b.complete_state:
+ b.complete_previous()
+ else:
+ b.start_completion(select_last=True)
+
+ @kb.add("c-space")
+ def _(event):
+ """
+ Initialize autocompletion at cursor.
+
+ If the autocompletion menu is not showing, display it with the
+ appropriate completions for the context.
+
+ If the menu is showing, select the next completion.
+ """
+ _logger.debug("Detected <C-Space> key.")
+
+ b = event.app.current_buffer
+ if b.complete_state:
+ b.complete_next()
+ else:
+ b.start_completion(select_first=False)
+
+ @kb.add("enter", filter=completion_is_selected)
+ def _(event):
+ """Makes the enter key work as the tab key only when showing the menu.
+
+ In other words, don't execute query when enter is pressed in
+ the completion dropdown menu, instead close the dropdown menu
+ (accept current selection).
+
+ """
+ _logger.debug("Detected enter key.")
+
+ event.current_buffer.complete_state = None
+ b = event.app.current_buffer
+ b.complete_state = None
+
+ return kb
diff --git a/litecli/lexer.py b/litecli/lexer.py
new file mode 100644
index 0000000..678eb3f
--- /dev/null
+++ b/litecli/lexer.py
@@ -0,0 +1,9 @@
+from pygments.lexer import inherit
+from pygments.lexers.sql import MySqlLexer
+from pygments.token import Keyword
+
+
+class LiteCliLexer(MySqlLexer):
+ """Extends SQLite lexer to add keywords."""
+
+ tokens = {"root": [(r"\brepair\b", Keyword), (r"\boffset\b", Keyword), inherit]}
diff --git a/litecli/liteclirc b/litecli/liteclirc
new file mode 100644
index 0000000..e3331d1
--- /dev/null
+++ b/litecli/liteclirc
@@ -0,0 +1,113 @@
+# vi: ft=dosini
+[main]
+
+# Multi-line mode allows breaking up the sql statements into multiple lines. If
+# this is set to True, then the end of the statements must have a semi-colon.
+# If this is set to False then sql statements can't be split into multiple
+# lines. End of line (return) is considered as the end of the statement.
+multi_line = False
+
+# Destructive warning mode will alert you before executing a sql statement
+# that may cause harm to the database such as "drop table", "drop database"
+# or "shutdown".
+destructive_warning = True
+
+# log_file location.
+# In Unix/Linux: ~/.config/litecli/log
+# In Windows: %USERPROFILE%\AppData\Local\dbcli\litecli\log
+# %USERPROFILE% is typically C:\Users\{username}
+log_file = default
+
+# Default log level. Possible values: "CRITICAL", "ERROR", "WARNING", "INFO"
+# and "DEBUG". "NONE" disables logging.
+log_level = INFO
+
+# Log every query and its results to a file. Enable this by uncommenting the
+# line below.
+# audit_log = ~/.litecli-audit.log
+
+# Default pager.
+# By default '$PAGER' environment variable is used
+# pager = less -SRXF
+
+# Table format. Possible values:
+# ascii, double, github, psql, plain, simple, grid, fancy_grid, pipe, orgtbl,
+# rst, mediawiki, html, latex, latex_booktabs, textile, moinmoin, jira,
+# vertical, tsv, csv.
+# Recommended: ascii
+table_format = ascii
+
+# Syntax coloring style. Possible values (many support the "-dark" suffix):
+# manni, igor, xcode, vim, autumn, vs, rrt, native, perldoc, borland, tango, emacs,
+# friendly, monokai, paraiso, colorful, murphy, bw, pastie, paraiso, trac, default,
+# fruity.
+# Screenshots at http://mycli.net/syntax
+syntax_style = default
+
+# Keybindings: Possible values: emacs, vi.
+# Emacs mode: Ctrl-A is home, Ctrl-E is end. All emacs keybindings are available in the REPL.
+# When Vi mode is enabled you can use modal editing features offered by Vi in the REPL.
+key_bindings = emacs
+
+# Enabling this option will show the suggestions in a wider menu. Thus more items are suggested.
+wider_completion_menu = False
+
+# litecli prompt
+# \D - The full current date
+# \d - Database name
+# \m - Minutes of the current time
+# \n - Newline
+# \P - AM/PM
+# \R - The current time, in 24-hour military time (0-23)
+# \r - The current time, standard 12-hour time (1-12)
+# \s - Seconds of the current time
+prompt = '\d> '
+prompt_continuation = '-> '
+
+# Skip intro info on startup and outro info on exit
+less_chatty = False
+
+# Use alias from --login-path instead of host name in prompt
+login_path_as_host = False
+
+# Cause result sets to be displayed vertically if they are too wide for the current window,
+# and using normal tabular format otherwise. (This applies to statements terminated by ; or \G.)
+auto_vertical_output = False
+
+# keyword casing preference. Possible values "lower", "upper", "auto"
+keyword_casing = auto
+
+# disabled pager on startup
+enable_pager = True
+
+# Custom colors for the completion menu, toolbar, etc.
+[colors]
+completion-menu.completion.current = 'bg:#ffffff #000000'
+completion-menu.completion = 'bg:#008888 #ffffff'
+completion-menu.meta.completion.current = 'bg:#44aaaa #000000'
+completion-menu.meta.completion = 'bg:#448888 #ffffff'
+completion-menu.multi-column-meta = 'bg:#aaffff #000000'
+scrollbar.arrow = 'bg:#003333'
+scrollbar = 'bg:#00aaaa'
+selected = '#ffffff bg:#6666aa'
+search = '#ffffff bg:#4444aa'
+search.current = '#ffffff bg:#44aa44'
+bottom-toolbar = 'bg:#222222 #aaaaaa'
+bottom-toolbar.off = 'bg:#222222 #888888'
+bottom-toolbar.on = 'bg:#222222 #ffffff'
+search-toolbar = 'noinherit bold'
+search-toolbar.text = 'nobold'
+system-toolbar = 'noinherit bold'
+arg-toolbar = 'noinherit bold'
+arg-toolbar.text = 'nobold'
+bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold'
+bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold'
+
+# style classes for colored table output
+output.header = "#00ff5f bold"
+output.odd-row = ""
+output.even-row = ""
+
+
+# Favorite queries.
+[favorite_queries]
diff --git a/litecli/main.py b/litecli/main.py
new file mode 100644
index 0000000..5768851
--- /dev/null
+++ b/litecli/main.py
@@ -0,0 +1,1008 @@
+from __future__ import unicode_literals
+from __future__ import print_function
+
+import os
+import sys
+import traceback
+import logging
+import threading
+from time import time
+from datetime import datetime
+from io import open
+from collections import namedtuple
+from sqlite3 import OperationalError
+
+from cli_helpers.tabular_output import TabularOutputFormatter
+from cli_helpers.tabular_output import preprocessors
+import click
+import sqlparse
+from prompt_toolkit.completion import DynamicCompleter
+from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
+from prompt_toolkit.shortcuts import PromptSession, CompleteStyle
+from prompt_toolkit.styles.pygments import style_from_pygments_cls
+from prompt_toolkit.document import Document
+from prompt_toolkit.filters import HasFocus, IsDone
+from prompt_toolkit.layout.processors import (
+ HighlightMatchingBracketProcessor,
+ ConditionalProcessor,
+)
+from prompt_toolkit.lexers import PygmentsLexer
+from prompt_toolkit.history import FileHistory
+from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
+
+from .packages.special.main import NO_QUERY
+from .packages.prompt_utils import confirm, confirm_destructive_query
+from .packages import special
+from .sqlcompleter import SQLCompleter
+from .clitoolbar import create_toolbar_tokens_func
+from .clistyle import style_factory, style_factory_output
+from .sqlexecute import SQLExecute
+from .clibuffer import cli_is_multiline
+from .completion_refresher import CompletionRefresher
+from .config import config_location, ensure_dir_exists, get_config
+from .key_bindings import cli_bindings
+from .encodingutils import utf8tounicode, text_type
+from .lexer import LiteCliLexer
+from .__init__ import __version__
+from .packages.filepaths import dir_path_exists
+
+import itertools
+
+click.disable_unicode_literals_warning = True
+
+# Query tuples are used for maintaining history
+Query = namedtuple("Query", ["query", "successful", "mutating"])
+
+PACKAGE_ROOT = os.path.abspath(os.path.dirname(__file__))
+
+
+class LiteCli(object):
+
+ default_prompt = "\\d> "
+ max_len_prompt = 45
+
+ def __init__(
+ self,
+ sqlexecute=None,
+ prompt=None,
+ logfile=None,
+ auto_vertical_output=False,
+ warn=None,
+ liteclirc=None,
+ ):
+ self.sqlexecute = sqlexecute
+ self.logfile = logfile
+
+ # Load config.
+ c = self.config = get_config(liteclirc)
+
+ self.multi_line = c["main"].as_bool("multi_line")
+ self.key_bindings = c["main"]["key_bindings"]
+ special.set_favorite_queries(self.config)
+ self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"])
+ self.formatter.litecli = self
+ self.syntax_style = c["main"]["syntax_style"]
+ self.less_chatty = c["main"].as_bool("less_chatty")
+ self.cli_style = c["colors"]
+ self.output_style = style_factory_output(self.syntax_style, self.cli_style)
+ self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
+ c_dest_warning = c["main"].as_bool("destructive_warning")
+ self.destructive_warning = c_dest_warning if warn is None else warn
+ self.login_path_as_host = c["main"].as_bool("login_path_as_host")
+
+ # read from cli argument or user config file
+ self.auto_vertical_output = auto_vertical_output or c["main"].as_bool(
+ "auto_vertical_output"
+ )
+
+ # audit log
+ if self.logfile is None and "audit_log" in c["main"]:
+ try:
+ self.logfile = open(os.path.expanduser(c["main"]["audit_log"]), "a")
+ except (IOError, OSError):
+ self.echo(
+ "Error: Unable to open the audit log file. Your queries will not be logged.",
+ err=True,
+ fg="red",
+ )
+ self.logfile = False
+
+ self.completion_refresher = CompletionRefresher()
+
+ self.logger = logging.getLogger(__name__)
+ self.initialize_logging()
+
+ prompt_cnf = self.read_my_cnf_files(["prompt"])["prompt"]
+ self.prompt_format = (
+ prompt or prompt_cnf or c["main"]["prompt"] or self.default_prompt
+ )
+ self.prompt_continuation_format = c["main"]["prompt_continuation"]
+ keyword_casing = c["main"].get("keyword_casing", "auto")
+
+ self.query_history = []
+
+ # Initialize completer.
+ self.completer = SQLCompleter(
+ supported_formats=self.formatter.supported_formats,
+ keyword_casing=keyword_casing,
+ )
+ self._completer_lock = threading.Lock()
+
+ # Register custom special commands.
+ self.register_special_commands()
+
+ self.prompt_app = None
+
+ def register_special_commands(self):
+ special.register_special_command(
+ self.change_db,
+ ".open",
+ ".open",
+ "Change to a new database.",
+ aliases=("use", "\\u"),
+ )
+ special.register_special_command(
+ self.refresh_completions,
+ "rehash",
+ "\\#",
+ "Refresh auto-completions.",
+ arg_type=NO_QUERY,
+ aliases=("\\#",),
+ )
+ special.register_special_command(
+ self.change_table_format,
+ ".mode",
+ "\\T",
+ "Change the table format used to output results.",
+ aliases=("tableformat", "\\T"),
+ case_sensitive=True,
+ )
+ special.register_special_command(
+ self.execute_from_file,
+ "source",
+ "\\. filename",
+ "Execute commands from file.",
+ aliases=("\\.",),
+ )
+ special.register_special_command(
+ self.change_prompt_format,
+ "prompt",
+ "\\R",
+ "Change prompt format.",
+ aliases=("\\R",),
+ case_sensitive=True,
+ )
+
+ def change_table_format(self, arg, **_):
+ try:
+ self.formatter.format_name = arg
+ yield (None, None, None, "Changed table format to {}".format(arg))
+ except ValueError:
+ msg = "Table format {} not recognized. Allowed formats:".format(arg)
+ for table_type in self.formatter.supported_formats:
+ msg += "\n\t{}".format(table_type)
+ yield (None, None, None, msg)
+
+ def change_db(self, arg, **_):
+ if arg is None:
+ self.sqlexecute.connect()
+ else:
+ self.sqlexecute.connect(database=arg)
+
+ self.refresh_completions()
+ yield (
+ None,
+ None,
+ None,
+ 'You are now connected to database "%s"' % (self.sqlexecute.dbname),
+ )
+
+ def execute_from_file(self, arg, **_):
+ if not arg:
+ message = "Missing required argument, filename."
+ return [(None, None, None, message)]
+ try:
+ with open(os.path.expanduser(arg), encoding="utf-8") as f:
+ query = f.read()
+ except IOError as e:
+ return [(None, None, None, str(e))]
+
+ if self.destructive_warning and confirm_destructive_query(query) is False:
+ message = "Wise choice. Command execution stopped."
+ return [(None, None, None, message)]
+
+ return self.sqlexecute.run(query)
+
+ def change_prompt_format(self, arg, **_):
+ """
+ Change the prompt format.
+ """
+ if not arg:
+ message = "Missing required argument, format."
+ return [(None, None, None, message)]
+
+ self.prompt_format = self.get_prompt(arg)
+ return [(None, None, None, "Changed prompt format to %s" % arg)]
+
+ def initialize_logging(self):
+
+ log_file = self.config["main"]["log_file"]
+ if log_file == "default":
+ log_file = config_location() + "log"
+ ensure_dir_exists(log_file)
+
+ log_level = self.config["main"]["log_level"]
+
+ level_map = {
+ "CRITICAL": logging.CRITICAL,
+ "ERROR": logging.ERROR,
+ "WARNING": logging.WARNING,
+ "INFO": logging.INFO,
+ "DEBUG": logging.DEBUG,
+ }
+
+ # Disable logging if value is NONE by switching to a no-op handler
+ # Set log level to a high value so it doesn't even waste cycles getting called.
+ if log_level.upper() == "NONE":
+ handler = logging.NullHandler()
+ log_level = "CRITICAL"
+ elif dir_path_exists(log_file):
+ handler = logging.FileHandler(log_file)
+ else:
+ self.echo(
+ 'Error: Unable to open the log file "{}".'.format(log_file),
+ err=True,
+ fg="red",
+ )
+ return
+
+ formatter = logging.Formatter(
+ "%(asctime)s (%(process)d/%(threadName)s) "
+ "%(name)s %(levelname)s - %(message)s"
+ )
+
+ handler.setFormatter(formatter)
+
+ root_logger = logging.getLogger("litecli")
+ root_logger.addHandler(handler)
+ root_logger.setLevel(level_map[log_level.upper()])
+
+ logging.captureWarnings(True)
+
+ root_logger.debug("Initializing litecli logging.")
+ root_logger.debug("Log file %r.", log_file)
+
+ def read_my_cnf_files(self, keys):
+ """
+ Reads a list of config files and merges them. The last one will win.
+ :param files: list of files to read
+ :param keys: list of keys to retrieve
+ :returns: tuple, with None for missing keys.
+ """
+ cnf = self.config
+
+ sections = ["main"]
+
+ def get(key):
+ result = None
+ for sect in cnf:
+ if sect in sections and key in cnf[sect]:
+ result = cnf[sect][key]
+ return result
+
+ return {x: get(x) for x in keys}
+
+ def connect(self, database=""):
+
+ cnf = {"database": None}
+
+ cnf = self.read_my_cnf_files(cnf.keys())
+
+ # Fall back to config values only if user did not specify a value.
+
+ database = database or cnf["database"]
+
+ # Connect to the database.
+
+ def _connect():
+ self.sqlexecute = SQLExecute(database)
+
+ try:
+ _connect()
+ except Exception as e: # Connecting to a database could fail.
+ self.logger.debug("Database connection failed: %r.", e)
+ self.logger.error("traceback: %r", traceback.format_exc())
+ self.echo(str(e), err=True, fg="red")
+ exit(1)
+
+ def handle_editor_command(self, text):
+ """Editor command is any query that is prefixed or suffixed by a '\e'.
+ The reason for a while loop is because a user might edit a query
+ multiple times. For eg:
+
+ "select * from \e"<enter> to edit it in vim, then come
+ back to the prompt with the edited query "select * from
+ blah where q = 'abc'\e" to edit it again.
+ :param text: Document
+ :return: Document
+
+ """
+
+ while special.editor_command(text):
+ filename = special.get_filename(text)
+ query = special.get_editor_query(text) or self.get_last_query()
+ sql, message = special.open_external_editor(filename, sql=query)
+ if message:
+ # Something went wrong. Raise an exception and bail.
+ raise RuntimeError(message)
+ while True:
+ try:
+ text = self.prompt_app.prompt(default=sql)
+ break
+ except KeyboardInterrupt:
+ sql = ""
+
+ continue
+ return text
+
+ def run_cli(self):
+ iterations = 0
+ sqlexecute = self.sqlexecute
+ logger = self.logger
+ self.configure_pager()
+ self.refresh_completions()
+
+ history_file = config_location() + "history"
+ if dir_path_exists(history_file):
+ history = FileHistory(history_file)
+ else:
+ history = None
+ self.echo(
+ 'Error: Unable to open the history file "{}". '
+ "Your query history will not be saved.".format(history_file),
+ err=True,
+ fg="red",
+ )
+
+ key_bindings = cli_bindings(self)
+
+ if not self.less_chatty:
+ print("Version:", __version__)
+ print("Mail: https://groups.google.com/forum/#!forum/litecli-users")
+ print("GitHub: https://github.com/dbcli/litecli")
+ # print("Home: https://litecli.com")
+
+ def get_message():
+ prompt = self.get_prompt(self.prompt_format)
+ if (
+ self.prompt_format == self.default_prompt
+ and len(prompt) > self.max_len_prompt
+ ):
+ prompt = self.get_prompt("\\d> ")
+ return [("class:prompt", prompt)]
+
+ def get_continuation(width, line_number, is_soft_wrap):
+ continuation = " " * (width - 1) + " "
+ return [("class:continuation", continuation)]
+
+ def show_suggestion_tip():
+ return iterations < 2
+
+ def one_iteration(text=None):
+ if text is None:
+ try:
+ text = self.prompt_app.prompt()
+ except KeyboardInterrupt:
+ return
+
+ special.set_expanded_output(False)
+
+ try:
+ text = self.handle_editor_command(text)
+ except RuntimeError as e:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ self.echo(str(e), err=True, fg="red")
+ return
+
+ if not text.strip():
+ return
+
+ if self.destructive_warning:
+ destroy = confirm_destructive_query(text)
+ if destroy is None:
+ pass # Query was not destructive. Nothing to do here.
+ elif destroy is True:
+ self.echo("Your call!")
+ else:
+ self.echo("Wise choice!")
+ return
+
+ # Keep track of whether or not the query is mutating. In case
+ # of a multi-statement query, the overall query is considered
+ # mutating if any one of the component statements is mutating
+ mutating = False
+
+ try:
+ logger.debug("sql: %r", text)
+
+ special.write_tee(self.get_prompt(self.prompt_format) + text)
+ if self.logfile:
+ self.logfile.write("\n# %s\n" % datetime.now())
+ self.logfile.write(text)
+ self.logfile.write("\n")
+
+ successful = False
+ start = time()
+ res = sqlexecute.run(text)
+ self.formatter.query = text
+ successful = True
+ result_count = 0
+ for title, cur, headers, status in res:
+ logger.debug("headers: %r", headers)
+ logger.debug("rows: %r", cur)
+ logger.debug("status: %r", status)
+ threshold = 1000
+ if is_select(status) and cur and cur.rowcount > threshold:
+ self.echo(
+ "The result set has more than {} rows.".format(threshold),
+ fg="red",
+ )
+ if not confirm("Do you want to continue?"):
+ self.echo("Aborted!", err=True, fg="red")
+ break
+
+ if self.auto_vertical_output:
+ max_width = self.prompt_app.output.get_size().columns
+ else:
+ max_width = None
+
+ formatted = self.format_output(
+ title, cur, headers, special.is_expanded_output(), max_width
+ )
+
+ t = time() - start
+ try:
+ if result_count > 0:
+ self.echo("")
+ try:
+ self.output(formatted, status)
+ except KeyboardInterrupt:
+ pass
+ self.echo("Time: %0.03fs" % t)
+ except KeyboardInterrupt:
+ pass
+
+ start = time()
+ result_count += 1
+ mutating = mutating or is_mutating(status)
+ special.unset_once_if_written()
+ except EOFError as e:
+ raise e
+ except KeyboardInterrupt:
+ # get last connection id
+ connection_id_to_kill = sqlexecute.connection_id
+ logger.debug("connection id to kill: %r", connection_id_to_kill)
+ # Restart connection to the database
+ sqlexecute.connect()
+ try:
+ for title, cur, headers, status in sqlexecute.run(
+ "kill %s" % connection_id_to_kill
+ ):
+ status_str = str(status).lower()
+ if status_str.find("ok") > -1:
+ logger.debug(
+ "cancelled query, connection id: %r, sql: %r",
+ connection_id_to_kill,
+ text,
+ )
+ self.echo("cancelled query", err=True, fg="red")
+ except Exception as e:
+ self.echo(
+ "Encountered error while cancelling query: {}".format(e),
+ err=True,
+ fg="red",
+ )
+ except NotImplementedError:
+ self.echo("Not Yet Implemented.", fg="yellow")
+ except OperationalError as e:
+ logger.debug("Exception: %r", e)
+ if e.args[0] in (2003, 2006, 2013):
+ logger.debug("Attempting to reconnect.")
+ self.echo("Reconnecting...", fg="yellow")
+ try:
+ sqlexecute.connect()
+ logger.debug("Reconnected successfully.")
+ one_iteration(text)
+ return # OK to just return, cuz the recursion call runs to the end.
+ except OperationalError as e:
+ logger.debug("Reconnect failed. e: %r", e)
+ self.echo(str(e), err=True, fg="red")
+ # If reconnection failed, don't proceed further.
+ return
+ else:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ self.echo(str(e), err=True, fg="red")
+ except Exception as e:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ self.echo(str(e), err=True, fg="red")
+ else:
+ # Refresh the table names and column names if necessary.
+ if need_completion_refresh(text):
+ self.refresh_completions(reset=need_completion_reset(text))
+ finally:
+ if self.logfile is False:
+ self.echo("Warning: This query was not logged.", err=True, fg="red")
+ query = Query(text, successful, mutating)
+ self.query_history.append(query)
+
+ get_toolbar_tokens = create_toolbar_tokens_func(self, show_suggestion_tip)
+
+ if self.wider_completion_menu:
+ complete_style = CompleteStyle.MULTI_COLUMN
+ else:
+ complete_style = CompleteStyle.COLUMN
+
+ with self._completer_lock:
+
+ if self.key_bindings == "vi":
+ editing_mode = EditingMode.VI
+ else:
+ editing_mode = EditingMode.EMACS
+
+ self.prompt_app = PromptSession(
+ lexer=PygmentsLexer(LiteCliLexer),
+ reserve_space_for_menu=self.get_reserved_space(),
+ message=get_message,
+ prompt_continuation=get_continuation,
+ bottom_toolbar=get_toolbar_tokens,
+ complete_style=complete_style,
+ input_processors=[
+ ConditionalProcessor(
+ processor=HighlightMatchingBracketProcessor(chars="[](){}"),
+ filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(),
+ )
+ ],
+ tempfile_suffix=".sql",
+ completer=DynamicCompleter(lambda: self.completer),
+ history=history,
+ auto_suggest=AutoSuggestFromHistory(),
+ complete_while_typing=True,
+ multiline=cli_is_multiline(self),
+ style=style_factory(self.syntax_style, self.cli_style),
+ include_default_pygments_style=False,
+ key_bindings=key_bindings,
+ enable_open_in_editor=True,
+ enable_system_prompt=True,
+ enable_suspend=True,
+ editing_mode=editing_mode,
+ search_ignore_case=True,
+ )
+
+ try:
+ while True:
+ one_iteration()
+ iterations += 1
+ except EOFError:
+ special.close_tee()
+ if not self.less_chatty:
+ self.echo("Goodbye!")
+
+ def log_output(self, output):
+ """Log the output in the audit log, if it's enabled."""
+ if self.logfile:
+ click.echo(utf8tounicode(output), file=self.logfile)
+
+ def echo(self, s, **kwargs):
+ """Print a message to stdout.
+
+ The message will be logged in the audit log, if enabled.
+
+ All keyword arguments are passed to click.echo().
+
+ """
+ self.log_output(s)
+ click.secho(s, **kwargs)
+
+ def get_output_margin(self, status=None):
+ """Get the output margin (number of rows for the prompt, footer and
+ timing message."""
+ margin = (
+ self.get_reserved_space()
+ + self.get_prompt(self.prompt_format).count("\n")
+ + 2
+ )
+ if status:
+ margin += 1 + status.count("\n")
+
+ return margin
+
+ def output(self, output, status=None):
+ """Output text to stdout or a pager command.
+
+ The status text is not outputted to pager or files.
+
+ The message will be logged in the audit log, if enabled. The
+ message will be written to the tee file, if enabled. The
+ message will be written to the output file, if enabled.
+
+ """
+ if output:
+ size = self.prompt_app.output.get_size()
+
+ margin = self.get_output_margin(status)
+
+ fits = True
+ buf = []
+ output_via_pager = self.explicit_pager and special.is_pager_enabled()
+ for i, line in enumerate(output, 1):
+ self.log_output(line)
+ special.write_tee(line)
+ special.write_once(line)
+
+ if fits or output_via_pager:
+ # buffering
+ buf.append(line)
+ if len(line) > size.columns or i > (size.rows - margin):
+ fits = False
+ if not self.explicit_pager and special.is_pager_enabled():
+ # doesn't fit, use pager
+ output_via_pager = True
+
+ if not output_via_pager:
+ # doesn't fit, flush buffer
+ for line in buf:
+ click.secho(line)
+ buf = []
+ else:
+ click.secho(line)
+
+ if buf:
+ if output_via_pager:
+ # sadly click.echo_via_pager doesn't accept generators
+ click.echo_via_pager("\n".join(buf))
+ else:
+ for line in buf:
+ click.secho(line)
+
+ if status:
+ self.log_output(status)
+ click.secho(status)
+
+ def configure_pager(self):
+ # Provide sane defaults for less if they are empty.
+ if not os.environ.get("LESS"):
+ os.environ["LESS"] = "-RXF"
+
+ cnf = self.read_my_cnf_files(["pager", "skip-pager"])
+ if cnf["pager"]:
+ special.set_pager(cnf["pager"])
+ self.explicit_pager = True
+ else:
+ self.explicit_pager = False
+
+ if cnf["skip-pager"] or not self.config["main"].as_bool("enable_pager"):
+ special.disable_pager()
+
+ def refresh_completions(self, reset=False):
+ if reset:
+ with self._completer_lock:
+ self.completer.reset_completions()
+ self.completion_refresher.refresh(
+ self.sqlexecute,
+ self._on_completions_refreshed,
+ {
+ "supported_formats": self.formatter.supported_formats,
+ "keyword_casing": self.completer.keyword_casing,
+ },
+ )
+
+ return [
+ (None, None, None, "Auto-completion refresh started in the background.")
+ ]
+
+ def _on_completions_refreshed(self, new_completer):
+ """Swap the completer object in cli with the newly created completer.
+ """
+ with self._completer_lock:
+ self.completer = new_completer
+
+ if self.prompt_app:
+ # After refreshing, redraw the CLI to clear the statusbar
+ # "Refreshing completions..." indicator
+ self.prompt_app.app.invalidate()
+
+ def get_completions(self, text, cursor_positition):
+ with self._completer_lock:
+ return self.completer.get_completions(
+ Document(text=text, cursor_position=cursor_positition), None
+ )
+
+ def get_prompt(self, string):
+ self.logger.debug("Getting prompt")
+ sqlexecute = self.sqlexecute
+ now = datetime.now()
+ string = string.replace("\\d", sqlexecute.dbname or "(none)")
+ string = string.replace("\\n", "\n")
+ string = string.replace("\\D", now.strftime("%a %b %d %H:%M:%S %Y"))
+ string = string.replace("\\m", now.strftime("%M"))
+ string = string.replace("\\P", now.strftime("%p"))
+ string = string.replace("\\R", now.strftime("%H"))
+ string = string.replace("\\r", now.strftime("%I"))
+ string = string.replace("\\s", now.strftime("%S"))
+ string = string.replace("\\_", " ")
+ return string
+
+ def run_query(self, query, new_line=True):
+ """Runs *query*."""
+ results = self.sqlexecute.run(query)
+ for result in results:
+ title, cur, headers, status = result
+ self.formatter.query = query
+ output = self.format_output(title, cur, headers)
+ for line in output:
+ click.echo(line, nl=new_line)
+
+ def format_output(self, title, cur, headers, expanded=False, max_width=None):
+ expanded = expanded or self.formatter.format_name == "vertical"
+ output = []
+
+ output_kwargs = {
+ "dialect": "unix",
+ "disable_numparse": True,
+ "preserve_whitespace": True,
+ "preprocessors": (preprocessors.align_decimals,),
+ "style": self.output_style,
+ }
+
+ if title: # Only print the title if it's not None.
+ output = itertools.chain(output, [title])
+
+ if cur:
+ column_types = None
+ if hasattr(cur, "description"):
+
+ def get_col_type(col):
+ # col_type = FIELD_TYPES.get(col[1], text_type)
+ # return col_type if type(col_type) is type else text_type
+ return text_type
+
+ column_types = [get_col_type(col) for col in cur.description]
+
+ if max_width is not None:
+ cur = list(cur)
+
+ formatted = self.formatter.format_output(
+ cur,
+ headers,
+ format_name="vertical" if expanded else None,
+ column_types=column_types,
+ **output_kwargs
+ )
+
+ if isinstance(formatted, (text_type)):
+ formatted = formatted.splitlines()
+ formatted = iter(formatted)
+
+ first_line = next(formatted)
+ formatted = itertools.chain([first_line], formatted)
+
+ if (
+ not expanded
+ and max_width
+ and headers
+ and cur
+ and len(first_line) > max_width
+ ):
+ formatted = self.formatter.format_output(
+ cur,
+ headers,
+ format_name="vertical",
+ column_types=column_types,
+ **output_kwargs
+ )
+ if isinstance(formatted, (text_type)):
+ formatted = iter(formatted.splitlines())
+
+ output = itertools.chain(output, formatted)
+
+ return output
+
+ def get_reserved_space(self):
+ """Get the number of lines to reserve for the completion menu."""
+ reserved_space_ratio = 0.45
+ max_reserved_space = 8
+ _, height = click.get_terminal_size()
+ return min(int(round(height * reserved_space_ratio)), max_reserved_space)
+
+ def get_last_query(self):
+ """Get the last query executed or None."""
+ return self.query_history[-1][0] if self.query_history else None
+
+
+@click.command()
+@click.option("-V", "--version", is_flag=True, help="Output litecli's version.")
+@click.option("-D", "--database", "dbname", help="Database to use.")
+@click.option(
+ "-R",
+ "--prompt",
+ "prompt",
+ help='Prompt format (Default: "{0}").'.format(LiteCli.default_prompt),
+)
+@click.option(
+ "-l",
+ "--logfile",
+ type=click.File(mode="a", encoding="utf-8"),
+ help="Log every query and its results to a file.",
+)
+@click.option(
+ "--liteclirc",
+ default=config_location() + "config",
+ help="Location of liteclirc file.",
+ type=click.Path(dir_okay=False),
+)
+@click.option(
+ "--auto-vertical-output",
+ is_flag=True,
+ help="Automatically switch to vertical output mode if the result is wider than the terminal width.",
+)
+@click.option(
+ "-t", "--table", is_flag=True, help="Display batch output in table format."
+)
+@click.option("--csv", is_flag=True, help="Display batch output in CSV format.")
+@click.option(
+ "--warn/--no-warn", default=None, help="Warn before running a destructive query."
+)
+@click.option("-e", "--execute", type=str, help="Execute command and quit.")
+@click.argument("database", default="", nargs=1)
+def cli(
+ database,
+ dbname,
+ version,
+ prompt,
+ logfile,
+ auto_vertical_output,
+ table,
+ csv,
+ warn,
+ execute,
+ liteclirc,
+):
+ """A SQLite terminal client with auto-completion and syntax highlighting.
+
+ \b
+ Examples:
+ - litecli lite_database
+
+ """
+
+ if version:
+ print("Version:", __version__)
+ sys.exit(0)
+
+ litecli = LiteCli(
+ prompt=prompt,
+ logfile=logfile,
+ auto_vertical_output=auto_vertical_output,
+ warn=warn,
+ liteclirc=liteclirc,
+ )
+
+ # Choose which ever one has a valid value.
+ database = database or dbname
+
+ litecli.connect(database)
+
+ litecli.logger.debug("Launch Params: \n" "\tdatabase: %r", database)
+
+ # --execute argument
+ if execute:
+ try:
+ if csv:
+ litecli.formatter.format_name = "csv"
+ elif not table:
+ litecli.formatter.format_name = "tsv"
+
+ litecli.run_query(execute)
+ exit(0)
+ except Exception as e:
+ click.secho(str(e), err=True, fg="red")
+ exit(1)
+
+ if sys.stdin.isatty():
+ litecli.run_cli()
+ else:
+ stdin = click.get_text_stream("stdin")
+ stdin_text = stdin.read()
+
+ try:
+ sys.stdin = open("/dev/tty")
+ except (FileNotFoundError, OSError):
+ litecli.logger.warning("Unable to open TTY as stdin.")
+
+ if (
+ litecli.destructive_warning
+ and confirm_destructive_query(stdin_text) is False
+ ):
+ exit(0)
+ try:
+ new_line = True
+
+ if csv:
+ litecli.formatter.format_name = "csv"
+ elif not table:
+ litecli.formatter.format_name = "tsv"
+
+ litecli.run_query(stdin_text, new_line=new_line)
+ exit(0)
+ except Exception as e:
+ click.secho(str(e), err=True, fg="red")
+ exit(1)
+
+
+def need_completion_refresh(queries):
+ """Determines if the completion needs a refresh by checking if the sql
+ statement is an alter, create, drop or change db."""
+ for query in sqlparse.split(queries):
+ try:
+ first_token = query.split()[0]
+ if first_token.lower() in (
+ "alter",
+ "create",
+ "use",
+ "\\r",
+ "\\u",
+ "connect",
+ "drop",
+ ):
+ return True
+ except Exception:
+ return False
+
+
+def need_completion_reset(queries):
+ """Determines if the statement is a database switch such as 'use' or '\\u'.
+ When a database is changed the existing completions must be reset before we
+ start the completion refresh for the new database.
+ """
+ for query in sqlparse.split(queries):
+ try:
+ first_token = query.split()[0]
+ if first_token.lower() in ("use", "\\u"):
+ return True
+ except Exception:
+ return False
+
+
+def is_mutating(status):
+ """Determines if the statement is mutating based on the status."""
+ if not status:
+ return False
+
+ mutating = set(
+ [
+ "insert",
+ "update",
+ "delete",
+ "alter",
+ "create",
+ "drop",
+ "replace",
+ "truncate",
+ "load",
+ ]
+ )
+ return status.split(None, 1)[0].lower() in mutating
+
+
+def is_select(status):
+ """Returns true if the first word in status is 'select'."""
+ if not status:
+ return False
+ return status.split(None, 1)[0].lower() == "select"
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/litecli/packages/__init__.py b/litecli/packages/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/litecli/packages/__init__.py
diff --git a/litecli/packages/completion_engine.py b/litecli/packages/completion_engine.py
new file mode 100644
index 0000000..0397857
--- /dev/null
+++ b/litecli/packages/completion_engine.py
@@ -0,0 +1,331 @@
+from __future__ import print_function
+import sys
+import sqlparse
+from sqlparse.sql import Comparison, Identifier, Where
+from litecli.encodingutils import string_types, text_type
+from .parseutils import last_word, extract_tables, find_prev_keyword
+from .special import parse_special_command
+
+
+def suggest_type(full_text, text_before_cursor):
+ """Takes the full_text that is typed so far and also the text before the
+ cursor to suggest completion type and scope.
+
+ Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
+ A scope for a column category will be a list of tables.
+ """
+
+ word_before_cursor = last_word(text_before_cursor, include="many_punctuations")
+
+ identifier = None
+
+ # here should be removed once sqlparse has been fixed
+ try:
+ # If we've partially typed a word then word_before_cursor won't be an empty
+ # string. In that case we want to remove the partially typed string before
+ # sending it to the sqlparser. Otherwise the last token will always be the
+ # partially typed string which renders the smart completion useless because
+ # it will always return the list of keywords as completion.
+ if word_before_cursor:
+ if word_before_cursor.endswith("(") or word_before_cursor.startswith("\\"):
+ parsed = sqlparse.parse(text_before_cursor)
+ else:
+ parsed = sqlparse.parse(text_before_cursor[: -len(word_before_cursor)])
+
+ # word_before_cursor may include a schema qualification, like
+ # "schema_name.partial_name" or "schema_name.", so parse it
+ # separately
+ p = sqlparse.parse(word_before_cursor)[0]
+
+ if p.tokens and isinstance(p.tokens[0], Identifier):
+ identifier = p.tokens[0]
+ else:
+ parsed = sqlparse.parse(text_before_cursor)
+ except (TypeError, AttributeError):
+ return [{"type": "keyword"}]
+
+ if len(parsed) > 1:
+ # Multiple statements being edited -- isolate the current one by
+ # cumulatively summing statement lengths to find the one that bounds the
+ # current position
+ current_pos = len(text_before_cursor)
+ stmt_start, stmt_end = 0, 0
+
+ for statement in parsed:
+ stmt_len = len(text_type(statement))
+ stmt_start, stmt_end = stmt_end, stmt_end + stmt_len
+
+ if stmt_end >= current_pos:
+ text_before_cursor = full_text[stmt_start:current_pos]
+ full_text = full_text[stmt_start:]
+ break
+
+ elif parsed:
+ # A single statement
+ statement = parsed[0]
+ else:
+ # The empty string
+ statement = None
+
+ # Check for special commands and handle those separately
+ if statement:
+ # Be careful here because trivial whitespace is parsed as a statement,
+ # but the statement won't have a first token
+ tok1 = statement.token_first()
+ if tok1 and tok1.value.startswith("."):
+ return suggest_special(text_before_cursor)
+ elif tok1 and tok1.value.startswith("\\"):
+ return suggest_special(text_before_cursor)
+ elif tok1 and tok1.value.startswith("source"):
+ return suggest_special(text_before_cursor)
+ elif text_before_cursor and text_before_cursor.startswith(".open "):
+ return suggest_special(text_before_cursor)
+
+ last_token = statement and statement.token_prev(len(statement.tokens))[1] or ""
+
+ return suggest_based_on_last_token(
+ last_token, text_before_cursor, full_text, identifier
+ )
+
+
+def suggest_special(text):
+ text = text.lstrip()
+ cmd, _, arg = parse_special_command(text)
+
+ if cmd == text:
+ # Trying to complete the special command itself
+ return [{"type": "special"}]
+
+ if cmd in ("\\u", "\\r"):
+ return [{"type": "database"}]
+
+ if cmd in ("\\T"):
+ return [{"type": "table_format"}]
+
+ if cmd in ["\\f", "\\fs", "\\fd"]:
+ return [{"type": "favoritequery"}]
+
+ if cmd in ["\\d", "\\dt", "\\dt+", ".schema"]:
+ return [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+
+ if cmd in ["\\.", "source", ".open"]:
+ return [{"type": "file_name"}]
+
+ if cmd in [".import"]:
+ # Usage: .import filename table
+ if _expecting_arg_idx(arg, text) == 1:
+ return [{"type": "file_name"}]
+ else:
+ return [{"type": "table", "schema": []}]
+
+ return [{"type": "keyword"}, {"type": "special"}]
+
+
+def _expecting_arg_idx(arg, text):
+ """Return the index of expecting argument.
+
+ >>> _expecting_arg_idx("./da", ".import ./da")
+ 1
+ >>> _expecting_arg_idx("./data.csv", ".import ./data.csv")
+ 1
+ >>> _expecting_arg_idx("./data.csv", ".import ./data.csv ")
+ 2
+ >>> _expecting_arg_idx("./data.csv t", ".import ./data.csv t")
+ 2
+ """
+ args = arg.split()
+ return len(args) + int(text[-1].isspace())
+
+
+def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier):
+ if isinstance(token, string_types):
+ token_v = token.lower()
+ elif isinstance(token, Comparison):
+ # If 'token' is a Comparison type such as
+ # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling
+ # token.value on the comparison type will only return the lhs of the
+ # comparison. In this case a.id. So we need to do token.tokens to get
+ # both sides of the comparison and pick the last token out of that
+ # list.
+ token_v = token.tokens[-1].value.lower()
+ elif isinstance(token, Where):
+ # sqlparse groups all tokens from the where clause into a single token
+ # list. This means that token.value may be something like
+ # 'where foo > 5 and '. We need to look "inside" token.tokens to handle
+ # suggestions in complicated where clauses correctly
+ prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
+ return suggest_based_on_last_token(
+ prev_keyword, text_before_cursor, full_text, identifier
+ )
+ else:
+ token_v = token.value.lower()
+
+ is_operand = lambda x: x and any([x.endswith(op) for op in ["+", "-", "*", "/"]])
+
+ if not token:
+ return [{"type": "keyword"}, {"type": "special"}]
+ elif token_v.endswith("("):
+ p = sqlparse.parse(text_before_cursor)[0]
+
+ if p.tokens and isinstance(p.tokens[-1], Where):
+ # Four possibilities:
+ # 1 - Parenthesized clause like "WHERE foo AND ("
+ # Suggest columns/functions
+ # 2 - Function call like "WHERE foo("
+ # Suggest columns/functions
+ # 3 - Subquery expression like "WHERE EXISTS ("
+ # Suggest keywords, in order to do a subquery
+ # 4 - Subquery OR array comparison like "WHERE foo = ANY("
+ # Suggest columns/functions AND keywords. (If we wanted to be
+ # really fancy, we could suggest only array-typed columns)
+
+ column_suggestions = suggest_based_on_last_token(
+ "where", text_before_cursor, full_text, identifier
+ )
+
+ # Check for a subquery expression (cases 3 & 4)
+ where = p.tokens[-1]
+ idx, prev_tok = where.token_prev(len(where.tokens) - 1)
+
+ if isinstance(prev_tok, Comparison):
+ # e.g. "SELECT foo FROM bar WHERE foo = ANY("
+ prev_tok = prev_tok.tokens[-1]
+
+ prev_tok = prev_tok.value.lower()
+ if prev_tok == "exists":
+ return [{"type": "keyword"}]
+ else:
+ return column_suggestions
+
+ # Get the token before the parens
+ idx, prev_tok = p.token_prev(len(p.tokens) - 1)
+ if prev_tok and prev_tok.value and prev_tok.value.lower() == "using":
+ # tbl1 INNER JOIN tbl2 USING (col1, col2)
+ tables = extract_tables(full_text)
+
+ # suggest columns that are present in more than one table
+ return [{"type": "column", "tables": tables, "drop_unique": True}]
+ elif p.token_first().value.lower() == "select":
+ # If the lparen is preceeded by a space chances are we're about to
+ # do a sub-select.
+ if last_word(text_before_cursor, "all_punctuations").startswith("("):
+ return [{"type": "keyword"}]
+ elif p.token_first().value.lower() == "show":
+ return [{"type": "show"}]
+
+ # We're probably in a function argument list
+ return [{"type": "column", "tables": extract_tables(full_text)}]
+ elif token_v in ("set", "order by", "distinct"):
+ return [{"type": "column", "tables": extract_tables(full_text)}]
+ elif token_v == "as":
+ # Don't suggest anything for an alias
+ return []
+ elif token_v in ("show"):
+ return [{"type": "show"}]
+ elif token_v in ("to",):
+ p = sqlparse.parse(text_before_cursor)[0]
+ if p.token_first().value.lower() == "change":
+ return [{"type": "change"}]
+ else:
+ return [{"type": "user"}]
+ elif token_v in ("user", "for"):
+ return [{"type": "user"}]
+ elif token_v in ("select", "where", "having"):
+ # Check for a table alias or schema qualification
+ parent = (identifier and identifier.get_parent_name()) or []
+
+ tables = extract_tables(full_text)
+ if parent:
+ tables = [t for t in tables if identifies(parent, *t)]
+ return [
+ {"type": "column", "tables": tables},
+ {"type": "table", "schema": parent},
+ {"type": "view", "schema": parent},
+ {"type": "function", "schema": parent},
+ ]
+ else:
+ aliases = [alias or table for (schema, table, alias) in tables]
+ return [
+ {"type": "column", "tables": tables},
+ {"type": "function", "schema": []},
+ {"type": "alias", "aliases": aliases},
+ {"type": "keyword"},
+ ]
+ elif (token_v.endswith("join") and token.is_keyword) or (
+ token_v
+ in ("copy", "from", "update", "into", "describe", "truncate", "desc", "explain")
+ ):
+ schema = (identifier and identifier.get_parent_name()) or []
+
+ # Suggest tables from either the currently-selected schema or the
+ # public schema if no schema has been specified
+ suggest = [{"type": "table", "schema": schema}]
+
+ if not schema:
+ # Suggest schemas
+ suggest.insert(0, {"type": "schema"})
+
+ # Only tables can be TRUNCATED, otherwise suggest views
+ if token_v != "truncate":
+ suggest.append({"type": "view", "schema": schema})
+
+ return suggest
+
+ elif token_v in ("table", "view", "function"):
+ # E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>'
+ rel_type = token_v
+ schema = (identifier and identifier.get_parent_name()) or []
+ if schema:
+ return [{"type": rel_type, "schema": schema}]
+ else:
+ return [{"type": "schema"}, {"type": rel_type, "schema": []}]
+ elif token_v == "on":
+ tables = extract_tables(full_text) # [(schema, table, alias), ...]
+ parent = (identifier and identifier.get_parent_name()) or []
+ if parent:
+ # "ON parent.<suggestion>"
+ # parent can be either a schema name or table alias
+ tables = [t for t in tables if identifies(parent, *t)]
+ return [
+ {"type": "column", "tables": tables},
+ {"type": "table", "schema": parent},
+ {"type": "view", "schema": parent},
+ {"type": "function", "schema": parent},
+ ]
+ else:
+ # ON <suggestion>
+ # Use table alias if there is one, otherwise the table name
+ aliases = [alias or table for (schema, table, alias) in tables]
+ suggest = [{"type": "alias", "aliases": aliases}]
+
+ # The lists of 'aliases' could be empty if we're trying to complete
+ # a GRANT query. eg: GRANT SELECT, INSERT ON <tab>
+ # In that case we just suggest all tables.
+ if not aliases:
+ suggest.append({"type": "table", "schema": parent})
+ return suggest
+
+ elif token_v in ("use", "database", "template", "connect"):
+ # "\c <db", "use <db>", "DROP DATABASE <db>",
+ # "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
+ return [{"type": "database"}]
+ elif token_v == "tableformat":
+ return [{"type": "table_format"}]
+ elif token_v.endswith(",") or is_operand(token_v) or token_v in ["=", "and", "or"]:
+ prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
+ if prev_keyword:
+ return suggest_based_on_last_token(
+ prev_keyword, text_before_cursor, full_text, identifier
+ )
+ else:
+ return []
+ else:
+ return [{"type": "keyword"}]
+
+
+def identifies(id, schema, table, alias):
+ return id == alias or id == table or (schema and (id == schema + "." + table))
diff --git a/litecli/packages/filepaths.py b/litecli/packages/filepaths.py
new file mode 100644
index 0000000..2f01046
--- /dev/null
+++ b/litecli/packages/filepaths.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8
+
+from __future__ import unicode_literals
+
+from litecli.encodingutils import text_type
+import os
+
+
+def list_path(root_dir):
+ """List directory if exists.
+
+ :param dir: str
+ :return: list
+
+ """
+ res = []
+ if os.path.isdir(root_dir):
+ for name in os.listdir(root_dir):
+ res.append(name)
+ return res
+
+
+def complete_path(curr_dir, last_dir):
+ """Return the path to complete that matches the last entered component.
+
+ If the last entered component is ~, expanded path would not
+ match, so return all of the available paths.
+
+ :param curr_dir: str
+ :param last_dir: str
+ :return: str
+
+ """
+ if not last_dir or curr_dir.startswith(last_dir):
+ return curr_dir
+ elif last_dir == "~":
+ return os.path.join(last_dir, curr_dir)
+
+
+def parse_path(root_dir):
+ """Split path into head and last component for the completer.
+
+ Also return position where last component starts.
+
+ :param root_dir: str path
+ :return: tuple of (string, string, int)
+
+ """
+ base_dir, last_dir, position = "", "", 0
+ if root_dir:
+ base_dir, last_dir = os.path.split(root_dir)
+ position = -len(last_dir) if last_dir else 0
+ return base_dir, last_dir, position
+
+
+def suggest_path(root_dir):
+ """List all files and subdirectories in a directory.
+
+ If the directory is not specified, suggest root directory,
+ user directory, current and parent directory.
+
+ :param root_dir: string: directory to list
+ :return: list
+
+ """
+ if not root_dir:
+ return map(text_type, [os.path.abspath(os.sep), "~", os.curdir, os.pardir])
+
+ if "~" in root_dir:
+ root_dir = text_type(os.path.expanduser(root_dir))
+
+ if not os.path.exists(root_dir):
+ root_dir, _ = os.path.split(root_dir)
+
+ return list_path(root_dir)
+
+
+def dir_path_exists(path):
+ """Check if the directory path exists for a given file.
+
+ For example, for a file /home/user/.cache/litecli/log, check if
+ /home/user/.cache/litecli exists.
+
+ :param str path: The file path.
+ :return: Whether or not the directory path exists.
+
+ """
+ return os.path.exists(os.path.dirname(path))
diff --git a/litecli/packages/parseutils.py b/litecli/packages/parseutils.py
new file mode 100644
index 0000000..92fe365
--- /dev/null
+++ b/litecli/packages/parseutils.py
@@ -0,0 +1,227 @@
+from __future__ import print_function
+import re
+import sqlparse
+from sqlparse.sql import IdentifierList, Identifier, Function
+from sqlparse.tokens import Keyword, DML, Punctuation
+
+cleanup_regex = {
+ # This matches only alphanumerics and underscores.
+ "alphanum_underscore": re.compile(r"(\w+)$"),
+ # This matches everything except spaces, parens, colon, and comma
+ "many_punctuations": re.compile(r"([^():,\s]+)$"),
+ # This matches everything except spaces, parens, colon, comma, and period
+ "most_punctuations": re.compile(r"([^\.():,\s]+)$"),
+ # This matches everything except a space.
+ "all_punctuations": re.compile("([^\s]+)$"),
+}
+
+
+def last_word(text, include="alphanum_underscore"):
+ """
+ Find the last word in a sentence.
+
+ >>> last_word('abc')
+ 'abc'
+ >>> last_word(' abc')
+ 'abc'
+ >>> last_word('')
+ ''
+ >>> last_word(' ')
+ ''
+ >>> last_word('abc ')
+ ''
+ >>> last_word('abc def')
+ 'def'
+ >>> last_word('abc def ')
+ ''
+ >>> last_word('abc def;')
+ ''
+ >>> last_word('bac $def')
+ 'def'
+ >>> last_word('bac $def', include='most_punctuations')
+ '$def'
+ >>> last_word('bac \def', include='most_punctuations')
+ '\\\\def'
+ >>> last_word('bac \def;', include='most_punctuations')
+ '\\\\def;'
+ >>> last_word('bac::def', include='most_punctuations')
+ 'def'
+ """
+
+ if not text: # Empty string
+ return ""
+
+ if text[-1].isspace():
+ return ""
+ else:
+ regex = cleanup_regex[include]
+ matches = regex.search(text)
+ if matches:
+ return matches.group(0)
+ else:
+ return ""
+
+
+# This code is borrowed from sqlparse example script.
+# <url>
+def is_subselect(parsed):
+ if not parsed.is_group:
+ return False
+ for item in parsed.tokens:
+ if item.ttype is DML and item.value.upper() in (
+ "SELECT",
+ "INSERT",
+ "UPDATE",
+ "CREATE",
+ "DELETE",
+ ):
+ return True
+ return False
+
+
+def extract_from_part(parsed, stop_at_punctuation=True):
+ tbl_prefix_seen = False
+ for item in parsed.tokens:
+ if tbl_prefix_seen:
+ if is_subselect(item):
+ for x in extract_from_part(item, stop_at_punctuation):
+ yield x
+ elif stop_at_punctuation and item.ttype is Punctuation:
+ return
+ # An incomplete nested select won't be recognized correctly as a
+ # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
+ # the second FROM to trigger this elif condition resulting in a
+ # `return`. So we need to ignore the keyword if the keyword
+ # FROM.
+ # Also 'SELECT * FROM abc JOIN def' will trigger this elif
+ # condition. So we need to ignore the keyword JOIN and its variants
+ # INNER JOIN, FULL OUTER JOIN, etc.
+ elif (
+ item.ttype is Keyword
+ and (not item.value.upper() == "FROM")
+ and (not item.value.upper().endswith("JOIN"))
+ ):
+ return
+ else:
+ yield item
+ elif (
+ item.ttype is Keyword or item.ttype is Keyword.DML
+ ) and item.value.upper() in ("COPY", "FROM", "INTO", "UPDATE", "TABLE", "JOIN"):
+ tbl_prefix_seen = True
+ # 'SELECT a, FROM abc' will detect FROM as part of the column list.
+ # So this check here is necessary.
+ elif isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ if identifier.ttype is Keyword and identifier.value.upper() == "FROM":
+ tbl_prefix_seen = True
+ break
+
+
+def extract_table_identifiers(token_stream):
+ """yields tuples of (schema_name, table_name, table_alias)"""
+
+ for item in token_stream:
+ if isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ # Sometimes Keywords (such as FROM ) are classified as
+ # identifiers which don't have the get_real_name() method.
+ try:
+ schema_name = identifier.get_parent_name()
+ real_name = identifier.get_real_name()
+ except AttributeError:
+ continue
+ if real_name:
+ yield (schema_name, real_name, identifier.get_alias())
+ elif isinstance(item, Identifier):
+ real_name = item.get_real_name()
+ schema_name = item.get_parent_name()
+
+ if real_name:
+ yield (schema_name, real_name, item.get_alias())
+ else:
+ name = item.get_name()
+ yield (None, name, item.get_alias() or name)
+ elif isinstance(item, Function):
+ yield (None, item.get_name(), item.get_name())
+
+
+# extract_tables is inspired from examples in the sqlparse lib.
+def extract_tables(sql):
+ """Extract the table names from an SQL statment.
+
+ Returns a list of (schema, table, alias) tuples
+
+ """
+ parsed = sqlparse.parse(sql)
+ if not parsed:
+ return []
+
+ # INSERT statements must stop looking for tables at the sign of first
+ # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
+ # abc is the table name, but if we don't stop at the first lparen, then
+ # we'll identify abc, col1 and col2 as table names.
+ insert_stmt = parsed[0].token_first().value.lower() == "insert"
+ stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
+ return list(extract_table_identifiers(stream))
+
+
+def find_prev_keyword(sql):
+ """ Find the last sql keyword in an SQL statement
+
+ Returns the value of the last keyword, and the text of the query with
+ everything after the last keyword stripped
+ """
+ if not sql.strip():
+ return None, ""
+
+ parsed = sqlparse.parse(sql)[0]
+ flattened = list(parsed.flatten())
+
+ logical_operators = ("AND", "OR", "NOT", "BETWEEN")
+
+ for t in reversed(flattened):
+ if t.value == "(" or (
+ t.is_keyword and (t.value.upper() not in logical_operators)
+ ):
+ # Find the location of token t in the original parsed statement
+ # We can't use parsed.token_index(t) because t may be a child token
+ # inside a TokenList, in which case token_index thows an error
+ # Minimal example:
+ # p = sqlparse.parse('select * from foo where bar')
+ # t = list(p.flatten())[-3] # The "Where" token
+ # p.token_index(t) # Throws ValueError: not in list
+ idx = flattened.index(t)
+
+ # Combine the string values of all tokens in the original list
+ # up to and including the target keyword token t, to produce a
+ # query string with everything after the keyword token removed
+ text = "".join(tok.value for tok in flattened[: idx + 1])
+ return t, text
+
+ return None, ""
+
+
+def query_starts_with(query, prefixes):
+ """Check if the query starts with any item from *prefixes*."""
+ prefixes = [prefix.lower() for prefix in prefixes]
+ formatted_sql = sqlparse.format(query.lower(), strip_comments=True)
+ return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
+
+
+def queries_start_with(queries, prefixes):
+ """Check if any queries start with any item from *prefixes*."""
+ for query in sqlparse.split(queries):
+ if query and query_starts_with(query, prefixes) is True:
+ return True
+ return False
+
+
+def is_destructive(queries):
+ """Returns if any of the queries in *queries* is destructive."""
+ keywords = ("drop", "shutdown", "delete", "truncate", "alter")
+ return queries_start_with(queries, keywords)
+
+
+if __name__ == "__main__":
+ sql = "select * from (select t. from tabl t"
+ print(extract_tables(sql))
diff --git a/litecli/packages/prompt_utils.py b/litecli/packages/prompt_utils.py
new file mode 100644
index 0000000..d9ad2b6
--- /dev/null
+++ b/litecli/packages/prompt_utils.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+from __future__ import unicode_literals
+
+
+import sys
+import click
+from .parseutils import is_destructive
+
+
+def confirm_destructive_query(queries):
+ """Check if the query is destructive and prompts the user to confirm.
+
+ Returns:
+ * None if the query is non-destructive or we can't prompt the user.
+ * True if the query is destructive and the user wants to proceed.
+ * False if the query is destructive and the user doesn't want to proceed.
+
+ """
+ prompt_text = (
+ "You're about to run a destructive command.\n" "Do you want to proceed? (y/n)"
+ )
+ if is_destructive(queries) and sys.stdin.isatty():
+ return prompt(prompt_text, type=bool)
+
+
+def confirm(*args, **kwargs):
+ """Prompt for confirmation (yes/no) and handle any abort exceptions."""
+ try:
+ return click.confirm(*args, **kwargs)
+ except click.Abort:
+ return False
+
+
+def prompt(*args, **kwargs):
+ """Prompt the user for input and handle any abort exceptions."""
+ try:
+ return click.prompt(*args, **kwargs)
+ except click.Abort:
+ return False
diff --git a/litecli/packages/special/__init__.py b/litecli/packages/special/__init__.py
new file mode 100644
index 0000000..fd2b18c
--- /dev/null
+++ b/litecli/packages/special/__init__.py
@@ -0,0 +1,12 @@
+__all__ = []
+
+
+def export(defn):
+ """Decorator to explicitly mark functions that are exposed in a lib."""
+ globals()[defn.__name__] = defn
+ __all__.append(defn.__name__)
+ return defn
+
+
+from . import dbcommands
+from . import iocommands
diff --git a/litecli/packages/special/dbcommands.py b/litecli/packages/special/dbcommands.py
new file mode 100644
index 0000000..a7eaa0c
--- /dev/null
+++ b/litecli/packages/special/dbcommands.py
@@ -0,0 +1,273 @@
+from __future__ import unicode_literals, print_function
+import csv
+import logging
+import os
+import sys
+import platform
+import shlex
+from sqlite3 import ProgrammingError
+
+from litecli import __version__
+from litecli.packages.special import iocommands
+from litecli.packages.special.utils import format_uptime
+from .main import special_command, RAW_QUERY, PARSED_QUERY, ArgumentMissing
+
+log = logging.getLogger(__name__)
+
+
+@special_command(
+ ".tables",
+ "\\dt",
+ "List tables.",
+ arg_type=PARSED_QUERY,
+ case_sensitive=True,
+ aliases=("\\dt",),
+)
+def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False):
+ if arg:
+ args = ("{0}%".format(arg),)
+ query = """
+ SELECT name FROM sqlite_master
+ WHERE type IN ('table','view') AND name LIKE ? AND name NOT LIKE 'sqlite_%'
+ ORDER BY 1
+ """
+ else:
+ args = tuple()
+ query = """
+ SELECT name FROM sqlite_master
+ WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%'
+ ORDER BY 1
+ """
+
+ log.debug(query)
+ cur.execute(query, args)
+ tables = cur.fetchall()
+ status = ""
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ else:
+ return [(None, None, None, "")]
+
+ # if verbose and arg:
+ # query = "SELECT sql FROM sqlite_master WHERE name LIKE ?"
+ # log.debug(query)
+ # cur.execute(query)
+ # status = cur.fetchone()[1]
+
+ return [(None, tables, headers, status)]
+
+
+@special_command(
+ ".schema",
+ ".schema[+] [table]",
+ "The complete schema for the database or a single table",
+ arg_type=PARSED_QUERY,
+ case_sensitive=True,
+)
+def show_schema(cur, arg=None, **_):
+ if arg:
+ args = (arg,)
+ query = """
+ SELECT sql FROM sqlite_master
+ WHERE name==?
+ ORDER BY tbl_name, type DESC, name
+ """
+ else:
+ args = tuple()
+ query = """
+ SELECT sql FROM sqlite_master
+ ORDER BY tbl_name, type DESC, name
+ """
+
+ log.debug(query)
+ cur.execute(query, args)
+ tables = cur.fetchall()
+ status = ""
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ else:
+ return [(None, None, None, "")]
+
+ return [(None, tables, headers, status)]
+
+
+@special_command(
+ ".databases",
+ ".databases",
+ "List databases.",
+ arg_type=RAW_QUERY,
+ case_sensitive=True,
+ aliases=("\\l",),
+)
+def list_databases(cur, **_):
+ query = "PRAGMA database_list"
+ log.debug(query)
+ cur.execute(query)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ return [(None, cur, headers, "")]
+ else:
+ return [(None, None, None, "")]
+
+
+@special_command(
+ ".status",
+ "\\s",
+ "Show current settings.",
+ arg_type=RAW_QUERY,
+ aliases=("\\s",),
+ case_sensitive=True,
+)
+def status(cur, **_):
+ # Create output buffers.
+ footer = []
+ footer.append("--------------")
+
+ # Output the litecli client information.
+ implementation = platform.python_implementation()
+ version = platform.python_version()
+ client_info = []
+ client_info.append("litecli {0},".format(__version__))
+ client_info.append("running on {0} {1}".format(implementation, version))
+ footer.append(" ".join(client_info))
+
+ # Build the output that will be displayed as a table.
+ query = "SELECT file from pragma_database_list() where name = 'main';"
+ log.debug(query)
+ cur.execute(query)
+ db = cur.fetchone()[0]
+ if db is None:
+ db = ""
+
+ footer.append("Current database: " + db)
+ if iocommands.is_pager_enabled():
+ if "PAGER" in os.environ:
+ pager = os.environ["PAGER"]
+ else:
+ pager = "System default"
+ else:
+ pager = "stdout"
+ footer.append("Current pager:" + pager)
+
+ footer.append("--------------")
+ return [(None, None, "", "\n".join(footer))]
+
+
+@special_command(
+ ".load",
+ ".load path",
+ "Load an extension library.",
+ arg_type=PARSED_QUERY,
+ case_sensitive=True,
+)
+def load_extension(cur, arg, **_):
+ args = shlex.split(arg)
+ if len(args) != 1:
+ raise TypeError(".load accepts exactly one path")
+ path = args[0]
+ conn = cur.connection
+ conn.enable_load_extension(True)
+ conn.load_extension(path)
+ return [(None, None, None, "")]
+
+
+@special_command(
+ "describe",
+ "\\d [table]",
+ "Description of a table",
+ arg_type=PARSED_QUERY,
+ case_sensitive=True,
+ aliases=("\\d", "describe", "desc"),
+)
+def describe(cur, arg, **_):
+ if arg:
+ args = (arg,)
+ query = """
+ PRAGMA table_info({})
+ """.format(
+ arg
+ )
+ else:
+ raise ArgumentMissing("Table name required.")
+
+ log.debug(query)
+ cur.execute(query)
+ tables = cur.fetchall()
+ status = ""
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ else:
+ return [(None, None, None, "")]
+
+ return [(None, tables, headers, status)]
+
+
+@special_command(
+ ".read",
+ ".read path",
+ "Read input from path",
+ arg_type=PARSED_QUERY,
+ case_sensitive=True,
+)
+def read_script(cur, arg, **_):
+ args = shlex.split(arg)
+ if len(args) != 1:
+ raise TypeError(".read accepts exactly one path")
+ path = args[0]
+ with open(path, "r") as f:
+ script = f.read()
+ cur.executescript(script)
+ return [(None, None, None, "")]
+
+
+@special_command(
+ ".import",
+ ".import filename table",
+ "Import data from filename into an existing table",
+ arg_type=PARSED_QUERY,
+ case_sensitive=True,
+)
+def import_file(cur, arg=None, **_):
+ def split(s):
+ # this is a modification of shlex.split function, just to make it support '`',
+ # because table name might contain '`' character.
+ lex = shlex.shlex(s, posix=True)
+ lex.whitespace_split = True
+ lex.commenters = ""
+ lex.quotes += "`"
+ return list(lex)
+
+ args = split(arg)
+ log.debug("[arg = %r], [args = %r]", arg, args)
+ if len(args) != 2:
+ raise TypeError("Usage: .import filename table")
+
+ filename, table = args
+ cur.execute('PRAGMA table_info("%s")' % table)
+ ncols = len(cur.fetchall())
+ insert_tmpl = 'INSERT INTO "%s" VALUES (?%s)' % (table, ",?" * (ncols - 1))
+
+ with open(filename, "r") as csvfile:
+ dialect = csv.Sniffer().sniff(csvfile.read(1024))
+ csvfile.seek(0)
+ reader = csv.reader(csvfile, dialect)
+
+ cur.execute("BEGIN")
+ ninserted, nignored = 0, 0
+ for i, row in enumerate(reader):
+ if len(row) != ncols:
+ print(
+ "%s:%d expected %d columns but found %d - ignored"
+ % (filename, i, ncols, len(row)),
+ file=sys.stderr,
+ )
+ nignored += 1
+ continue
+ cur.execute(insert_tmpl, row)
+ ninserted += 1
+ cur.execute("COMMIT")
+
+ status = "Inserted %d rows into %s" % (ninserted, table)
+ if nignored > 0:
+ status += " (%d rows are ignored)" % nignored
+ return [(None, None, None, status)]
diff --git a/litecli/packages/special/favoritequeries.py b/litecli/packages/special/favoritequeries.py
new file mode 100644
index 0000000..7da6fbf
--- /dev/null
+++ b/litecli/packages/special/favoritequeries.py
@@ -0,0 +1,59 @@
+# -*- coding: utf-8 -*-
+from __future__ import unicode_literals
+
+
+class FavoriteQueries(object):
+
+ section_name = "favorite_queries"
+
+ usage = """
+Favorite Queries are a way to save frequently used queries
+with a short name.
+Examples:
+
+ # Save a new favorite query.
+ > \\fs simple select * from abc where a is not Null;
+
+ # List all favorite queries.
+ > \\f
+ ╒════════╤═══════════════════════════════════════╕
+ │ Name │ Query │
+ ╞════════╪═══════════════════════════════════════╡
+ │ simple │ SELECT * FROM abc where a is not NULL │
+ ╘════════╧═══════════════════════════════════════╛
+
+ # Run a favorite query.
+ > \\f simple
+ ╒════════╤════════╕
+ │ a │ b │
+ ╞════════╪════════╡
+ │ 日本語 │ 日本語 │
+ ╘════════╧════════╛
+
+ # Delete a favorite query.
+ > \\fd simple
+ simple: Deleted
+"""
+
+ def __init__(self, config):
+ self.config = config
+
+ def list(self):
+ return self.config.get(self.section_name, [])
+
+ def get(self, name):
+ return self.config.get(self.section_name, {}).get(name, None)
+
+ def save(self, name, query):
+ if self.section_name not in self.config:
+ self.config[self.section_name] = {}
+ self.config[self.section_name][name] = query
+ self.config.write()
+
+ def delete(self, name):
+ try:
+ del self.config[self.section_name][name]
+ except KeyError:
+ return "%s: Not Found." % name
+ self.config.write()
+ return "%s: Deleted" % name
diff --git a/litecli/packages/special/iocommands.py b/litecli/packages/special/iocommands.py
new file mode 100644
index 0000000..8940057
--- /dev/null
+++ b/litecli/packages/special/iocommands.py
@@ -0,0 +1,479 @@
+from __future__ import unicode_literals
+import os
+import re
+import locale
+import logging
+import subprocess
+import shlex
+from io import open
+from time import sleep
+
+import click
+import sqlparse
+from configobj import ConfigObj
+
+from . import export
+from .main import special_command, NO_QUERY, PARSED_QUERY
+from .favoritequeries import FavoriteQueries
+from .utils import handle_cd_command
+from litecli.packages.prompt_utils import confirm_destructive_query
+
+use_expanded_output = False
+PAGER_ENABLED = True
+tee_file = None
+once_file = written_to_once_file = None
+favoritequeries = FavoriteQueries(ConfigObj())
+
+
+@export
+def set_favorite_queries(config):
+ global favoritequeries
+ favoritequeries = FavoriteQueries(config)
+
+
+@export
+def set_pager_enabled(val):
+ global PAGER_ENABLED
+ PAGER_ENABLED = val
+
+
+@export
+def is_pager_enabled():
+ return PAGER_ENABLED
+
+
+@export
+@special_command(
+ "pager",
+ "\\P [command]",
+ "Set PAGER. Print the query results via PAGER.",
+ arg_type=PARSED_QUERY,
+ aliases=("\\P",),
+ case_sensitive=True,
+)
+def set_pager(arg, **_):
+ if arg:
+ os.environ["PAGER"] = arg
+ msg = "PAGER set to %s." % arg
+ set_pager_enabled(True)
+ else:
+ if "PAGER" in os.environ:
+ msg = "PAGER set to %s." % os.environ["PAGER"]
+ else:
+ # This uses click's default per echo_via_pager.
+ msg = "Pager enabled."
+ set_pager_enabled(True)
+
+ return [(None, None, None, msg)]
+
+
+@export
+@special_command(
+ "nopager",
+ "\\n",
+ "Disable pager, print to stdout.",
+ arg_type=NO_QUERY,
+ aliases=("\\n",),
+ case_sensitive=True,
+)
+def disable_pager():
+ set_pager_enabled(False)
+ return [(None, None, None, "Pager disabled.")]
+
+
+@export
+def set_expanded_output(val):
+ global use_expanded_output
+ use_expanded_output = val
+
+
+@export
+def is_expanded_output():
+ return use_expanded_output
+
+
+_logger = logging.getLogger(__name__)
+
+
+@export
+def editor_command(command):
+ """
+ Is this an external editor command?
+ :param command: string
+ """
+ # It is possible to have `\e filename` or `SELECT * FROM \e`. So we check
+ # for both conditions.
+ return command.strip().endswith("\\e") or command.strip().startswith("\\e")
+
+
+@export
+def get_filename(sql):
+ if sql.strip().startswith("\\e"):
+ command, _, filename = sql.partition(" ")
+ return filename.strip() or None
+
+
+@export
+def get_editor_query(sql):
+ """Get the query part of an editor command."""
+ sql = sql.strip()
+
+ # The reason we can't simply do .strip('\e') is that it strips characters,
+ # not a substring. So it'll strip "e" in the end of the sql also!
+ # Ex: "select * from style\e" -> "select * from styl".
+ pattern = re.compile("(^\\\e|\\\e$)")
+ while pattern.search(sql):
+ sql = pattern.sub("", sql)
+
+ return sql
+
+
+@export
+def open_external_editor(filename=None, sql=None):
+ """Open external editor, wait for the user to type in their query, return
+ the query.
+
+ :return: list with one tuple, query as first element.
+
+ """
+
+ message = None
+ filename = filename.strip().split(" ", 1)[0] if filename else None
+
+ sql = sql or ""
+ MARKER = "# Type your query above this line.\n"
+
+ # Populate the editor buffer with the partial sql (if available) and a
+ # placeholder comment.
+ query = click.edit(
+ "{sql}\n\n{marker}".format(sql=sql, marker=MARKER),
+ filename=filename,
+ extension=".sql",
+ )
+
+ if filename:
+ try:
+ with open(filename, encoding="utf-8") as f:
+ query = f.read()
+ except IOError:
+ message = "Error reading file: %s." % filename
+
+ if query is not None:
+ query = query.split(MARKER, 1)[0].rstrip("\n")
+ else:
+ # Don't return None for the caller to deal with.
+ # Empty string is ok.
+ query = sql
+
+ return (query, message)
+
+
+@special_command(
+ "\\f",
+ "\\f [name [args..]]",
+ "List or execute favorite queries.",
+ arg_type=PARSED_QUERY,
+ case_sensitive=True,
+)
+def execute_favorite_query(cur, arg, verbose=False, **_):
+ """Returns (title, rows, headers, status)"""
+ if arg == "":
+ for result in list_favorite_queries():
+ yield result
+
+ """Parse out favorite name and optional substitution parameters"""
+ name, _, arg_str = arg.partition(" ")
+ args = shlex.split(arg_str)
+
+ query = favoritequeries.get(name)
+ if query is None:
+ message = "No favorite query: %s" % (name)
+ yield (None, None, None, message)
+ elif "?" in query:
+ for sql in sqlparse.split(query):
+ sql = sql.rstrip(";")
+ title = "> %s" % (sql) if verbose else None
+ cur.execute(sql, args)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ yield (title, cur, headers, None)
+ else:
+ yield (title, None, None, None)
+ else:
+ query, arg_error = subst_favorite_query_args(query, args)
+ if arg_error:
+ yield (None, None, None, arg_error)
+ else:
+ for sql in sqlparse.split(query):
+ sql = sql.rstrip(";")
+ title = "> %s" % (sql) if verbose else None
+ cur.execute(sql)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ yield (title, cur, headers, None)
+ else:
+ yield (title, None, None, None)
+
+
+def list_favorite_queries():
+ """List of all favorite queries.
+ Returns (title, rows, headers, status)"""
+
+ headers = ["Name", "Query"]
+ rows = [(r, favoritequeries.get(r)) for r in favoritequeries.list()]
+
+ if not rows:
+ status = "\nNo favorite queries found." + favoritequeries.usage
+ else:
+ status = ""
+ return [("", rows, headers, status)]
+
+
+def subst_favorite_query_args(query, args):
+ """replace positional parameters ($1...$N) in query."""
+ for idx, val in enumerate(args):
+ shell_subst_var = "$" + str(idx + 1)
+ question_subst_var = "?"
+ if shell_subst_var in query:
+ query = query.replace(shell_subst_var, val)
+ elif question_subst_var in query:
+ query = query.replace(question_subst_var, val, 1)
+ else:
+ return [
+ None,
+ "Too many arguments.\nQuery does not have enough place holders to substitute.\n"
+ + query,
+ ]
+
+ match = re.search("\\?|\\$\d+", query)
+ if match:
+ return [
+ None,
+ "missing substitution for " + match.group(0) + " in query:\n " + query,
+ ]
+
+ return [query, None]
+
+
+@special_command("\\fs", "\\fs name query", "Save a favorite query.")
+def save_favorite_query(arg, **_):
+ """Save a new favorite query.
+ Returns (title, rows, headers, status)"""
+
+ usage = "Syntax: \\fs name query.\n\n" + favoritequeries.usage
+ if not arg:
+ return [(None, None, None, usage)]
+
+ name, _, query = arg.partition(" ")
+
+ # If either name or query is missing then print the usage and complain.
+ if (not name) or (not query):
+ return [(None, None, None, usage + "Err: Both name and query are required.")]
+
+ favoritequeries.save(name, query)
+ return [(None, None, None, "Saved.")]
+
+
+@special_command("\\fd", "\\fd [name]", "Delete a favorite query.")
+def delete_favorite_query(arg, **_):
+ """Delete an existing favorite query.
+ """
+ usage = "Syntax: \\fd name.\n\n" + favoritequeries.usage
+ if not arg:
+ return [(None, None, None, usage)]
+
+ status = favoritequeries.delete(arg)
+
+ return [(None, None, None, status)]
+
+
+@special_command("system", "system [command]", "Execute a system shell commmand.")
+def execute_system_command(arg, **_):
+ """Execute a system shell command."""
+ usage = "Syntax: system [command].\n"
+
+ if not arg:
+ return [(None, None, None, usage)]
+
+ try:
+ command = arg.strip()
+ if command.startswith("cd"):
+ ok, error_message = handle_cd_command(arg)
+ if not ok:
+ return [(None, None, None, error_message)]
+ return [(None, None, None, "")]
+
+ args = arg.split(" ")
+ process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ output, error = process.communicate()
+ response = output if not error else error
+
+ # Python 3 returns bytes. This needs to be decoded to a string.
+ if isinstance(response, bytes):
+ encoding = locale.getpreferredencoding(False)
+ response = response.decode(encoding)
+
+ return [(None, None, None, response)]
+ except OSError as e:
+ return [(None, None, None, "OSError: %s" % e.strerror)]
+
+
+def parseargfile(arg):
+ if arg.startswith("-o "):
+ mode = "w"
+ filename = arg[3:]
+ else:
+ mode = "a"
+ filename = arg
+
+ if not filename:
+ raise TypeError("You must provide a filename.")
+
+ return {"file": os.path.expanduser(filename), "mode": mode}
+
+
+@special_command(
+ "tee",
+ "tee [-o] filename",
+ "Append all results to an output file (overwrite using -o).",
+)
+def set_tee(arg, **_):
+ global tee_file
+
+ try:
+ tee_file = open(**parseargfile(arg))
+ except (IOError, OSError) as e:
+ raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror))
+
+ return [(None, None, None, "")]
+
+
+@export
+def close_tee():
+ global tee_file
+ if tee_file:
+ tee_file.close()
+ tee_file = None
+
+
+@special_command("notee", "notee", "Stop writing results to an output file.")
+def no_tee(arg, **_):
+ close_tee()
+ return [(None, None, None, "")]
+
+
+@export
+def write_tee(output):
+ global tee_file
+ if tee_file:
+ click.echo(output, file=tee_file, nl=False)
+ click.echo("\n", file=tee_file, nl=False)
+ tee_file.flush()
+
+
+@special_command(
+ ".once",
+ "\\o [-o] filename",
+ "Append next result to an output file (overwrite using -o).",
+ aliases=("\\o", "\\once"),
+)
+def set_once(arg, **_):
+ global once_file
+
+ once_file = parseargfile(arg)
+
+ return [(None, None, None, "")]
+
+
+@export
+def write_once(output):
+ global once_file, written_to_once_file
+ if output and once_file:
+ try:
+ f = open(**once_file)
+ except (IOError, OSError) as e:
+ once_file = None
+ raise OSError(
+ "Cannot write to file '{}': {}".format(e.filename, e.strerror)
+ )
+
+ with f:
+ click.echo(output, file=f, nl=False)
+ click.echo("\n", file=f, nl=False)
+ written_to_once_file = True
+
+
+@export
+def unset_once_if_written():
+ """Unset the once file, if it has been written to."""
+ global once_file
+ if written_to_once_file:
+ once_file = None
+
+
+@special_command(
+ "watch",
+ "watch [seconds] [-c] query",
+ "Executes the query every [seconds] seconds (by default 5).",
+)
+def watch_query(arg, **kwargs):
+ usage = """Syntax: watch [seconds] [-c] query.
+ * seconds: The interval at the query will be repeated, in seconds.
+ By default 5.
+ * -c: Clears the screen between every iteration.
+"""
+ if not arg:
+ yield (None, None, None, usage)
+ raise StopIteration
+ seconds = 5
+ clear_screen = False
+ statement = None
+ while statement is None:
+ arg = arg.strip()
+ if not arg:
+ # Oops, we parsed all the arguments without finding a statement
+ yield (None, None, None, usage)
+ raise StopIteration
+ (current_arg, _, arg) = arg.partition(" ")
+ try:
+ seconds = float(current_arg)
+ continue
+ except ValueError:
+ pass
+ if current_arg == "-c":
+ clear_screen = True
+ continue
+ statement = "{0!s} {1!s}".format(current_arg, arg)
+ destructive_prompt = confirm_destructive_query(statement)
+ if destructive_prompt is False:
+ click.secho("Wise choice!")
+ raise StopIteration
+ elif destructive_prompt is True:
+ click.secho("Your call!")
+ cur = kwargs["cur"]
+ sql_list = [
+ (sql.rstrip(";"), "> {0!s}".format(sql)) for sql in sqlparse.split(statement)
+ ]
+ old_pager_enabled = is_pager_enabled()
+ while True:
+ if clear_screen:
+ click.clear()
+ try:
+ # Somewhere in the code the pager its activated after every yield,
+ # so we disable it in every iteration
+ set_pager_enabled(False)
+ for (sql, title) in sql_list:
+ cur.execute(sql)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ yield (title, cur, headers, None)
+ else:
+ yield (title, None, None, None)
+ sleep(seconds)
+ except KeyboardInterrupt:
+ # This prints the Ctrl-C character in its own line, which prevents
+ # to print a line with the cursor positioned behind the prompt
+ click.secho("", nl=True)
+ raise StopIteration
+ finally:
+ set_pager_enabled(old_pager_enabled)
diff --git a/litecli/packages/special/main.py b/litecli/packages/special/main.py
new file mode 100644
index 0000000..3dd0e77
--- /dev/null
+++ b/litecli/packages/special/main.py
@@ -0,0 +1,160 @@
+from __future__ import unicode_literals
+import logging
+from collections import namedtuple
+
+from . import export
+
+log = logging.getLogger(__name__)
+
+NO_QUERY = 0
+PARSED_QUERY = 1
+RAW_QUERY = 2
+
+SpecialCommand = namedtuple(
+ "SpecialCommand",
+ [
+ "handler",
+ "command",
+ "shortcut",
+ "description",
+ "arg_type",
+ "hidden",
+ "case_sensitive",
+ ],
+)
+
+COMMANDS = {}
+
+
+@export
+class ArgumentMissing(Exception):
+ pass
+
+
+@export
+class CommandNotFound(Exception):
+ pass
+
+
+@export
+def parse_special_command(sql):
+ command, _, arg = sql.partition(" ")
+ verbose = "+" in command
+ command = command.strip().replace("+", "")
+ return (command, verbose, arg.strip())
+
+
+@export
+def special_command(
+ command,
+ shortcut,
+ description,
+ arg_type=PARSED_QUERY,
+ hidden=False,
+ case_sensitive=False,
+ aliases=(),
+):
+ def wrapper(wrapped):
+ register_special_command(
+ wrapped,
+ command,
+ shortcut,
+ description,
+ arg_type,
+ hidden,
+ case_sensitive,
+ aliases,
+ )
+ return wrapped
+
+ return wrapper
+
+
+@export
+def register_special_command(
+ handler,
+ command,
+ shortcut,
+ description,
+ arg_type=PARSED_QUERY,
+ hidden=False,
+ case_sensitive=False,
+ aliases=(),
+):
+ cmd = command.lower() if not case_sensitive else command
+ COMMANDS[cmd] = SpecialCommand(
+ handler, command, shortcut, description, arg_type, hidden, case_sensitive
+ )
+ for alias in aliases:
+ cmd = alias.lower() if not case_sensitive else alias
+ COMMANDS[cmd] = SpecialCommand(
+ handler,
+ command,
+ shortcut,
+ description,
+ arg_type,
+ case_sensitive=case_sensitive,
+ hidden=True,
+ )
+
+
+@export
+def execute(cur, sql):
+ """Execute a special command and return the results. If the special command
+ is not supported a KeyError will be raised.
+ """
+ command, verbose, arg = parse_special_command(sql)
+
+ if (command not in COMMANDS) and (command.lower() not in COMMANDS):
+ raise CommandNotFound
+
+ try:
+ special_cmd = COMMANDS[command]
+ except KeyError:
+ special_cmd = COMMANDS[command.lower()]
+ if special_cmd.case_sensitive:
+ raise CommandNotFound("Command not found: %s" % command)
+
+ if special_cmd.arg_type == NO_QUERY:
+ return special_cmd.handler()
+ elif special_cmd.arg_type == PARSED_QUERY:
+ return special_cmd.handler(cur=cur, arg=arg, verbose=verbose)
+ elif special_cmd.arg_type == RAW_QUERY:
+ return special_cmd.handler(cur=cur, query=sql)
+
+
+@special_command(
+ "help", "\\?", "Show this help.", arg_type=NO_QUERY, aliases=("\\?", "?")
+)
+def show_help(): # All the parameters are ignored.
+ headers = ["Command", "Shortcut", "Description"]
+ result = []
+
+ for _, value in sorted(COMMANDS.items()):
+ if not value.hidden:
+ result.append((value.command, value.shortcut, value.description))
+ return [(None, result, headers, None)]
+
+
+@special_command(".exit", "\\q", "Exit.", arg_type=NO_QUERY, aliases=("\\q", "exit"))
+@special_command("quit", "\\q", "Quit.", arg_type=NO_QUERY)
+def quit(*_args):
+ raise EOFError
+
+
+@special_command(
+ "\\e",
+ "\\e",
+ "Edit command with editor (uses $EDITOR).",
+ arg_type=NO_QUERY,
+ case_sensitive=True,
+)
+@special_command(
+ "\\G",
+ "\\G",
+ "Display current query results vertically.",
+ arg_type=NO_QUERY,
+ case_sensitive=True,
+)
+def stub():
+ raise NotImplementedError
diff --git a/litecli/packages/special/utils.py b/litecli/packages/special/utils.py
new file mode 100644
index 0000000..eed9306
--- /dev/null
+++ b/litecli/packages/special/utils.py
@@ -0,0 +1,48 @@
+import os
+import subprocess
+
+
+def handle_cd_command(arg):
+ """Handles a `cd` shell command by calling python's os.chdir."""
+ CD_CMD = "cd"
+ tokens = arg.split(CD_CMD + " ")
+ directory = tokens[-1] if len(tokens) > 1 else None
+ if not directory:
+ return False, "No folder name was provided."
+ try:
+ os.chdir(directory)
+ subprocess.call(["pwd"])
+ return True, None
+ except OSError as e:
+ return False, e.strerror
+
+
+def format_uptime(uptime_in_seconds):
+ """Format number of seconds into human-readable string.
+
+ :param uptime_in_seconds: The server uptime in seconds.
+ :returns: A human-readable string representing the uptime.
+
+ >>> uptime = format_uptime('56892')
+ >>> print(uptime)
+ 15 hours 48 min 12 sec
+ """
+
+ m, s = divmod(int(uptime_in_seconds), 60)
+ h, m = divmod(m, 60)
+ d, h = divmod(h, 24)
+
+ uptime_values = []
+
+ for value, unit in ((d, "days"), (h, "hours"), (m, "min"), (s, "sec")):
+ if value == 0 and not uptime_values:
+ # Don't include a value/unit if the unit isn't applicable to
+ # the uptime. E.g. don't do 0 days 0 hours 1 min 30 sec.
+ continue
+ elif value == 1 and unit.endswith("s"):
+ # Remove the "s" if the unit is singular.
+ unit = unit[:-1]
+ uptime_values.append("{0} {1}".format(value, unit))
+
+ uptime = " ".join(uptime_values)
+ return uptime
diff --git a/litecli/sqlcompleter.py b/litecli/sqlcompleter.py
new file mode 100644
index 0000000..64ca352
--- /dev/null
+++ b/litecli/sqlcompleter.py
@@ -0,0 +1,612 @@
+from __future__ import print_function
+from __future__ import unicode_literals
+import logging
+from re import compile, escape
+from collections import Counter
+
+from prompt_toolkit.completion import Completer, Completion
+
+from .packages.completion_engine import suggest_type
+from .packages.parseutils import last_word
+from .packages.special.iocommands import favoritequeries
+from .packages.filepaths import parse_path, complete_path, suggest_path
+
+_logger = logging.getLogger(__name__)
+
+
+class SQLCompleter(Completer):
+ keywords = [
+ "ABORT",
+ "ACTION",
+ "ADD",
+ "AFTER",
+ "ALL",
+ "ALTER",
+ "ANALYZE",
+ "AND",
+ "AS",
+ "ASC",
+ "ATTACH",
+ "AUTOINCREMENT",
+ "BEFORE",
+ "BEGIN",
+ "BETWEEN",
+ "BIGINT",
+ "BLOB",
+ "BOOLEAN",
+ "BY",
+ "CASCADE",
+ "CASE",
+ "CAST",
+ "CHARACTER",
+ "CHECK",
+ "CLOB",
+ "COLLATE",
+ "COLUMN",
+ "COMMIT",
+ "CONFLICT",
+ "CONSTRAINT",
+ "CREATE",
+ "CROSS",
+ "CURRENT",
+ "CURRENT_DATE",
+ "CURRENT_TIME",
+ "CURRENT_TIMESTAMP",
+ "DATABASE",
+ "DATE",
+ "DATETIME",
+ "DECIMAL",
+ "DEFAULT",
+ "DEFERRABLE",
+ "DEFERRED",
+ "DELETE",
+ "DETACH",
+ "DISTINCT",
+ "DO",
+ "DOUBLE PRECISION",
+ "DOUBLE",
+ "DROP",
+ "EACH",
+ "ELSE",
+ "END",
+ "ESCAPE",
+ "EXCEPT",
+ "EXCLUSIVE",
+ "EXISTS",
+ "EXPLAIN",
+ "FAIL",
+ "FILTER",
+ "FLOAT",
+ "FOLLOWING",
+ "FOR",
+ "FOREIGN",
+ "FROM",
+ "FULL",
+ "GLOB",
+ "GROUP",
+ "HAVING",
+ "IF",
+ "IGNORE",
+ "IMMEDIATE",
+ "IN",
+ "INDEX",
+ "INDEXED",
+ "INITIALLY",
+ "INNER",
+ "INSERT",
+ "INSTEAD",
+ "INT",
+ "INT2",
+ "INT8",
+ "INTEGER",
+ "INTERSECT",
+ "INTO",
+ "IS",
+ "ISNULL",
+ "JOIN",
+ "KEY",
+ "LEFT",
+ "LIKE",
+ "LIMIT",
+ "MATCH",
+ "MEDIUMINT",
+ "NATIVE CHARACTER",
+ "NATURAL",
+ "NCHAR",
+ "NO",
+ "NOT",
+ "NOTHING",
+ "NULL",
+ "NULLS FIRST",
+ "NULLS LAST",
+ "NUMERIC",
+ "NVARCHAR",
+ "OF",
+ "OFFSET",
+ "ON",
+ "OR",
+ "ORDER BY",
+ "OUTER",
+ "OVER",
+ "PARTITION",
+ "PLAN",
+ "PRAGMA",
+ "PRECEDING",
+ "PRIMARY",
+ "QUERY",
+ "RAISE",
+ "RANGE",
+ "REAL",
+ "RECURSIVE",
+ "REFERENCES",
+ "REGEXP",
+ "REINDEX",
+ "RELEASE",
+ "RENAME",
+ "REPLACE",
+ "RESTRICT",
+ "RIGHT",
+ "ROLLBACK",
+ "ROW",
+ "ROWS",
+ "SAVEPOINT",
+ "SELECT",
+ "SET",
+ "SMALLINT",
+ "TABLE",
+ "TEMP",
+ "TEMPORARY",
+ "TEXT",
+ "THEN",
+ "TINYINT",
+ "TO",
+ "TRANSACTION",
+ "TRIGGER",
+ "UNBOUNDED",
+ "UNION",
+ "UNIQUE",
+ "UNSIGNED BIG INT",
+ "UPDATE",
+ "USING",
+ "VACUUM",
+ "VALUES",
+ "VARCHAR",
+ "VARYING CHARACTER",
+ "VIEW",
+ "VIRTUAL",
+ "WHEN",
+ "WHERE",
+ "WINDOW",
+ "WITH",
+ "WITHOUT",
+ ]
+
+ functions = [
+ "ABS",
+ "AVG",
+ "CHANGES",
+ "CHAR",
+ "COALESCE",
+ "COUNT",
+ "CUME_DIST",
+ "DATE",
+ "DATETIME",
+ "DENSE_RANK",
+ "GLOB",
+ "GROUP_CONCAT",
+ "HEX",
+ "IFNULL",
+ "INSTR",
+ "JSON",
+ "JSON_ARRAY",
+ "JSON_ARRAY_LENGTH",
+ "JSON_EACH",
+ "JSON_EXTRACT",
+ "JSON_GROUP_ARRAY",
+ "JSON_GROUP_OBJECT",
+ "JSON_INSERT",
+ "JSON_OBJECT",
+ "JSON_PATCH",
+ "JSON_QUOTE",
+ "JSON_REMOVE",
+ "JSON_REPLACE",
+ "JSON_SET",
+ "JSON_TREE",
+ "JSON_TYPE",
+ "JSON_VALID",
+ "JULIANDAY",
+ "LAG",
+ "LAST_INSERT_ROWID",
+ "LENGTH",
+ "LIKELIHOOD",
+ "LIKELY",
+ "LOAD_EXTENSION",
+ "LOWER",
+ "LTRIM",
+ "MAX",
+ "MIN",
+ "NTILE",
+ "NULLIF",
+ "PERCENT_RANK",
+ "PRINTF",
+ "QUOTE",
+ "RANDOM",
+ "RANDOMBLOB",
+ "RANK",
+ "REPLACE",
+ "ROUND",
+ "ROW_NUMBER",
+ "RTRIM",
+ "SOUNDEX",
+ "SQLITE_COMPILEOPTION_GET",
+ "SQLITE_COMPILEOPTION_USED",
+ "SQLITE_OFFSET",
+ "SQLITE_SOURCE_ID",
+ "SQLITE_VERSION",
+ "STRFTIME",
+ "SUBSTR",
+ "SUM",
+ "TIME",
+ "TOTAL",
+ "TOTAL_CHANGES",
+ "TRIM",
+ ]
+
+ def __init__(self, supported_formats=(), keyword_casing="auto"):
+ super(self.__class__, self).__init__()
+ self.reserved_words = set()
+ for x in self.keywords:
+ self.reserved_words.update(x.split())
+ self.name_pattern = compile("^[_a-z][_a-z0-9\$]*$")
+
+ self.special_commands = []
+ self.table_formats = supported_formats
+ if keyword_casing not in ("upper", "lower", "auto"):
+ keyword_casing = "auto"
+ self.keyword_casing = keyword_casing
+ self.reset_completions()
+
+ def escape_name(self, name):
+ if name and (
+ (not self.name_pattern.match(name))
+ or (name.upper() in self.reserved_words)
+ or (name.upper() in self.functions)
+ ):
+ name = "`%s`" % name
+
+ return name
+
+ def unescape_name(self, name):
+ """Unquote a string."""
+ if name and name[0] == '"' and name[-1] == '"':
+ name = name[1:-1]
+
+ return name
+
+ def escaped_names(self, names):
+ return [self.escape_name(name) for name in names]
+
+ def extend_special_commands(self, special_commands):
+ # Special commands are not part of all_completions since they can only
+ # be at the beginning of a line.
+ self.special_commands.extend(special_commands)
+
+ def extend_database_names(self, databases):
+ self.databases.extend(databases)
+
+ def extend_keywords(self, additional_keywords):
+ self.keywords.extend(additional_keywords)
+ self.all_completions.update(additional_keywords)
+
+ def extend_schemata(self, schema):
+ if schema is None:
+ return
+ metadata = self.dbmetadata["tables"]
+ metadata[schema] = {}
+
+ # dbmetadata.values() are the 'tables' and 'functions' dicts
+ for metadata in self.dbmetadata.values():
+ metadata[schema] = {}
+ self.all_completions.update(schema)
+
+ def extend_relations(self, data, kind):
+ """Extend metadata for tables or views
+
+ :param data: list of (rel_name, ) tuples
+ :param kind: either 'tables' or 'views'
+ :return:
+ """
+ # 'data' is a generator object. It can throw an exception while being
+ # consumed. This could happen if the user has launched the app without
+ # specifying a database name. This exception must be handled to prevent
+ # crashing.
+ try:
+ data = [self.escaped_names(d) for d in data]
+ except Exception:
+ data = []
+
+ # dbmetadata['tables'][$schema_name][$table_name] should be a list of
+ # column names. Default to an asterisk
+ metadata = self.dbmetadata[kind]
+ for relname in data:
+ try:
+ metadata[self.dbname][relname[0]] = ["*"]
+ except KeyError:
+ _logger.error(
+ "%r %r listed in unrecognized schema %r",
+ kind,
+ relname[0],
+ self.dbname,
+ )
+ self.all_completions.add(relname[0])
+
+ def extend_columns(self, column_data, kind):
+ """Extend column metadata
+
+ :param column_data: list of (rel_name, column_name) tuples
+ :param kind: either 'tables' or 'views'
+ :return:
+ """
+ # 'column_data' is a generator object. It can throw an exception while
+ # being consumed. This could happen if the user has launched the app
+ # without specifying a database name. This exception must be handled to
+ # prevent crashing.
+ try:
+ column_data = [self.escaped_names(d) for d in column_data]
+ except Exception:
+ column_data = []
+
+ metadata = self.dbmetadata[kind]
+ for relname, column in column_data:
+ metadata[self.dbname][relname].append(column)
+ self.all_completions.add(column)
+
+ def extend_functions(self, func_data):
+ # 'func_data' is a generator object. It can throw an exception while
+ # being consumed. This could happen if the user has launched the app
+ # without specifying a database name. This exception must be handled to
+ # prevent crashing.
+ try:
+ func_data = [self.escaped_names(d) for d in func_data]
+ except Exception:
+ func_data = []
+
+ # dbmetadata['functions'][$schema_name][$function_name] should return
+ # function metadata.
+ metadata = self.dbmetadata["functions"]
+
+ for func in func_data:
+ metadata[self.dbname][func[0]] = None
+ self.all_completions.add(func[0])
+
+ def set_dbname(self, dbname):
+ self.dbname = dbname
+
+ def reset_completions(self):
+ self.databases = []
+ self.dbname = ""
+ self.dbmetadata = {"tables": {}, "views": {}, "functions": {}}
+ self.all_completions = set(self.keywords + self.functions)
+
+ @staticmethod
+ def find_matches(
+ text,
+ collection,
+ start_only=False,
+ fuzzy=True,
+ casing=None,
+ punctuations="most_punctuations",
+ ):
+ """Find completion matches for the given text.
+
+ Given the user's input text and a collection of available
+ completions, find completions matching the last word of the
+ text.
+
+ If `start_only` is True, the text will match an available
+ completion only at the beginning. Otherwise, a completion is
+ considered a match if the text appears anywhere within it.
+
+ yields prompt_toolkit Completion instances for any matches found
+ in the collection of available completions.
+ """
+ last = last_word(text, include=punctuations)
+ text = last.lower()
+
+ completions = []
+
+ if fuzzy:
+ regex = ".*?".join(map(escape, text))
+ pat = compile("(%s)" % regex)
+ for item in sorted(collection):
+ r = pat.search(item.lower())
+ if r:
+ completions.append((len(r.group()), r.start(), item))
+ else:
+ match_end_limit = len(text) if start_only else None
+ for item in sorted(collection):
+ match_point = item.lower().find(text, 0, match_end_limit)
+ if match_point >= 0:
+ completions.append((len(text), match_point, item))
+
+ if casing == "auto":
+ casing = "lower" if last and last[-1].islower() else "upper"
+
+ def apply_case(kw):
+ if casing == "upper":
+ return kw.upper()
+ return kw.lower()
+
+ return (
+ Completion(z if casing is None else apply_case(z), -len(text))
+ for x, y, z in sorted(completions)
+ )
+
+ def get_completions(self, document, complete_event):
+ word_before_cursor = document.get_word_before_cursor(WORD=True)
+ completions = []
+ suggestions = suggest_type(document.text, document.text_before_cursor)
+
+ for suggestion in suggestions:
+
+ _logger.debug("Suggestion type: %r", suggestion["type"])
+
+ if suggestion["type"] == "column":
+ tables = suggestion["tables"]
+ _logger.debug("Completion column scope: %r", tables)
+ scoped_cols = self.populate_scoped_cols(tables)
+ if suggestion.get("drop_unique"):
+ # drop_unique is used for 'tb11 JOIN tbl2 USING (...'
+ # which should suggest only columns that appear in more than
+ # one table
+ scoped_cols = [
+ col
+ for (col, count) in Counter(scoped_cols).items()
+ if count > 1 and col != "*"
+ ]
+
+ cols = self.find_matches(word_before_cursor, scoped_cols)
+ completions.extend(cols)
+
+ elif suggestion["type"] == "function":
+ # suggest user-defined functions using substring matching
+ funcs = self.populate_schema_objects(suggestion["schema"], "functions")
+ user_funcs = self.find_matches(word_before_cursor, funcs)
+ completions.extend(user_funcs)
+
+ # suggest hardcoded functions using startswith matching only if
+ # there is no schema qualifier. If a schema qualifier is
+ # present it probably denotes a table.
+ # eg: SELECT * FROM users u WHERE u.
+ if not suggestion["schema"]:
+ predefined_funcs = self.find_matches(
+ word_before_cursor,
+ self.functions,
+ start_only=True,
+ fuzzy=False,
+ casing=self.keyword_casing,
+ )
+ completions.extend(predefined_funcs)
+
+ elif suggestion["type"] == "table":
+ tables = self.populate_schema_objects(suggestion["schema"], "tables")
+ tables = self.find_matches(word_before_cursor, tables)
+ completions.extend(tables)
+
+ elif suggestion["type"] == "view":
+ views = self.populate_schema_objects(suggestion["schema"], "views")
+ views = self.find_matches(word_before_cursor, views)
+ completions.extend(views)
+
+ elif suggestion["type"] == "alias":
+ aliases = suggestion["aliases"]
+ aliases = self.find_matches(word_before_cursor, aliases)
+ completions.extend(aliases)
+
+ elif suggestion["type"] == "database":
+ dbs = self.find_matches(word_before_cursor, self.databases)
+ completions.extend(dbs)
+
+ elif suggestion["type"] == "keyword":
+ keywords = self.find_matches(
+ word_before_cursor,
+ self.keywords,
+ start_only=True,
+ fuzzy=False,
+ casing=self.keyword_casing,
+ punctuations="many_punctuations",
+ )
+ completions.extend(keywords)
+
+ elif suggestion["type"] == "special":
+ special = self.find_matches(
+ word_before_cursor,
+ self.special_commands,
+ start_only=True,
+ fuzzy=False,
+ punctuations="many_punctuations",
+ )
+ completions.extend(special)
+ elif suggestion["type"] == "favoritequery":
+ queries = self.find_matches(
+ word_before_cursor,
+ favoritequeries.list(),
+ start_only=False,
+ fuzzy=True,
+ )
+ completions.extend(queries)
+ elif suggestion["type"] == "table_format":
+ formats = self.find_matches(
+ word_before_cursor, self.table_formats, start_only=True, fuzzy=False
+ )
+ completions.extend(formats)
+ elif suggestion["type"] == "file_name":
+ file_names = self.find_files(word_before_cursor)
+ completions.extend(file_names)
+
+ return completions
+
+ def find_files(self, word):
+ """Yield matching directory or file names.
+
+ :param word:
+ :return: iterable
+
+ """
+ base_path, last_path, position = parse_path(word)
+ paths = suggest_path(word)
+ for name in sorted(paths):
+ suggestion = complete_path(name, last_path)
+ if suggestion:
+ yield Completion(suggestion, position)
+
+ def populate_scoped_cols(self, scoped_tbls):
+ """Find all columns in a set of scoped_tables
+ :param scoped_tbls: list of (schema, table, alias) tuples
+ :return: list of column names
+ """
+ columns = []
+ meta = self.dbmetadata
+
+ for tbl in scoped_tbls:
+ # A fully qualified schema.relname reference or default_schema
+ # DO NOT escape schema names.
+ schema = tbl[0] or self.dbname
+ relname = tbl[1]
+ escaped_relname = self.escape_name(tbl[1])
+
+ # We don't know if schema.relname is a table or view. Since
+ # tables and views cannot share the same name, we can check one
+ # at a time
+ try:
+ columns.extend(meta["tables"][schema][relname])
+
+ # Table exists, so don't bother checking for a view
+ continue
+ except KeyError:
+ try:
+ columns.extend(meta["tables"][schema][escaped_relname])
+ # Table exists, so don't bother checking for a view
+ continue
+ except KeyError:
+ pass
+
+ try:
+ columns.extend(meta["views"][schema][relname])
+ except KeyError:
+ pass
+
+ return columns
+
+ def populate_schema_objects(self, schema, obj_type):
+ """Returns list of tables or functions for a (optional) schema"""
+ metadata = self.dbmetadata[obj_type]
+ schema = schema or self.dbname
+
+ try:
+ objects = metadata[schema].keys()
+ except KeyError:
+ # schema doesn't exist
+ objects = []
+
+ return objects
diff --git a/litecli/sqlexecute.py b/litecli/sqlexecute.py
new file mode 100644
index 0000000..7ef103c
--- /dev/null
+++ b/litecli/sqlexecute.py
@@ -0,0 +1,212 @@
+import logging
+import sqlite3
+import uuid
+from contextlib import closing
+from sqlite3 import OperationalError
+
+import sqlparse
+import os.path
+
+from .packages import special
+
+_logger = logging.getLogger(__name__)
+
+# FIELD_TYPES = decoders.copy()
+# FIELD_TYPES.update({
+# FIELD_TYPE.NULL: type(None)
+# })
+
+
+class SQLExecute(object):
+
+ databases_query = """
+ PRAGMA database_list
+ """
+
+ tables_query = """
+ SELECT name
+ FROM sqlite_master
+ WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%'
+ ORDER BY 1
+ """
+
+ table_columns_query = """
+ SELECT m.name as tableName, p.name as columnName
+ FROM sqlite_master m
+ LEFT OUTER JOIN pragma_table_info((m.name)) p ON m.name <> p.name
+ WHERE m.type IN ('table','view') AND m.name NOT LIKE 'sqlite_%'
+ ORDER BY tableName, columnName
+ """
+
+ functions_query = '''SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES
+ WHERE ROUTINE_TYPE="FUNCTION" AND ROUTINE_SCHEMA = "%s"'''
+
+ def __init__(self, database):
+ self.dbname = database
+ self._server_type = None
+ self.connection_id = None
+ self.conn = None
+ if not database:
+ _logger.debug("Database is not specified. Skip connection.")
+ return
+ self.connect()
+
+ def connect(self, database=None):
+ db = database or self.dbname
+ _logger.debug("Connection DB Params: \n" "\tdatabase: %r", database)
+
+ db_name = os.path.expanduser(db)
+ db_dir_name = os.path.dirname(os.path.abspath(db_name))
+ if not os.path.exists(db_dir_name):
+ raise Exception("Path does not exist: {}".format(db_dir_name))
+
+ conn = sqlite3.connect(database=db_name, isolation_level=None)
+ if self.conn:
+ self.conn.close()
+
+ self.conn = conn
+ # Update them after the connection is made to ensure that it was a
+ # successful connection.
+ self.dbname = db
+ # retrieve connection id
+ self.reset_connection_id()
+
+ def run(self, statement):
+ """Execute the sql in the database and return the results. The results
+ are a list of tuples. Each tuple has 4 values
+ (title, rows, headers, status).
+ """
+ # Remove spaces and EOL
+ statement = statement.strip()
+ if not statement: # Empty string
+ yield (None, None, None, None)
+
+ # Split the sql into separate queries and run each one.
+ # Unless it's saving a favorite query, in which case we
+ # want to save them all together.
+ if statement.startswith("\\fs"):
+ components = [statement]
+ else:
+ components = sqlparse.split(statement)
+
+ for sql in components:
+ # Remove spaces, eol and semi-colons.
+ sql = sql.rstrip(";")
+
+ # \G is treated specially since we have to set the expanded output.
+ if sql.endswith("\\G"):
+ special.set_expanded_output(True)
+ sql = sql[:-2].strip()
+
+ if not self.conn and not (
+ sql.startswith(".open")
+ or sql.lower().startswith("use")
+ or sql.startswith("\\u")
+ or sql.startswith("\\?")
+ or sql.startswith("\\q")
+ or sql.startswith("help")
+ or sql.startswith("exit")
+ or sql.startswith("quit")
+ ):
+ _logger.debug(
+ "Not connected to database. Will not run statement: %s.", sql
+ )
+ raise OperationalError("Not connected to database.")
+ # yield ('Not connected to database', None, None, None)
+ # return
+
+ cur = self.conn.cursor() if self.conn else None
+ try: # Special command
+ _logger.debug("Trying a dbspecial command. sql: %r", sql)
+ for result in special.execute(cur, sql):
+ yield result
+ except special.CommandNotFound: # Regular SQL
+ _logger.debug("Regular sql statement. sql: %r", sql)
+ cur.execute(sql)
+ yield self.get_result(cur)
+
+ def get_result(self, cursor):
+ """Get the current result's data from the cursor."""
+ title = headers = None
+
+ # cursor.description is not None for queries that return result sets,
+ # e.g. SELECT.
+ if cursor.description is not None:
+ headers = [x[0] for x in cursor.description]
+ status = "{0} row{1} in set"
+ cursor = list(cursor)
+ rowcount = len(cursor)
+ else:
+ _logger.debug("No rows in result.")
+ status = "Query OK, {0} row{1} affected"
+ rowcount = 0 if cursor.rowcount == -1 else cursor.rowcount
+ cursor = None
+
+ status = status.format(rowcount, "" if rowcount == 1 else "s")
+
+ return (title, cursor, headers, status)
+
+ def tables(self):
+ """Yields table names"""
+
+ with closing(self.conn.cursor()) as cur:
+ _logger.debug("Tables Query. sql: %r", self.tables_query)
+ cur.execute(self.tables_query)
+ for row in cur:
+ yield row
+
+ def table_columns(self):
+ """Yields column names"""
+ with closing(self.conn.cursor()) as cur:
+ _logger.debug("Columns Query. sql: %r", self.table_columns_query)
+ cur.execute(self.table_columns_query)
+ for row in cur:
+ yield row
+
+ def databases(self):
+ if not self.conn:
+ return
+
+ with closing(self.conn.cursor()) as cur:
+ _logger.debug("Databases Query. sql: %r", self.databases_query)
+ for row in cur.execute(self.databases_query):
+ yield row[1]
+
+ def functions(self):
+ """Yields tuples of (schema_name, function_name)"""
+
+ with closing(self.conn.cursor()) as cur:
+ _logger.debug("Functions Query. sql: %r", self.functions_query)
+ cur.execute(self.functions_query % self.dbname)
+ for row in cur:
+ yield row
+
+ def show_candidates(self):
+ with closing(self.conn.cursor()) as cur:
+ _logger.debug("Show Query. sql: %r", self.show_candidates_query)
+ try:
+ cur.execute(self.show_candidates_query)
+ except sqlite3.DatabaseError as e:
+ _logger.error("No show completions due to %r", e)
+ yield ""
+ else:
+ for row in cur:
+ yield (row[0].split(None, 1)[-1],)
+
+ def server_type(self):
+ self._server_type = ("sqlite3", "3")
+ return self._server_type
+
+ def get_connection_id(self):
+ if not self.connection_id:
+ self.reset_connection_id()
+ return self.connection_id
+
+ def reset_connection_id(self):
+ # Remember current connection id
+ _logger.debug("Get current connection id")
+ # res = self.run('select connection_id()')
+ self.connection_id = uuid.uuid4()
+ # for title, cur, headers, status in res:
+ # self.connection_id = cur.fetchone()[0]
+ _logger.debug("Current connection id: %s", self.connection_id)