diff options
Diffstat (limited to '')
60 files changed, 7409 insertions, 0 deletions
diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..aea4994 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,3 @@ +[run] +parallel = True +source = litecli diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000..c433202 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,2 @@ +# Black +f767afc80bd5bcc8f1b1cc1a134babc2dec4d239
\ No newline at end of file diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..3e14cc7 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,8 @@ +## Description +<!--- Describe your changes in detail. --> + + + +## Checklist +<!--- We appreciate your help and want to give you credit. Please take a moment to put an `x` in the boxes below as you complete them. --> +- [ ] I've added this contribution to the `CHANGELOG.md` file. diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..9ee36cf --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,44 @@ +name: litecli + +on: + pull_request: + paths-ignore: + - '**.md' + +jobs: + build: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ["3.7", "3.8", "3.9", "3.10"] + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install requirements + run: | + python -m pip install -U pip setuptools + pip install --no-cache-dir -e . + pip install -r requirements-dev.txt -U --upgrade-strategy=only-if-needed + + - name: Run unit tests + env: + PYTEST_PASSWORD: root + run: | + ./setup.py test --pytest-args="--cov-report= --cov=litecli" + + - name: Run Black + run: | + ./setup.py lint + if: matrix.python-version == '3.7' + + - name: Coverage + run: | + coverage report + codecov diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..63c3eb6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,18 @@ +.idea/ +/build +/dist +/litecli.egg-info +/htmlcov +/src +/test/behave.ini +/litecli_env +/.venv +/.eggs + +.vagrant +*.pyc +*.deb +*.swp +.cache/ +.coverage +.coverage.* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..67ba03d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,5 @@ +repos: +- repo: https://github.com/psf/black + rev: 22.3.0 + hooks: + - id: black diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..afeebd1 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,110 @@ +## 1.9.0 - 2022-06-06 + +### Features + +* Add support for ANSI escape sequences for coloring the prompt. +* Add support for `.indexes` command. +* Add an option to turn off the auto-completion menu. Completion menu can be + triggered by pressed the `<tab>` key when this option is set to False. Fixes + [#105](https://github.com/dbcli/litecli/issues/105). + +### Bug Fixes + +* Fix [#120](https://github.com/dbcli/litecli/issues/120). Make the `.read` command actually read and execute the commands from a file. +* Fix [#96](https://github.com/dbcli/litecli/issues/96) the crash in VI mode when pressing `r`. + +## 1.8.0 - 2022-03-29 + +### Features + +* Update compatible Python versions. (Thanks: [blazewicz]) +* Add support for Python 3.10. (Thanks: [blazewicz]) +* Drop support for Python 3.6. (Thanks: [blazewicz]) + +### Bug Fixes + +* Upgrade cli_helpers to workaround Pygments regression. +* Use get_terminal_size from shutil instead of click. + +## 1.7.0 - 2022-01-11 + +### Features + +* Add config option show_bottom_toolbar. + +### Bug Fixes + +* Pin pygments version to prevent breaking change. + +## 1.6.0 - 2021-03-15 + +### Features + +- Add verbose feature to `favorite_query` command. (Thanks: [Zhaolong Zhu]) + - `\f query` does not show the full SQL. + - `\f+ query` shows the full SQL. +- Add prompt format of file's basename. (Thanks: [elig0n]) + +### Bug Fixes + +- Fix compatibility with sqlparse >= 0.4.0. (Thanks: [chocolateboy]) +- Fix invalid utf-8 exception. (Thanks: [Amjith]) + +## 1.4.1 - 2020-07-27 + +### Bug Fixes + +- Fix setup.py to set `long_description_content_type` as markdown. + +## 1.4.0 - 2020-07-27 + +### Features + +- Add NULLS FIRST and NULLS LAST to keywords. (Thanks: [Amjith]) + +## 1.3.2 - 2020-03-11 + +- Fix the completion engine to work with newer sqlparse. + +## 1.3.1 - 2020-03-11 + +- Remove the version pinning of sqlparse package. + +## 1.3.0 - 2020-02-11 + +### Features + +- Added `.import` command for importing data from file into table. (Thanks: [Zhaolong Zhu]) +- Upgraded to prompt-toolkit 3.x. + +## 1.2.0 - 2019-10-26 + +### Features + +- Enhance the `describe` command. (Thanks: [Amjith]) +- Autocomplete table names for special commands. (Thanks: [Amjith]) + +## 1.1.0 - 2019-07-14 + +### Features + +- Added `.read` command for reading scripts. +- Added `.load` command for loading extension libraries. (Thanks: [Zhiming Wang]) +- Add support for using `?` as a placeholder in the favorite queries. (Thanks: [Amjith]) +- Added shift-tab to select the previous entry in the completion menu. [Amjith] +- Added `describe` and `desc` keywords. + +### Bug Fixes + +- Clear error message when directory does not exist. (Thanks: [Irina Truong]) + +## 1.0.0 - 2019-01-04 + +- To new beginnings. :tada: + +[Amjith]: https://blog.amjith.com +[chocolateboy]: https://github.com/chocolateboy +[Irina Truong]: https://github.com/j-bennet +[Shawn Chapla]: https://github.com/shwnchpl +[Zhaolong Zhu]: https://github.com/zzl0 +[Zhiming Wang]: https://github.com/zmwangx diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..9cd868a --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,119 @@ +# Development Guide + +This is a guide for developers who would like to contribute to this project. It is recommended to use Python 3.7 and above for development. + +If you're interested in contributing to litecli, thank you. We'd love your help! +You'll always get credit for your work. + +## GitHub Workflow + +1. [Fork the repository](https://github.com/dbcli/litecli) on GitHub. + +2. Clone your fork locally: + ```bash + $ git clone <url-for-your-fork> + ``` + +3. Add the official repository (`upstream`) as a remote repository: + ```bash + $ git remote add upstream git@github.com:dbcli/litecli.git + ``` + +4. Set up a [virtual environment](http://docs.python-guide.org/en/latest/dev/virtualenvs) + for development: + + ```bash + $ cd litecli + $ pip install virtualenv + $ virtualenv litecli_dev + ``` + + We've just created a virtual environment that we'll use to install all the dependencies + and tools we need to work on litecli. Whenever you want to work on litecli, you + need to activate the virtual environment: + + ```bash + $ source litecli_dev/bin/activate + ``` + + When you're done working, you can deactivate the virtual environment: + + ```bash + $ deactivate + ``` + +5. Install the dependencies and development tools: + + ```bash + $ pip install -r requirements-dev.txt + $ pip install --editable . + ``` + +6. Create a branch for your bugfix or feature based off the `master` branch: + + ```bash + $ git checkout -b <name-of-bugfix-or-feature> + ``` + +7. While you work on your bugfix or feature, be sure to pull the latest changes from `upstream`. This ensures that your local codebase is up-to-date: + + ```bash + $ git pull upstream master + ``` + +8. When your work is ready for the litecli team to review it, push your branch to your fork: + + ```bash + $ git push origin <name-of-bugfix-or-feature> + ``` + +9. [Create a pull request](https://help.github.com/articles/creating-a-pull-request-from-a-fork/) on GitHub. + + +## Running the Tests + +While you work on litecli, it's important to run the tests to make sure your code +hasn't broken any existing functionality. To run the tests, just type in: + +```bash +$ ./setup.py test +``` + +litecli supports Python 3.7+. You can test against multiple versions of +Python by running tox: + +```bash +$ tox +``` + + +### CLI Tests + +Some CLI tests expect the program `ex` to be a symbolic link to `vim`. + +In some systems (e.g. Arch Linux) `ex` is a symbolic link to `vi`, which will +change the output and therefore make some tests fail. + +You can check this by running: +```bash +$ readlink -f $(which ex) +``` + + +## Coding Style + +litecli uses [black](https://github.com/ambv/black) to format the source code. Make sure to install black. + +It's easy to check the style of your code, just run: + +```bash +$ ./setup.py lint +``` + +If you see any style issues, you can automatically fix them by running: + +```bash +$ ./setup.py lint --fix +``` + +Be sure to commit and push any stylistic fixes. @@ -0,0 +1,27 @@ +Copyright (c) 2018, dbcli +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of dbcli nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..f1ff0f6 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,8 @@ +include *.txt *.py +include LICENSE CHANGELOG.md +include tox.ini +recursive-include tests *.py +recursive-include tests *.txt +recursive-include tests *.csv +recursive-include tests liteclirc +recursive-include litecli AUTHORS diff --git a/README.md b/README.md new file mode 100644 index 0000000..649481d --- /dev/null +++ b/README.md @@ -0,0 +1,54 @@ +# litecli + +[![Build Status](https://travis-ci.org/dbcli/litecli.svg?branch=master)](https://travis-ci.org/dbcli/litecli) + +[Docs](https://litecli.com) + +A command-line client for SQLite databases that has auto-completion and syntax highlighting. + +![Completion](screenshots/litecli.png) +![CompletionGif](screenshots/litecli.gif) + +## Installation + +If you already know how to install python packages, then you can install it via pip: + +You might need sudo on linux. + +``` +$ pip install -U litecli +``` + +The package is also available on Arch Linux through AUR in two versions: [litecli](https://aur.archlinux.org/packages/litecli/) is based the latest release (git tag) and [litecli-git](https://aur.archlinux.org/packages/litecli-git/) is based on the master branch of the git repo. You can install them manually or with an AUR helper such as `yay`: + +``` +$ yay -S litecli +``` +or + +``` +$ yay -S litecli-git +``` + +For MacOS users, you can also use Homebrew to install it: + +``` +$ brew install litecli +``` + +## Usage + +``` +$ litecli --help + +Usage: litecli [OPTIONS] [DATABASE] + +Examples: + - litecli sqlite_db_name +``` + +A config file is automatically created at `~/.config/litecli/config` at first launch. For Windows machines a config file is created at `~\AppData\Local\dbcli\litecli\config` at first launch. See the file itself for a description of all available options. + +## Docs + +Visit: [litecli.com/features](https://litecli.com/features) @@ -0,0 +1,3 @@ +* [] Sort by frecency. +* [] Add completions when an attach database command is run. +* [] Add behave tests. diff --git a/litecli/AUTHORS b/litecli/AUTHORS new file mode 100644 index 0000000..194cdc7 --- /dev/null +++ b/litecli/AUTHORS @@ -0,0 +1,21 @@ +Project Lead: +------------- + + * Delgermurun Purevkhu + + +Core Developers: +---------------- + + * Amjith Ramanujam + * Irina Truong + * Dick Marinus + +Contributors: +------------- + + * Thomas Roten + * Zhaolong Zhu + * Zhiming Wang + * Shawn M. Chapla + * Paweł Sacawa diff --git a/litecli/__init__.py b/litecli/__init__.py new file mode 100644 index 0000000..0a0a43a --- /dev/null +++ b/litecli/__init__.py @@ -0,0 +1 @@ +__version__ = "1.9.0" diff --git a/litecli/clibuffer.py b/litecli/clibuffer.py new file mode 100644 index 0000000..a57192a --- /dev/null +++ b/litecli/clibuffer.py @@ -0,0 +1,40 @@ +from __future__ import unicode_literals + +from prompt_toolkit.enums import DEFAULT_BUFFER +from prompt_toolkit.filters import Condition +from prompt_toolkit.application import get_app + + +def cli_is_multiline(cli): + @Condition + def cond(): + doc = get_app().layout.get_buffer_by_name(DEFAULT_BUFFER).document + + if not cli.multi_line: + return False + else: + return not _multiline_exception(doc.text) + + return cond + + +def _multiline_exception(text): + orig = text + text = text.strip() + + # Multi-statement favorite query is a special case. Because there will + # be a semicolon separating statements, we can't consider semicolon an + # EOL. Let's consider an empty line an EOL instead. + if text.startswith("\\fs"): + return orig.endswith("\n") + + return ( + text.startswith("\\") # Special Command + or text.endswith(";") # Ended with a semi-colon + or text.endswith("\\g") # Ended with \g + or text.endswith("\\G") # Ended with \G + or (text == "exit") # Exit doesn't need semi-colon + or (text == "quit") # Quit doesn't need semi-colon + or (text == ":q") # To all the vim fans out there + or (text == "") # Just a plain enter without any text + ) diff --git a/litecli/clistyle.py b/litecli/clistyle.py new file mode 100644 index 0000000..7527315 --- /dev/null +++ b/litecli/clistyle.py @@ -0,0 +1,114 @@ +from __future__ import unicode_literals + +import logging + +import pygments.styles +from pygments.token import string_to_tokentype, Token +from pygments.style import Style as PygmentsStyle +from pygments.util import ClassNotFound +from prompt_toolkit.styles.pygments import style_from_pygments_cls +from prompt_toolkit.styles import merge_styles, Style + +logger = logging.getLogger(__name__) + +# map Pygments tokens (ptk 1.0) to class names (ptk 2.0). +TOKEN_TO_PROMPT_STYLE = { + Token.Menu.Completions.Completion.Current: "completion-menu.completion.current", + Token.Menu.Completions.Completion: "completion-menu.completion", + Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current", + Token.Menu.Completions.Meta: "completion-menu.meta.completion", + Token.Menu.Completions.MultiColumnMeta: "completion-menu.multi-column-meta", + Token.Menu.Completions.ProgressButton: "scrollbar.arrow", # best guess + Token.Menu.Completions.ProgressBar: "scrollbar", # best guess + Token.SelectedText: "selected", + Token.SearchMatch: "search", + Token.SearchMatch.Current: "search.current", + Token.Toolbar: "bottom-toolbar", + Token.Toolbar.Off: "bottom-toolbar.off", + Token.Toolbar.On: "bottom-toolbar.on", + Token.Toolbar.Search: "search-toolbar", + Token.Toolbar.Search.Text: "search-toolbar.text", + Token.Toolbar.System: "system-toolbar", + Token.Toolbar.Arg: "arg-toolbar", + Token.Toolbar.Arg.Text: "arg-toolbar.text", + Token.Toolbar.Transaction.Valid: "bottom-toolbar.transaction.valid", + Token.Toolbar.Transaction.Failed: "bottom-toolbar.transaction.failed", + Token.Output.Header: "output.header", + Token.Output.OddRow: "output.odd-row", + Token.Output.EvenRow: "output.even-row", + Token.Prompt: "prompt", + Token.Continuation: "continuation", +} + +# reverse dict for cli_helpers, because they still expect Pygments tokens. +PROMPT_STYLE_TO_TOKEN = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()} + + +def parse_pygments_style(token_name, style_object, style_dict): + """Parse token type and style string. + + :param token_name: str name of Pygments token. Example: "Token.String" + :param style_object: pygments.style.Style instance to use as base + :param style_dict: dict of token names and their styles, customized to this cli + + """ + token_type = string_to_tokentype(token_name) + try: + other_token_type = string_to_tokentype(style_dict[token_name]) + return token_type, style_object.styles[other_token_type] + except AttributeError as err: + return token_type, style_dict[token_name] + + +def style_factory(name, cli_style): + try: + style = pygments.styles.get_style_by_name(name) + except ClassNotFound: + style = pygments.styles.get_style_by_name("native") + + prompt_styles = [] + # prompt-toolkit used pygments tokens for styling before, switched to style + # names in 2.0. Convert old token types to new style names, for backwards compatibility. + for token in cli_style: + if token.startswith("Token."): + # treat as pygments token (1.0) + token_type, style_value = parse_pygments_style(token, style, cli_style) + if token_type in TOKEN_TO_PROMPT_STYLE: + prompt_style = TOKEN_TO_PROMPT_STYLE[token_type] + prompt_styles.append((prompt_style, style_value)) + else: + # we don't want to support tokens anymore + logger.error("Unhandled style / class name: %s", token) + else: + # treat as prompt style name (2.0). See default style names here: + # https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py + prompt_styles.append((token, cli_style[token])) + + override_style = Style([("bottom-toolbar", "noreverse")]) + return merge_styles( + [style_from_pygments_cls(style), override_style, Style(prompt_styles)] + ) + + +def style_factory_output(name, cli_style): + try: + style = pygments.styles.get_style_by_name(name).styles + except ClassNotFound: + style = pygments.styles.get_style_by_name("native").styles + + for token in cli_style: + if token.startswith("Token."): + token_type, style_value = parse_pygments_style(token, style, cli_style) + style.update({token_type: style_value}) + elif token in PROMPT_STYLE_TO_TOKEN: + token_type = PROMPT_STYLE_TO_TOKEN[token] + style.update({token_type: cli_style[token]}) + else: + # TODO: cli helpers will have to switch to ptk.Style + logger.error("Unhandled style / class name: %s", token) + + class OutputStyle(PygmentsStyle): + default_style = "" + styles = style + + return OutputStyle diff --git a/litecli/clitoolbar.py b/litecli/clitoolbar.py new file mode 100644 index 0000000..1e28784 --- /dev/null +++ b/litecli/clitoolbar.py @@ -0,0 +1,52 @@ +from __future__ import unicode_literals + +from prompt_toolkit.key_binding.vi_state import InputMode +from prompt_toolkit.enums import EditingMode +from prompt_toolkit.application import get_app + + +def create_toolbar_tokens_func(cli, show_fish_help): + """ + Return a function that generates the toolbar tokens. + """ + + def get_toolbar_tokens(): + result = [] + result.append(("class:bottom-toolbar", " ")) + + if cli.multi_line: + result.append( + ("class:bottom-toolbar", " (Semi-colon [;] will end the line) ") + ) + + if cli.multi_line: + result.append(("class:bottom-toolbar.on", "[F3] Multiline: ON ")) + else: + result.append(("class:bottom-toolbar.off", "[F3] Multiline: OFF ")) + if cli.prompt_app.editing_mode == EditingMode.VI: + result.append( + ("class:botton-toolbar.on", "Vi-mode ({})".format(_get_vi_mode())) + ) + + if show_fish_help(): + result.append( + ("class:bottom-toolbar", " Right-arrow to complete suggestion") + ) + + if cli.completion_refresher.is_refreshing(): + result.append(("class:bottom-toolbar", " Refreshing completions...")) + + return result + + return get_toolbar_tokens + + +def _get_vi_mode(): + """Get the current vi mode for display.""" + return { + InputMode.INSERT: "I", + InputMode.NAVIGATION: "N", + InputMode.REPLACE: "R", + InputMode.INSERT_MULTIPLE: "M", + InputMode.REPLACE_SINGLE: "R", + }[get_app().vi_state.input_mode] diff --git a/litecli/compat.py b/litecli/compat.py new file mode 100644 index 0000000..7316261 --- /dev/null +++ b/litecli/compat.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- +"""Platform and Python version compatibility support.""" + +import sys + + +PY2 = sys.version_info[0] == 2 +PY3 = sys.version_info[0] == 3 +WIN = sys.platform in ("win32", "cygwin") diff --git a/litecli/completion_refresher.py b/litecli/completion_refresher.py new file mode 100644 index 0000000..9602070 --- /dev/null +++ b/litecli/completion_refresher.py @@ -0,0 +1,131 @@ +import threading +from .packages.special.main import COMMANDS +from collections import OrderedDict + +from .sqlcompleter import SQLCompleter +from .sqlexecute import SQLExecute + + +class CompletionRefresher(object): + + refreshers = OrderedDict() + + def __init__(self): + self._completer_thread = None + self._restart_refresh = threading.Event() + + def refresh(self, executor, callbacks, completer_options=None): + """Creates a SQLCompleter object and populates it with the relevant + completion suggestions in a background thread. + + executor - SQLExecute object, used to extract the credentials to connect + to the database. + callbacks - A function or a list of functions to call after the thread + has completed the refresh. The newly created completion + object will be passed in as an argument to each callback. + completer_options - dict of options to pass to SQLCompleter. + + """ + if completer_options is None: + completer_options = {} + + if self.is_refreshing(): + self._restart_refresh.set() + return [(None, None, None, "Auto-completion refresh restarted.")] + else: + if executor.dbname == ":memory:": + # if DB is memory, needed to use same connection + # So can't use same connection with different thread + self._bg_refresh(executor, callbacks, completer_options) + else: + self._completer_thread = threading.Thread( + target=self._bg_refresh, + args=(executor, callbacks, completer_options), + name="completion_refresh", + ) + self._completer_thread.setDaemon(True) + self._completer_thread.start() + return [ + ( + None, + None, + None, + "Auto-completion refresh started in the background.", + ) + ] + + def is_refreshing(self): + return self._completer_thread and self._completer_thread.is_alive() + + def _bg_refresh(self, sqlexecute, callbacks, completer_options): + completer = SQLCompleter(**completer_options) + + e = sqlexecute + if e.dbname == ":memory:": + # if DB is memory, needed to use same connection + executor = sqlexecute + else: + # Create a new sqlexecute method to popoulate the completions. + executor = SQLExecute(e.dbname) + + # If callbacks is a single function then push it into a list. + if callable(callbacks): + callbacks = [callbacks] + + while 1: + for refresher in self.refreshers.values(): + refresher(completer, executor) + if self._restart_refresh.is_set(): + self._restart_refresh.clear() + break + else: + # Break out of while loop if the for loop finishes natually + # without hitting the break statement. + break + + # Start over the refresh from the beginning if the for loop hit the + # break statement. + continue + + for callback in callbacks: + callback(completer) + + +def refresher(name, refreshers=CompletionRefresher.refreshers): + """Decorator to add the decorated function to the dictionary of + refreshers. Any function decorated with a @refresher will be executed as + part of the completion refresh routine.""" + + def wrapper(wrapped): + refreshers[name] = wrapped + return wrapped + + return wrapper + + +@refresher("databases") +def refresh_databases(completer, executor): + completer.extend_database_names(executor.databases()) + + +@refresher("schemata") +def refresh_schemata(completer, executor): + # name of the current database. + completer.extend_schemata(executor.dbname) + completer.set_dbname(executor.dbname) + + +@refresher("tables") +def refresh_tables(completer, executor): + completer.extend_relations(executor.tables(), kind="tables") + completer.extend_columns(executor.table_columns(), kind="tables") + + +@refresher("functions") +def refresh_functions(completer, executor): + completer.extend_functions(executor.functions()) + + +@refresher("special_commands") +def refresh_special(completer, executor): + completer.extend_special_commands(COMMANDS.keys()) diff --git a/litecli/config.py b/litecli/config.py new file mode 100644 index 0000000..1c7fb25 --- /dev/null +++ b/litecli/config.py @@ -0,0 +1,62 @@ +import errno +import shutil +import os +import platform +from os.path import expanduser, exists, dirname +from configobj import ConfigObj + + +def config_location(): + if "XDG_CONFIG_HOME" in os.environ: + return "%s/litecli/" % expanduser(os.environ["XDG_CONFIG_HOME"]) + elif platform.system() == "Windows": + return os.getenv("USERPROFILE") + "\\AppData\\Local\\dbcli\\litecli\\" + else: + return expanduser("~/.config/litecli/") + + +def load_config(usr_cfg, def_cfg=None): + cfg = ConfigObj() + cfg.merge(ConfigObj(def_cfg, interpolation=False)) + cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8")) + cfg.filename = expanduser(usr_cfg) + + return cfg + + +def ensure_dir_exists(path): + parent_dir = expanduser(dirname(path)) + try: + os.makedirs(parent_dir) + except OSError as exc: + # ignore existing destination (py2 has no exist_ok arg to makedirs) + if exc.errno != errno.EEXIST: + raise + + +def write_default_config(source, destination, overwrite=False): + destination = expanduser(destination) + if not overwrite and exists(destination): + return + + ensure_dir_exists(destination) + + shutil.copyfile(source, destination) + + +def upgrade_config(config, def_config): + cfg = load_config(config, def_config) + cfg.write() + + +def get_config(liteclirc_file=None): + from litecli import __file__ as package_root + + package_root = os.path.dirname(package_root) + + liteclirc_file = liteclirc_file or "%sconfig" % config_location() + + default_config = os.path.join(package_root, "liteclirc") + write_default_config(default_config, liteclirc_file) + + return load_config(liteclirc_file, default_config) diff --git a/litecli/encodingutils.py b/litecli/encodingutils.py new file mode 100644 index 0000000..6caf14d --- /dev/null +++ b/litecli/encodingutils.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +from litecli.compat import PY2 + + +if PY2: + binary_type = str + string_types = basestring + text_type = unicode +else: + binary_type = bytes + string_types = str + text_type = str + + +def unicode2utf8(arg): + """Convert strings to UTF8-encoded bytes. + + Only in Python 2. In Python 3 the args are expected as unicode. + + """ + + if PY2 and isinstance(arg, text_type): + return arg.encode("utf-8") + return arg + + +def utf8tounicode(arg): + """Convert UTF8-encoded bytes to strings. + + Only in Python 2. In Python 3 the errors are returned as strings. + + """ + + if PY2 and isinstance(arg, binary_type): + return arg.decode("utf-8") + return arg diff --git a/litecli/key_bindings.py b/litecli/key_bindings.py new file mode 100644 index 0000000..44d59d2 --- /dev/null +++ b/litecli/key_bindings.py @@ -0,0 +1,84 @@ +from __future__ import unicode_literals +import logging +from prompt_toolkit.enums import EditingMode +from prompt_toolkit.filters import completion_is_selected +from prompt_toolkit.key_binding import KeyBindings + +_logger = logging.getLogger(__name__) + + +def cli_bindings(cli): + """Custom key bindings for cli.""" + kb = KeyBindings() + + @kb.add("f3") + def _(event): + """Enable/Disable Multiline Mode.""" + _logger.debug("Detected F3 key.") + cli.multi_line = not cli.multi_line + + @kb.add("f4") + def _(event): + """Toggle between Vi and Emacs mode.""" + _logger.debug("Detected F4 key.") + if cli.key_bindings == "vi": + event.app.editing_mode = EditingMode.EMACS + cli.key_bindings = "emacs" + else: + event.app.editing_mode = EditingMode.VI + cli.key_bindings = "vi" + + @kb.add("tab") + def _(event): + """Force autocompletion at cursor.""" + _logger.debug("Detected <Tab> key.") + b = event.app.current_buffer + if b.complete_state: + b.complete_next() + else: + b.start_completion(select_first=True) + + @kb.add("s-tab") + def _(event): + """Force autocompletion at cursor.""" + _logger.debug("Detected <Tab> key.") + b = event.app.current_buffer + if b.complete_state: + b.complete_previous() + else: + b.start_completion(select_last=True) + + @kb.add("c-space") + def _(event): + """ + Initialize autocompletion at cursor. + + If the autocompletion menu is not showing, display it with the + appropriate completions for the context. + + If the menu is showing, select the next completion. + """ + _logger.debug("Detected <C-Space> key.") + + b = event.app.current_buffer + if b.complete_state: + b.complete_next() + else: + b.start_completion(select_first=False) + + @kb.add("enter", filter=completion_is_selected) + def _(event): + """Makes the enter key work as the tab key only when showing the menu. + + In other words, don't execute query when enter is pressed in + the completion dropdown menu, instead close the dropdown menu + (accept current selection). + + """ + _logger.debug("Detected enter key.") + + event.current_buffer.complete_state = None + b = event.app.current_buffer + b.complete_state = None + + return kb diff --git a/litecli/lexer.py b/litecli/lexer.py new file mode 100644 index 0000000..678eb3f --- /dev/null +++ b/litecli/lexer.py @@ -0,0 +1,9 @@ +from pygments.lexer import inherit +from pygments.lexers.sql import MySqlLexer +from pygments.token import Keyword + + +class LiteCliLexer(MySqlLexer): + """Extends SQLite lexer to add keywords.""" + + tokens = {"root": [(r"\brepair\b", Keyword), (r"\boffset\b", Keyword), inherit]} diff --git a/litecli/liteclirc b/litecli/liteclirc new file mode 100644 index 0000000..4db6f3a --- /dev/null +++ b/litecli/liteclirc @@ -0,0 +1,122 @@ +# vi: ft=dosini +[main] + +# Multi-line mode allows breaking up the sql statements into multiple lines. If +# this is set to True, then the end of the statements must have a semi-colon. +# If this is set to False then sql statements can't be split into multiple +# lines. End of line (return) is considered as the end of the statement. +multi_line = False + +# Destructive warning mode will alert you before executing a sql statement +# that may cause harm to the database such as "drop table", "drop database" +# or "shutdown". +destructive_warning = True + +# log_file location. +# In Unix/Linux: ~/.config/litecli/log +# In Windows: %USERPROFILE%\AppData\Local\dbcli\litecli\log +# %USERPROFILE% is typically C:\Users\{username} +log_file = default + +# Default log level. Possible values: "CRITICAL", "ERROR", "WARNING", "INFO" +# and "DEBUG". "NONE" disables logging. +log_level = INFO + +# Log every query and its results to a file. Enable this by uncommenting the +# line below. +# audit_log = ~/.litecli-audit.log + +# Default pager. +# By default '$PAGER' environment variable is used +# pager = less -SRXF + +# Table format. Possible values: +# ascii, double, github, psql, plain, simple, grid, fancy_grid, pipe, orgtbl, +# rst, mediawiki, html, latex, latex_booktabs, textile, moinmoin, jira, +# vertical, tsv, csv. +# Recommended: ascii +table_format = ascii + +# Syntax coloring style. Possible values (many support the "-dark" suffix): +# manni, igor, xcode, vim, autumn, vs, rrt, native, perldoc, borland, tango, emacs, +# friendly, monokai, paraiso, colorful, murphy, bw, pastie, paraiso, trac, default, +# fruity. +# Screenshots at http://mycli.net/syntax +syntax_style = default + +# Keybindings: Possible values: emacs, vi. +# Emacs mode: Ctrl-A is home, Ctrl-E is end. All emacs keybindings are available in the REPL. +# When Vi mode is enabled you can use modal editing features offered by Vi in the REPL. +key_bindings = emacs + +# Enabling this option will show the suggestions in a wider menu. Thus more items are suggested. +wider_completion_menu = False + +# Autocompletion is on by default. This can be truned off by setting this +# option to False. Pressing tab will still trigger completion. +autocompletion = True + +# litecli prompt +# \D - The full current date +# \d - Database name +# \f - File basename of the "main" database +# \m - Minutes of the current time +# \n - Newline +# \P - AM/PM +# \R - The current time, in 24-hour military time (0-23) +# \r - The current time, standard 12-hour time (1-12) +# \s - Seconds of the current time +# \x1b[...m - insert ANSI escape sequence +prompt = '\d> ' +prompt_continuation = '-> ' + +# Show/hide the informational toolbar with function keymap at the footer. +show_bottom_toolbar = True + +# Skip intro info on startup and outro info on exit +less_chatty = False + +# Use alias from --login-path instead of host name in prompt +login_path_as_host = False + +# Cause result sets to be displayed vertically if they are too wide for the current window, +# and using normal tabular format otherwise. (This applies to statements terminated by ; or \G.) +auto_vertical_output = False + +# keyword casing preference. Possible values "lower", "upper", "auto" +keyword_casing = auto + +# disabled pager on startup +enable_pager = True + +# Custom colors for the completion menu, toolbar, etc. +[colors] +completion-menu.completion.current = 'bg:#ffffff #000000' +completion-menu.completion = 'bg:#008888 #ffffff' +completion-menu.meta.completion.current = 'bg:#44aaaa #000000' +completion-menu.meta.completion = 'bg:#448888 #ffffff' +completion-menu.multi-column-meta = 'bg:#aaffff #000000' +scrollbar.arrow = 'bg:#003333' +scrollbar = 'bg:#00aaaa' +selected = '#ffffff bg:#6666aa' +search = '#ffffff bg:#4444aa' +search.current = '#ffffff bg:#44aa44' +bottom-toolbar = 'bg:#222222 #aaaaaa' +bottom-toolbar.off = 'bg:#222222 #888888' +bottom-toolbar.on = 'bg:#222222 #ffffff' +search-toolbar = 'noinherit bold' +search-toolbar.text = 'nobold' +system-toolbar = 'noinherit bold' +arg-toolbar = 'noinherit bold' +arg-toolbar.text = 'nobold' +bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold' +bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold' + +# style classes for colored table output +output.header = "#00ff5f bold" +output.odd-row = "" +output.even-row = "" + + +# Favorite queries. +[favorite_queries] diff --git a/litecli/main.py b/litecli/main.py new file mode 100644 index 0000000..de279f6 --- /dev/null +++ b/litecli/main.py @@ -0,0 +1,1017 @@ +from __future__ import unicode_literals +from __future__ import print_function + +import os +import sys +import traceback +import logging +import threading +from time import time +from datetime import datetime +from io import open +from collections import namedtuple +from sqlite3 import OperationalError +import shutil + +from cli_helpers.tabular_output import TabularOutputFormatter +from cli_helpers.tabular_output import preprocessors +import click +import sqlparse +from prompt_toolkit.completion import DynamicCompleter +from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode +from prompt_toolkit.shortcuts import PromptSession, CompleteStyle +from prompt_toolkit.styles.pygments import style_from_pygments_cls +from prompt_toolkit.document import Document +from prompt_toolkit.filters import HasFocus, IsDone +from prompt_toolkit.formatted_text import ANSI +from prompt_toolkit.layout.processors import ( + HighlightMatchingBracketProcessor, + ConditionalProcessor, +) +from prompt_toolkit.lexers import PygmentsLexer +from prompt_toolkit.history import FileHistory +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory + +from .packages.special.main import NO_QUERY +from .packages.prompt_utils import confirm, confirm_destructive_query +from .packages import special +from .sqlcompleter import SQLCompleter +from .clitoolbar import create_toolbar_tokens_func +from .clistyle import style_factory, style_factory_output +from .sqlexecute import SQLExecute +from .clibuffer import cli_is_multiline +from .completion_refresher import CompletionRefresher +from .config import config_location, ensure_dir_exists, get_config +from .key_bindings import cli_bindings +from .encodingutils import utf8tounicode, text_type +from .lexer import LiteCliLexer +from .__init__ import __version__ +from .packages.filepaths import dir_path_exists + +import itertools + +click.disable_unicode_literals_warning = True + +# Query tuples are used for maintaining history +Query = namedtuple("Query", ["query", "successful", "mutating"]) + +PACKAGE_ROOT = os.path.abspath(os.path.dirname(__file__)) + + +class LiteCli(object): + + default_prompt = "\\d> " + max_len_prompt = 45 + + def __init__( + self, + sqlexecute=None, + prompt=None, + logfile=None, + auto_vertical_output=False, + warn=None, + liteclirc=None, + ): + self.sqlexecute = sqlexecute + self.logfile = logfile + + # Load config. + c = self.config = get_config(liteclirc) + + self.multi_line = c["main"].as_bool("multi_line") + self.key_bindings = c["main"]["key_bindings"] + special.set_favorite_queries(self.config) + self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"]) + self.formatter.litecli = self + self.syntax_style = c["main"]["syntax_style"] + self.less_chatty = c["main"].as_bool("less_chatty") + self.show_bottom_toolbar = c["main"].as_bool("show_bottom_toolbar") + self.cli_style = c["colors"] + self.output_style = style_factory_output(self.syntax_style, self.cli_style) + self.wider_completion_menu = c["main"].as_bool("wider_completion_menu") + self.autocompletion = c["main"].as_bool("autocompletion") + c_dest_warning = c["main"].as_bool("destructive_warning") + self.destructive_warning = c_dest_warning if warn is None else warn + self.login_path_as_host = c["main"].as_bool("login_path_as_host") + + # read from cli argument or user config file + self.auto_vertical_output = auto_vertical_output or c["main"].as_bool( + "auto_vertical_output" + ) + + # audit log + if self.logfile is None and "audit_log" in c["main"]: + try: + self.logfile = open(os.path.expanduser(c["main"]["audit_log"]), "a") + except (IOError, OSError): + self.echo( + "Error: Unable to open the audit log file. Your queries will not be logged.", + err=True, + fg="red", + ) + self.logfile = False + + self.completion_refresher = CompletionRefresher() + + self.logger = logging.getLogger(__name__) + self.initialize_logging() + + prompt_cnf = self.read_my_cnf_files(["prompt"])["prompt"] + self.prompt_format = ( + prompt or prompt_cnf or c["main"]["prompt"] or self.default_prompt + ) + self.prompt_continuation_format = c["main"]["prompt_continuation"] + keyword_casing = c["main"].get("keyword_casing", "auto") + + self.query_history = [] + + # Initialize completer. + self.completer = SQLCompleter( + supported_formats=self.formatter.supported_formats, + keyword_casing=keyword_casing, + ) + self._completer_lock = threading.Lock() + + # Register custom special commands. + self.register_special_commands() + + self.prompt_app = None + + def register_special_commands(self): + special.register_special_command( + self.change_db, + ".open", + ".open", + "Change to a new database.", + aliases=("use", "\\u"), + ) + special.register_special_command( + self.refresh_completions, + "rehash", + "\\#", + "Refresh auto-completions.", + arg_type=NO_QUERY, + aliases=("\\#",), + ) + special.register_special_command( + self.change_table_format, + ".mode", + "\\T", + "Change the table format used to output results.", + aliases=("tableformat", "\\T"), + case_sensitive=True, + ) + special.register_special_command( + self.execute_from_file, + ".read", + "\\. filename", + "Execute commands from file.", + case_sensitive=True, + aliases=("\\.", "source"), + ) + special.register_special_command( + self.change_prompt_format, + "prompt", + "\\R", + "Change prompt format.", + aliases=("\\R",), + case_sensitive=True, + ) + + def change_table_format(self, arg, **_): + try: + self.formatter.format_name = arg + yield (None, None, None, "Changed table format to {}".format(arg)) + except ValueError: + msg = "Table format {} not recognized. Allowed formats:".format(arg) + for table_type in self.formatter.supported_formats: + msg += "\n\t{}".format(table_type) + yield (None, None, None, msg) + + def change_db(self, arg, **_): + if arg is None: + self.sqlexecute.connect() + else: + self.sqlexecute.connect(database=arg) + + self.refresh_completions() + yield ( + None, + None, + None, + 'You are now connected to database "%s"' % (self.sqlexecute.dbname), + ) + + def execute_from_file(self, arg, **_): + if not arg: + message = "Missing required argument, filename." + return [(None, None, None, message)] + try: + with open(os.path.expanduser(arg), encoding="utf-8") as f: + query = f.read() + except IOError as e: + return [(None, None, None, str(e))] + + if self.destructive_warning and confirm_destructive_query(query) is False: + message = "Wise choice. Command execution stopped." + return [(None, None, None, message)] + + return self.sqlexecute.run(query) + + def change_prompt_format(self, arg, **_): + """ + Change the prompt format. + """ + if not arg: + message = "Missing required argument, format." + return [(None, None, None, message)] + + self.prompt_format = self.get_prompt(arg) + return [(None, None, None, "Changed prompt format to %s" % arg)] + + def initialize_logging(self): + + log_file = self.config["main"]["log_file"] + if log_file == "default": + log_file = config_location() + "log" + ensure_dir_exists(log_file) + + log_level = self.config["main"]["log_level"] + + level_map = { + "CRITICAL": logging.CRITICAL, + "ERROR": logging.ERROR, + "WARNING": logging.WARNING, + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, + } + + # Disable logging if value is NONE by switching to a no-op handler + # Set log level to a high value so it doesn't even waste cycles getting called. + if log_level.upper() == "NONE": + handler = logging.NullHandler() + log_level = "CRITICAL" + elif dir_path_exists(log_file): + handler = logging.FileHandler(log_file) + else: + self.echo( + 'Error: Unable to open the log file "{}".'.format(log_file), + err=True, + fg="red", + ) + return + + formatter = logging.Formatter( + "%(asctime)s (%(process)d/%(threadName)s) " + "%(name)s %(levelname)s - %(message)s" + ) + + handler.setFormatter(formatter) + + root_logger = logging.getLogger("litecli") + root_logger.addHandler(handler) + root_logger.setLevel(level_map[log_level.upper()]) + + logging.captureWarnings(True) + + root_logger.debug("Initializing litecli logging.") + root_logger.debug("Log file %r.", log_file) + + def read_my_cnf_files(self, keys): + """ + Reads a list of config files and merges them. The last one will win. + :param files: list of files to read + :param keys: list of keys to retrieve + :returns: tuple, with None for missing keys. + """ + cnf = self.config + + sections = ["main"] + + def get(key): + result = None + for sect in cnf: + if sect in sections and key in cnf[sect]: + result = cnf[sect][key] + return result + + return {x: get(x) for x in keys} + + def connect(self, database=""): + + cnf = {"database": None} + + cnf = self.read_my_cnf_files(cnf.keys()) + + # Fall back to config values only if user did not specify a value. + + database = database or cnf["database"] + + # Connect to the database. + + def _connect(): + self.sqlexecute = SQLExecute(database) + + try: + _connect() + except Exception as e: # Connecting to a database could fail. + self.logger.debug("Database connection failed: %r.", e) + self.logger.error("traceback: %r", traceback.format_exc()) + self.echo(str(e), err=True, fg="red") + exit(1) + + def handle_editor_command(self, text): + """Editor command is any query that is prefixed or suffixed by a '\e'. + The reason for a while loop is because a user might edit a query + multiple times. For eg: + + "select * from \e"<enter> to edit it in vim, then come + back to the prompt with the edited query "select * from + blah where q = 'abc'\e" to edit it again. + :param text: Document + :return: Document + + """ + + while special.editor_command(text): + filename = special.get_filename(text) + query = special.get_editor_query(text) or self.get_last_query() + sql, message = special.open_external_editor(filename, sql=query) + if message: + # Something went wrong. Raise an exception and bail. + raise RuntimeError(message) + while True: + try: + text = self.prompt_app.prompt(default=sql) + break + except KeyboardInterrupt: + sql = "" + + continue + return text + + def run_cli(self): + iterations = 0 + sqlexecute = self.sqlexecute + logger = self.logger + self.configure_pager() + self.refresh_completions() + + history_file = config_location() + "history" + if dir_path_exists(history_file): + history = FileHistory(history_file) + else: + history = None + self.echo( + 'Error: Unable to open the history file "{}". ' + "Your query history will not be saved.".format(history_file), + err=True, + fg="red", + ) + + key_bindings = cli_bindings(self) + + if not self.less_chatty: + print("Version:", __version__) + print("Mail: https://groups.google.com/forum/#!forum/litecli-users") + print("GitHub: https://github.com/dbcli/litecli") + # print("Home: https://litecli.com") + + def get_message(): + prompt = self.get_prompt(self.prompt_format) + if ( + self.prompt_format == self.default_prompt + and len(prompt) > self.max_len_prompt + ): + prompt = self.get_prompt("\\d> ") + prompt = prompt.replace("\\x1b", "\x1b") + return ANSI(prompt) + + def get_continuation(width, line_number, is_soft_wrap): + continuation = " " * (width - 1) + " " + return [("class:continuation", continuation)] + + def show_suggestion_tip(): + return iterations < 2 + + def one_iteration(text=None): + if text is None: + try: + text = self.prompt_app.prompt() + except KeyboardInterrupt: + return + + special.set_expanded_output(False) + + try: + text = self.handle_editor_command(text) + except RuntimeError as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + self.echo(str(e), err=True, fg="red") + return + + if not text.strip(): + return + + if self.destructive_warning: + destroy = confirm_destructive_query(text) + if destroy is None: + pass # Query was not destructive. Nothing to do here. + elif destroy is True: + self.echo("Your call!") + else: + self.echo("Wise choice!") + return + + # Keep track of whether or not the query is mutating. In case + # of a multi-statement query, the overall query is considered + # mutating if any one of the component statements is mutating + mutating = False + + try: + logger.debug("sql: %r", text) + + special.write_tee(self.get_prompt(self.prompt_format) + text) + if self.logfile: + self.logfile.write("\n# %s\n" % datetime.now()) + self.logfile.write(text) + self.logfile.write("\n") + + successful = False + start = time() + res = sqlexecute.run(text) + self.formatter.query = text + successful = True + result_count = 0 + for title, cur, headers, status in res: + logger.debug("headers: %r", headers) + logger.debug("rows: %r", cur) + logger.debug("status: %r", status) + threshold = 1000 + if is_select(status) and cur and cur.rowcount > threshold: + self.echo( + "The result set has more than {} rows.".format(threshold), + fg="red", + ) + if not confirm("Do you want to continue?"): + self.echo("Aborted!", err=True, fg="red") + break + + if self.auto_vertical_output: + max_width = self.prompt_app.output.get_size().columns + else: + max_width = None + + formatted = self.format_output( + title, cur, headers, special.is_expanded_output(), max_width + ) + + t = time() - start + try: + if result_count > 0: + self.echo("") + try: + self.output(formatted, status) + except KeyboardInterrupt: + pass + self.echo("Time: %0.03fs" % t) + except KeyboardInterrupt: + pass + + start = time() + result_count += 1 + mutating = mutating or is_mutating(status) + special.unset_once_if_written() + except EOFError as e: + raise e + except KeyboardInterrupt: + # get last connection id + connection_id_to_kill = sqlexecute.connection_id + logger.debug("connection id to kill: %r", connection_id_to_kill) + # Restart connection to the database + sqlexecute.connect() + try: + for title, cur, headers, status in sqlexecute.run( + "kill %s" % connection_id_to_kill + ): + status_str = str(status).lower() + if status_str.find("ok") > -1: + logger.debug( + "cancelled query, connection id: %r, sql: %r", + connection_id_to_kill, + text, + ) + self.echo("cancelled query", err=True, fg="red") + except Exception as e: + self.echo( + "Encountered error while cancelling query: {}".format(e), + err=True, + fg="red", + ) + except NotImplementedError: + self.echo("Not Yet Implemented.", fg="yellow") + except OperationalError as e: + logger.debug("Exception: %r", e) + if e.args[0] in (2003, 2006, 2013): + logger.debug("Attempting to reconnect.") + self.echo("Reconnecting...", fg="yellow") + try: + sqlexecute.connect() + logger.debug("Reconnected successfully.") + one_iteration(text) + return # OK to just return, cuz the recursion call runs to the end. + except OperationalError as e: + logger.debug("Reconnect failed. e: %r", e) + self.echo(str(e), err=True, fg="red") + # If reconnection failed, don't proceed further. + return + else: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + self.echo(str(e), err=True, fg="red") + except Exception as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + self.echo(str(e), err=True, fg="red") + else: + # Refresh the table names and column names if necessary. + if need_completion_refresh(text): + self.refresh_completions(reset=need_completion_reset(text)) + finally: + if self.logfile is False: + self.echo("Warning: This query was not logged.", err=True, fg="red") + query = Query(text, successful, mutating) + self.query_history.append(query) + + get_toolbar_tokens = create_toolbar_tokens_func(self, show_suggestion_tip) + + if self.wider_completion_menu: + complete_style = CompleteStyle.MULTI_COLUMN + else: + complete_style = CompleteStyle.COLUMN + + if not self.autocompletion: + complete_style = CompleteStyle.READLINE_LIKE + + with self._completer_lock: + + if self.key_bindings == "vi": + editing_mode = EditingMode.VI + else: + editing_mode = EditingMode.EMACS + + self.prompt_app = PromptSession( + lexer=PygmentsLexer(LiteCliLexer), + reserve_space_for_menu=self.get_reserved_space(), + message=get_message, + prompt_continuation=get_continuation, + bottom_toolbar=get_toolbar_tokens if self.show_bottom_toolbar else None, + complete_style=complete_style, + input_processors=[ + ConditionalProcessor( + processor=HighlightMatchingBracketProcessor(chars="[](){}"), + filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(), + ) + ], + tempfile_suffix=".sql", + completer=DynamicCompleter(lambda: self.completer), + history=history, + auto_suggest=AutoSuggestFromHistory(), + complete_while_typing=True, + multiline=cli_is_multiline(self), + style=style_factory(self.syntax_style, self.cli_style), + include_default_pygments_style=False, + key_bindings=key_bindings, + enable_open_in_editor=True, + enable_system_prompt=True, + enable_suspend=True, + editing_mode=editing_mode, + search_ignore_case=True, + ) + + try: + while True: + one_iteration() + iterations += 1 + except EOFError: + special.close_tee() + if not self.less_chatty: + self.echo("Goodbye!") + + def log_output(self, output): + """Log the output in the audit log, if it's enabled.""" + if self.logfile: + click.echo(utf8tounicode(output), file=self.logfile) + + def echo(self, s, **kwargs): + """Print a message to stdout. + + The message will be logged in the audit log, if enabled. + + All keyword arguments are passed to click.echo(). + + """ + self.log_output(s) + click.secho(s, **kwargs) + + def get_output_margin(self, status=None): + """Get the output margin (number of rows for the prompt, footer and + timing message.""" + margin = ( + self.get_reserved_space() + + self.get_prompt(self.prompt_format).count("\n") + + 2 + ) + if status: + margin += 1 + status.count("\n") + + return margin + + def output(self, output, status=None): + """Output text to stdout or a pager command. + + The status text is not outputted to pager or files. + + The message will be logged in the audit log, if enabled. The + message will be written to the tee file, if enabled. The + message will be written to the output file, if enabled. + + """ + if output: + size = self.prompt_app.output.get_size() + + margin = self.get_output_margin(status) + + fits = True + buf = [] + output_via_pager = self.explicit_pager and special.is_pager_enabled() + for i, line in enumerate(output, 1): + self.log_output(line) + special.write_tee(line) + special.write_once(line) + + if fits or output_via_pager: + # buffering + buf.append(line) + if len(line) > size.columns or i > (size.rows - margin): + fits = False + if not self.explicit_pager and special.is_pager_enabled(): + # doesn't fit, use pager + output_via_pager = True + + if not output_via_pager: + # doesn't fit, flush buffer + for line in buf: + click.secho(line) + buf = [] + else: + click.secho(line) + + if buf: + if output_via_pager: + # sadly click.echo_via_pager doesn't accept generators + click.echo_via_pager("\n".join(buf)) + else: + for line in buf: + click.secho(line) + + if status: + self.log_output(status) + click.secho(status) + + def configure_pager(self): + # Provide sane defaults for less if they are empty. + if not os.environ.get("LESS"): + os.environ["LESS"] = "-RXF" + + cnf = self.read_my_cnf_files(["pager", "skip-pager"]) + if cnf["pager"]: + special.set_pager(cnf["pager"]) + self.explicit_pager = True + else: + self.explicit_pager = False + + if cnf["skip-pager"] or not self.config["main"].as_bool("enable_pager"): + special.disable_pager() + + def refresh_completions(self, reset=False): + if reset: + with self._completer_lock: + self.completer.reset_completions() + self.completion_refresher.refresh( + self.sqlexecute, + self._on_completions_refreshed, + { + "supported_formats": self.formatter.supported_formats, + "keyword_casing": self.completer.keyword_casing, + }, + ) + + return [ + (None, None, None, "Auto-completion refresh started in the background.") + ] + + def _on_completions_refreshed(self, new_completer): + """Swap the completer object in cli with the newly created completer.""" + with self._completer_lock: + self.completer = new_completer + + if self.prompt_app: + # After refreshing, redraw the CLI to clear the statusbar + # "Refreshing completions..." indicator + self.prompt_app.app.invalidate() + + def get_completions(self, text, cursor_positition): + with self._completer_lock: + return self.completer.get_completions( + Document(text=text, cursor_position=cursor_positition), None + ) + + def get_prompt(self, string): + self.logger.debug("Getting prompt") + sqlexecute = self.sqlexecute + now = datetime.now() + string = string.replace("\\d", sqlexecute.dbname or "(none)") + string = string.replace("\\f", os.path.basename(sqlexecute.dbname or "(none)")) + string = string.replace("\\n", "\n") + string = string.replace("\\D", now.strftime("%a %b %d %H:%M:%S %Y")) + string = string.replace("\\m", now.strftime("%M")) + string = string.replace("\\P", now.strftime("%p")) + string = string.replace("\\R", now.strftime("%H")) + string = string.replace("\\r", now.strftime("%I")) + string = string.replace("\\s", now.strftime("%S")) + string = string.replace("\\_", " ") + return string + + def run_query(self, query, new_line=True): + """Runs *query*.""" + results = self.sqlexecute.run(query) + for result in results: + title, cur, headers, status = result + self.formatter.query = query + output = self.format_output(title, cur, headers) + for line in output: + click.echo(line, nl=new_line) + + def format_output(self, title, cur, headers, expanded=False, max_width=None): + expanded = expanded or self.formatter.format_name == "vertical" + output = [] + + output_kwargs = { + "dialect": "unix", + "disable_numparse": True, + "preserve_whitespace": True, + "preprocessors": (preprocessors.align_decimals,), + "style": self.output_style, + } + + if title: # Only print the title if it's not None. + output = itertools.chain(output, [title]) + + if cur: + column_types = None + if hasattr(cur, "description"): + + def get_col_type(col): + # col_type = FIELD_TYPES.get(col[1], text_type) + # return col_type if type(col_type) is type else text_type + return text_type + + column_types = [get_col_type(col) for col in cur.description] + + if max_width is not None: + cur = list(cur) + + formatted = self.formatter.format_output( + cur, + headers, + format_name="vertical" if expanded else None, + column_types=column_types, + **output_kwargs + ) + + if isinstance(formatted, (text_type)): + formatted = formatted.splitlines() + formatted = iter(formatted) + + first_line = next(formatted) + formatted = itertools.chain([first_line], formatted) + + if ( + not expanded + and max_width + and headers + and cur + and len(first_line) > max_width + ): + formatted = self.formatter.format_output( + cur, + headers, + format_name="vertical", + column_types=column_types, + **output_kwargs + ) + if isinstance(formatted, (text_type)): + formatted = iter(formatted.splitlines()) + + output = itertools.chain(output, formatted) + + return output + + def get_reserved_space(self): + """Get the number of lines to reserve for the completion menu.""" + reserved_space_ratio = 0.45 + max_reserved_space = 8 + _, height = shutil.get_terminal_size() + return min(int(round(height * reserved_space_ratio)), max_reserved_space) + + def get_last_query(self): + """Get the last query executed or None.""" + return self.query_history[-1][0] if self.query_history else None + + +@click.command() +@click.option("-V", "--version", is_flag=True, help="Output litecli's version.") +@click.option("-D", "--database", "dbname", help="Database to use.") +@click.option( + "-R", + "--prompt", + "prompt", + help='Prompt format (Default: "{0}").'.format(LiteCli.default_prompt), +) +@click.option( + "-l", + "--logfile", + type=click.File(mode="a", encoding="utf-8"), + help="Log every query and its results to a file.", +) +@click.option( + "--liteclirc", + default=config_location() + "config", + help="Location of liteclirc file.", + type=click.Path(dir_okay=False), +) +@click.option( + "--auto-vertical-output", + is_flag=True, + help="Automatically switch to vertical output mode if the result is wider than the terminal width.", +) +@click.option( + "-t", "--table", is_flag=True, help="Display batch output in table format." +) +@click.option("--csv", is_flag=True, help="Display batch output in CSV format.") +@click.option( + "--warn/--no-warn", default=None, help="Warn before running a destructive query." +) +@click.option("-e", "--execute", type=str, help="Execute command and quit.") +@click.argument("database", default="", nargs=1) +def cli( + database, + dbname, + version, + prompt, + logfile, + auto_vertical_output, + table, + csv, + warn, + execute, + liteclirc, +): + """A SQLite terminal client with auto-completion and syntax highlighting. + + \b + Examples: + - litecli lite_database + + """ + + if version: + print("Version:", __version__) + sys.exit(0) + + litecli = LiteCli( + prompt=prompt, + logfile=logfile, + auto_vertical_output=auto_vertical_output, + warn=warn, + liteclirc=liteclirc, + ) + + # Choose which ever one has a valid value. + database = database or dbname + + litecli.connect(database) + + litecli.logger.debug("Launch Params: \n" "\tdatabase: %r", database) + + # --execute argument + if execute: + try: + if csv: + litecli.formatter.format_name = "csv" + elif not table: + litecli.formatter.format_name = "tsv" + + litecli.run_query(execute) + exit(0) + except Exception as e: + click.secho(str(e), err=True, fg="red") + exit(1) + + if sys.stdin.isatty(): + litecli.run_cli() + else: + stdin = click.get_text_stream("stdin") + stdin_text = stdin.read() + + try: + sys.stdin = open("/dev/tty") + except (FileNotFoundError, OSError): + litecli.logger.warning("Unable to open TTY as stdin.") + + if ( + litecli.destructive_warning + and confirm_destructive_query(stdin_text) is False + ): + exit(0) + try: + new_line = True + + if csv: + litecli.formatter.format_name = "csv" + elif not table: + litecli.formatter.format_name = "tsv" + + litecli.run_query(stdin_text, new_line=new_line) + exit(0) + except Exception as e: + click.secho(str(e), err=True, fg="red") + exit(1) + + +def need_completion_refresh(queries): + """Determines if the completion needs a refresh by checking if the sql + statement is an alter, create, drop or change db.""" + for query in sqlparse.split(queries): + try: + first_token = query.split()[0] + if first_token.lower() in ( + "alter", + "create", + "use", + "\\r", + "\\u", + "connect", + "drop", + ): + return True + except Exception: + return False + + +def need_completion_reset(queries): + """Determines if the statement is a database switch such as 'use' or '\\u'. + When a database is changed the existing completions must be reset before we + start the completion refresh for the new database. + """ + for query in sqlparse.split(queries): + try: + first_token = query.split()[0] + if first_token.lower() in ("use", "\\u"): + return True + except Exception: + return False + + +def is_mutating(status): + """Determines if the statement is mutating based on the status.""" + if not status: + return False + + mutating = set( + [ + "insert", + "update", + "delete", + "alter", + "create", + "drop", + "replace", + "truncate", + "load", + ] + ) + return status.split(None, 1)[0].lower() in mutating + + +def is_select(status): + """Returns true if the first word in status is 'select'.""" + if not status: + return False + return status.split(None, 1)[0].lower() == "select" + + +if __name__ == "__main__": + cli() diff --git a/litecli/packages/__init__.py b/litecli/packages/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/litecli/packages/__init__.py diff --git a/litecli/packages/completion_engine.py b/litecli/packages/completion_engine.py new file mode 100644 index 0000000..0e2a30f --- /dev/null +++ b/litecli/packages/completion_engine.py @@ -0,0 +1,331 @@ +from __future__ import print_function +import sys +import sqlparse +from sqlparse.sql import Comparison, Identifier, Where +from litecli.encodingutils import string_types, text_type +from .parseutils import last_word, extract_tables, find_prev_keyword +from .special import parse_special_command + + +def suggest_type(full_text, text_before_cursor): + """Takes the full_text that is typed so far and also the text before the + cursor to suggest completion type and scope. + + Returns a tuple with a type of entity ('table', 'column' etc) and a scope. + A scope for a column category will be a list of tables. + """ + + word_before_cursor = last_word(text_before_cursor, include="many_punctuations") + + identifier = None + + # here should be removed once sqlparse has been fixed + try: + # If we've partially typed a word then word_before_cursor won't be an empty + # string. In that case we want to remove the partially typed string before + # sending it to the sqlparser. Otherwise the last token will always be the + # partially typed string which renders the smart completion useless because + # it will always return the list of keywords as completion. + if word_before_cursor: + if word_before_cursor.endswith("(") or word_before_cursor.startswith("\\"): + parsed = sqlparse.parse(text_before_cursor) + else: + parsed = sqlparse.parse(text_before_cursor[: -len(word_before_cursor)]) + + # word_before_cursor may include a schema qualification, like + # "schema_name.partial_name" or "schema_name.", so parse it + # separately + p = sqlparse.parse(word_before_cursor)[0] + + if p.tokens and isinstance(p.tokens[0], Identifier): + identifier = p.tokens[0] + else: + parsed = sqlparse.parse(text_before_cursor) + except (TypeError, AttributeError): + return [{"type": "keyword"}] + + if len(parsed) > 1: + # Multiple statements being edited -- isolate the current one by + # cumulatively summing statement lengths to find the one that bounds the + # current position + current_pos = len(text_before_cursor) + stmt_start, stmt_end = 0, 0 + + for statement in parsed: + stmt_len = len(text_type(statement)) + stmt_start, stmt_end = stmt_end, stmt_end + stmt_len + + if stmt_end >= current_pos: + text_before_cursor = full_text[stmt_start:current_pos] + full_text = full_text[stmt_start:] + break + + elif parsed: + # A single statement + statement = parsed[0] + else: + # The empty string + statement = None + + # Check for special commands and handle those separately + if statement: + # Be careful here because trivial whitespace is parsed as a statement, + # but the statement won't have a first token + tok1 = statement.token_first() + if tok1 and tok1.value.startswith("."): + return suggest_special(text_before_cursor) + elif tok1 and tok1.value.startswith("\\"): + return suggest_special(text_before_cursor) + elif tok1 and tok1.value.startswith("source"): + return suggest_special(text_before_cursor) + elif text_before_cursor and text_before_cursor.startswith(".open "): + return suggest_special(text_before_cursor) + + last_token = statement and statement.token_prev(len(statement.tokens))[1] or "" + + return suggest_based_on_last_token( + last_token, text_before_cursor, full_text, identifier + ) + + +def suggest_special(text): + text = text.lstrip() + cmd, _, arg = parse_special_command(text) + + if cmd == text: + # Trying to complete the special command itself + return [{"type": "special"}] + + if cmd in ("\\u", "\\r"): + return [{"type": "database"}] + + if cmd in ("\\T"): + return [{"type": "table_format"}] + + if cmd in ["\\f", "\\fs", "\\fd"]: + return [{"type": "favoritequery"}] + + if cmd in ["\\d", "\\dt", "\\dt+", ".schema", ".indexes"]: + return [ + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + {"type": "schema"}, + ] + + if cmd in ["\\.", "source", ".open", ".read"]: + return [{"type": "file_name"}] + + if cmd in [".import"]: + # Usage: .import filename table + if _expecting_arg_idx(arg, text) == 1: + return [{"type": "file_name"}] + else: + return [{"type": "table", "schema": []}] + + return [{"type": "keyword"}, {"type": "special"}] + + +def _expecting_arg_idx(arg, text): + """Return the index of expecting argument. + + >>> _expecting_arg_idx("./da", ".import ./da") + 1 + >>> _expecting_arg_idx("./data.csv", ".import ./data.csv") + 1 + >>> _expecting_arg_idx("./data.csv", ".import ./data.csv ") + 2 + >>> _expecting_arg_idx("./data.csv t", ".import ./data.csv t") + 2 + """ + args = arg.split() + return len(args) + int(text[-1].isspace()) + + +def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier): + if isinstance(token, string_types): + token_v = token.lower() + elif isinstance(token, Comparison): + # If 'token' is a Comparison type such as + # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling + # token.value on the comparison type will only return the lhs of the + # comparison. In this case a.id. So we need to do token.tokens to get + # both sides of the comparison and pick the last token out of that + # list. + token_v = token.tokens[-1].value.lower() + elif isinstance(token, Where): + # sqlparse groups all tokens from the where clause into a single token + # list. This means that token.value may be something like + # 'where foo > 5 and '. We need to look "inside" token.tokens to handle + # suggestions in complicated where clauses correctly + prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor) + return suggest_based_on_last_token( + prev_keyword, text_before_cursor, full_text, identifier + ) + else: + token_v = token.value.lower() + + is_operand = lambda x: x and any([x.endswith(op) for op in ["+", "-", "*", "/"]]) + + if not token: + return [{"type": "keyword"}, {"type": "special"}] + elif token_v.endswith("("): + p = sqlparse.parse(text_before_cursor)[0] + + if p.tokens and isinstance(p.tokens[-1], Where): + # Four possibilities: + # 1 - Parenthesized clause like "WHERE foo AND (" + # Suggest columns/functions + # 2 - Function call like "WHERE foo(" + # Suggest columns/functions + # 3 - Subquery expression like "WHERE EXISTS (" + # Suggest keywords, in order to do a subquery + # 4 - Subquery OR array comparison like "WHERE foo = ANY(" + # Suggest columns/functions AND keywords. (If we wanted to be + # really fancy, we could suggest only array-typed columns) + + column_suggestions = suggest_based_on_last_token( + "where", text_before_cursor, full_text, identifier + ) + + # Check for a subquery expression (cases 3 & 4) + where = p.tokens[-1] + idx, prev_tok = where.token_prev(len(where.tokens) - 1) + + if isinstance(prev_tok, Comparison): + # e.g. "SELECT foo FROM bar WHERE foo = ANY(" + prev_tok = prev_tok.tokens[-1] + + prev_tok = prev_tok.value.lower() + if prev_tok == "exists": + return [{"type": "keyword"}] + else: + return column_suggestions + + # Get the token before the parens + idx, prev_tok = p.token_prev(len(p.tokens) - 1) + if prev_tok and prev_tok.value and prev_tok.value.lower() == "using": + # tbl1 INNER JOIN tbl2 USING (col1, col2) + tables = extract_tables(full_text) + + # suggest columns that are present in more than one table + return [{"type": "column", "tables": tables, "drop_unique": True}] + elif p.token_first().value.lower() == "select": + # If the lparen is preceeded by a space chances are we're about to + # do a sub-select. + if last_word(text_before_cursor, "all_punctuations").startswith("("): + return [{"type": "keyword"}] + elif p.token_first().value.lower() == "show": + return [{"type": "show"}] + + # We're probably in a function argument list + return [{"type": "column", "tables": extract_tables(full_text)}] + elif token_v in ("set", "order by", "distinct"): + return [{"type": "column", "tables": extract_tables(full_text)}] + elif token_v == "as": + # Don't suggest anything for an alias + return [] + elif token_v in ("show"): + return [{"type": "show"}] + elif token_v in ("to",): + p = sqlparse.parse(text_before_cursor)[0] + if p.token_first().value.lower() == "change": + return [{"type": "change"}] + else: + return [{"type": "user"}] + elif token_v in ("user", "for"): + return [{"type": "user"}] + elif token_v in ("select", "where", "having"): + # Check for a table alias or schema qualification + parent = (identifier and identifier.get_parent_name()) or [] + + tables = extract_tables(full_text) + if parent: + tables = [t for t in tables if identifies(parent, *t)] + return [ + {"type": "column", "tables": tables}, + {"type": "table", "schema": parent}, + {"type": "view", "schema": parent}, + {"type": "function", "schema": parent}, + ] + else: + aliases = [alias or table for (schema, table, alias) in tables] + return [ + {"type": "column", "tables": tables}, + {"type": "function", "schema": []}, + {"type": "alias", "aliases": aliases}, + {"type": "keyword"}, + ] + elif (token_v.endswith("join") and token.is_keyword) or ( + token_v + in ("copy", "from", "update", "into", "describe", "truncate", "desc", "explain") + ): + schema = (identifier and identifier.get_parent_name()) or [] + + # Suggest tables from either the currently-selected schema or the + # public schema if no schema has been specified + suggest = [{"type": "table", "schema": schema}] + + if not schema: + # Suggest schemas + suggest.insert(0, {"type": "schema"}) + + # Only tables can be TRUNCATED, otherwise suggest views + if token_v != "truncate": + suggest.append({"type": "view", "schema": schema}) + + return suggest + + elif token_v in ("table", "view", "function"): + # E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>' + rel_type = token_v + schema = (identifier and identifier.get_parent_name()) or [] + if schema: + return [{"type": rel_type, "schema": schema}] + else: + return [{"type": "schema"}, {"type": rel_type, "schema": []}] + elif token_v == "on": + tables = extract_tables(full_text) # [(schema, table, alias), ...] + parent = (identifier and identifier.get_parent_name()) or [] + if parent: + # "ON parent.<suggestion>" + # parent can be either a schema name or table alias + tables = [t for t in tables if identifies(parent, *t)] + return [ + {"type": "column", "tables": tables}, + {"type": "table", "schema": parent}, + {"type": "view", "schema": parent}, + {"type": "function", "schema": parent}, + ] + else: + # ON <suggestion> + # Use table alias if there is one, otherwise the table name + aliases = [alias or table for (schema, table, alias) in tables] + suggest = [{"type": "alias", "aliases": aliases}] + + # The lists of 'aliases' could be empty if we're trying to complete + # a GRANT query. eg: GRANT SELECT, INSERT ON <tab> + # In that case we just suggest all tables. + if not aliases: + suggest.append({"type": "table", "schema": parent}) + return suggest + + elif token_v in ("use", "database", "template", "connect"): + # "\c <db", "use <db>", "DROP DATABASE <db>", + # "CREATE DATABASE <newdb> WITH TEMPLATE <db>" + return [{"type": "database"}] + elif token_v == "tableformat": + return [{"type": "table_format"}] + elif token_v.endswith(",") or is_operand(token_v) or token_v in ["=", "and", "or"]: + prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor) + if prev_keyword: + return suggest_based_on_last_token( + prev_keyword, text_before_cursor, full_text, identifier + ) + else: + return [] + else: + return [{"type": "keyword"}] + + +def identifies(id, schema, table, alias): + return id == alias or id == table or (schema and (id == schema + "." + table)) diff --git a/litecli/packages/filepaths.py b/litecli/packages/filepaths.py new file mode 100644 index 0000000..2f01046 --- /dev/null +++ b/litecli/packages/filepaths.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 + +from __future__ import unicode_literals + +from litecli.encodingutils import text_type +import os + + +def list_path(root_dir): + """List directory if exists. + + :param dir: str + :return: list + + """ + res = [] + if os.path.isdir(root_dir): + for name in os.listdir(root_dir): + res.append(name) + return res + + +def complete_path(curr_dir, last_dir): + """Return the path to complete that matches the last entered component. + + If the last entered component is ~, expanded path would not + match, so return all of the available paths. + + :param curr_dir: str + :param last_dir: str + :return: str + + """ + if not last_dir or curr_dir.startswith(last_dir): + return curr_dir + elif last_dir == "~": + return os.path.join(last_dir, curr_dir) + + +def parse_path(root_dir): + """Split path into head and last component for the completer. + + Also return position where last component starts. + + :param root_dir: str path + :return: tuple of (string, string, int) + + """ + base_dir, last_dir, position = "", "", 0 + if root_dir: + base_dir, last_dir = os.path.split(root_dir) + position = -len(last_dir) if last_dir else 0 + return base_dir, last_dir, position + + +def suggest_path(root_dir): + """List all files and subdirectories in a directory. + + If the directory is not specified, suggest root directory, + user directory, current and parent directory. + + :param root_dir: string: directory to list + :return: list + + """ + if not root_dir: + return map(text_type, [os.path.abspath(os.sep), "~", os.curdir, os.pardir]) + + if "~" in root_dir: + root_dir = text_type(os.path.expanduser(root_dir)) + + if not os.path.exists(root_dir): + root_dir, _ = os.path.split(root_dir) + + return list_path(root_dir) + + +def dir_path_exists(path): + """Check if the directory path exists for a given file. + + For example, for a file /home/user/.cache/litecli/log, check if + /home/user/.cache/litecli exists. + + :param str path: The file path. + :return: Whether or not the directory path exists. + + """ + return os.path.exists(os.path.dirname(path)) diff --git a/litecli/packages/parseutils.py b/litecli/packages/parseutils.py new file mode 100644 index 0000000..3f5ca61 --- /dev/null +++ b/litecli/packages/parseutils.py @@ -0,0 +1,227 @@ +from __future__ import print_function +import re +import sqlparse +from sqlparse.sql import IdentifierList, Identifier, Function +from sqlparse.tokens import Keyword, DML, Punctuation + +cleanup_regex = { + # This matches only alphanumerics and underscores. + "alphanum_underscore": re.compile(r"(\w+)$"), + # This matches everything except spaces, parens, colon, and comma + "many_punctuations": re.compile(r"([^():,\s]+)$"), + # This matches everything except spaces, parens, colon, comma, and period + "most_punctuations": re.compile(r"([^\.():,\s]+)$"), + # This matches everything except a space. + "all_punctuations": re.compile("([^\s]+)$"), +} + + +def last_word(text, include="alphanum_underscore"): + """ + Find the last word in a sentence. + + >>> last_word('abc') + 'abc' + >>> last_word(' abc') + 'abc' + >>> last_word('') + '' + >>> last_word(' ') + '' + >>> last_word('abc ') + '' + >>> last_word('abc def') + 'def' + >>> last_word('abc def ') + '' + >>> last_word('abc def;') + '' + >>> last_word('bac $def') + 'def' + >>> last_word('bac $def', include='most_punctuations') + '$def' + >>> last_word('bac \def', include='most_punctuations') + '\\\\def' + >>> last_word('bac \def;', include='most_punctuations') + '\\\\def;' + >>> last_word('bac::def', include='most_punctuations') + 'def' + """ + + if not text: # Empty string + return "" + + if text[-1].isspace(): + return "" + else: + regex = cleanup_regex[include] + matches = regex.search(text) + if matches: + return matches.group(0) + else: + return "" + + +# This code is borrowed from sqlparse example script. +# <url> +def is_subselect(parsed): + if not parsed.is_group: + return False + for item in parsed.tokens: + if item.ttype is DML and item.value.upper() in ( + "SELECT", + "INSERT", + "UPDATE", + "CREATE", + "DELETE", + ): + return True + return False + + +def extract_from_part(parsed, stop_at_punctuation=True): + tbl_prefix_seen = False + for item in parsed.tokens: + if tbl_prefix_seen: + if is_subselect(item): + for x in extract_from_part(item, stop_at_punctuation): + yield x + elif stop_at_punctuation and item.ttype is Punctuation: + return + # An incomplete nested select won't be recognized correctly as a + # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes + # the second FROM to trigger this elif condition resulting in a + # `return`. So we need to ignore the keyword if the keyword + # FROM. + # Also 'SELECT * FROM abc JOIN def' will trigger this elif + # condition. So we need to ignore the keyword JOIN and its variants + # INNER JOIN, FULL OUTER JOIN, etc. + elif ( + item.ttype is Keyword + and (not item.value.upper() == "FROM") + and (not item.value.upper().endswith("JOIN")) + ): + return + else: + yield item + elif ( + item.ttype is Keyword or item.ttype is Keyword.DML + ) and item.value.upper() in ("COPY", "FROM", "INTO", "UPDATE", "TABLE", "JOIN"): + tbl_prefix_seen = True + # 'SELECT a, FROM abc' will detect FROM as part of the column list. + # So this check here is necessary. + elif isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + if identifier.ttype is Keyword and identifier.value.upper() == "FROM": + tbl_prefix_seen = True + break + + +def extract_table_identifiers(token_stream): + """yields tuples of (schema_name, table_name, table_alias)""" + + for item in token_stream: + if isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + # Sometimes Keywords (such as FROM ) are classified as + # identifiers which don't have the get_real_name() method. + try: + schema_name = identifier.get_parent_name() + real_name = identifier.get_real_name() + except AttributeError: + continue + if real_name: + yield (schema_name, real_name, identifier.get_alias()) + elif isinstance(item, Identifier): + real_name = item.get_real_name() + schema_name = item.get_parent_name() + + if real_name: + yield (schema_name, real_name, item.get_alias()) + else: + name = item.get_name() + yield (None, name, item.get_alias() or name) + elif isinstance(item, Function): + yield (None, item.get_name(), item.get_name()) + + +# extract_tables is inspired from examples in the sqlparse lib. +def extract_tables(sql): + """Extract the table names from an SQL statment. + + Returns a list of (schema, table, alias) tuples + + """ + parsed = sqlparse.parse(sql) + if not parsed: + return [] + + # INSERT statements must stop looking for tables at the sign of first + # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2) + # abc is the table name, but if we don't stop at the first lparen, then + # we'll identify abc, col1 and col2 as table names. + insert_stmt = parsed[0].token_first().value.lower() == "insert" + stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt) + return list(extract_table_identifiers(stream)) + + +def find_prev_keyword(sql): + """Find the last sql keyword in an SQL statement + + Returns the value of the last keyword, and the text of the query with + everything after the last keyword stripped + """ + if not sql.strip(): + return None, "" + + parsed = sqlparse.parse(sql)[0] + flattened = list(parsed.flatten()) + + logical_operators = ("AND", "OR", "NOT", "BETWEEN") + + for t in reversed(flattened): + if t.value == "(" or ( + t.is_keyword and (t.value.upper() not in logical_operators) + ): + # Find the location of token t in the original parsed statement + # We can't use parsed.token_index(t) because t may be a child token + # inside a TokenList, in which case token_index thows an error + # Minimal example: + # p = sqlparse.parse('select * from foo where bar') + # t = list(p.flatten())[-3] # The "Where" token + # p.token_index(t) # Throws ValueError: not in list + idx = flattened.index(t) + + # Combine the string values of all tokens in the original list + # up to and including the target keyword token t, to produce a + # query string with everything after the keyword token removed + text = "".join(tok.value for tok in flattened[: idx + 1]) + return t, text + + return None, "" + + +def query_starts_with(query, prefixes): + """Check if the query starts with any item from *prefixes*.""" + prefixes = [prefix.lower() for prefix in prefixes] + formatted_sql = sqlparse.format(query.lower(), strip_comments=True) + return bool(formatted_sql) and formatted_sql.split()[0] in prefixes + + +def queries_start_with(queries, prefixes): + """Check if any queries start with any item from *prefixes*.""" + for query in sqlparse.split(queries): + if query and query_starts_with(query, prefixes) is True: + return True + return False + + +def is_destructive(queries): + """Returns if any of the queries in *queries* is destructive.""" + keywords = ("drop", "shutdown", "delete", "truncate", "alter") + return queries_start_with(queries, keywords) + + +if __name__ == "__main__": + sql = "select * from (select t. from tabl t" + print(extract_tables(sql)) diff --git a/litecli/packages/prompt_utils.py b/litecli/packages/prompt_utils.py new file mode 100644 index 0000000..d9ad2b6 --- /dev/null +++ b/litecli/packages/prompt_utils.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + + +import sys +import click +from .parseutils import is_destructive + + +def confirm_destructive_query(queries): + """Check if the query is destructive and prompts the user to confirm. + + Returns: + * None if the query is non-destructive or we can't prompt the user. + * True if the query is destructive and the user wants to proceed. + * False if the query is destructive and the user doesn't want to proceed. + + """ + prompt_text = ( + "You're about to run a destructive command.\n" "Do you want to proceed? (y/n)" + ) + if is_destructive(queries) and sys.stdin.isatty(): + return prompt(prompt_text, type=bool) + + +def confirm(*args, **kwargs): + """Prompt for confirmation (yes/no) and handle any abort exceptions.""" + try: + return click.confirm(*args, **kwargs) + except click.Abort: + return False + + +def prompt(*args, **kwargs): + """Prompt the user for input and handle any abort exceptions.""" + try: + return click.prompt(*args, **kwargs) + except click.Abort: + return False diff --git a/litecli/packages/special/__init__.py b/litecli/packages/special/__init__.py new file mode 100644 index 0000000..fd2b18c --- /dev/null +++ b/litecli/packages/special/__init__.py @@ -0,0 +1,12 @@ +__all__ = [] + + +def export(defn): + """Decorator to explicitly mark functions that are exposed in a lib.""" + globals()[defn.__name__] = defn + __all__.append(defn.__name__) + return defn + + +from . import dbcommands +from . import iocommands diff --git a/litecli/packages/special/dbcommands.py b/litecli/packages/special/dbcommands.py new file mode 100644 index 0000000..203e1a8 --- /dev/null +++ b/litecli/packages/special/dbcommands.py @@ -0,0 +1,290 @@ +from __future__ import unicode_literals, print_function +import csv +import logging +import os +import sys +import platform +import shlex +from sqlite3 import ProgrammingError + +from litecli import __version__ +from litecli.packages.special import iocommands +from litecli.packages.special.utils import format_uptime +from .main import special_command, RAW_QUERY, PARSED_QUERY, ArgumentMissing + +log = logging.getLogger(__name__) + + +@special_command( + ".tables", + "\\dt", + "List tables.", + arg_type=PARSED_QUERY, + case_sensitive=True, + aliases=("\\dt",), +) +def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False): + if arg: + args = ("{0}%".format(arg),) + query = """ + SELECT name FROM sqlite_master + WHERE type IN ('table','view') AND name LIKE ? AND name NOT LIKE 'sqlite_%' + ORDER BY 1 + """ + else: + args = tuple() + query = """ + SELECT name FROM sqlite_master + WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' + ORDER BY 1 + """ + + log.debug(query) + cur.execute(query, args) + tables = cur.fetchall() + status = "" + if cur.description: + headers = [x[0] for x in cur.description] + else: + return [(None, None, None, "")] + + # if verbose and arg: + # query = "SELECT sql FROM sqlite_master WHERE name LIKE ?" + # log.debug(query) + # cur.execute(query) + # status = cur.fetchone()[1] + + return [(None, tables, headers, status)] + + +@special_command( + ".schema", + ".schema[+] [table]", + "The complete schema for the database or a single table", + arg_type=PARSED_QUERY, + case_sensitive=True, +) +def show_schema(cur, arg=None, **_): + if arg: + args = (arg,) + query = """ + SELECT sql FROM sqlite_master + WHERE name==? + ORDER BY tbl_name, type DESC, name + """ + else: + args = tuple() + query = """ + SELECT sql FROM sqlite_master + ORDER BY tbl_name, type DESC, name + """ + + log.debug(query) + cur.execute(query, args) + tables = cur.fetchall() + status = "" + if cur.description: + headers = [x[0] for x in cur.description] + else: + return [(None, None, None, "")] + + return [(None, tables, headers, status)] + + +@special_command( + ".databases", + ".databases", + "List databases.", + arg_type=RAW_QUERY, + case_sensitive=True, + aliases=("\\l",), +) +def list_databases(cur, **_): + query = "PRAGMA database_list" + log.debug(query) + cur.execute(query) + if cur.description: + headers = [x[0] for x in cur.description] + return [(None, cur, headers, "")] + else: + return [(None, None, None, "")] + + +@special_command( + ".indexes", + ".indexes [tablename]", + "List indexes.", + arg_type=PARSED_QUERY, + case_sensitive=True, + aliases=("\\di",), +) +def list_indexes(cur, arg=None, arg_type=PARSED_QUERY, verbose=False): + if arg: + args = ("{0}%".format(arg),) + query = """ + SELECT name FROM sqlite_master + WHERE type = 'index' AND tbl_name LIKE ? AND name NOT LIKE 'sqlite_%' + ORDER BY 1 + """ + else: + args = tuple() + query = """ + SELECT name FROM sqlite_master + WHERE type = 'index' AND name NOT LIKE 'sqlite_%' + ORDER BY 1 + """ + + log.debug(query) + cur.execute(query, args) + indexes = cur.fetchall() + status = "" + if cur.description: + headers = [x[0] for x in cur.description] + else: + return [(None, None, None, "")] + return [(None, indexes, headers, status)] + + +@special_command( + ".status", + "\\s", + "Show current settings.", + arg_type=RAW_QUERY, + aliases=("\\s",), + case_sensitive=True, +) +def status(cur, **_): + # Create output buffers. + footer = [] + footer.append("--------------") + + # Output the litecli client information. + implementation = platform.python_implementation() + version = platform.python_version() + client_info = [] + client_info.append("litecli {0},".format(__version__)) + client_info.append("running on {0} {1}".format(implementation, version)) + footer.append(" ".join(client_info)) + + # Build the output that will be displayed as a table. + query = "SELECT file from pragma_database_list() where name = 'main';" + log.debug(query) + cur.execute(query) + db = cur.fetchone()[0] + if db is None: + db = "" + + footer.append("Current database: " + db) + if iocommands.is_pager_enabled(): + if "PAGER" in os.environ: + pager = os.environ["PAGER"] + else: + pager = "System default" + else: + pager = "stdout" + footer.append("Current pager:" + pager) + + footer.append("--------------") + return [(None, None, "", "\n".join(footer))] + + +@special_command( + ".load", + ".load path", + "Load an extension library.", + arg_type=PARSED_QUERY, + case_sensitive=True, +) +def load_extension(cur, arg, **_): + args = shlex.split(arg) + if len(args) != 1: + raise TypeError(".load accepts exactly one path") + path = args[0] + conn = cur.connection + conn.enable_load_extension(True) + conn.load_extension(path) + return [(None, None, None, "")] + + +@special_command( + "describe", + "\\d [table]", + "Description of a table", + arg_type=PARSED_QUERY, + case_sensitive=True, + aliases=("\\d", "describe", "desc"), +) +def describe(cur, arg, **_): + if arg: + args = (arg,) + query = """ + PRAGMA table_info({}) + """.format( + arg + ) + else: + raise ArgumentMissing("Table name required.") + + log.debug(query) + cur.execute(query) + tables = cur.fetchall() + status = "" + if cur.description: + headers = [x[0] for x in cur.description] + else: + return [(None, None, None, "")] + + return [(None, tables, headers, status)] + + +@special_command( + ".import", + ".import filename table", + "Import data from filename into an existing table", + arg_type=PARSED_QUERY, + case_sensitive=True, +) +def import_file(cur, arg=None, **_): + def split(s): + # this is a modification of shlex.split function, just to make it support '`', + # because table name might contain '`' character. + lex = shlex.shlex(s, posix=True) + lex.whitespace_split = True + lex.commenters = "" + lex.quotes += "`" + return list(lex) + + args = split(arg) + log.debug("[arg = %r], [args = %r]", arg, args) + if len(args) != 2: + raise TypeError("Usage: .import filename table") + + filename, table = args + cur.execute('PRAGMA table_info("%s")' % table) + ncols = len(cur.fetchall()) + insert_tmpl = 'INSERT INTO "%s" VALUES (?%s)' % (table, ",?" * (ncols - 1)) + + with open(filename, "r") as csvfile: + dialect = csv.Sniffer().sniff(csvfile.read(1024)) + csvfile.seek(0) + reader = csv.reader(csvfile, dialect) + + cur.execute("BEGIN") + ninserted, nignored = 0, 0 + for i, row in enumerate(reader): + if len(row) != ncols: + print( + "%s:%d expected %d columns but found %d - ignored" + % (filename, i, ncols, len(row)), + file=sys.stderr, + ) + nignored += 1 + continue + cur.execute(insert_tmpl, row) + ninserted += 1 + cur.execute("COMMIT") + + status = "Inserted %d rows into %s" % (ninserted, table) + if nignored > 0: + status += " (%d rows are ignored)" % nignored + return [(None, None, None, status)] diff --git a/litecli/packages/special/favoritequeries.py b/litecli/packages/special/favoritequeries.py new file mode 100644 index 0000000..7da6fbf --- /dev/null +++ b/litecli/packages/special/favoritequeries.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + + +class FavoriteQueries(object): + + section_name = "favorite_queries" + + usage = """ +Favorite Queries are a way to save frequently used queries +with a short name. +Examples: + + # Save a new favorite query. + > \\fs simple select * from abc where a is not Null; + + # List all favorite queries. + > \\f + ╒════════╤═══════════════════════════════════════╕ + │ Name │ Query │ + ╞════════╪═══════════════════════════════════════╡ + │ simple │ SELECT * FROM abc where a is not NULL │ + ╘════════╧═══════════════════════════════════════╛ + + # Run a favorite query. + > \\f simple + ╒════════╤════════╕ + │ a │ b │ + ╞════════╪════════╡ + │ 日本語 │ 日本語 │ + ╘════════╧════════╛ + + # Delete a favorite query. + > \\fd simple + simple: Deleted +""" + + def __init__(self, config): + self.config = config + + def list(self): + return self.config.get(self.section_name, []) + + def get(self, name): + return self.config.get(self.section_name, {}).get(name, None) + + def save(self, name, query): + if self.section_name not in self.config: + self.config[self.section_name] = {} + self.config[self.section_name][name] = query + self.config.write() + + def delete(self, name): + try: + del self.config[self.section_name][name] + except KeyError: + return "%s: Not Found." % name + self.config.write() + return "%s: Deleted" % name diff --git a/litecli/packages/special/iocommands.py b/litecli/packages/special/iocommands.py new file mode 100644 index 0000000..43c3577 --- /dev/null +++ b/litecli/packages/special/iocommands.py @@ -0,0 +1,478 @@ +from __future__ import unicode_literals +import os +import re +import locale +import logging +import subprocess +import shlex +from io import open +from time import sleep + +import click +import sqlparse +from configobj import ConfigObj + +from . import export +from .main import special_command, NO_QUERY, PARSED_QUERY +from .favoritequeries import FavoriteQueries +from .utils import handle_cd_command +from litecli.packages.prompt_utils import confirm_destructive_query + +use_expanded_output = False +PAGER_ENABLED = True +tee_file = None +once_file = written_to_once_file = None +favoritequeries = FavoriteQueries(ConfigObj()) + + +@export +def set_favorite_queries(config): + global favoritequeries + favoritequeries = FavoriteQueries(config) + + +@export +def set_pager_enabled(val): + global PAGER_ENABLED + PAGER_ENABLED = val + + +@export +def is_pager_enabled(): + return PAGER_ENABLED + + +@export +@special_command( + "pager", + "\\P [command]", + "Set PAGER. Print the query results via PAGER.", + arg_type=PARSED_QUERY, + aliases=("\\P",), + case_sensitive=True, +) +def set_pager(arg, **_): + if arg: + os.environ["PAGER"] = arg + msg = "PAGER set to %s." % arg + set_pager_enabled(True) + else: + if "PAGER" in os.environ: + msg = "PAGER set to %s." % os.environ["PAGER"] + else: + # This uses click's default per echo_via_pager. + msg = "Pager enabled." + set_pager_enabled(True) + + return [(None, None, None, msg)] + + +@export +@special_command( + "nopager", + "\\n", + "Disable pager, print to stdout.", + arg_type=NO_QUERY, + aliases=("\\n",), + case_sensitive=True, +) +def disable_pager(): + set_pager_enabled(False) + return [(None, None, None, "Pager disabled.")] + + +@export +def set_expanded_output(val): + global use_expanded_output + use_expanded_output = val + + +@export +def is_expanded_output(): + return use_expanded_output + + +_logger = logging.getLogger(__name__) + + +@export +def editor_command(command): + """ + Is this an external editor command? + :param command: string + """ + # It is possible to have `\e filename` or `SELECT * FROM \e`. So we check + # for both conditions. + return command.strip().endswith("\\e") or command.strip().startswith("\\e") + + +@export +def get_filename(sql): + if sql.strip().startswith("\\e"): + command, _, filename = sql.partition(" ") + return filename.strip() or None + + +@export +def get_editor_query(sql): + """Get the query part of an editor command.""" + sql = sql.strip() + + # The reason we can't simply do .strip('\e') is that it strips characters, + # not a substring. So it'll strip "e" in the end of the sql also! + # Ex: "select * from style\e" -> "select * from styl". + pattern = re.compile("(^\\\e|\\\e$)") + while pattern.search(sql): + sql = pattern.sub("", sql) + + return sql + + +@export +def open_external_editor(filename=None, sql=None): + """Open external editor, wait for the user to type in their query, return + the query. + + :return: list with one tuple, query as first element. + + """ + + message = None + filename = filename.strip().split(" ", 1)[0] if filename else None + + sql = sql or "" + MARKER = "# Type your query above this line.\n" + + # Populate the editor buffer with the partial sql (if available) and a + # placeholder comment. + query = click.edit( + "{sql}\n\n{marker}".format(sql=sql, marker=MARKER), + filename=filename, + extension=".sql", + ) + + if filename: + try: + with open(filename, encoding="utf-8") as f: + query = f.read() + except IOError: + message = "Error reading file: %s." % filename + + if query is not None: + query = query.split(MARKER, 1)[0].rstrip("\n") + else: + # Don't return None for the caller to deal with. + # Empty string is ok. + query = sql + + return (query, message) + + +@special_command( + "\\f", + "\\f [name [args..]]", + "List or execute favorite queries.", + arg_type=PARSED_QUERY, + case_sensitive=True, +) +def execute_favorite_query(cur, arg, verbose=False, **_): + """Returns (title, rows, headers, status)""" + if arg == "": + for result in list_favorite_queries(): + yield result + + """Parse out favorite name and optional substitution parameters""" + name, _, arg_str = arg.partition(" ") + args = shlex.split(arg_str) + + query = favoritequeries.get(name) + if query is None: + message = "No favorite query: %s" % (name) + yield (None, None, None, message) + elif "?" in query: + for sql in sqlparse.split(query): + sql = sql.rstrip(";") + title = "> %s" % (sql) if verbose else None + cur.execute(sql, args) + if cur.description: + headers = [x[0] for x in cur.description] + yield (title, cur, headers, None) + else: + yield (title, None, None, None) + else: + query, arg_error = subst_favorite_query_args(query, args) + if arg_error: + yield (None, None, None, arg_error) + else: + for sql in sqlparse.split(query): + sql = sql.rstrip(";") + title = "> %s" % (sql) if verbose else None + cur.execute(sql) + if cur.description: + headers = [x[0] for x in cur.description] + yield (title, cur, headers, None) + else: + yield (title, None, None, None) + + +def list_favorite_queries(): + """List of all favorite queries. + Returns (title, rows, headers, status)""" + + headers = ["Name", "Query"] + rows = [(r, favoritequeries.get(r)) for r in favoritequeries.list()] + + if not rows: + status = "\nNo favorite queries found." + favoritequeries.usage + else: + status = "" + return [("", rows, headers, status)] + + +def subst_favorite_query_args(query, args): + """replace positional parameters ($1...$N) in query.""" + for idx, val in enumerate(args): + shell_subst_var = "$" + str(idx + 1) + question_subst_var = "?" + if shell_subst_var in query: + query = query.replace(shell_subst_var, val) + elif question_subst_var in query: + query = query.replace(question_subst_var, val, 1) + else: + return [ + None, + "Too many arguments.\nQuery does not have enough place holders to substitute.\n" + + query, + ] + + match = re.search("\\?|\\$\d+", query) + if match: + return [ + None, + "missing substitution for " + match.group(0) + " in query:\n " + query, + ] + + return [query, None] + + +@special_command("\\fs", "\\fs name query", "Save a favorite query.") +def save_favorite_query(arg, **_): + """Save a new favorite query. + Returns (title, rows, headers, status)""" + + usage = "Syntax: \\fs name query.\n\n" + favoritequeries.usage + if not arg: + return [(None, None, None, usage)] + + name, _, query = arg.partition(" ") + + # If either name or query is missing then print the usage and complain. + if (not name) or (not query): + return [(None, None, None, usage + "Err: Both name and query are required.")] + + favoritequeries.save(name, query) + return [(None, None, None, "Saved.")] + + +@special_command("\\fd", "\\fd [name]", "Delete a favorite query.") +def delete_favorite_query(arg, **_): + """Delete an existing favorite query.""" + usage = "Syntax: \\fd name.\n\n" + favoritequeries.usage + if not arg: + return [(None, None, None, usage)] + + status = favoritequeries.delete(arg) + + return [(None, None, None, status)] + + +@special_command("system", "system [command]", "Execute a system shell commmand.") +def execute_system_command(arg, **_): + """Execute a system shell command.""" + usage = "Syntax: system [command].\n" + + if not arg: + return [(None, None, None, usage)] + + try: + command = arg.strip() + if command.startswith("cd"): + ok, error_message = handle_cd_command(arg) + if not ok: + return [(None, None, None, error_message)] + return [(None, None, None, "")] + + args = arg.split(" ") + process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + output, error = process.communicate() + response = output if not error else error + + # Python 3 returns bytes. This needs to be decoded to a string. + if isinstance(response, bytes): + encoding = locale.getpreferredencoding(False) + response = response.decode(encoding) + + return [(None, None, None, response)] + except OSError as e: + return [(None, None, None, "OSError: %s" % e.strerror)] + + +def parseargfile(arg): + if arg.startswith("-o "): + mode = "w" + filename = arg[3:] + else: + mode = "a" + filename = arg + + if not filename: + raise TypeError("You must provide a filename.") + + return {"file": os.path.expanduser(filename), "mode": mode} + + +@special_command( + "tee", + "tee [-o] filename", + "Append all results to an output file (overwrite using -o).", +) +def set_tee(arg, **_): + global tee_file + + try: + tee_file = open(**parseargfile(arg)) + except (IOError, OSError) as e: + raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror)) + + return [(None, None, None, "")] + + +@export +def close_tee(): + global tee_file + if tee_file: + tee_file.close() + tee_file = None + + +@special_command("notee", "notee", "Stop writing results to an output file.") +def no_tee(arg, **_): + close_tee() + return [(None, None, None, "")] + + +@export +def write_tee(output): + global tee_file + if tee_file: + click.echo(output, file=tee_file, nl=False) + click.echo("\n", file=tee_file, nl=False) + tee_file.flush() + + +@special_command( + ".once", + "\\o [-o] filename", + "Append next result to an output file (overwrite using -o).", + aliases=("\\o", "\\once"), +) +def set_once(arg, **_): + global once_file + + once_file = parseargfile(arg) + + return [(None, None, None, "")] + + +@export +def write_once(output): + global once_file, written_to_once_file + if output and once_file: + try: + f = open(**once_file) + except (IOError, OSError) as e: + once_file = None + raise OSError( + "Cannot write to file '{}': {}".format(e.filename, e.strerror) + ) + + with f: + click.echo(output, file=f, nl=False) + click.echo("\n", file=f, nl=False) + written_to_once_file = True + + +@export +def unset_once_if_written(): + """Unset the once file, if it has been written to.""" + global once_file + if written_to_once_file: + once_file = None + + +@special_command( + "watch", + "watch [seconds] [-c] query", + "Executes the query every [seconds] seconds (by default 5).", +) +def watch_query(arg, **kwargs): + usage = """Syntax: watch [seconds] [-c] query. + * seconds: The interval at the query will be repeated, in seconds. + By default 5. + * -c: Clears the screen between every iteration. +""" + if not arg: + yield (None, None, None, usage) + raise StopIteration + seconds = 5 + clear_screen = False + statement = None + while statement is None: + arg = arg.strip() + if not arg: + # Oops, we parsed all the arguments without finding a statement + yield (None, None, None, usage) + raise StopIteration + (current_arg, _, arg) = arg.partition(" ") + try: + seconds = float(current_arg) + continue + except ValueError: + pass + if current_arg == "-c": + clear_screen = True + continue + statement = "{0!s} {1!s}".format(current_arg, arg) + destructive_prompt = confirm_destructive_query(statement) + if destructive_prompt is False: + click.secho("Wise choice!") + raise StopIteration + elif destructive_prompt is True: + click.secho("Your call!") + cur = kwargs["cur"] + sql_list = [ + (sql.rstrip(";"), "> {0!s}".format(sql)) for sql in sqlparse.split(statement) + ] + old_pager_enabled = is_pager_enabled() + while True: + if clear_screen: + click.clear() + try: + # Somewhere in the code the pager its activated after every yield, + # so we disable it in every iteration + set_pager_enabled(False) + for (sql, title) in sql_list: + cur.execute(sql) + if cur.description: + headers = [x[0] for x in cur.description] + yield (title, cur, headers, None) + else: + yield (title, None, None, None) + sleep(seconds) + except KeyboardInterrupt: + # This prints the Ctrl-C character in its own line, which prevents + # to print a line with the cursor positioned behind the prompt + click.secho("", nl=True) + raise StopIteration + finally: + set_pager_enabled(old_pager_enabled) diff --git a/litecli/packages/special/main.py b/litecli/packages/special/main.py new file mode 100644 index 0000000..3dd0e77 --- /dev/null +++ b/litecli/packages/special/main.py @@ -0,0 +1,160 @@ +from __future__ import unicode_literals +import logging +from collections import namedtuple + +from . import export + +log = logging.getLogger(__name__) + +NO_QUERY = 0 +PARSED_QUERY = 1 +RAW_QUERY = 2 + +SpecialCommand = namedtuple( + "SpecialCommand", + [ + "handler", + "command", + "shortcut", + "description", + "arg_type", + "hidden", + "case_sensitive", + ], +) + +COMMANDS = {} + + +@export +class ArgumentMissing(Exception): + pass + + +@export +class CommandNotFound(Exception): + pass + + +@export +def parse_special_command(sql): + command, _, arg = sql.partition(" ") + verbose = "+" in command + command = command.strip().replace("+", "") + return (command, verbose, arg.strip()) + + +@export +def special_command( + command, + shortcut, + description, + arg_type=PARSED_QUERY, + hidden=False, + case_sensitive=False, + aliases=(), +): + def wrapper(wrapped): + register_special_command( + wrapped, + command, + shortcut, + description, + arg_type, + hidden, + case_sensitive, + aliases, + ) + return wrapped + + return wrapper + + +@export +def register_special_command( + handler, + command, + shortcut, + description, + arg_type=PARSED_QUERY, + hidden=False, + case_sensitive=False, + aliases=(), +): + cmd = command.lower() if not case_sensitive else command + COMMANDS[cmd] = SpecialCommand( + handler, command, shortcut, description, arg_type, hidden, case_sensitive + ) + for alias in aliases: + cmd = alias.lower() if not case_sensitive else alias + COMMANDS[cmd] = SpecialCommand( + handler, + command, + shortcut, + description, + arg_type, + case_sensitive=case_sensitive, + hidden=True, + ) + + +@export +def execute(cur, sql): + """Execute a special command and return the results. If the special command + is not supported a KeyError will be raised. + """ + command, verbose, arg = parse_special_command(sql) + + if (command not in COMMANDS) and (command.lower() not in COMMANDS): + raise CommandNotFound + + try: + special_cmd = COMMANDS[command] + except KeyError: + special_cmd = COMMANDS[command.lower()] + if special_cmd.case_sensitive: + raise CommandNotFound("Command not found: %s" % command) + + if special_cmd.arg_type == NO_QUERY: + return special_cmd.handler() + elif special_cmd.arg_type == PARSED_QUERY: + return special_cmd.handler(cur=cur, arg=arg, verbose=verbose) + elif special_cmd.arg_type == RAW_QUERY: + return special_cmd.handler(cur=cur, query=sql) + + +@special_command( + "help", "\\?", "Show this help.", arg_type=NO_QUERY, aliases=("\\?", "?") +) +def show_help(): # All the parameters are ignored. + headers = ["Command", "Shortcut", "Description"] + result = [] + + for _, value in sorted(COMMANDS.items()): + if not value.hidden: + result.append((value.command, value.shortcut, value.description)) + return [(None, result, headers, None)] + + +@special_command(".exit", "\\q", "Exit.", arg_type=NO_QUERY, aliases=("\\q", "exit")) +@special_command("quit", "\\q", "Quit.", arg_type=NO_QUERY) +def quit(*_args): + raise EOFError + + +@special_command( + "\\e", + "\\e", + "Edit command with editor (uses $EDITOR).", + arg_type=NO_QUERY, + case_sensitive=True, +) +@special_command( + "\\G", + "\\G", + "Display current query results vertically.", + arg_type=NO_QUERY, + case_sensitive=True, +) +def stub(): + raise NotImplementedError diff --git a/litecli/packages/special/utils.py b/litecli/packages/special/utils.py new file mode 100644 index 0000000..eed9306 --- /dev/null +++ b/litecli/packages/special/utils.py @@ -0,0 +1,48 @@ +import os +import subprocess + + +def handle_cd_command(arg): + """Handles a `cd` shell command by calling python's os.chdir.""" + CD_CMD = "cd" + tokens = arg.split(CD_CMD + " ") + directory = tokens[-1] if len(tokens) > 1 else None + if not directory: + return False, "No folder name was provided." + try: + os.chdir(directory) + subprocess.call(["pwd"]) + return True, None + except OSError as e: + return False, e.strerror + + +def format_uptime(uptime_in_seconds): + """Format number of seconds into human-readable string. + + :param uptime_in_seconds: The server uptime in seconds. + :returns: A human-readable string representing the uptime. + + >>> uptime = format_uptime('56892') + >>> print(uptime) + 15 hours 48 min 12 sec + """ + + m, s = divmod(int(uptime_in_seconds), 60) + h, m = divmod(m, 60) + d, h = divmod(h, 24) + + uptime_values = [] + + for value, unit in ((d, "days"), (h, "hours"), (m, "min"), (s, "sec")): + if value == 0 and not uptime_values: + # Don't include a value/unit if the unit isn't applicable to + # the uptime. E.g. don't do 0 days 0 hours 1 min 30 sec. + continue + elif value == 1 and unit.endswith("s"): + # Remove the "s" if the unit is singular. + unit = unit[:-1] + uptime_values.append("{0} {1}".format(value, unit)) + + uptime = " ".join(uptime_values) + return uptime diff --git a/litecli/sqlcompleter.py b/litecli/sqlcompleter.py new file mode 100644 index 0000000..64ca352 --- /dev/null +++ b/litecli/sqlcompleter.py @@ -0,0 +1,612 @@ +from __future__ import print_function +from __future__ import unicode_literals +import logging +from re import compile, escape +from collections import Counter + +from prompt_toolkit.completion import Completer, Completion + +from .packages.completion_engine import suggest_type +from .packages.parseutils import last_word +from .packages.special.iocommands import favoritequeries +from .packages.filepaths import parse_path, complete_path, suggest_path + +_logger = logging.getLogger(__name__) + + +class SQLCompleter(Completer): + keywords = [ + "ABORT", + "ACTION", + "ADD", + "AFTER", + "ALL", + "ALTER", + "ANALYZE", + "AND", + "AS", + "ASC", + "ATTACH", + "AUTOINCREMENT", + "BEFORE", + "BEGIN", + "BETWEEN", + "BIGINT", + "BLOB", + "BOOLEAN", + "BY", + "CASCADE", + "CASE", + "CAST", + "CHARACTER", + "CHECK", + "CLOB", + "COLLATE", + "COLUMN", + "COMMIT", + "CONFLICT", + "CONSTRAINT", + "CREATE", + "CROSS", + "CURRENT", + "CURRENT_DATE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "DATABASE", + "DATE", + "DATETIME", + "DECIMAL", + "DEFAULT", + "DEFERRABLE", + "DEFERRED", + "DELETE", + "DETACH", + "DISTINCT", + "DO", + "DOUBLE PRECISION", + "DOUBLE", + "DROP", + "EACH", + "ELSE", + "END", + "ESCAPE", + "EXCEPT", + "EXCLUSIVE", + "EXISTS", + "EXPLAIN", + "FAIL", + "FILTER", + "FLOAT", + "FOLLOWING", + "FOR", + "FOREIGN", + "FROM", + "FULL", + "GLOB", + "GROUP", + "HAVING", + "IF", + "IGNORE", + "IMMEDIATE", + "IN", + "INDEX", + "INDEXED", + "INITIALLY", + "INNER", + "INSERT", + "INSTEAD", + "INT", + "INT2", + "INT8", + "INTEGER", + "INTERSECT", + "INTO", + "IS", + "ISNULL", + "JOIN", + "KEY", + "LEFT", + "LIKE", + "LIMIT", + "MATCH", + "MEDIUMINT", + "NATIVE CHARACTER", + "NATURAL", + "NCHAR", + "NO", + "NOT", + "NOTHING", + "NULL", + "NULLS FIRST", + "NULLS LAST", + "NUMERIC", + "NVARCHAR", + "OF", + "OFFSET", + "ON", + "OR", + "ORDER BY", + "OUTER", + "OVER", + "PARTITION", + "PLAN", + "PRAGMA", + "PRECEDING", + "PRIMARY", + "QUERY", + "RAISE", + "RANGE", + "REAL", + "RECURSIVE", + "REFERENCES", + "REGEXP", + "REINDEX", + "RELEASE", + "RENAME", + "REPLACE", + "RESTRICT", + "RIGHT", + "ROLLBACK", + "ROW", + "ROWS", + "SAVEPOINT", + "SELECT", + "SET", + "SMALLINT", + "TABLE", + "TEMP", + "TEMPORARY", + "TEXT", + "THEN", + "TINYINT", + "TO", + "TRANSACTION", + "TRIGGER", + "UNBOUNDED", + "UNION", + "UNIQUE", + "UNSIGNED BIG INT", + "UPDATE", + "USING", + "VACUUM", + "VALUES", + "VARCHAR", + "VARYING CHARACTER", + "VIEW", + "VIRTUAL", + "WHEN", + "WHERE", + "WINDOW", + "WITH", + "WITHOUT", + ] + + functions = [ + "ABS", + "AVG", + "CHANGES", + "CHAR", + "COALESCE", + "COUNT", + "CUME_DIST", + "DATE", + "DATETIME", + "DENSE_RANK", + "GLOB", + "GROUP_CONCAT", + "HEX", + "IFNULL", + "INSTR", + "JSON", + "JSON_ARRAY", + "JSON_ARRAY_LENGTH", + "JSON_EACH", + "JSON_EXTRACT", + "JSON_GROUP_ARRAY", + "JSON_GROUP_OBJECT", + "JSON_INSERT", + "JSON_OBJECT", + "JSON_PATCH", + "JSON_QUOTE", + "JSON_REMOVE", + "JSON_REPLACE", + "JSON_SET", + "JSON_TREE", + "JSON_TYPE", + "JSON_VALID", + "JULIANDAY", + "LAG", + "LAST_INSERT_ROWID", + "LENGTH", + "LIKELIHOOD", + "LIKELY", + "LOAD_EXTENSION", + "LOWER", + "LTRIM", + "MAX", + "MIN", + "NTILE", + "NULLIF", + "PERCENT_RANK", + "PRINTF", + "QUOTE", + "RANDOM", + "RANDOMBLOB", + "RANK", + "REPLACE", + "ROUND", + "ROW_NUMBER", + "RTRIM", + "SOUNDEX", + "SQLITE_COMPILEOPTION_GET", + "SQLITE_COMPILEOPTION_USED", + "SQLITE_OFFSET", + "SQLITE_SOURCE_ID", + "SQLITE_VERSION", + "STRFTIME", + "SUBSTR", + "SUM", + "TIME", + "TOTAL", + "TOTAL_CHANGES", + "TRIM", + ] + + def __init__(self, supported_formats=(), keyword_casing="auto"): + super(self.__class__, self).__init__() + self.reserved_words = set() + for x in self.keywords: + self.reserved_words.update(x.split()) + self.name_pattern = compile("^[_a-z][_a-z0-9\$]*$") + + self.special_commands = [] + self.table_formats = supported_formats + if keyword_casing not in ("upper", "lower", "auto"): + keyword_casing = "auto" + self.keyword_casing = keyword_casing + self.reset_completions() + + def escape_name(self, name): + if name and ( + (not self.name_pattern.match(name)) + or (name.upper() in self.reserved_words) + or (name.upper() in self.functions) + ): + name = "`%s`" % name + + return name + + def unescape_name(self, name): + """Unquote a string.""" + if name and name[0] == '"' and name[-1] == '"': + name = name[1:-1] + + return name + + def escaped_names(self, names): + return [self.escape_name(name) for name in names] + + def extend_special_commands(self, special_commands): + # Special commands are not part of all_completions since they can only + # be at the beginning of a line. + self.special_commands.extend(special_commands) + + def extend_database_names(self, databases): + self.databases.extend(databases) + + def extend_keywords(self, additional_keywords): + self.keywords.extend(additional_keywords) + self.all_completions.update(additional_keywords) + + def extend_schemata(self, schema): + if schema is None: + return + metadata = self.dbmetadata["tables"] + metadata[schema] = {} + + # dbmetadata.values() are the 'tables' and 'functions' dicts + for metadata in self.dbmetadata.values(): + metadata[schema] = {} + self.all_completions.update(schema) + + def extend_relations(self, data, kind): + """Extend metadata for tables or views + + :param data: list of (rel_name, ) tuples + :param kind: either 'tables' or 'views' + :return: + """ + # 'data' is a generator object. It can throw an exception while being + # consumed. This could happen if the user has launched the app without + # specifying a database name. This exception must be handled to prevent + # crashing. + try: + data = [self.escaped_names(d) for d in data] + except Exception: + data = [] + + # dbmetadata['tables'][$schema_name][$table_name] should be a list of + # column names. Default to an asterisk + metadata = self.dbmetadata[kind] + for relname in data: + try: + metadata[self.dbname][relname[0]] = ["*"] + except KeyError: + _logger.error( + "%r %r listed in unrecognized schema %r", + kind, + relname[0], + self.dbname, + ) + self.all_completions.add(relname[0]) + + def extend_columns(self, column_data, kind): + """Extend column metadata + + :param column_data: list of (rel_name, column_name) tuples + :param kind: either 'tables' or 'views' + :return: + """ + # 'column_data' is a generator object. It can throw an exception while + # being consumed. This could happen if the user has launched the app + # without specifying a database name. This exception must be handled to + # prevent crashing. + try: + column_data = [self.escaped_names(d) for d in column_data] + except Exception: + column_data = [] + + metadata = self.dbmetadata[kind] + for relname, column in column_data: + metadata[self.dbname][relname].append(column) + self.all_completions.add(column) + + def extend_functions(self, func_data): + # 'func_data' is a generator object. It can throw an exception while + # being consumed. This could happen if the user has launched the app + # without specifying a database name. This exception must be handled to + # prevent crashing. + try: + func_data = [self.escaped_names(d) for d in func_data] + except Exception: + func_data = [] + + # dbmetadata['functions'][$schema_name][$function_name] should return + # function metadata. + metadata = self.dbmetadata["functions"] + + for func in func_data: + metadata[self.dbname][func[0]] = None + self.all_completions.add(func[0]) + + def set_dbname(self, dbname): + self.dbname = dbname + + def reset_completions(self): + self.databases = [] + self.dbname = "" + self.dbmetadata = {"tables": {}, "views": {}, "functions": {}} + self.all_completions = set(self.keywords + self.functions) + + @staticmethod + def find_matches( + text, + collection, + start_only=False, + fuzzy=True, + casing=None, + punctuations="most_punctuations", + ): + """Find completion matches for the given text. + + Given the user's input text and a collection of available + completions, find completions matching the last word of the + text. + + If `start_only` is True, the text will match an available + completion only at the beginning. Otherwise, a completion is + considered a match if the text appears anywhere within it. + + yields prompt_toolkit Completion instances for any matches found + in the collection of available completions. + """ + last = last_word(text, include=punctuations) + text = last.lower() + + completions = [] + + if fuzzy: + regex = ".*?".join(map(escape, text)) + pat = compile("(%s)" % regex) + for item in sorted(collection): + r = pat.search(item.lower()) + if r: + completions.append((len(r.group()), r.start(), item)) + else: + match_end_limit = len(text) if start_only else None + for item in sorted(collection): + match_point = item.lower().find(text, 0, match_end_limit) + if match_point >= 0: + completions.append((len(text), match_point, item)) + + if casing == "auto": + casing = "lower" if last and last[-1].islower() else "upper" + + def apply_case(kw): + if casing == "upper": + return kw.upper() + return kw.lower() + + return ( + Completion(z if casing is None else apply_case(z), -len(text)) + for x, y, z in sorted(completions) + ) + + def get_completions(self, document, complete_event): + word_before_cursor = document.get_word_before_cursor(WORD=True) + completions = [] + suggestions = suggest_type(document.text, document.text_before_cursor) + + for suggestion in suggestions: + + _logger.debug("Suggestion type: %r", suggestion["type"]) + + if suggestion["type"] == "column": + tables = suggestion["tables"] + _logger.debug("Completion column scope: %r", tables) + scoped_cols = self.populate_scoped_cols(tables) + if suggestion.get("drop_unique"): + # drop_unique is used for 'tb11 JOIN tbl2 USING (...' + # which should suggest only columns that appear in more than + # one table + scoped_cols = [ + col + for (col, count) in Counter(scoped_cols).items() + if count > 1 and col != "*" + ] + + cols = self.find_matches(word_before_cursor, scoped_cols) + completions.extend(cols) + + elif suggestion["type"] == "function": + # suggest user-defined functions using substring matching + funcs = self.populate_schema_objects(suggestion["schema"], "functions") + user_funcs = self.find_matches(word_before_cursor, funcs) + completions.extend(user_funcs) + + # suggest hardcoded functions using startswith matching only if + # there is no schema qualifier. If a schema qualifier is + # present it probably denotes a table. + # eg: SELECT * FROM users u WHERE u. + if not suggestion["schema"]: + predefined_funcs = self.find_matches( + word_before_cursor, + self.functions, + start_only=True, + fuzzy=False, + casing=self.keyword_casing, + ) + completions.extend(predefined_funcs) + + elif suggestion["type"] == "table": + tables = self.populate_schema_objects(suggestion["schema"], "tables") + tables = self.find_matches(word_before_cursor, tables) + completions.extend(tables) + + elif suggestion["type"] == "view": + views = self.populate_schema_objects(suggestion["schema"], "views") + views = self.find_matches(word_before_cursor, views) + completions.extend(views) + + elif suggestion["type"] == "alias": + aliases = suggestion["aliases"] + aliases = self.find_matches(word_before_cursor, aliases) + completions.extend(aliases) + + elif suggestion["type"] == "database": + dbs = self.find_matches(word_before_cursor, self.databases) + completions.extend(dbs) + + elif suggestion["type"] == "keyword": + keywords = self.find_matches( + word_before_cursor, + self.keywords, + start_only=True, + fuzzy=False, + casing=self.keyword_casing, + punctuations="many_punctuations", + ) + completions.extend(keywords) + + elif suggestion["type"] == "special": + special = self.find_matches( + word_before_cursor, + self.special_commands, + start_only=True, + fuzzy=False, + punctuations="many_punctuations", + ) + completions.extend(special) + elif suggestion["type"] == "favoritequery": + queries = self.find_matches( + word_before_cursor, + favoritequeries.list(), + start_only=False, + fuzzy=True, + ) + completions.extend(queries) + elif suggestion["type"] == "table_format": + formats = self.find_matches( + word_before_cursor, self.table_formats, start_only=True, fuzzy=False + ) + completions.extend(formats) + elif suggestion["type"] == "file_name": + file_names = self.find_files(word_before_cursor) + completions.extend(file_names) + + return completions + + def find_files(self, word): + """Yield matching directory or file names. + + :param word: + :return: iterable + + """ + base_path, last_path, position = parse_path(word) + paths = suggest_path(word) + for name in sorted(paths): + suggestion = complete_path(name, last_path) + if suggestion: + yield Completion(suggestion, position) + + def populate_scoped_cols(self, scoped_tbls): + """Find all columns in a set of scoped_tables + :param scoped_tbls: list of (schema, table, alias) tuples + :return: list of column names + """ + columns = [] + meta = self.dbmetadata + + for tbl in scoped_tbls: + # A fully qualified schema.relname reference or default_schema + # DO NOT escape schema names. + schema = tbl[0] or self.dbname + relname = tbl[1] + escaped_relname = self.escape_name(tbl[1]) + + # We don't know if schema.relname is a table or view. Since + # tables and views cannot share the same name, we can check one + # at a time + try: + columns.extend(meta["tables"][schema][relname]) + + # Table exists, so don't bother checking for a view + continue + except KeyError: + try: + columns.extend(meta["tables"][schema][escaped_relname]) + # Table exists, so don't bother checking for a view + continue + except KeyError: + pass + + try: + columns.extend(meta["views"][schema][relname]) + except KeyError: + pass + + return columns + + def populate_schema_objects(self, schema, obj_type): + """Returns list of tables or functions for a (optional) schema""" + metadata = self.dbmetadata[obj_type] + schema = schema or self.dbname + + try: + objects = metadata[schema].keys() + except KeyError: + # schema doesn't exist + objects = [] + + return objects diff --git a/litecli/sqlexecute.py b/litecli/sqlexecute.py new file mode 100644 index 0000000..3f78d49 --- /dev/null +++ b/litecli/sqlexecute.py @@ -0,0 +1,220 @@ +import logging +import sqlite3 +import uuid +from contextlib import closing +from sqlite3 import OperationalError + +import sqlparse +import os.path + +from .packages import special + +_logger = logging.getLogger(__name__) + +# FIELD_TYPES = decoders.copy() +# FIELD_TYPES.update({ +# FIELD_TYPE.NULL: type(None) +# }) + + +class SQLExecute(object): + + databases_query = """ + PRAGMA database_list + """ + + tables_query = """ + SELECT name + FROM sqlite_master + WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' + ORDER BY 1 + """ + + table_columns_query = """ + SELECT m.name as tableName, p.name as columnName + FROM sqlite_master m + LEFT OUTER JOIN pragma_table_info((m.name)) p ON m.name <> p.name + WHERE m.type IN ('table','view') AND m.name NOT LIKE 'sqlite_%' + ORDER BY tableName, columnName + """ + + indexes_query = """ + SELECT name + FROM sqlite_master + WHERE type = 'index' AND name NOT LIKE 'sqlite_%' + ORDER BY 1 + """ + + functions_query = '''SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES + WHERE ROUTINE_TYPE="FUNCTION" AND ROUTINE_SCHEMA = "%s"''' + + def __init__(self, database): + self.dbname = database + self._server_type = None + self.connection_id = None + self.conn = None + if not database: + _logger.debug("Database is not specified. Skip connection.") + return + self.connect() + + def connect(self, database=None): + db = database or self.dbname + _logger.debug("Connection DB Params: \n" "\tdatabase: %r", database) + + db_name = os.path.expanduser(db) + db_dir_name = os.path.dirname(os.path.abspath(db_name)) + if not os.path.exists(db_dir_name): + raise Exception("Path does not exist: {}".format(db_dir_name)) + + conn = sqlite3.connect(database=db_name, isolation_level=None) + conn.text_factory = lambda x: x.decode("utf-8", "backslashreplace") + if self.conn: + self.conn.close() + + self.conn = conn + # Update them after the connection is made to ensure that it was a + # successful connection. + self.dbname = db + # retrieve connection id + self.reset_connection_id() + + def run(self, statement): + """Execute the sql in the database and return the results. The results + are a list of tuples. Each tuple has 4 values + (title, rows, headers, status). + """ + # Remove spaces and EOL + statement = statement.strip() + if not statement: # Empty string + yield (None, None, None, None) + + # Split the sql into separate queries and run each one. + # Unless it's saving a favorite query, in which case we + # want to save them all together. + if statement.startswith("\\fs"): + components = [statement] + else: + components = sqlparse.split(statement) + + for sql in components: + # Remove spaces, eol and semi-colons. + sql = sql.rstrip(";") + + # \G is treated specially since we have to set the expanded output. + if sql.endswith("\\G"): + special.set_expanded_output(True) + sql = sql[:-2].strip() + + if not self.conn and not ( + sql.startswith(".open") + or sql.lower().startswith("use") + or sql.startswith("\\u") + or sql.startswith("\\?") + or sql.startswith("\\q") + or sql.startswith("help") + or sql.startswith("exit") + or sql.startswith("quit") + ): + _logger.debug( + "Not connected to database. Will not run statement: %s.", sql + ) + raise OperationalError("Not connected to database.") + # yield ('Not connected to database', None, None, None) + # return + + cur = self.conn.cursor() if self.conn else None + try: # Special command + _logger.debug("Trying a dbspecial command. sql: %r", sql) + for result in special.execute(cur, sql): + yield result + except special.CommandNotFound: # Regular SQL + _logger.debug("Regular sql statement. sql: %r", sql) + cur.execute(sql) + yield self.get_result(cur) + + def get_result(self, cursor): + """Get the current result's data from the cursor.""" + title = headers = None + + # cursor.description is not None for queries that return result sets, + # e.g. SELECT. + if cursor.description is not None: + headers = [x[0] for x in cursor.description] + status = "{0} row{1} in set" + cursor = list(cursor) + rowcount = len(cursor) + else: + _logger.debug("No rows in result.") + status = "Query OK, {0} row{1} affected" + rowcount = 0 if cursor.rowcount == -1 else cursor.rowcount + cursor = None + + status = status.format(rowcount, "" if rowcount == 1 else "s") + + return (title, cursor, headers, status) + + def tables(self): + """Yields table names""" + + with closing(self.conn.cursor()) as cur: + _logger.debug("Tables Query. sql: %r", self.tables_query) + cur.execute(self.tables_query) + for row in cur: + yield row + + def table_columns(self): + """Yields column names""" + with closing(self.conn.cursor()) as cur: + _logger.debug("Columns Query. sql: %r", self.table_columns_query) + cur.execute(self.table_columns_query) + for row in cur: + yield row + + def databases(self): + if not self.conn: + return + + with closing(self.conn.cursor()) as cur: + _logger.debug("Databases Query. sql: %r", self.databases_query) + for row in cur.execute(self.databases_query): + yield row[1] + + def functions(self): + """Yields tuples of (schema_name, function_name)""" + + with closing(self.conn.cursor()) as cur: + _logger.debug("Functions Query. sql: %r", self.functions_query) + cur.execute(self.functions_query % self.dbname) + for row in cur: + yield row + + def show_candidates(self): + with closing(self.conn.cursor()) as cur: + _logger.debug("Show Query. sql: %r", self.show_candidates_query) + try: + cur.execute(self.show_candidates_query) + except sqlite3.DatabaseError as e: + _logger.error("No show completions due to %r", e) + yield "" + else: + for row in cur: + yield (row[0].split(None, 1)[-1],) + + def server_type(self): + self._server_type = ("sqlite3", "3") + return self._server_type + + def get_connection_id(self): + if not self.connection_id: + self.reset_connection_id() + return self.connection_id + + def reset_connection_id(self): + # Remember current connection id + _logger.debug("Get current connection id") + # res = self.run('select connection_id()') + self.connection_id = uuid.uuid4() + # for title, cur, headers, status in res: + # self.connection_id = cur.fetchone()[0] + _logger.debug("Current connection id: %s", self.connection_id) diff --git a/release.py b/release.py new file mode 100644 index 0000000..f6beb88 --- /dev/null +++ b/release.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python +"""A script to publish a release of litecli to PyPI.""" + +from __future__ import print_function +import io +from optparse import OptionParser +import re +import subprocess +import sys + +import click + +DEBUG = False +CONFIRM_STEPS = False +DRY_RUN = False + + +def skip_step(): + """ + Asks for user's response whether to run a step. Default is yes. + :return: boolean + """ + global CONFIRM_STEPS + + if CONFIRM_STEPS: + return not click.confirm("--- Run this step?", default=True) + return False + + +def run_step(*args): + """ + Prints out the command and asks if it should be run. + If yes (default), runs it. + :param args: list of strings (command and args) + """ + global DRY_RUN + + cmd = args + print(" ".join(cmd)) + if skip_step(): + print("--- Skipping...") + elif DRY_RUN: + print("--- Pretending to run...") + else: + subprocess.check_output(cmd) + + +def version(version_file): + _version_re = re.compile( + r'__version__\s+=\s+(?P<quote>[\'"])(?P<version>.*)(?P=quote)' + ) + + with io.open(version_file, encoding="utf-8") as f: + ver = _version_re.search(f.read()).group("version") + + return ver + + +def commit_for_release(version_file, ver): + run_step("git", "reset") + run_step("git", "add", version_file) + run_step("git", "commit", "--message", "Releasing version {}".format(ver)) + + +def create_git_tag(tag_name): + run_step("git", "tag", tag_name) + + +def create_distribution_files(): + run_step("python", "setup.py", "sdist", "bdist_wheel") + + +def upload_distribution_files(): + run_step("twine", "upload", "dist/*") + + +def push_to_github(): + run_step("git", "push", "origin", "main") + + +def push_tags_to_github(): + run_step("git", "push", "--tags", "origin") + + +def checklist(questions): + for question in questions: + if not click.confirm("--- {}".format(question), default=False): + sys.exit(1) + + +if __name__ == "__main__": + if DEBUG: + subprocess.check_output = lambda x: x + + ver = version("litecli/__init__.py") + print("Releasing Version:", ver) + + parser = OptionParser() + parser.add_option( + "-c", + "--confirm-steps", + action="store_true", + dest="confirm_steps", + default=False, + help=( + "Confirm every step. If the step is not " "confirmed, it will be skipped." + ), + ) + parser.add_option( + "-d", + "--dry-run", + action="store_true", + dest="dry_run", + default=False, + help="Print out, but not actually run any steps.", + ) + + popts, pargs = parser.parse_args() + CONFIRM_STEPS = popts.confirm_steps + DRY_RUN = popts.dry_run + + if not click.confirm("Are you sure?", default=False): + sys.exit(1) + + commit_for_release("litecli/__init__.py", ver) + create_git_tag("v{}".format(ver)) + create_distribution_files() + push_to_github() + push_tags_to_github() + upload_distribution_files() diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..c517d59 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,10 @@ +mock +pytest>=3.6 +pytest-cov +tox +behave +pexpect +coverage +codecov +click +black
\ No newline at end of file diff --git a/screenshots/litecli.gif b/screenshots/litecli.gif Binary files differnew file mode 100644 index 0000000..9cfd80c --- /dev/null +++ b/screenshots/litecli.gif diff --git a/screenshots/litecli.png b/screenshots/litecli.png Binary files differnew file mode 100644 index 0000000..6ca999e --- /dev/null +++ b/screenshots/litecli.png diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..40eab0a --- /dev/null +++ b/setup.cfg @@ -0,0 +1,18 @@ +[bdist_wheel] +universal = 1 + +[tool:pytest] +addopts = --capture=sys + --showlocals + --doctest-modules + --doctest-ignore-import-errors + --ignore=setup.py + --ignore=litecli/magic.py + --ignore=litecli/packages/parseutils.py + --ignore=test/features + +[pep8] +rev = master +docformatter = True +diff = True +error-status = True diff --git a/setup.py b/setup.py new file mode 100755 index 0000000..0ff4eeb --- /dev/null +++ b/setup.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import ast +from io import open +import re +from setuptools import setup, find_packages + +_version_re = re.compile(r"__version__\s+=\s+(.*)") + +with open("litecli/__init__.py", "rb") as f: + version = str( + ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1)) + ) + + +def open_file(filename): + """Open and read the file *filename*.""" + with open(filename) as f: + return f.read() + + +readme = open_file("README.md") + +install_requirements = [ + "click >= 4.1", + "Pygments>=1.6", + "prompt_toolkit>=3.0.3,<4.0.0", + "sqlparse", + "configobj >= 5.0.5", + "cli_helpers[styles] >= 2.2.1", +] + + +setup( + name="litecli", + author="dbcli", + author_email="litecli-users@googlegroups.com", + license="BSD", + version=version, + url="https://github.com/dbcli/litecli", + packages=find_packages(), + package_data={"litecli": ["liteclirc", "AUTHORS"]}, + description="CLI for SQLite Databases with auto-completion and syntax " + "highlighting.", + long_description=readme, + long_description_content_type="text/markdown", + install_requires=install_requirements, + # cmdclass={"test": test, "lint": lint}, + entry_points={ + "console_scripts": ["litecli = litecli.main:cli"], + "distutils.commands": ["lint = tasks:lint", "test = tasks:test"], + }, + classifiers=[ + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: Unix", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: SQL", + "Topic :: Database", + "Topic :: Database :: Front-Ends", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/tasks.py b/tasks.py new file mode 100644 index 0000000..1cd4b69 --- /dev/null +++ b/tasks.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +"""Common development tasks for setup.py to use.""" + +import re +import subprocess +import sys + +from setuptools import Command +from setuptools.command.test import test as TestCommand + + +class BaseCommand(Command, object): + """The base command for project tasks.""" + + user_options = [] + + default_cmd_options = ("verbose", "quiet", "dry_run") + + def __init__(self, *args, **kwargs): + super(BaseCommand, self).__init__(*args, **kwargs) + self.verbose = False + + def initialize_options(self): + """Override the distutils abstract method.""" + pass + + def finalize_options(self): + """Override the distutils abstract method.""" + # Distutils uses incrementing integers for verbosity. + self.verbose = bool(self.verbose) + + def call_and_exit(self, cmd, shell=True): + """Run the *cmd* and exit with the proper exit code.""" + sys.exit(subprocess.call(cmd, shell=shell)) + + def call_in_sequence(self, cmds, shell=True): + """Run multiple commmands in a row, exiting if one fails.""" + for cmd in cmds: + if subprocess.call(cmd, shell=shell) == 1: + sys.exit(1) + + def apply_options(self, cmd, options=()): + """Apply command-line options.""" + for option in self.default_cmd_options + options: + cmd = self.apply_option(cmd, option, active=getattr(self, option, False)) + return cmd + + def apply_option(self, cmd, option, active=True): + """Apply a command-line option.""" + return re.sub( + r"{{{}\:(?P<option>[^}}]*)}}".format(option), + r"\g<option>" if active else "", + cmd, + ) + + +class lint(BaseCommand): + description = "check code using black (and fix violations)" + + user_options = [("fix", "f", "fix the violations in place")] + + def initialize_options(self): + """Set the default options.""" + self.fix = False + + def finalize_options(self): + pass + + def run(self): + cmd = "black" + if not self.fix: + cmd += " --check" + cmd += " ." + sys.exit(subprocess.call(cmd, shell=True)) + + +class test(TestCommand): + + user_options = [("pytest-args=", "a", "Arguments to pass to pytest")] + + def initialize_options(self): + TestCommand.initialize_options(self) + self.pytest_args = "" + + def run_tests(self): + unit_test_errno = subprocess.call( + "pytest tests " + self.pytest_args, shell=True + ) + # cli_errno = subprocess.call('behave test/features', shell=True) + # sys.exit(unit_test_errno or cli_errno) + sys.exit(unit_test_errno) + + +# class test(BaseCommand): +# """Run the test suites for this project.""" + +# description = "run the test suite" + +# user_options = [ +# ("all", "a", "test against all supported versions of Python"), +# ("coverage", "c", "measure test coverage"), +# ] + +# unit_test_cmd = ( +# "py.test{quiet: -q}{verbose: -v}{dry_run: --setup-only}" +# "{coverage: --cov-report= --cov=litecli}" +# ) +# # cli_test_cmd = 'behave{quiet: -q}{verbose: -v}{dry_run: -d} test/features' +# test_all_cmd = "tox{verbose: -v}{dry_run: --notest}" +# coverage_cmd = "coverage combine && coverage report" + +# def initialize_options(self): +# """Set the default options.""" +# self.all = False +# self.coverage = False +# super(test, self).initialize_options() + +# def run(self): +# """Run the test suites.""" +# if self.all: +# cmd = self.apply_options(self.test_all_cmd) +# self.call_and_exit(cmd) +# else: +# cmds = ( +# self.apply_options(self.unit_test_cmd, ("coverage",)), +# # self.apply_options(self.cli_test_cmd) +# ) +# if self.coverage: +# cmds += (self.apply_options(self.coverage_cmd),) +# self.call_in_sequence(cmds) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..dce0d7e --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,40 @@ +from __future__ import print_function + +import os +import pytest +from utils import create_db, db_connection, drop_tables +import litecli.sqlexecute + + +@pytest.yield_fixture(scope="function") +def connection(): + create_db("_test_db") + connection = db_connection("_test_db") + yield connection + + drop_tables(connection) + connection.close() + os.remove("_test_db") + + +@pytest.fixture +def cursor(connection): + with connection.cursor() as cur: + return cur + + +@pytest.fixture +def executor(connection): + return litecli.sqlexecute.SQLExecute(database="_test_db") + + +@pytest.fixture +def exception_formatter(): + return lambda e: str(e) + + +@pytest.fixture(scope="session", autouse=True) +def temp_config(tmpdir_factory): + # this function runs on start of test session. + # use temporary directory for config home so user config will not be used + os.environ["XDG_CONFIG_HOME"] = str(tmpdir_factory.mktemp("data")) diff --git a/tests/data/import_data.csv b/tests/data/import_data.csv new file mode 100644 index 0000000..d68d655 --- /dev/null +++ b/tests/data/import_data.csv @@ -0,0 +1,2 @@ +t1,11 +t2,22 diff --git a/tests/liteclirc b/tests/liteclirc new file mode 100644 index 0000000..da9b061 --- /dev/null +++ b/tests/liteclirc @@ -0,0 +1,137 @@ +[main] + +# Multi-line mode allows breaking up the sql statements into multiple lines. If +# this is set to True, then the end of the statements must have a semi-colon. +# If this is set to False then sql statements can't be split into multiple +# lines. End of line (return) is considered as the end of the statement. +multi_line = False + +# Destructive warning mode will alert you before executing a sql statement +# that may cause harm to the database such as "drop table", "drop database" +# or "shutdown". +destructive_warning = True + +# log_file location. +# In Unix/Linux: ~/.config/litecli/log +# In Windows: %USERPROFILE%\AppData\Local\dbcli\litecli\log +# %USERPROFILE% is typically C:\Users\{username} +log_file = default + +# Default log level. Possible values: "CRITICAL", "ERROR", "WARNING", "INFO" +# and "DEBUG". "NONE" disables logging. +log_level = INFO + +# Log every query and its results to a file. Enable this by uncommenting the +# line below. +# audit_log = ~/.litecli-audit.log + +# Default pager. +# By default '$PAGER' environment variable is used +# pager = less -SRXF + +# Table format. Possible values: +# ascii, double, github, psql, plain, simple, grid, fancy_grid, pipe, orgtbl, +# rst, mediawiki, html, latex, latex_booktabs, textile, moinmoin, jira, +# vertical, tsv, csv. +# Recommended: ascii +table_format = ascii + +# Syntax coloring style. Possible values (many support the "-dark" suffix): +# manni, igor, xcode, vim, autumn, vs, rrt, native, perldoc, borland, tango, emacs, +# friendly, monokai, paraiso, colorful, murphy, bw, pastie, paraiso, trac, default, +# fruity. +# Screenshots at http://mycli.net/syntax +syntax_style = default + +# Keybindings: Possible values: emacs, vi. +# Emacs mode: Ctrl-A is home, Ctrl-E is end. All emacs keybindings are available in the REPL. +# When Vi mode is enabled you can use modal editing features offered by Vi in the REPL. +key_bindings = emacs + +# Enabling this option will show the suggestions in a wider menu. Thus more items are suggested. +wider_completion_menu = False + +# Autocompletion is on by default. This can be truned off by setting this +# option to False. Pressing tab will still trigger completion. +autocompletion = True + +# litecli prompt +# \D - The full current date +# \d - Database name +# \f - File basename of the "main" database +# \m - Minutes of the current time +# \n - Newline +# \P - AM/PM +# \R - The current time, in 24-hour military time (0-23) +# \r - The current time, standard 12-hour time (1-12) +# \s - Seconds of the current time +# \x1b[...m - insert ANSI escape sequence +prompt = "\t :\d> " +prompt_continuation = "-> " + +# Show/hide the informational toolbar with function keymap at the footer. +show_bottom_toolbar = True + +# Skip intro info on startup and outro info on exit +less_chatty = False + +# Use alias from --login-path instead of host name in prompt +login_path_as_host = False + +# Cause result sets to be displayed vertically if they are too wide for the current window, +# and using normal tabular format otherwise. (This applies to statements terminated by ; or \G.) +auto_vertical_output = False + +# keyword casing preference. Possible values "lower", "upper", "auto" +keyword_casing = auto + +# disabled pager on startup +enable_pager = True +[colors] +completion-menu.completion.current = "bg:#ffffff #000000" +completion-menu.completion = "bg:#008888 #ffffff" +completion-menu.meta.completion.current = "bg:#44aaaa #000000" +completion-menu.meta.completion = "bg:#448888 #ffffff" +completion-menu.multi-column-meta = "bg:#aaffff #000000" +scrollbar.arrow = "bg:#003333" +scrollbar = "bg:#00aaaa" +selected = "#ffffff bg:#6666aa" +search = "#ffffff bg:#4444aa" +search.current = "#ffffff bg:#44aa44" +bottom-toolbar = "bg:#222222 #aaaaaa" +bottom-toolbar.off = "bg:#222222 #888888" +bottom-toolbar.on = "bg:#222222 #ffffff" +search-toolbar = noinherit bold +search-toolbar.text = nobold +system-toolbar = noinherit bold +arg-toolbar = noinherit bold +arg-toolbar.text = nobold +bottom-toolbar.transaction.valid = "bg:#222222 #00ff5f bold" +bottom-toolbar.transaction.failed = "bg:#222222 #ff005f bold" + +# style classes for colored table output +output.header = "#00ff5f bold" +output.odd-row = "" +output.even-row = "" +Token.Menu.Completions.Completion.Current = "bg:#00aaaa #000000" +Token.Menu.Completions.Completion = "bg:#008888 #ffffff" +Token.Menu.Completions.MultiColumnMeta = "bg:#aaffff #000000" +Token.Menu.Completions.ProgressButton = "bg:#003333" +Token.Menu.Completions.ProgressBar = "bg:#00aaaa" +Token.Output.Header = bold +Token.Output.OddRow = "" +Token.Output.EvenRow = "" +Token.SelectedText = "#ffffff bg:#6666aa" +Token.SearchMatch = "#ffffff bg:#4444aa" +Token.SearchMatch.Current = "#ffffff bg:#44aa44" +Token.Toolbar = "bg:#222222 #aaaaaa" +Token.Toolbar.Off = "bg:#222222 #888888" +Token.Toolbar.On = "bg:#222222 #ffffff" +Token.Toolbar.Search = noinherit bold +Token.Toolbar.Search.Text = nobold +Token.Toolbar.System = noinherit bold +Token.Toolbar.Arg = noinherit bold +Token.Toolbar.Arg.Text = nobold +[favorite_queries] +q_param = select * from test where name=? +sh_param = select * from test where id=$1 diff --git a/tests/test.txt b/tests/test.txt new file mode 100644 index 0000000..1fa4cf0 --- /dev/null +++ b/tests/test.txt @@ -0,0 +1 @@ +litecli is awesome! diff --git a/tests/test_clistyle.py b/tests/test_clistyle.py new file mode 100644 index 0000000..c1177de --- /dev/null +++ b/tests/test_clistyle.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +"""Test the litecli.clistyle module.""" +import pytest + +from pygments.style import Style +from pygments.token import Token + +from litecli.clistyle import style_factory + + +@pytest.mark.skip(reason="incompatible with new prompt toolkit") +def test_style_factory(): + """Test that a Pygments Style class is created.""" + header = "bold underline #ansired" + cli_style = {"Token.Output.Header": header} + style = style_factory("default", cli_style) + + assert isinstance(style(), Style) + assert Token.Output.Header in style.styles + assert header == style.styles[Token.Output.Header] + + +@pytest.mark.skip(reason="incompatible with new prompt toolkit") +def test_style_factory_unknown_name(): + """Test that an unrecognized name will not throw an error.""" + style = style_factory("foobar", {}) + + assert isinstance(style(), Style) diff --git a/tests/test_completion_engine.py b/tests/test_completion_engine.py new file mode 100644 index 0000000..760f275 --- /dev/null +++ b/tests/test_completion_engine.py @@ -0,0 +1,657 @@ +from litecli.packages.completion_engine import suggest_type +import pytest + + +def sorted_dicts(dicts): + """input is a list of dicts.""" + return sorted(tuple(x.items()) for x in dicts) + + +def test_select_suggests_cols_with_visible_table_scope(): + suggestions = suggest_type("SELECT FROM tabl", "SELECT ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) + + +def test_select_suggests_cols_with_qualified_table_scope(): + suggestions = suggest_type("SELECT FROM sch.tabl", "SELECT ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [("sch", "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) + + +def test_order_by_suggests_cols_with_qualified_table_scope(): + suggestions = suggest_type( + "SELECT * FROM sch.tabl ORDER BY ", "SELECT * FROM sch.tabl ORDER BY " + ) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "column", "tables": [("sch", "tabl", None)]}, + ] + ) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM tabl WHERE ", + "SELECT * FROM tabl WHERE (", + "SELECT * FROM tabl WHERE foo = ", + "SELECT * FROM tabl WHERE bar OR ", + "SELECT * FROM tabl WHERE foo = 1 AND ", + "SELECT * FROM tabl WHERE (bar > 10 AND ", + "SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (", + "SELECT * FROM tabl WHERE 10 < ", + "SELECT * FROM tabl WHERE foo BETWEEN ", + "SELECT * FROM tabl WHERE foo BETWEEN foo AND ", + ], +) +def test_where_suggests_columns_functions(expression): + suggestions = suggest_type(expression, expression) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) + + +@pytest.mark.parametrize( + "expression", + ["SELECT * FROM tabl WHERE foo IN (", "SELECT * FROM tabl WHERE foo IN (bar, "], +) +def test_where_in_suggests_columns(expression): + suggestions = suggest_type(expression, expression) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) + + +def test_where_equals_any_suggests_columns_or_keywords(): + text = "SELECT * FROM tabl WHERE foo = ANY(" + suggestions = suggest_type(text, text) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) + + +def test_lparen_suggests_cols(): + suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(") + assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}] + + +def test_operand_inside_function_suggests_cols1(): + suggestion = suggest_type("SELECT MAX(col1 + FROM tbl", "SELECT MAX(col1 + ") + assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}] + + +def test_operand_inside_function_suggests_cols2(): + suggestion = suggest_type( + "SELECT MAX(col1 + col2 + FROM tbl", "SELECT MAX(col1 + col2 + " + ) + assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}] + + +def test_select_suggests_cols_and_funcs(): + suggestions = suggest_type("SELECT ", "SELECT ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": []}, + {"type": "column", "tables": []}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM ", + "INSERT INTO ", + "COPY ", + "UPDATE ", + "DESCRIBE ", + "DESC ", + "EXPLAIN ", + "SELECT * FROM foo JOIN ", + ], +) +def test_expression_suggests_tables_views_and_schemas(expression): + suggestions = suggest_type(expression, expression) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + {"type": "schema"}, + ] + ) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM sch.", + "INSERT INTO sch.", + "COPY sch.", + "UPDATE sch.", + "DESCRIBE sch.", + "DESC sch.", + "EXPLAIN sch.", + "SELECT * FROM foo JOIN sch.", + ], +) +def test_expression_suggests_qualified_tables_views_and_schemas(expression): + suggestions = suggest_type(expression, expression) + assert sorted_dicts(suggestions) == sorted_dicts( + [{"type": "table", "schema": "sch"}, {"type": "view", "schema": "sch"}] + ) + + +def test_truncate_suggests_tables_and_schemas(): + suggestions = suggest_type("TRUNCATE ", "TRUNCATE ") + assert sorted_dicts(suggestions) == sorted_dicts( + [{"type": "table", "schema": []}, {"type": "schema"}] + ) + + +def test_truncate_suggests_qualified_tables(): + suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.") + assert sorted_dicts(suggestions) == sorted_dicts( + [{"type": "table", "schema": "sch"}] + ) + + +def test_distinct_suggests_cols(): + suggestions = suggest_type("SELECT DISTINCT ", "SELECT DISTINCT ") + assert suggestions == [{"type": "column", "tables": []}] + + +def test_col_comma_suggests_cols(): + suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["tbl"]}, + {"type": "column", "tables": [(None, "tbl", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) + + +def test_table_comma_suggests_tables_and_schemas(): + suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + {"type": "schema"}, + ] + ) + + +def test_into_suggests_tables_and_schemas(): + suggestion = suggest_type("INSERT INTO ", "INSERT INTO ") + assert sorted_dicts(suggestion) == sorted_dicts( + [ + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + {"type": "schema"}, + ] + ) + + +def test_insert_into_lparen_suggests_cols(): + suggestions = suggest_type("INSERT INTO abc (", "INSERT INTO abc (") + assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}] + + +def test_insert_into_lparen_partial_text_suggests_cols(): + suggestions = suggest_type("INSERT INTO abc (i", "INSERT INTO abc (i") + assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}] + + +def test_insert_into_lparen_comma_suggests_cols(): + suggestions = suggest_type("INSERT INTO abc (id,", "INSERT INTO abc (id,") + assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}] + + +def test_partially_typed_col_name_suggests_col_names(): + suggestions = suggest_type( + "SELECT * FROM tabl WHERE col_n", "SELECT * FROM tabl WHERE col_n" + ) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) + + +def test_dot_suggests_cols_of_a_table_or_schema_qualified_table(): + suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "table", "schema": "tabl"}, + {"type": "view", "schema": "tabl"}, + {"type": "function", "schema": "tabl"}, + ] + ) + + +def test_dot_suggests_cols_of_an_alias(): + suggestions = suggest_type("SELECT t1. FROM tabl1 t1, tabl2 t2", "SELECT t1.") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "table", "schema": "t1"}, + {"type": "view", "schema": "t1"}, + {"type": "column", "tables": [(None, "tabl1", "t1")]}, + {"type": "function", "schema": "t1"}, + ] + ) + + +def test_dot_col_comma_suggests_cols_or_schema_qualified_table(): + suggestions = suggest_type( + "SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2." + ) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "column", "tables": [(None, "tabl2", "t2")]}, + {"type": "table", "schema": "t2"}, + {"type": "view", "schema": "t2"}, + {"type": "function", "schema": "t2"}, + ] + ) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM (", + "SELECT * FROM foo WHERE EXISTS (", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (", + "SELECT 1 AS", + ], +) +def test_sub_select_suggests_keyword(expression): + suggestion = suggest_type(expression, expression) + assert suggestion == [{"type": "keyword"}] + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM (S", + "SELECT * FROM foo WHERE EXISTS (S", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (S", + ], +) +def test_sub_select_partial_text_suggests_keyword(expression): + suggestion = suggest_type(expression, expression) + assert suggestion == [{"type": "keyword"}] + + +def test_outer_table_reference_in_exists_subquery_suggests_columns(): + q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f." + suggestions = suggest_type(q, q) + assert suggestions == [ + {"type": "column", "tables": [(None, "foo", "f")]}, + {"type": "table", "schema": "f"}, + {"type": "view", "schema": "f"}, + {"type": "function", "schema": "f"}, + ] + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM (SELECT * FROM ", + "SELECT * FROM foo WHERE EXISTS (SELECT * FROM ", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ", + ], +) +def test_sub_select_table_name_completion(expression): + suggestion = suggest_type(expression, expression) + assert sorted_dicts(suggestion) == sorted_dicts( + [ + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + {"type": "schema"}, + ] + ) + + +def test_sub_select_col_name_completion(): + suggestions = suggest_type( + "SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT " + ) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["abc"]}, + {"type": "column", "tables": [(None, "abc", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) + + +@pytest.mark.xfail +def test_sub_select_multiple_col_name_completion(): + suggestions = suggest_type( + "SELECT * FROM (SELECT a, FROM abc", "SELECT * FROM (SELECT a, " + ) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "column", "tables": [(None, "abc", None)]}, + {"type": "function", "schema": []}, + ] + ) + + +def test_sub_select_dot_col_name_completion(): + suggestions = suggest_type( + "SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t." + ) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "column", "tables": [(None, "tabl", "t")]}, + {"type": "table", "schema": "t"}, + {"type": "view", "schema": "t"}, + {"type": "function", "schema": "t"}, + ] + ) + + +@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"]) +@pytest.mark.parametrize("tbl_alias", ["", "foo"]) +def test_join_suggests_tables_and_schemas(tbl_alias, join_type): + text = "SELECT * FROM abc {0} {1} JOIN ".format(tbl_alias, join_type) + suggestion = suggest_type(text, text) + assert sorted_dicts(suggestion) == sorted_dicts( + [ + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + {"type": "schema"}, + ] + ) + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM abc a JOIN def d ON a.", + "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.", + ], +) +def test_join_alias_dot_suggests_cols1(sql): + suggestions = suggest_type(sql, sql) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "column", "tables": [(None, "abc", "a")]}, + {"type": "table", "schema": "a"}, + {"type": "view", "schema": "a"}, + {"type": "function", "schema": "a"}, + ] + ) + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM abc a JOIN def d ON a.id = d.", + "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.", + ], +) +def test_join_alias_dot_suggests_cols2(sql): + suggestions = suggest_type(sql, sql) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "column", "tables": [(None, "def", "d")]}, + {"type": "table", "schema": "d"}, + {"type": "view", "schema": "d"}, + {"type": "function", "schema": "d"}, + ] + ) + + +@pytest.mark.parametrize( + "sql", + [ + "select a.x, b.y from abc a join bcd b on ", + "select a.x, b.y from abc a join bcd b on a.id = b.id OR ", + ], +) +def test_on_suggests_aliases(sql): + suggestions = suggest_type(sql, sql) + assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}] + + +@pytest.mark.parametrize( + "sql", + [ + "select abc.x, bcd.y from abc join bcd on ", + "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ", + ], +) +def test_on_suggests_tables(sql): + suggestions = suggest_type(sql, sql) + assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}] + + +@pytest.mark.parametrize( + "sql", + [ + "select a.x, b.y from abc a join bcd b on a.id = ", + "select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ", + ], +) +def test_on_suggests_aliases_right_side(sql): + suggestions = suggest_type(sql, sql) + assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}] + + +@pytest.mark.parametrize( + "sql", + [ + "select abc.x, bcd.y from abc join bcd on ", + "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ", + ], +) +def test_on_suggests_tables_right_side(sql): + suggestions = suggest_type(sql, sql) + assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}] + + +@pytest.mark.parametrize("col_list", ["", "col1, "]) +def test_join_using_suggests_common_columns(col_list): + text = "select * from abc inner join def using (" + col_list + assert suggest_type(text, text) == [ + { + "type": "column", + "tables": [(None, "abc", None), (None, "def", None)], + "drop_unique": True, + } + ] + + +def test_2_statements_2nd_current(): + suggestions = suggest_type( + "select * from a; select * from ", "select * from a; select * from " + ) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + {"type": "schema"}, + ] + ) + + suggestions = suggest_type( + "select * from a; select from b", "select * from a; select " + ) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["b"]}, + {"type": "column", "tables": [(None, "b", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) + + # Should work even if first statement is invalid + suggestions = suggest_type( + "select * from; select * from ", "select * from; select * from " + ) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + {"type": "schema"}, + ] + ) + + +def test_2_statements_1st_current(): + suggestions = suggest_type("select * from ; select * from b", "select * from ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + {"type": "schema"}, + ] + ) + + suggestions = suggest_type("select from a; select * from b", "select ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["a"]}, + {"type": "column", "tables": [(None, "a", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) + + +def test_3_statements_2nd_current(): + suggestions = suggest_type( + "select * from a; select * from ; select * from c", + "select * from a; select * from ", + ) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + {"type": "schema"}, + ] + ) + + suggestions = suggest_type( + "select * from a; select from b; select * from c", "select * from a; select " + ) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["b"]}, + {"type": "column", "tables": [(None, "b", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) + + +def test_create_db_with_template(): + suggestions = suggest_type( + "create database foo with template ", "create database foo with template " + ) + + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "database"}]) + + +@pytest.mark.parametrize("initial_text", ["", " ", "\t \t"]) +def test_specials_included_for_initial_completion(initial_text): + suggestions = suggest_type(initial_text, initial_text) + + assert sorted_dicts(suggestions) == sorted_dicts( + [{"type": "keyword"}, {"type": "special"}] + ) + + +def test_specials_not_included_after_initial_token(): + suggestions = suggest_type("create table foo (dt d", "create table foo (dt d") + + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "keyword"}]) + + +def test_drop_schema_qualified_table_suggests_only_tables(): + text = "DROP TABLE schema_name.table_name" + suggestions = suggest_type(text, text) + assert suggestions == [{"type": "table", "schema": "schema_name"}] + + +@pytest.mark.parametrize("text", [",", " ,", "sel ,"]) +def test_handle_pre_completion_comma_gracefully(text): + suggestions = suggest_type(text, text) + + assert iter(suggestions) + + +def test_cross_join(): + text = "select * from v1 cross join v2 JOIN v1.id, " + suggestions = suggest_type(text, text) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + {"type": "schema"}, + ] + ) + + +@pytest.mark.parametrize("expression", ["SELECT 1 AS ", "SELECT 1 FROM tabl AS "]) +def test_after_as(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == set() + + +@pytest.mark.parametrize( + "expression", + [ + "\\. ", + "select 1; \\. ", + "select 1;\\. ", + "select 1 ; \\. ", + "source ", + "truncate table test; source ", + "truncate table test ; source ", + "truncate table test;source ", + ], +) +def test_source_is_file(expression): + suggestions = suggest_type(expression, expression) + assert suggestions == [{"type": "file_name"}] diff --git a/tests/test_completion_refresher.py b/tests/test_completion_refresher.py new file mode 100644 index 0000000..620a364 --- /dev/null +++ b/tests/test_completion_refresher.py @@ -0,0 +1,94 @@ +import time +import pytest +from mock import Mock, patch + + +@pytest.fixture +def refresher(): + from litecli.completion_refresher import CompletionRefresher + + return CompletionRefresher() + + +def test_ctor(refresher): + """Refresher object should contain a few handlers. + + :param refresher: + :return: + + """ + assert len(refresher.refreshers) > 0 + actual_handlers = list(refresher.refreshers.keys()) + expected_handlers = [ + "databases", + "schemata", + "tables", + "functions", + "special_commands", + ] + assert expected_handlers == actual_handlers + + +def test_refresh_called_once(refresher): + """ + + :param refresher: + :return: + """ + callbacks = Mock() + sqlexecute = Mock() + + with patch.object(refresher, "_bg_refresh") as bg_refresh: + actual = refresher.refresh(sqlexecute, callbacks) + time.sleep(1) # Wait for the thread to work. + assert len(actual) == 1 + assert len(actual[0]) == 4 + assert actual[0][3] == "Auto-completion refresh started in the background." + bg_refresh.assert_called_with(sqlexecute, callbacks, {}) + + +def test_refresh_called_twice(refresher): + """If refresh is called a second time, it should be restarted. + + :param refresher: + :return: + + """ + callbacks = Mock() + + sqlexecute = Mock() + + def dummy_bg_refresh(*args): + time.sleep(3) # seconds + + refresher._bg_refresh = dummy_bg_refresh + + actual1 = refresher.refresh(sqlexecute, callbacks) + time.sleep(1) # Wait for the thread to work. + assert len(actual1) == 1 + assert len(actual1[0]) == 4 + assert actual1[0][3] == "Auto-completion refresh started in the background." + + actual2 = refresher.refresh(sqlexecute, callbacks) + time.sleep(1) # Wait for the thread to work. + assert len(actual2) == 1 + assert len(actual2[0]) == 4 + assert actual2[0][3] == "Auto-completion refresh restarted." + + +def test_refresh_with_callbacks(refresher): + """Callbacks must be called. + + :param refresher: + + """ + callbacks = [Mock()] + sqlexecute_class = Mock() + sqlexecute = Mock() + + with patch("litecli.completion_refresher.SQLExecute", sqlexecute_class): + # Set refreshers to 0: we're not testing refresh logic here + refresher.refreshers = {} + refresher.refresh(sqlexecute, callbacks) + time.sleep(1) # Wait for the thread to work. + assert callbacks[0].call_count == 1 diff --git a/tests/test_dbspecial.py b/tests/test_dbspecial.py new file mode 100644 index 0000000..5128b5b --- /dev/null +++ b/tests/test_dbspecial.py @@ -0,0 +1,76 @@ +from litecli.packages.completion_engine import suggest_type +from test_completion_engine import sorted_dicts +from litecli.packages.special.utils import format_uptime + + +def test_import_first_argument(): + test_cases = [ + # text, expecting_arg_idx + [".import ", 1], + [".import ./da", 1], + [".import ./data.csv ", 2], + [".import ./data.csv t", 2], + [".import ./data.csv `t", 2], + ['.import ./data.csv "t', 2], + ] + for text, expecting_arg_idx in test_cases: + suggestions = suggest_type(text, text) + if expecting_arg_idx == 1: + assert suggestions == [{"type": "file_name"}] + else: + assert suggestions == [{"type": "table", "schema": []}] + + +def test_u_suggests_databases(): + suggestions = suggest_type("\\u ", "\\u ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "database"}]) + + +def test_describe_table(): + suggestions = suggest_type("\\dt", "\\dt ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + {"type": "schema"}, + ] + ) + + +def test_list_or_show_create_tables(): + suggestions = suggest_type("\\dt+", "\\dt+ ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + {"type": "schema"}, + ] + ) + + +def test_format_uptime(): + seconds = 59 + assert "59 sec" == format_uptime(seconds) + + seconds = 120 + assert "2 min 0 sec" == format_uptime(seconds) + + seconds = 54890 + assert "15 hours 14 min 50 sec" == format_uptime(seconds) + + seconds = 598244 + assert "6 days 22 hours 10 min 44 sec" == format_uptime(seconds) + + seconds = 522600 + assert "6 days 1 hour 10 min 0 sec" == format_uptime(seconds) + + +def test_indexes(): + suggestions = suggest_type(".indexes", ".indexes ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + {"type": "schema"}, + ] + ) diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..d4d52af --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,262 @@ +import os +from collections import namedtuple +from textwrap import dedent +from tempfile import NamedTemporaryFile +import shutil + +import click +from click.testing import CliRunner + +from litecli.main import cli, LiteCli +from litecli.packages.special.main import COMMANDS as SPECIAL_COMMANDS +from utils import dbtest, run + +test_dir = os.path.abspath(os.path.dirname(__file__)) +project_dir = os.path.dirname(test_dir) +default_config_file = os.path.join(project_dir, "tests", "liteclirc") + +CLI_ARGS = ["--liteclirc", default_config_file, "_test_db"] + + +@dbtest +def test_execute_arg(executor): + run(executor, "create table test (a text)") + run(executor, 'insert into test values("abc")') + + sql = "select * from test;" + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql]) + + assert result.exit_code == 0 + assert "abc" in result.output + + result = runner.invoke(cli, args=CLI_ARGS + ["--execute", sql]) + + assert result.exit_code == 0 + assert "abc" in result.output + + expected = "a\nabc\n" + + assert expected in result.output + + +@dbtest +def test_execute_arg_with_table(executor): + run(executor, "create table test (a text)") + run(executor, 'insert into test values("abc")') + + sql = "select * from test;" + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql] + ["--table"]) + expected = "+-----+\n| a |\n+-----+\n| abc |\n+-----+\n" + + assert result.exit_code == 0 + assert expected in result.output + + +@dbtest +def test_execute_arg_with_csv(executor): + run(executor, "create table test (a text)") + run(executor, 'insert into test values("abc")') + + sql = "select * from test;" + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql] + ["--csv"]) + expected = '"a"\n"abc"\n' + + assert result.exit_code == 0 + assert expected in "".join(result.output) + + +@dbtest +def test_batch_mode(executor): + run(executor, """create table test(a text)""") + run(executor, """insert into test values('abc'), ('def'), ('ghi')""") + + sql = "select count(*) from test;\n" "select * from test limit 1;" + + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS, input=sql) + + assert result.exit_code == 0 + assert "count(*)\n3\na\nabc\n" in "".join(result.output) + + +@dbtest +def test_batch_mode_table(executor): + run(executor, """create table test(a text)""") + run(executor, """insert into test values('abc'), ('def'), ('ghi')""") + + sql = "select count(*) from test;\n" "select * from test limit 1;" + + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS + ["-t"], input=sql) + + expected = dedent( + """\ + +----------+ + | count(*) | + +----------+ + | 3 | + +----------+ + +-----+ + | a | + +-----+ + | abc | + +-----+""" + ) + + assert result.exit_code == 0 + assert expected in result.output + + +@dbtest +def test_batch_mode_csv(executor): + run(executor, """create table test(a text, b text)""") + run(executor, """insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')""") + + sql = "select * from test;" + + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS + ["--csv"], input=sql) + + expected = '"a","b"\n"abc","de\nf"\n"ghi","jkl"\n' + + assert result.exit_code == 0 + assert expected in "".join(result.output) + + +def test_help_strings_end_with_periods(): + """Make sure click options have help text that end with a period.""" + for param in cli.params: + if isinstance(param, click.core.Option): + assert hasattr(param, "help") + assert param.help.endswith(".") + + +def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager): + global clickoutput + clickoutput = "" + m = LiteCli(liteclirc=default_config_file) + + class TestOutput: + def get_size(self): + size = namedtuple("Size", "rows columns") + size.columns, size.rows = terminal_size + return size + + class TestExecute: + host = "test" + user = "test" + dbname = "test" + port = 0 + + def server_type(self): + return ["test"] + + class PromptBuffer: + output = TestOutput() + + m.prompt_app = PromptBuffer() + m.sqlexecute = TestExecute() + m.explicit_pager = explicit_pager + + def echo_via_pager(s): + assert expect_pager + global clickoutput + clickoutput += s + + def secho(s): + assert not expect_pager + global clickoutput + clickoutput += s + "\n" + + monkeypatch.setattr(click, "echo_via_pager", echo_via_pager) + monkeypatch.setattr(click, "secho", secho) + m.output(testdata) + if clickoutput.endswith("\n"): + clickoutput = clickoutput[:-1] + assert clickoutput == "\n".join(testdata) + + +def test_conditional_pager(monkeypatch): + testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split( + " " + ) + # User didn't set pager, output doesn't fit screen -> pager + output( + monkeypatch, + terminal_size=(5, 10), + testdata=testdata, + explicit_pager=False, + expect_pager=True, + ) + # User didn't set pager, output fits screen -> no pager + output( + monkeypatch, + terminal_size=(20, 20), + testdata=testdata, + explicit_pager=False, + expect_pager=False, + ) + # User manually configured pager, output doesn't fit screen -> pager + output( + monkeypatch, + terminal_size=(5, 10), + testdata=testdata, + explicit_pager=True, + expect_pager=True, + ) + # User manually configured pager, output fit screen -> pager + output( + monkeypatch, + terminal_size=(20, 20), + testdata=testdata, + explicit_pager=True, + expect_pager=True, + ) + + SPECIAL_COMMANDS["nopager"].handler() + output( + monkeypatch, + terminal_size=(5, 10), + testdata=testdata, + explicit_pager=False, + expect_pager=False, + ) + SPECIAL_COMMANDS["pager"].handler("") + + +def test_reserved_space_is_integer(): + """Make sure that reserved space is returned as an integer.""" + + def stub_terminal_size(): + return (5, 5) + + old_func = shutil.get_terminal_size + + shutil.get_terminal_size = stub_terminal_size + lc = LiteCli() + assert isinstance(lc.get_reserved_space(), int) + shutil.get_terminal_size = old_func + + +@dbtest +def test_import_command(executor): + data_file = os.path.join(project_dir, "tests", "data", "import_data.csv") + run(executor, """create table tbl1(one varchar(10), two smallint)""") + + # execute + run(executor, """.import %s tbl1""" % data_file) + + # verify + sql = "select * from tbl1;" + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS + ["--csv"], input=sql) + + expected = """one","two" +"t1","11" +"t2","22" +""" + assert result.exit_code == 0 + assert expected in "".join(result.output) diff --git a/tests/test_parseutils.py b/tests/test_parseutils.py new file mode 100644 index 0000000..cad7a8c --- /dev/null +++ b/tests/test_parseutils.py @@ -0,0 +1,131 @@ +import pytest +from litecli.packages.parseutils import ( + extract_tables, + query_starts_with, + queries_start_with, + is_destructive, +) + + +def test_empty_string(): + tables = extract_tables("") + assert tables == [] + + +def test_simple_select_single_table(): + tables = extract_tables("select * from abc") + assert tables == [(None, "abc", None)] + + +def test_simple_select_single_table_schema_qualified(): + tables = extract_tables("select * from abc.def") + assert tables == [("abc", "def", None)] + + +def test_simple_select_multiple_tables(): + tables = extract_tables("select * from abc, def") + assert sorted(tables) == [(None, "abc", None), (None, "def", None)] + + +def test_simple_select_multiple_tables_schema_qualified(): + tables = extract_tables("select * from abc.def, ghi.jkl") + assert sorted(tables) == [("abc", "def", None), ("ghi", "jkl", None)] + + +def test_simple_select_with_cols_single_table(): + tables = extract_tables("select a,b from abc") + assert tables == [(None, "abc", None)] + + +def test_simple_select_with_cols_single_table_schema_qualified(): + tables = extract_tables("select a,b from abc.def") + assert tables == [("abc", "def", None)] + + +def test_simple_select_with_cols_multiple_tables(): + tables = extract_tables("select a,b from abc, def") + assert sorted(tables) == [(None, "abc", None), (None, "def", None)] + + +def test_simple_select_with_cols_multiple_tables_with_schema(): + tables = extract_tables("select a,b from abc.def, def.ghi") + assert sorted(tables) == [("abc", "def", None), ("def", "ghi", None)] + + +def test_select_with_hanging_comma_single_table(): + tables = extract_tables("select a, from abc") + assert tables == [(None, "abc", None)] + + +def test_select_with_hanging_comma_multiple_tables(): + tables = extract_tables("select a, from abc, def") + assert sorted(tables) == [(None, "abc", None), (None, "def", None)] + + +def test_select_with_hanging_period_multiple_tables(): + tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2") + assert sorted(tables) == [(None, "tabl1", "t1"), (None, "tabl2", "t2")] + + +def test_simple_insert_single_table(): + tables = extract_tables('insert into abc (id, name) values (1, "def")') + + # sqlparse mistakenly assigns an alias to the table + # assert tables == [(None, 'abc', None)] + assert tables == [(None, "abc", "abc")] + + +@pytest.mark.xfail +def test_simple_insert_single_table_schema_qualified(): + tables = extract_tables('insert into abc.def (id, name) values (1, "def")') + assert tables == [("abc", "def", None)] + + +def test_simple_update_table(): + tables = extract_tables("update abc set id = 1") + assert tables == [(None, "abc", None)] + + +def test_simple_update_table_with_schema(): + tables = extract_tables("update abc.def set id = 1") + assert tables == [("abc", "def", None)] + + +def test_join_table(): + tables = extract_tables("SELECT * FROM abc a JOIN def d ON a.id = d.num") + assert sorted(tables) == [(None, "abc", "a"), (None, "def", "d")] + + +def test_join_table_schema_qualified(): + tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num") + assert tables == [("abc", "def", "x"), ("ghi", "jkl", "y")] + + +def test_join_as_table(): + tables = extract_tables("SELECT * FROM my_table AS m WHERE m.a > 5") + assert tables == [(None, "my_table", "m")] + + +def test_query_starts_with(): + query = "USE test;" + assert query_starts_with(query, ("use",)) is True + + query = "DROP DATABASE test;" + assert query_starts_with(query, ("use",)) is False + + +def test_query_starts_with_comment(): + query = "# comment\nUSE test;" + assert query_starts_with(query, ("use",)) is True + + +def test_queries_start_with(): + sql = "# comment\n" "show databases;" "use foo;" + assert queries_start_with(sql, ("show", "select")) is True + assert queries_start_with(sql, ("use", "drop")) is True + assert queries_start_with(sql, ("delete", "update")) is False + + +def test_is_destructive(): + sql = "use test;\n" "show databases;\n" "drop database foo;" + assert is_destructive(sql) is True diff --git a/tests/test_prompt_utils.py b/tests/test_prompt_utils.py new file mode 100644 index 0000000..2de74ce --- /dev/null +++ b/tests/test_prompt_utils.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- + + +import click + +from litecli.packages.prompt_utils import confirm_destructive_query + + +def test_confirm_destructive_query_notty(): + stdin = click.get_text_stream("stdin") + assert stdin.isatty() is False + + sql = "drop database foo;" + assert confirm_destructive_query(sql) is None diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py new file mode 100644 index 0000000..e532118 --- /dev/null +++ b/tests/test_smart_completion_public_schema_only.py @@ -0,0 +1,432 @@ +# coding: utf-8 +from __future__ import unicode_literals +import pytest +from mock import patch +from prompt_toolkit.completion import Completion +from prompt_toolkit.document import Document + +metadata = { + "users": ["id", "email", "first_name", "last_name"], + "orders": ["id", "ordered_date", "status"], + "select": ["id", "insert", "ABC"], + "réveillé": ["id", "insert", "ABC"], +} + + +@pytest.fixture +def completer(): + + import litecli.sqlcompleter as sqlcompleter + + comp = sqlcompleter.SQLCompleter() + + tables, columns = [], [] + + for table, cols in metadata.items(): + tables.append((table,)) + columns.extend([(table, col) for col in cols]) + + comp.set_dbname("test") + comp.extend_schemata("test") + comp.extend_relations(tables, kind="tables") + comp.extend_columns(columns, kind="tables") + + return comp + + +@pytest.fixture +def complete_event(): + from mock import Mock + + return Mock() + + +def test_empty_string_completion(completer, complete_event): + text = "" + position = 0 + result = list( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert list(map(Completion, sorted(completer.keywords))) == result + + +def test_select_keyword_completion(completer, complete_event): + text = "SEL" + position = len("SEL") + result = completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + assert list(result) == list([Completion(text="SELECT", start_position=-3)]) + + +def test_table_completion(completer, complete_event): + text = "SELECT * FROM " + position = len(text) + result = completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + assert list(result) == list( + [ + Completion(text="`réveillé`", start_position=0), + Completion(text="`select`", start_position=0), + Completion(text="orders", start_position=0), + Completion(text="users", start_position=0), + ] + ) + + +def test_function_name_completion(completer, complete_event): + text = "SELECT MA" + position = len("SELECT MA") + result = completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + assert list(result) == list( + [ + Completion(text="MAX", start_position=-2), + Completion(text="MATCH", start_position=-2), + ] + ) + + +def test_suggested_column_names(completer, complete_event): + """Suggest column and function names when selecting from table. + + :param completer: + :param complete_event: + :return: + + """ + text = "SELECT from users" + position = len("SELECT ") + result = list( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == list( + [ + Completion(text="*", start_position=0), + Completion(text="email", start_position=0), + Completion(text="first_name", start_position=0), + Completion(text="id", start_position=0), + Completion(text="last_name", start_position=0), + ] + + list(map(Completion, completer.functions)) + + [Completion(text="users", start_position=0)] + + list(map(Completion, sorted(completer.keywords))) + ) + + +def test_suggested_column_names_in_function(completer, complete_event): + """Suggest column and function names when selecting multiple columns from + table. + + :param completer: + :param complete_event: + :return: + + """ + text = "SELECT MAX( from users" + position = len("SELECT MAX(") + result = completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + assert list(result) == list( + [ + Completion(text="*", start_position=0), + Completion(text="email", start_position=0), + Completion(text="first_name", start_position=0), + Completion(text="id", start_position=0), + Completion(text="last_name", start_position=0), + ] + ) + + +def test_suggested_column_names_with_table_dot(completer, complete_event): + """Suggest column names on table name and dot. + + :param completer: + :param complete_event: + :return: + + """ + text = "SELECT users. from users" + position = len("SELECT users.") + result = list( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == list( + [ + Completion(text="*", start_position=0), + Completion(text="email", start_position=0), + Completion(text="first_name", start_position=0), + Completion(text="id", start_position=0), + Completion(text="last_name", start_position=0), + ] + ) + + +def test_suggested_column_names_with_alias(completer, complete_event): + """Suggest column names on table alias and dot. + + :param completer: + :param complete_event: + :return: + + """ + text = "SELECT u. from users u" + position = len("SELECT u.") + result = list( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == list( + [ + Completion(text="*", start_position=0), + Completion(text="email", start_position=0), + Completion(text="first_name", start_position=0), + Completion(text="id", start_position=0), + Completion(text="last_name", start_position=0), + ] + ) + + +def test_suggested_multiple_column_names(completer, complete_event): + """Suggest column and function names when selecting multiple columns from + table. + + :param completer: + :param complete_event: + :return: + + """ + text = "SELECT id, from users u" + position = len("SELECT id, ") + result = list( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == list( + [ + Completion(text="*", start_position=0), + Completion(text="email", start_position=0), + Completion(text="first_name", start_position=0), + Completion(text="id", start_position=0), + Completion(text="last_name", start_position=0), + ] + + list(map(Completion, completer.functions)) + + [Completion(text="u", start_position=0)] + + list(map(Completion, sorted(completer.keywords))) + ) + + +def test_suggested_multiple_column_names_with_alias(completer, complete_event): + """Suggest column names on table alias and dot when selecting multiple + columns from table. + + :param completer: + :param complete_event: + :return: + + """ + text = "SELECT u.id, u. from users u" + position = len("SELECT u.id, u.") + result = list( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == list( + [ + Completion(text="*", start_position=0), + Completion(text="email", start_position=0), + Completion(text="first_name", start_position=0), + Completion(text="id", start_position=0), + Completion(text="last_name", start_position=0), + ] + ) + + +def test_suggested_multiple_column_names_with_dot(completer, complete_event): + """Suggest column names on table names and dot when selecting multiple + columns from table. + + :param completer: + :param complete_event: + :return: + + """ + text = "SELECT users.id, users. from users u" + position = len("SELECT users.id, users.") + result = list( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == list( + [ + Completion(text="*", start_position=0), + Completion(text="email", start_position=0), + Completion(text="first_name", start_position=0), + Completion(text="id", start_position=0), + Completion(text="last_name", start_position=0), + ] + ) + + +def test_suggested_aliases_after_on(completer, complete_event): + text = "SELECT u.name, o.id FROM users u JOIN orders o ON " + position = len("SELECT u.name, o.id FROM users u JOIN orders o ON ") + result = list( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == list( + [Completion(text="o", start_position=0), Completion(text="u", start_position=0)] + ) + + +def test_suggested_aliases_after_on_right_side(completer, complete_event): + text = "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = " + position = len("SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ") + result = list( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == list( + [Completion(text="o", start_position=0), Completion(text="u", start_position=0)] + ) + + +def test_suggested_tables_after_on(completer, complete_event): + text = "SELECT users.name, orders.id FROM users JOIN orders ON " + position = len("SELECT users.name, orders.id FROM users JOIN orders ON ") + result = list( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == list( + [ + Completion(text="orders", start_position=0), + Completion(text="users", start_position=0), + ] + ) + + +def test_suggested_tables_after_on_right_side(completer, complete_event): + text = "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = " + position = len( + "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = " + ) + result = list( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert list(result) == list( + [ + Completion(text="orders", start_position=0), + Completion(text="users", start_position=0), + ] + ) + + +def test_table_names_after_from(completer, complete_event): + text = "SELECT * FROM " + position = len("SELECT * FROM ") + result = list( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert list(result) == list( + [ + Completion(text="`réveillé`", start_position=0), + Completion(text="`select`", start_position=0), + Completion(text="orders", start_position=0), + Completion(text="users", start_position=0), + ] + ) + + +def test_auto_escaped_col_names(completer, complete_event): + text = "SELECT from `select`" + position = len("SELECT ") + result = list( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert ( + result + == [ + Completion(text="*", start_position=0), + Completion(text="`ABC`", start_position=0), + Completion(text="`insert`", start_position=0), + Completion(text="id", start_position=0), + ] + + list(map(Completion, completer.functions)) + + [Completion(text="`select`", start_position=0)] + + list(map(Completion, sorted(completer.keywords))) + ) + + +def test_un_escaped_table_names(completer, complete_event): + text = "SELECT from réveillé" + position = len("SELECT ") + result = list( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == list( + [ + Completion(text="*", start_position=0), + Completion(text="`ABC`", start_position=0), + Completion(text="`insert`", start_position=0), + Completion(text="id", start_position=0), + ] + + list(map(Completion, completer.functions)) + + [Completion(text="réveillé", start_position=0)] + + list(map(Completion, sorted(completer.keywords))) + ) + + +def dummy_list_path(dir_name): + dirs = { + "/": ["dir1", "file1.sql", "file2.sql"], + "/dir1": ["subdir1", "subfile1.sql", "subfile2.sql"], + "/dir1/subdir1": ["lastfile.sql"], + } + return dirs.get(dir_name, []) + + +@patch("litecli.packages.filepaths.list_path", new=dummy_list_path) +@pytest.mark.parametrize( + "text,expected", + [ + ("source ", [(".", 0), ("..", 0), ("/", 0), ("~", 0)]), + ("source /", [("dir1", 0), ("file1.sql", 0), ("file2.sql", 0)]), + ("source /dir1/", [("subdir1", 0), ("subfile1.sql", 0), ("subfile2.sql", 0)]), + ("source /dir1/subdir1/", [("lastfile.sql", 0)]), + ], +) +def test_file_name_completion(completer, complete_event, text, expected): + position = len(text) + result = list( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + expected = list([Completion(txt, pos) for txt, pos in expected]) + assert result == expected diff --git a/tests/test_sqlexecute.py b/tests/test_sqlexecute.py new file mode 100644 index 0000000..e559bc6 --- /dev/null +++ b/tests/test_sqlexecute.py @@ -0,0 +1,405 @@ +# coding=UTF-8 + +import os + +import pytest + +from utils import run, dbtest, set_expanded_output, is_expanded_output +from sqlite3 import OperationalError, ProgrammingError + + +def assert_result_equal( + result, + title=None, + rows=None, + headers=None, + status=None, + auto_status=True, + assert_contains=False, +): + """Assert that an sqlexecute.run() result matches the expected values.""" + if status is None and auto_status and rows: + status = "{} row{} in set".format(len(rows), "s" if len(rows) > 1 else "") + fields = {"title": title, "rows": rows, "headers": headers, "status": status} + + if assert_contains: + # Do a loose match on the results using the *in* operator. + for key, field in fields.items(): + if field: + assert field in result[0][key] + else: + # Do an exact match on the fields. + assert result == [fields] + + +@dbtest +def test_conn(executor): + run(executor, """create table test(a text)""") + run(executor, """insert into test values('abc')""") + results = run(executor, """select * from test""") + + assert_result_equal(results, headers=["a"], rows=[("abc",)]) + + +@dbtest +def test_bools(executor): + run(executor, """create table test(a boolean)""") + run(executor, """insert into test values(1)""") + results = run(executor, """select * from test""") + + assert_result_equal(results, headers=["a"], rows=[(1,)]) + + +@dbtest +def test_binary(executor): + run(executor, """create table foo(blb BLOB NOT NULL)""") + run(executor, """INSERT INTO foo VALUES ('\x01\x01\x01\n')""") + results = run(executor, """select * from foo""") + + expected = "\x01\x01\x01\n" + + assert_result_equal(results, headers=["blb"], rows=[(expected,)]) + + +## Failing in Travis for some unknown reason. +# @dbtest +# def test_table_and_columns_query(executor): +# run(executor, "create table a(x text, y text)") +# run(executor, "create table b(z text)") + +# assert set(executor.tables()) == set([("a",), ("b",)]) +# assert set(executor.table_columns()) == set([("a", "x"), ("a", "y"), ("b", "z")]) + + +@dbtest +def test_database_list(executor): + databases = executor.databases() + assert "main" in list(databases) + + +@dbtest +def test_invalid_syntax(executor): + with pytest.raises(OperationalError) as excinfo: + run(executor, "invalid syntax!") + assert "syntax error" in str(excinfo.value) + + +@dbtest +def test_invalid_column_name(executor): + with pytest.raises(OperationalError) as excinfo: + run(executor, "select invalid command") + assert "no such column: invalid" in str(excinfo.value) + + +@dbtest +def test_unicode_support_in_output(executor): + run(executor, "create table unicodechars(t text)") + run(executor, u"insert into unicodechars (t) values ('é')") + + # See issue #24, this raises an exception without proper handling + results = run(executor, u"select * from unicodechars") + assert_result_equal(results, headers=["t"], rows=[(u"é",)]) + + +@dbtest +def test_invalid_unicode_values_dont_choke(executor): + run(executor, "create table unicodechars(t text)") + # \xc3 is not a valid utf-8 char. But we can insert it into the database + # which can break querying if not handled correctly. + run(executor, u"insert into unicodechars (t) values (cast(x'c3' as text))") + + results = run(executor, u"select * from unicodechars") + assert_result_equal(results, headers=["t"], rows=[("\\xc3",)]) + + +@dbtest +def test_multiple_queries_same_line(executor): + results = run(executor, "select 'foo'; select 'bar'") + + expected = [ + { + "title": None, + "headers": ["'foo'"], + "rows": [(u"foo",)], + "status": "1 row in set", + }, + { + "title": None, + "headers": ["'bar'"], + "rows": [(u"bar",)], + "status": "1 row in set", + }, + ] + assert expected == results + + +@dbtest +def test_multiple_queries_same_line_syntaxerror(executor): + with pytest.raises(OperationalError) as excinfo: + run(executor, "select 'foo'; invalid syntax") + assert "syntax error" in str(excinfo.value) + + +@dbtest +def test_favorite_query(executor): + set_expanded_output(False) + run(executor, "create table test(a text)") + run(executor, "insert into test values('abc')") + run(executor, "insert into test values('def')") + + results = run(executor, "\\fs test-a select * from test where a like 'a%'") + assert_result_equal(results, status="Saved.") + + results = run(executor, "\\f+ test-a") + assert_result_equal( + results, + title="> select * from test where a like 'a%'", + headers=["a"], + rows=[("abc",)], + auto_status=False, + ) + + results = run(executor, "\\fd test-a") + assert_result_equal(results, status="test-a: Deleted") + + +@dbtest +def test_bind_parameterized_favorite_query(executor): + set_expanded_output(False) + run(executor, "create table test(name text, id integer)") + run(executor, "insert into test values('def', 2)") + run(executor, "insert into test values('two words', 3)") + + results = run(executor, "\\fs q_param select * from test where name=?") + assert_result_equal(results, status="Saved.") + + results = run(executor, "\\f+ q_param def") + assert_result_equal( + results, + title="> select * from test where name=?", + headers=["name", "id"], + rows=[("def", 2)], + auto_status=False, + ) + + results = run(executor, "\\f+ q_param 'two words'") + assert_result_equal( + results, + title="> select * from test where name=?", + headers=["name", "id"], + rows=[("two words", 3)], + auto_status=False, + ) + + with pytest.raises(ProgrammingError): + results = run(executor, "\\f+ q_param") + + with pytest.raises(ProgrammingError): + results = run(executor, "\\f+ q_param 1 2") + + +@dbtest +def test_verbose_feature_of_favorite_query(executor): + set_expanded_output(False) + run(executor, "create table test(a text, id integer)") + run(executor, "insert into test values('abc', 1)") + run(executor, "insert into test values('def', 2)") + + results = run(executor, "\\fs sh_param select * from test where id=$1") + assert_result_equal(results, status="Saved.") + + results = run(executor, "\\f sh_param 1") + assert_result_equal( + results, + title=None, + headers=["a", "id"], + rows=[("abc", 1)], + auto_status=False, + ) + + results = run(executor, "\\f+ sh_param 1") + assert_result_equal( + results, + title="> select * from test where id=1", + headers=["a", "id"], + rows=[("abc", 1)], + auto_status=False, + ) + + +@dbtest +def test_shell_parameterized_favorite_query(executor): + set_expanded_output(False) + run(executor, "create table test(a text, id integer)") + run(executor, "insert into test values('abc', 1)") + run(executor, "insert into test values('def', 2)") + + results = run(executor, "\\fs sh_param select * from test where id=$1") + assert_result_equal(results, status="Saved.") + + results = run(executor, "\\f+ sh_param 1") + assert_result_equal( + results, + title="> select * from test where id=1", + headers=["a", "id"], + rows=[("abc", 1)], + auto_status=False, + ) + + results = run(executor, "\\f+ sh_param") + assert_result_equal( + results, + title=None, + headers=None, + rows=None, + status="missing substitution for $1 in query:\n select * from test where id=$1", + ) + + results = run(executor, "\\f+ sh_param 1 2") + assert_result_equal( + results, + title=None, + headers=None, + rows=None, + status="Too many arguments.\nQuery does not have enough place holders to substitute.\nselect * from test where id=1", + ) + + +@dbtest +def test_favorite_query_multiple_statement(executor): + set_expanded_output(False) + run(executor, "create table test(a text)") + run(executor, "insert into test values('abc')") + run(executor, "insert into test values('def')") + + results = run( + executor, + "\\fs test-ad select * from test where a like 'a%'; " + "select * from test where a like 'd%'", + ) + assert_result_equal(results, status="Saved.") + + results = run(executor, "\\f+ test-ad") + expected = [ + { + "title": "> select * from test where a like 'a%'", + "headers": ["a"], + "rows": [("abc",)], + "status": None, + }, + { + "title": "> select * from test where a like 'd%'", + "headers": ["a"], + "rows": [("def",)], + "status": None, + }, + ] + assert expected == results + + results = run(executor, "\\fd test-ad") + assert_result_equal(results, status="test-ad: Deleted") + + +@dbtest +def test_favorite_query_expanded_output(executor): + set_expanded_output(False) + run(executor, """create table test(a text)""") + run(executor, """insert into test values('abc')""") + + results = run(executor, "\\fs test-ae select * from test") + assert_result_equal(results, status="Saved.") + + results = run(executor, "\\f+ test-ae \G") + assert is_expanded_output() is True + assert_result_equal( + results, + title="> select * from test", + headers=["a"], + rows=[("abc",)], + auto_status=False, + ) + + set_expanded_output(False) + + results = run(executor, "\\fd test-ae") + assert_result_equal(results, status="test-ae: Deleted") + + +@dbtest +def test_special_command(executor): + results = run(executor, "\\?") + assert_result_equal( + results, + rows=("quit", "\\q", "Quit."), + headers="Command", + assert_contains=True, + auto_status=False, + ) + + +@dbtest +def test_cd_command_without_a_folder_name(executor): + results = run(executor, "system cd") + assert_result_equal(results, status="No folder name was provided.") + + +@dbtest +def test_system_command_not_found(executor): + results = run(executor, "system xyz") + assert_result_equal( + results, status="OSError: No such file or directory", assert_contains=True + ) + + +@dbtest +def test_system_command_output(executor): + test_dir = os.path.abspath(os.path.dirname(__file__)) + test_file_path = os.path.join(test_dir, "test.txt") + results = run(executor, "system cat {0}".format(test_file_path)) + assert_result_equal(results, status="litecli is awesome!\n") + + +@dbtest +def test_cd_command_current_dir(executor): + test_path = os.path.abspath(os.path.dirname(__file__)) + run(executor, "system cd {0}".format(test_path)) + assert os.getcwd() == test_path + run(executor, "system cd ..") + + +@dbtest +def test_unicode_support(executor): + results = run(executor, u"SELECT '日本語' AS japanese;") + assert_result_equal(results, headers=["japanese"], rows=[(u"日本語",)]) + + +@dbtest +def test_timestamp_null(executor): + run(executor, """create table ts_null(a timestamp null)""") + run(executor, """insert into ts_null values(null)""") + results = run(executor, """select * from ts_null""") + assert_result_equal(results, headers=["a"], rows=[(None,)]) + + +@dbtest +def test_datetime_null(executor): + run(executor, """create table dt_null(a datetime null)""") + run(executor, """insert into dt_null values(null)""") + results = run(executor, """select * from dt_null""") + assert_result_equal(results, headers=["a"], rows=[(None,)]) + + +@dbtest +def test_date_null(executor): + run(executor, """create table date_null(a date null)""") + run(executor, """insert into date_null values(null)""") + results = run(executor, """select * from date_null""") + assert_result_equal(results, headers=["a"], rows=[(None,)]) + + +@dbtest +def test_time_null(executor): + run(executor, """create table time_null(a time null)""") + run(executor, """insert into time_null values(null)""") + results = run(executor, """select * from time_null""") + assert_result_equal(results, headers=["a"], rows=[(None,)]) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..41bac9b --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- + +import os +import time +import signal +import platform +import multiprocessing +from contextlib import closing + +import sqlite3 +import pytest + +from litecli.main import special + +DATABASE = os.getenv("PYTEST_DATABASE", "test.sqlite3") + + +def db_connection(dbname=":memory:"): + conn = sqlite3.connect(database=dbname, isolation_level=None) + return conn + + +try: + db_connection() + CAN_CONNECT_TO_DB = True +except Exception as ex: + CAN_CONNECT_TO_DB = False + +dbtest = pytest.mark.skipif( + not CAN_CONNECT_TO_DB, reason="Error creating sqlite connection" +) + + +def create_db(dbname): + with closing(db_connection().cursor()) as cur: + try: + cur.execute("""DROP DATABASE IF EXISTS _test_db""") + cur.execute("""CREATE DATABASE _test_db""") + except: + pass + + +def drop_tables(dbname): + with closing(db_connection().cursor()) as cur: + try: + cur.execute("""DROP DATABASE IF EXISTS _test_db""") + except: + pass + + +def run(executor, sql, rows_as_list=True): + """Return string output for the sql to be run.""" + result = [] + + for title, rows, headers, status in executor.run(sql): + rows = list(rows) if (rows_as_list and rows) else rows + result.append( + {"title": title, "rows": rows, "headers": headers, "status": status} + ) + + return result + + +def set_expanded_output(is_expanded): + """Pass-through for the tests.""" + return special.set_expanded_output(is_expanded) + + +def is_expanded_output(): + """Pass-through for the tests.""" + return special.is_expanded_output() + + +def send_ctrl_c_to_pid(pid, wait_seconds): + """Sends a Ctrl-C like signal to the given `pid` after `wait_seconds` + seconds.""" + time.sleep(wait_seconds) + system_name = platform.system() + if system_name == "Windows": + os.kill(pid, signal.CTRL_C_EVENT) + else: + os.kill(pid, signal.SIGINT) + + +def send_ctrl_c(wait_seconds): + """Create a process that sends a Ctrl-C like signal to the current process + after `wait_seconds` seconds. + + Returns the `multiprocessing.Process` created. + + """ + ctrl_c_process = multiprocessing.Process( + target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds) + ) + ctrl_c_process.start() + return ctrl_c_process @@ -0,0 +1,11 @@ +[tox] +envlist = py37, py38, py39, py310 + +[testenv] +deps = pytest + mock + pexpect + behave + coverage +commands = python setup.py test +passenv = PYTEST_DATABASE |