summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--.coveragerc3
-rw-r--r--.git-blame-ignore-revs2
-rw-r--r--.github/PULL_REQUEST_TEMPLATE.md8
-rw-r--r--.github/workflows/ci.yml44
-rw-r--r--.gitignore18
-rw-r--r--.pre-commit-config.yaml5
-rw-r--r--CHANGELOG.md110
-rw-r--r--CONTRIBUTING.md119
-rw-r--r--LICENSE27
-rw-r--r--MANIFEST.in8
-rw-r--r--README.md54
-rw-r--r--TODO3
-rw-r--r--litecli/AUTHORS21
-rw-r--r--litecli/__init__.py1
-rw-r--r--litecli/clibuffer.py40
-rw-r--r--litecli/clistyle.py114
-rw-r--r--litecli/clitoolbar.py52
-rw-r--r--litecli/compat.py9
-rw-r--r--litecli/completion_refresher.py131
-rw-r--r--litecli/config.py62
-rw-r--r--litecli/encodingutils.py38
-rw-r--r--litecli/key_bindings.py84
-rw-r--r--litecli/lexer.py9
-rw-r--r--litecli/liteclirc122
-rw-r--r--litecli/main.py1017
-rw-r--r--litecli/packages/__init__.py0
-rw-r--r--litecli/packages/completion_engine.py331
-rw-r--r--litecli/packages/filepaths.py88
-rw-r--r--litecli/packages/parseutils.py227
-rw-r--r--litecli/packages/prompt_utils.py39
-rw-r--r--litecli/packages/special/__init__.py12
-rw-r--r--litecli/packages/special/dbcommands.py290
-rw-r--r--litecli/packages/special/favoritequeries.py59
-rw-r--r--litecli/packages/special/iocommands.py478
-rw-r--r--litecli/packages/special/main.py160
-rw-r--r--litecli/packages/special/utils.py48
-rw-r--r--litecli/sqlcompleter.py612
-rw-r--r--litecli/sqlexecute.py220
-rw-r--r--release.py130
-rw-r--r--requirements-dev.txt10
-rw-r--r--screenshots/litecli.gifbin0 -> 742270 bytes
-rw-r--r--screenshots/litecli.pngbin0 -> 109549 bytes
-rw-r--r--setup.cfg18
-rwxr-xr-xsetup.py70
-rw-r--r--tasks.py130
-rw-r--r--tests/conftest.py40
-rw-r--r--tests/data/import_data.csv2
-rw-r--r--tests/liteclirc137
-rw-r--r--tests/test.txt1
-rw-r--r--tests/test_clistyle.py28
-rw-r--r--tests/test_completion_engine.py657
-rw-r--r--tests/test_completion_refresher.py94
-rw-r--r--tests/test_dbspecial.py76
-rw-r--r--tests/test_main.py262
-rw-r--r--tests/test_parseutils.py131
-rw-r--r--tests/test_prompt_utils.py14
-rw-r--r--tests/test_smart_completion_public_schema_only.py432
-rw-r--r--tests/test_sqlexecute.py405
-rw-r--r--tests/utils.py96
-rw-r--r--tox.ini11
60 files changed, 7409 insertions, 0 deletions
diff --git a/.coveragerc b/.coveragerc
new file mode 100644
index 0000000..aea4994
--- /dev/null
+++ b/.coveragerc
@@ -0,0 +1,3 @@
+[run]
+parallel = True
+source = litecli
diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
new file mode 100644
index 0000000..c433202
--- /dev/null
+++ b/.git-blame-ignore-revs
@@ -0,0 +1,2 @@
+# Black
+f767afc80bd5bcc8f1b1cc1a134babc2dec4d239 \ No newline at end of file
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
new file mode 100644
index 0000000..3e14cc7
--- /dev/null
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -0,0 +1,8 @@
+## Description
+<!--- Describe your changes in detail. -->
+
+
+
+## Checklist
+<!--- We appreciate your help and want to give you credit. Please take a moment to put an `x` in the boxes below as you complete them. -->
+- [ ] I've added this contribution to the `CHANGELOG.md` file.
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 0000000..9ee36cf
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,44 @@
+name: litecli
+
+on:
+ pull_request:
+ paths-ignore:
+ - '**.md'
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+
+ strategy:
+ matrix:
+ python-version: ["3.7", "3.8", "3.9", "3.10"]
+
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Install requirements
+ run: |
+ python -m pip install -U pip setuptools
+ pip install --no-cache-dir -e .
+ pip install -r requirements-dev.txt -U --upgrade-strategy=only-if-needed
+
+ - name: Run unit tests
+ env:
+ PYTEST_PASSWORD: root
+ run: |
+ ./setup.py test --pytest-args="--cov-report= --cov=litecli"
+
+ - name: Run Black
+ run: |
+ ./setup.py lint
+ if: matrix.python-version == '3.7'
+
+ - name: Coverage
+ run: |
+ coverage report
+ codecov
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..63c3eb6
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,18 @@
+.idea/
+/build
+/dist
+/litecli.egg-info
+/htmlcov
+/src
+/test/behave.ini
+/litecli_env
+/.venv
+/.eggs
+
+.vagrant
+*.pyc
+*.deb
+*.swp
+.cache/
+.coverage
+.coverage.*
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..67ba03d
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,5 @@
+repos:
+- repo: https://github.com/psf/black
+ rev: 22.3.0
+ hooks:
+ - id: black
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000..afeebd1
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,110 @@
+## 1.9.0 - 2022-06-06
+
+### Features
+
+* Add support for ANSI escape sequences for coloring the prompt.
+* Add support for `.indexes` command.
+* Add an option to turn off the auto-completion menu. Completion menu can be
+ triggered by pressed the `<tab>` key when this option is set to False. Fixes
+ [#105](https://github.com/dbcli/litecli/issues/105).
+
+### Bug Fixes
+
+* Fix [#120](https://github.com/dbcli/litecli/issues/120). Make the `.read` command actually read and execute the commands from a file.
+* Fix [#96](https://github.com/dbcli/litecli/issues/96) the crash in VI mode when pressing `r`.
+
+## 1.8.0 - 2022-03-29
+
+### Features
+
+* Update compatible Python versions. (Thanks: [blazewicz])
+* Add support for Python 3.10. (Thanks: [blazewicz])
+* Drop support for Python 3.6. (Thanks: [blazewicz])
+
+### Bug Fixes
+
+* Upgrade cli_helpers to workaround Pygments regression.
+* Use get_terminal_size from shutil instead of click.
+
+## 1.7.0 - 2022-01-11
+
+### Features
+
+* Add config option show_bottom_toolbar.
+
+### Bug Fixes
+
+* Pin pygments version to prevent breaking change.
+
+## 1.6.0 - 2021-03-15
+
+### Features
+
+- Add verbose feature to `favorite_query` command. (Thanks: [Zhaolong Zhu])
+ - `\f query` does not show the full SQL.
+ - `\f+ query` shows the full SQL.
+- Add prompt format of file's basename. (Thanks: [elig0n])
+
+### Bug Fixes
+
+- Fix compatibility with sqlparse >= 0.4.0. (Thanks: [chocolateboy])
+- Fix invalid utf-8 exception. (Thanks: [Amjith])
+
+## 1.4.1 - 2020-07-27
+
+### Bug Fixes
+
+- Fix setup.py to set `long_description_content_type` as markdown.
+
+## 1.4.0 - 2020-07-27
+
+### Features
+
+- Add NULLS FIRST and NULLS LAST to keywords. (Thanks: [Amjith])
+
+## 1.3.2 - 2020-03-11
+
+- Fix the completion engine to work with newer sqlparse.
+
+## 1.3.1 - 2020-03-11
+
+- Remove the version pinning of sqlparse package.
+
+## 1.3.0 - 2020-02-11
+
+### Features
+
+- Added `.import` command for importing data from file into table. (Thanks: [Zhaolong Zhu])
+- Upgraded to prompt-toolkit 3.x.
+
+## 1.2.0 - 2019-10-26
+
+### Features
+
+- Enhance the `describe` command. (Thanks: [Amjith])
+- Autocomplete table names for special commands. (Thanks: [Amjith])
+
+## 1.1.0 - 2019-07-14
+
+### Features
+
+- Added `.read` command for reading scripts.
+- Added `.load` command for loading extension libraries. (Thanks: [Zhiming Wang])
+- Add support for using `?` as a placeholder in the favorite queries. (Thanks: [Amjith])
+- Added shift-tab to select the previous entry in the completion menu. [Amjith]
+- Added `describe` and `desc` keywords.
+
+### Bug Fixes
+
+- Clear error message when directory does not exist. (Thanks: [Irina Truong])
+
+## 1.0.0 - 2019-01-04
+
+- To new beginnings. :tada:
+
+[Amjith]: https://blog.amjith.com
+[chocolateboy]: https://github.com/chocolateboy
+[Irina Truong]: https://github.com/j-bennet
+[Shawn Chapla]: https://github.com/shwnchpl
+[Zhaolong Zhu]: https://github.com/zzl0
+[Zhiming Wang]: https://github.com/zmwangx
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..9cd868a
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,119 @@
+# Development Guide
+
+This is a guide for developers who would like to contribute to this project. It is recommended to use Python 3.7 and above for development.
+
+If you're interested in contributing to litecli, thank you. We'd love your help!
+You'll always get credit for your work.
+
+## GitHub Workflow
+
+1. [Fork the repository](https://github.com/dbcli/litecli) on GitHub.
+
+2. Clone your fork locally:
+ ```bash
+ $ git clone <url-for-your-fork>
+ ```
+
+3. Add the official repository (`upstream`) as a remote repository:
+ ```bash
+ $ git remote add upstream git@github.com:dbcli/litecli.git
+ ```
+
+4. Set up a [virtual environment](http://docs.python-guide.org/en/latest/dev/virtualenvs)
+ for development:
+
+ ```bash
+ $ cd litecli
+ $ pip install virtualenv
+ $ virtualenv litecli_dev
+ ```
+
+ We've just created a virtual environment that we'll use to install all the dependencies
+ and tools we need to work on litecli. Whenever you want to work on litecli, you
+ need to activate the virtual environment:
+
+ ```bash
+ $ source litecli_dev/bin/activate
+ ```
+
+ When you're done working, you can deactivate the virtual environment:
+
+ ```bash
+ $ deactivate
+ ```
+
+5. Install the dependencies and development tools:
+
+ ```bash
+ $ pip install -r requirements-dev.txt
+ $ pip install --editable .
+ ```
+
+6. Create a branch for your bugfix or feature based off the `master` branch:
+
+ ```bash
+ $ git checkout -b <name-of-bugfix-or-feature>
+ ```
+
+7. While you work on your bugfix or feature, be sure to pull the latest changes from `upstream`. This ensures that your local codebase is up-to-date:
+
+ ```bash
+ $ git pull upstream master
+ ```
+
+8. When your work is ready for the litecli team to review it, push your branch to your fork:
+
+ ```bash
+ $ git push origin <name-of-bugfix-or-feature>
+ ```
+
+9. [Create a pull request](https://help.github.com/articles/creating-a-pull-request-from-a-fork/) on GitHub.
+
+
+## Running the Tests
+
+While you work on litecli, it's important to run the tests to make sure your code
+hasn't broken any existing functionality. To run the tests, just type in:
+
+```bash
+$ ./setup.py test
+```
+
+litecli supports Python 3.7+. You can test against multiple versions of
+Python by running tox:
+
+```bash
+$ tox
+```
+
+
+### CLI Tests
+
+Some CLI tests expect the program `ex` to be a symbolic link to `vim`.
+
+In some systems (e.g. Arch Linux) `ex` is a symbolic link to `vi`, which will
+change the output and therefore make some tests fail.
+
+You can check this by running:
+```bash
+$ readlink -f $(which ex)
+```
+
+
+## Coding Style
+
+litecli uses [black](https://github.com/ambv/black) to format the source code. Make sure to install black.
+
+It's easy to check the style of your code, just run:
+
+```bash
+$ ./setup.py lint
+```
+
+If you see any style issues, you can automatically fix them by running:
+
+```bash
+$ ./setup.py lint --fix
+```
+
+Be sure to commit and push any stylistic fixes.
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..c3bbe96
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,27 @@
+Copyright (c) 2018, dbcli
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+* Neither the name of dbcli nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000..f1ff0f6
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,8 @@
+include *.txt *.py
+include LICENSE CHANGELOG.md
+include tox.ini
+recursive-include tests *.py
+recursive-include tests *.txt
+recursive-include tests *.csv
+recursive-include tests liteclirc
+recursive-include litecli AUTHORS
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..649481d
--- /dev/null
+++ b/README.md
@@ -0,0 +1,54 @@
+# litecli
+
+[![Build Status](https://travis-ci.org/dbcli/litecli.svg?branch=master)](https://travis-ci.org/dbcli/litecli)
+
+[Docs](https://litecli.com)
+
+A command-line client for SQLite databases that has auto-completion and syntax highlighting.
+
+![Completion](screenshots/litecli.png)
+![CompletionGif](screenshots/litecli.gif)
+
+## Installation
+
+If you already know how to install python packages, then you can install it via pip:
+
+You might need sudo on linux.
+
+```
+$ pip install -U litecli
+```
+
+The package is also available on Arch Linux through AUR in two versions: [litecli](https://aur.archlinux.org/packages/litecli/) is based the latest release (git tag) and [litecli-git](https://aur.archlinux.org/packages/litecli-git/) is based on the master branch of the git repo. You can install them manually or with an AUR helper such as `yay`:
+
+```
+$ yay -S litecli
+```
+or
+
+```
+$ yay -S litecli-git
+```
+
+For MacOS users, you can also use Homebrew to install it:
+
+```
+$ brew install litecli
+```
+
+## Usage
+
+```
+$ litecli --help
+
+Usage: litecli [OPTIONS] [DATABASE]
+
+Examples:
+ - litecli sqlite_db_name
+```
+
+A config file is automatically created at `~/.config/litecli/config` at first launch. For Windows machines a config file is created at `~\AppData\Local\dbcli\litecli\config` at first launch. See the file itself for a description of all available options.
+
+## Docs
+
+Visit: [litecli.com/features](https://litecli.com/features)
diff --git a/TODO b/TODO
new file mode 100644
index 0000000..7c854dc
--- /dev/null
+++ b/TODO
@@ -0,0 +1,3 @@
+* [] Sort by frecency.
+* [] Add completions when an attach database command is run.
+* [] Add behave tests.
diff --git a/litecli/AUTHORS b/litecli/AUTHORS
new file mode 100644
index 0000000..194cdc7
--- /dev/null
+++ b/litecli/AUTHORS
@@ -0,0 +1,21 @@
+Project Lead:
+-------------
+
+ * Delgermurun Purevkhu
+
+
+Core Developers:
+----------------
+
+ * Amjith Ramanujam
+ * Irina Truong
+ * Dick Marinus
+
+Contributors:
+-------------
+
+ * Thomas Roten
+ * Zhaolong Zhu
+ * Zhiming Wang
+ * Shawn M. Chapla
+ * Paweł Sacawa
diff --git a/litecli/__init__.py b/litecli/__init__.py
new file mode 100644
index 0000000..0a0a43a
--- /dev/null
+++ b/litecli/__init__.py
@@ -0,0 +1 @@
+__version__ = "1.9.0"
diff --git a/litecli/clibuffer.py b/litecli/clibuffer.py
new file mode 100644
index 0000000..a57192a
--- /dev/null
+++ b/litecli/clibuffer.py
@@ -0,0 +1,40 @@
+from __future__ import unicode_literals
+
+from prompt_toolkit.enums import DEFAULT_BUFFER
+from prompt_toolkit.filters import Condition
+from prompt_toolkit.application import get_app
+
+
+def cli_is_multiline(cli):
+ @Condition
+ def cond():
+ doc = get_app().layout.get_buffer_by_name(DEFAULT_BUFFER).document
+
+ if not cli.multi_line:
+ return False
+ else:
+ return not _multiline_exception(doc.text)
+
+ return cond
+
+
+def _multiline_exception(text):
+ orig = text
+ text = text.strip()
+
+ # Multi-statement favorite query is a special case. Because there will
+ # be a semicolon separating statements, we can't consider semicolon an
+ # EOL. Let's consider an empty line an EOL instead.
+ if text.startswith("\\fs"):
+ return orig.endswith("\n")
+
+ return (
+ text.startswith("\\") # Special Command
+ or text.endswith(";") # Ended with a semi-colon
+ or text.endswith("\\g") # Ended with \g
+ or text.endswith("\\G") # Ended with \G
+ or (text == "exit") # Exit doesn't need semi-colon
+ or (text == "quit") # Quit doesn't need semi-colon
+ or (text == ":q") # To all the vim fans out there
+ or (text == "") # Just a plain enter without any text
+ )
diff --git a/litecli/clistyle.py b/litecli/clistyle.py
new file mode 100644
index 0000000..7527315
--- /dev/null
+++ b/litecli/clistyle.py
@@ -0,0 +1,114 @@
+from __future__ import unicode_literals
+
+import logging
+
+import pygments.styles
+from pygments.token import string_to_tokentype, Token
+from pygments.style import Style as PygmentsStyle
+from pygments.util import ClassNotFound
+from prompt_toolkit.styles.pygments import style_from_pygments_cls
+from prompt_toolkit.styles import merge_styles, Style
+
+logger = logging.getLogger(__name__)
+
+# map Pygments tokens (ptk 1.0) to class names (ptk 2.0).
+TOKEN_TO_PROMPT_STYLE = {
+ Token.Menu.Completions.Completion.Current: "completion-menu.completion.current",
+ Token.Menu.Completions.Completion: "completion-menu.completion",
+ Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current",
+ Token.Menu.Completions.Meta: "completion-menu.meta.completion",
+ Token.Menu.Completions.MultiColumnMeta: "completion-menu.multi-column-meta",
+ Token.Menu.Completions.ProgressButton: "scrollbar.arrow", # best guess
+ Token.Menu.Completions.ProgressBar: "scrollbar", # best guess
+ Token.SelectedText: "selected",
+ Token.SearchMatch: "search",
+ Token.SearchMatch.Current: "search.current",
+ Token.Toolbar: "bottom-toolbar",
+ Token.Toolbar.Off: "bottom-toolbar.off",
+ Token.Toolbar.On: "bottom-toolbar.on",
+ Token.Toolbar.Search: "search-toolbar",
+ Token.Toolbar.Search.Text: "search-toolbar.text",
+ Token.Toolbar.System: "system-toolbar",
+ Token.Toolbar.Arg: "arg-toolbar",
+ Token.Toolbar.Arg.Text: "arg-toolbar.text",
+ Token.Toolbar.Transaction.Valid: "bottom-toolbar.transaction.valid",
+ Token.Toolbar.Transaction.Failed: "bottom-toolbar.transaction.failed",
+ Token.Output.Header: "output.header",
+ Token.Output.OddRow: "output.odd-row",
+ Token.Output.EvenRow: "output.even-row",
+ Token.Prompt: "prompt",
+ Token.Continuation: "continuation",
+}
+
+# reverse dict for cli_helpers, because they still expect Pygments tokens.
+PROMPT_STYLE_TO_TOKEN = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()}
+
+
+def parse_pygments_style(token_name, style_object, style_dict):
+ """Parse token type and style string.
+
+ :param token_name: str name of Pygments token. Example: "Token.String"
+ :param style_object: pygments.style.Style instance to use as base
+ :param style_dict: dict of token names and their styles, customized to this cli
+
+ """
+ token_type = string_to_tokentype(token_name)
+ try:
+ other_token_type = string_to_tokentype(style_dict[token_name])
+ return token_type, style_object.styles[other_token_type]
+ except AttributeError as err:
+ return token_type, style_dict[token_name]
+
+
+def style_factory(name, cli_style):
+ try:
+ style = pygments.styles.get_style_by_name(name)
+ except ClassNotFound:
+ style = pygments.styles.get_style_by_name("native")
+
+ prompt_styles = []
+ # prompt-toolkit used pygments tokens for styling before, switched to style
+ # names in 2.0. Convert old token types to new style names, for backwards compatibility.
+ for token in cli_style:
+ if token.startswith("Token."):
+ # treat as pygments token (1.0)
+ token_type, style_value = parse_pygments_style(token, style, cli_style)
+ if token_type in TOKEN_TO_PROMPT_STYLE:
+ prompt_style = TOKEN_TO_PROMPT_STYLE[token_type]
+ prompt_styles.append((prompt_style, style_value))
+ else:
+ # we don't want to support tokens anymore
+ logger.error("Unhandled style / class name: %s", token)
+ else:
+ # treat as prompt style name (2.0). See default style names here:
+ # https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py
+ prompt_styles.append((token, cli_style[token]))
+
+ override_style = Style([("bottom-toolbar", "noreverse")])
+ return merge_styles(
+ [style_from_pygments_cls(style), override_style, Style(prompt_styles)]
+ )
+
+
+def style_factory_output(name, cli_style):
+ try:
+ style = pygments.styles.get_style_by_name(name).styles
+ except ClassNotFound:
+ style = pygments.styles.get_style_by_name("native").styles
+
+ for token in cli_style:
+ if token.startswith("Token."):
+ token_type, style_value = parse_pygments_style(token, style, cli_style)
+ style.update({token_type: style_value})
+ elif token in PROMPT_STYLE_TO_TOKEN:
+ token_type = PROMPT_STYLE_TO_TOKEN[token]
+ style.update({token_type: cli_style[token]})
+ else:
+ # TODO: cli helpers will have to switch to ptk.Style
+ logger.error("Unhandled style / class name: %s", token)
+
+ class OutputStyle(PygmentsStyle):
+ default_style = ""
+ styles = style
+
+ return OutputStyle
diff --git a/litecli/clitoolbar.py b/litecli/clitoolbar.py
new file mode 100644
index 0000000..1e28784
--- /dev/null
+++ b/litecli/clitoolbar.py
@@ -0,0 +1,52 @@
+from __future__ import unicode_literals
+
+from prompt_toolkit.key_binding.vi_state import InputMode
+from prompt_toolkit.enums import EditingMode
+from prompt_toolkit.application import get_app
+
+
+def create_toolbar_tokens_func(cli, show_fish_help):
+ """
+ Return a function that generates the toolbar tokens.
+ """
+
+ def get_toolbar_tokens():
+ result = []
+ result.append(("class:bottom-toolbar", " "))
+
+ if cli.multi_line:
+ result.append(
+ ("class:bottom-toolbar", " (Semi-colon [;] will end the line) ")
+ )
+
+ if cli.multi_line:
+ result.append(("class:bottom-toolbar.on", "[F3] Multiline: ON "))
+ else:
+ result.append(("class:bottom-toolbar.off", "[F3] Multiline: OFF "))
+ if cli.prompt_app.editing_mode == EditingMode.VI:
+ result.append(
+ ("class:botton-toolbar.on", "Vi-mode ({})".format(_get_vi_mode()))
+ )
+
+ if show_fish_help():
+ result.append(
+ ("class:bottom-toolbar", " Right-arrow to complete suggestion")
+ )
+
+ if cli.completion_refresher.is_refreshing():
+ result.append(("class:bottom-toolbar", " Refreshing completions..."))
+
+ return result
+
+ return get_toolbar_tokens
+
+
+def _get_vi_mode():
+ """Get the current vi mode for display."""
+ return {
+ InputMode.INSERT: "I",
+ InputMode.NAVIGATION: "N",
+ InputMode.REPLACE: "R",
+ InputMode.INSERT_MULTIPLE: "M",
+ InputMode.REPLACE_SINGLE: "R",
+ }[get_app().vi_state.input_mode]
diff --git a/litecli/compat.py b/litecli/compat.py
new file mode 100644
index 0000000..7316261
--- /dev/null
+++ b/litecli/compat.py
@@ -0,0 +1,9 @@
+# -*- coding: utf-8 -*-
+"""Platform and Python version compatibility support."""
+
+import sys
+
+
+PY2 = sys.version_info[0] == 2
+PY3 = sys.version_info[0] == 3
+WIN = sys.platform in ("win32", "cygwin")
diff --git a/litecli/completion_refresher.py b/litecli/completion_refresher.py
new file mode 100644
index 0000000..9602070
--- /dev/null
+++ b/litecli/completion_refresher.py
@@ -0,0 +1,131 @@
+import threading
+from .packages.special.main import COMMANDS
+from collections import OrderedDict
+
+from .sqlcompleter import SQLCompleter
+from .sqlexecute import SQLExecute
+
+
+class CompletionRefresher(object):
+
+ refreshers = OrderedDict()
+
+ def __init__(self):
+ self._completer_thread = None
+ self._restart_refresh = threading.Event()
+
+ def refresh(self, executor, callbacks, completer_options=None):
+ """Creates a SQLCompleter object and populates it with the relevant
+ completion suggestions in a background thread.
+
+ executor - SQLExecute object, used to extract the credentials to connect
+ to the database.
+ callbacks - A function or a list of functions to call after the thread
+ has completed the refresh. The newly created completion
+ object will be passed in as an argument to each callback.
+ completer_options - dict of options to pass to SQLCompleter.
+
+ """
+ if completer_options is None:
+ completer_options = {}
+
+ if self.is_refreshing():
+ self._restart_refresh.set()
+ return [(None, None, None, "Auto-completion refresh restarted.")]
+ else:
+ if executor.dbname == ":memory:":
+ # if DB is memory, needed to use same connection
+ # So can't use same connection with different thread
+ self._bg_refresh(executor, callbacks, completer_options)
+ else:
+ self._completer_thread = threading.Thread(
+ target=self._bg_refresh,
+ args=(executor, callbacks, completer_options),
+ name="completion_refresh",
+ )
+ self._completer_thread.setDaemon(True)
+ self._completer_thread.start()
+ return [
+ (
+ None,
+ None,
+ None,
+ "Auto-completion refresh started in the background.",
+ )
+ ]
+
+ def is_refreshing(self):
+ return self._completer_thread and self._completer_thread.is_alive()
+
+ def _bg_refresh(self, sqlexecute, callbacks, completer_options):
+ completer = SQLCompleter(**completer_options)
+
+ e = sqlexecute
+ if e.dbname == ":memory:":
+ # if DB is memory, needed to use same connection
+ executor = sqlexecute
+ else:
+ # Create a new sqlexecute method to popoulate the completions.
+ executor = SQLExecute(e.dbname)
+
+ # If callbacks is a single function then push it into a list.
+ if callable(callbacks):
+ callbacks = [callbacks]
+
+ while 1:
+ for refresher in self.refreshers.values():
+ refresher(completer, executor)
+ if self._restart_refresh.is_set():
+ self._restart_refresh.clear()
+ break
+ else:
+ # Break out of while loop if the for loop finishes natually
+ # without hitting the break statement.
+ break
+
+ # Start over the refresh from the beginning if the for loop hit the
+ # break statement.
+ continue
+
+ for callback in callbacks:
+ callback(completer)
+
+
+def refresher(name, refreshers=CompletionRefresher.refreshers):
+ """Decorator to add the decorated function to the dictionary of
+ refreshers. Any function decorated with a @refresher will be executed as
+ part of the completion refresh routine."""
+
+ def wrapper(wrapped):
+ refreshers[name] = wrapped
+ return wrapped
+
+ return wrapper
+
+
+@refresher("databases")
+def refresh_databases(completer, executor):
+ completer.extend_database_names(executor.databases())
+
+
+@refresher("schemata")
+def refresh_schemata(completer, executor):
+ # name of the current database.
+ completer.extend_schemata(executor.dbname)
+ completer.set_dbname(executor.dbname)
+
+
+@refresher("tables")
+def refresh_tables(completer, executor):
+ completer.extend_relations(executor.tables(), kind="tables")
+ completer.extend_columns(executor.table_columns(), kind="tables")
+
+
+@refresher("functions")
+def refresh_functions(completer, executor):
+ completer.extend_functions(executor.functions())
+
+
+@refresher("special_commands")
+def refresh_special(completer, executor):
+ completer.extend_special_commands(COMMANDS.keys())
diff --git a/litecli/config.py b/litecli/config.py
new file mode 100644
index 0000000..1c7fb25
--- /dev/null
+++ b/litecli/config.py
@@ -0,0 +1,62 @@
+import errno
+import shutil
+import os
+import platform
+from os.path import expanduser, exists, dirname
+from configobj import ConfigObj
+
+
+def config_location():
+ if "XDG_CONFIG_HOME" in os.environ:
+ return "%s/litecli/" % expanduser(os.environ["XDG_CONFIG_HOME"])
+ elif platform.system() == "Windows":
+ return os.getenv("USERPROFILE") + "\\AppData\\Local\\dbcli\\litecli\\"
+ else:
+ return expanduser("~/.config/litecli/")
+
+
+def load_config(usr_cfg, def_cfg=None):
+ cfg = ConfigObj()
+ cfg.merge(ConfigObj(def_cfg, interpolation=False))
+ cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8"))
+ cfg.filename = expanduser(usr_cfg)
+
+ return cfg
+
+
+def ensure_dir_exists(path):
+ parent_dir = expanduser(dirname(path))
+ try:
+ os.makedirs(parent_dir)
+ except OSError as exc:
+ # ignore existing destination (py2 has no exist_ok arg to makedirs)
+ if exc.errno != errno.EEXIST:
+ raise
+
+
+def write_default_config(source, destination, overwrite=False):
+ destination = expanduser(destination)
+ if not overwrite and exists(destination):
+ return
+
+ ensure_dir_exists(destination)
+
+ shutil.copyfile(source, destination)
+
+
+def upgrade_config(config, def_config):
+ cfg = load_config(config, def_config)
+ cfg.write()
+
+
+def get_config(liteclirc_file=None):
+ from litecli import __file__ as package_root
+
+ package_root = os.path.dirname(package_root)
+
+ liteclirc_file = liteclirc_file or "%sconfig" % config_location()
+
+ default_config = os.path.join(package_root, "liteclirc")
+ write_default_config(default_config, liteclirc_file)
+
+ return load_config(liteclirc_file, default_config)
diff --git a/litecli/encodingutils.py b/litecli/encodingutils.py
new file mode 100644
index 0000000..6caf14d
--- /dev/null
+++ b/litecli/encodingutils.py
@@ -0,0 +1,38 @@
+# -*- coding: utf-8 -*-
+from __future__ import unicode_literals
+
+from litecli.compat import PY2
+
+
+if PY2:
+ binary_type = str
+ string_types = basestring
+ text_type = unicode
+else:
+ binary_type = bytes
+ string_types = str
+ text_type = str
+
+
+def unicode2utf8(arg):
+ """Convert strings to UTF8-encoded bytes.
+
+ Only in Python 2. In Python 3 the args are expected as unicode.
+
+ """
+
+ if PY2 and isinstance(arg, text_type):
+ return arg.encode("utf-8")
+ return arg
+
+
+def utf8tounicode(arg):
+ """Convert UTF8-encoded bytes to strings.
+
+ Only in Python 2. In Python 3 the errors are returned as strings.
+
+ """
+
+ if PY2 and isinstance(arg, binary_type):
+ return arg.decode("utf-8")
+ return arg
diff --git a/litecli/key_bindings.py b/litecli/key_bindings.py
new file mode 100644
index 0000000..44d59d2
--- /dev/null
+++ b/litecli/key_bindings.py
@@ -0,0 +1,84 @@
+from __future__ import unicode_literals
+import logging
+from prompt_toolkit.enums import EditingMode
+from prompt_toolkit.filters import completion_is_selected
+from prompt_toolkit.key_binding import KeyBindings
+
+_logger = logging.getLogger(__name__)
+
+
+def cli_bindings(cli):
+ """Custom key bindings for cli."""
+ kb = KeyBindings()
+
+ @kb.add("f3")
+ def _(event):
+ """Enable/Disable Multiline Mode."""
+ _logger.debug("Detected F3 key.")
+ cli.multi_line = not cli.multi_line
+
+ @kb.add("f4")
+ def _(event):
+ """Toggle between Vi and Emacs mode."""
+ _logger.debug("Detected F4 key.")
+ if cli.key_bindings == "vi":
+ event.app.editing_mode = EditingMode.EMACS
+ cli.key_bindings = "emacs"
+ else:
+ event.app.editing_mode = EditingMode.VI
+ cli.key_bindings = "vi"
+
+ @kb.add("tab")
+ def _(event):
+ """Force autocompletion at cursor."""
+ _logger.debug("Detected <Tab> key.")
+ b = event.app.current_buffer
+ if b.complete_state:
+ b.complete_next()
+ else:
+ b.start_completion(select_first=True)
+
+ @kb.add("s-tab")
+ def _(event):
+ """Force autocompletion at cursor."""
+ _logger.debug("Detected <Tab> key.")
+ b = event.app.current_buffer
+ if b.complete_state:
+ b.complete_previous()
+ else:
+ b.start_completion(select_last=True)
+
+ @kb.add("c-space")
+ def _(event):
+ """
+ Initialize autocompletion at cursor.
+
+ If the autocompletion menu is not showing, display it with the
+ appropriate completions for the context.
+
+ If the menu is showing, select the next completion.
+ """
+ _logger.debug("Detected <C-Space> key.")
+
+ b = event.app.current_buffer
+ if b.complete_state:
+ b.complete_next()
+ else:
+ b.start_completion(select_first=False)
+
+ @kb.add("enter", filter=completion_is_selected)
+ def _(event):
+ """Makes the enter key work as the tab key only when showing the menu.
+
+ In other words, don't execute query when enter is pressed in
+ the completion dropdown menu, instead close the dropdown menu
+ (accept current selection).
+
+ """
+ _logger.debug("Detected enter key.")
+
+ event.current_buffer.complete_state = None
+ b = event.app.current_buffer
+ b.complete_state = None
+
+ return kb
diff --git a/litecli/lexer.py b/litecli/lexer.py
new file mode 100644
index 0000000..678eb3f
--- /dev/null
+++ b/litecli/lexer.py
@@ -0,0 +1,9 @@
+from pygments.lexer import inherit
+from pygments.lexers.sql import MySqlLexer
+from pygments.token import Keyword
+
+
+class LiteCliLexer(MySqlLexer):
+ """Extends SQLite lexer to add keywords."""
+
+ tokens = {"root": [(r"\brepair\b", Keyword), (r"\boffset\b", Keyword), inherit]}
diff --git a/litecli/liteclirc b/litecli/liteclirc
new file mode 100644
index 0000000..4db6f3a
--- /dev/null
+++ b/litecli/liteclirc
@@ -0,0 +1,122 @@
+# vi: ft=dosini
+[main]
+
+# Multi-line mode allows breaking up the sql statements into multiple lines. If
+# this is set to True, then the end of the statements must have a semi-colon.
+# If this is set to False then sql statements can't be split into multiple
+# lines. End of line (return) is considered as the end of the statement.
+multi_line = False
+
+# Destructive warning mode will alert you before executing a sql statement
+# that may cause harm to the database such as "drop table", "drop database"
+# or "shutdown".
+destructive_warning = True
+
+# log_file location.
+# In Unix/Linux: ~/.config/litecli/log
+# In Windows: %USERPROFILE%\AppData\Local\dbcli\litecli\log
+# %USERPROFILE% is typically C:\Users\{username}
+log_file = default
+
+# Default log level. Possible values: "CRITICAL", "ERROR", "WARNING", "INFO"
+# and "DEBUG". "NONE" disables logging.
+log_level = INFO
+
+# Log every query and its results to a file. Enable this by uncommenting the
+# line below.
+# audit_log = ~/.litecli-audit.log
+
+# Default pager.
+# By default '$PAGER' environment variable is used
+# pager = less -SRXF
+
+# Table format. Possible values:
+# ascii, double, github, psql, plain, simple, grid, fancy_grid, pipe, orgtbl,
+# rst, mediawiki, html, latex, latex_booktabs, textile, moinmoin, jira,
+# vertical, tsv, csv.
+# Recommended: ascii
+table_format = ascii
+
+# Syntax coloring style. Possible values (many support the "-dark" suffix):
+# manni, igor, xcode, vim, autumn, vs, rrt, native, perldoc, borland, tango, emacs,
+# friendly, monokai, paraiso, colorful, murphy, bw, pastie, paraiso, trac, default,
+# fruity.
+# Screenshots at http://mycli.net/syntax
+syntax_style = default
+
+# Keybindings: Possible values: emacs, vi.
+# Emacs mode: Ctrl-A is home, Ctrl-E is end. All emacs keybindings are available in the REPL.
+# When Vi mode is enabled you can use modal editing features offered by Vi in the REPL.
+key_bindings = emacs
+
+# Enabling this option will show the suggestions in a wider menu. Thus more items are suggested.
+wider_completion_menu = False
+
+# Autocompletion is on by default. This can be truned off by setting this
+# option to False. Pressing tab will still trigger completion.
+autocompletion = True
+
+# litecli prompt
+# \D - The full current date
+# \d - Database name
+# \f - File basename of the "main" database
+# \m - Minutes of the current time
+# \n - Newline
+# \P - AM/PM
+# \R - The current time, in 24-hour military time (0-23)
+# \r - The current time, standard 12-hour time (1-12)
+# \s - Seconds of the current time
+# \x1b[...m - insert ANSI escape sequence
+prompt = '\d> '
+prompt_continuation = '-> '
+
+# Show/hide the informational toolbar with function keymap at the footer.
+show_bottom_toolbar = True
+
+# Skip intro info on startup and outro info on exit
+less_chatty = False
+
+# Use alias from --login-path instead of host name in prompt
+login_path_as_host = False
+
+# Cause result sets to be displayed vertically if they are too wide for the current window,
+# and using normal tabular format otherwise. (This applies to statements terminated by ; or \G.)
+auto_vertical_output = False
+
+# keyword casing preference. Possible values "lower", "upper", "auto"
+keyword_casing = auto
+
+# disabled pager on startup
+enable_pager = True
+
+# Custom colors for the completion menu, toolbar, etc.
+[colors]
+completion-menu.completion.current = 'bg:#ffffff #000000'
+completion-menu.completion = 'bg:#008888 #ffffff'
+completion-menu.meta.completion.current = 'bg:#44aaaa #000000'
+completion-menu.meta.completion = 'bg:#448888 #ffffff'
+completion-menu.multi-column-meta = 'bg:#aaffff #000000'
+scrollbar.arrow = 'bg:#003333'
+scrollbar = 'bg:#00aaaa'
+selected = '#ffffff bg:#6666aa'
+search = '#ffffff bg:#4444aa'
+search.current = '#ffffff bg:#44aa44'
+bottom-toolbar = 'bg:#222222 #aaaaaa'
+bottom-toolbar.off = 'bg:#222222 #888888'
+bottom-toolbar.on = 'bg:#222222 #ffffff'
+search-toolbar = 'noinherit bold'
+search-toolbar.text = 'nobold'
+system-toolbar = 'noinherit bold'
+arg-toolbar = 'noinherit bold'
+arg-toolbar.text = 'nobold'
+bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold'
+bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold'
+
+# style classes for colored table output
+output.header = "#00ff5f bold"
+output.odd-row = ""
+output.even-row = ""
+
+
+# Favorite queries.
+[favorite_queries]
diff --git a/litecli/main.py b/litecli/main.py
new file mode 100644
index 0000000..de279f6
--- /dev/null
+++ b/litecli/main.py
@@ -0,0 +1,1017 @@
+from __future__ import unicode_literals
+from __future__ import print_function
+
+import os
+import sys
+import traceback
+import logging
+import threading
+from time import time
+from datetime import datetime
+from io import open
+from collections import namedtuple
+from sqlite3 import OperationalError
+import shutil
+
+from cli_helpers.tabular_output import TabularOutputFormatter
+from cli_helpers.tabular_output import preprocessors
+import click
+import sqlparse
+from prompt_toolkit.completion import DynamicCompleter
+from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
+from prompt_toolkit.shortcuts import PromptSession, CompleteStyle
+from prompt_toolkit.styles.pygments import style_from_pygments_cls
+from prompt_toolkit.document import Document
+from prompt_toolkit.filters import HasFocus, IsDone
+from prompt_toolkit.formatted_text import ANSI
+from prompt_toolkit.layout.processors import (
+ HighlightMatchingBracketProcessor,
+ ConditionalProcessor,
+)
+from prompt_toolkit.lexers import PygmentsLexer
+from prompt_toolkit.history import FileHistory
+from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
+
+from .packages.special.main import NO_QUERY
+from .packages.prompt_utils import confirm, confirm_destructive_query
+from .packages import special
+from .sqlcompleter import SQLCompleter
+from .clitoolbar import create_toolbar_tokens_func
+from .clistyle import style_factory, style_factory_output
+from .sqlexecute import SQLExecute
+from .clibuffer import cli_is_multiline
+from .completion_refresher import CompletionRefresher
+from .config import config_location, ensure_dir_exists, get_config
+from .key_bindings import cli_bindings
+from .encodingutils import utf8tounicode, text_type
+from .lexer import LiteCliLexer
+from .__init__ import __version__
+from .packages.filepaths import dir_path_exists
+
+import itertools
+
+click.disable_unicode_literals_warning = True
+
+# Query tuples are used for maintaining history
+Query = namedtuple("Query", ["query", "successful", "mutating"])
+
+PACKAGE_ROOT = os.path.abspath(os.path.dirname(__file__))
+
+
+class LiteCli(object):
+
+ default_prompt = "\\d> "
+ max_len_prompt = 45
+
+ def __init__(
+ self,
+ sqlexecute=None,
+ prompt=None,
+ logfile=None,
+ auto_vertical_output=False,
+ warn=None,
+ liteclirc=None,
+ ):
+ self.sqlexecute = sqlexecute
+ self.logfile = logfile
+
+ # Load config.
+ c = self.config = get_config(liteclirc)
+
+ self.multi_line = c["main"].as_bool("multi_line")
+ self.key_bindings = c["main"]["key_bindings"]
+ special.set_favorite_queries(self.config)
+ self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"])
+ self.formatter.litecli = self
+ self.syntax_style = c["main"]["syntax_style"]
+ self.less_chatty = c["main"].as_bool("less_chatty")
+ self.show_bottom_toolbar = c["main"].as_bool("show_bottom_toolbar")
+ self.cli_style = c["colors"]
+ self.output_style = style_factory_output(self.syntax_style, self.cli_style)
+ self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
+ self.autocompletion = c["main"].as_bool("autocompletion")
+ c_dest_warning = c["main"].as_bool("destructive_warning")
+ self.destructive_warning = c_dest_warning if warn is None else warn
+ self.login_path_as_host = c["main"].as_bool("login_path_as_host")
+
+ # read from cli argument or user config file
+ self.auto_vertical_output = auto_vertical_output or c["main"].as_bool(
+ "auto_vertical_output"
+ )
+
+ # audit log
+ if self.logfile is None and "audit_log" in c["main"]:
+ try:
+ self.logfile = open(os.path.expanduser(c["main"]["audit_log"]), "a")
+ except (IOError, OSError):
+ self.echo(
+ "Error: Unable to open the audit log file. Your queries will not be logged.",
+ err=True,
+ fg="red",
+ )
+ self.logfile = False
+
+ self.completion_refresher = CompletionRefresher()
+
+ self.logger = logging.getLogger(__name__)
+ self.initialize_logging()
+
+ prompt_cnf = self.read_my_cnf_files(["prompt"])["prompt"]
+ self.prompt_format = (
+ prompt or prompt_cnf or c["main"]["prompt"] or self.default_prompt
+ )
+ self.prompt_continuation_format = c["main"]["prompt_continuation"]
+ keyword_casing = c["main"].get("keyword_casing", "auto")
+
+ self.query_history = []
+
+ # Initialize completer.
+ self.completer = SQLCompleter(
+ supported_formats=self.formatter.supported_formats,
+ keyword_casing=keyword_casing,
+ )
+ self._completer_lock = threading.Lock()
+
+ # Register custom special commands.
+ self.register_special_commands()
+
+ self.prompt_app = None
+
+ def register_special_commands(self):
+ special.register_special_command(
+ self.change_db,
+ ".open",
+ ".open",
+ "Change to a new database.",
+ aliases=("use", "\\u"),
+ )
+ special.register_special_command(
+ self.refresh_completions,
+ "rehash",
+ "\\#",
+ "Refresh auto-completions.",
+ arg_type=NO_QUERY,
+ aliases=("\\#",),
+ )
+ special.register_special_command(
+ self.change_table_format,
+ ".mode",
+ "\\T",
+ "Change the table format used to output results.",
+ aliases=("tableformat", "\\T"),
+ case_sensitive=True,
+ )
+ special.register_special_command(
+ self.execute_from_file,
+ ".read",
+ "\\. filename",
+ "Execute commands from file.",
+ case_sensitive=True,
+ aliases=("\\.", "source"),
+ )
+ special.register_special_command(
+ self.change_prompt_format,
+ "prompt",
+ "\\R",
+ "Change prompt format.",
+ aliases=("\\R",),
+ case_sensitive=True,
+ )
+
+ def change_table_format(self, arg, **_):
+ try:
+ self.formatter.format_name = arg
+ yield (None, None, None, "Changed table format to {}".format(arg))
+ except ValueError:
+ msg = "Table format {} not recognized. Allowed formats:".format(arg)
+ for table_type in self.formatter.supported_formats:
+ msg += "\n\t{}".format(table_type)
+ yield (None, None, None, msg)
+
+ def change_db(self, arg, **_):
+ if arg is None:
+ self.sqlexecute.connect()
+ else:
+ self.sqlexecute.connect(database=arg)
+
+ self.refresh_completions()
+ yield (
+ None,
+ None,
+ None,
+ 'You are now connected to database "%s"' % (self.sqlexecute.dbname),
+ )
+
+ def execute_from_file(self, arg, **_):
+ if not arg:
+ message = "Missing required argument, filename."
+ return [(None, None, None, message)]
+ try:
+ with open(os.path.expanduser(arg), encoding="utf-8") as f:
+ query = f.read()
+ except IOError as e:
+ return [(None, None, None, str(e))]
+
+ if self.destructive_warning and confirm_destructive_query(query) is False:
+ message = "Wise choice. Command execution stopped."
+ return [(None, None, None, message)]
+
+ return self.sqlexecute.run(query)
+
+ def change_prompt_format(self, arg, **_):
+ """
+ Change the prompt format.
+ """
+ if not arg:
+ message = "Missing required argument, format."
+ return [(None, None, None, message)]
+
+ self.prompt_format = self.get_prompt(arg)
+ return [(None, None, None, "Changed prompt format to %s" % arg)]
+
+ def initialize_logging(self):
+
+ log_file = self.config["main"]["log_file"]
+ if log_file == "default":
+ log_file = config_location() + "log"
+ ensure_dir_exists(log_file)
+
+ log_level = self.config["main"]["log_level"]
+
+ level_map = {
+ "CRITICAL": logging.CRITICAL,
+ "ERROR": logging.ERROR,
+ "WARNING": logging.WARNING,
+ "INFO": logging.INFO,
+ "DEBUG": logging.DEBUG,
+ }
+
+ # Disable logging if value is NONE by switching to a no-op handler
+ # Set log level to a high value so it doesn't even waste cycles getting called.
+ if log_level.upper() == "NONE":
+ handler = logging.NullHandler()
+ log_level = "CRITICAL"
+ elif dir_path_exists(log_file):
+ handler = logging.FileHandler(log_file)
+ else:
+ self.echo(
+ 'Error: Unable to open the log file "{}".'.format(log_file),
+ err=True,
+ fg="red",
+ )
+ return
+
+ formatter = logging.Formatter(
+ "%(asctime)s (%(process)d/%(threadName)s) "
+ "%(name)s %(levelname)s - %(message)s"
+ )
+
+ handler.setFormatter(formatter)
+
+ root_logger = logging.getLogger("litecli")
+ root_logger.addHandler(handler)
+ root_logger.setLevel(level_map[log_level.upper()])
+
+ logging.captureWarnings(True)
+
+ root_logger.debug("Initializing litecli logging.")
+ root_logger.debug("Log file %r.", log_file)
+
+ def read_my_cnf_files(self, keys):
+ """
+ Reads a list of config files and merges them. The last one will win.
+ :param files: list of files to read
+ :param keys: list of keys to retrieve
+ :returns: tuple, with None for missing keys.
+ """
+ cnf = self.config
+
+ sections = ["main"]
+
+ def get(key):
+ result = None
+ for sect in cnf:
+ if sect in sections and key in cnf[sect]:
+ result = cnf[sect][key]
+ return result
+
+ return {x: get(x) for x in keys}
+
+ def connect(self, database=""):
+
+ cnf = {"database": None}
+
+ cnf = self.read_my_cnf_files(cnf.keys())
+
+ # Fall back to config values only if user did not specify a value.
+
+ database = database or cnf["database"]
+
+ # Connect to the database.
+
+ def _connect():
+ self.sqlexecute = SQLExecute(database)
+
+ try:
+ _connect()
+ except Exception as e: # Connecting to a database could fail.
+ self.logger.debug("Database connection failed: %r.", e)
+ self.logger.error("traceback: %r", traceback.format_exc())
+ self.echo(str(e), err=True, fg="red")
+ exit(1)
+
+ def handle_editor_command(self, text):
+ """Editor command is any query that is prefixed or suffixed by a '\e'.
+ The reason for a while loop is because a user might edit a query
+ multiple times. For eg:
+
+ "select * from \e"<enter> to edit it in vim, then come
+ back to the prompt with the edited query "select * from
+ blah where q = 'abc'\e" to edit it again.
+ :param text: Document
+ :return: Document
+
+ """
+
+ while special.editor_command(text):
+ filename = special.get_filename(text)
+ query = special.get_editor_query(text) or self.get_last_query()
+ sql, message = special.open_external_editor(filename, sql=query)
+ if message:
+ # Something went wrong. Raise an exception and bail.
+ raise RuntimeError(message)
+ while True:
+ try:
+ text = self.prompt_app.prompt(default=sql)
+ break
+ except KeyboardInterrupt:
+ sql = ""
+
+ continue
+ return text
+
+ def run_cli(self):
+ iterations = 0
+ sqlexecute = self.sqlexecute
+ logger = self.logger
+ self.configure_pager()
+ self.refresh_completions()
+
+ history_file = config_location() + "history"
+ if dir_path_exists(history_file):
+ history = FileHistory(history_file)
+ else:
+ history = None
+ self.echo(
+ 'Error: Unable to open the history file "{}". '
+ "Your query history will not be saved.".format(history_file),
+ err=True,
+ fg="red",
+ )
+
+ key_bindings = cli_bindings(self)
+
+ if not self.less_chatty:
+ print("Version:", __version__)
+ print("Mail: https://groups.google.com/forum/#!forum/litecli-users")
+ print("GitHub: https://github.com/dbcli/litecli")
+ # print("Home: https://litecli.com")
+
+ def get_message():
+ prompt = self.get_prompt(self.prompt_format)
+ if (
+ self.prompt_format == self.default_prompt
+ and len(prompt) > self.max_len_prompt
+ ):
+ prompt = self.get_prompt("\\d> ")
+ prompt = prompt.replace("\\x1b", "\x1b")
+ return ANSI(prompt)
+
+ def get_continuation(width, line_number, is_soft_wrap):
+ continuation = " " * (width - 1) + " "
+ return [("class:continuation", continuation)]
+
+ def show_suggestion_tip():
+ return iterations < 2
+
+ def one_iteration(text=None):
+ if text is None:
+ try:
+ text = self.prompt_app.prompt()
+ except KeyboardInterrupt:
+ return
+
+ special.set_expanded_output(False)
+
+ try:
+ text = self.handle_editor_command(text)
+ except RuntimeError as e:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ self.echo(str(e), err=True, fg="red")
+ return
+
+ if not text.strip():
+ return
+
+ if self.destructive_warning:
+ destroy = confirm_destructive_query(text)
+ if destroy is None:
+ pass # Query was not destructive. Nothing to do here.
+ elif destroy is True:
+ self.echo("Your call!")
+ else:
+ self.echo("Wise choice!")
+ return
+
+ # Keep track of whether or not the query is mutating. In case
+ # of a multi-statement query, the overall query is considered
+ # mutating if any one of the component statements is mutating
+ mutating = False
+
+ try:
+ logger.debug("sql: %r", text)
+
+ special.write_tee(self.get_prompt(self.prompt_format) + text)
+ if self.logfile:
+ self.logfile.write("\n# %s\n" % datetime.now())
+ self.logfile.write(text)
+ self.logfile.write("\n")
+
+ successful = False
+ start = time()
+ res = sqlexecute.run(text)
+ self.formatter.query = text
+ successful = True
+ result_count = 0
+ for title, cur, headers, status in res:
+ logger.debug("headers: %r", headers)
+ logger.debug("rows: %r", cur)
+ logger.debug("status: %r", status)
+ threshold = 1000
+ if is_select(status) and cur and cur.rowcount > threshold:
+ self.echo(
+ "The result set has more than {} rows.".format(threshold),
+ fg="red",
+ )
+ if not confirm("Do you want to continue?"):
+ self.echo("Aborted!", err=True, fg="red")
+ break
+
+ if self.auto_vertical_output:
+ max_width = self.prompt_app.output.get_size().columns
+ else:
+ max_width = None
+
+ formatted = self.format_output(
+ title, cur, headers, special.is_expanded_output(), max_width
+ )
+
+ t = time() - start
+ try:
+ if result_count > 0:
+ self.echo("")
+ try:
+ self.output(formatted, status)
+ except KeyboardInterrupt:
+ pass
+ self.echo("Time: %0.03fs" % t)
+ except KeyboardInterrupt:
+ pass
+
+ start = time()
+ result_count += 1
+ mutating = mutating or is_mutating(status)
+ special.unset_once_if_written()
+ except EOFError as e:
+ raise e
+ except KeyboardInterrupt:
+ # get last connection id
+ connection_id_to_kill = sqlexecute.connection_id
+ logger.debug("connection id to kill: %r", connection_id_to_kill)
+ # Restart connection to the database
+ sqlexecute.connect()
+ try:
+ for title, cur, headers, status in sqlexecute.run(
+ "kill %s" % connection_id_to_kill
+ ):
+ status_str = str(status).lower()
+ if status_str.find("ok") > -1:
+ logger.debug(
+ "cancelled query, connection id: %r, sql: %r",
+ connection_id_to_kill,
+ text,
+ )
+ self.echo("cancelled query", err=True, fg="red")
+ except Exception as e:
+ self.echo(
+ "Encountered error while cancelling query: {}".format(e),
+ err=True,
+ fg="red",
+ )
+ except NotImplementedError:
+ self.echo("Not Yet Implemented.", fg="yellow")
+ except OperationalError as e:
+ logger.debug("Exception: %r", e)
+ if e.args[0] in (2003, 2006, 2013):
+ logger.debug("Attempting to reconnect.")
+ self.echo("Reconnecting...", fg="yellow")
+ try:
+ sqlexecute.connect()
+ logger.debug("Reconnected successfully.")
+ one_iteration(text)
+ return # OK to just return, cuz the recursion call runs to the end.
+ except OperationalError as e:
+ logger.debug("Reconnect failed. e: %r", e)
+ self.echo(str(e), err=True, fg="red")
+ # If reconnection failed, don't proceed further.
+ return
+ else:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ self.echo(str(e), err=True, fg="red")
+ except Exception as e:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ self.echo(str(e), err=True, fg="red")
+ else:
+ # Refresh the table names and column names if necessary.
+ if need_completion_refresh(text):
+ self.refresh_completions(reset=need_completion_reset(text))
+ finally:
+ if self.logfile is False:
+ self.echo("Warning: This query was not logged.", err=True, fg="red")
+ query = Query(text, successful, mutating)
+ self.query_history.append(query)
+
+ get_toolbar_tokens = create_toolbar_tokens_func(self, show_suggestion_tip)
+
+ if self.wider_completion_menu:
+ complete_style = CompleteStyle.MULTI_COLUMN
+ else:
+ complete_style = CompleteStyle.COLUMN
+
+ if not self.autocompletion:
+ complete_style = CompleteStyle.READLINE_LIKE
+
+ with self._completer_lock:
+
+ if self.key_bindings == "vi":
+ editing_mode = EditingMode.VI
+ else:
+ editing_mode = EditingMode.EMACS
+
+ self.prompt_app = PromptSession(
+ lexer=PygmentsLexer(LiteCliLexer),
+ reserve_space_for_menu=self.get_reserved_space(),
+ message=get_message,
+ prompt_continuation=get_continuation,
+ bottom_toolbar=get_toolbar_tokens if self.show_bottom_toolbar else None,
+ complete_style=complete_style,
+ input_processors=[
+ ConditionalProcessor(
+ processor=HighlightMatchingBracketProcessor(chars="[](){}"),
+ filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(),
+ )
+ ],
+ tempfile_suffix=".sql",
+ completer=DynamicCompleter(lambda: self.completer),
+ history=history,
+ auto_suggest=AutoSuggestFromHistory(),
+ complete_while_typing=True,
+ multiline=cli_is_multiline(self),
+ style=style_factory(self.syntax_style, self.cli_style),
+ include_default_pygments_style=False,
+ key_bindings=key_bindings,
+ enable_open_in_editor=True,
+ enable_system_prompt=True,
+ enable_suspend=True,
+ editing_mode=editing_mode,
+ search_ignore_case=True,
+ )
+
+ try:
+ while True:
+ one_iteration()
+ iterations += 1
+ except EOFError:
+ special.close_tee()
+ if not self.less_chatty:
+ self.echo("Goodbye!")
+
+ def log_output(self, output):
+ """Log the output in the audit log, if it's enabled."""
+ if self.logfile:
+ click.echo(utf8tounicode(output), file=self.logfile)
+
+ def echo(self, s, **kwargs):
+ """Print a message to stdout.
+
+ The message will be logged in the audit log, if enabled.
+
+ All keyword arguments are passed to click.echo().
+
+ """
+ self.log_output(s)
+ click.secho(s, **kwargs)
+
+ def get_output_margin(self, status=None):
+ """Get the output margin (number of rows for the prompt, footer and
+ timing message."""
+ margin = (
+ self.get_reserved_space()
+ + self.get_prompt(self.prompt_format).count("\n")
+ + 2
+ )
+ if status:
+ margin += 1 + status.count("\n")
+
+ return margin
+
+ def output(self, output, status=None):
+ """Output text to stdout or a pager command.
+
+ The status text is not outputted to pager or files.
+
+ The message will be logged in the audit log, if enabled. The
+ message will be written to the tee file, if enabled. The
+ message will be written to the output file, if enabled.
+
+ """
+ if output:
+ size = self.prompt_app.output.get_size()
+
+ margin = self.get_output_margin(status)
+
+ fits = True
+ buf = []
+ output_via_pager = self.explicit_pager and special.is_pager_enabled()
+ for i, line in enumerate(output, 1):
+ self.log_output(line)
+ special.write_tee(line)
+ special.write_once(line)
+
+ if fits or output_via_pager:
+ # buffering
+ buf.append(line)
+ if len(line) > size.columns or i > (size.rows - margin):
+ fits = False
+ if not self.explicit_pager and special.is_pager_enabled():
+ # doesn't fit, use pager
+ output_via_pager = True
+
+ if not output_via_pager:
+ # doesn't fit, flush buffer
+ for line in buf:
+ click.secho(line)
+ buf = []
+ else:
+ click.secho(line)
+
+ if buf:
+ if output_via_pager:
+ # sadly click.echo_via_pager doesn't accept generators
+ click.echo_via_pager("\n".join(buf))
+ else:
+ for line in buf:
+ click.secho(line)
+
+ if status:
+ self.log_output(status)
+ click.secho(status)
+
+ def configure_pager(self):
+ # Provide sane defaults for less if they are empty.
+ if not os.environ.get("LESS"):
+ os.environ["LESS"] = "-RXF"
+
+ cnf = self.read_my_cnf_files(["pager", "skip-pager"])
+ if cnf["pager"]:
+ special.set_pager(cnf["pager"])
+ self.explicit_pager = True
+ else:
+ self.explicit_pager = False
+
+ if cnf["skip-pager"] or not self.config["main"].as_bool("enable_pager"):
+ special.disable_pager()
+
+ def refresh_completions(self, reset=False):
+ if reset:
+ with self._completer_lock:
+ self.completer.reset_completions()
+ self.completion_refresher.refresh(
+ self.sqlexecute,
+ self._on_completions_refreshed,
+ {
+ "supported_formats": self.formatter.supported_formats,
+ "keyword_casing": self.completer.keyword_casing,
+ },
+ )
+
+ return [
+ (None, None, None, "Auto-completion refresh started in the background.")
+ ]
+
+ def _on_completions_refreshed(self, new_completer):
+ """Swap the completer object in cli with the newly created completer."""
+ with self._completer_lock:
+ self.completer = new_completer
+
+ if self.prompt_app:
+ # After refreshing, redraw the CLI to clear the statusbar
+ # "Refreshing completions..." indicator
+ self.prompt_app.app.invalidate()
+
+ def get_completions(self, text, cursor_positition):
+ with self._completer_lock:
+ return self.completer.get_completions(
+ Document(text=text, cursor_position=cursor_positition), None
+ )
+
+ def get_prompt(self, string):
+ self.logger.debug("Getting prompt")
+ sqlexecute = self.sqlexecute
+ now = datetime.now()
+ string = string.replace("\\d", sqlexecute.dbname or "(none)")
+ string = string.replace("\\f", os.path.basename(sqlexecute.dbname or "(none)"))
+ string = string.replace("\\n", "\n")
+ string = string.replace("\\D", now.strftime("%a %b %d %H:%M:%S %Y"))
+ string = string.replace("\\m", now.strftime("%M"))
+ string = string.replace("\\P", now.strftime("%p"))
+ string = string.replace("\\R", now.strftime("%H"))
+ string = string.replace("\\r", now.strftime("%I"))
+ string = string.replace("\\s", now.strftime("%S"))
+ string = string.replace("\\_", " ")
+ return string
+
+ def run_query(self, query, new_line=True):
+ """Runs *query*."""
+ results = self.sqlexecute.run(query)
+ for result in results:
+ title, cur, headers, status = result
+ self.formatter.query = query
+ output = self.format_output(title, cur, headers)
+ for line in output:
+ click.echo(line, nl=new_line)
+
+ def format_output(self, title, cur, headers, expanded=False, max_width=None):
+ expanded = expanded or self.formatter.format_name == "vertical"
+ output = []
+
+ output_kwargs = {
+ "dialect": "unix",
+ "disable_numparse": True,
+ "preserve_whitespace": True,
+ "preprocessors": (preprocessors.align_decimals,),
+ "style": self.output_style,
+ }
+
+ if title: # Only print the title if it's not None.
+ output = itertools.chain(output, [title])
+
+ if cur:
+ column_types = None
+ if hasattr(cur, "description"):
+
+ def get_col_type(col):
+ # col_type = FIELD_TYPES.get(col[1], text_type)
+ # return col_type if type(col_type) is type else text_type
+ return text_type
+
+ column_types = [get_col_type(col) for col in cur.description]
+
+ if max_width is not None:
+ cur = list(cur)
+
+ formatted = self.formatter.format_output(
+ cur,
+ headers,
+ format_name="vertical" if expanded else None,
+ column_types=column_types,
+ **output_kwargs
+ )
+
+ if isinstance(formatted, (text_type)):
+ formatted = formatted.splitlines()
+ formatted = iter(formatted)
+
+ first_line = next(formatted)
+ formatted = itertools.chain([first_line], formatted)
+
+ if (
+ not expanded
+ and max_width
+ and headers
+ and cur
+ and len(first_line) > max_width
+ ):
+ formatted = self.formatter.format_output(
+ cur,
+ headers,
+ format_name="vertical",
+ column_types=column_types,
+ **output_kwargs
+ )
+ if isinstance(formatted, (text_type)):
+ formatted = iter(formatted.splitlines())
+
+ output = itertools.chain(output, formatted)
+
+ return output
+
+ def get_reserved_space(self):
+ """Get the number of lines to reserve for the completion menu."""
+ reserved_space_ratio = 0.45
+ max_reserved_space = 8
+ _, height = shutil.get_terminal_size()
+ return min(int(round(height * reserved_space_ratio)), max_reserved_space)
+
+ def get_last_query(self):
+ """Get the last query executed or None."""
+ return self.query_history[-1][0] if self.query_history else None
+
+
+@click.command()
+@click.option("-V", "--version", is_flag=True, help="Output litecli's version.")
+@click.option("-D", "--database", "dbname", help="Database to use.")
+@click.option(
+ "-R",
+ "--prompt",
+ "prompt",
+ help='Prompt format (Default: "{0}").'.format(LiteCli.default_prompt),
+)
+@click.option(
+ "-l",
+ "--logfile",
+ type=click.File(mode="a", encoding="utf-8"),
+ help="Log every query and its results to a file.",
+)
+@click.option(
+ "--liteclirc",
+ default=config_location() + "config",
+ help="Location of liteclirc file.",
+ type=click.Path(dir_okay=False),
+)
+@click.option(
+ "--auto-vertical-output",
+ is_flag=True,
+ help="Automatically switch to vertical output mode if the result is wider than the terminal width.",
+)
+@click.option(
+ "-t", "--table", is_flag=True, help="Display batch output in table format."
+)
+@click.option("--csv", is_flag=True, help="Display batch output in CSV format.")
+@click.option(
+ "--warn/--no-warn", default=None, help="Warn before running a destructive query."
+)
+@click.option("-e", "--execute", type=str, help="Execute command and quit.")
+@click.argument("database", default="", nargs=1)
+def cli(
+ database,
+ dbname,
+ version,
+ prompt,
+ logfile,
+ auto_vertical_output,
+ table,
+ csv,
+ warn,
+ execute,
+ liteclirc,
+):
+ """A SQLite terminal client with auto-completion and syntax highlighting.
+
+ \b
+ Examples:
+ - litecli lite_database
+
+ """
+
+ if version:
+ print("Version:", __version__)
+ sys.exit(0)
+
+ litecli = LiteCli(
+ prompt=prompt,
+ logfile=logfile,
+ auto_vertical_output=auto_vertical_output,
+ warn=warn,
+ liteclirc=liteclirc,
+ )
+
+ # Choose which ever one has a valid value.
+ database = database or dbname
+
+ litecli.connect(database)
+
+ litecli.logger.debug("Launch Params: \n" "\tdatabase: %r", database)
+
+ # --execute argument
+ if execute:
+ try:
+ if csv:
+ litecli.formatter.format_name = "csv"
+ elif not table:
+ litecli.formatter.format_name = "tsv"
+
+ litecli.run_query(execute)
+ exit(0)
+ except Exception as e:
+ click.secho(str(e), err=True, fg="red")
+ exit(1)
+
+ if sys.stdin.isatty():
+ litecli.run_cli()
+ else:
+ stdin = click.get_text_stream("stdin")
+ stdin_text = stdin.read()
+
+ try:
+ sys.stdin = open("/dev/tty")
+ except (FileNotFoundError, OSError):
+ litecli.logger.warning("Unable to open TTY as stdin.")
+
+ if (
+ litecli.destructive_warning
+ and confirm_destructive_query(stdin_text) is False
+ ):
+ exit(0)
+ try:
+ new_line = True
+
+ if csv:
+ litecli.formatter.format_name = "csv"
+ elif not table:
+ litecli.formatter.format_name = "tsv"
+
+ litecli.run_query(stdin_text, new_line=new_line)
+ exit(0)
+ except Exception as e:
+ click.secho(str(e), err=True, fg="red")
+ exit(1)
+
+
+def need_completion_refresh(queries):
+ """Determines if the completion needs a refresh by checking if the sql
+ statement is an alter, create, drop or change db."""
+ for query in sqlparse.split(queries):
+ try:
+ first_token = query.split()[0]
+ if first_token.lower() in (
+ "alter",
+ "create",
+ "use",
+ "\\r",
+ "\\u",
+ "connect",
+ "drop",
+ ):
+ return True
+ except Exception:
+ return False
+
+
+def need_completion_reset(queries):
+ """Determines if the statement is a database switch such as 'use' or '\\u'.
+ When a database is changed the existing completions must be reset before we
+ start the completion refresh for the new database.
+ """
+ for query in sqlparse.split(queries):
+ try:
+ first_token = query.split()[0]
+ if first_token.lower() in ("use", "\\u"):
+ return True
+ except Exception:
+ return False
+
+
+def is_mutating(status):
+ """Determines if the statement is mutating based on the status."""
+ if not status:
+ return False
+
+ mutating = set(
+ [
+ "insert",
+ "update",
+ "delete",
+ "alter",
+ "create",
+ "drop",
+ "replace",
+ "truncate",
+ "load",
+ ]
+ )
+ return status.split(None, 1)[0].lower() in mutating
+
+
+def is_select(status):
+ """Returns true if the first word in status is 'select'."""
+ if not status:
+ return False
+ return status.split(None, 1)[0].lower() == "select"
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/litecli/packages/__init__.py b/litecli/packages/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/litecli/packages/__init__.py
diff --git a/litecli/packages/completion_engine.py b/litecli/packages/completion_engine.py
new file mode 100644
index 0000000..0e2a30f
--- /dev/null
+++ b/litecli/packages/completion_engine.py
@@ -0,0 +1,331 @@
+from __future__ import print_function
+import sys
+import sqlparse
+from sqlparse.sql import Comparison, Identifier, Where
+from litecli.encodingutils import string_types, text_type
+from .parseutils import last_word, extract_tables, find_prev_keyword
+from .special import parse_special_command
+
+
+def suggest_type(full_text, text_before_cursor):
+ """Takes the full_text that is typed so far and also the text before the
+ cursor to suggest completion type and scope.
+
+ Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
+ A scope for a column category will be a list of tables.
+ """
+
+ word_before_cursor = last_word(text_before_cursor, include="many_punctuations")
+
+ identifier = None
+
+ # here should be removed once sqlparse has been fixed
+ try:
+ # If we've partially typed a word then word_before_cursor won't be an empty
+ # string. In that case we want to remove the partially typed string before
+ # sending it to the sqlparser. Otherwise the last token will always be the
+ # partially typed string which renders the smart completion useless because
+ # it will always return the list of keywords as completion.
+ if word_before_cursor:
+ if word_before_cursor.endswith("(") or word_before_cursor.startswith("\\"):
+ parsed = sqlparse.parse(text_before_cursor)
+ else:
+ parsed = sqlparse.parse(text_before_cursor[: -len(word_before_cursor)])
+
+ # word_before_cursor may include a schema qualification, like
+ # "schema_name.partial_name" or "schema_name.", so parse it
+ # separately
+ p = sqlparse.parse(word_before_cursor)[0]
+
+ if p.tokens and isinstance(p.tokens[0], Identifier):
+ identifier = p.tokens[0]
+ else:
+ parsed = sqlparse.parse(text_before_cursor)
+ except (TypeError, AttributeError):
+ return [{"type": "keyword"}]
+
+ if len(parsed) > 1:
+ # Multiple statements being edited -- isolate the current one by
+ # cumulatively summing statement lengths to find the one that bounds the
+ # current position
+ current_pos = len(text_before_cursor)
+ stmt_start, stmt_end = 0, 0
+
+ for statement in parsed:
+ stmt_len = len(text_type(statement))
+ stmt_start, stmt_end = stmt_end, stmt_end + stmt_len
+
+ if stmt_end >= current_pos:
+ text_before_cursor = full_text[stmt_start:current_pos]
+ full_text = full_text[stmt_start:]
+ break
+
+ elif parsed:
+ # A single statement
+ statement = parsed[0]
+ else:
+ # The empty string
+ statement = None
+
+ # Check for special commands and handle those separately
+ if statement:
+ # Be careful here because trivial whitespace is parsed as a statement,
+ # but the statement won't have a first token
+ tok1 = statement.token_first()
+ if tok1 and tok1.value.startswith("."):
+ return suggest_special(text_before_cursor)
+ elif tok1 and tok1.value.startswith("\\"):
+ return suggest_special(text_before_cursor)
+ elif tok1 and tok1.value.startswith("source"):
+ return suggest_special(text_before_cursor)
+ elif text_before_cursor and text_before_cursor.startswith(".open "):
+ return suggest_special(text_before_cursor)
+
+ last_token = statement and statement.token_prev(len(statement.tokens))[1] or ""
+
+ return suggest_based_on_last_token(
+ last_token, text_before_cursor, full_text, identifier
+ )
+
+
+def suggest_special(text):
+ text = text.lstrip()
+ cmd, _, arg = parse_special_command(text)
+
+ if cmd == text:
+ # Trying to complete the special command itself
+ return [{"type": "special"}]
+
+ if cmd in ("\\u", "\\r"):
+ return [{"type": "database"}]
+
+ if cmd in ("\\T"):
+ return [{"type": "table_format"}]
+
+ if cmd in ["\\f", "\\fs", "\\fd"]:
+ return [{"type": "favoritequery"}]
+
+ if cmd in ["\\d", "\\dt", "\\dt+", ".schema", ".indexes"]:
+ return [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+
+ if cmd in ["\\.", "source", ".open", ".read"]:
+ return [{"type": "file_name"}]
+
+ if cmd in [".import"]:
+ # Usage: .import filename table
+ if _expecting_arg_idx(arg, text) == 1:
+ return [{"type": "file_name"}]
+ else:
+ return [{"type": "table", "schema": []}]
+
+ return [{"type": "keyword"}, {"type": "special"}]
+
+
+def _expecting_arg_idx(arg, text):
+ """Return the index of expecting argument.
+
+ >>> _expecting_arg_idx("./da", ".import ./da")
+ 1
+ >>> _expecting_arg_idx("./data.csv", ".import ./data.csv")
+ 1
+ >>> _expecting_arg_idx("./data.csv", ".import ./data.csv ")
+ 2
+ >>> _expecting_arg_idx("./data.csv t", ".import ./data.csv t")
+ 2
+ """
+ args = arg.split()
+ return len(args) + int(text[-1].isspace())
+
+
+def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier):
+ if isinstance(token, string_types):
+ token_v = token.lower()
+ elif isinstance(token, Comparison):
+ # If 'token' is a Comparison type such as
+ # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling
+ # token.value on the comparison type will only return the lhs of the
+ # comparison. In this case a.id. So we need to do token.tokens to get
+ # both sides of the comparison and pick the last token out of that
+ # list.
+ token_v = token.tokens[-1].value.lower()
+ elif isinstance(token, Where):
+ # sqlparse groups all tokens from the where clause into a single token
+ # list. This means that token.value may be something like
+ # 'where foo > 5 and '. We need to look "inside" token.tokens to handle
+ # suggestions in complicated where clauses correctly
+ prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
+ return suggest_based_on_last_token(
+ prev_keyword, text_before_cursor, full_text, identifier
+ )
+ else:
+ token_v = token.value.lower()
+
+ is_operand = lambda x: x and any([x.endswith(op) for op in ["+", "-", "*", "/"]])
+
+ if not token:
+ return [{"type": "keyword"}, {"type": "special"}]
+ elif token_v.endswith("("):
+ p = sqlparse.parse(text_before_cursor)[0]
+
+ if p.tokens and isinstance(p.tokens[-1], Where):
+ # Four possibilities:
+ # 1 - Parenthesized clause like "WHERE foo AND ("
+ # Suggest columns/functions
+ # 2 - Function call like "WHERE foo("
+ # Suggest columns/functions
+ # 3 - Subquery expression like "WHERE EXISTS ("
+ # Suggest keywords, in order to do a subquery
+ # 4 - Subquery OR array comparison like "WHERE foo = ANY("
+ # Suggest columns/functions AND keywords. (If we wanted to be
+ # really fancy, we could suggest only array-typed columns)
+
+ column_suggestions = suggest_based_on_last_token(
+ "where", text_before_cursor, full_text, identifier
+ )
+
+ # Check for a subquery expression (cases 3 & 4)
+ where = p.tokens[-1]
+ idx, prev_tok = where.token_prev(len(where.tokens) - 1)
+
+ if isinstance(prev_tok, Comparison):
+ # e.g. "SELECT foo FROM bar WHERE foo = ANY("
+ prev_tok = prev_tok.tokens[-1]
+
+ prev_tok = prev_tok.value.lower()
+ if prev_tok == "exists":
+ return [{"type": "keyword"}]
+ else:
+ return column_suggestions
+
+ # Get the token before the parens
+ idx, prev_tok = p.token_prev(len(p.tokens) - 1)
+ if prev_tok and prev_tok.value and prev_tok.value.lower() == "using":
+ # tbl1 INNER JOIN tbl2 USING (col1, col2)
+ tables = extract_tables(full_text)
+
+ # suggest columns that are present in more than one table
+ return [{"type": "column", "tables": tables, "drop_unique": True}]
+ elif p.token_first().value.lower() == "select":
+ # If the lparen is preceeded by a space chances are we're about to
+ # do a sub-select.
+ if last_word(text_before_cursor, "all_punctuations").startswith("("):
+ return [{"type": "keyword"}]
+ elif p.token_first().value.lower() == "show":
+ return [{"type": "show"}]
+
+ # We're probably in a function argument list
+ return [{"type": "column", "tables": extract_tables(full_text)}]
+ elif token_v in ("set", "order by", "distinct"):
+ return [{"type": "column", "tables": extract_tables(full_text)}]
+ elif token_v == "as":
+ # Don't suggest anything for an alias
+ return []
+ elif token_v in ("show"):
+ return [{"type": "show"}]
+ elif token_v in ("to",):
+ p = sqlparse.parse(text_before_cursor)[0]
+ if p.token_first().value.lower() == "change":
+ return [{"type": "change"}]
+ else:
+ return [{"type": "user"}]
+ elif token_v in ("user", "for"):
+ return [{"type": "user"}]
+ elif token_v in ("select", "where", "having"):
+ # Check for a table alias or schema qualification
+ parent = (identifier and identifier.get_parent_name()) or []
+
+ tables = extract_tables(full_text)
+ if parent:
+ tables = [t for t in tables if identifies(parent, *t)]
+ return [
+ {"type": "column", "tables": tables},
+ {"type": "table", "schema": parent},
+ {"type": "view", "schema": parent},
+ {"type": "function", "schema": parent},
+ ]
+ else:
+ aliases = [alias or table for (schema, table, alias) in tables]
+ return [
+ {"type": "column", "tables": tables},
+ {"type": "function", "schema": []},
+ {"type": "alias", "aliases": aliases},
+ {"type": "keyword"},
+ ]
+ elif (token_v.endswith("join") and token.is_keyword) or (
+ token_v
+ in ("copy", "from", "update", "into", "describe", "truncate", "desc", "explain")
+ ):
+ schema = (identifier and identifier.get_parent_name()) or []
+
+ # Suggest tables from either the currently-selected schema or the
+ # public schema if no schema has been specified
+ suggest = [{"type": "table", "schema": schema}]
+
+ if not schema:
+ # Suggest schemas
+ suggest.insert(0, {"type": "schema"})
+
+ # Only tables can be TRUNCATED, otherwise suggest views
+ if token_v != "truncate":
+ suggest.append({"type": "view", "schema": schema})
+
+ return suggest
+
+ elif token_v in ("table", "view", "function"):
+ # E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>'
+ rel_type = token_v
+ schema = (identifier and identifier.get_parent_name()) or []
+ if schema:
+ return [{"type": rel_type, "schema": schema}]
+ else:
+ return [{"type": "schema"}, {"type": rel_type, "schema": []}]
+ elif token_v == "on":
+ tables = extract_tables(full_text) # [(schema, table, alias), ...]
+ parent = (identifier and identifier.get_parent_name()) or []
+ if parent:
+ # "ON parent.<suggestion>"
+ # parent can be either a schema name or table alias
+ tables = [t for t in tables if identifies(parent, *t)]
+ return [
+ {"type": "column", "tables": tables},
+ {"type": "table", "schema": parent},
+ {"type": "view", "schema": parent},
+ {"type": "function", "schema": parent},
+ ]
+ else:
+ # ON <suggestion>
+ # Use table alias if there is one, otherwise the table name
+ aliases = [alias or table for (schema, table, alias) in tables]
+ suggest = [{"type": "alias", "aliases": aliases}]
+
+ # The lists of 'aliases' could be empty if we're trying to complete
+ # a GRANT query. eg: GRANT SELECT, INSERT ON <tab>
+ # In that case we just suggest all tables.
+ if not aliases:
+ suggest.append({"type": "table", "schema": parent})
+ return suggest
+
+ elif token_v in ("use", "database", "template", "connect"):
+ # "\c <db", "use <db>", "DROP DATABASE <db>",
+ # "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
+ return [{"type": "database"}]
+ elif token_v == "tableformat":
+ return [{"type": "table_format"}]
+ elif token_v.endswith(",") or is_operand(token_v) or token_v in ["=", "and", "or"]:
+ prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
+ if prev_keyword:
+ return suggest_based_on_last_token(
+ prev_keyword, text_before_cursor, full_text, identifier
+ )
+ else:
+ return []
+ else:
+ return [{"type": "keyword"}]
+
+
+def identifies(id, schema, table, alias):
+ return id == alias or id == table or (schema and (id == schema + "." + table))
diff --git a/litecli/packages/filepaths.py b/litecli/packages/filepaths.py
new file mode 100644
index 0000000..2f01046
--- /dev/null
+++ b/litecli/packages/filepaths.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8
+
+from __future__ import unicode_literals
+
+from litecli.encodingutils import text_type
+import os
+
+
+def list_path(root_dir):
+ """List directory if exists.
+
+ :param dir: str
+ :return: list
+
+ """
+ res = []
+ if os.path.isdir(root_dir):
+ for name in os.listdir(root_dir):
+ res.append(name)
+ return res
+
+
+def complete_path(curr_dir, last_dir):
+ """Return the path to complete that matches the last entered component.
+
+ If the last entered component is ~, expanded path would not
+ match, so return all of the available paths.
+
+ :param curr_dir: str
+ :param last_dir: str
+ :return: str
+
+ """
+ if not last_dir or curr_dir.startswith(last_dir):
+ return curr_dir
+ elif last_dir == "~":
+ return os.path.join(last_dir, curr_dir)
+
+
+def parse_path(root_dir):
+ """Split path into head and last component for the completer.
+
+ Also return position where last component starts.
+
+ :param root_dir: str path
+ :return: tuple of (string, string, int)
+
+ """
+ base_dir, last_dir, position = "", "", 0
+ if root_dir:
+ base_dir, last_dir = os.path.split(root_dir)
+ position = -len(last_dir) if last_dir else 0
+ return base_dir, last_dir, position
+
+
+def suggest_path(root_dir):
+ """List all files and subdirectories in a directory.
+
+ If the directory is not specified, suggest root directory,
+ user directory, current and parent directory.
+
+ :param root_dir: string: directory to list
+ :return: list
+
+ """
+ if not root_dir:
+ return map(text_type, [os.path.abspath(os.sep), "~", os.curdir, os.pardir])
+
+ if "~" in root_dir:
+ root_dir = text_type(os.path.expanduser(root_dir))
+
+ if not os.path.exists(root_dir):
+ root_dir, _ = os.path.split(root_dir)
+
+ return list_path(root_dir)
+
+
+def dir_path_exists(path):
+ """Check if the directory path exists for a given file.
+
+ For example, for a file /home/user/.cache/litecli/log, check if
+ /home/user/.cache/litecli exists.
+
+ :param str path: The file path.
+ :return: Whether or not the directory path exists.
+
+ """
+ return os.path.exists(os.path.dirname(path))
diff --git a/litecli/packages/parseutils.py b/litecli/packages/parseutils.py
new file mode 100644
index 0000000..3f5ca61
--- /dev/null
+++ b/litecli/packages/parseutils.py
@@ -0,0 +1,227 @@
+from __future__ import print_function
+import re
+import sqlparse
+from sqlparse.sql import IdentifierList, Identifier, Function
+from sqlparse.tokens import Keyword, DML, Punctuation
+
+cleanup_regex = {
+ # This matches only alphanumerics and underscores.
+ "alphanum_underscore": re.compile(r"(\w+)$"),
+ # This matches everything except spaces, parens, colon, and comma
+ "many_punctuations": re.compile(r"([^():,\s]+)$"),
+ # This matches everything except spaces, parens, colon, comma, and period
+ "most_punctuations": re.compile(r"([^\.():,\s]+)$"),
+ # This matches everything except a space.
+ "all_punctuations": re.compile("([^\s]+)$"),
+}
+
+
+def last_word(text, include="alphanum_underscore"):
+ """
+ Find the last word in a sentence.
+
+ >>> last_word('abc')
+ 'abc'
+ >>> last_word(' abc')
+ 'abc'
+ >>> last_word('')
+ ''
+ >>> last_word(' ')
+ ''
+ >>> last_word('abc ')
+ ''
+ >>> last_word('abc def')
+ 'def'
+ >>> last_word('abc def ')
+ ''
+ >>> last_word('abc def;')
+ ''
+ >>> last_word('bac $def')
+ 'def'
+ >>> last_word('bac $def', include='most_punctuations')
+ '$def'
+ >>> last_word('bac \def', include='most_punctuations')
+ '\\\\def'
+ >>> last_word('bac \def;', include='most_punctuations')
+ '\\\\def;'
+ >>> last_word('bac::def', include='most_punctuations')
+ 'def'
+ """
+
+ if not text: # Empty string
+ return ""
+
+ if text[-1].isspace():
+ return ""
+ else:
+ regex = cleanup_regex[include]
+ matches = regex.search(text)
+ if matches:
+ return matches.group(0)
+ else:
+ return ""
+
+
+# This code is borrowed from sqlparse example script.
+# <url>
+def is_subselect(parsed):
+ if not parsed.is_group:
+ return False
+ for item in parsed.tokens:
+ if item.ttype is DML and item.value.upper() in (
+ "SELECT",
+ "INSERT",
+ "UPDATE",
+ "CREATE",
+ "DELETE",
+ ):
+ return True
+ return False
+
+
+def extract_from_part(parsed, stop_at_punctuation=True):
+ tbl_prefix_seen = False
+ for item in parsed.tokens:
+ if tbl_prefix_seen:
+ if is_subselect(item):
+ for x in extract_from_part(item, stop_at_punctuation):
+ yield x
+ elif stop_at_punctuation and item.ttype is Punctuation:
+ return
+ # An incomplete nested select won't be recognized correctly as a
+ # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
+ # the second FROM to trigger this elif condition resulting in a
+ # `return`. So we need to ignore the keyword if the keyword
+ # FROM.
+ # Also 'SELECT * FROM abc JOIN def' will trigger this elif
+ # condition. So we need to ignore the keyword JOIN and its variants
+ # INNER JOIN, FULL OUTER JOIN, etc.
+ elif (
+ item.ttype is Keyword
+ and (not item.value.upper() == "FROM")
+ and (not item.value.upper().endswith("JOIN"))
+ ):
+ return
+ else:
+ yield item
+ elif (
+ item.ttype is Keyword or item.ttype is Keyword.DML
+ ) and item.value.upper() in ("COPY", "FROM", "INTO", "UPDATE", "TABLE", "JOIN"):
+ tbl_prefix_seen = True
+ # 'SELECT a, FROM abc' will detect FROM as part of the column list.
+ # So this check here is necessary.
+ elif isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ if identifier.ttype is Keyword and identifier.value.upper() == "FROM":
+ tbl_prefix_seen = True
+ break
+
+
+def extract_table_identifiers(token_stream):
+ """yields tuples of (schema_name, table_name, table_alias)"""
+
+ for item in token_stream:
+ if isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ # Sometimes Keywords (such as FROM ) are classified as
+ # identifiers which don't have the get_real_name() method.
+ try:
+ schema_name = identifier.get_parent_name()
+ real_name = identifier.get_real_name()
+ except AttributeError:
+ continue
+ if real_name:
+ yield (schema_name, real_name, identifier.get_alias())
+ elif isinstance(item, Identifier):
+ real_name = item.get_real_name()
+ schema_name = item.get_parent_name()
+
+ if real_name:
+ yield (schema_name, real_name, item.get_alias())
+ else:
+ name = item.get_name()
+ yield (None, name, item.get_alias() or name)
+ elif isinstance(item, Function):
+ yield (None, item.get_name(), item.get_name())
+
+
+# extract_tables is inspired from examples in the sqlparse lib.
+def extract_tables(sql):
+ """Extract the table names from an SQL statment.
+
+ Returns a list of (schema, table, alias) tuples
+
+ """
+ parsed = sqlparse.parse(sql)
+ if not parsed:
+ return []
+
+ # INSERT statements must stop looking for tables at the sign of first
+ # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
+ # abc is the table name, but if we don't stop at the first lparen, then
+ # we'll identify abc, col1 and col2 as table names.
+ insert_stmt = parsed[0].token_first().value.lower() == "insert"
+ stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
+ return list(extract_table_identifiers(stream))
+
+
+def find_prev_keyword(sql):
+ """Find the last sql keyword in an SQL statement
+
+ Returns the value of the last keyword, and the text of the query with
+ everything after the last keyword stripped
+ """
+ if not sql.strip():
+ return None, ""
+
+ parsed = sqlparse.parse(sql)[0]
+ flattened = list(parsed.flatten())
+
+ logical_operators = ("AND", "OR", "NOT", "BETWEEN")
+
+ for t in reversed(flattened):
+ if t.value == "(" or (
+ t.is_keyword and (t.value.upper() not in logical_operators)
+ ):
+ # Find the location of token t in the original parsed statement
+ # We can't use parsed.token_index(t) because t may be a child token
+ # inside a TokenList, in which case token_index thows an error
+ # Minimal example:
+ # p = sqlparse.parse('select * from foo where bar')
+ # t = list(p.flatten())[-3] # The "Where" token
+ # p.token_index(t) # Throws ValueError: not in list
+ idx = flattened.index(t)
+
+ # Combine the string values of all tokens in the original list
+ # up to and including the target keyword token t, to produce a
+ # query string with everything after the keyword token removed
+ text = "".join(tok.value for tok in flattened[: idx + 1])
+ return t, text
+
+ return None, ""
+
+
+def query_starts_with(query, prefixes):
+ """Check if the query starts with any item from *prefixes*."""
+ prefixes = [prefix.lower() for prefix in prefixes]
+ formatted_sql = sqlparse.format(query.lower(), strip_comments=True)
+ return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
+
+
+def queries_start_with(queries, prefixes):
+ """Check if any queries start with any item from *prefixes*."""
+ for query in sqlparse.split(queries):
+ if query and query_starts_with(query, prefixes) is True:
+ return True
+ return False
+
+
+def is_destructive(queries):
+ """Returns if any of the queries in *queries* is destructive."""
+ keywords = ("drop", "shutdown", "delete", "truncate", "alter")
+ return queries_start_with(queries, keywords)
+
+
+if __name__ == "__main__":
+ sql = "select * from (select t. from tabl t"
+ print(extract_tables(sql))
diff --git a/litecli/packages/prompt_utils.py b/litecli/packages/prompt_utils.py
new file mode 100644
index 0000000..d9ad2b6
--- /dev/null
+++ b/litecli/packages/prompt_utils.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+from __future__ import unicode_literals
+
+
+import sys
+import click
+from .parseutils import is_destructive
+
+
+def confirm_destructive_query(queries):
+ """Check if the query is destructive and prompts the user to confirm.
+
+ Returns:
+ * None if the query is non-destructive or we can't prompt the user.
+ * True if the query is destructive and the user wants to proceed.
+ * False if the query is destructive and the user doesn't want to proceed.
+
+ """
+ prompt_text = (
+ "You're about to run a destructive command.\n" "Do you want to proceed? (y/n)"
+ )
+ if is_destructive(queries) and sys.stdin.isatty():
+ return prompt(prompt_text, type=bool)
+
+
+def confirm(*args, **kwargs):
+ """Prompt for confirmation (yes/no) and handle any abort exceptions."""
+ try:
+ return click.confirm(*args, **kwargs)
+ except click.Abort:
+ return False
+
+
+def prompt(*args, **kwargs):
+ """Prompt the user for input and handle any abort exceptions."""
+ try:
+ return click.prompt(*args, **kwargs)
+ except click.Abort:
+ return False
diff --git a/litecli/packages/special/__init__.py b/litecli/packages/special/__init__.py
new file mode 100644
index 0000000..fd2b18c
--- /dev/null
+++ b/litecli/packages/special/__init__.py
@@ -0,0 +1,12 @@
+__all__ = []
+
+
+def export(defn):
+ """Decorator to explicitly mark functions that are exposed in a lib."""
+ globals()[defn.__name__] = defn
+ __all__.append(defn.__name__)
+ return defn
+
+
+from . import dbcommands
+from . import iocommands
diff --git a/litecli/packages/special/dbcommands.py b/litecli/packages/special/dbcommands.py
new file mode 100644
index 0000000..203e1a8
--- /dev/null
+++ b/litecli/packages/special/dbcommands.py
@@ -0,0 +1,290 @@
+from __future__ import unicode_literals, print_function
+import csv
+import logging
+import os
+import sys
+import platform
+import shlex
+from sqlite3 import ProgrammingError
+
+from litecli import __version__
+from litecli.packages.special import iocommands
+from litecli.packages.special.utils import format_uptime
+from .main import special_command, RAW_QUERY, PARSED_QUERY, ArgumentMissing
+
+log = logging.getLogger(__name__)
+
+
+@special_command(
+ ".tables",
+ "\\dt",
+ "List tables.",
+ arg_type=PARSED_QUERY,
+ case_sensitive=True,
+ aliases=("\\dt",),
+)
+def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False):
+ if arg:
+ args = ("{0}%".format(arg),)
+ query = """
+ SELECT name FROM sqlite_master
+ WHERE type IN ('table','view') AND name LIKE ? AND name NOT LIKE 'sqlite_%'
+ ORDER BY 1
+ """
+ else:
+ args = tuple()
+ query = """
+ SELECT name FROM sqlite_master
+ WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%'
+ ORDER BY 1
+ """
+
+ log.debug(query)
+ cur.execute(query, args)
+ tables = cur.fetchall()
+ status = ""
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ else:
+ return [(None, None, None, "")]
+
+ # if verbose and arg:
+ # query = "SELECT sql FROM sqlite_master WHERE name LIKE ?"
+ # log.debug(query)
+ # cur.execute(query)
+ # status = cur.fetchone()[1]
+
+ return [(None, tables, headers, status)]
+
+
+@special_command(
+ ".schema",
+ ".schema[+] [table]",
+ "The complete schema for the database or a single table",
+ arg_type=PARSED_QUERY,
+ case_sensitive=True,
+)
+def show_schema(cur, arg=None, **_):
+ if arg:
+ args = (arg,)
+ query = """
+ SELECT sql FROM sqlite_master
+ WHERE name==?
+ ORDER BY tbl_name, type DESC, name
+ """
+ else:
+ args = tuple()
+ query = """
+ SELECT sql FROM sqlite_master
+ ORDER BY tbl_name, type DESC, name
+ """
+
+ log.debug(query)
+ cur.execute(query, args)
+ tables = cur.fetchall()
+ status = ""
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ else:
+ return [(None, None, None, "")]
+
+ return [(None, tables, headers, status)]
+
+
+@special_command(
+ ".databases",
+ ".databases",
+ "List databases.",
+ arg_type=RAW_QUERY,
+ case_sensitive=True,
+ aliases=("\\l",),
+)
+def list_databases(cur, **_):
+ query = "PRAGMA database_list"
+ log.debug(query)
+ cur.execute(query)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ return [(None, cur, headers, "")]
+ else:
+ return [(None, None, None, "")]
+
+
+@special_command(
+ ".indexes",
+ ".indexes [tablename]",
+ "List indexes.",
+ arg_type=PARSED_QUERY,
+ case_sensitive=True,
+ aliases=("\\di",),
+)
+def list_indexes(cur, arg=None, arg_type=PARSED_QUERY, verbose=False):
+ if arg:
+ args = ("{0}%".format(arg),)
+ query = """
+ SELECT name FROM sqlite_master
+ WHERE type = 'index' AND tbl_name LIKE ? AND name NOT LIKE 'sqlite_%'
+ ORDER BY 1
+ """
+ else:
+ args = tuple()
+ query = """
+ SELECT name FROM sqlite_master
+ WHERE type = 'index' AND name NOT LIKE 'sqlite_%'
+ ORDER BY 1
+ """
+
+ log.debug(query)
+ cur.execute(query, args)
+ indexes = cur.fetchall()
+ status = ""
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ else:
+ return [(None, None, None, "")]
+ return [(None, indexes, headers, status)]
+
+
+@special_command(
+ ".status",
+ "\\s",
+ "Show current settings.",
+ arg_type=RAW_QUERY,
+ aliases=("\\s",),
+ case_sensitive=True,
+)
+def status(cur, **_):
+ # Create output buffers.
+ footer = []
+ footer.append("--------------")
+
+ # Output the litecli client information.
+ implementation = platform.python_implementation()
+ version = platform.python_version()
+ client_info = []
+ client_info.append("litecli {0},".format(__version__))
+ client_info.append("running on {0} {1}".format(implementation, version))
+ footer.append(" ".join(client_info))
+
+ # Build the output that will be displayed as a table.
+ query = "SELECT file from pragma_database_list() where name = 'main';"
+ log.debug(query)
+ cur.execute(query)
+ db = cur.fetchone()[0]
+ if db is None:
+ db = ""
+
+ footer.append("Current database: " + db)
+ if iocommands.is_pager_enabled():
+ if "PAGER" in os.environ:
+ pager = os.environ["PAGER"]
+ else:
+ pager = "System default"
+ else:
+ pager = "stdout"
+ footer.append("Current pager:" + pager)
+
+ footer.append("--------------")
+ return [(None, None, "", "\n".join(footer))]
+
+
+@special_command(
+ ".load",
+ ".load path",
+ "Load an extension library.",
+ arg_type=PARSED_QUERY,
+ case_sensitive=True,
+)
+def load_extension(cur, arg, **_):
+ args = shlex.split(arg)
+ if len(args) != 1:
+ raise TypeError(".load accepts exactly one path")
+ path = args[0]
+ conn = cur.connection
+ conn.enable_load_extension(True)
+ conn.load_extension(path)
+ return [(None, None, None, "")]
+
+
+@special_command(
+ "describe",
+ "\\d [table]",
+ "Description of a table",
+ arg_type=PARSED_QUERY,
+ case_sensitive=True,
+ aliases=("\\d", "describe", "desc"),
+)
+def describe(cur, arg, **_):
+ if arg:
+ args = (arg,)
+ query = """
+ PRAGMA table_info({})
+ """.format(
+ arg
+ )
+ else:
+ raise ArgumentMissing("Table name required.")
+
+ log.debug(query)
+ cur.execute(query)
+ tables = cur.fetchall()
+ status = ""
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ else:
+ return [(None, None, None, "")]
+
+ return [(None, tables, headers, status)]
+
+
+@special_command(
+ ".import",
+ ".import filename table",
+ "Import data from filename into an existing table",
+ arg_type=PARSED_QUERY,
+ case_sensitive=True,
+)
+def import_file(cur, arg=None, **_):
+ def split(s):
+ # this is a modification of shlex.split function, just to make it support '`',
+ # because table name might contain '`' character.
+ lex = shlex.shlex(s, posix=True)
+ lex.whitespace_split = True
+ lex.commenters = ""
+ lex.quotes += "`"
+ return list(lex)
+
+ args = split(arg)
+ log.debug("[arg = %r], [args = %r]", arg, args)
+ if len(args) != 2:
+ raise TypeError("Usage: .import filename table")
+
+ filename, table = args
+ cur.execute('PRAGMA table_info("%s")' % table)
+ ncols = len(cur.fetchall())
+ insert_tmpl = 'INSERT INTO "%s" VALUES (?%s)' % (table, ",?" * (ncols - 1))
+
+ with open(filename, "r") as csvfile:
+ dialect = csv.Sniffer().sniff(csvfile.read(1024))
+ csvfile.seek(0)
+ reader = csv.reader(csvfile, dialect)
+
+ cur.execute("BEGIN")
+ ninserted, nignored = 0, 0
+ for i, row in enumerate(reader):
+ if len(row) != ncols:
+ print(
+ "%s:%d expected %d columns but found %d - ignored"
+ % (filename, i, ncols, len(row)),
+ file=sys.stderr,
+ )
+ nignored += 1
+ continue
+ cur.execute(insert_tmpl, row)
+ ninserted += 1
+ cur.execute("COMMIT")
+
+ status = "Inserted %d rows into %s" % (ninserted, table)
+ if nignored > 0:
+ status += " (%d rows are ignored)" % nignored
+ return [(None, None, None, status)]
diff --git a/litecli/packages/special/favoritequeries.py b/litecli/packages/special/favoritequeries.py
new file mode 100644
index 0000000..7da6fbf
--- /dev/null
+++ b/litecli/packages/special/favoritequeries.py
@@ -0,0 +1,59 @@
+# -*- coding: utf-8 -*-
+from __future__ import unicode_literals
+
+
+class FavoriteQueries(object):
+
+ section_name = "favorite_queries"
+
+ usage = """
+Favorite Queries are a way to save frequently used queries
+with a short name.
+Examples:
+
+ # Save a new favorite query.
+ > \\fs simple select * from abc where a is not Null;
+
+ # List all favorite queries.
+ > \\f
+ ╒════════╤═══════════════════════════════════════╕
+ │ Name │ Query │
+ ╞════════╪═══════════════════════════════════════╡
+ │ simple │ SELECT * FROM abc where a is not NULL │
+ ╘════════╧═══════════════════════════════════════╛
+
+ # Run a favorite query.
+ > \\f simple
+ ╒════════╤════════╕
+ │ a │ b │
+ ╞════════╪════════╡
+ │ 日本語 │ 日本語 │
+ ╘════════╧════════╛
+
+ # Delete a favorite query.
+ > \\fd simple
+ simple: Deleted
+"""
+
+ def __init__(self, config):
+ self.config = config
+
+ def list(self):
+ return self.config.get(self.section_name, [])
+
+ def get(self, name):
+ return self.config.get(self.section_name, {}).get(name, None)
+
+ def save(self, name, query):
+ if self.section_name not in self.config:
+ self.config[self.section_name] = {}
+ self.config[self.section_name][name] = query
+ self.config.write()
+
+ def delete(self, name):
+ try:
+ del self.config[self.section_name][name]
+ except KeyError:
+ return "%s: Not Found." % name
+ self.config.write()
+ return "%s: Deleted" % name
diff --git a/litecli/packages/special/iocommands.py b/litecli/packages/special/iocommands.py
new file mode 100644
index 0000000..43c3577
--- /dev/null
+++ b/litecli/packages/special/iocommands.py
@@ -0,0 +1,478 @@
+from __future__ import unicode_literals
+import os
+import re
+import locale
+import logging
+import subprocess
+import shlex
+from io import open
+from time import sleep
+
+import click
+import sqlparse
+from configobj import ConfigObj
+
+from . import export
+from .main import special_command, NO_QUERY, PARSED_QUERY
+from .favoritequeries import FavoriteQueries
+from .utils import handle_cd_command
+from litecli.packages.prompt_utils import confirm_destructive_query
+
+use_expanded_output = False
+PAGER_ENABLED = True
+tee_file = None
+once_file = written_to_once_file = None
+favoritequeries = FavoriteQueries(ConfigObj())
+
+
+@export
+def set_favorite_queries(config):
+ global favoritequeries
+ favoritequeries = FavoriteQueries(config)
+
+
+@export
+def set_pager_enabled(val):
+ global PAGER_ENABLED
+ PAGER_ENABLED = val
+
+
+@export
+def is_pager_enabled():
+ return PAGER_ENABLED
+
+
+@export
+@special_command(
+ "pager",
+ "\\P [command]",
+ "Set PAGER. Print the query results via PAGER.",
+ arg_type=PARSED_QUERY,
+ aliases=("\\P",),
+ case_sensitive=True,
+)
+def set_pager(arg, **_):
+ if arg:
+ os.environ["PAGER"] = arg
+ msg = "PAGER set to %s." % arg
+ set_pager_enabled(True)
+ else:
+ if "PAGER" in os.environ:
+ msg = "PAGER set to %s." % os.environ["PAGER"]
+ else:
+ # This uses click's default per echo_via_pager.
+ msg = "Pager enabled."
+ set_pager_enabled(True)
+
+ return [(None, None, None, msg)]
+
+
+@export
+@special_command(
+ "nopager",
+ "\\n",
+ "Disable pager, print to stdout.",
+ arg_type=NO_QUERY,
+ aliases=("\\n",),
+ case_sensitive=True,
+)
+def disable_pager():
+ set_pager_enabled(False)
+ return [(None, None, None, "Pager disabled.")]
+
+
+@export
+def set_expanded_output(val):
+ global use_expanded_output
+ use_expanded_output = val
+
+
+@export
+def is_expanded_output():
+ return use_expanded_output
+
+
+_logger = logging.getLogger(__name__)
+
+
+@export
+def editor_command(command):
+ """
+ Is this an external editor command?
+ :param command: string
+ """
+ # It is possible to have `\e filename` or `SELECT * FROM \e`. So we check
+ # for both conditions.
+ return command.strip().endswith("\\e") or command.strip().startswith("\\e")
+
+
+@export
+def get_filename(sql):
+ if sql.strip().startswith("\\e"):
+ command, _, filename = sql.partition(" ")
+ return filename.strip() or None
+
+
+@export
+def get_editor_query(sql):
+ """Get the query part of an editor command."""
+ sql = sql.strip()
+
+ # The reason we can't simply do .strip('\e') is that it strips characters,
+ # not a substring. So it'll strip "e" in the end of the sql also!
+ # Ex: "select * from style\e" -> "select * from styl".
+ pattern = re.compile("(^\\\e|\\\e$)")
+ while pattern.search(sql):
+ sql = pattern.sub("", sql)
+
+ return sql
+
+
+@export
+def open_external_editor(filename=None, sql=None):
+ """Open external editor, wait for the user to type in their query, return
+ the query.
+
+ :return: list with one tuple, query as first element.
+
+ """
+
+ message = None
+ filename = filename.strip().split(" ", 1)[0] if filename else None
+
+ sql = sql or ""
+ MARKER = "# Type your query above this line.\n"
+
+ # Populate the editor buffer with the partial sql (if available) and a
+ # placeholder comment.
+ query = click.edit(
+ "{sql}\n\n{marker}".format(sql=sql, marker=MARKER),
+ filename=filename,
+ extension=".sql",
+ )
+
+ if filename:
+ try:
+ with open(filename, encoding="utf-8") as f:
+ query = f.read()
+ except IOError:
+ message = "Error reading file: %s." % filename
+
+ if query is not None:
+ query = query.split(MARKER, 1)[0].rstrip("\n")
+ else:
+ # Don't return None for the caller to deal with.
+ # Empty string is ok.
+ query = sql
+
+ return (query, message)
+
+
+@special_command(
+ "\\f",
+ "\\f [name [args..]]",
+ "List or execute favorite queries.",
+ arg_type=PARSED_QUERY,
+ case_sensitive=True,
+)
+def execute_favorite_query(cur, arg, verbose=False, **_):
+ """Returns (title, rows, headers, status)"""
+ if arg == "":
+ for result in list_favorite_queries():
+ yield result
+
+ """Parse out favorite name and optional substitution parameters"""
+ name, _, arg_str = arg.partition(" ")
+ args = shlex.split(arg_str)
+
+ query = favoritequeries.get(name)
+ if query is None:
+ message = "No favorite query: %s" % (name)
+ yield (None, None, None, message)
+ elif "?" in query:
+ for sql in sqlparse.split(query):
+ sql = sql.rstrip(";")
+ title = "> %s" % (sql) if verbose else None
+ cur.execute(sql, args)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ yield (title, cur, headers, None)
+ else:
+ yield (title, None, None, None)
+ else:
+ query, arg_error = subst_favorite_query_args(query, args)
+ if arg_error:
+ yield (None, None, None, arg_error)
+ else:
+ for sql in sqlparse.split(query):
+ sql = sql.rstrip(";")
+ title = "> %s" % (sql) if verbose else None
+ cur.execute(sql)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ yield (title, cur, headers, None)
+ else:
+ yield (title, None, None, None)
+
+
+def list_favorite_queries():
+ """List of all favorite queries.
+ Returns (title, rows, headers, status)"""
+
+ headers = ["Name", "Query"]
+ rows = [(r, favoritequeries.get(r)) for r in favoritequeries.list()]
+
+ if not rows:
+ status = "\nNo favorite queries found." + favoritequeries.usage
+ else:
+ status = ""
+ return [("", rows, headers, status)]
+
+
+def subst_favorite_query_args(query, args):
+ """replace positional parameters ($1...$N) in query."""
+ for idx, val in enumerate(args):
+ shell_subst_var = "$" + str(idx + 1)
+ question_subst_var = "?"
+ if shell_subst_var in query:
+ query = query.replace(shell_subst_var, val)
+ elif question_subst_var in query:
+ query = query.replace(question_subst_var, val, 1)
+ else:
+ return [
+ None,
+ "Too many arguments.\nQuery does not have enough place holders to substitute.\n"
+ + query,
+ ]
+
+ match = re.search("\\?|\\$\d+", query)
+ if match:
+ return [
+ None,
+ "missing substitution for " + match.group(0) + " in query:\n " + query,
+ ]
+
+ return [query, None]
+
+
+@special_command("\\fs", "\\fs name query", "Save a favorite query.")
+def save_favorite_query(arg, **_):
+ """Save a new favorite query.
+ Returns (title, rows, headers, status)"""
+
+ usage = "Syntax: \\fs name query.\n\n" + favoritequeries.usage
+ if not arg:
+ return [(None, None, None, usage)]
+
+ name, _, query = arg.partition(" ")
+
+ # If either name or query is missing then print the usage and complain.
+ if (not name) or (not query):
+ return [(None, None, None, usage + "Err: Both name and query are required.")]
+
+ favoritequeries.save(name, query)
+ return [(None, None, None, "Saved.")]
+
+
+@special_command("\\fd", "\\fd [name]", "Delete a favorite query.")
+def delete_favorite_query(arg, **_):
+ """Delete an existing favorite query."""
+ usage = "Syntax: \\fd name.\n\n" + favoritequeries.usage
+ if not arg:
+ return [(None, None, None, usage)]
+
+ status = favoritequeries.delete(arg)
+
+ return [(None, None, None, status)]
+
+
+@special_command("system", "system [command]", "Execute a system shell commmand.")
+def execute_system_command(arg, **_):
+ """Execute a system shell command."""
+ usage = "Syntax: system [command].\n"
+
+ if not arg:
+ return [(None, None, None, usage)]
+
+ try:
+ command = arg.strip()
+ if command.startswith("cd"):
+ ok, error_message = handle_cd_command(arg)
+ if not ok:
+ return [(None, None, None, error_message)]
+ return [(None, None, None, "")]
+
+ args = arg.split(" ")
+ process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ output, error = process.communicate()
+ response = output if not error else error
+
+ # Python 3 returns bytes. This needs to be decoded to a string.
+ if isinstance(response, bytes):
+ encoding = locale.getpreferredencoding(False)
+ response = response.decode(encoding)
+
+ return [(None, None, None, response)]
+ except OSError as e:
+ return [(None, None, None, "OSError: %s" % e.strerror)]
+
+
+def parseargfile(arg):
+ if arg.startswith("-o "):
+ mode = "w"
+ filename = arg[3:]
+ else:
+ mode = "a"
+ filename = arg
+
+ if not filename:
+ raise TypeError("You must provide a filename.")
+
+ return {"file": os.path.expanduser(filename), "mode": mode}
+
+
+@special_command(
+ "tee",
+ "tee [-o] filename",
+ "Append all results to an output file (overwrite using -o).",
+)
+def set_tee(arg, **_):
+ global tee_file
+
+ try:
+ tee_file = open(**parseargfile(arg))
+ except (IOError, OSError) as e:
+ raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror))
+
+ return [(None, None, None, "")]
+
+
+@export
+def close_tee():
+ global tee_file
+ if tee_file:
+ tee_file.close()
+ tee_file = None
+
+
+@special_command("notee", "notee", "Stop writing results to an output file.")
+def no_tee(arg, **_):
+ close_tee()
+ return [(None, None, None, "")]
+
+
+@export
+def write_tee(output):
+ global tee_file
+ if tee_file:
+ click.echo(output, file=tee_file, nl=False)
+ click.echo("\n", file=tee_file, nl=False)
+ tee_file.flush()
+
+
+@special_command(
+ ".once",
+ "\\o [-o] filename",
+ "Append next result to an output file (overwrite using -o).",
+ aliases=("\\o", "\\once"),
+)
+def set_once(arg, **_):
+ global once_file
+
+ once_file = parseargfile(arg)
+
+ return [(None, None, None, "")]
+
+
+@export
+def write_once(output):
+ global once_file, written_to_once_file
+ if output and once_file:
+ try:
+ f = open(**once_file)
+ except (IOError, OSError) as e:
+ once_file = None
+ raise OSError(
+ "Cannot write to file '{}': {}".format(e.filename, e.strerror)
+ )
+
+ with f:
+ click.echo(output, file=f, nl=False)
+ click.echo("\n", file=f, nl=False)
+ written_to_once_file = True
+
+
+@export
+def unset_once_if_written():
+ """Unset the once file, if it has been written to."""
+ global once_file
+ if written_to_once_file:
+ once_file = None
+
+
+@special_command(
+ "watch",
+ "watch [seconds] [-c] query",
+ "Executes the query every [seconds] seconds (by default 5).",
+)
+def watch_query(arg, **kwargs):
+ usage = """Syntax: watch [seconds] [-c] query.
+ * seconds: The interval at the query will be repeated, in seconds.
+ By default 5.
+ * -c: Clears the screen between every iteration.
+"""
+ if not arg:
+ yield (None, None, None, usage)
+ raise StopIteration
+ seconds = 5
+ clear_screen = False
+ statement = None
+ while statement is None:
+ arg = arg.strip()
+ if not arg:
+ # Oops, we parsed all the arguments without finding a statement
+ yield (None, None, None, usage)
+ raise StopIteration
+ (current_arg, _, arg) = arg.partition(" ")
+ try:
+ seconds = float(current_arg)
+ continue
+ except ValueError:
+ pass
+ if current_arg == "-c":
+ clear_screen = True
+ continue
+ statement = "{0!s} {1!s}".format(current_arg, arg)
+ destructive_prompt = confirm_destructive_query(statement)
+ if destructive_prompt is False:
+ click.secho("Wise choice!")
+ raise StopIteration
+ elif destructive_prompt is True:
+ click.secho("Your call!")
+ cur = kwargs["cur"]
+ sql_list = [
+ (sql.rstrip(";"), "> {0!s}".format(sql)) for sql in sqlparse.split(statement)
+ ]
+ old_pager_enabled = is_pager_enabled()
+ while True:
+ if clear_screen:
+ click.clear()
+ try:
+ # Somewhere in the code the pager its activated after every yield,
+ # so we disable it in every iteration
+ set_pager_enabled(False)
+ for (sql, title) in sql_list:
+ cur.execute(sql)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ yield (title, cur, headers, None)
+ else:
+ yield (title, None, None, None)
+ sleep(seconds)
+ except KeyboardInterrupt:
+ # This prints the Ctrl-C character in its own line, which prevents
+ # to print a line with the cursor positioned behind the prompt
+ click.secho("", nl=True)
+ raise StopIteration
+ finally:
+ set_pager_enabled(old_pager_enabled)
diff --git a/litecli/packages/special/main.py b/litecli/packages/special/main.py
new file mode 100644
index 0000000..3dd0e77
--- /dev/null
+++ b/litecli/packages/special/main.py
@@ -0,0 +1,160 @@
+from __future__ import unicode_literals
+import logging
+from collections import namedtuple
+
+from . import export
+
+log = logging.getLogger(__name__)
+
+NO_QUERY = 0
+PARSED_QUERY = 1
+RAW_QUERY = 2
+
+SpecialCommand = namedtuple(
+ "SpecialCommand",
+ [
+ "handler",
+ "command",
+ "shortcut",
+ "description",
+ "arg_type",
+ "hidden",
+ "case_sensitive",
+ ],
+)
+
+COMMANDS = {}
+
+
+@export
+class ArgumentMissing(Exception):
+ pass
+
+
+@export
+class CommandNotFound(Exception):
+ pass
+
+
+@export
+def parse_special_command(sql):
+ command, _, arg = sql.partition(" ")
+ verbose = "+" in command
+ command = command.strip().replace("+", "")
+ return (command, verbose, arg.strip())
+
+
+@export
+def special_command(
+ command,
+ shortcut,
+ description,
+ arg_type=PARSED_QUERY,
+ hidden=False,
+ case_sensitive=False,
+ aliases=(),
+):
+ def wrapper(wrapped):
+ register_special_command(
+ wrapped,
+ command,
+ shortcut,
+ description,
+ arg_type,
+ hidden,
+ case_sensitive,
+ aliases,
+ )
+ return wrapped
+
+ return wrapper
+
+
+@export
+def register_special_command(
+ handler,
+ command,
+ shortcut,
+ description,
+ arg_type=PARSED_QUERY,
+ hidden=False,
+ case_sensitive=False,
+ aliases=(),
+):
+ cmd = command.lower() if not case_sensitive else command
+ COMMANDS[cmd] = SpecialCommand(
+ handler, command, shortcut, description, arg_type, hidden, case_sensitive
+ )
+ for alias in aliases:
+ cmd = alias.lower() if not case_sensitive else alias
+ COMMANDS[cmd] = SpecialCommand(
+ handler,
+ command,
+ shortcut,
+ description,
+ arg_type,
+ case_sensitive=case_sensitive,
+ hidden=True,
+ )
+
+
+@export
+def execute(cur, sql):
+ """Execute a special command and return the results. If the special command
+ is not supported a KeyError will be raised.
+ """
+ command, verbose, arg = parse_special_command(sql)
+
+ if (command not in COMMANDS) and (command.lower() not in COMMANDS):
+ raise CommandNotFound
+
+ try:
+ special_cmd = COMMANDS[command]
+ except KeyError:
+ special_cmd = COMMANDS[command.lower()]
+ if special_cmd.case_sensitive:
+ raise CommandNotFound("Command not found: %s" % command)
+
+ if special_cmd.arg_type == NO_QUERY:
+ return special_cmd.handler()
+ elif special_cmd.arg_type == PARSED_QUERY:
+ return special_cmd.handler(cur=cur, arg=arg, verbose=verbose)
+ elif special_cmd.arg_type == RAW_QUERY:
+ return special_cmd.handler(cur=cur, query=sql)
+
+
+@special_command(
+ "help", "\\?", "Show this help.", arg_type=NO_QUERY, aliases=("\\?", "?")
+)
+def show_help(): # All the parameters are ignored.
+ headers = ["Command", "Shortcut", "Description"]
+ result = []
+
+ for _, value in sorted(COMMANDS.items()):
+ if not value.hidden:
+ result.append((value.command, value.shortcut, value.description))
+ return [(None, result, headers, None)]
+
+
+@special_command(".exit", "\\q", "Exit.", arg_type=NO_QUERY, aliases=("\\q", "exit"))
+@special_command("quit", "\\q", "Quit.", arg_type=NO_QUERY)
+def quit(*_args):
+ raise EOFError
+
+
+@special_command(
+ "\\e",
+ "\\e",
+ "Edit command with editor (uses $EDITOR).",
+ arg_type=NO_QUERY,
+ case_sensitive=True,
+)
+@special_command(
+ "\\G",
+ "\\G",
+ "Display current query results vertically.",
+ arg_type=NO_QUERY,
+ case_sensitive=True,
+)
+def stub():
+ raise NotImplementedError
diff --git a/litecli/packages/special/utils.py b/litecli/packages/special/utils.py
new file mode 100644
index 0000000..eed9306
--- /dev/null
+++ b/litecli/packages/special/utils.py
@@ -0,0 +1,48 @@
+import os
+import subprocess
+
+
+def handle_cd_command(arg):
+ """Handles a `cd` shell command by calling python's os.chdir."""
+ CD_CMD = "cd"
+ tokens = arg.split(CD_CMD + " ")
+ directory = tokens[-1] if len(tokens) > 1 else None
+ if not directory:
+ return False, "No folder name was provided."
+ try:
+ os.chdir(directory)
+ subprocess.call(["pwd"])
+ return True, None
+ except OSError as e:
+ return False, e.strerror
+
+
+def format_uptime(uptime_in_seconds):
+ """Format number of seconds into human-readable string.
+
+ :param uptime_in_seconds: The server uptime in seconds.
+ :returns: A human-readable string representing the uptime.
+
+ >>> uptime = format_uptime('56892')
+ >>> print(uptime)
+ 15 hours 48 min 12 sec
+ """
+
+ m, s = divmod(int(uptime_in_seconds), 60)
+ h, m = divmod(m, 60)
+ d, h = divmod(h, 24)
+
+ uptime_values = []
+
+ for value, unit in ((d, "days"), (h, "hours"), (m, "min"), (s, "sec")):
+ if value == 0 and not uptime_values:
+ # Don't include a value/unit if the unit isn't applicable to
+ # the uptime. E.g. don't do 0 days 0 hours 1 min 30 sec.
+ continue
+ elif value == 1 and unit.endswith("s"):
+ # Remove the "s" if the unit is singular.
+ unit = unit[:-1]
+ uptime_values.append("{0} {1}".format(value, unit))
+
+ uptime = " ".join(uptime_values)
+ return uptime
diff --git a/litecli/sqlcompleter.py b/litecli/sqlcompleter.py
new file mode 100644
index 0000000..64ca352
--- /dev/null
+++ b/litecli/sqlcompleter.py
@@ -0,0 +1,612 @@
+from __future__ import print_function
+from __future__ import unicode_literals
+import logging
+from re import compile, escape
+from collections import Counter
+
+from prompt_toolkit.completion import Completer, Completion
+
+from .packages.completion_engine import suggest_type
+from .packages.parseutils import last_word
+from .packages.special.iocommands import favoritequeries
+from .packages.filepaths import parse_path, complete_path, suggest_path
+
+_logger = logging.getLogger(__name__)
+
+
+class SQLCompleter(Completer):
+ keywords = [
+ "ABORT",
+ "ACTION",
+ "ADD",
+ "AFTER",
+ "ALL",
+ "ALTER",
+ "ANALYZE",
+ "AND",
+ "AS",
+ "ASC",
+ "ATTACH",
+ "AUTOINCREMENT",
+ "BEFORE",
+ "BEGIN",
+ "BETWEEN",
+ "BIGINT",
+ "BLOB",
+ "BOOLEAN",
+ "BY",
+ "CASCADE",
+ "CASE",
+ "CAST",
+ "CHARACTER",
+ "CHECK",
+ "CLOB",
+ "COLLATE",
+ "COLUMN",
+ "COMMIT",
+ "CONFLICT",
+ "CONSTRAINT",
+ "CREATE",
+ "CROSS",
+ "CURRENT",
+ "CURRENT_DATE",
+ "CURRENT_TIME",
+ "CURRENT_TIMESTAMP",
+ "DATABASE",
+ "DATE",
+ "DATETIME",
+ "DECIMAL",
+ "DEFAULT",
+ "DEFERRABLE",
+ "DEFERRED",
+ "DELETE",
+ "DETACH",
+ "DISTINCT",
+ "DO",
+ "DOUBLE PRECISION",
+ "DOUBLE",
+ "DROP",
+ "EACH",
+ "ELSE",
+ "END",
+ "ESCAPE",
+ "EXCEPT",
+ "EXCLUSIVE",
+ "EXISTS",
+ "EXPLAIN",
+ "FAIL",
+ "FILTER",
+ "FLOAT",
+ "FOLLOWING",
+ "FOR",
+ "FOREIGN",
+ "FROM",
+ "FULL",
+ "GLOB",
+ "GROUP",
+ "HAVING",
+ "IF",
+ "IGNORE",
+ "IMMEDIATE",
+ "IN",
+ "INDEX",
+ "INDEXED",
+ "INITIALLY",
+ "INNER",
+ "INSERT",
+ "INSTEAD",
+ "INT",
+ "INT2",
+ "INT8",
+ "INTEGER",
+ "INTERSECT",
+ "INTO",
+ "IS",
+ "ISNULL",
+ "JOIN",
+ "KEY",
+ "LEFT",
+ "LIKE",
+ "LIMIT",
+ "MATCH",
+ "MEDIUMINT",
+ "NATIVE CHARACTER",
+ "NATURAL",
+ "NCHAR",
+ "NO",
+ "NOT",
+ "NOTHING",
+ "NULL",
+ "NULLS FIRST",
+ "NULLS LAST",
+ "NUMERIC",
+ "NVARCHAR",
+ "OF",
+ "OFFSET",
+ "ON",
+ "OR",
+ "ORDER BY",
+ "OUTER",
+ "OVER",
+ "PARTITION",
+ "PLAN",
+ "PRAGMA",
+ "PRECEDING",
+ "PRIMARY",
+ "QUERY",
+ "RAISE",
+ "RANGE",
+ "REAL",
+ "RECURSIVE",
+ "REFERENCES",
+ "REGEXP",
+ "REINDEX",
+ "RELEASE",
+ "RENAME",
+ "REPLACE",
+ "RESTRICT",
+ "RIGHT",
+ "ROLLBACK",
+ "ROW",
+ "ROWS",
+ "SAVEPOINT",
+ "SELECT",
+ "SET",
+ "SMALLINT",
+ "TABLE",
+ "TEMP",
+ "TEMPORARY",
+ "TEXT",
+ "THEN",
+ "TINYINT",
+ "TO",
+ "TRANSACTION",
+ "TRIGGER",
+ "UNBOUNDED",
+ "UNION",
+ "UNIQUE",
+ "UNSIGNED BIG INT",
+ "UPDATE",
+ "USING",
+ "VACUUM",
+ "VALUES",
+ "VARCHAR",
+ "VARYING CHARACTER",
+ "VIEW",
+ "VIRTUAL",
+ "WHEN",
+ "WHERE",
+ "WINDOW",
+ "WITH",
+ "WITHOUT",
+ ]
+
+ functions = [
+ "ABS",
+ "AVG",
+ "CHANGES",
+ "CHAR",
+ "COALESCE",
+ "COUNT",
+ "CUME_DIST",
+ "DATE",
+ "DATETIME",
+ "DENSE_RANK",
+ "GLOB",
+ "GROUP_CONCAT",
+ "HEX",
+ "IFNULL",
+ "INSTR",
+ "JSON",
+ "JSON_ARRAY",
+ "JSON_ARRAY_LENGTH",
+ "JSON_EACH",
+ "JSON_EXTRACT",
+ "JSON_GROUP_ARRAY",
+ "JSON_GROUP_OBJECT",
+ "JSON_INSERT",
+ "JSON_OBJECT",
+ "JSON_PATCH",
+ "JSON_QUOTE",
+ "JSON_REMOVE",
+ "JSON_REPLACE",
+ "JSON_SET",
+ "JSON_TREE",
+ "JSON_TYPE",
+ "JSON_VALID",
+ "JULIANDAY",
+ "LAG",
+ "LAST_INSERT_ROWID",
+ "LENGTH",
+ "LIKELIHOOD",
+ "LIKELY",
+ "LOAD_EXTENSION",
+ "LOWER",
+ "LTRIM",
+ "MAX",
+ "MIN",
+ "NTILE",
+ "NULLIF",
+ "PERCENT_RANK",
+ "PRINTF",
+ "QUOTE",
+ "RANDOM",
+ "RANDOMBLOB",
+ "RANK",
+ "REPLACE",
+ "ROUND",
+ "ROW_NUMBER",
+ "RTRIM",
+ "SOUNDEX",
+ "SQLITE_COMPILEOPTION_GET",
+ "SQLITE_COMPILEOPTION_USED",
+ "SQLITE_OFFSET",
+ "SQLITE_SOURCE_ID",
+ "SQLITE_VERSION",
+ "STRFTIME",
+ "SUBSTR",
+ "SUM",
+ "TIME",
+ "TOTAL",
+ "TOTAL_CHANGES",
+ "TRIM",
+ ]
+
+ def __init__(self, supported_formats=(), keyword_casing="auto"):
+ super(self.__class__, self).__init__()
+ self.reserved_words = set()
+ for x in self.keywords:
+ self.reserved_words.update(x.split())
+ self.name_pattern = compile("^[_a-z][_a-z0-9\$]*$")
+
+ self.special_commands = []
+ self.table_formats = supported_formats
+ if keyword_casing not in ("upper", "lower", "auto"):
+ keyword_casing = "auto"
+ self.keyword_casing = keyword_casing
+ self.reset_completions()
+
+ def escape_name(self, name):
+ if name and (
+ (not self.name_pattern.match(name))
+ or (name.upper() in self.reserved_words)
+ or (name.upper() in self.functions)
+ ):
+ name = "`%s`" % name
+
+ return name
+
+ def unescape_name(self, name):
+ """Unquote a string."""
+ if name and name[0] == '"' and name[-1] == '"':
+ name = name[1:-1]
+
+ return name
+
+ def escaped_names(self, names):
+ return [self.escape_name(name) for name in names]
+
+ def extend_special_commands(self, special_commands):
+ # Special commands are not part of all_completions since they can only
+ # be at the beginning of a line.
+ self.special_commands.extend(special_commands)
+
+ def extend_database_names(self, databases):
+ self.databases.extend(databases)
+
+ def extend_keywords(self, additional_keywords):
+ self.keywords.extend(additional_keywords)
+ self.all_completions.update(additional_keywords)
+
+ def extend_schemata(self, schema):
+ if schema is None:
+ return
+ metadata = self.dbmetadata["tables"]
+ metadata[schema] = {}
+
+ # dbmetadata.values() are the 'tables' and 'functions' dicts
+ for metadata in self.dbmetadata.values():
+ metadata[schema] = {}
+ self.all_completions.update(schema)
+
+ def extend_relations(self, data, kind):
+ """Extend metadata for tables or views
+
+ :param data: list of (rel_name, ) tuples
+ :param kind: either 'tables' or 'views'
+ :return:
+ """
+ # 'data' is a generator object. It can throw an exception while being
+ # consumed. This could happen if the user has launched the app without
+ # specifying a database name. This exception must be handled to prevent
+ # crashing.
+ try:
+ data = [self.escaped_names(d) for d in data]
+ except Exception:
+ data = []
+
+ # dbmetadata['tables'][$schema_name][$table_name] should be a list of
+ # column names. Default to an asterisk
+ metadata = self.dbmetadata[kind]
+ for relname in data:
+ try:
+ metadata[self.dbname][relname[0]] = ["*"]
+ except KeyError:
+ _logger.error(
+ "%r %r listed in unrecognized schema %r",
+ kind,
+ relname[0],
+ self.dbname,
+ )
+ self.all_completions.add(relname[0])
+
+ def extend_columns(self, column_data, kind):
+ """Extend column metadata
+
+ :param column_data: list of (rel_name, column_name) tuples
+ :param kind: either 'tables' or 'views'
+ :return:
+ """
+ # 'column_data' is a generator object. It can throw an exception while
+ # being consumed. This could happen if the user has launched the app
+ # without specifying a database name. This exception must be handled to
+ # prevent crashing.
+ try:
+ column_data = [self.escaped_names(d) for d in column_data]
+ except Exception:
+ column_data = []
+
+ metadata = self.dbmetadata[kind]
+ for relname, column in column_data:
+ metadata[self.dbname][relname].append(column)
+ self.all_completions.add(column)
+
+ def extend_functions(self, func_data):
+ # 'func_data' is a generator object. It can throw an exception while
+ # being consumed. This could happen if the user has launched the app
+ # without specifying a database name. This exception must be handled to
+ # prevent crashing.
+ try:
+ func_data = [self.escaped_names(d) for d in func_data]
+ except Exception:
+ func_data = []
+
+ # dbmetadata['functions'][$schema_name][$function_name] should return
+ # function metadata.
+ metadata = self.dbmetadata["functions"]
+
+ for func in func_data:
+ metadata[self.dbname][func[0]] = None
+ self.all_completions.add(func[0])
+
+ def set_dbname(self, dbname):
+ self.dbname = dbname
+
+ def reset_completions(self):
+ self.databases = []
+ self.dbname = ""
+ self.dbmetadata = {"tables": {}, "views": {}, "functions": {}}
+ self.all_completions = set(self.keywords + self.functions)
+
+ @staticmethod
+ def find_matches(
+ text,
+ collection,
+ start_only=False,
+ fuzzy=True,
+ casing=None,
+ punctuations="most_punctuations",
+ ):
+ """Find completion matches for the given text.
+
+ Given the user's input text and a collection of available
+ completions, find completions matching the last word of the
+ text.
+
+ If `start_only` is True, the text will match an available
+ completion only at the beginning. Otherwise, a completion is
+ considered a match if the text appears anywhere within it.
+
+ yields prompt_toolkit Completion instances for any matches found
+ in the collection of available completions.
+ """
+ last = last_word(text, include=punctuations)
+ text = last.lower()
+
+ completions = []
+
+ if fuzzy:
+ regex = ".*?".join(map(escape, text))
+ pat = compile("(%s)" % regex)
+ for item in sorted(collection):
+ r = pat.search(item.lower())
+ if r:
+ completions.append((len(r.group()), r.start(), item))
+ else:
+ match_end_limit = len(text) if start_only else None
+ for item in sorted(collection):
+ match_point = item.lower().find(text, 0, match_end_limit)
+ if match_point >= 0:
+ completions.append((len(text), match_point, item))
+
+ if casing == "auto":
+ casing = "lower" if last and last[-1].islower() else "upper"
+
+ def apply_case(kw):
+ if casing == "upper":
+ return kw.upper()
+ return kw.lower()
+
+ return (
+ Completion(z if casing is None else apply_case(z), -len(text))
+ for x, y, z in sorted(completions)
+ )
+
+ def get_completions(self, document, complete_event):
+ word_before_cursor = document.get_word_before_cursor(WORD=True)
+ completions = []
+ suggestions = suggest_type(document.text, document.text_before_cursor)
+
+ for suggestion in suggestions:
+
+ _logger.debug("Suggestion type: %r", suggestion["type"])
+
+ if suggestion["type"] == "column":
+ tables = suggestion["tables"]
+ _logger.debug("Completion column scope: %r", tables)
+ scoped_cols = self.populate_scoped_cols(tables)
+ if suggestion.get("drop_unique"):
+ # drop_unique is used for 'tb11 JOIN tbl2 USING (...'
+ # which should suggest only columns that appear in more than
+ # one table
+ scoped_cols = [
+ col
+ for (col, count) in Counter(scoped_cols).items()
+ if count > 1 and col != "*"
+ ]
+
+ cols = self.find_matches(word_before_cursor, scoped_cols)
+ completions.extend(cols)
+
+ elif suggestion["type"] == "function":
+ # suggest user-defined functions using substring matching
+ funcs = self.populate_schema_objects(suggestion["schema"], "functions")
+ user_funcs = self.find_matches(word_before_cursor, funcs)
+ completions.extend(user_funcs)
+
+ # suggest hardcoded functions using startswith matching only if
+ # there is no schema qualifier. If a schema qualifier is
+ # present it probably denotes a table.
+ # eg: SELECT * FROM users u WHERE u.
+ if not suggestion["schema"]:
+ predefined_funcs = self.find_matches(
+ word_before_cursor,
+ self.functions,
+ start_only=True,
+ fuzzy=False,
+ casing=self.keyword_casing,
+ )
+ completions.extend(predefined_funcs)
+
+ elif suggestion["type"] == "table":
+ tables = self.populate_schema_objects(suggestion["schema"], "tables")
+ tables = self.find_matches(word_before_cursor, tables)
+ completions.extend(tables)
+
+ elif suggestion["type"] == "view":
+ views = self.populate_schema_objects(suggestion["schema"], "views")
+ views = self.find_matches(word_before_cursor, views)
+ completions.extend(views)
+
+ elif suggestion["type"] == "alias":
+ aliases = suggestion["aliases"]
+ aliases = self.find_matches(word_before_cursor, aliases)
+ completions.extend(aliases)
+
+ elif suggestion["type"] == "database":
+ dbs = self.find_matches(word_before_cursor, self.databases)
+ completions.extend(dbs)
+
+ elif suggestion["type"] == "keyword":
+ keywords = self.find_matches(
+ word_before_cursor,
+ self.keywords,
+ start_only=True,
+ fuzzy=False,
+ casing=self.keyword_casing,
+ punctuations="many_punctuations",
+ )
+ completions.extend(keywords)
+
+ elif suggestion["type"] == "special":
+ special = self.find_matches(
+ word_before_cursor,
+ self.special_commands,
+ start_only=True,
+ fuzzy=False,
+ punctuations="many_punctuations",
+ )
+ completions.extend(special)
+ elif suggestion["type"] == "favoritequery":
+ queries = self.find_matches(
+ word_before_cursor,
+ favoritequeries.list(),
+ start_only=False,
+ fuzzy=True,
+ )
+ completions.extend(queries)
+ elif suggestion["type"] == "table_format":
+ formats = self.find_matches(
+ word_before_cursor, self.table_formats, start_only=True, fuzzy=False
+ )
+ completions.extend(formats)
+ elif suggestion["type"] == "file_name":
+ file_names = self.find_files(word_before_cursor)
+ completions.extend(file_names)
+
+ return completions
+
+ def find_files(self, word):
+ """Yield matching directory or file names.
+
+ :param word:
+ :return: iterable
+
+ """
+ base_path, last_path, position = parse_path(word)
+ paths = suggest_path(word)
+ for name in sorted(paths):
+ suggestion = complete_path(name, last_path)
+ if suggestion:
+ yield Completion(suggestion, position)
+
+ def populate_scoped_cols(self, scoped_tbls):
+ """Find all columns in a set of scoped_tables
+ :param scoped_tbls: list of (schema, table, alias) tuples
+ :return: list of column names
+ """
+ columns = []
+ meta = self.dbmetadata
+
+ for tbl in scoped_tbls:
+ # A fully qualified schema.relname reference or default_schema
+ # DO NOT escape schema names.
+ schema = tbl[0] or self.dbname
+ relname = tbl[1]
+ escaped_relname = self.escape_name(tbl[1])
+
+ # We don't know if schema.relname is a table or view. Since
+ # tables and views cannot share the same name, we can check one
+ # at a time
+ try:
+ columns.extend(meta["tables"][schema][relname])
+
+ # Table exists, so don't bother checking for a view
+ continue
+ except KeyError:
+ try:
+ columns.extend(meta["tables"][schema][escaped_relname])
+ # Table exists, so don't bother checking for a view
+ continue
+ except KeyError:
+ pass
+
+ try:
+ columns.extend(meta["views"][schema][relname])
+ except KeyError:
+ pass
+
+ return columns
+
+ def populate_schema_objects(self, schema, obj_type):
+ """Returns list of tables or functions for a (optional) schema"""
+ metadata = self.dbmetadata[obj_type]
+ schema = schema or self.dbname
+
+ try:
+ objects = metadata[schema].keys()
+ except KeyError:
+ # schema doesn't exist
+ objects = []
+
+ return objects
diff --git a/litecli/sqlexecute.py b/litecli/sqlexecute.py
new file mode 100644
index 0000000..3f78d49
--- /dev/null
+++ b/litecli/sqlexecute.py
@@ -0,0 +1,220 @@
+import logging
+import sqlite3
+import uuid
+from contextlib import closing
+from sqlite3 import OperationalError
+
+import sqlparse
+import os.path
+
+from .packages import special
+
+_logger = logging.getLogger(__name__)
+
+# FIELD_TYPES = decoders.copy()
+# FIELD_TYPES.update({
+# FIELD_TYPE.NULL: type(None)
+# })
+
+
+class SQLExecute(object):
+
+ databases_query = """
+ PRAGMA database_list
+ """
+
+ tables_query = """
+ SELECT name
+ FROM sqlite_master
+ WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%'
+ ORDER BY 1
+ """
+
+ table_columns_query = """
+ SELECT m.name as tableName, p.name as columnName
+ FROM sqlite_master m
+ LEFT OUTER JOIN pragma_table_info((m.name)) p ON m.name <> p.name
+ WHERE m.type IN ('table','view') AND m.name NOT LIKE 'sqlite_%'
+ ORDER BY tableName, columnName
+ """
+
+ indexes_query = """
+ SELECT name
+ FROM sqlite_master
+ WHERE type = 'index' AND name NOT LIKE 'sqlite_%'
+ ORDER BY 1
+ """
+
+ functions_query = '''SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES
+ WHERE ROUTINE_TYPE="FUNCTION" AND ROUTINE_SCHEMA = "%s"'''
+
+ def __init__(self, database):
+ self.dbname = database
+ self._server_type = None
+ self.connection_id = None
+ self.conn = None
+ if not database:
+ _logger.debug("Database is not specified. Skip connection.")
+ return
+ self.connect()
+
+ def connect(self, database=None):
+ db = database or self.dbname
+ _logger.debug("Connection DB Params: \n" "\tdatabase: %r", database)
+
+ db_name = os.path.expanduser(db)
+ db_dir_name = os.path.dirname(os.path.abspath(db_name))
+ if not os.path.exists(db_dir_name):
+ raise Exception("Path does not exist: {}".format(db_dir_name))
+
+ conn = sqlite3.connect(database=db_name, isolation_level=None)
+ conn.text_factory = lambda x: x.decode("utf-8", "backslashreplace")
+ if self.conn:
+ self.conn.close()
+
+ self.conn = conn
+ # Update them after the connection is made to ensure that it was a
+ # successful connection.
+ self.dbname = db
+ # retrieve connection id
+ self.reset_connection_id()
+
+ def run(self, statement):
+ """Execute the sql in the database and return the results. The results
+ are a list of tuples. Each tuple has 4 values
+ (title, rows, headers, status).
+ """
+ # Remove spaces and EOL
+ statement = statement.strip()
+ if not statement: # Empty string
+ yield (None, None, None, None)
+
+ # Split the sql into separate queries and run each one.
+ # Unless it's saving a favorite query, in which case we
+ # want to save them all together.
+ if statement.startswith("\\fs"):
+ components = [statement]
+ else:
+ components = sqlparse.split(statement)
+
+ for sql in components:
+ # Remove spaces, eol and semi-colons.
+ sql = sql.rstrip(";")
+
+ # \G is treated specially since we have to set the expanded output.
+ if sql.endswith("\\G"):
+ special.set_expanded_output(True)
+ sql = sql[:-2].strip()
+
+ if not self.conn and not (
+ sql.startswith(".open")
+ or sql.lower().startswith("use")
+ or sql.startswith("\\u")
+ or sql.startswith("\\?")
+ or sql.startswith("\\q")
+ or sql.startswith("help")
+ or sql.startswith("exit")
+ or sql.startswith("quit")
+ ):
+ _logger.debug(
+ "Not connected to database. Will not run statement: %s.", sql
+ )
+ raise OperationalError("Not connected to database.")
+ # yield ('Not connected to database', None, None, None)
+ # return
+
+ cur = self.conn.cursor() if self.conn else None
+ try: # Special command
+ _logger.debug("Trying a dbspecial command. sql: %r", sql)
+ for result in special.execute(cur, sql):
+ yield result
+ except special.CommandNotFound: # Regular SQL
+ _logger.debug("Regular sql statement. sql: %r", sql)
+ cur.execute(sql)
+ yield self.get_result(cur)
+
+ def get_result(self, cursor):
+ """Get the current result's data from the cursor."""
+ title = headers = None
+
+ # cursor.description is not None for queries that return result sets,
+ # e.g. SELECT.
+ if cursor.description is not None:
+ headers = [x[0] for x in cursor.description]
+ status = "{0} row{1} in set"
+ cursor = list(cursor)
+ rowcount = len(cursor)
+ else:
+ _logger.debug("No rows in result.")
+ status = "Query OK, {0} row{1} affected"
+ rowcount = 0 if cursor.rowcount == -1 else cursor.rowcount
+ cursor = None
+
+ status = status.format(rowcount, "" if rowcount == 1 else "s")
+
+ return (title, cursor, headers, status)
+
+ def tables(self):
+ """Yields table names"""
+
+ with closing(self.conn.cursor()) as cur:
+ _logger.debug("Tables Query. sql: %r", self.tables_query)
+ cur.execute(self.tables_query)
+ for row in cur:
+ yield row
+
+ def table_columns(self):
+ """Yields column names"""
+ with closing(self.conn.cursor()) as cur:
+ _logger.debug("Columns Query. sql: %r", self.table_columns_query)
+ cur.execute(self.table_columns_query)
+ for row in cur:
+ yield row
+
+ def databases(self):
+ if not self.conn:
+ return
+
+ with closing(self.conn.cursor()) as cur:
+ _logger.debug("Databases Query. sql: %r", self.databases_query)
+ for row in cur.execute(self.databases_query):
+ yield row[1]
+
+ def functions(self):
+ """Yields tuples of (schema_name, function_name)"""
+
+ with closing(self.conn.cursor()) as cur:
+ _logger.debug("Functions Query. sql: %r", self.functions_query)
+ cur.execute(self.functions_query % self.dbname)
+ for row in cur:
+ yield row
+
+ def show_candidates(self):
+ with closing(self.conn.cursor()) as cur:
+ _logger.debug("Show Query. sql: %r", self.show_candidates_query)
+ try:
+ cur.execute(self.show_candidates_query)
+ except sqlite3.DatabaseError as e:
+ _logger.error("No show completions due to %r", e)
+ yield ""
+ else:
+ for row in cur:
+ yield (row[0].split(None, 1)[-1],)
+
+ def server_type(self):
+ self._server_type = ("sqlite3", "3")
+ return self._server_type
+
+ def get_connection_id(self):
+ if not self.connection_id:
+ self.reset_connection_id()
+ return self.connection_id
+
+ def reset_connection_id(self):
+ # Remember current connection id
+ _logger.debug("Get current connection id")
+ # res = self.run('select connection_id()')
+ self.connection_id = uuid.uuid4()
+ # for title, cur, headers, status in res:
+ # self.connection_id = cur.fetchone()[0]
+ _logger.debug("Current connection id: %s", self.connection_id)
diff --git a/release.py b/release.py
new file mode 100644
index 0000000..f6beb88
--- /dev/null
+++ b/release.py
@@ -0,0 +1,130 @@
+#!/usr/bin/env python
+"""A script to publish a release of litecli to PyPI."""
+
+from __future__ import print_function
+import io
+from optparse import OptionParser
+import re
+import subprocess
+import sys
+
+import click
+
+DEBUG = False
+CONFIRM_STEPS = False
+DRY_RUN = False
+
+
+def skip_step():
+ """
+ Asks for user's response whether to run a step. Default is yes.
+ :return: boolean
+ """
+ global CONFIRM_STEPS
+
+ if CONFIRM_STEPS:
+ return not click.confirm("--- Run this step?", default=True)
+ return False
+
+
+def run_step(*args):
+ """
+ Prints out the command and asks if it should be run.
+ If yes (default), runs it.
+ :param args: list of strings (command and args)
+ """
+ global DRY_RUN
+
+ cmd = args
+ print(" ".join(cmd))
+ if skip_step():
+ print("--- Skipping...")
+ elif DRY_RUN:
+ print("--- Pretending to run...")
+ else:
+ subprocess.check_output(cmd)
+
+
+def version(version_file):
+ _version_re = re.compile(
+ r'__version__\s+=\s+(?P<quote>[\'"])(?P<version>.*)(?P=quote)'
+ )
+
+ with io.open(version_file, encoding="utf-8") as f:
+ ver = _version_re.search(f.read()).group("version")
+
+ return ver
+
+
+def commit_for_release(version_file, ver):
+ run_step("git", "reset")
+ run_step("git", "add", version_file)
+ run_step("git", "commit", "--message", "Releasing version {}".format(ver))
+
+
+def create_git_tag(tag_name):
+ run_step("git", "tag", tag_name)
+
+
+def create_distribution_files():
+ run_step("python", "setup.py", "sdist", "bdist_wheel")
+
+
+def upload_distribution_files():
+ run_step("twine", "upload", "dist/*")
+
+
+def push_to_github():
+ run_step("git", "push", "origin", "main")
+
+
+def push_tags_to_github():
+ run_step("git", "push", "--tags", "origin")
+
+
+def checklist(questions):
+ for question in questions:
+ if not click.confirm("--- {}".format(question), default=False):
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ if DEBUG:
+ subprocess.check_output = lambda x: x
+
+ ver = version("litecli/__init__.py")
+ print("Releasing Version:", ver)
+
+ parser = OptionParser()
+ parser.add_option(
+ "-c",
+ "--confirm-steps",
+ action="store_true",
+ dest="confirm_steps",
+ default=False,
+ help=(
+ "Confirm every step. If the step is not " "confirmed, it will be skipped."
+ ),
+ )
+ parser.add_option(
+ "-d",
+ "--dry-run",
+ action="store_true",
+ dest="dry_run",
+ default=False,
+ help="Print out, but not actually run any steps.",
+ )
+
+ popts, pargs = parser.parse_args()
+ CONFIRM_STEPS = popts.confirm_steps
+ DRY_RUN = popts.dry_run
+
+ if not click.confirm("Are you sure?", default=False):
+ sys.exit(1)
+
+ commit_for_release("litecli/__init__.py", ver)
+ create_git_tag("v{}".format(ver))
+ create_distribution_files()
+ push_to_github()
+ push_tags_to_github()
+ upload_distribution_files()
diff --git a/requirements-dev.txt b/requirements-dev.txt
new file mode 100644
index 0000000..c517d59
--- /dev/null
+++ b/requirements-dev.txt
@@ -0,0 +1,10 @@
+mock
+pytest>=3.6
+pytest-cov
+tox
+behave
+pexpect
+coverage
+codecov
+click
+black \ No newline at end of file
diff --git a/screenshots/litecli.gif b/screenshots/litecli.gif
new file mode 100644
index 0000000..9cfd80c
--- /dev/null
+++ b/screenshots/litecli.gif
Binary files differ
diff --git a/screenshots/litecli.png b/screenshots/litecli.png
new file mode 100644
index 0000000..6ca999e
--- /dev/null
+++ b/screenshots/litecli.png
Binary files differ
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..40eab0a
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,18 @@
+[bdist_wheel]
+universal = 1
+
+[tool:pytest]
+addopts = --capture=sys
+ --showlocals
+ --doctest-modules
+ --doctest-ignore-import-errors
+ --ignore=setup.py
+ --ignore=litecli/magic.py
+ --ignore=litecli/packages/parseutils.py
+ --ignore=test/features
+
+[pep8]
+rev = master
+docformatter = True
+diff = True
+error-status = True
diff --git a/setup.py b/setup.py
new file mode 100755
index 0000000..0ff4eeb
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,70 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import ast
+from io import open
+import re
+from setuptools import setup, find_packages
+
+_version_re = re.compile(r"__version__\s+=\s+(.*)")
+
+with open("litecli/__init__.py", "rb") as f:
+ version = str(
+ ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1))
+ )
+
+
+def open_file(filename):
+ """Open and read the file *filename*."""
+ with open(filename) as f:
+ return f.read()
+
+
+readme = open_file("README.md")
+
+install_requirements = [
+ "click >= 4.1",
+ "Pygments>=1.6",
+ "prompt_toolkit>=3.0.3,<4.0.0",
+ "sqlparse",
+ "configobj >= 5.0.5",
+ "cli_helpers[styles] >= 2.2.1",
+]
+
+
+setup(
+ name="litecli",
+ author="dbcli",
+ author_email="litecli-users@googlegroups.com",
+ license="BSD",
+ version=version,
+ url="https://github.com/dbcli/litecli",
+ packages=find_packages(),
+ package_data={"litecli": ["liteclirc", "AUTHORS"]},
+ description="CLI for SQLite Databases with auto-completion and syntax "
+ "highlighting.",
+ long_description=readme,
+ long_description_content_type="text/markdown",
+ install_requires=install_requirements,
+ # cmdclass={"test": test, "lint": lint},
+ entry_points={
+ "console_scripts": ["litecli = litecli.main:cli"],
+ "distutils.commands": ["lint = tasks:lint", "test = tasks:test"],
+ },
+ classifiers=[
+ "Intended Audience :: Developers",
+ "License :: OSI Approved :: BSD License",
+ "Operating System :: Unix",
+ "Programming Language :: Python",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: SQL",
+ "Topic :: Database",
+ "Topic :: Database :: Front-Ends",
+ "Topic :: Software Development",
+ "Topic :: Software Development :: Libraries :: Python Modules",
+ ],
+)
diff --git a/tasks.py b/tasks.py
new file mode 100644
index 0000000..1cd4b69
--- /dev/null
+++ b/tasks.py
@@ -0,0 +1,130 @@
+# -*- coding: utf-8 -*-
+"""Common development tasks for setup.py to use."""
+
+import re
+import subprocess
+import sys
+
+from setuptools import Command
+from setuptools.command.test import test as TestCommand
+
+
+class BaseCommand(Command, object):
+ """The base command for project tasks."""
+
+ user_options = []
+
+ default_cmd_options = ("verbose", "quiet", "dry_run")
+
+ def __init__(self, *args, **kwargs):
+ super(BaseCommand, self).__init__(*args, **kwargs)
+ self.verbose = False
+
+ def initialize_options(self):
+ """Override the distutils abstract method."""
+ pass
+
+ def finalize_options(self):
+ """Override the distutils abstract method."""
+ # Distutils uses incrementing integers for verbosity.
+ self.verbose = bool(self.verbose)
+
+ def call_and_exit(self, cmd, shell=True):
+ """Run the *cmd* and exit with the proper exit code."""
+ sys.exit(subprocess.call(cmd, shell=shell))
+
+ def call_in_sequence(self, cmds, shell=True):
+ """Run multiple commmands in a row, exiting if one fails."""
+ for cmd in cmds:
+ if subprocess.call(cmd, shell=shell) == 1:
+ sys.exit(1)
+
+ def apply_options(self, cmd, options=()):
+ """Apply command-line options."""
+ for option in self.default_cmd_options + options:
+ cmd = self.apply_option(cmd, option, active=getattr(self, option, False))
+ return cmd
+
+ def apply_option(self, cmd, option, active=True):
+ """Apply a command-line option."""
+ return re.sub(
+ r"{{{}\:(?P<option>[^}}]*)}}".format(option),
+ r"\g<option>" if active else "",
+ cmd,
+ )
+
+
+class lint(BaseCommand):
+ description = "check code using black (and fix violations)"
+
+ user_options = [("fix", "f", "fix the violations in place")]
+
+ def initialize_options(self):
+ """Set the default options."""
+ self.fix = False
+
+ def finalize_options(self):
+ pass
+
+ def run(self):
+ cmd = "black"
+ if not self.fix:
+ cmd += " --check"
+ cmd += " ."
+ sys.exit(subprocess.call(cmd, shell=True))
+
+
+class test(TestCommand):
+
+ user_options = [("pytest-args=", "a", "Arguments to pass to pytest")]
+
+ def initialize_options(self):
+ TestCommand.initialize_options(self)
+ self.pytest_args = ""
+
+ def run_tests(self):
+ unit_test_errno = subprocess.call(
+ "pytest tests " + self.pytest_args, shell=True
+ )
+ # cli_errno = subprocess.call('behave test/features', shell=True)
+ # sys.exit(unit_test_errno or cli_errno)
+ sys.exit(unit_test_errno)
+
+
+# class test(BaseCommand):
+# """Run the test suites for this project."""
+
+# description = "run the test suite"
+
+# user_options = [
+# ("all", "a", "test against all supported versions of Python"),
+# ("coverage", "c", "measure test coverage"),
+# ]
+
+# unit_test_cmd = (
+# "py.test{quiet: -q}{verbose: -v}{dry_run: --setup-only}"
+# "{coverage: --cov-report= --cov=litecli}"
+# )
+# # cli_test_cmd = 'behave{quiet: -q}{verbose: -v}{dry_run: -d} test/features'
+# test_all_cmd = "tox{verbose: -v}{dry_run: --notest}"
+# coverage_cmd = "coverage combine && coverage report"
+
+# def initialize_options(self):
+# """Set the default options."""
+# self.all = False
+# self.coverage = False
+# super(test, self).initialize_options()
+
+# def run(self):
+# """Run the test suites."""
+# if self.all:
+# cmd = self.apply_options(self.test_all_cmd)
+# self.call_and_exit(cmd)
+# else:
+# cmds = (
+# self.apply_options(self.unit_test_cmd, ("coverage",)),
+# # self.apply_options(self.cli_test_cmd)
+# )
+# if self.coverage:
+# cmds += (self.apply_options(self.coverage_cmd),)
+# self.call_in_sequence(cmds)
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..dce0d7e
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,40 @@
+from __future__ import print_function
+
+import os
+import pytest
+from utils import create_db, db_connection, drop_tables
+import litecli.sqlexecute
+
+
+@pytest.yield_fixture(scope="function")
+def connection():
+ create_db("_test_db")
+ connection = db_connection("_test_db")
+ yield connection
+
+ drop_tables(connection)
+ connection.close()
+ os.remove("_test_db")
+
+
+@pytest.fixture
+def cursor(connection):
+ with connection.cursor() as cur:
+ return cur
+
+
+@pytest.fixture
+def executor(connection):
+ return litecli.sqlexecute.SQLExecute(database="_test_db")
+
+
+@pytest.fixture
+def exception_formatter():
+ return lambda e: str(e)
+
+
+@pytest.fixture(scope="session", autouse=True)
+def temp_config(tmpdir_factory):
+ # this function runs on start of test session.
+ # use temporary directory for config home so user config will not be used
+ os.environ["XDG_CONFIG_HOME"] = str(tmpdir_factory.mktemp("data"))
diff --git a/tests/data/import_data.csv b/tests/data/import_data.csv
new file mode 100644
index 0000000..d68d655
--- /dev/null
+++ b/tests/data/import_data.csv
@@ -0,0 +1,2 @@
+t1,11
+t2,22
diff --git a/tests/liteclirc b/tests/liteclirc
new file mode 100644
index 0000000..da9b061
--- /dev/null
+++ b/tests/liteclirc
@@ -0,0 +1,137 @@
+[main]
+
+# Multi-line mode allows breaking up the sql statements into multiple lines. If
+# this is set to True, then the end of the statements must have a semi-colon.
+# If this is set to False then sql statements can't be split into multiple
+# lines. End of line (return) is considered as the end of the statement.
+multi_line = False
+
+# Destructive warning mode will alert you before executing a sql statement
+# that may cause harm to the database such as "drop table", "drop database"
+# or "shutdown".
+destructive_warning = True
+
+# log_file location.
+# In Unix/Linux: ~/.config/litecli/log
+# In Windows: %USERPROFILE%\AppData\Local\dbcli\litecli\log
+# %USERPROFILE% is typically C:\Users\{username}
+log_file = default
+
+# Default log level. Possible values: "CRITICAL", "ERROR", "WARNING", "INFO"
+# and "DEBUG". "NONE" disables logging.
+log_level = INFO
+
+# Log every query and its results to a file. Enable this by uncommenting the
+# line below.
+# audit_log = ~/.litecli-audit.log
+
+# Default pager.
+# By default '$PAGER' environment variable is used
+# pager = less -SRXF
+
+# Table format. Possible values:
+# ascii, double, github, psql, plain, simple, grid, fancy_grid, pipe, orgtbl,
+# rst, mediawiki, html, latex, latex_booktabs, textile, moinmoin, jira,
+# vertical, tsv, csv.
+# Recommended: ascii
+table_format = ascii
+
+# Syntax coloring style. Possible values (many support the "-dark" suffix):
+# manni, igor, xcode, vim, autumn, vs, rrt, native, perldoc, borland, tango, emacs,
+# friendly, monokai, paraiso, colorful, murphy, bw, pastie, paraiso, trac, default,
+# fruity.
+# Screenshots at http://mycli.net/syntax
+syntax_style = default
+
+# Keybindings: Possible values: emacs, vi.
+# Emacs mode: Ctrl-A is home, Ctrl-E is end. All emacs keybindings are available in the REPL.
+# When Vi mode is enabled you can use modal editing features offered by Vi in the REPL.
+key_bindings = emacs
+
+# Enabling this option will show the suggestions in a wider menu. Thus more items are suggested.
+wider_completion_menu = False
+
+# Autocompletion is on by default. This can be truned off by setting this
+# option to False. Pressing tab will still trigger completion.
+autocompletion = True
+
+# litecli prompt
+# \D - The full current date
+# \d - Database name
+# \f - File basename of the "main" database
+# \m - Minutes of the current time
+# \n - Newline
+# \P - AM/PM
+# \R - The current time, in 24-hour military time (0-23)
+# \r - The current time, standard 12-hour time (1-12)
+# \s - Seconds of the current time
+# \x1b[...m - insert ANSI escape sequence
+prompt = "\t :\d> "
+prompt_continuation = "-> "
+
+# Show/hide the informational toolbar with function keymap at the footer.
+show_bottom_toolbar = True
+
+# Skip intro info on startup and outro info on exit
+less_chatty = False
+
+# Use alias from --login-path instead of host name in prompt
+login_path_as_host = False
+
+# Cause result sets to be displayed vertically if they are too wide for the current window,
+# and using normal tabular format otherwise. (This applies to statements terminated by ; or \G.)
+auto_vertical_output = False
+
+# keyword casing preference. Possible values "lower", "upper", "auto"
+keyword_casing = auto
+
+# disabled pager on startup
+enable_pager = True
+[colors]
+completion-menu.completion.current = "bg:#ffffff #000000"
+completion-menu.completion = "bg:#008888 #ffffff"
+completion-menu.meta.completion.current = "bg:#44aaaa #000000"
+completion-menu.meta.completion = "bg:#448888 #ffffff"
+completion-menu.multi-column-meta = "bg:#aaffff #000000"
+scrollbar.arrow = "bg:#003333"
+scrollbar = "bg:#00aaaa"
+selected = "#ffffff bg:#6666aa"
+search = "#ffffff bg:#4444aa"
+search.current = "#ffffff bg:#44aa44"
+bottom-toolbar = "bg:#222222 #aaaaaa"
+bottom-toolbar.off = "bg:#222222 #888888"
+bottom-toolbar.on = "bg:#222222 #ffffff"
+search-toolbar = noinherit bold
+search-toolbar.text = nobold
+system-toolbar = noinherit bold
+arg-toolbar = noinherit bold
+arg-toolbar.text = nobold
+bottom-toolbar.transaction.valid = "bg:#222222 #00ff5f bold"
+bottom-toolbar.transaction.failed = "bg:#222222 #ff005f bold"
+
+# style classes for colored table output
+output.header = "#00ff5f bold"
+output.odd-row = ""
+output.even-row = ""
+Token.Menu.Completions.Completion.Current = "bg:#00aaaa #000000"
+Token.Menu.Completions.Completion = "bg:#008888 #ffffff"
+Token.Menu.Completions.MultiColumnMeta = "bg:#aaffff #000000"
+Token.Menu.Completions.ProgressButton = "bg:#003333"
+Token.Menu.Completions.ProgressBar = "bg:#00aaaa"
+Token.Output.Header = bold
+Token.Output.OddRow = ""
+Token.Output.EvenRow = ""
+Token.SelectedText = "#ffffff bg:#6666aa"
+Token.SearchMatch = "#ffffff bg:#4444aa"
+Token.SearchMatch.Current = "#ffffff bg:#44aa44"
+Token.Toolbar = "bg:#222222 #aaaaaa"
+Token.Toolbar.Off = "bg:#222222 #888888"
+Token.Toolbar.On = "bg:#222222 #ffffff"
+Token.Toolbar.Search = noinherit bold
+Token.Toolbar.Search.Text = nobold
+Token.Toolbar.System = noinherit bold
+Token.Toolbar.Arg = noinherit bold
+Token.Toolbar.Arg.Text = nobold
+[favorite_queries]
+q_param = select * from test where name=?
+sh_param = select * from test where id=$1
diff --git a/tests/test.txt b/tests/test.txt
new file mode 100644
index 0000000..1fa4cf0
--- /dev/null
+++ b/tests/test.txt
@@ -0,0 +1 @@
+litecli is awesome!
diff --git a/tests/test_clistyle.py b/tests/test_clistyle.py
new file mode 100644
index 0000000..c1177de
--- /dev/null
+++ b/tests/test_clistyle.py
@@ -0,0 +1,28 @@
+# -*- coding: utf-8 -*-
+"""Test the litecli.clistyle module."""
+import pytest
+
+from pygments.style import Style
+from pygments.token import Token
+
+from litecli.clistyle import style_factory
+
+
+@pytest.mark.skip(reason="incompatible with new prompt toolkit")
+def test_style_factory():
+ """Test that a Pygments Style class is created."""
+ header = "bold underline #ansired"
+ cli_style = {"Token.Output.Header": header}
+ style = style_factory("default", cli_style)
+
+ assert isinstance(style(), Style)
+ assert Token.Output.Header in style.styles
+ assert header == style.styles[Token.Output.Header]
+
+
+@pytest.mark.skip(reason="incompatible with new prompt toolkit")
+def test_style_factory_unknown_name():
+ """Test that an unrecognized name will not throw an error."""
+ style = style_factory("foobar", {})
+
+ assert isinstance(style(), Style)
diff --git a/tests/test_completion_engine.py b/tests/test_completion_engine.py
new file mode 100644
index 0000000..760f275
--- /dev/null
+++ b/tests/test_completion_engine.py
@@ -0,0 +1,657 @@
+from litecli.packages.completion_engine import suggest_type
+import pytest
+
+
+def sorted_dicts(dicts):
+ """input is a list of dicts."""
+ return sorted(tuple(x.items()) for x in dicts)
+
+
+def test_select_suggests_cols_with_visible_table_scope():
+ suggestions = suggest_type("SELECT FROM tabl", "SELECT ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["tabl"]},
+ {"type": "column", "tables": [(None, "tabl", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+def test_select_suggests_cols_with_qualified_table_scope():
+ suggestions = suggest_type("SELECT FROM sch.tabl", "SELECT ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["tabl"]},
+ {"type": "column", "tables": [("sch", "tabl", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+def test_order_by_suggests_cols_with_qualified_table_scope():
+ suggestions = suggest_type(
+ "SELECT * FROM sch.tabl ORDER BY ", "SELECT * FROM sch.tabl ORDER BY "
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "column", "tables": [("sch", "tabl", None)]},
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM tabl WHERE ",
+ "SELECT * FROM tabl WHERE (",
+ "SELECT * FROM tabl WHERE foo = ",
+ "SELECT * FROM tabl WHERE bar OR ",
+ "SELECT * FROM tabl WHERE foo = 1 AND ",
+ "SELECT * FROM tabl WHERE (bar > 10 AND ",
+ "SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (",
+ "SELECT * FROM tabl WHERE 10 < ",
+ "SELECT * FROM tabl WHERE foo BETWEEN ",
+ "SELECT * FROM tabl WHERE foo BETWEEN foo AND ",
+ ],
+)
+def test_where_suggests_columns_functions(expression):
+ suggestions = suggest_type(expression, expression)
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["tabl"]},
+ {"type": "column", "tables": [(None, "tabl", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "expression",
+ ["SELECT * FROM tabl WHERE foo IN (", "SELECT * FROM tabl WHERE foo IN (bar, "],
+)
+def test_where_in_suggests_columns(expression):
+ suggestions = suggest_type(expression, expression)
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["tabl"]},
+ {"type": "column", "tables": [(None, "tabl", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+def test_where_equals_any_suggests_columns_or_keywords():
+ text = "SELECT * FROM tabl WHERE foo = ANY("
+ suggestions = suggest_type(text, text)
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["tabl"]},
+ {"type": "column", "tables": [(None, "tabl", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+def test_lparen_suggests_cols():
+ suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(")
+ assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}]
+
+
+def test_operand_inside_function_suggests_cols1():
+ suggestion = suggest_type("SELECT MAX(col1 + FROM tbl", "SELECT MAX(col1 + ")
+ assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}]
+
+
+def test_operand_inside_function_suggests_cols2():
+ suggestion = suggest_type(
+ "SELECT MAX(col1 + col2 + FROM tbl", "SELECT MAX(col1 + col2 + "
+ )
+ assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}]
+
+
+def test_select_suggests_cols_and_funcs():
+ suggestions = suggest_type("SELECT ", "SELECT ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": []},
+ {"type": "column", "tables": []},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM ",
+ "INSERT INTO ",
+ "COPY ",
+ "UPDATE ",
+ "DESCRIBE ",
+ "DESC ",
+ "EXPLAIN ",
+ "SELECT * FROM foo JOIN ",
+ ],
+)
+def test_expression_suggests_tables_views_and_schemas(expression):
+ suggestions = suggest_type(expression, expression)
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM sch.",
+ "INSERT INTO sch.",
+ "COPY sch.",
+ "UPDATE sch.",
+ "DESCRIBE sch.",
+ "DESC sch.",
+ "EXPLAIN sch.",
+ "SELECT * FROM foo JOIN sch.",
+ ],
+)
+def test_expression_suggests_qualified_tables_views_and_schemas(expression):
+ suggestions = suggest_type(expression, expression)
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [{"type": "table", "schema": "sch"}, {"type": "view", "schema": "sch"}]
+ )
+
+
+def test_truncate_suggests_tables_and_schemas():
+ suggestions = suggest_type("TRUNCATE ", "TRUNCATE ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [{"type": "table", "schema": []}, {"type": "schema"}]
+ )
+
+
+def test_truncate_suggests_qualified_tables():
+ suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [{"type": "table", "schema": "sch"}]
+ )
+
+
+def test_distinct_suggests_cols():
+ suggestions = suggest_type("SELECT DISTINCT ", "SELECT DISTINCT ")
+ assert suggestions == [{"type": "column", "tables": []}]
+
+
+def test_col_comma_suggests_cols():
+ suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["tbl"]},
+ {"type": "column", "tables": [(None, "tbl", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+def test_table_comma_suggests_tables_and_schemas():
+ suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+
+def test_into_suggests_tables_and_schemas():
+ suggestion = suggest_type("INSERT INTO ", "INSERT INTO ")
+ assert sorted_dicts(suggestion) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+
+def test_insert_into_lparen_suggests_cols():
+ suggestions = suggest_type("INSERT INTO abc (", "INSERT INTO abc (")
+ assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}]
+
+
+def test_insert_into_lparen_partial_text_suggests_cols():
+ suggestions = suggest_type("INSERT INTO abc (i", "INSERT INTO abc (i")
+ assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}]
+
+
+def test_insert_into_lparen_comma_suggests_cols():
+ suggestions = suggest_type("INSERT INTO abc (id,", "INSERT INTO abc (id,")
+ assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}]
+
+
+def test_partially_typed_col_name_suggests_col_names():
+ suggestions = suggest_type(
+ "SELECT * FROM tabl WHERE col_n", "SELECT * FROM tabl WHERE col_n"
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["tabl"]},
+ {"type": "column", "tables": [(None, "tabl", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
+ suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "column", "tables": [(None, "tabl", None)]},
+ {"type": "table", "schema": "tabl"},
+ {"type": "view", "schema": "tabl"},
+ {"type": "function", "schema": "tabl"},
+ ]
+ )
+
+
+def test_dot_suggests_cols_of_an_alias():
+ suggestions = suggest_type("SELECT t1. FROM tabl1 t1, tabl2 t2", "SELECT t1.")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": "t1"},
+ {"type": "view", "schema": "t1"},
+ {"type": "column", "tables": [(None, "tabl1", "t1")]},
+ {"type": "function", "schema": "t1"},
+ ]
+ )
+
+
+def test_dot_col_comma_suggests_cols_or_schema_qualified_table():
+ suggestions = suggest_type(
+ "SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2."
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "column", "tables": [(None, "tabl2", "t2")]},
+ {"type": "table", "schema": "t2"},
+ {"type": "view", "schema": "t2"},
+ {"type": "function", "schema": "t2"},
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM (",
+ "SELECT * FROM foo WHERE EXISTS (",
+ "SELECT * FROM foo WHERE bar AND NOT EXISTS (",
+ "SELECT 1 AS",
+ ],
+)
+def test_sub_select_suggests_keyword(expression):
+ suggestion = suggest_type(expression, expression)
+ assert suggestion == [{"type": "keyword"}]
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM (S",
+ "SELECT * FROM foo WHERE EXISTS (S",
+ "SELECT * FROM foo WHERE bar AND NOT EXISTS (S",
+ ],
+)
+def test_sub_select_partial_text_suggests_keyword(expression):
+ suggestion = suggest_type(expression, expression)
+ assert suggestion == [{"type": "keyword"}]
+
+
+def test_outer_table_reference_in_exists_subquery_suggests_columns():
+ q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f."
+ suggestions = suggest_type(q, q)
+ assert suggestions == [
+ {"type": "column", "tables": [(None, "foo", "f")]},
+ {"type": "table", "schema": "f"},
+ {"type": "view", "schema": "f"},
+ {"type": "function", "schema": "f"},
+ ]
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM (SELECT * FROM ",
+ "SELECT * FROM foo WHERE EXISTS (SELECT * FROM ",
+ "SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ",
+ ],
+)
+def test_sub_select_table_name_completion(expression):
+ suggestion = suggest_type(expression, expression)
+ assert sorted_dicts(suggestion) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+
+def test_sub_select_col_name_completion():
+ suggestions = suggest_type(
+ "SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT "
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["abc"]},
+ {"type": "column", "tables": [(None, "abc", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+@pytest.mark.xfail
+def test_sub_select_multiple_col_name_completion():
+ suggestions = suggest_type(
+ "SELECT * FROM (SELECT a, FROM abc", "SELECT * FROM (SELECT a, "
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "column", "tables": [(None, "abc", None)]},
+ {"type": "function", "schema": []},
+ ]
+ )
+
+
+def test_sub_select_dot_col_name_completion():
+ suggestions = suggest_type(
+ "SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t."
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "column", "tables": [(None, "tabl", "t")]},
+ {"type": "table", "schema": "t"},
+ {"type": "view", "schema": "t"},
+ {"type": "function", "schema": "t"},
+ ]
+ )
+
+
+@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"])
+@pytest.mark.parametrize("tbl_alias", ["", "foo"])
+def test_join_suggests_tables_and_schemas(tbl_alias, join_type):
+ text = "SELECT * FROM abc {0} {1} JOIN ".format(tbl_alias, join_type)
+ suggestion = suggest_type(text, text)
+ assert sorted_dicts(suggestion) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "SELECT * FROM abc a JOIN def d ON a.",
+ "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.",
+ ],
+)
+def test_join_alias_dot_suggests_cols1(sql):
+ suggestions = suggest_type(sql, sql)
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "column", "tables": [(None, "abc", "a")]},
+ {"type": "table", "schema": "a"},
+ {"type": "view", "schema": "a"},
+ {"type": "function", "schema": "a"},
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "SELECT * FROM abc a JOIN def d ON a.id = d.",
+ "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.",
+ ],
+)
+def test_join_alias_dot_suggests_cols2(sql):
+ suggestions = suggest_type(sql, sql)
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "column", "tables": [(None, "def", "d")]},
+ {"type": "table", "schema": "d"},
+ {"type": "view", "schema": "d"},
+ {"type": "function", "schema": "d"},
+ ]
+ )
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select a.x, b.y from abc a join bcd b on ",
+ "select a.x, b.y from abc a join bcd b on a.id = b.id OR ",
+ ],
+)
+def test_on_suggests_aliases(sql):
+ suggestions = suggest_type(sql, sql)
+ assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}]
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select abc.x, bcd.y from abc join bcd on ",
+ "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ",
+ ],
+)
+def test_on_suggests_tables(sql):
+ suggestions = suggest_type(sql, sql)
+ assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}]
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select a.x, b.y from abc a join bcd b on a.id = ",
+ "select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ",
+ ],
+)
+def test_on_suggests_aliases_right_side(sql):
+ suggestions = suggest_type(sql, sql)
+ assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}]
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select abc.x, bcd.y from abc join bcd on ",
+ "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ",
+ ],
+)
+def test_on_suggests_tables_right_side(sql):
+ suggestions = suggest_type(sql, sql)
+ assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}]
+
+
+@pytest.mark.parametrize("col_list", ["", "col1, "])
+def test_join_using_suggests_common_columns(col_list):
+ text = "select * from abc inner join def using (" + col_list
+ assert suggest_type(text, text) == [
+ {
+ "type": "column",
+ "tables": [(None, "abc", None), (None, "def", None)],
+ "drop_unique": True,
+ }
+ ]
+
+
+def test_2_statements_2nd_current():
+ suggestions = suggest_type(
+ "select * from a; select * from ", "select * from a; select * from "
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+ suggestions = suggest_type(
+ "select * from a; select from b", "select * from a; select "
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["b"]},
+ {"type": "column", "tables": [(None, "b", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+ # Should work even if first statement is invalid
+ suggestions = suggest_type(
+ "select * from; select * from ", "select * from; select * from "
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+
+def test_2_statements_1st_current():
+ suggestions = suggest_type("select * from ; select * from b", "select * from ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+ suggestions = suggest_type("select from a; select * from b", "select ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["a"]},
+ {"type": "column", "tables": [(None, "a", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+def test_3_statements_2nd_current():
+ suggestions = suggest_type(
+ "select * from a; select * from ; select * from c",
+ "select * from a; select * from ",
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+ suggestions = suggest_type(
+ "select * from a; select from b; select * from c", "select * from a; select "
+ )
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "alias", "aliases": ["b"]},
+ {"type": "column", "tables": [(None, "b", None)]},
+ {"type": "function", "schema": []},
+ {"type": "keyword"},
+ ]
+ )
+
+
+def test_create_db_with_template():
+ suggestions = suggest_type(
+ "create database foo with template ", "create database foo with template "
+ )
+
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "database"}])
+
+
+@pytest.mark.parametrize("initial_text", ["", " ", "\t \t"])
+def test_specials_included_for_initial_completion(initial_text):
+ suggestions = suggest_type(initial_text, initial_text)
+
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [{"type": "keyword"}, {"type": "special"}]
+ )
+
+
+def test_specials_not_included_after_initial_token():
+ suggestions = suggest_type("create table foo (dt d", "create table foo (dt d")
+
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "keyword"}])
+
+
+def test_drop_schema_qualified_table_suggests_only_tables():
+ text = "DROP TABLE schema_name.table_name"
+ suggestions = suggest_type(text, text)
+ assert suggestions == [{"type": "table", "schema": "schema_name"}]
+
+
+@pytest.mark.parametrize("text", [",", " ,", "sel ,"])
+def test_handle_pre_completion_comma_gracefully(text):
+ suggestions = suggest_type(text, text)
+
+ assert iter(suggestions)
+
+
+def test_cross_join():
+ text = "select * from v1 cross join v2 JOIN v1.id, "
+ suggestions = suggest_type(text, text)
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+
+@pytest.mark.parametrize("expression", ["SELECT 1 AS ", "SELECT 1 FROM tabl AS "])
+def test_after_as(expression):
+ suggestions = suggest_type(expression, expression)
+ assert set(suggestions) == set()
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "\\. ",
+ "select 1; \\. ",
+ "select 1;\\. ",
+ "select 1 ; \\. ",
+ "source ",
+ "truncate table test; source ",
+ "truncate table test ; source ",
+ "truncate table test;source ",
+ ],
+)
+def test_source_is_file(expression):
+ suggestions = suggest_type(expression, expression)
+ assert suggestions == [{"type": "file_name"}]
diff --git a/tests/test_completion_refresher.py b/tests/test_completion_refresher.py
new file mode 100644
index 0000000..620a364
--- /dev/null
+++ b/tests/test_completion_refresher.py
@@ -0,0 +1,94 @@
+import time
+import pytest
+from mock import Mock, patch
+
+
+@pytest.fixture
+def refresher():
+ from litecli.completion_refresher import CompletionRefresher
+
+ return CompletionRefresher()
+
+
+def test_ctor(refresher):
+ """Refresher object should contain a few handlers.
+
+ :param refresher:
+ :return:
+
+ """
+ assert len(refresher.refreshers) > 0
+ actual_handlers = list(refresher.refreshers.keys())
+ expected_handlers = [
+ "databases",
+ "schemata",
+ "tables",
+ "functions",
+ "special_commands",
+ ]
+ assert expected_handlers == actual_handlers
+
+
+def test_refresh_called_once(refresher):
+ """
+
+ :param refresher:
+ :return:
+ """
+ callbacks = Mock()
+ sqlexecute = Mock()
+
+ with patch.object(refresher, "_bg_refresh") as bg_refresh:
+ actual = refresher.refresh(sqlexecute, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert len(actual) == 1
+ assert len(actual[0]) == 4
+ assert actual[0][3] == "Auto-completion refresh started in the background."
+ bg_refresh.assert_called_with(sqlexecute, callbacks, {})
+
+
+def test_refresh_called_twice(refresher):
+ """If refresh is called a second time, it should be restarted.
+
+ :param refresher:
+ :return:
+
+ """
+ callbacks = Mock()
+
+ sqlexecute = Mock()
+
+ def dummy_bg_refresh(*args):
+ time.sleep(3) # seconds
+
+ refresher._bg_refresh = dummy_bg_refresh
+
+ actual1 = refresher.refresh(sqlexecute, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert len(actual1) == 1
+ assert len(actual1[0]) == 4
+ assert actual1[0][3] == "Auto-completion refresh started in the background."
+
+ actual2 = refresher.refresh(sqlexecute, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert len(actual2) == 1
+ assert len(actual2[0]) == 4
+ assert actual2[0][3] == "Auto-completion refresh restarted."
+
+
+def test_refresh_with_callbacks(refresher):
+ """Callbacks must be called.
+
+ :param refresher:
+
+ """
+ callbacks = [Mock()]
+ sqlexecute_class = Mock()
+ sqlexecute = Mock()
+
+ with patch("litecli.completion_refresher.SQLExecute", sqlexecute_class):
+ # Set refreshers to 0: we're not testing refresh logic here
+ refresher.refreshers = {}
+ refresher.refresh(sqlexecute, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert callbacks[0].call_count == 1
diff --git a/tests/test_dbspecial.py b/tests/test_dbspecial.py
new file mode 100644
index 0000000..5128b5b
--- /dev/null
+++ b/tests/test_dbspecial.py
@@ -0,0 +1,76 @@
+from litecli.packages.completion_engine import suggest_type
+from test_completion_engine import sorted_dicts
+from litecli.packages.special.utils import format_uptime
+
+
+def test_import_first_argument():
+ test_cases = [
+ # text, expecting_arg_idx
+ [".import ", 1],
+ [".import ./da", 1],
+ [".import ./data.csv ", 2],
+ [".import ./data.csv t", 2],
+ [".import ./data.csv `t", 2],
+ ['.import ./data.csv "t', 2],
+ ]
+ for text, expecting_arg_idx in test_cases:
+ suggestions = suggest_type(text, text)
+ if expecting_arg_idx == 1:
+ assert suggestions == [{"type": "file_name"}]
+ else:
+ assert suggestions == [{"type": "table", "schema": []}]
+
+
+def test_u_suggests_databases():
+ suggestions = suggest_type("\\u ", "\\u ")
+ assert sorted_dicts(suggestions) == sorted_dicts([{"type": "database"}])
+
+
+def test_describe_table():
+ suggestions = suggest_type("\\dt", "\\dt ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+
+def test_list_or_show_create_tables():
+ suggestions = suggest_type("\\dt+", "\\dt+ ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
+
+
+def test_format_uptime():
+ seconds = 59
+ assert "59 sec" == format_uptime(seconds)
+
+ seconds = 120
+ assert "2 min 0 sec" == format_uptime(seconds)
+
+ seconds = 54890
+ assert "15 hours 14 min 50 sec" == format_uptime(seconds)
+
+ seconds = 598244
+ assert "6 days 22 hours 10 min 44 sec" == format_uptime(seconds)
+
+ seconds = 522600
+ assert "6 days 1 hour 10 min 0 sec" == format_uptime(seconds)
+
+
+def test_indexes():
+ suggestions = suggest_type(".indexes", ".indexes ")
+ assert sorted_dicts(suggestions) == sorted_dicts(
+ [
+ {"type": "table", "schema": []},
+ {"type": "view", "schema": []},
+ {"type": "schema"},
+ ]
+ )
diff --git a/tests/test_main.py b/tests/test_main.py
new file mode 100644
index 0000000..d4d52af
--- /dev/null
+++ b/tests/test_main.py
@@ -0,0 +1,262 @@
+import os
+from collections import namedtuple
+from textwrap import dedent
+from tempfile import NamedTemporaryFile
+import shutil
+
+import click
+from click.testing import CliRunner
+
+from litecli.main import cli, LiteCli
+from litecli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
+from utils import dbtest, run
+
+test_dir = os.path.abspath(os.path.dirname(__file__))
+project_dir = os.path.dirname(test_dir)
+default_config_file = os.path.join(project_dir, "tests", "liteclirc")
+
+CLI_ARGS = ["--liteclirc", default_config_file, "_test_db"]
+
+
+@dbtest
+def test_execute_arg(executor):
+ run(executor, "create table test (a text)")
+ run(executor, 'insert into test values("abc")')
+
+ sql = "select * from test;"
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql])
+
+ assert result.exit_code == 0
+ assert "abc" in result.output
+
+ result = runner.invoke(cli, args=CLI_ARGS + ["--execute", sql])
+
+ assert result.exit_code == 0
+ assert "abc" in result.output
+
+ expected = "a\nabc\n"
+
+ assert expected in result.output
+
+
+@dbtest
+def test_execute_arg_with_table(executor):
+ run(executor, "create table test (a text)")
+ run(executor, 'insert into test values("abc")')
+
+ sql = "select * from test;"
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql] + ["--table"])
+ expected = "+-----+\n| a |\n+-----+\n| abc |\n+-----+\n"
+
+ assert result.exit_code == 0
+ assert expected in result.output
+
+
+@dbtest
+def test_execute_arg_with_csv(executor):
+ run(executor, "create table test (a text)")
+ run(executor, 'insert into test values("abc")')
+
+ sql = "select * from test;"
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql] + ["--csv"])
+ expected = '"a"\n"abc"\n'
+
+ assert result.exit_code == 0
+ assert expected in "".join(result.output)
+
+
+@dbtest
+def test_batch_mode(executor):
+ run(executor, """create table test(a text)""")
+ run(executor, """insert into test values('abc'), ('def'), ('ghi')""")
+
+ sql = "select count(*) from test;\n" "select * from test limit 1;"
+
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS, input=sql)
+
+ assert result.exit_code == 0
+ assert "count(*)\n3\na\nabc\n" in "".join(result.output)
+
+
+@dbtest
+def test_batch_mode_table(executor):
+ run(executor, """create table test(a text)""")
+ run(executor, """insert into test values('abc'), ('def'), ('ghi')""")
+
+ sql = "select count(*) from test;\n" "select * from test limit 1;"
+
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ["-t"], input=sql)
+
+ expected = dedent(
+ """\
+ +----------+
+ | count(*) |
+ +----------+
+ | 3 |
+ +----------+
+ +-----+
+ | a |
+ +-----+
+ | abc |
+ +-----+"""
+ )
+
+ assert result.exit_code == 0
+ assert expected in result.output
+
+
+@dbtest
+def test_batch_mode_csv(executor):
+ run(executor, """create table test(a text, b text)""")
+ run(executor, """insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')""")
+
+ sql = "select * from test;"
+
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ["--csv"], input=sql)
+
+ expected = '"a","b"\n"abc","de\nf"\n"ghi","jkl"\n'
+
+ assert result.exit_code == 0
+ assert expected in "".join(result.output)
+
+
+def test_help_strings_end_with_periods():
+ """Make sure click options have help text that end with a period."""
+ for param in cli.params:
+ if isinstance(param, click.core.Option):
+ assert hasattr(param, "help")
+ assert param.help.endswith(".")
+
+
+def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
+ global clickoutput
+ clickoutput = ""
+ m = LiteCli(liteclirc=default_config_file)
+
+ class TestOutput:
+ def get_size(self):
+ size = namedtuple("Size", "rows columns")
+ size.columns, size.rows = terminal_size
+ return size
+
+ class TestExecute:
+ host = "test"
+ user = "test"
+ dbname = "test"
+ port = 0
+
+ def server_type(self):
+ return ["test"]
+
+ class PromptBuffer:
+ output = TestOutput()
+
+ m.prompt_app = PromptBuffer()
+ m.sqlexecute = TestExecute()
+ m.explicit_pager = explicit_pager
+
+ def echo_via_pager(s):
+ assert expect_pager
+ global clickoutput
+ clickoutput += s
+
+ def secho(s):
+ assert not expect_pager
+ global clickoutput
+ clickoutput += s + "\n"
+
+ monkeypatch.setattr(click, "echo_via_pager", echo_via_pager)
+ monkeypatch.setattr(click, "secho", secho)
+ m.output(testdata)
+ if clickoutput.endswith("\n"):
+ clickoutput = clickoutput[:-1]
+ assert clickoutput == "\n".join(testdata)
+
+
+def test_conditional_pager(monkeypatch):
+ testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split(
+ " "
+ )
+ # User didn't set pager, output doesn't fit screen -> pager
+ output(
+ monkeypatch,
+ terminal_size=(5, 10),
+ testdata=testdata,
+ explicit_pager=False,
+ expect_pager=True,
+ )
+ # User didn't set pager, output fits screen -> no pager
+ output(
+ monkeypatch,
+ terminal_size=(20, 20),
+ testdata=testdata,
+ explicit_pager=False,
+ expect_pager=False,
+ )
+ # User manually configured pager, output doesn't fit screen -> pager
+ output(
+ monkeypatch,
+ terminal_size=(5, 10),
+ testdata=testdata,
+ explicit_pager=True,
+ expect_pager=True,
+ )
+ # User manually configured pager, output fit screen -> pager
+ output(
+ monkeypatch,
+ terminal_size=(20, 20),
+ testdata=testdata,
+ explicit_pager=True,
+ expect_pager=True,
+ )
+
+ SPECIAL_COMMANDS["nopager"].handler()
+ output(
+ monkeypatch,
+ terminal_size=(5, 10),
+ testdata=testdata,
+ explicit_pager=False,
+ expect_pager=False,
+ )
+ SPECIAL_COMMANDS["pager"].handler("")
+
+
+def test_reserved_space_is_integer():
+ """Make sure that reserved space is returned as an integer."""
+
+ def stub_terminal_size():
+ return (5, 5)
+
+ old_func = shutil.get_terminal_size
+
+ shutil.get_terminal_size = stub_terminal_size
+ lc = LiteCli()
+ assert isinstance(lc.get_reserved_space(), int)
+ shutil.get_terminal_size = old_func
+
+
+@dbtest
+def test_import_command(executor):
+ data_file = os.path.join(project_dir, "tests", "data", "import_data.csv")
+ run(executor, """create table tbl1(one varchar(10), two smallint)""")
+
+ # execute
+ run(executor, """.import %s tbl1""" % data_file)
+
+ # verify
+ sql = "select * from tbl1;"
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ["--csv"], input=sql)
+
+ expected = """one","two"
+"t1","11"
+"t2","22"
+"""
+ assert result.exit_code == 0
+ assert expected in "".join(result.output)
diff --git a/tests/test_parseutils.py b/tests/test_parseutils.py
new file mode 100644
index 0000000..cad7a8c
--- /dev/null
+++ b/tests/test_parseutils.py
@@ -0,0 +1,131 @@
+import pytest
+from litecli.packages.parseutils import (
+ extract_tables,
+ query_starts_with,
+ queries_start_with,
+ is_destructive,
+)
+
+
+def test_empty_string():
+ tables = extract_tables("")
+ assert tables == []
+
+
+def test_simple_select_single_table():
+ tables = extract_tables("select * from abc")
+ assert tables == [(None, "abc", None)]
+
+
+def test_simple_select_single_table_schema_qualified():
+ tables = extract_tables("select * from abc.def")
+ assert tables == [("abc", "def", None)]
+
+
+def test_simple_select_multiple_tables():
+ tables = extract_tables("select * from abc, def")
+ assert sorted(tables) == [(None, "abc", None), (None, "def", None)]
+
+
+def test_simple_select_multiple_tables_schema_qualified():
+ tables = extract_tables("select * from abc.def, ghi.jkl")
+ assert sorted(tables) == [("abc", "def", None), ("ghi", "jkl", None)]
+
+
+def test_simple_select_with_cols_single_table():
+ tables = extract_tables("select a,b from abc")
+ assert tables == [(None, "abc", None)]
+
+
+def test_simple_select_with_cols_single_table_schema_qualified():
+ tables = extract_tables("select a,b from abc.def")
+ assert tables == [("abc", "def", None)]
+
+
+def test_simple_select_with_cols_multiple_tables():
+ tables = extract_tables("select a,b from abc, def")
+ assert sorted(tables) == [(None, "abc", None), (None, "def", None)]
+
+
+def test_simple_select_with_cols_multiple_tables_with_schema():
+ tables = extract_tables("select a,b from abc.def, def.ghi")
+ assert sorted(tables) == [("abc", "def", None), ("def", "ghi", None)]
+
+
+def test_select_with_hanging_comma_single_table():
+ tables = extract_tables("select a, from abc")
+ assert tables == [(None, "abc", None)]
+
+
+def test_select_with_hanging_comma_multiple_tables():
+ tables = extract_tables("select a, from abc, def")
+ assert sorted(tables) == [(None, "abc", None), (None, "def", None)]
+
+
+def test_select_with_hanging_period_multiple_tables():
+ tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2")
+ assert sorted(tables) == [(None, "tabl1", "t1"), (None, "tabl2", "t2")]
+
+
+def test_simple_insert_single_table():
+ tables = extract_tables('insert into abc (id, name) values (1, "def")')
+
+ # sqlparse mistakenly assigns an alias to the table
+ # assert tables == [(None, 'abc', None)]
+ assert tables == [(None, "abc", "abc")]
+
+
+@pytest.mark.xfail
+def test_simple_insert_single_table_schema_qualified():
+ tables = extract_tables('insert into abc.def (id, name) values (1, "def")')
+ assert tables == [("abc", "def", None)]
+
+
+def test_simple_update_table():
+ tables = extract_tables("update abc set id = 1")
+ assert tables == [(None, "abc", None)]
+
+
+def test_simple_update_table_with_schema():
+ tables = extract_tables("update abc.def set id = 1")
+ assert tables == [("abc", "def", None)]
+
+
+def test_join_table():
+ tables = extract_tables("SELECT * FROM abc a JOIN def d ON a.id = d.num")
+ assert sorted(tables) == [(None, "abc", "a"), (None, "def", "d")]
+
+
+def test_join_table_schema_qualified():
+ tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num")
+ assert tables == [("abc", "def", "x"), ("ghi", "jkl", "y")]
+
+
+def test_join_as_table():
+ tables = extract_tables("SELECT * FROM my_table AS m WHERE m.a > 5")
+ assert tables == [(None, "my_table", "m")]
+
+
+def test_query_starts_with():
+ query = "USE test;"
+ assert query_starts_with(query, ("use",)) is True
+
+ query = "DROP DATABASE test;"
+ assert query_starts_with(query, ("use",)) is False
+
+
+def test_query_starts_with_comment():
+ query = "# comment\nUSE test;"
+ assert query_starts_with(query, ("use",)) is True
+
+
+def test_queries_start_with():
+ sql = "# comment\n" "show databases;" "use foo;"
+ assert queries_start_with(sql, ("show", "select")) is True
+ assert queries_start_with(sql, ("use", "drop")) is True
+ assert queries_start_with(sql, ("delete", "update")) is False
+
+
+def test_is_destructive():
+ sql = "use test;\n" "show databases;\n" "drop database foo;"
+ assert is_destructive(sql) is True
diff --git a/tests/test_prompt_utils.py b/tests/test_prompt_utils.py
new file mode 100644
index 0000000..2de74ce
--- /dev/null
+++ b/tests/test_prompt_utils.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+
+
+import click
+
+from litecli.packages.prompt_utils import confirm_destructive_query
+
+
+def test_confirm_destructive_query_notty():
+ stdin = click.get_text_stream("stdin")
+ assert stdin.isatty() is False
+
+ sql = "drop database foo;"
+ assert confirm_destructive_query(sql) is None
diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py
new file mode 100644
index 0000000..e532118
--- /dev/null
+++ b/tests/test_smart_completion_public_schema_only.py
@@ -0,0 +1,432 @@
+# coding: utf-8
+from __future__ import unicode_literals
+import pytest
+from mock import patch
+from prompt_toolkit.completion import Completion
+from prompt_toolkit.document import Document
+
+metadata = {
+ "users": ["id", "email", "first_name", "last_name"],
+ "orders": ["id", "ordered_date", "status"],
+ "select": ["id", "insert", "ABC"],
+ "réveillé": ["id", "insert", "ABC"],
+}
+
+
+@pytest.fixture
+def completer():
+
+ import litecli.sqlcompleter as sqlcompleter
+
+ comp = sqlcompleter.SQLCompleter()
+
+ tables, columns = [], []
+
+ for table, cols in metadata.items():
+ tables.append((table,))
+ columns.extend([(table, col) for col in cols])
+
+ comp.set_dbname("test")
+ comp.extend_schemata("test")
+ comp.extend_relations(tables, kind="tables")
+ comp.extend_columns(columns, kind="tables")
+
+ return comp
+
+
+@pytest.fixture
+def complete_event():
+ from mock import Mock
+
+ return Mock()
+
+
+def test_empty_string_completion(completer, complete_event):
+ text = ""
+ position = 0
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert list(map(Completion, sorted(completer.keywords))) == result
+
+
+def test_select_keyword_completion(completer, complete_event):
+ text = "SEL"
+ position = len("SEL")
+ result = completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ assert list(result) == list([Completion(text="SELECT", start_position=-3)])
+
+
+def test_table_completion(completer, complete_event):
+ text = "SELECT * FROM "
+ position = len(text)
+ result = completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ assert list(result) == list(
+ [
+ Completion(text="`réveillé`", start_position=0),
+ Completion(text="`select`", start_position=0),
+ Completion(text="orders", start_position=0),
+ Completion(text="users", start_position=0),
+ ]
+ )
+
+
+def test_function_name_completion(completer, complete_event):
+ text = "SELECT MA"
+ position = len("SELECT MA")
+ result = completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ assert list(result) == list(
+ [
+ Completion(text="MAX", start_position=-2),
+ Completion(text="MATCH", start_position=-2),
+ ]
+ )
+
+
+def test_suggested_column_names(completer, complete_event):
+ """Suggest column and function names when selecting from table.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = "SELECT from users"
+ position = len("SELECT ")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == list(
+ [
+ Completion(text="*", start_position=0),
+ Completion(text="email", start_position=0),
+ Completion(text="first_name", start_position=0),
+ Completion(text="id", start_position=0),
+ Completion(text="last_name", start_position=0),
+ ]
+ + list(map(Completion, completer.functions))
+ + [Completion(text="users", start_position=0)]
+ + list(map(Completion, sorted(completer.keywords)))
+ )
+
+
+def test_suggested_column_names_in_function(completer, complete_event):
+ """Suggest column and function names when selecting multiple columns from
+ table.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = "SELECT MAX( from users"
+ position = len("SELECT MAX(")
+ result = completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ assert list(result) == list(
+ [
+ Completion(text="*", start_position=0),
+ Completion(text="email", start_position=0),
+ Completion(text="first_name", start_position=0),
+ Completion(text="id", start_position=0),
+ Completion(text="last_name", start_position=0),
+ ]
+ )
+
+
+def test_suggested_column_names_with_table_dot(completer, complete_event):
+ """Suggest column names on table name and dot.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = "SELECT users. from users"
+ position = len("SELECT users.")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == list(
+ [
+ Completion(text="*", start_position=0),
+ Completion(text="email", start_position=0),
+ Completion(text="first_name", start_position=0),
+ Completion(text="id", start_position=0),
+ Completion(text="last_name", start_position=0),
+ ]
+ )
+
+
+def test_suggested_column_names_with_alias(completer, complete_event):
+ """Suggest column names on table alias and dot.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = "SELECT u. from users u"
+ position = len("SELECT u.")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == list(
+ [
+ Completion(text="*", start_position=0),
+ Completion(text="email", start_position=0),
+ Completion(text="first_name", start_position=0),
+ Completion(text="id", start_position=0),
+ Completion(text="last_name", start_position=0),
+ ]
+ )
+
+
+def test_suggested_multiple_column_names(completer, complete_event):
+ """Suggest column and function names when selecting multiple columns from
+ table.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = "SELECT id, from users u"
+ position = len("SELECT id, ")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == list(
+ [
+ Completion(text="*", start_position=0),
+ Completion(text="email", start_position=0),
+ Completion(text="first_name", start_position=0),
+ Completion(text="id", start_position=0),
+ Completion(text="last_name", start_position=0),
+ ]
+ + list(map(Completion, completer.functions))
+ + [Completion(text="u", start_position=0)]
+ + list(map(Completion, sorted(completer.keywords)))
+ )
+
+
+def test_suggested_multiple_column_names_with_alias(completer, complete_event):
+ """Suggest column names on table alias and dot when selecting multiple
+ columns from table.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = "SELECT u.id, u. from users u"
+ position = len("SELECT u.id, u.")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == list(
+ [
+ Completion(text="*", start_position=0),
+ Completion(text="email", start_position=0),
+ Completion(text="first_name", start_position=0),
+ Completion(text="id", start_position=0),
+ Completion(text="last_name", start_position=0),
+ ]
+ )
+
+
+def test_suggested_multiple_column_names_with_dot(completer, complete_event):
+ """Suggest column names on table names and dot when selecting multiple
+ columns from table.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = "SELECT users.id, users. from users u"
+ position = len("SELECT users.id, users.")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == list(
+ [
+ Completion(text="*", start_position=0),
+ Completion(text="email", start_position=0),
+ Completion(text="first_name", start_position=0),
+ Completion(text="id", start_position=0),
+ Completion(text="last_name", start_position=0),
+ ]
+ )
+
+
+def test_suggested_aliases_after_on(completer, complete_event):
+ text = "SELECT u.name, o.id FROM users u JOIN orders o ON "
+ position = len("SELECT u.name, o.id FROM users u JOIN orders o ON ")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == list(
+ [Completion(text="o", start_position=0), Completion(text="u", start_position=0)]
+ )
+
+
+def test_suggested_aliases_after_on_right_side(completer, complete_event):
+ text = "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = "
+ position = len("SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == list(
+ [Completion(text="o", start_position=0), Completion(text="u", start_position=0)]
+ )
+
+
+def test_suggested_tables_after_on(completer, complete_event):
+ text = "SELECT users.name, orders.id FROM users JOIN orders ON "
+ position = len("SELECT users.name, orders.id FROM users JOIN orders ON ")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == list(
+ [
+ Completion(text="orders", start_position=0),
+ Completion(text="users", start_position=0),
+ ]
+ )
+
+
+def test_suggested_tables_after_on_right_side(completer, complete_event):
+ text = "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = "
+ position = len(
+ "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = "
+ )
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert list(result) == list(
+ [
+ Completion(text="orders", start_position=0),
+ Completion(text="users", start_position=0),
+ ]
+ )
+
+
+def test_table_names_after_from(completer, complete_event):
+ text = "SELECT * FROM "
+ position = len("SELECT * FROM ")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert list(result) == list(
+ [
+ Completion(text="`réveillé`", start_position=0),
+ Completion(text="`select`", start_position=0),
+ Completion(text="orders", start_position=0),
+ Completion(text="users", start_position=0),
+ ]
+ )
+
+
+def test_auto_escaped_col_names(completer, complete_event):
+ text = "SELECT from `select`"
+ position = len("SELECT ")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert (
+ result
+ == [
+ Completion(text="*", start_position=0),
+ Completion(text="`ABC`", start_position=0),
+ Completion(text="`insert`", start_position=0),
+ Completion(text="id", start_position=0),
+ ]
+ + list(map(Completion, completer.functions))
+ + [Completion(text="`select`", start_position=0)]
+ + list(map(Completion, sorted(completer.keywords)))
+ )
+
+
+def test_un_escaped_table_names(completer, complete_event):
+ text = "SELECT from réveillé"
+ position = len("SELECT ")
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == list(
+ [
+ Completion(text="*", start_position=0),
+ Completion(text="`ABC`", start_position=0),
+ Completion(text="`insert`", start_position=0),
+ Completion(text="id", start_position=0),
+ ]
+ + list(map(Completion, completer.functions))
+ + [Completion(text="réveillé", start_position=0)]
+ + list(map(Completion, sorted(completer.keywords)))
+ )
+
+
+def dummy_list_path(dir_name):
+ dirs = {
+ "/": ["dir1", "file1.sql", "file2.sql"],
+ "/dir1": ["subdir1", "subfile1.sql", "subfile2.sql"],
+ "/dir1/subdir1": ["lastfile.sql"],
+ }
+ return dirs.get(dir_name, [])
+
+
+@patch("litecli.packages.filepaths.list_path", new=dummy_list_path)
+@pytest.mark.parametrize(
+ "text,expected",
+ [
+ ("source ", [(".", 0), ("..", 0), ("/", 0), ("~", 0)]),
+ ("source /", [("dir1", 0), ("file1.sql", 0), ("file2.sql", 0)]),
+ ("source /dir1/", [("subdir1", 0), ("subfile1.sql", 0), ("subfile2.sql", 0)]),
+ ("source /dir1/subdir1/", [("lastfile.sql", 0)]),
+ ],
+)
+def test_file_name_completion(completer, complete_event, text, expected):
+ position = len(text)
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ expected = list([Completion(txt, pos) for txt, pos in expected])
+ assert result == expected
diff --git a/tests/test_sqlexecute.py b/tests/test_sqlexecute.py
new file mode 100644
index 0000000..e559bc6
--- /dev/null
+++ b/tests/test_sqlexecute.py
@@ -0,0 +1,405 @@
+# coding=UTF-8
+
+import os
+
+import pytest
+
+from utils import run, dbtest, set_expanded_output, is_expanded_output
+from sqlite3 import OperationalError, ProgrammingError
+
+
+def assert_result_equal(
+ result,
+ title=None,
+ rows=None,
+ headers=None,
+ status=None,
+ auto_status=True,
+ assert_contains=False,
+):
+ """Assert that an sqlexecute.run() result matches the expected values."""
+ if status is None and auto_status and rows:
+ status = "{} row{} in set".format(len(rows), "s" if len(rows) > 1 else "")
+ fields = {"title": title, "rows": rows, "headers": headers, "status": status}
+
+ if assert_contains:
+ # Do a loose match on the results using the *in* operator.
+ for key, field in fields.items():
+ if field:
+ assert field in result[0][key]
+ else:
+ # Do an exact match on the fields.
+ assert result == [fields]
+
+
+@dbtest
+def test_conn(executor):
+ run(executor, """create table test(a text)""")
+ run(executor, """insert into test values('abc')""")
+ results = run(executor, """select * from test""")
+
+ assert_result_equal(results, headers=["a"], rows=[("abc",)])
+
+
+@dbtest
+def test_bools(executor):
+ run(executor, """create table test(a boolean)""")
+ run(executor, """insert into test values(1)""")
+ results = run(executor, """select * from test""")
+
+ assert_result_equal(results, headers=["a"], rows=[(1,)])
+
+
+@dbtest
+def test_binary(executor):
+ run(executor, """create table foo(blb BLOB NOT NULL)""")
+ run(executor, """INSERT INTO foo VALUES ('\x01\x01\x01\n')""")
+ results = run(executor, """select * from foo""")
+
+ expected = "\x01\x01\x01\n"
+
+ assert_result_equal(results, headers=["blb"], rows=[(expected,)])
+
+
+## Failing in Travis for some unknown reason.
+# @dbtest
+# def test_table_and_columns_query(executor):
+# run(executor, "create table a(x text, y text)")
+# run(executor, "create table b(z text)")
+
+# assert set(executor.tables()) == set([("a",), ("b",)])
+# assert set(executor.table_columns()) == set([("a", "x"), ("a", "y"), ("b", "z")])
+
+
+@dbtest
+def test_database_list(executor):
+ databases = executor.databases()
+ assert "main" in list(databases)
+
+
+@dbtest
+def test_invalid_syntax(executor):
+ with pytest.raises(OperationalError) as excinfo:
+ run(executor, "invalid syntax!")
+ assert "syntax error" in str(excinfo.value)
+
+
+@dbtest
+def test_invalid_column_name(executor):
+ with pytest.raises(OperationalError) as excinfo:
+ run(executor, "select invalid command")
+ assert "no such column: invalid" in str(excinfo.value)
+
+
+@dbtest
+def test_unicode_support_in_output(executor):
+ run(executor, "create table unicodechars(t text)")
+ run(executor, u"insert into unicodechars (t) values ('é')")
+
+ # See issue #24, this raises an exception without proper handling
+ results = run(executor, u"select * from unicodechars")
+ assert_result_equal(results, headers=["t"], rows=[(u"é",)])
+
+
+@dbtest
+def test_invalid_unicode_values_dont_choke(executor):
+ run(executor, "create table unicodechars(t text)")
+ # \xc3 is not a valid utf-8 char. But we can insert it into the database
+ # which can break querying if not handled correctly.
+ run(executor, u"insert into unicodechars (t) values (cast(x'c3' as text))")
+
+ results = run(executor, u"select * from unicodechars")
+ assert_result_equal(results, headers=["t"], rows=[("\\xc3",)])
+
+
+@dbtest
+def test_multiple_queries_same_line(executor):
+ results = run(executor, "select 'foo'; select 'bar'")
+
+ expected = [
+ {
+ "title": None,
+ "headers": ["'foo'"],
+ "rows": [(u"foo",)],
+ "status": "1 row in set",
+ },
+ {
+ "title": None,
+ "headers": ["'bar'"],
+ "rows": [(u"bar",)],
+ "status": "1 row in set",
+ },
+ ]
+ assert expected == results
+
+
+@dbtest
+def test_multiple_queries_same_line_syntaxerror(executor):
+ with pytest.raises(OperationalError) as excinfo:
+ run(executor, "select 'foo'; invalid syntax")
+ assert "syntax error" in str(excinfo.value)
+
+
+@dbtest
+def test_favorite_query(executor):
+ set_expanded_output(False)
+ run(executor, "create table test(a text)")
+ run(executor, "insert into test values('abc')")
+ run(executor, "insert into test values('def')")
+
+ results = run(executor, "\\fs test-a select * from test where a like 'a%'")
+ assert_result_equal(results, status="Saved.")
+
+ results = run(executor, "\\f+ test-a")
+ assert_result_equal(
+ results,
+ title="> select * from test where a like 'a%'",
+ headers=["a"],
+ rows=[("abc",)],
+ auto_status=False,
+ )
+
+ results = run(executor, "\\fd test-a")
+ assert_result_equal(results, status="test-a: Deleted")
+
+
+@dbtest
+def test_bind_parameterized_favorite_query(executor):
+ set_expanded_output(False)
+ run(executor, "create table test(name text, id integer)")
+ run(executor, "insert into test values('def', 2)")
+ run(executor, "insert into test values('two words', 3)")
+
+ results = run(executor, "\\fs q_param select * from test where name=?")
+ assert_result_equal(results, status="Saved.")
+
+ results = run(executor, "\\f+ q_param def")
+ assert_result_equal(
+ results,
+ title="> select * from test where name=?",
+ headers=["name", "id"],
+ rows=[("def", 2)],
+ auto_status=False,
+ )
+
+ results = run(executor, "\\f+ q_param 'two words'")
+ assert_result_equal(
+ results,
+ title="> select * from test where name=?",
+ headers=["name", "id"],
+ rows=[("two words", 3)],
+ auto_status=False,
+ )
+
+ with pytest.raises(ProgrammingError):
+ results = run(executor, "\\f+ q_param")
+
+ with pytest.raises(ProgrammingError):
+ results = run(executor, "\\f+ q_param 1 2")
+
+
+@dbtest
+def test_verbose_feature_of_favorite_query(executor):
+ set_expanded_output(False)
+ run(executor, "create table test(a text, id integer)")
+ run(executor, "insert into test values('abc', 1)")
+ run(executor, "insert into test values('def', 2)")
+
+ results = run(executor, "\\fs sh_param select * from test where id=$1")
+ assert_result_equal(results, status="Saved.")
+
+ results = run(executor, "\\f sh_param 1")
+ assert_result_equal(
+ results,
+ title=None,
+ headers=["a", "id"],
+ rows=[("abc", 1)],
+ auto_status=False,
+ )
+
+ results = run(executor, "\\f+ sh_param 1")
+ assert_result_equal(
+ results,
+ title="> select * from test where id=1",
+ headers=["a", "id"],
+ rows=[("abc", 1)],
+ auto_status=False,
+ )
+
+
+@dbtest
+def test_shell_parameterized_favorite_query(executor):
+ set_expanded_output(False)
+ run(executor, "create table test(a text, id integer)")
+ run(executor, "insert into test values('abc', 1)")
+ run(executor, "insert into test values('def', 2)")
+
+ results = run(executor, "\\fs sh_param select * from test where id=$1")
+ assert_result_equal(results, status="Saved.")
+
+ results = run(executor, "\\f+ sh_param 1")
+ assert_result_equal(
+ results,
+ title="> select * from test where id=1",
+ headers=["a", "id"],
+ rows=[("abc", 1)],
+ auto_status=False,
+ )
+
+ results = run(executor, "\\f+ sh_param")
+ assert_result_equal(
+ results,
+ title=None,
+ headers=None,
+ rows=None,
+ status="missing substitution for $1 in query:\n select * from test where id=$1",
+ )
+
+ results = run(executor, "\\f+ sh_param 1 2")
+ assert_result_equal(
+ results,
+ title=None,
+ headers=None,
+ rows=None,
+ status="Too many arguments.\nQuery does not have enough place holders to substitute.\nselect * from test where id=1",
+ )
+
+
+@dbtest
+def test_favorite_query_multiple_statement(executor):
+ set_expanded_output(False)
+ run(executor, "create table test(a text)")
+ run(executor, "insert into test values('abc')")
+ run(executor, "insert into test values('def')")
+
+ results = run(
+ executor,
+ "\\fs test-ad select * from test where a like 'a%'; "
+ "select * from test where a like 'd%'",
+ )
+ assert_result_equal(results, status="Saved.")
+
+ results = run(executor, "\\f+ test-ad")
+ expected = [
+ {
+ "title": "> select * from test where a like 'a%'",
+ "headers": ["a"],
+ "rows": [("abc",)],
+ "status": None,
+ },
+ {
+ "title": "> select * from test where a like 'd%'",
+ "headers": ["a"],
+ "rows": [("def",)],
+ "status": None,
+ },
+ ]
+ assert expected == results
+
+ results = run(executor, "\\fd test-ad")
+ assert_result_equal(results, status="test-ad: Deleted")
+
+
+@dbtest
+def test_favorite_query_expanded_output(executor):
+ set_expanded_output(False)
+ run(executor, """create table test(a text)""")
+ run(executor, """insert into test values('abc')""")
+
+ results = run(executor, "\\fs test-ae select * from test")
+ assert_result_equal(results, status="Saved.")
+
+ results = run(executor, "\\f+ test-ae \G")
+ assert is_expanded_output() is True
+ assert_result_equal(
+ results,
+ title="> select * from test",
+ headers=["a"],
+ rows=[("abc",)],
+ auto_status=False,
+ )
+
+ set_expanded_output(False)
+
+ results = run(executor, "\\fd test-ae")
+ assert_result_equal(results, status="test-ae: Deleted")
+
+
+@dbtest
+def test_special_command(executor):
+ results = run(executor, "\\?")
+ assert_result_equal(
+ results,
+ rows=("quit", "\\q", "Quit."),
+ headers="Command",
+ assert_contains=True,
+ auto_status=False,
+ )
+
+
+@dbtest
+def test_cd_command_without_a_folder_name(executor):
+ results = run(executor, "system cd")
+ assert_result_equal(results, status="No folder name was provided.")
+
+
+@dbtest
+def test_system_command_not_found(executor):
+ results = run(executor, "system xyz")
+ assert_result_equal(
+ results, status="OSError: No such file or directory", assert_contains=True
+ )
+
+
+@dbtest
+def test_system_command_output(executor):
+ test_dir = os.path.abspath(os.path.dirname(__file__))
+ test_file_path = os.path.join(test_dir, "test.txt")
+ results = run(executor, "system cat {0}".format(test_file_path))
+ assert_result_equal(results, status="litecli is awesome!\n")
+
+
+@dbtest
+def test_cd_command_current_dir(executor):
+ test_path = os.path.abspath(os.path.dirname(__file__))
+ run(executor, "system cd {0}".format(test_path))
+ assert os.getcwd() == test_path
+ run(executor, "system cd ..")
+
+
+@dbtest
+def test_unicode_support(executor):
+ results = run(executor, u"SELECT '日本語' AS japanese;")
+ assert_result_equal(results, headers=["japanese"], rows=[(u"日本語",)])
+
+
+@dbtest
+def test_timestamp_null(executor):
+ run(executor, """create table ts_null(a timestamp null)""")
+ run(executor, """insert into ts_null values(null)""")
+ results = run(executor, """select * from ts_null""")
+ assert_result_equal(results, headers=["a"], rows=[(None,)])
+
+
+@dbtest
+def test_datetime_null(executor):
+ run(executor, """create table dt_null(a datetime null)""")
+ run(executor, """insert into dt_null values(null)""")
+ results = run(executor, """select * from dt_null""")
+ assert_result_equal(results, headers=["a"], rows=[(None,)])
+
+
+@dbtest
+def test_date_null(executor):
+ run(executor, """create table date_null(a date null)""")
+ run(executor, """insert into date_null values(null)""")
+ results = run(executor, """select * from date_null""")
+ assert_result_equal(results, headers=["a"], rows=[(None,)])
+
+
+@dbtest
+def test_time_null(executor):
+ run(executor, """create table time_null(a time null)""")
+ run(executor, """insert into time_null values(null)""")
+ results = run(executor, """select * from time_null""")
+ assert_result_equal(results, headers=["a"], rows=[(None,)])
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 0000000..41bac9b
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,96 @@
+# -*- coding: utf-8 -*-
+
+import os
+import time
+import signal
+import platform
+import multiprocessing
+from contextlib import closing
+
+import sqlite3
+import pytest
+
+from litecli.main import special
+
+DATABASE = os.getenv("PYTEST_DATABASE", "test.sqlite3")
+
+
+def db_connection(dbname=":memory:"):
+ conn = sqlite3.connect(database=dbname, isolation_level=None)
+ return conn
+
+
+try:
+ db_connection()
+ CAN_CONNECT_TO_DB = True
+except Exception as ex:
+ CAN_CONNECT_TO_DB = False
+
+dbtest = pytest.mark.skipif(
+ not CAN_CONNECT_TO_DB, reason="Error creating sqlite connection"
+)
+
+
+def create_db(dbname):
+ with closing(db_connection().cursor()) as cur:
+ try:
+ cur.execute("""DROP DATABASE IF EXISTS _test_db""")
+ cur.execute("""CREATE DATABASE _test_db""")
+ except:
+ pass
+
+
+def drop_tables(dbname):
+ with closing(db_connection().cursor()) as cur:
+ try:
+ cur.execute("""DROP DATABASE IF EXISTS _test_db""")
+ except:
+ pass
+
+
+def run(executor, sql, rows_as_list=True):
+ """Return string output for the sql to be run."""
+ result = []
+
+ for title, rows, headers, status in executor.run(sql):
+ rows = list(rows) if (rows_as_list and rows) else rows
+ result.append(
+ {"title": title, "rows": rows, "headers": headers, "status": status}
+ )
+
+ return result
+
+
+def set_expanded_output(is_expanded):
+ """Pass-through for the tests."""
+ return special.set_expanded_output(is_expanded)
+
+
+def is_expanded_output():
+ """Pass-through for the tests."""
+ return special.is_expanded_output()
+
+
+def send_ctrl_c_to_pid(pid, wait_seconds):
+ """Sends a Ctrl-C like signal to the given `pid` after `wait_seconds`
+ seconds."""
+ time.sleep(wait_seconds)
+ system_name = platform.system()
+ if system_name == "Windows":
+ os.kill(pid, signal.CTRL_C_EVENT)
+ else:
+ os.kill(pid, signal.SIGINT)
+
+
+def send_ctrl_c(wait_seconds):
+ """Create a process that sends a Ctrl-C like signal to the current process
+ after `wait_seconds` seconds.
+
+ Returns the `multiprocessing.Process` created.
+
+ """
+ ctrl_c_process = multiprocessing.Process(
+ target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds)
+ )
+ ctrl_c_process.start()
+ return ctrl_c_process
diff --git a/tox.ini b/tox.ini
new file mode 100644
index 0000000..c1c81ea
--- /dev/null
+++ b/tox.ini
@@ -0,0 +1,11 @@
+[tox]
+envlist = py37, py38, py39, py310
+
+[testenv]
+deps = pytest
+ mock
+ pexpect
+ behave
+ coverage
+commands = python setup.py test
+passenv = PYTEST_DATABASE