From 4edd467b28c895483cd5468d51d1c6824a21715a Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 14 Aug 2020 18:58:23 +0200 Subject: Adding upstream version 1.5.0. Signed-off-by: Daniel Baumann --- .coveragerc | 3 + .github/PULL_REQUEST_TEMPLATE.md | 8 + .gitignore | 18 + .pre-commit-config.yaml | 6 + .travis.yml | 29 + CHANGELOG.md | 72 ++ CONTRIBUTING.md | 119 +++ LICENSE | 27 + MANIFEST.in | 8 + README.md | 54 ++ TODO | 3 + litecli/AUTHORS | 20 + litecli/__init__.py | 1 + litecli/clibuffer.py | 40 + litecli/clistyle.py | 114 +++ litecli/clitoolbar.py | 51 ++ litecli/compat.py | 9 + litecli/completion_refresher.py | 131 +++ litecli/config.py | 62 ++ litecli/encodingutils.py | 38 + litecli/key_bindings.py | 84 ++ litecli/lexer.py | 9 + litecli/liteclirc | 113 +++ litecli/main.py | 1008 +++++++++++++++++++++ litecli/packages/__init__.py | 0 litecli/packages/completion_engine.py | 331 +++++++ litecli/packages/filepaths.py | 88 ++ litecli/packages/parseutils.py | 227 +++++ litecli/packages/prompt_utils.py | 39 + litecli/packages/special/__init__.py | 12 + litecli/packages/special/dbcommands.py | 273 ++++++ litecli/packages/special/favoritequeries.py | 59 ++ litecli/packages/special/iocommands.py | 479 ++++++++++ litecli/packages/special/main.py | 160 ++++ litecli/packages/special/utils.py | 48 + litecli/sqlcompleter.py | 612 +++++++++++++ litecli/sqlexecute.py | 212 +++++ release.py | 130 +++ requirements-dev.txt | 9 + screenshots/litecli.gif | Bin 0 -> 742270 bytes screenshots/litecli.png | Bin 0 -> 109549 bytes setup.cfg | 18 + setup.py | 70 ++ tasks.py | 128 +++ tests/conftest.py | 40 + tests/data/import_data.csv | 2 + tests/liteclirc | 128 +++ tests/test.txt | 1 + tests/test_clistyle.py | 28 + tests/test_completion_engine.py | 655 +++++++++++++ tests/test_completion_refresher.py | 94 ++ tests/test_dbspecial.py | 65 ++ tests/test_main.py | 261 ++++++ tests/test_parseutils.py | 131 +++ tests/test_prompt_utils.py | 14 + tests/test_smart_completion_public_schema_only.py | 430 +++++++++ tests/test_sqlexecute.py | 392 ++++++++ tests/utils.py | 96 ++ tox.ini | 11 + 59 files changed, 7270 insertions(+) create mode 100644 .coveragerc create mode 100644 .github/PULL_REQUEST_TEMPLATE.md create mode 100644 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 .travis.yml create mode 100644 CHANGELOG.md create mode 100644 CONTRIBUTING.md create mode 100644 LICENSE create mode 100644 MANIFEST.in create mode 100644 README.md create mode 100644 TODO create mode 100644 litecli/AUTHORS create mode 100644 litecli/__init__.py create mode 100644 litecli/clibuffer.py create mode 100644 litecli/clistyle.py create mode 100644 litecli/clitoolbar.py create mode 100644 litecli/compat.py create mode 100644 litecli/completion_refresher.py create mode 100644 litecli/config.py create mode 100644 litecli/encodingutils.py create mode 100644 litecli/key_bindings.py create mode 100644 litecli/lexer.py create mode 100644 litecli/liteclirc create mode 100644 litecli/main.py create mode 100644 litecli/packages/__init__.py create mode 100644 litecli/packages/completion_engine.py create mode 100644 litecli/packages/filepaths.py create mode 100644 litecli/packages/parseutils.py create mode 100644 litecli/packages/prompt_utils.py create mode 100644 litecli/packages/special/__init__.py create mode 100644 litecli/packages/special/dbcommands.py create mode 100644 litecli/packages/special/favoritequeries.py create mode 100644 litecli/packages/special/iocommands.py create mode 100644 litecli/packages/special/main.py create mode 100644 litecli/packages/special/utils.py create mode 100644 litecli/sqlcompleter.py create mode 100644 litecli/sqlexecute.py create mode 100644 release.py create mode 100644 requirements-dev.txt create mode 100644 screenshots/litecli.gif create mode 100644 screenshots/litecli.png create mode 100644 setup.cfg create mode 100755 setup.py create mode 100644 tasks.py create mode 100644 tests/conftest.py create mode 100644 tests/data/import_data.csv create mode 100644 tests/liteclirc create mode 100644 tests/test.txt create mode 100644 tests/test_clistyle.py create mode 100644 tests/test_completion_engine.py create mode 100644 tests/test_completion_refresher.py create mode 100644 tests/test_dbspecial.py create mode 100644 tests/test_main.py create mode 100644 tests/test_parseutils.py create mode 100644 tests/test_prompt_utils.py create mode 100644 tests/test_smart_completion_public_schema_only.py create mode 100644 tests/test_sqlexecute.py create mode 100644 tests/utils.py create mode 100644 tox.ini diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..aea4994 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,3 @@ +[run] +parallel = True +source = litecli diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..3e14cc7 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,8 @@ +## Description + + + + +## Checklist + +- [ ] I've added this contribution to the `CHANGELOG.md` file. diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..63c3eb6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,18 @@ +.idea/ +/build +/dist +/litecli.egg-info +/htmlcov +/src +/test/behave.ini +/litecli_env +/.venv +/.eggs + +.vagrant +*.pyc +*.deb +*.swp +.cache/ +.coverage +.coverage.* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..b268b62 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,6 @@ +repos: +- repo: https://github.com/ambv/black + rev: stable + hooks: + - id: black + language_version: python3.7 diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..896e7f7 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,29 @@ +language: python +python: + - "3.6" + +matrix: + include: + - python: 3.7 + dist: xenial + sudo: true + +install: + - pip install -r requirements-dev.txt + - if [[ $TRAVIS_PYTHON_VERSION == '3.6' ]]; then pip install black; fi + - pip install -e . + +script: + - ./setup.py test --pytest-args="--cov-report= --cov=litecli" + - coverage report + - if [[ $TRAVIS_PYTHON_VERSION == '3.6' ]]; then ./setup.py lint; fi + +after_success: + - codecov + +notifications: + webhooks: + urls: + - YOUR_WEBHOOK_URL + on_success: change # options: [always|never|change] default: always + on_failure: always # options: [always|never|change] default: always diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..ec98b08 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,72 @@ +## Unreleased - TBD + + + +### Features + +- Add verbose feature to `favorite_query` command. (Thanks: [Zhaolong Zhu]) + - `\f query` does not show the full SQL. + - `\f+ query` shows the full SQL. + +### Bug Fixes + +- Fix compatibility with sqlparse >= 0.4.0. (Thanks: [chocolateboy]) + +## 1.4.1 - 2020-07-27 + +### Bug Fixes + +- Fix setup.py to set `long_description_content_type` as markdown. + +## 1.4.0 - 2020-07-27 + +### Features + +- Add NULLS FIRST and NULLS LAST to keywords. (Thanks: [Amjith]) + +## 1.3.2 - 2020-03-11 + +- Fix the completion engine to work with newer sqlparse. + +## 1.3.1 - 2020-03-11 + +- Remove the version pinning of sqlparse package. + +## 1.3.0 - 2020-02-11 + +### Features + +- Added `.import` command for importing data from file into table. (Thanks: [Zhaolong Zhu]) +- Upgraded to prompt-toolkit 3.x. + +## 1.2.0 - 2019-10-26 + +### Features + +- Enhance the `describe` command. (Thanks: [Amjith]) +- Autocomplete table names for special commands. (Thanks: [Amjith]) + +## 1.1.0 - 2019-07-14 + +### Features + +- Added `.read` command for reading scripts. +- Added `.load` command for loading extension libraries. (Thanks: [Zhiming Wang]) +- Add support for using `?` as a placeholder in the favorite queries. (Thanks: [Amjith]) +- Added shift-tab to select the previous entry in the completion menu. [Amjith] +- Added `describe` and `desc` keywords. + +### Bug Fixes + +- Clear error message when directory does not exist. (Thanks: [Irina Truong]) + +## 1.0.0 - 2019-01-04 + +- To new beginnings. :tada: + +[Amjith]: https://blog.amjith.com +[chocolateboy]: https://github.com/chocolateboy +[Irina Truong]: https://github.com/j-bennet +[Shawn Chapla]: https://github.com/shwnchpl +[Zhaolong Zhu]: https://github.com/zzl0 +[Zhiming Wang]: https://github.com/zmwangx diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..ad22cdf --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,119 @@ +# Development Guide + +This is a guide for developers who would like to contribute to this project. It is recommended to use Python 3.6 and above for development. + +If you're interested in contributing to litecli, thank you. We'd love your help! +You'll always get credit for your work. + +## GitHub Workflow + +1. [Fork the repository](https://github.com/dbcli/litecli) on GitHub. + +2. Clone your fork locally: + ```bash + $ git clone + ``` + +3. Add the official repository (`upstream`) as a remote repository: + ```bash + $ git remote add upstream git@github.com:dbcli/litecli.git + ``` + +4. Set up a [virtual environment](http://docs.python-guide.org/en/latest/dev/virtualenvs) + for development: + + ```bash + $ cd litecli + $ pip install virtualenv + $ virtualenv litecli_dev + ``` + + We've just created a virtual environment that we'll use to install all the dependencies + and tools we need to work on litecli. Whenever you want to work on litecli, you + need to activate the virtual environment: + + ```bash + $ source litecli_dev/bin/activate + ``` + + When you're done working, you can deactivate the virtual environment: + + ```bash + $ deactivate + ``` + +5. Install the dependencies and development tools: + + ```bash + $ pip install -r requirements-dev.txt + $ pip install --editable . + ``` + +6. Create a branch for your bugfix or feature based off the `master` branch: + + ```bash + $ git checkout -b + ``` + +7. While you work on your bugfix or feature, be sure to pull the latest changes from `upstream`. This ensures that your local codebase is up-to-date: + + ```bash + $ git pull upstream master + ``` + +8. When your work is ready for the litecli team to review it, push your branch to your fork: + + ```bash + $ git push origin + ``` + +9. [Create a pull request](https://help.github.com/articles/creating-a-pull-request-from-a-fork/) on GitHub. + + +## Running the Tests + +While you work on litecli, it's important to run the tests to make sure your code +hasn't broken any existing functionality. To run the tests, just type in: + +```bash +$ ./setup.py test +``` + +litecli supports Python 2.7 and 3.4+. You can test against multiple versions of +Python by running tox: + +```bash +$ tox +``` + + +### CLI Tests + +Some CLI tests expect the program `ex` to be a symbolic link to `vim`. + +In some systems (e.g. Arch Linux) `ex` is a symbolic link to `vi`, which will +change the output and therefore make some tests fail. + +You can check this by running: +```bash +$ readlink -f $(which ex) +``` + + +## Coding Style + +litecli uses [black](https://github.com/ambv/black) to format the source code. Make sure to install black. + +It's easy to check the style of your code, just run: + +```bash +$ ./setup.py lint +``` + +If you see any style issues, you can automatically fix them by running: + +```bash +$ ./setup.py lint --fix +``` + +Be sure to commit and push any stylistic fixes. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..c3bbe96 --- /dev/null +++ b/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2018, dbcli +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of dbcli nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..f1ff0f6 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,8 @@ +include *.txt *.py +include LICENSE CHANGELOG.md +include tox.ini +recursive-include tests *.py +recursive-include tests *.txt +recursive-include tests *.csv +recursive-include tests liteclirc +recursive-include litecli AUTHORS diff --git a/README.md b/README.md new file mode 100644 index 0000000..dfc995a --- /dev/null +++ b/README.md @@ -0,0 +1,54 @@ +# litecli + +[![Build Status](https://travis-ci.org/dbcli/litecli.svg?branch=master)](https://travis-ci.org/dbcli/litecli) + +[Docs](https://litecli.com) + +A command-line client for SQLite databases that has auto-completion and syntax highlighting. + +![Completion](screenshots/litecli.png) +![CompletionGif](screenshots/litecli.gif) + +## Installation + +If you already know how to install python packages, then you can install it via pip: + +You might need sudo on linux. + +``` +$ pip install -U litecli +``` + +The package is also available on Arch Linux through AUR in two versions: [litecli](https://aur.archlinux.org/packages/litecli/) is based the latest release (git tag) and [litecli-git](https://aur.archlinux.org/packages/litecli-git/) is based on the master branch of the git repo. You can install them manually or with an AUR helper such as `yay`: + +``` +$ yay -S litecli +``` +or + +``` +$ yay -S litecli-git +``` + +For MacOS users, you can also use Homebrew to install it: + +``` +$ brew install litecli +``` + +## Usage + +``` +$ litecli --help + +Usage: litecli [OPTIONS] [DATABASE] + +Examples: + - litecli sqlite_db_name +``` + +A config file is automatically created at `~/.config/litecli/config` at first launch. See the file itself for a description of all available options. + +## Docs + +Visit: [litecli.com/features](https://litecli.com/features) diff --git a/TODO b/TODO new file mode 100644 index 0000000..7c854dc --- /dev/null +++ b/TODO @@ -0,0 +1,3 @@ +* [] Sort by frecency. +* [] Add completions when an attach database command is run. +* [] Add behave tests. diff --git a/litecli/AUTHORS b/litecli/AUTHORS new file mode 100644 index 0000000..d5265de --- /dev/null +++ b/litecli/AUTHORS @@ -0,0 +1,20 @@ +Project Lead: +------------- + + * Delgermurun Purevkhu + + +Core Developers: +---------------- + + * Amjith Ramanujam + * Irina Truong + * Dick Marinus + +Contributors: +------------- + + * Thomas Roten + * Zhaolong Zhu + * Zhiming Wang + * Shawn M. Chapla diff --git a/litecli/__init__.py b/litecli/__init__.py new file mode 100644 index 0000000..5b60188 --- /dev/null +++ b/litecli/__init__.py @@ -0,0 +1 @@ +__version__ = "1.5.0" diff --git a/litecli/clibuffer.py b/litecli/clibuffer.py new file mode 100644 index 0000000..a57192a --- /dev/null +++ b/litecli/clibuffer.py @@ -0,0 +1,40 @@ +from __future__ import unicode_literals + +from prompt_toolkit.enums import DEFAULT_BUFFER +from prompt_toolkit.filters import Condition +from prompt_toolkit.application import get_app + + +def cli_is_multiline(cli): + @Condition + def cond(): + doc = get_app().layout.get_buffer_by_name(DEFAULT_BUFFER).document + + if not cli.multi_line: + return False + else: + return not _multiline_exception(doc.text) + + return cond + + +def _multiline_exception(text): + orig = text + text = text.strip() + + # Multi-statement favorite query is a special case. Because there will + # be a semicolon separating statements, we can't consider semicolon an + # EOL. Let's consider an empty line an EOL instead. + if text.startswith("\\fs"): + return orig.endswith("\n") + + return ( + text.startswith("\\") # Special Command + or text.endswith(";") # Ended with a semi-colon + or text.endswith("\\g") # Ended with \g + or text.endswith("\\G") # Ended with \G + or (text == "exit") # Exit doesn't need semi-colon + or (text == "quit") # Quit doesn't need semi-colon + or (text == ":q") # To all the vim fans out there + or (text == "") # Just a plain enter without any text + ) diff --git a/litecli/clistyle.py b/litecli/clistyle.py new file mode 100644 index 0000000..7527315 --- /dev/null +++ b/litecli/clistyle.py @@ -0,0 +1,114 @@ +from __future__ import unicode_literals + +import logging + +import pygments.styles +from pygments.token import string_to_tokentype, Token +from pygments.style import Style as PygmentsStyle +from pygments.util import ClassNotFound +from prompt_toolkit.styles.pygments import style_from_pygments_cls +from prompt_toolkit.styles import merge_styles, Style + +logger = logging.getLogger(__name__) + +# map Pygments tokens (ptk 1.0) to class names (ptk 2.0). +TOKEN_TO_PROMPT_STYLE = { + Token.Menu.Completions.Completion.Current: "completion-menu.completion.current", + Token.Menu.Completions.Completion: "completion-menu.completion", + Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current", + Token.Menu.Completions.Meta: "completion-menu.meta.completion", + Token.Menu.Completions.MultiColumnMeta: "completion-menu.multi-column-meta", + Token.Menu.Completions.ProgressButton: "scrollbar.arrow", # best guess + Token.Menu.Completions.ProgressBar: "scrollbar", # best guess + Token.SelectedText: "selected", + Token.SearchMatch: "search", + Token.SearchMatch.Current: "search.current", + Token.Toolbar: "bottom-toolbar", + Token.Toolbar.Off: "bottom-toolbar.off", + Token.Toolbar.On: "bottom-toolbar.on", + Token.Toolbar.Search: "search-toolbar", + Token.Toolbar.Search.Text: "search-toolbar.text", + Token.Toolbar.System: "system-toolbar", + Token.Toolbar.Arg: "arg-toolbar", + Token.Toolbar.Arg.Text: "arg-toolbar.text", + Token.Toolbar.Transaction.Valid: "bottom-toolbar.transaction.valid", + Token.Toolbar.Transaction.Failed: "bottom-toolbar.transaction.failed", + Token.Output.Header: "output.header", + Token.Output.OddRow: "output.odd-row", + Token.Output.EvenRow: "output.even-row", + Token.Prompt: "prompt", + Token.Continuation: "continuation", +} + +# reverse dict for cli_helpers, because they still expect Pygments tokens. +PROMPT_STYLE_TO_TOKEN = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()} + + +def parse_pygments_style(token_name, style_object, style_dict): + """Parse token type and style string. + + :param token_name: str name of Pygments token. Example: "Token.String" + :param style_object: pygments.style.Style instance to use as base + :param style_dict: dict of token names and their styles, customized to this cli + + """ + token_type = string_to_tokentype(token_name) + try: + other_token_type = string_to_tokentype(style_dict[token_name]) + return token_type, style_object.styles[other_token_type] + except AttributeError as err: + return token_type, style_dict[token_name] + + +def style_factory(name, cli_style): + try: + style = pygments.styles.get_style_by_name(name) + except ClassNotFound: + style = pygments.styles.get_style_by_name("native") + + prompt_styles = [] + # prompt-toolkit used pygments tokens for styling before, switched to style + # names in 2.0. Convert old token types to new style names, for backwards compatibility. + for token in cli_style: + if token.startswith("Token."): + # treat as pygments token (1.0) + token_type, style_value = parse_pygments_style(token, style, cli_style) + if token_type in TOKEN_TO_PROMPT_STYLE: + prompt_style = TOKEN_TO_PROMPT_STYLE[token_type] + prompt_styles.append((prompt_style, style_value)) + else: + # we don't want to support tokens anymore + logger.error("Unhandled style / class name: %s", token) + else: + # treat as prompt style name (2.0). See default style names here: + # https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py + prompt_styles.append((token, cli_style[token])) + + override_style = Style([("bottom-toolbar", "noreverse")]) + return merge_styles( + [style_from_pygments_cls(style), override_style, Style(prompt_styles)] + ) + + +def style_factory_output(name, cli_style): + try: + style = pygments.styles.get_style_by_name(name).styles + except ClassNotFound: + style = pygments.styles.get_style_by_name("native").styles + + for token in cli_style: + if token.startswith("Token."): + token_type, style_value = parse_pygments_style(token, style, cli_style) + style.update({token_type: style_value}) + elif token in PROMPT_STYLE_TO_TOKEN: + token_type = PROMPT_STYLE_TO_TOKEN[token] + style.update({token_type: cli_style[token]}) + else: + # TODO: cli helpers will have to switch to ptk.Style + logger.error("Unhandled style / class name: %s", token) + + class OutputStyle(PygmentsStyle): + default_style = "" + styles = style + + return OutputStyle diff --git a/litecli/clitoolbar.py b/litecli/clitoolbar.py new file mode 100644 index 0000000..05d0bfd --- /dev/null +++ b/litecli/clitoolbar.py @@ -0,0 +1,51 @@ +from __future__ import unicode_literals + +from prompt_toolkit.key_binding.vi_state import InputMode +from prompt_toolkit.enums import EditingMode +from prompt_toolkit.application import get_app + + +def create_toolbar_tokens_func(cli, show_fish_help): + """ + Return a function that generates the toolbar tokens. + """ + + def get_toolbar_tokens(): + result = [] + result.append(("class:bottom-toolbar", " ")) + + if cli.multi_line: + result.append( + ("class:bottom-toolbar", " (Semi-colon [;] will end the line) ") + ) + + if cli.multi_line: + result.append(("class:bottom-toolbar.on", "[F3] Multiline: ON ")) + else: + result.append(("class:bottom-toolbar.off", "[F3] Multiline: OFF ")) + if cli.prompt_app.editing_mode == EditingMode.VI: + result.append( + ("class:botton-toolbar.on", "Vi-mode ({})".format(_get_vi_mode())) + ) + + if show_fish_help(): + result.append( + ("class:bottom-toolbar", " Right-arrow to complete suggestion") + ) + + if cli.completion_refresher.is_refreshing(): + result.append(("class:bottom-toolbar", " Refreshing completions...")) + + return result + + return get_toolbar_tokens + + +def _get_vi_mode(): + """Get the current vi mode for display.""" + return { + InputMode.INSERT: "I", + InputMode.NAVIGATION: "N", + InputMode.REPLACE: "R", + InputMode.INSERT_MULTIPLE: "M", + }[get_app().vi_state.input_mode] diff --git a/litecli/compat.py b/litecli/compat.py new file mode 100644 index 0000000..7316261 --- /dev/null +++ b/litecli/compat.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- +"""Platform and Python version compatibility support.""" + +import sys + + +PY2 = sys.version_info[0] == 2 +PY3 = sys.version_info[0] == 3 +WIN = sys.platform in ("win32", "cygwin") diff --git a/litecli/completion_refresher.py b/litecli/completion_refresher.py new file mode 100644 index 0000000..9602070 --- /dev/null +++ b/litecli/completion_refresher.py @@ -0,0 +1,131 @@ +import threading +from .packages.special.main import COMMANDS +from collections import OrderedDict + +from .sqlcompleter import SQLCompleter +from .sqlexecute import SQLExecute + + +class CompletionRefresher(object): + + refreshers = OrderedDict() + + def __init__(self): + self._completer_thread = None + self._restart_refresh = threading.Event() + + def refresh(self, executor, callbacks, completer_options=None): + """Creates a SQLCompleter object and populates it with the relevant + completion suggestions in a background thread. + + executor - SQLExecute object, used to extract the credentials to connect + to the database. + callbacks - A function or a list of functions to call after the thread + has completed the refresh. The newly created completion + object will be passed in as an argument to each callback. + completer_options - dict of options to pass to SQLCompleter. + + """ + if completer_options is None: + completer_options = {} + + if self.is_refreshing(): + self._restart_refresh.set() + return [(None, None, None, "Auto-completion refresh restarted.")] + else: + if executor.dbname == ":memory:": + # if DB is memory, needed to use same connection + # So can't use same connection with different thread + self._bg_refresh(executor, callbacks, completer_options) + else: + self._completer_thread = threading.Thread( + target=self._bg_refresh, + args=(executor, callbacks, completer_options), + name="completion_refresh", + ) + self._completer_thread.setDaemon(True) + self._completer_thread.start() + return [ + ( + None, + None, + None, + "Auto-completion refresh started in the background.", + ) + ] + + def is_refreshing(self): + return self._completer_thread and self._completer_thread.is_alive() + + def _bg_refresh(self, sqlexecute, callbacks, completer_options): + completer = SQLCompleter(**completer_options) + + e = sqlexecute + if e.dbname == ":memory:": + # if DB is memory, needed to use same connection + executor = sqlexecute + else: + # Create a new sqlexecute method to popoulate the completions. + executor = SQLExecute(e.dbname) + + # If callbacks is a single function then push it into a list. + if callable(callbacks): + callbacks = [callbacks] + + while 1: + for refresher in self.refreshers.values(): + refresher(completer, executor) + if self._restart_refresh.is_set(): + self._restart_refresh.clear() + break + else: + # Break out of while loop if the for loop finishes natually + # without hitting the break statement. + break + + # Start over the refresh from the beginning if the for loop hit the + # break statement. + continue + + for callback in callbacks: + callback(completer) + + +def refresher(name, refreshers=CompletionRefresher.refreshers): + """Decorator to add the decorated function to the dictionary of + refreshers. Any function decorated with a @refresher will be executed as + part of the completion refresh routine.""" + + def wrapper(wrapped): + refreshers[name] = wrapped + return wrapped + + return wrapper + + +@refresher("databases") +def refresh_databases(completer, executor): + completer.extend_database_names(executor.databases()) + + +@refresher("schemata") +def refresh_schemata(completer, executor): + # name of the current database. + completer.extend_schemata(executor.dbname) + completer.set_dbname(executor.dbname) + + +@refresher("tables") +def refresh_tables(completer, executor): + completer.extend_relations(executor.tables(), kind="tables") + completer.extend_columns(executor.table_columns(), kind="tables") + + +@refresher("functions") +def refresh_functions(completer, executor): + completer.extend_functions(executor.functions()) + + +@refresher("special_commands") +def refresh_special(completer, executor): + completer.extend_special_commands(COMMANDS.keys()) diff --git a/litecli/config.py b/litecli/config.py new file mode 100644 index 0000000..1c7fb25 --- /dev/null +++ b/litecli/config.py @@ -0,0 +1,62 @@ +import errno +import shutil +import os +import platform +from os.path import expanduser, exists, dirname +from configobj import ConfigObj + + +def config_location(): + if "XDG_CONFIG_HOME" in os.environ: + return "%s/litecli/" % expanduser(os.environ["XDG_CONFIG_HOME"]) + elif platform.system() == "Windows": + return os.getenv("USERPROFILE") + "\\AppData\\Local\\dbcli\\litecli\\" + else: + return expanduser("~/.config/litecli/") + + +def load_config(usr_cfg, def_cfg=None): + cfg = ConfigObj() + cfg.merge(ConfigObj(def_cfg, interpolation=False)) + cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8")) + cfg.filename = expanduser(usr_cfg) + + return cfg + + +def ensure_dir_exists(path): + parent_dir = expanduser(dirname(path)) + try: + os.makedirs(parent_dir) + except OSError as exc: + # ignore existing destination (py2 has no exist_ok arg to makedirs) + if exc.errno != errno.EEXIST: + raise + + +def write_default_config(source, destination, overwrite=False): + destination = expanduser(destination) + if not overwrite and exists(destination): + return + + ensure_dir_exists(destination) + + shutil.copyfile(source, destination) + + +def upgrade_config(config, def_config): + cfg = load_config(config, def_config) + cfg.write() + + +def get_config(liteclirc_file=None): + from litecli import __file__ as package_root + + package_root = os.path.dirname(package_root) + + liteclirc_file = liteclirc_file or "%sconfig" % config_location() + + default_config = os.path.join(package_root, "liteclirc") + write_default_config(default_config, liteclirc_file) + + return load_config(liteclirc_file, default_config) diff --git a/litecli/encodingutils.py b/litecli/encodingutils.py new file mode 100644 index 0000000..6caf14d --- /dev/null +++ b/litecli/encodingutils.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +from litecli.compat import PY2 + + +if PY2: + binary_type = str + string_types = basestring + text_type = unicode +else: + binary_type = bytes + string_types = str + text_type = str + + +def unicode2utf8(arg): + """Convert strings to UTF8-encoded bytes. + + Only in Python 2. In Python 3 the args are expected as unicode. + + """ + + if PY2 and isinstance(arg, text_type): + return arg.encode("utf-8") + return arg + + +def utf8tounicode(arg): + """Convert UTF8-encoded bytes to strings. + + Only in Python 2. In Python 3 the errors are returned as strings. + + """ + + if PY2 and isinstance(arg, binary_type): + return arg.decode("utf-8") + return arg diff --git a/litecli/key_bindings.py b/litecli/key_bindings.py new file mode 100644 index 0000000..44d59d2 --- /dev/null +++ b/litecli/key_bindings.py @@ -0,0 +1,84 @@ +from __future__ import unicode_literals +import logging +from prompt_toolkit.enums import EditingMode +from prompt_toolkit.filters import completion_is_selected +from prompt_toolkit.key_binding import KeyBindings + +_logger = logging.getLogger(__name__) + + +def cli_bindings(cli): + """Custom key bindings for cli.""" + kb = KeyBindings() + + @kb.add("f3") + def _(event): + """Enable/Disable Multiline Mode.""" + _logger.debug("Detected F3 key.") + cli.multi_line = not cli.multi_line + + @kb.add("f4") + def _(event): + """Toggle between Vi and Emacs mode.""" + _logger.debug("Detected F4 key.") + if cli.key_bindings == "vi": + event.app.editing_mode = EditingMode.EMACS + cli.key_bindings = "emacs" + else: + event.app.editing_mode = EditingMode.VI + cli.key_bindings = "vi" + + @kb.add("tab") + def _(event): + """Force autocompletion at cursor.""" + _logger.debug("Detected key.") + b = event.app.current_buffer + if b.complete_state: + b.complete_next() + else: + b.start_completion(select_first=True) + + @kb.add("s-tab") + def _(event): + """Force autocompletion at cursor.""" + _logger.debug("Detected key.") + b = event.app.current_buffer + if b.complete_state: + b.complete_previous() + else: + b.start_completion(select_last=True) + + @kb.add("c-space") + def _(event): + """ + Initialize autocompletion at cursor. + + If the autocompletion menu is not showing, display it with the + appropriate completions for the context. + + If the menu is showing, select the next completion. + """ + _logger.debug("Detected key.") + + b = event.app.current_buffer + if b.complete_state: + b.complete_next() + else: + b.start_completion(select_first=False) + + @kb.add("enter", filter=completion_is_selected) + def _(event): + """Makes the enter key work as the tab key only when showing the menu. + + In other words, don't execute query when enter is pressed in + the completion dropdown menu, instead close the dropdown menu + (accept current selection). + + """ + _logger.debug("Detected enter key.") + + event.current_buffer.complete_state = None + b = event.app.current_buffer + b.complete_state = None + + return kb diff --git a/litecli/lexer.py b/litecli/lexer.py new file mode 100644 index 0000000..678eb3f --- /dev/null +++ b/litecli/lexer.py @@ -0,0 +1,9 @@ +from pygments.lexer import inherit +from pygments.lexers.sql import MySqlLexer +from pygments.token import Keyword + + +class LiteCliLexer(MySqlLexer): + """Extends SQLite lexer to add keywords.""" + + tokens = {"root": [(r"\brepair\b", Keyword), (r"\boffset\b", Keyword), inherit]} diff --git a/litecli/liteclirc b/litecli/liteclirc new file mode 100644 index 0000000..e3331d1 --- /dev/null +++ b/litecli/liteclirc @@ -0,0 +1,113 @@ +# vi: ft=dosini +[main] + +# Multi-line mode allows breaking up the sql statements into multiple lines. If +# this is set to True, then the end of the statements must have a semi-colon. +# If this is set to False then sql statements can't be split into multiple +# lines. End of line (return) is considered as the end of the statement. +multi_line = False + +# Destructive warning mode will alert you before executing a sql statement +# that may cause harm to the database such as "drop table", "drop database" +# or "shutdown". +destructive_warning = True + +# log_file location. +# In Unix/Linux: ~/.config/litecli/log +# In Windows: %USERPROFILE%\AppData\Local\dbcli\litecli\log +# %USERPROFILE% is typically C:\Users\{username} +log_file = default + +# Default log level. Possible values: "CRITICAL", "ERROR", "WARNING", "INFO" +# and "DEBUG". "NONE" disables logging. +log_level = INFO + +# Log every query and its results to a file. Enable this by uncommenting the +# line below. +# audit_log = ~/.litecli-audit.log + +# Default pager. +# By default '$PAGER' environment variable is used +# pager = less -SRXF + +# Table format. Possible values: +# ascii, double, github, psql, plain, simple, grid, fancy_grid, pipe, orgtbl, +# rst, mediawiki, html, latex, latex_booktabs, textile, moinmoin, jira, +# vertical, tsv, csv. +# Recommended: ascii +table_format = ascii + +# Syntax coloring style. Possible values (many support the "-dark" suffix): +# manni, igor, xcode, vim, autumn, vs, rrt, native, perldoc, borland, tango, emacs, +# friendly, monokai, paraiso, colorful, murphy, bw, pastie, paraiso, trac, default, +# fruity. +# Screenshots at http://mycli.net/syntax +syntax_style = default + +# Keybindings: Possible values: emacs, vi. +# Emacs mode: Ctrl-A is home, Ctrl-E is end. All emacs keybindings are available in the REPL. +# When Vi mode is enabled you can use modal editing features offered by Vi in the REPL. +key_bindings = emacs + +# Enabling this option will show the suggestions in a wider menu. Thus more items are suggested. +wider_completion_menu = False + +# litecli prompt +# \D - The full current date +# \d - Database name +# \m - Minutes of the current time +# \n - Newline +# \P - AM/PM +# \R - The current time, in 24-hour military time (0-23) +# \r - The current time, standard 12-hour time (1-12) +# \s - Seconds of the current time +prompt = '\d> ' +prompt_continuation = '-> ' + +# Skip intro info on startup and outro info on exit +less_chatty = False + +# Use alias from --login-path instead of host name in prompt +login_path_as_host = False + +# Cause result sets to be displayed vertically if they are too wide for the current window, +# and using normal tabular format otherwise. (This applies to statements terminated by ; or \G.) +auto_vertical_output = False + +# keyword casing preference. Possible values "lower", "upper", "auto" +keyword_casing = auto + +# disabled pager on startup +enable_pager = True + +# Custom colors for the completion menu, toolbar, etc. +[colors] +completion-menu.completion.current = 'bg:#ffffff #000000' +completion-menu.completion = 'bg:#008888 #ffffff' +completion-menu.meta.completion.current = 'bg:#44aaaa #000000' +completion-menu.meta.completion = 'bg:#448888 #ffffff' +completion-menu.multi-column-meta = 'bg:#aaffff #000000' +scrollbar.arrow = 'bg:#003333' +scrollbar = 'bg:#00aaaa' +selected = '#ffffff bg:#6666aa' +search = '#ffffff bg:#4444aa' +search.current = '#ffffff bg:#44aa44' +bottom-toolbar = 'bg:#222222 #aaaaaa' +bottom-toolbar.off = 'bg:#222222 #888888' +bottom-toolbar.on = 'bg:#222222 #ffffff' +search-toolbar = 'noinherit bold' +search-toolbar.text = 'nobold' +system-toolbar = 'noinherit bold' +arg-toolbar = 'noinherit bold' +arg-toolbar.text = 'nobold' +bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold' +bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold' + +# style classes for colored table output +output.header = "#00ff5f bold" +output.odd-row = "" +output.even-row = "" + + +# Favorite queries. +[favorite_queries] diff --git a/litecli/main.py b/litecli/main.py new file mode 100644 index 0000000..5768851 --- /dev/null +++ b/litecli/main.py @@ -0,0 +1,1008 @@ +from __future__ import unicode_literals +from __future__ import print_function + +import os +import sys +import traceback +import logging +import threading +from time import time +from datetime import datetime +from io import open +from collections import namedtuple +from sqlite3 import OperationalError + +from cli_helpers.tabular_output import TabularOutputFormatter +from cli_helpers.tabular_output import preprocessors +import click +import sqlparse +from prompt_toolkit.completion import DynamicCompleter +from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode +from prompt_toolkit.shortcuts import PromptSession, CompleteStyle +from prompt_toolkit.styles.pygments import style_from_pygments_cls +from prompt_toolkit.document import Document +from prompt_toolkit.filters import HasFocus, IsDone +from prompt_toolkit.layout.processors import ( + HighlightMatchingBracketProcessor, + ConditionalProcessor, +) +from prompt_toolkit.lexers import PygmentsLexer +from prompt_toolkit.history import FileHistory +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory + +from .packages.special.main import NO_QUERY +from .packages.prompt_utils import confirm, confirm_destructive_query +from .packages import special +from .sqlcompleter import SQLCompleter +from .clitoolbar import create_toolbar_tokens_func +from .clistyle import style_factory, style_factory_output +from .sqlexecute import SQLExecute +from .clibuffer import cli_is_multiline +from .completion_refresher import CompletionRefresher +from .config import config_location, ensure_dir_exists, get_config +from .key_bindings import cli_bindings +from .encodingutils import utf8tounicode, text_type +from .lexer import LiteCliLexer +from .__init__ import __version__ +from .packages.filepaths import dir_path_exists + +import itertools + +click.disable_unicode_literals_warning = True + +# Query tuples are used for maintaining history +Query = namedtuple("Query", ["query", "successful", "mutating"]) + +PACKAGE_ROOT = os.path.abspath(os.path.dirname(__file__)) + + +class LiteCli(object): + + default_prompt = "\\d> " + max_len_prompt = 45 + + def __init__( + self, + sqlexecute=None, + prompt=None, + logfile=None, + auto_vertical_output=False, + warn=None, + liteclirc=None, + ): + self.sqlexecute = sqlexecute + self.logfile = logfile + + # Load config. + c = self.config = get_config(liteclirc) + + self.multi_line = c["main"].as_bool("multi_line") + self.key_bindings = c["main"]["key_bindings"] + special.set_favorite_queries(self.config) + self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"]) + self.formatter.litecli = self + self.syntax_style = c["main"]["syntax_style"] + self.less_chatty = c["main"].as_bool("less_chatty") + self.cli_style = c["colors"] + self.output_style = style_factory_output(self.syntax_style, self.cli_style) + self.wider_completion_menu = c["main"].as_bool("wider_completion_menu") + c_dest_warning = c["main"].as_bool("destructive_warning") + self.destructive_warning = c_dest_warning if warn is None else warn + self.login_path_as_host = c["main"].as_bool("login_path_as_host") + + # read from cli argument or user config file + self.auto_vertical_output = auto_vertical_output or c["main"].as_bool( + "auto_vertical_output" + ) + + # audit log + if self.logfile is None and "audit_log" in c["main"]: + try: + self.logfile = open(os.path.expanduser(c["main"]["audit_log"]), "a") + except (IOError, OSError): + self.echo( + "Error: Unable to open the audit log file. Your queries will not be logged.", + err=True, + fg="red", + ) + self.logfile = False + + self.completion_refresher = CompletionRefresher() + + self.logger = logging.getLogger(__name__) + self.initialize_logging() + + prompt_cnf = self.read_my_cnf_files(["prompt"])["prompt"] + self.prompt_format = ( + prompt or prompt_cnf or c["main"]["prompt"] or self.default_prompt + ) + self.prompt_continuation_format = c["main"]["prompt_continuation"] + keyword_casing = c["main"].get("keyword_casing", "auto") + + self.query_history = [] + + # Initialize completer. + self.completer = SQLCompleter( + supported_formats=self.formatter.supported_formats, + keyword_casing=keyword_casing, + ) + self._completer_lock = threading.Lock() + + # Register custom special commands. + self.register_special_commands() + + self.prompt_app = None + + def register_special_commands(self): + special.register_special_command( + self.change_db, + ".open", + ".open", + "Change to a new database.", + aliases=("use", "\\u"), + ) + special.register_special_command( + self.refresh_completions, + "rehash", + "\\#", + "Refresh auto-completions.", + arg_type=NO_QUERY, + aliases=("\\#",), + ) + special.register_special_command( + self.change_table_format, + ".mode", + "\\T", + "Change the table format used to output results.", + aliases=("tableformat", "\\T"), + case_sensitive=True, + ) + special.register_special_command( + self.execute_from_file, + "source", + "\\. filename", + "Execute commands from file.", + aliases=("\\.",), + ) + special.register_special_command( + self.change_prompt_format, + "prompt", + "\\R", + "Change prompt format.", + aliases=("\\R",), + case_sensitive=True, + ) + + def change_table_format(self, arg, **_): + try: + self.formatter.format_name = arg + yield (None, None, None, "Changed table format to {}".format(arg)) + except ValueError: + msg = "Table format {} not recognized. Allowed formats:".format(arg) + for table_type in self.formatter.supported_formats: + msg += "\n\t{}".format(table_type) + yield (None, None, None, msg) + + def change_db(self, arg, **_): + if arg is None: + self.sqlexecute.connect() + else: + self.sqlexecute.connect(database=arg) + + self.refresh_completions() + yield ( + None, + None, + None, + 'You are now connected to database "%s"' % (self.sqlexecute.dbname), + ) + + def execute_from_file(self, arg, **_): + if not arg: + message = "Missing required argument, filename." + return [(None, None, None, message)] + try: + with open(os.path.expanduser(arg), encoding="utf-8") as f: + query = f.read() + except IOError as e: + return [(None, None, None, str(e))] + + if self.destructive_warning and confirm_destructive_query(query) is False: + message = "Wise choice. Command execution stopped." + return [(None, None, None, message)] + + return self.sqlexecute.run(query) + + def change_prompt_format(self, arg, **_): + """ + Change the prompt format. + """ + if not arg: + message = "Missing required argument, format." + return [(None, None, None, message)] + + self.prompt_format = self.get_prompt(arg) + return [(None, None, None, "Changed prompt format to %s" % arg)] + + def initialize_logging(self): + + log_file = self.config["main"]["log_file"] + if log_file == "default": + log_file = config_location() + "log" + ensure_dir_exists(log_file) + + log_level = self.config["main"]["log_level"] + + level_map = { + "CRITICAL": logging.CRITICAL, + "ERROR": logging.ERROR, + "WARNING": logging.WARNING, + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, + } + + # Disable logging if value is NONE by switching to a no-op handler + # Set log level to a high value so it doesn't even waste cycles getting called. + if log_level.upper() == "NONE": + handler = logging.NullHandler() + log_level = "CRITICAL" + elif dir_path_exists(log_file): + handler = logging.FileHandler(log_file) + else: + self.echo( + 'Error: Unable to open the log file "{}".'.format(log_file), + err=True, + fg="red", + ) + return + + formatter = logging.Formatter( + "%(asctime)s (%(process)d/%(threadName)s) " + "%(name)s %(levelname)s - %(message)s" + ) + + handler.setFormatter(formatter) + + root_logger = logging.getLogger("litecli") + root_logger.addHandler(handler) + root_logger.setLevel(level_map[log_level.upper()]) + + logging.captureWarnings(True) + + root_logger.debug("Initializing litecli logging.") + root_logger.debug("Log file %r.", log_file) + + def read_my_cnf_files(self, keys): + """ + Reads a list of config files and merges them. The last one will win. + :param files: list of files to read + :param keys: list of keys to retrieve + :returns: tuple, with None for missing keys. + """ + cnf = self.config + + sections = ["main"] + + def get(key): + result = None + for sect in cnf: + if sect in sections and key in cnf[sect]: + result = cnf[sect][key] + return result + + return {x: get(x) for x in keys} + + def connect(self, database=""): + + cnf = {"database": None} + + cnf = self.read_my_cnf_files(cnf.keys()) + + # Fall back to config values only if user did not specify a value. + + database = database or cnf["database"] + + # Connect to the database. + + def _connect(): + self.sqlexecute = SQLExecute(database) + + try: + _connect() + except Exception as e: # Connecting to a database could fail. + self.logger.debug("Database connection failed: %r.", e) + self.logger.error("traceback: %r", traceback.format_exc()) + self.echo(str(e), err=True, fg="red") + exit(1) + + def handle_editor_command(self, text): + """Editor command is any query that is prefixed or suffixed by a '\e'. + The reason for a while loop is because a user might edit a query + multiple times. For eg: + + "select * from \e" to edit it in vim, then come + back to the prompt with the edited query "select * from + blah where q = 'abc'\e" to edit it again. + :param text: Document + :return: Document + + """ + + while special.editor_command(text): + filename = special.get_filename(text) + query = special.get_editor_query(text) or self.get_last_query() + sql, message = special.open_external_editor(filename, sql=query) + if message: + # Something went wrong. Raise an exception and bail. + raise RuntimeError(message) + while True: + try: + text = self.prompt_app.prompt(default=sql) + break + except KeyboardInterrupt: + sql = "" + + continue + return text + + def run_cli(self): + iterations = 0 + sqlexecute = self.sqlexecute + logger = self.logger + self.configure_pager() + self.refresh_completions() + + history_file = config_location() + "history" + if dir_path_exists(history_file): + history = FileHistory(history_file) + else: + history = None + self.echo( + 'Error: Unable to open the history file "{}". ' + "Your query history will not be saved.".format(history_file), + err=True, + fg="red", + ) + + key_bindings = cli_bindings(self) + + if not self.less_chatty: + print("Version:", __version__) + print("Mail: https://groups.google.com/forum/#!forum/litecli-users") + print("GitHub: https://github.com/dbcli/litecli") + # print("Home: https://litecli.com") + + def get_message(): + prompt = self.get_prompt(self.prompt_format) + if ( + self.prompt_format == self.default_prompt + and len(prompt) > self.max_len_prompt + ): + prompt = self.get_prompt("\\d> ") + return [("class:prompt", prompt)] + + def get_continuation(width, line_number, is_soft_wrap): + continuation = " " * (width - 1) + " " + return [("class:continuation", continuation)] + + def show_suggestion_tip(): + return iterations < 2 + + def one_iteration(text=None): + if text is None: + try: + text = self.prompt_app.prompt() + except KeyboardInterrupt: + return + + special.set_expanded_output(False) + + try: + text = self.handle_editor_command(text) + except RuntimeError as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + self.echo(str(e), err=True, fg="red") + return + + if not text.strip(): + return + + if self.destructive_warning: + destroy = confirm_destructive_query(text) + if destroy is None: + pass # Query was not destructive. Nothing to do here. + elif destroy is True: + self.echo("Your call!") + else: + self.echo("Wise choice!") + return + + # Keep track of whether or not the query is mutating. In case + # of a multi-statement query, the overall query is considered + # mutating if any one of the component statements is mutating + mutating = False + + try: + logger.debug("sql: %r", text) + + special.write_tee(self.get_prompt(self.prompt_format) + text) + if self.logfile: + self.logfile.write("\n# %s\n" % datetime.now()) + self.logfile.write(text) + self.logfile.write("\n") + + successful = False + start = time() + res = sqlexecute.run(text) + self.formatter.query = text + successful = True + result_count = 0 + for title, cur, headers, status in res: + logger.debug("headers: %r", headers) + logger.debug("rows: %r", cur) + logger.debug("status: %r", status) + threshold = 1000 + if is_select(status) and cur and cur.rowcount > threshold: + self.echo( + "The result set has more than {} rows.".format(threshold), + fg="red", + ) + if not confirm("Do you want to continue?"): + self.echo("Aborted!", err=True, fg="red") + break + + if self.auto_vertical_output: + max_width = self.prompt_app.output.get_size().columns + else: + max_width = None + + formatted = self.format_output( + title, cur, headers, special.is_expanded_output(), max_width + ) + + t = time() - start + try: + if result_count > 0: + self.echo("") + try: + self.output(formatted, status) + except KeyboardInterrupt: + pass + self.echo("Time: %0.03fs" % t) + except KeyboardInterrupt: + pass + + start = time() + result_count += 1 + mutating = mutating or is_mutating(status) + special.unset_once_if_written() + except EOFError as e: + raise e + except KeyboardInterrupt: + # get last connection id + connection_id_to_kill = sqlexecute.connection_id + logger.debug("connection id to kill: %r", connection_id_to_kill) + # Restart connection to the database + sqlexecute.connect() + try: + for title, cur, headers, status in sqlexecute.run( + "kill %s" % connection_id_to_kill + ): + status_str = str(status).lower() + if status_str.find("ok") > -1: + logger.debug( + "cancelled query, connection id: %r, sql: %r", + connection_id_to_kill, + text, + ) + self.echo("cancelled query", err=True, fg="red") + except Exception as e: + self.echo( + "Encountered error while cancelling query: {}".format(e), + err=True, + fg="red", + ) + except NotImplementedError: + self.echo("Not Yet Implemented.", fg="yellow") + except OperationalError as e: + logger.debug("Exception: %r", e) + if e.args[0] in (2003, 2006, 2013): + logger.debug("Attempting to reconnect.") + self.echo("Reconnecting...", fg="yellow") + try: + sqlexecute.connect() + logger.debug("Reconnected successfully.") + one_iteration(text) + return # OK to just return, cuz the recursion call runs to the end. + except OperationalError as e: + logger.debug("Reconnect failed. e: %r", e) + self.echo(str(e), err=True, fg="red") + # If reconnection failed, don't proceed further. + return + else: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + self.echo(str(e), err=True, fg="red") + except Exception as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + self.echo(str(e), err=True, fg="red") + else: + # Refresh the table names and column names if necessary. + if need_completion_refresh(text): + self.refresh_completions(reset=need_completion_reset(text)) + finally: + if self.logfile is False: + self.echo("Warning: This query was not logged.", err=True, fg="red") + query = Query(text, successful, mutating) + self.query_history.append(query) + + get_toolbar_tokens = create_toolbar_tokens_func(self, show_suggestion_tip) + + if self.wider_completion_menu: + complete_style = CompleteStyle.MULTI_COLUMN + else: + complete_style = CompleteStyle.COLUMN + + with self._completer_lock: + + if self.key_bindings == "vi": + editing_mode = EditingMode.VI + else: + editing_mode = EditingMode.EMACS + + self.prompt_app = PromptSession( + lexer=PygmentsLexer(LiteCliLexer), + reserve_space_for_menu=self.get_reserved_space(), + message=get_message, + prompt_continuation=get_continuation, + bottom_toolbar=get_toolbar_tokens, + complete_style=complete_style, + input_processors=[ + ConditionalProcessor( + processor=HighlightMatchingBracketProcessor(chars="[](){}"), + filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(), + ) + ], + tempfile_suffix=".sql", + completer=DynamicCompleter(lambda: self.completer), + history=history, + auto_suggest=AutoSuggestFromHistory(), + complete_while_typing=True, + multiline=cli_is_multiline(self), + style=style_factory(self.syntax_style, self.cli_style), + include_default_pygments_style=False, + key_bindings=key_bindings, + enable_open_in_editor=True, + enable_system_prompt=True, + enable_suspend=True, + editing_mode=editing_mode, + search_ignore_case=True, + ) + + try: + while True: + one_iteration() + iterations += 1 + except EOFError: + special.close_tee() + if not self.less_chatty: + self.echo("Goodbye!") + + def log_output(self, output): + """Log the output in the audit log, if it's enabled.""" + if self.logfile: + click.echo(utf8tounicode(output), file=self.logfile) + + def echo(self, s, **kwargs): + """Print a message to stdout. + + The message will be logged in the audit log, if enabled. + + All keyword arguments are passed to click.echo(). + + """ + self.log_output(s) + click.secho(s, **kwargs) + + def get_output_margin(self, status=None): + """Get the output margin (number of rows for the prompt, footer and + timing message.""" + margin = ( + self.get_reserved_space() + + self.get_prompt(self.prompt_format).count("\n") + + 2 + ) + if status: + margin += 1 + status.count("\n") + + return margin + + def output(self, output, status=None): + """Output text to stdout or a pager command. + + The status text is not outputted to pager or files. + + The message will be logged in the audit log, if enabled. The + message will be written to the tee file, if enabled. The + message will be written to the output file, if enabled. + + """ + if output: + size = self.prompt_app.output.get_size() + + margin = self.get_output_margin(status) + + fits = True + buf = [] + output_via_pager = self.explicit_pager and special.is_pager_enabled() + for i, line in enumerate(output, 1): + self.log_output(line) + special.write_tee(line) + special.write_once(line) + + if fits or output_via_pager: + # buffering + buf.append(line) + if len(line) > size.columns or i > (size.rows - margin): + fits = False + if not self.explicit_pager and special.is_pager_enabled(): + # doesn't fit, use pager + output_via_pager = True + + if not output_via_pager: + # doesn't fit, flush buffer + for line in buf: + click.secho(line) + buf = [] + else: + click.secho(line) + + if buf: + if output_via_pager: + # sadly click.echo_via_pager doesn't accept generators + click.echo_via_pager("\n".join(buf)) + else: + for line in buf: + click.secho(line) + + if status: + self.log_output(status) + click.secho(status) + + def configure_pager(self): + # Provide sane defaults for less if they are empty. + if not os.environ.get("LESS"): + os.environ["LESS"] = "-RXF" + + cnf = self.read_my_cnf_files(["pager", "skip-pager"]) + if cnf["pager"]: + special.set_pager(cnf["pager"]) + self.explicit_pager = True + else: + self.explicit_pager = False + + if cnf["skip-pager"] or not self.config["main"].as_bool("enable_pager"): + special.disable_pager() + + def refresh_completions(self, reset=False): + if reset: + with self._completer_lock: + self.completer.reset_completions() + self.completion_refresher.refresh( + self.sqlexecute, + self._on_completions_refreshed, + { + "supported_formats": self.formatter.supported_formats, + "keyword_casing": self.completer.keyword_casing, + }, + ) + + return [ + (None, None, None, "Auto-completion refresh started in the background.") + ] + + def _on_completions_refreshed(self, new_completer): + """Swap the completer object in cli with the newly created completer. + """ + with self._completer_lock: + self.completer = new_completer + + if self.prompt_app: + # After refreshing, redraw the CLI to clear the statusbar + # "Refreshing completions..." indicator + self.prompt_app.app.invalidate() + + def get_completions(self, text, cursor_positition): + with self._completer_lock: + return self.completer.get_completions( + Document(text=text, cursor_position=cursor_positition), None + ) + + def get_prompt(self, string): + self.logger.debug("Getting prompt") + sqlexecute = self.sqlexecute + now = datetime.now() + string = string.replace("\\d", sqlexecute.dbname or "(none)") + string = string.replace("\\n", "\n") + string = string.replace("\\D", now.strftime("%a %b %d %H:%M:%S %Y")) + string = string.replace("\\m", now.strftime("%M")) + string = string.replace("\\P", now.strftime("%p")) + string = string.replace("\\R", now.strftime("%H")) + string = string.replace("\\r", now.strftime("%I")) + string = string.replace("\\s", now.strftime("%S")) + string = string.replace("\\_", " ") + return string + + def run_query(self, query, new_line=True): + """Runs *query*.""" + results = self.sqlexecute.run(query) + for result in results: + title, cur, headers, status = result + self.formatter.query = query + output = self.format_output(title, cur, headers) + for line in output: + click.echo(line, nl=new_line) + + def format_output(self, title, cur, headers, expanded=False, max_width=None): + expanded = expanded or self.formatter.format_name == "vertical" + output = [] + + output_kwargs = { + "dialect": "unix", + "disable_numparse": True, + "preserve_whitespace": True, + "preprocessors": (preprocessors.align_decimals,), + "style": self.output_style, + } + + if title: # Only print the title if it's not None. + output = itertools.chain(output, [title]) + + if cur: + column_types = None + if hasattr(cur, "description"): + + def get_col_type(col): + # col_type = FIELD_TYPES.get(col[1], text_type) + # return col_type if type(col_type) is type else text_type + return text_type + + column_types = [get_col_type(col) for col in cur.description] + + if max_width is not None: + cur = list(cur) + + formatted = self.formatter.format_output( + cur, + headers, + format_name="vertical" if expanded else None, + column_types=column_types, + **output_kwargs + ) + + if isinstance(formatted, (text_type)): + formatted = formatted.splitlines() + formatted = iter(formatted) + + first_line = next(formatted) + formatted = itertools.chain([first_line], formatted) + + if ( + not expanded + and max_width + and headers + and cur + and len(first_line) > max_width + ): + formatted = self.formatter.format_output( + cur, + headers, + format_name="vertical", + column_types=column_types, + **output_kwargs + ) + if isinstance(formatted, (text_type)): + formatted = iter(formatted.splitlines()) + + output = itertools.chain(output, formatted) + + return output + + def get_reserved_space(self): + """Get the number of lines to reserve for the completion menu.""" + reserved_space_ratio = 0.45 + max_reserved_space = 8 + _, height = click.get_terminal_size() + return min(int(round(height * reserved_space_ratio)), max_reserved_space) + + def get_last_query(self): + """Get the last query executed or None.""" + return self.query_history[-1][0] if self.query_history else None + + +@click.command() +@click.option("-V", "--version", is_flag=True, help="Output litecli's version.") +@click.option("-D", "--database", "dbname", help="Database to use.") +@click.option( + "-R", + "--prompt", + "prompt", + help='Prompt format (Default: "{0}").'.format(LiteCli.default_prompt), +) +@click.option( + "-l", + "--logfile", + type=click.File(mode="a", encoding="utf-8"), + help="Log every query and its results to a file.", +) +@click.option( + "--liteclirc", + default=config_location() + "config", + help="Location of liteclirc file.", + type=click.Path(dir_okay=False), +) +@click.option( + "--auto-vertical-output", + is_flag=True, + help="Automatically switch to vertical output mode if the result is wider than the terminal width.", +) +@click.option( + "-t", "--table", is_flag=True, help="Display batch output in table format." +) +@click.option("--csv", is_flag=True, help="Display batch output in CSV format.") +@click.option( + "--warn/--no-warn", default=None, help="Warn before running a destructive query." +) +@click.option("-e", "--execute", type=str, help="Execute command and quit.") +@click.argument("database", default="", nargs=1) +def cli( + database, + dbname, + version, + prompt, + logfile, + auto_vertical_output, + table, + csv, + warn, + execute, + liteclirc, +): + """A SQLite terminal client with auto-completion and syntax highlighting. + + \b + Examples: + - litecli lite_database + + """ + + if version: + print("Version:", __version__) + sys.exit(0) + + litecli = LiteCli( + prompt=prompt, + logfile=logfile, + auto_vertical_output=auto_vertical_output, + warn=warn, + liteclirc=liteclirc, + ) + + # Choose which ever one has a valid value. + database = database or dbname + + litecli.connect(database) + + litecli.logger.debug("Launch Params: \n" "\tdatabase: %r", database) + + # --execute argument + if execute: + try: + if csv: + litecli.formatter.format_name = "csv" + elif not table: + litecli.formatter.format_name = "tsv" + + litecli.run_query(execute) + exit(0) + except Exception as e: + click.secho(str(e), err=True, fg="red") + exit(1) + + if sys.stdin.isatty(): + litecli.run_cli() + else: + stdin = click.get_text_stream("stdin") + stdin_text = stdin.read() + + try: + sys.stdin = open("/dev/tty") + except (FileNotFoundError, OSError): + litecli.logger.warning("Unable to open TTY as stdin.") + + if ( + litecli.destructive_warning + and confirm_destructive_query(stdin_text) is False + ): + exit(0) + try: + new_line = True + + if csv: + litecli.formatter.format_name = "csv" + elif not table: + litecli.formatter.format_name = "tsv" + + litecli.run_query(stdin_text, new_line=new_line) + exit(0) + except Exception as e: + click.secho(str(e), err=True, fg="red") + exit(1) + + +def need_completion_refresh(queries): + """Determines if the completion needs a refresh by checking if the sql + statement is an alter, create, drop or change db.""" + for query in sqlparse.split(queries): + try: + first_token = query.split()[0] + if first_token.lower() in ( + "alter", + "create", + "use", + "\\r", + "\\u", + "connect", + "drop", + ): + return True + except Exception: + return False + + +def need_completion_reset(queries): + """Determines if the statement is a database switch such as 'use' or '\\u'. + When a database is changed the existing completions must be reset before we + start the completion refresh for the new database. + """ + for query in sqlparse.split(queries): + try: + first_token = query.split()[0] + if first_token.lower() in ("use", "\\u"): + return True + except Exception: + return False + + +def is_mutating(status): + """Determines if the statement is mutating based on the status.""" + if not status: + return False + + mutating = set( + [ + "insert", + "update", + "delete", + "alter", + "create", + "drop", + "replace", + "truncate", + "load", + ] + ) + return status.split(None, 1)[0].lower() in mutating + + +def is_select(status): + """Returns true if the first word in status is 'select'.""" + if not status: + return False + return status.split(None, 1)[0].lower() == "select" + + +if __name__ == "__main__": + cli() diff --git a/litecli/packages/__init__.py b/litecli/packages/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/litecli/packages/completion_engine.py b/litecli/packages/completion_engine.py new file mode 100644 index 0000000..0397857 --- /dev/null +++ b/litecli/packages/completion_engine.py @@ -0,0 +1,331 @@ +from __future__ import print_function +import sys +import sqlparse +from sqlparse.sql import Comparison, Identifier, Where +from litecli.encodingutils import string_types, text_type +from .parseutils import last_word, extract_tables, find_prev_keyword +from .special import parse_special_command + + +def suggest_type(full_text, text_before_cursor): + """Takes the full_text that is typed so far and also the text before the + cursor to suggest completion type and scope. + + Returns a tuple with a type of entity ('table', 'column' etc) and a scope. + A scope for a column category will be a list of tables. + """ + + word_before_cursor = last_word(text_before_cursor, include="many_punctuations") + + identifier = None + + # here should be removed once sqlparse has been fixed + try: + # If we've partially typed a word then word_before_cursor won't be an empty + # string. In that case we want to remove the partially typed string before + # sending it to the sqlparser. Otherwise the last token will always be the + # partially typed string which renders the smart completion useless because + # it will always return the list of keywords as completion. + if word_before_cursor: + if word_before_cursor.endswith("(") or word_before_cursor.startswith("\\"): + parsed = sqlparse.parse(text_before_cursor) + else: + parsed = sqlparse.parse(text_before_cursor[: -len(word_before_cursor)]) + + # word_before_cursor may include a schema qualification, like + # "schema_name.partial_name" or "schema_name.", so parse it + # separately + p = sqlparse.parse(word_before_cursor)[0] + + if p.tokens and isinstance(p.tokens[0], Identifier): + identifier = p.tokens[0] + else: + parsed = sqlparse.parse(text_before_cursor) + except (TypeError, AttributeError): + return [{"type": "keyword"}] + + if len(parsed) > 1: + # Multiple statements being edited -- isolate the current one by + # cumulatively summing statement lengths to find the one that bounds the + # current position + current_pos = len(text_before_cursor) + stmt_start, stmt_end = 0, 0 + + for statement in parsed: + stmt_len = len(text_type(statement)) + stmt_start, stmt_end = stmt_end, stmt_end + stmt_len + + if stmt_end >= current_pos: + text_before_cursor = full_text[stmt_start:current_pos] + full_text = full_text[stmt_start:] + break + + elif parsed: + # A single statement + statement = parsed[0] + else: + # The empty string + statement = None + + # Check for special commands and handle those separately + if statement: + # Be careful here because trivial whitespace is parsed as a statement, + # but the statement won't have a first token + tok1 = statement.token_first() + if tok1 and tok1.value.startswith("."): + return suggest_special(text_before_cursor) + elif tok1 and tok1.value.startswith("\\"): + return suggest_special(text_before_cursor) + elif tok1 and tok1.value.startswith("source"): + return suggest_special(text_before_cursor) + elif text_before_cursor and text_before_cursor.startswith(".open "): + return suggest_special(text_before_cursor) + + last_token = statement and statement.token_prev(len(statement.tokens))[1] or "" + + return suggest_based_on_last_token( + last_token, text_before_cursor, full_text, identifier + ) + + +def suggest_special(text): + text = text.lstrip() + cmd, _, arg = parse_special_command(text) + + if cmd == text: + # Trying to complete the special command itself + return [{"type": "special"}] + + if cmd in ("\\u", "\\r"): + return [{"type": "database"}] + + if cmd in ("\\T"): + return [{"type": "table_format"}] + + if cmd in ["\\f", "\\fs", "\\fd"]: + return [{"type": "favoritequery"}] + + if cmd in ["\\d", "\\dt", "\\dt+", ".schema"]: + return [ + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + {"type": "schema"}, + ] + + if cmd in ["\\.", "source", ".open"]: + return [{"type": "file_name"}] + + if cmd in [".import"]: + # Usage: .import filename table + if _expecting_arg_idx(arg, text) == 1: + return [{"type": "file_name"}] + else: + return [{"type": "table", "schema": []}] + + return [{"type": "keyword"}, {"type": "special"}] + + +def _expecting_arg_idx(arg, text): + """Return the index of expecting argument. + + >>> _expecting_arg_idx("./da", ".import ./da") + 1 + >>> _expecting_arg_idx("./data.csv", ".import ./data.csv") + 1 + >>> _expecting_arg_idx("./data.csv", ".import ./data.csv ") + 2 + >>> _expecting_arg_idx("./data.csv t", ".import ./data.csv t") + 2 + """ + args = arg.split() + return len(args) + int(text[-1].isspace()) + + +def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier): + if isinstance(token, string_types): + token_v = token.lower() + elif isinstance(token, Comparison): + # If 'token' is a Comparison type such as + # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling + # token.value on the comparison type will only return the lhs of the + # comparison. In this case a.id. So we need to do token.tokens to get + # both sides of the comparison and pick the last token out of that + # list. + token_v = token.tokens[-1].value.lower() + elif isinstance(token, Where): + # sqlparse groups all tokens from the where clause into a single token + # list. This means that token.value may be something like + # 'where foo > 5 and '. We need to look "inside" token.tokens to handle + # suggestions in complicated where clauses correctly + prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor) + return suggest_based_on_last_token( + prev_keyword, text_before_cursor, full_text, identifier + ) + else: + token_v = token.value.lower() + + is_operand = lambda x: x and any([x.endswith(op) for op in ["+", "-", "*", "/"]]) + + if not token: + return [{"type": "keyword"}, {"type": "special"}] + elif token_v.endswith("("): + p = sqlparse.parse(text_before_cursor)[0] + + if p.tokens and isinstance(p.tokens[-1], Where): + # Four possibilities: + # 1 - Parenthesized clause like "WHERE foo AND (" + # Suggest columns/functions + # 2 - Function call like "WHERE foo(" + # Suggest columns/functions + # 3 - Subquery expression like "WHERE EXISTS (" + # Suggest keywords, in order to do a subquery + # 4 - Subquery OR array comparison like "WHERE foo = ANY(" + # Suggest columns/functions AND keywords. (If we wanted to be + # really fancy, we could suggest only array-typed columns) + + column_suggestions = suggest_based_on_last_token( + "where", text_before_cursor, full_text, identifier + ) + + # Check for a subquery expression (cases 3 & 4) + where = p.tokens[-1] + idx, prev_tok = where.token_prev(len(where.tokens) - 1) + + if isinstance(prev_tok, Comparison): + # e.g. "SELECT foo FROM bar WHERE foo = ANY(" + prev_tok = prev_tok.tokens[-1] + + prev_tok = prev_tok.value.lower() + if prev_tok == "exists": + return [{"type": "keyword"}] + else: + return column_suggestions + + # Get the token before the parens + idx, prev_tok = p.token_prev(len(p.tokens) - 1) + if prev_tok and prev_tok.value and prev_tok.value.lower() == "using": + # tbl1 INNER JOIN tbl2 USING (col1, col2) + tables = extract_tables(full_text) + + # suggest columns that are present in more than one table + return [{"type": "column", "tables": tables, "drop_unique": True}] + elif p.token_first().value.lower() == "select": + # If the lparen is preceeded by a space chances are we're about to + # do a sub-select. + if last_word(text_before_cursor, "all_punctuations").startswith("("): + return [{"type": "keyword"}] + elif p.token_first().value.lower() == "show": + return [{"type": "show"}] + + # We're probably in a function argument list + return [{"type": "column", "tables": extract_tables(full_text)}] + elif token_v in ("set", "order by", "distinct"): + return [{"type": "column", "tables": extract_tables(full_text)}] + elif token_v == "as": + # Don't suggest anything for an alias + return [] + elif token_v in ("show"): + return [{"type": "show"}] + elif token_v in ("to",): + p = sqlparse.parse(text_before_cursor)[0] + if p.token_first().value.lower() == "change": + return [{"type": "change"}] + else: + return [{"type": "user"}] + elif token_v in ("user", "for"): + return [{"type": "user"}] + elif token_v in ("select", "where", "having"): + # Check for a table alias or schema qualification + parent = (identifier and identifier.get_parent_name()) or [] + + tables = extract_tables(full_text) + if parent: + tables = [t for t in tables if identifies(parent, *t)] + return [ + {"type": "column", "tables": tables}, + {"type": "table", "schema": parent}, + {"type": "view", "schema": parent}, + {"type": "function", "schema": parent}, + ] + else: + aliases = [alias or table for (schema, table, alias) in tables] + return [ + {"type": "column", "tables": tables}, + {"type": "function", "schema": []}, + {"type": "alias", "aliases": aliases}, + {"type": "keyword"}, + ] + elif (token_v.endswith("join") and token.is_keyword) or ( + token_v + in ("copy", "from", "update", "into", "describe", "truncate", "desc", "explain") + ): + schema = (identifier and identifier.get_parent_name()) or [] + + # Suggest tables from either the currently-selected schema or the + # public schema if no schema has been specified + suggest = [{"type": "table", "schema": schema}] + + if not schema: + # Suggest schemas + suggest.insert(0, {"type": "schema"}) + + # Only tables can be TRUNCATED, otherwise suggest views + if token_v != "truncate": + suggest.append({"type": "view", "schema": schema}) + + return suggest + + elif token_v in ("table", "view", "function"): + # E.g. 'DROP FUNCTION ', 'ALTER TABLE ' + rel_type = token_v + schema = (identifier and identifier.get_parent_name()) or [] + if schema: + return [{"type": rel_type, "schema": schema}] + else: + return [{"type": "schema"}, {"type": rel_type, "schema": []}] + elif token_v == "on": + tables = extract_tables(full_text) # [(schema, table, alias), ...] + parent = (identifier and identifier.get_parent_name()) or [] + if parent: + # "ON parent." + # parent can be either a schema name or table alias + tables = [t for t in tables if identifies(parent, *t)] + return [ + {"type": "column", "tables": tables}, + {"type": "table", "schema": parent}, + {"type": "view", "schema": parent}, + {"type": "function", "schema": parent}, + ] + else: + # ON + # Use table alias if there is one, otherwise the table name + aliases = [alias or table for (schema, table, alias) in tables] + suggest = [{"type": "alias", "aliases": aliases}] + + # The lists of 'aliases' could be empty if we're trying to complete + # a GRANT query. eg: GRANT SELECT, INSERT ON + # In that case we just suggest all tables. + if not aliases: + suggest.append({"type": "table", "schema": parent}) + return suggest + + elif token_v in ("use", "database", "template", "connect"): + # "\c ", "DROP DATABASE ", + # "CREATE DATABASE WITH TEMPLATE " + return [{"type": "database"}] + elif token_v == "tableformat": + return [{"type": "table_format"}] + elif token_v.endswith(",") or is_operand(token_v) or token_v in ["=", "and", "or"]: + prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor) + if prev_keyword: + return suggest_based_on_last_token( + prev_keyword, text_before_cursor, full_text, identifier + ) + else: + return [] + else: + return [{"type": "keyword"}] + + +def identifies(id, schema, table, alias): + return id == alias or id == table or (schema and (id == schema + "." + table)) diff --git a/litecli/packages/filepaths.py b/litecli/packages/filepaths.py new file mode 100644 index 0000000..2f01046 --- /dev/null +++ b/litecli/packages/filepaths.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 + +from __future__ import unicode_literals + +from litecli.encodingutils import text_type +import os + + +def list_path(root_dir): + """List directory if exists. + + :param dir: str + :return: list + + """ + res = [] + if os.path.isdir(root_dir): + for name in os.listdir(root_dir): + res.append(name) + return res + + +def complete_path(curr_dir, last_dir): + """Return the path to complete that matches the last entered component. + + If the last entered component is ~, expanded path would not + match, so return all of the available paths. + + :param curr_dir: str + :param last_dir: str + :return: str + + """ + if not last_dir or curr_dir.startswith(last_dir): + return curr_dir + elif last_dir == "~": + return os.path.join(last_dir, curr_dir) + + +def parse_path(root_dir): + """Split path into head and last component for the completer. + + Also return position where last component starts. + + :param root_dir: str path + :return: tuple of (string, string, int) + + """ + base_dir, last_dir, position = "", "", 0 + if root_dir: + base_dir, last_dir = os.path.split(root_dir) + position = -len(last_dir) if last_dir else 0 + return base_dir, last_dir, position + + +def suggest_path(root_dir): + """List all files and subdirectories in a directory. + + If the directory is not specified, suggest root directory, + user directory, current and parent directory. + + :param root_dir: string: directory to list + :return: list + + """ + if not root_dir: + return map(text_type, [os.path.abspath(os.sep), "~", os.curdir, os.pardir]) + + if "~" in root_dir: + root_dir = text_type(os.path.expanduser(root_dir)) + + if not os.path.exists(root_dir): + root_dir, _ = os.path.split(root_dir) + + return list_path(root_dir) + + +def dir_path_exists(path): + """Check if the directory path exists for a given file. + + For example, for a file /home/user/.cache/litecli/log, check if + /home/user/.cache/litecli exists. + + :param str path: The file path. + :return: Whether or not the directory path exists. + + """ + return os.path.exists(os.path.dirname(path)) diff --git a/litecli/packages/parseutils.py b/litecli/packages/parseutils.py new file mode 100644 index 0000000..92fe365 --- /dev/null +++ b/litecli/packages/parseutils.py @@ -0,0 +1,227 @@ +from __future__ import print_function +import re +import sqlparse +from sqlparse.sql import IdentifierList, Identifier, Function +from sqlparse.tokens import Keyword, DML, Punctuation + +cleanup_regex = { + # This matches only alphanumerics and underscores. + "alphanum_underscore": re.compile(r"(\w+)$"), + # This matches everything except spaces, parens, colon, and comma + "many_punctuations": re.compile(r"([^():,\s]+)$"), + # This matches everything except spaces, parens, colon, comma, and period + "most_punctuations": re.compile(r"([^\.():,\s]+)$"), + # This matches everything except a space. + "all_punctuations": re.compile("([^\s]+)$"), +} + + +def last_word(text, include="alphanum_underscore"): + """ + Find the last word in a sentence. + + >>> last_word('abc') + 'abc' + >>> last_word(' abc') + 'abc' + >>> last_word('') + '' + >>> last_word(' ') + '' + >>> last_word('abc ') + '' + >>> last_word('abc def') + 'def' + >>> last_word('abc def ') + '' + >>> last_word('abc def;') + '' + >>> last_word('bac $def') + 'def' + >>> last_word('bac $def', include='most_punctuations') + '$def' + >>> last_word('bac \def', include='most_punctuations') + '\\\\def' + >>> last_word('bac \def;', include='most_punctuations') + '\\\\def;' + >>> last_word('bac::def', include='most_punctuations') + 'def' + """ + + if not text: # Empty string + return "" + + if text[-1].isspace(): + return "" + else: + regex = cleanup_regex[include] + matches = regex.search(text) + if matches: + return matches.group(0) + else: + return "" + + +# This code is borrowed from sqlparse example script. +# +def is_subselect(parsed): + if not parsed.is_group: + return False + for item in parsed.tokens: + if item.ttype is DML and item.value.upper() in ( + "SELECT", + "INSERT", + "UPDATE", + "CREATE", + "DELETE", + ): + return True + return False + + +def extract_from_part(parsed, stop_at_punctuation=True): + tbl_prefix_seen = False + for item in parsed.tokens: + if tbl_prefix_seen: + if is_subselect(item): + for x in extract_from_part(item, stop_at_punctuation): + yield x + elif stop_at_punctuation and item.ttype is Punctuation: + return + # An incomplete nested select won't be recognized correctly as a + # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes + # the second FROM to trigger this elif condition resulting in a + # `return`. So we need to ignore the keyword if the keyword + # FROM. + # Also 'SELECT * FROM abc JOIN def' will trigger this elif + # condition. So we need to ignore the keyword JOIN and its variants + # INNER JOIN, FULL OUTER JOIN, etc. + elif ( + item.ttype is Keyword + and (not item.value.upper() == "FROM") + and (not item.value.upper().endswith("JOIN")) + ): + return + else: + yield item + elif ( + item.ttype is Keyword or item.ttype is Keyword.DML + ) and item.value.upper() in ("COPY", "FROM", "INTO", "UPDATE", "TABLE", "JOIN"): + tbl_prefix_seen = True + # 'SELECT a, FROM abc' will detect FROM as part of the column list. + # So this check here is necessary. + elif isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + if identifier.ttype is Keyword and identifier.value.upper() == "FROM": + tbl_prefix_seen = True + break + + +def extract_table_identifiers(token_stream): + """yields tuples of (schema_name, table_name, table_alias)""" + + for item in token_stream: + if isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + # Sometimes Keywords (such as FROM ) are classified as + # identifiers which don't have the get_real_name() method. + try: + schema_name = identifier.get_parent_name() + real_name = identifier.get_real_name() + except AttributeError: + continue + if real_name: + yield (schema_name, real_name, identifier.get_alias()) + elif isinstance(item, Identifier): + real_name = item.get_real_name() + schema_name = item.get_parent_name() + + if real_name: + yield (schema_name, real_name, item.get_alias()) + else: + name = item.get_name() + yield (None, name, item.get_alias() or name) + elif isinstance(item, Function): + yield (None, item.get_name(), item.get_name()) + + +# extract_tables is inspired from examples in the sqlparse lib. +def extract_tables(sql): + """Extract the table names from an SQL statment. + + Returns a list of (schema, table, alias) tuples + + """ + parsed = sqlparse.parse(sql) + if not parsed: + return [] + + # INSERT statements must stop looking for tables at the sign of first + # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2) + # abc is the table name, but if we don't stop at the first lparen, then + # we'll identify abc, col1 and col2 as table names. + insert_stmt = parsed[0].token_first().value.lower() == "insert" + stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt) + return list(extract_table_identifiers(stream)) + + +def find_prev_keyword(sql): + """ Find the last sql keyword in an SQL statement + + Returns the value of the last keyword, and the text of the query with + everything after the last keyword stripped + """ + if not sql.strip(): + return None, "" + + parsed = sqlparse.parse(sql)[0] + flattened = list(parsed.flatten()) + + logical_operators = ("AND", "OR", "NOT", "BETWEEN") + + for t in reversed(flattened): + if t.value == "(" or ( + t.is_keyword and (t.value.upper() not in logical_operators) + ): + # Find the location of token t in the original parsed statement + # We can't use parsed.token_index(t) because t may be a child token + # inside a TokenList, in which case token_index thows an error + # Minimal example: + # p = sqlparse.parse('select * from foo where bar') + # t = list(p.flatten())[-3] # The "Where" token + # p.token_index(t) # Throws ValueError: not in list + idx = flattened.index(t) + + # Combine the string values of all tokens in the original list + # up to and including the target keyword token t, to produce a + # query string with everything after the keyword token removed + text = "".join(tok.value for tok in flattened[: idx + 1]) + return t, text + + return None, "" + + +def query_starts_with(query, prefixes): + """Check if the query starts with any item from *prefixes*.""" + prefixes = [prefix.lower() for prefix in prefixes] + formatted_sql = sqlparse.format(query.lower(), strip_comments=True) + return bool(formatted_sql) and formatted_sql.split()[0] in prefixes + + +def queries_start_with(queries, prefixes): + """Check if any queries start with any item from *prefixes*.""" + for query in sqlparse.split(queries): + if query and query_starts_with(query, prefixes) is True: + return True + return False + + +def is_destructive(queries): + """Returns if any of the queries in *queries* is destructive.""" + keywords = ("drop", "shutdown", "delete", "truncate", "alter") + return queries_start_with(queries, keywords) + + +if __name__ == "__main__": + sql = "select * from (select t. from tabl t" + print(extract_tables(sql)) diff --git a/litecli/packages/prompt_utils.py b/litecli/packages/prompt_utils.py new file mode 100644 index 0000000..d9ad2b6 --- /dev/null +++ b/litecli/packages/prompt_utils.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + + +import sys +import click +from .parseutils import is_destructive + + +def confirm_destructive_query(queries): + """Check if the query is destructive and prompts the user to confirm. + + Returns: + * None if the query is non-destructive or we can't prompt the user. + * True if the query is destructive and the user wants to proceed. + * False if the query is destructive and the user doesn't want to proceed. + + """ + prompt_text = ( + "You're about to run a destructive command.\n" "Do you want to proceed? (y/n)" + ) + if is_destructive(queries) and sys.stdin.isatty(): + return prompt(prompt_text, type=bool) + + +def confirm(*args, **kwargs): + """Prompt for confirmation (yes/no) and handle any abort exceptions.""" + try: + return click.confirm(*args, **kwargs) + except click.Abort: + return False + + +def prompt(*args, **kwargs): + """Prompt the user for input and handle any abort exceptions.""" + try: + return click.prompt(*args, **kwargs) + except click.Abort: + return False diff --git a/litecli/packages/special/__init__.py b/litecli/packages/special/__init__.py new file mode 100644 index 0000000..fd2b18c --- /dev/null +++ b/litecli/packages/special/__init__.py @@ -0,0 +1,12 @@ +__all__ = [] + + +def export(defn): + """Decorator to explicitly mark functions that are exposed in a lib.""" + globals()[defn.__name__] = defn + __all__.append(defn.__name__) + return defn + + +from . import dbcommands +from . import iocommands diff --git a/litecli/packages/special/dbcommands.py b/litecli/packages/special/dbcommands.py new file mode 100644 index 0000000..a7eaa0c --- /dev/null +++ b/litecli/packages/special/dbcommands.py @@ -0,0 +1,273 @@ +from __future__ import unicode_literals, print_function +import csv +import logging +import os +import sys +import platform +import shlex +from sqlite3 import ProgrammingError + +from litecli import __version__ +from litecli.packages.special import iocommands +from litecli.packages.special.utils import format_uptime +from .main import special_command, RAW_QUERY, PARSED_QUERY, ArgumentMissing + +log = logging.getLogger(__name__) + + +@special_command( + ".tables", + "\\dt", + "List tables.", + arg_type=PARSED_QUERY, + case_sensitive=True, + aliases=("\\dt",), +) +def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False): + if arg: + args = ("{0}%".format(arg),) + query = """ + SELECT name FROM sqlite_master + WHERE type IN ('table','view') AND name LIKE ? AND name NOT LIKE 'sqlite_%' + ORDER BY 1 + """ + else: + args = tuple() + query = """ + SELECT name FROM sqlite_master + WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' + ORDER BY 1 + """ + + log.debug(query) + cur.execute(query, args) + tables = cur.fetchall() + status = "" + if cur.description: + headers = [x[0] for x in cur.description] + else: + return [(None, None, None, "")] + + # if verbose and arg: + # query = "SELECT sql FROM sqlite_master WHERE name LIKE ?" + # log.debug(query) + # cur.execute(query) + # status = cur.fetchone()[1] + + return [(None, tables, headers, status)] + + +@special_command( + ".schema", + ".schema[+] [table]", + "The complete schema for the database or a single table", + arg_type=PARSED_QUERY, + case_sensitive=True, +) +def show_schema(cur, arg=None, **_): + if arg: + args = (arg,) + query = """ + SELECT sql FROM sqlite_master + WHERE name==? + ORDER BY tbl_name, type DESC, name + """ + else: + args = tuple() + query = """ + SELECT sql FROM sqlite_master + ORDER BY tbl_name, type DESC, name + """ + + log.debug(query) + cur.execute(query, args) + tables = cur.fetchall() + status = "" + if cur.description: + headers = [x[0] for x in cur.description] + else: + return [(None, None, None, "")] + + return [(None, tables, headers, status)] + + +@special_command( + ".databases", + ".databases", + "List databases.", + arg_type=RAW_QUERY, + case_sensitive=True, + aliases=("\\l",), +) +def list_databases(cur, **_): + query = "PRAGMA database_list" + log.debug(query) + cur.execute(query) + if cur.description: + headers = [x[0] for x in cur.description] + return [(None, cur, headers, "")] + else: + return [(None, None, None, "")] + + +@special_command( + ".status", + "\\s", + "Show current settings.", + arg_type=RAW_QUERY, + aliases=("\\s",), + case_sensitive=True, +) +def status(cur, **_): + # Create output buffers. + footer = [] + footer.append("--------------") + + # Output the litecli client information. + implementation = platform.python_implementation() + version = platform.python_version() + client_info = [] + client_info.append("litecli {0},".format(__version__)) + client_info.append("running on {0} {1}".format(implementation, version)) + footer.append(" ".join(client_info)) + + # Build the output that will be displayed as a table. + query = "SELECT file from pragma_database_list() where name = 'main';" + log.debug(query) + cur.execute(query) + db = cur.fetchone()[0] + if db is None: + db = "" + + footer.append("Current database: " + db) + if iocommands.is_pager_enabled(): + if "PAGER" in os.environ: + pager = os.environ["PAGER"] + else: + pager = "System default" + else: + pager = "stdout" + footer.append("Current pager:" + pager) + + footer.append("--------------") + return [(None, None, "", "\n".join(footer))] + + +@special_command( + ".load", + ".load path", + "Load an extension library.", + arg_type=PARSED_QUERY, + case_sensitive=True, +) +def load_extension(cur, arg, **_): + args = shlex.split(arg) + if len(args) != 1: + raise TypeError(".load accepts exactly one path") + path = args[0] + conn = cur.connection + conn.enable_load_extension(True) + conn.load_extension(path) + return [(None, None, None, "")] + + +@special_command( + "describe", + "\\d [table]", + "Description of a table", + arg_type=PARSED_QUERY, + case_sensitive=True, + aliases=("\\d", "describe", "desc"), +) +def describe(cur, arg, **_): + if arg: + args = (arg,) + query = """ + PRAGMA table_info({}) + """.format( + arg + ) + else: + raise ArgumentMissing("Table name required.") + + log.debug(query) + cur.execute(query) + tables = cur.fetchall() + status = "" + if cur.description: + headers = [x[0] for x in cur.description] + else: + return [(None, None, None, "")] + + return [(None, tables, headers, status)] + + +@special_command( + ".read", + ".read path", + "Read input from path", + arg_type=PARSED_QUERY, + case_sensitive=True, +) +def read_script(cur, arg, **_): + args = shlex.split(arg) + if len(args) != 1: + raise TypeError(".read accepts exactly one path") + path = args[0] + with open(path, "r") as f: + script = f.read() + cur.executescript(script) + return [(None, None, None, "")] + + +@special_command( + ".import", + ".import filename table", + "Import data from filename into an existing table", + arg_type=PARSED_QUERY, + case_sensitive=True, +) +def import_file(cur, arg=None, **_): + def split(s): + # this is a modification of shlex.split function, just to make it support '`', + # because table name might contain '`' character. + lex = shlex.shlex(s, posix=True) + lex.whitespace_split = True + lex.commenters = "" + lex.quotes += "`" + return list(lex) + + args = split(arg) + log.debug("[arg = %r], [args = %r]", arg, args) + if len(args) != 2: + raise TypeError("Usage: .import filename table") + + filename, table = args + cur.execute('PRAGMA table_info("%s")' % table) + ncols = len(cur.fetchall()) + insert_tmpl = 'INSERT INTO "%s" VALUES (?%s)' % (table, ",?" * (ncols - 1)) + + with open(filename, "r") as csvfile: + dialect = csv.Sniffer().sniff(csvfile.read(1024)) + csvfile.seek(0) + reader = csv.reader(csvfile, dialect) + + cur.execute("BEGIN") + ninserted, nignored = 0, 0 + for i, row in enumerate(reader): + if len(row) != ncols: + print( + "%s:%d expected %d columns but found %d - ignored" + % (filename, i, ncols, len(row)), + file=sys.stderr, + ) + nignored += 1 + continue + cur.execute(insert_tmpl, row) + ninserted += 1 + cur.execute("COMMIT") + + status = "Inserted %d rows into %s" % (ninserted, table) + if nignored > 0: + status += " (%d rows are ignored)" % nignored + return [(None, None, None, status)] diff --git a/litecli/packages/special/favoritequeries.py b/litecli/packages/special/favoritequeries.py new file mode 100644 index 0000000..7da6fbf --- /dev/null +++ b/litecli/packages/special/favoritequeries.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + + +class FavoriteQueries(object): + + section_name = "favorite_queries" + + usage = """ +Favorite Queries are a way to save frequently used queries +with a short name. +Examples: + + # Save a new favorite query. + > \\fs simple select * from abc where a is not Null; + + # List all favorite queries. + > \\f + ╒════════╤═══════════════════════════════════════╕ + │ Name │ Query │ + ╞════════╪═══════════════════════════════════════╡ + │ simple │ SELECT * FROM abc where a is not NULL │ + ╘════════╧═══════════════════════════════════════╛ + + # Run a favorite query. + > \\f simple + ╒════════╤════════╕ + │ a │ b │ + ╞════════╪════════╡ + │ 日本語 │ 日本語 │ + ╘════════╧════════╛ + + # Delete a favorite query. + > \\fd simple + simple: Deleted +""" + + def __init__(self, config): + self.config = config + + def list(self): + return self.config.get(self.section_name, []) + + def get(self, name): + return self.config.get(self.section_name, {}).get(name, None) + + def save(self, name, query): + if self.section_name not in self.config: + self.config[self.section_name] = {} + self.config[self.section_name][name] = query + self.config.write() + + def delete(self, name): + try: + del self.config[self.section_name][name] + except KeyError: + return "%s: Not Found." % name + self.config.write() + return "%s: Deleted" % name diff --git a/litecli/packages/special/iocommands.py b/litecli/packages/special/iocommands.py new file mode 100644 index 0000000..8940057 --- /dev/null +++ b/litecli/packages/special/iocommands.py @@ -0,0 +1,479 @@ +from __future__ import unicode_literals +import os +import re +import locale +import logging +import subprocess +import shlex +from io import open +from time import sleep + +import click +import sqlparse +from configobj import ConfigObj + +from . import export +from .main import special_command, NO_QUERY, PARSED_QUERY +from .favoritequeries import FavoriteQueries +from .utils import handle_cd_command +from litecli.packages.prompt_utils import confirm_destructive_query + +use_expanded_output = False +PAGER_ENABLED = True +tee_file = None +once_file = written_to_once_file = None +favoritequeries = FavoriteQueries(ConfigObj()) + + +@export +def set_favorite_queries(config): + global favoritequeries + favoritequeries = FavoriteQueries(config) + + +@export +def set_pager_enabled(val): + global PAGER_ENABLED + PAGER_ENABLED = val + + +@export +def is_pager_enabled(): + return PAGER_ENABLED + + +@export +@special_command( + "pager", + "\\P [command]", + "Set PAGER. Print the query results via PAGER.", + arg_type=PARSED_QUERY, + aliases=("\\P",), + case_sensitive=True, +) +def set_pager(arg, **_): + if arg: + os.environ["PAGER"] = arg + msg = "PAGER set to %s." % arg + set_pager_enabled(True) + else: + if "PAGER" in os.environ: + msg = "PAGER set to %s." % os.environ["PAGER"] + else: + # This uses click's default per echo_via_pager. + msg = "Pager enabled." + set_pager_enabled(True) + + return [(None, None, None, msg)] + + +@export +@special_command( + "nopager", + "\\n", + "Disable pager, print to stdout.", + arg_type=NO_QUERY, + aliases=("\\n",), + case_sensitive=True, +) +def disable_pager(): + set_pager_enabled(False) + return [(None, None, None, "Pager disabled.")] + + +@export +def set_expanded_output(val): + global use_expanded_output + use_expanded_output = val + + +@export +def is_expanded_output(): + return use_expanded_output + + +_logger = logging.getLogger(__name__) + + +@export +def editor_command(command): + """ + Is this an external editor command? + :param command: string + """ + # It is possible to have `\e filename` or `SELECT * FROM \e`. So we check + # for both conditions. + return command.strip().endswith("\\e") or command.strip().startswith("\\e") + + +@export +def get_filename(sql): + if sql.strip().startswith("\\e"): + command, _, filename = sql.partition(" ") + return filename.strip() or None + + +@export +def get_editor_query(sql): + """Get the query part of an editor command.""" + sql = sql.strip() + + # The reason we can't simply do .strip('\e') is that it strips characters, + # not a substring. So it'll strip "e" in the end of the sql also! + # Ex: "select * from style\e" -> "select * from styl". + pattern = re.compile("(^\\\e|\\\e$)") + while pattern.search(sql): + sql = pattern.sub("", sql) + + return sql + + +@export +def open_external_editor(filename=None, sql=None): + """Open external editor, wait for the user to type in their query, return + the query. + + :return: list with one tuple, query as first element. + + """ + + message = None + filename = filename.strip().split(" ", 1)[0] if filename else None + + sql = sql or "" + MARKER = "# Type your query above this line.\n" + + # Populate the editor buffer with the partial sql (if available) and a + # placeholder comment. + query = click.edit( + "{sql}\n\n{marker}".format(sql=sql, marker=MARKER), + filename=filename, + extension=".sql", + ) + + if filename: + try: + with open(filename, encoding="utf-8") as f: + query = f.read() + except IOError: + message = "Error reading file: %s." % filename + + if query is not None: + query = query.split(MARKER, 1)[0].rstrip("\n") + else: + # Don't return None for the caller to deal with. + # Empty string is ok. + query = sql + + return (query, message) + + +@special_command( + "\\f", + "\\f [name [args..]]", + "List or execute favorite queries.", + arg_type=PARSED_QUERY, + case_sensitive=True, +) +def execute_favorite_query(cur, arg, verbose=False, **_): + """Returns (title, rows, headers, status)""" + if arg == "": + for result in list_favorite_queries(): + yield result + + """Parse out favorite name and optional substitution parameters""" + name, _, arg_str = arg.partition(" ") + args = shlex.split(arg_str) + + query = favoritequeries.get(name) + if query is None: + message = "No favorite query: %s" % (name) + yield (None, None, None, message) + elif "?" in query: + for sql in sqlparse.split(query): + sql = sql.rstrip(";") + title = "> %s" % (sql) if verbose else None + cur.execute(sql, args) + if cur.description: + headers = [x[0] for x in cur.description] + yield (title, cur, headers, None) + else: + yield (title, None, None, None) + else: + query, arg_error = subst_favorite_query_args(query, args) + if arg_error: + yield (None, None, None, arg_error) + else: + for sql in sqlparse.split(query): + sql = sql.rstrip(";") + title = "> %s" % (sql) if verbose else None + cur.execute(sql) + if cur.description: + headers = [x[0] for x in cur.description] + yield (title, cur, headers, None) + else: + yield (title, None, None, None) + + +def list_favorite_queries(): + """List of all favorite queries. + Returns (title, rows, headers, status)""" + + headers = ["Name", "Query"] + rows = [(r, favoritequeries.get(r)) for r in favoritequeries.list()] + + if not rows: + status = "\nNo favorite queries found." + favoritequeries.usage + else: + status = "" + return [("", rows, headers, status)] + + +def subst_favorite_query_args(query, args): + """replace positional parameters ($1...$N) in query.""" + for idx, val in enumerate(args): + shell_subst_var = "$" + str(idx + 1) + question_subst_var = "?" + if shell_subst_var in query: + query = query.replace(shell_subst_var, val) + elif question_subst_var in query: + query = query.replace(question_subst_var, val, 1) + else: + return [ + None, + "Too many arguments.\nQuery does not have enough place holders to substitute.\n" + + query, + ] + + match = re.search("\\?|\\$\d+", query) + if match: + return [ + None, + "missing substitution for " + match.group(0) + " in query:\n " + query, + ] + + return [query, None] + + +@special_command("\\fs", "\\fs name query", "Save a favorite query.") +def save_favorite_query(arg, **_): + """Save a new favorite query. + Returns (title, rows, headers, status)""" + + usage = "Syntax: \\fs name query.\n\n" + favoritequeries.usage + if not arg: + return [(None, None, None, usage)] + + name, _, query = arg.partition(" ") + + # If either name or query is missing then print the usage and complain. + if (not name) or (not query): + return [(None, None, None, usage + "Err: Both name and query are required.")] + + favoritequeries.save(name, query) + return [(None, None, None, "Saved.")] + + +@special_command("\\fd", "\\fd [name]", "Delete a favorite query.") +def delete_favorite_query(arg, **_): + """Delete an existing favorite query. + """ + usage = "Syntax: \\fd name.\n\n" + favoritequeries.usage + if not arg: + return [(None, None, None, usage)] + + status = favoritequeries.delete(arg) + + return [(None, None, None, status)] + + +@special_command("system", "system [command]", "Execute a system shell commmand.") +def execute_system_command(arg, **_): + """Execute a system shell command.""" + usage = "Syntax: system [command].\n" + + if not arg: + return [(None, None, None, usage)] + + try: + command = arg.strip() + if command.startswith("cd"): + ok, error_message = handle_cd_command(arg) + if not ok: + return [(None, None, None, error_message)] + return [(None, None, None, "")] + + args = arg.split(" ") + process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + output, error = process.communicate() + response = output if not error else error + + # Python 3 returns bytes. This needs to be decoded to a string. + if isinstance(response, bytes): + encoding = locale.getpreferredencoding(False) + response = response.decode(encoding) + + return [(None, None, None, response)] + except OSError as e: + return [(None, None, None, "OSError: %s" % e.strerror)] + + +def parseargfile(arg): + if arg.startswith("-o "): + mode = "w" + filename = arg[3:] + else: + mode = "a" + filename = arg + + if not filename: + raise TypeError("You must provide a filename.") + + return {"file": os.path.expanduser(filename), "mode": mode} + + +@special_command( + "tee", + "tee [-o] filename", + "Append all results to an output file (overwrite using -o).", +) +def set_tee(arg, **_): + global tee_file + + try: + tee_file = open(**parseargfile(arg)) + except (IOError, OSError) as e: + raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror)) + + return [(None, None, None, "")] + + +@export +def close_tee(): + global tee_file + if tee_file: + tee_file.close() + tee_file = None + + +@special_command("notee", "notee", "Stop writing results to an output file.") +def no_tee(arg, **_): + close_tee() + return [(None, None, None, "")] + + +@export +def write_tee(output): + global tee_file + if tee_file: + click.echo(output, file=tee_file, nl=False) + click.echo("\n", file=tee_file, nl=False) + tee_file.flush() + + +@special_command( + ".once", + "\\o [-o] filename", + "Append next result to an output file (overwrite using -o).", + aliases=("\\o", "\\once"), +) +def set_once(arg, **_): + global once_file + + once_file = parseargfile(arg) + + return [(None, None, None, "")] + + +@export +def write_once(output): + global once_file, written_to_once_file + if output and once_file: + try: + f = open(**once_file) + except (IOError, OSError) as e: + once_file = None + raise OSError( + "Cannot write to file '{}': {}".format(e.filename, e.strerror) + ) + + with f: + click.echo(output, file=f, nl=False) + click.echo("\n", file=f, nl=False) + written_to_once_file = True + + +@export +def unset_once_if_written(): + """Unset the once file, if it has been written to.""" + global once_file + if written_to_once_file: + once_file = None + + +@special_command( + "watch", + "watch [seconds] [-c] query", + "Executes the query every [seconds] seconds (by default 5).", +) +def watch_query(arg, **kwargs): + usage = """Syntax: watch [seconds] [-c] query. + * seconds: The interval at the query will be repeated, in seconds. + By default 5. + * -c: Clears the screen between every iteration. +""" + if not arg: + yield (None, None, None, usage) + raise StopIteration + seconds = 5 + clear_screen = False + statement = None + while statement is None: + arg = arg.strip() + if not arg: + # Oops, we parsed all the arguments without finding a statement + yield (None, None, None, usage) + raise StopIteration + (current_arg, _, arg) = arg.partition(" ") + try: + seconds = float(current_arg) + continue + except ValueError: + pass + if current_arg == "-c": + clear_screen = True + continue + statement = "{0!s} {1!s}".format(current_arg, arg) + destructive_prompt = confirm_destructive_query(statement) + if destructive_prompt is False: + click.secho("Wise choice!") + raise StopIteration + elif destructive_prompt is True: + click.secho("Your call!") + cur = kwargs["cur"] + sql_list = [ + (sql.rstrip(";"), "> {0!s}".format(sql)) for sql in sqlparse.split(statement) + ] + old_pager_enabled = is_pager_enabled() + while True: + if clear_screen: + click.clear() + try: + # Somewhere in the code the pager its activated after every yield, + # so we disable it in every iteration + set_pager_enabled(False) + for (sql, title) in sql_list: + cur.execute(sql) + if cur.description: + headers = [x[0] for x in cur.description] + yield (title, cur, headers, None) + else: + yield (title, None, None, None) + sleep(seconds) + except KeyboardInterrupt: + # This prints the Ctrl-C character in its own line, which prevents + # to print a line with the cursor positioned behind the prompt + click.secho("", nl=True) + raise StopIteration + finally: + set_pager_enabled(old_pager_enabled) diff --git a/litecli/packages/special/main.py b/litecli/packages/special/main.py new file mode 100644 index 0000000..3dd0e77 --- /dev/null +++ b/litecli/packages/special/main.py @@ -0,0 +1,160 @@ +from __future__ import unicode_literals +import logging +from collections import namedtuple + +from . import export + +log = logging.getLogger(__name__) + +NO_QUERY = 0 +PARSED_QUERY = 1 +RAW_QUERY = 2 + +SpecialCommand = namedtuple( + "SpecialCommand", + [ + "handler", + "command", + "shortcut", + "description", + "arg_type", + "hidden", + "case_sensitive", + ], +) + +COMMANDS = {} + + +@export +class ArgumentMissing(Exception): + pass + + +@export +class CommandNotFound(Exception): + pass + + +@export +def parse_special_command(sql): + command, _, arg = sql.partition(" ") + verbose = "+" in command + command = command.strip().replace("+", "") + return (command, verbose, arg.strip()) + + +@export +def special_command( + command, + shortcut, + description, + arg_type=PARSED_QUERY, + hidden=False, + case_sensitive=False, + aliases=(), +): + def wrapper(wrapped): + register_special_command( + wrapped, + command, + shortcut, + description, + arg_type, + hidden, + case_sensitive, + aliases, + ) + return wrapped + + return wrapper + + +@export +def register_special_command( + handler, + command, + shortcut, + description, + arg_type=PARSED_QUERY, + hidden=False, + case_sensitive=False, + aliases=(), +): + cmd = command.lower() if not case_sensitive else command + COMMANDS[cmd] = SpecialCommand( + handler, command, shortcut, description, arg_type, hidden, case_sensitive + ) + for alias in aliases: + cmd = alias.lower() if not case_sensitive else alias + COMMANDS[cmd] = SpecialCommand( + handler, + command, + shortcut, + description, + arg_type, + case_sensitive=case_sensitive, + hidden=True, + ) + + +@export +def execute(cur, sql): + """Execute a special command and return the results. If the special command + is not supported a KeyError will be raised. + """ + command, verbose, arg = parse_special_command(sql) + + if (command not in COMMANDS) and (command.lower() not in COMMANDS): + raise CommandNotFound + + try: + special_cmd = COMMANDS[command] + except KeyError: + special_cmd = COMMANDS[command.lower()] + if special_cmd.case_sensitive: + raise CommandNotFound("Command not found: %s" % command) + + if special_cmd.arg_type == NO_QUERY: + return special_cmd.handler() + elif special_cmd.arg_type == PARSED_QUERY: + return special_cmd.handler(cur=cur, arg=arg, verbose=verbose) + elif special_cmd.arg_type == RAW_QUERY: + return special_cmd.handler(cur=cur, query=sql) + + +@special_command( + "help", "\\?", "Show this help.", arg_type=NO_QUERY, aliases=("\\?", "?") +) +def show_help(): # All the parameters are ignored. + headers = ["Command", "Shortcut", "Description"] + result = [] + + for _, value in sorted(COMMANDS.items()): + if not value.hidden: + result.append((value.command, value.shortcut, value.description)) + return [(None, result, headers, None)] + + +@special_command(".exit", "\\q", "Exit.", arg_type=NO_QUERY, aliases=("\\q", "exit")) +@special_command("quit", "\\q", "Quit.", arg_type=NO_QUERY) +def quit(*_args): + raise EOFError + + +@special_command( + "\\e", + "\\e", + "Edit command with editor (uses $EDITOR).", + arg_type=NO_QUERY, + case_sensitive=True, +) +@special_command( + "\\G", + "\\G", + "Display current query results vertically.", + arg_type=NO_QUERY, + case_sensitive=True, +) +def stub(): + raise NotImplementedError diff --git a/litecli/packages/special/utils.py b/litecli/packages/special/utils.py new file mode 100644 index 0000000..eed9306 --- /dev/null +++ b/litecli/packages/special/utils.py @@ -0,0 +1,48 @@ +import os +import subprocess + + +def handle_cd_command(arg): + """Handles a `cd` shell command by calling python's os.chdir.""" + CD_CMD = "cd" + tokens = arg.split(CD_CMD + " ") + directory = tokens[-1] if len(tokens) > 1 else None + if not directory: + return False, "No folder name was provided." + try: + os.chdir(directory) + subprocess.call(["pwd"]) + return True, None + except OSError as e: + return False, e.strerror + + +def format_uptime(uptime_in_seconds): + """Format number of seconds into human-readable string. + + :param uptime_in_seconds: The server uptime in seconds. + :returns: A human-readable string representing the uptime. + + >>> uptime = format_uptime('56892') + >>> print(uptime) + 15 hours 48 min 12 sec + """ + + m, s = divmod(int(uptime_in_seconds), 60) + h, m = divmod(m, 60) + d, h = divmod(h, 24) + + uptime_values = [] + + for value, unit in ((d, "days"), (h, "hours"), (m, "min"), (s, "sec")): + if value == 0 and not uptime_values: + # Don't include a value/unit if the unit isn't applicable to + # the uptime. E.g. don't do 0 days 0 hours 1 min 30 sec. + continue + elif value == 1 and unit.endswith("s"): + # Remove the "s" if the unit is singular. + unit = unit[:-1] + uptime_values.append("{0} {1}".format(value, unit)) + + uptime = " ".join(uptime_values) + return uptime diff --git a/litecli/sqlcompleter.py b/litecli/sqlcompleter.py new file mode 100644 index 0000000..64ca352 --- /dev/null +++ b/litecli/sqlcompleter.py @@ -0,0 +1,612 @@ +from __future__ import print_function +from __future__ import unicode_literals +import logging +from re import compile, escape +from collections import Counter + +from prompt_toolkit.completion import Completer, Completion + +from .packages.completion_engine import suggest_type +from .packages.parseutils import last_word +from .packages.special.iocommands import favoritequeries +from .packages.filepaths import parse_path, complete_path, suggest_path + +_logger = logging.getLogger(__name__) + + +class SQLCompleter(Completer): + keywords = [ + "ABORT", + "ACTION", + "ADD", + "AFTER", + "ALL", + "ALTER", + "ANALYZE", + "AND", + "AS", + "ASC", + "ATTACH", + "AUTOINCREMENT", + "BEFORE", + "BEGIN", + "BETWEEN", + "BIGINT", + "BLOB", + "BOOLEAN", + "BY", + "CASCADE", + "CASE", + "CAST", + "CHARACTER", + "CHECK", + "CLOB", + "COLLATE", + "COLUMN", + "COMMIT", + "CONFLICT", + "CONSTRAINT", + "CREATE", + "CROSS", + "CURRENT", + "CURRENT_DATE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "DATABASE", + "DATE", + "DATETIME", + "DECIMAL", + "DEFAULT", + "DEFERRABLE", + "DEFERRED", + "DELETE", + "DETACH", + "DISTINCT", + "DO", + "DOUBLE PRECISION", + "DOUBLE", + "DROP", + "EACH", + "ELSE", + "END", + "ESCAPE", + "EXCEPT", + "EXCLUSIVE", + "EXISTS", + "EXPLAIN", + "FAIL", + "FILTER", + "FLOAT", + "FOLLOWING", + "FOR", + "FOREIGN", + "FROM", + "FULL", + "GLOB", + "GROUP", + "HAVING", + "IF", + "IGNORE", + "IMMEDIATE", + "IN", + "INDEX", + "INDEXED", + "INITIALLY", + "INNER", + "INSERT", + "INSTEAD", + "INT", + "INT2", + "INT8", + "INTEGER", + "INTERSECT", + "INTO", + "IS", + "ISNULL", + "JOIN", + "KEY", + "LEFT", + "LIKE", + "LIMIT", + "MATCH", + "MEDIUMINT", + "NATIVE CHARACTER", + "NATURAL", + "NCHAR", + "NO", + "NOT", + "NOTHING", + "NULL", + "NULLS FIRST", + "NULLS LAST", + "NUMERIC", + "NVARCHAR", + "OF", + "OFFSET", + "ON", + "OR", + "ORDER BY", + "OUTER", + "OVER", + "PARTITION", + "PLAN", + "PRAGMA", + "PRECEDING", + "PRIMARY", + "QUERY", + "RAISE", + "RANGE", + "REAL", + "RECURSIVE", + "REFERENCES", + "REGEXP", + "REINDEX", + "RELEASE", + "RENAME", + "REPLACE", + "RESTRICT", + "RIGHT", + "ROLLBACK", + "ROW", + "ROWS", + "SAVEPOINT", + "SELECT", + "SET", + "SMALLINT", + "TABLE", + "TEMP", + "TEMPORARY", + "TEXT", + "THEN", + "TINYINT", + "TO", + "TRANSACTION", + "TRIGGER", + "UNBOUNDED", + "UNION", + "UNIQUE", + "UNSIGNED BIG INT", + "UPDATE", + "USING", + "VACUUM", + "VALUES", + "VARCHAR", + "VARYING CHARACTER", + "VIEW", + "VIRTUAL", + "WHEN", + "WHERE", + "WINDOW", + "WITH", + "WITHOUT", + ] + + functions = [ + "ABS", + "AVG", + "CHANGES", + "CHAR", + "COALESCE", + "COUNT", + "CUME_DIST", + "DATE", + "DATETIME", + "DENSE_RANK", + "GLOB", + "GROUP_CONCAT", + "HEX", + "IFNULL", + "INSTR", + "JSON", + "JSON_ARRAY", + "JSON_ARRAY_LENGTH", + "JSON_EACH", + "JSON_EXTRACT", + "JSON_GROUP_ARRAY", + "JSON_GROUP_OBJECT", + "JSON_INSERT", + "JSON_OBJECT", + "JSON_PATCH", + "JSON_QUOTE", + "JSON_REMOVE", + "JSON_REPLACE", + "JSON_SET", + "JSON_TREE", + "JSON_TYPE", + "JSON_VALID", + "JULIANDAY", + "LAG", + "LAST_INSERT_ROWID", + "LENGTH", + "LIKELIHOOD", + "LIKELY", + "LOAD_EXTENSION", + "LOWER", + "LTRIM", + "MAX", + "MIN", + "NTILE", + "NULLIF", + "PERCENT_RANK", + "PRINTF", + "QUOTE", + "RANDOM", + "RANDOMBLOB", + "RANK", + "REPLACE", + "ROUND", + "ROW_NUMBER", + "RTRIM", + "SOUNDEX", + "SQLITE_COMPILEOPTION_GET", + "SQLITE_COMPILEOPTION_USED", + "SQLITE_OFFSET", + "SQLITE_SOURCE_ID", + "SQLITE_VERSION", + "STRFTIME", + "SUBSTR", + "SUM", + "TIME", + "TOTAL", + "TOTAL_CHANGES", + "TRIM", + ] + + def __init__(self, supported_formats=(), keyword_casing="auto"): + super(self.__class__, self).__init__() + self.reserved_words = set() + for x in self.keywords: + self.reserved_words.update(x.split()) + self.name_pattern = compile("^[_a-z][_a-z0-9\$]*$") + + self.special_commands = [] + self.table_formats = supported_formats + if keyword_casing not in ("upper", "lower", "auto"): + keyword_casing = "auto" + self.keyword_casing = keyword_casing + self.reset_completions() + + def escape_name(self, name): + if name and ( + (not self.name_pattern.match(name)) + or (name.upper() in self.reserved_words) + or (name.upper() in self.functions) + ): + name = "`%s`" % name + + return name + + def unescape_name(self, name): + """Unquote a string.""" + if name and name[0] == '"' and name[-1] == '"': + name = name[1:-1] + + return name + + def escaped_names(self, names): + return [self.escape_name(name) for name in names] + + def extend_special_commands(self, special_commands): + # Special commands are not part of all_completions since they can only + # be at the beginning of a line. + self.special_commands.extend(special_commands) + + def extend_database_names(self, databases): + self.databases.extend(databases) + + def extend_keywords(self, additional_keywords): + self.keywords.extend(additional_keywords) + self.all_completions.update(additional_keywords) + + def extend_schemata(self, schema): + if schema is None: + return + metadata = self.dbmetadata["tables"] + metadata[schema] = {} + + # dbmetadata.values() are the 'tables' and 'functions' dicts + for metadata in self.dbmetadata.values(): + metadata[schema] = {} + self.all_completions.update(schema) + + def extend_relations(self, data, kind): + """Extend metadata for tables or views + + :param data: list of (rel_name, ) tuples + :param kind: either 'tables' or 'views' + :return: + """ + # 'data' is a generator object. It can throw an exception while being + # consumed. This could happen if the user has launched the app without + # specifying a database name. This exception must be handled to prevent + # crashing. + try: + data = [self.escaped_names(d) for d in data] + except Exception: + data = [] + + # dbmetadata['tables'][$schema_name][$table_name] should be a list of + # column names. Default to an asterisk + metadata = self.dbmetadata[kind] + for relname in data: + try: + metadata[self.dbname][relname[0]] = ["*"] + except KeyError: + _logger.error( + "%r %r listed in unrecognized schema %r", + kind, + relname[0], + self.dbname, + ) + self.all_completions.add(relname[0]) + + def extend_columns(self, column_data, kind): + """Extend column metadata + + :param column_data: list of (rel_name, column_name) tuples + :param kind: either 'tables' or 'views' + :return: + """ + # 'column_data' is a generator object. It can throw an exception while + # being consumed. This could happen if the user has launched the app + # without specifying a database name. This exception must be handled to + # prevent crashing. + try: + column_data = [self.escaped_names(d) for d in column_data] + except Exception: + column_data = [] + + metadata = self.dbmetadata[kind] + for relname, column in column_data: + metadata[self.dbname][relname].append(column) + self.all_completions.add(column) + + def extend_functions(self, func_data): + # 'func_data' is a generator object. It can throw an exception while + # being consumed. This could happen if the user has launched the app + # without specifying a database name. This exception must be handled to + # prevent crashing. + try: + func_data = [self.escaped_names(d) for d in func_data] + except Exception: + func_data = [] + + # dbmetadata['functions'][$schema_name][$function_name] should return + # function metadata. + metadata = self.dbmetadata["functions"] + + for func in func_data: + metadata[self.dbname][func[0]] = None + self.all_completions.add(func[0]) + + def set_dbname(self, dbname): + self.dbname = dbname + + def reset_completions(self): + self.databases = [] + self.dbname = "" + self.dbmetadata = {"tables": {}, "views": {}, "functions": {}} + self.all_completions = set(self.keywords + self.functions) + + @staticmethod + def find_matches( + text, + collection, + start_only=False, + fuzzy=True, + casing=None, + punctuations="most_punctuations", + ): + """Find completion matches for the given text. + + Given the user's input text and a collection of available + completions, find completions matching the last word of the + text. + + If `start_only` is True, the text will match an available + completion only at the beginning. Otherwise, a completion is + considered a match if the text appears anywhere within it. + + yields prompt_toolkit Completion instances for any matches found + in the collection of available completions. + """ + last = last_word(text, include=punctuations) + text = last.lower() + + completions = [] + + if fuzzy: + regex = ".*?".join(map(escape, text)) + pat = compile("(%s)" % regex) + for item in sorted(collection): + r = pat.search(item.lower()) + if r: + completions.append((len(r.group()), r.start(), item)) + else: + match_end_limit = len(text) if start_only else None + for item in sorted(collection): + match_point = item.lower().find(text, 0, match_end_limit) + if match_point >= 0: + completions.append((len(text), match_point, item)) + + if casing == "auto": + casing = "lower" if last and last[-1].islower() else "upper" + + def apply_case(kw): + if casing == "upper": + return kw.upper() + return kw.lower() + + return ( + Completion(z if casing is None else apply_case(z), -len(text)) + for x, y, z in sorted(completions) + ) + + def get_completions(self, document, complete_event): + word_before_cursor = document.get_word_before_cursor(WORD=True) + completions = [] + suggestions = suggest_type(document.text, document.text_before_cursor) + + for suggestion in suggestions: + + _logger.debug("Suggestion type: %r", suggestion["type"]) + + if suggestion["type"] == "column": + tables = suggestion["tables"] + _logger.debug("Completion column scope: %r", tables) + scoped_cols = self.populate_scoped_cols(tables) + if suggestion.get("drop_unique"): + # drop_unique is used for 'tb11 JOIN tbl2 USING (...' + # which should suggest only columns that appear in more than + # one table + scoped_cols = [ + col + for (col, count) in Counter(scoped_cols).items() + if count > 1 and col != "*" + ] + + cols = self.find_matches(word_before_cursor, scoped_cols) + completions.extend(cols) + + elif suggestion["type"] == "function": + # suggest user-defined functions using substring matching + funcs = self.populate_schema_objects(suggestion["schema"], "functions") + user_funcs = self.find_matches(word_before_cursor, funcs) + completions.extend(user_funcs) + + # suggest hardcoded functions using startswith matching only if + # there is no schema qualifier. If a schema qualifier is + # present it probably denotes a table. + # eg: SELECT * FROM users u WHERE u. + if not suggestion["schema"]: + predefined_funcs = self.find_matches( + word_before_cursor, + self.functions, + start_only=True, + fuzzy=False, + casing=self.keyword_casing, + ) + completions.extend(predefined_funcs) + + elif suggestion["type"] == "table": + tables = self.populate_schema_objects(suggestion["schema"], "tables") + tables = self.find_matches(word_before_cursor, tables) + completions.extend(tables) + + elif suggestion["type"] == "view": + views = self.populate_schema_objects(suggestion["schema"], "views") + views = self.find_matches(word_before_cursor, views) + completions.extend(views) + + elif suggestion["type"] == "alias": + aliases = suggestion["aliases"] + aliases = self.find_matches(word_before_cursor, aliases) + completions.extend(aliases) + + elif suggestion["type"] == "database": + dbs = self.find_matches(word_before_cursor, self.databases) + completions.extend(dbs) + + elif suggestion["type"] == "keyword": + keywords = self.find_matches( + word_before_cursor, + self.keywords, + start_only=True, + fuzzy=False, + casing=self.keyword_casing, + punctuations="many_punctuations", + ) + completions.extend(keywords) + + elif suggestion["type"] == "special": + special = self.find_matches( + word_before_cursor, + self.special_commands, + start_only=True, + fuzzy=False, + punctuations="many_punctuations", + ) + completions.extend(special) + elif suggestion["type"] == "favoritequery": + queries = self.find_matches( + word_before_cursor, + favoritequeries.list(), + start_only=False, + fuzzy=True, + ) + completions.extend(queries) + elif suggestion["type"] == "table_format": + formats = self.find_matches( + word_before_cursor, self.table_formats, start_only=True, fuzzy=False + ) + completions.extend(formats) + elif suggestion["type"] == "file_name": + file_names = self.find_files(word_before_cursor) + completions.extend(file_names) + + return completions + + def find_files(self, word): + """Yield matching directory or file names. + + :param word: + :return: iterable + + """ + base_path, last_path, position = parse_path(word) + paths = suggest_path(word) + for name in sorted(paths): + suggestion = complete_path(name, last_path) + if suggestion: + yield Completion(suggestion, position) + + def populate_scoped_cols(self, scoped_tbls): + """Find all columns in a set of scoped_tables + :param scoped_tbls: list of (schema, table, alias) tuples + :return: list of column names + """ + columns = [] + meta = self.dbmetadata + + for tbl in scoped_tbls: + # A fully qualified schema.relname reference or default_schema + # DO NOT escape schema names. + schema = tbl[0] or self.dbname + relname = tbl[1] + escaped_relname = self.escape_name(tbl[1]) + + # We don't know if schema.relname is a table or view. Since + # tables and views cannot share the same name, we can check one + # at a time + try: + columns.extend(meta["tables"][schema][relname]) + + # Table exists, so don't bother checking for a view + continue + except KeyError: + try: + columns.extend(meta["tables"][schema][escaped_relname]) + # Table exists, so don't bother checking for a view + continue + except KeyError: + pass + + try: + columns.extend(meta["views"][schema][relname]) + except KeyError: + pass + + return columns + + def populate_schema_objects(self, schema, obj_type): + """Returns list of tables or functions for a (optional) schema""" + metadata = self.dbmetadata[obj_type] + schema = schema or self.dbname + + try: + objects = metadata[schema].keys() + except KeyError: + # schema doesn't exist + objects = [] + + return objects diff --git a/litecli/sqlexecute.py b/litecli/sqlexecute.py new file mode 100644 index 0000000..7ef103c --- /dev/null +++ b/litecli/sqlexecute.py @@ -0,0 +1,212 @@ +import logging +import sqlite3 +import uuid +from contextlib import closing +from sqlite3 import OperationalError + +import sqlparse +import os.path + +from .packages import special + +_logger = logging.getLogger(__name__) + +# FIELD_TYPES = decoders.copy() +# FIELD_TYPES.update({ +# FIELD_TYPE.NULL: type(None) +# }) + + +class SQLExecute(object): + + databases_query = """ + PRAGMA database_list + """ + + tables_query = """ + SELECT name + FROM sqlite_master + WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' + ORDER BY 1 + """ + + table_columns_query = """ + SELECT m.name as tableName, p.name as columnName + FROM sqlite_master m + LEFT OUTER JOIN pragma_table_info((m.name)) p ON m.name <> p.name + WHERE m.type IN ('table','view') AND m.name NOT LIKE 'sqlite_%' + ORDER BY tableName, columnName + """ + + functions_query = '''SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES + WHERE ROUTINE_TYPE="FUNCTION" AND ROUTINE_SCHEMA = "%s"''' + + def __init__(self, database): + self.dbname = database + self._server_type = None + self.connection_id = None + self.conn = None + if not database: + _logger.debug("Database is not specified. Skip connection.") + return + self.connect() + + def connect(self, database=None): + db = database or self.dbname + _logger.debug("Connection DB Params: \n" "\tdatabase: %r", database) + + db_name = os.path.expanduser(db) + db_dir_name = os.path.dirname(os.path.abspath(db_name)) + if not os.path.exists(db_dir_name): + raise Exception("Path does not exist: {}".format(db_dir_name)) + + conn = sqlite3.connect(database=db_name, isolation_level=None) + if self.conn: + self.conn.close() + + self.conn = conn + # Update them after the connection is made to ensure that it was a + # successful connection. + self.dbname = db + # retrieve connection id + self.reset_connection_id() + + def run(self, statement): + """Execute the sql in the database and return the results. The results + are a list of tuples. Each tuple has 4 values + (title, rows, headers, status). + """ + # Remove spaces and EOL + statement = statement.strip() + if not statement: # Empty string + yield (None, None, None, None) + + # Split the sql into separate queries and run each one. + # Unless it's saving a favorite query, in which case we + # want to save them all together. + if statement.startswith("\\fs"): + components = [statement] + else: + components = sqlparse.split(statement) + + for sql in components: + # Remove spaces, eol and semi-colons. + sql = sql.rstrip(";") + + # \G is treated specially since we have to set the expanded output. + if sql.endswith("\\G"): + special.set_expanded_output(True) + sql = sql[:-2].strip() + + if not self.conn and not ( + sql.startswith(".open") + or sql.lower().startswith("use") + or sql.startswith("\\u") + or sql.startswith("\\?") + or sql.startswith("\\q") + or sql.startswith("help") + or sql.startswith("exit") + or sql.startswith("quit") + ): + _logger.debug( + "Not connected to database. Will not run statement: %s.", sql + ) + raise OperationalError("Not connected to database.") + # yield ('Not connected to database', None, None, None) + # return + + cur = self.conn.cursor() if self.conn else None + try: # Special command + _logger.debug("Trying a dbspecial command. sql: %r", sql) + for result in special.execute(cur, sql): + yield result + except special.CommandNotFound: # Regular SQL + _logger.debug("Regular sql statement. sql: %r", sql) + cur.execute(sql) + yield self.get_result(cur) + + def get_result(self, cursor): + """Get the current result's data from the cursor.""" + title = headers = None + + # cursor.description is not None for queries that return result sets, + # e.g. SELECT. + if cursor.description is not None: + headers = [x[0] for x in cursor.description] + status = "{0} row{1} in set" + cursor = list(cursor) + rowcount = len(cursor) + else: + _logger.debug("No rows in result.") + status = "Query OK, {0} row{1} affected" + rowcount = 0 if cursor.rowcount == -1 else cursor.rowcount + cursor = None + + status = status.format(rowcount, "" if rowcount == 1 else "s") + + return (title, cursor, headers, status) + + def tables(self): + """Yields table names""" + + with closing(self.conn.cursor()) as cur: + _logger.debug("Tables Query. sql: %r", self.tables_query) + cur.execute(self.tables_query) + for row in cur: + yield row + + def table_columns(self): + """Yields column names""" + with closing(self.conn.cursor()) as cur: + _logger.debug("Columns Query. sql: %r", self.table_columns_query) + cur.execute(self.table_columns_query) + for row in cur: + yield row + + def databases(self): + if not self.conn: + return + + with closing(self.conn.cursor()) as cur: + _logger.debug("Databases Query. sql: %r", self.databases_query) + for row in cur.execute(self.databases_query): + yield row[1] + + def functions(self): + """Yields tuples of (schema_name, function_name)""" + + with closing(self.conn.cursor()) as cur: + _logger.debug("Functions Query. sql: %r", self.functions_query) + cur.execute(self.functions_query % self.dbname) + for row in cur: + yield row + + def show_candidates(self): + with closing(self.conn.cursor()) as cur: + _logger.debug("Show Query. sql: %r", self.show_candidates_query) + try: + cur.execute(self.show_candidates_query) + except sqlite3.DatabaseError as e: + _logger.error("No show completions due to %r", e) + yield "" + else: + for row in cur: + yield (row[0].split(None, 1)[-1],) + + def server_type(self): + self._server_type = ("sqlite3", "3") + return self._server_type + + def get_connection_id(self): + if not self.connection_id: + self.reset_connection_id() + return self.connection_id + + def reset_connection_id(self): + # Remember current connection id + _logger.debug("Get current connection id") + # res = self.run('select connection_id()') + self.connection_id = uuid.uuid4() + # for title, cur, headers, status in res: + # self.connection_id = cur.fetchone()[0] + _logger.debug("Current connection id: %s", self.connection_id) diff --git a/release.py b/release.py new file mode 100644 index 0000000..264a4c3 --- /dev/null +++ b/release.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python +"""A script to publish a release of litecli to PyPI.""" + +from __future__ import print_function +import io +from optparse import OptionParser +import re +import subprocess +import sys + +import click + +DEBUG = False +CONFIRM_STEPS = False +DRY_RUN = False + + +def skip_step(): + """ + Asks for user's response whether to run a step. Default is yes. + :return: boolean + """ + global CONFIRM_STEPS + + if CONFIRM_STEPS: + return not click.confirm("--- Run this step?", default=True) + return False + + +def run_step(*args): + """ + Prints out the command and asks if it should be run. + If yes (default), runs it. + :param args: list of strings (command and args) + """ + global DRY_RUN + + cmd = args + print(" ".join(cmd)) + if skip_step(): + print("--- Skipping...") + elif DRY_RUN: + print("--- Pretending to run...") + else: + subprocess.check_output(cmd) + + +def version(version_file): + _version_re = re.compile( + r'__version__\s+=\s+(?P[\'"])(?P.*)(?P=quote)' + ) + + with io.open(version_file, encoding="utf-8") as f: + ver = _version_re.search(f.read()).group("version") + + return ver + + +def commit_for_release(version_file, ver): + run_step("git", "reset") + run_step("git", "add", version_file) + run_step("git", "commit", "--message", "Releasing version {}".format(ver)) + + +def create_git_tag(tag_name): + run_step("git", "tag", tag_name) + + +def create_distribution_files(): + run_step("python", "setup.py", "sdist", "bdist_wheel") + + +def upload_distribution_files(): + run_step("twine", "upload", "dist/*") + + +def push_to_github(): + run_step("git", "push", "origin", "master") + + +def push_tags_to_github(): + run_step("git", "push", "--tags", "origin") + + +def checklist(questions): + for question in questions: + if not click.confirm("--- {}".format(question), default=False): + sys.exit(1) + + +if __name__ == "__main__": + if DEBUG: + subprocess.check_output = lambda x: x + + ver = version("litecli/__init__.py") + print("Releasing Version:", ver) + + parser = OptionParser() + parser.add_option( + "-c", + "--confirm-steps", + action="store_true", + dest="confirm_steps", + default=False, + help=( + "Confirm every step. If the step is not " "confirmed, it will be skipped." + ), + ) + parser.add_option( + "-d", + "--dry-run", + action="store_true", + dest="dry_run", + default=False, + help="Print out, but not actually run any steps.", + ) + + popts, pargs = parser.parse_args() + CONFIRM_STEPS = popts.confirm_steps + DRY_RUN = popts.dry_run + + if not click.confirm("Are you sure?", default=False): + sys.exit(1) + + commit_for_release("litecli/__init__.py", ver) + create_git_tag("v{}".format(ver)) + create_distribution_files() + push_to_github() + push_tags_to_github() + upload_distribution_files() diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..b95211a --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,9 @@ +mock +pytest>=3.6 +pytest-cov +tox +behave +pexpect +coverage +codecov +click diff --git a/screenshots/litecli.gif b/screenshots/litecli.gif new file mode 100644 index 0000000..9cfd80c Binary files /dev/null and b/screenshots/litecli.gif differ diff --git a/screenshots/litecli.png b/screenshots/litecli.png new file mode 100644 index 0000000..6ca999e Binary files /dev/null and b/screenshots/litecli.png differ diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..40eab0a --- /dev/null +++ b/setup.cfg @@ -0,0 +1,18 @@ +[bdist_wheel] +universal = 1 + +[tool:pytest] +addopts = --capture=sys + --showlocals + --doctest-modules + --doctest-ignore-import-errors + --ignore=setup.py + --ignore=litecli/magic.py + --ignore=litecli/packages/parseutils.py + --ignore=test/features + +[pep8] +rev = master +docformatter = True +diff = True +error-status = True diff --git a/setup.py b/setup.py new file mode 100755 index 0000000..acbb0d9 --- /dev/null +++ b/setup.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import ast +from io import open +import re +from setuptools import setup, find_packages + +_version_re = re.compile(r"__version__\s+=\s+(.*)") + +with open("litecli/__init__.py", "rb") as f: + version = str( + ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1)) + ) + + +def open_file(filename): + """Open and read the file *filename*.""" + with open(filename) as f: + return f.read() + + +readme = open_file("README.md") + +install_requirements = [ + "click >= 4.1", + "Pygments >= 1.6", + "prompt_toolkit>=3.0.3,<4.0.0", + "sqlparse", + "configobj >= 5.0.5", + "cli_helpers[styles] >= 1.0.1", +] + + +setup( + name="litecli", + author="dbcli", + author_email="litecli-users@googlegroups.com", + license="BSD", + version=version, + url="https://github.com/dbcli/litecli", + packages=find_packages(), + package_data={"litecli": ["liteclirc", "AUTHORS"]}, + description="CLI for SQLite Databases with auto-completion and syntax " + "highlighting.", + long_description=readme, + long_description_content_type="text/markdown", + install_requires=install_requirements, + # cmdclass={"test": test, "lint": lint}, + entry_points={ + "console_scripts": ["litecli = litecli.main:cli"], + "distutils.commands": ["lint = tasks:lint", "test = tasks:test"], + }, + classifiers=[ + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: Unix", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: SQL", + "Topic :: Database", + "Topic :: Database :: Front-Ends", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/tasks.py b/tasks.py new file mode 100644 index 0000000..5e68107 --- /dev/null +++ b/tasks.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- +"""Common development tasks for setup.py to use.""" + +import re +import subprocess +import sys + +from setuptools import Command +from setuptools.command.test import test as TestCommand + + +class BaseCommand(Command, object): + """The base command for project tasks.""" + + user_options = [] + + default_cmd_options = ("verbose", "quiet", "dry_run") + + def __init__(self, *args, **kwargs): + super(BaseCommand, self).__init__(*args, **kwargs) + self.verbose = False + + def initialize_options(self): + """Override the distutils abstract method.""" + pass + + def finalize_options(self): + """Override the distutils abstract method.""" + # Distutils uses incrementing integers for verbosity. + self.verbose = bool(self.verbose) + + def call_and_exit(self, cmd, shell=True): + """Run the *cmd* and exit with the proper exit code.""" + sys.exit(subprocess.call(cmd, shell=shell)) + + def call_in_sequence(self, cmds, shell=True): + """Run multiple commmands in a row, exiting if one fails.""" + for cmd in cmds: + if subprocess.call(cmd, shell=shell) == 1: + sys.exit(1) + + def apply_options(self, cmd, options=()): + """Apply command-line options.""" + for option in self.default_cmd_options + options: + cmd = self.apply_option(cmd, option, active=getattr(self, option, False)) + return cmd + + def apply_option(self, cmd, option, active=True): + """Apply a command-line option.""" + return re.sub( + r"{{{}\:(?P