summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-04 17:39:33 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-04 17:39:33 +0000
commit98aa4c820d8dd9e1090590242ab408c1221b0ba8 (patch)
tree70b027a809ee8f8fea766316f8d52f56b1dc6f32
parentInitial commit. (diff)
downloadmycli-98aa4c820d8dd9e1090590242ab408c1221b0ba8.tar.xz
mycli-98aa4c820d8dd9e1090590242ab408c1221b0ba8.zip
Adding upstream version 1.26.1.upstream/1.26.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
-rw-r--r--.coveragerc3
-rw-r--r--.git-blame-ignore-revs0
-rw-r--r--.github/PULL_REQUEST_TEMPLATE.md9
-rw-r--r--.github/workflows/ci.yml64
-rw-r--r--.gitignore14
-rw-r--r--AUTHORS.rst3
-rw-r--r--CONTRIBUTING.md167
-rw-r--r--LICENSE.txt34
-rw-r--r--MANIFEST.in6
-rw-r--r--README.md236
-rw-r--r--SPONSORS.rst3
-rw-r--r--changelog.md935
-rw-r--r--doc/key_bindings.rst65
-rw-r--r--mycli/AUTHORS101
-rw-r--r--mycli/SPONSORS31
-rw-r--r--mycli/__init__.py1
-rw-r--r--mycli/clibuffer.py55
-rw-r--r--mycli/clistyle.py152
-rw-r--r--mycli/clitoolbar.py58
-rw-r--r--mycli/compat.py6
-rw-r--r--mycli/completion_refresher.py123
-rw-r--r--mycli/config.py344
-rw-r--r--mycli/key_bindings.py131
-rw-r--r--mycli/lexer.py12
-rw-r--r--mycli/magic.py54
-rwxr-xr-xmycli/main.py1468
-rw-r--r--mycli/myclirc159
-rw-r--r--mycli/packages/__init__.py0
-rw-r--r--mycli/packages/completion_engine.py294
-rw-r--r--mycli/packages/filepaths.py106
-rw-r--r--mycli/packages/paramiko_stub/__init__.py28
-rw-r--r--mycli/packages/parseutils.py266
-rw-r--r--mycli/packages/prompt_utils.py54
-rw-r--r--mycli/packages/special/__init__.py10
-rw-r--r--mycli/packages/special/dbcommands.py162
-rw-r--r--mycli/packages/special/delimitercommand.py80
-rw-r--r--mycli/packages/special/favoritequeries.py63
-rw-r--r--mycli/packages/special/iocommands.py543
-rw-r--r--mycli/packages/special/main.py120
-rw-r--r--mycli/packages/special/utils.py46
-rw-r--r--mycli/packages/tabular_output/__init__.py0
-rw-r--r--mycli/packages/tabular_output/sql_format.py62
-rw-r--r--mycli/sqlcompleter.py435
-rw-r--r--mycli/sqlexecute.py356
-rw-r--r--pytest.ini2
-rwxr-xr-xrelease.py119
-rw-r--r--requirements-dev.txt17
-rw-r--r--screenshots/main.gifbin0 -> 131158 bytes
-rw-r--r--screenshots/tables.pngbin0 -> 61064 bytes
-rw-r--r--setup.cfg18
-rwxr-xr-xsetup.py127
-rw-r--r--test/__init__.py0
-rw-r--r--test/conftest.py29
-rw-r--r--test/features/__init__.py0
-rw-r--r--test/features/auto_vertical.feature12
-rw-r--r--test/features/basic_commands.feature19
-rw-r--r--test/features/connection.feature35
-rw-r--r--test/features/crud_database.feature30
-rw-r--r--test/features/crud_table.feature49
-rw-r--r--test/features/db_utils.py93
-rw-r--r--test/features/environment.py176
-rw-r--r--test/features/fixture_data/help.txt24
-rw-r--r--test/features/fixture_data/help_commands.txt31
-rw-r--r--test/features/fixture_utils.py29
-rw-r--r--test/features/iocommands.feature47
-rw-r--r--test/features/named_queries.feature24
-rw-r--r--test/features/specials.feature7
-rw-r--r--test/features/steps/__init__.py0
-rw-r--r--test/features/steps/auto_vertical.py46
-rw-r--r--test/features/steps/basic_commands.py100
-rw-r--r--test/features/steps/connection.py71
-rw-r--r--test/features/steps/crud_database.py115
-rw-r--r--test/features/steps/crud_table.py112
-rw-r--r--test/features/steps/iocommands.py105
-rw-r--r--test/features/steps/named_queries.py90
-rw-r--r--test/features/steps/specials.py27
-rw-r--r--test/features/steps/utils.py12
-rw-r--r--test/features/steps/wrappers.py117
-rwxr-xr-xtest/features/wrappager.py16
-rw-r--r--test/myclirc12
-rw-r--r--test/mylogin.cnfbin0 -> 156 bytes
-rw-r--r--test/test.txt1
-rw-r--r--test/test_clistyle.py27
-rw-r--r--test/test_completion_engine.py555
-rw-r--r--test/test_completion_refresher.py88
-rw-r--r--test/test_config.py196
-rw-r--r--test/test_dbspecial.py42
-rw-r--r--test/test_main.py548
-rw-r--r--test/test_naive_completion.py63
-rw-r--r--test/test_parseutils.py190
-rw-r--r--test/test_plan.wiki38
-rw-r--r--test/test_prompt_utils.py11
-rw-r--r--test/test_smart_completion_public_schema_only.py385
-rw-r--r--test/test_special_iocommands.py287
-rw-r--r--test/test_sqlexecute.py295
-rw-r--r--test/test_tabular_output.py118
-rw-r--r--test/utils.py94
-rw-r--r--tox.ini15
98 files changed, 11523 insertions, 0 deletions
diff --git a/.coveragerc b/.coveragerc
new file mode 100644
index 0000000..8d3149f
--- /dev/null
+++ b/.coveragerc
@@ -0,0 +1,3 @@
+[run]
+parallel = True
+source = mycli
diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/.git-blame-ignore-revs
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
new file mode 100644
index 0000000..8d498ab
--- /dev/null
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -0,0 +1,9 @@
+## Description
+<!--- Describe your changes in detail. -->
+
+
+
+## Checklist
+<!--- We appreciate your help and want to give you credit. Please take a moment to put an `x` in the boxes below as you complete them. -->
+- [ ] I've added this contribution to the `changelog.md`.
+- [ ] I've added my name to the `AUTHORS` file (or it's already there).
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 0000000..752ddb5
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,64 @@
+name: mycli
+
+on:
+ pull_request:
+ paths-ignore:
+ - '**.md'
+
+jobs:
+ linux:
+ strategy:
+ matrix:
+ python-version: ['3.7', '3.8', '3.9', '3.10']
+ include:
+ - python-version: '3.7'
+ os: ubuntu-18.04 # MySQL 5.7.32
+ - python-version: '3.8'
+ os: ubuntu-18.04 # MySQL 5.7.32
+ - python-version: '3.9'
+ os: ubuntu-20.04 # MySQL 8.0.22
+ - python-version: '3.10'
+ os: ubuntu-22.04 # MySQL 8.0.28
+
+ runs-on: ${{ matrix.os }}
+ steps:
+
+ - uses: actions/checkout@v2
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Start MySQL
+ run: |
+ sudo /etc/init.d/mysql start
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r requirements-dev.txt
+ pip install --no-cache-dir -e .
+
+ - name: Wait for MySQL connection
+ run: |
+ while ! mysqladmin ping --host=localhost --port=3306 --user=root --password=root --silent; do
+ sleep 5
+ done
+
+ - name: Pytest / behave
+ env:
+ PYTEST_PASSWORD: root
+ PYTEST_HOST: 127.0.0.1
+ run: |
+ ./setup.py test --pytest-args="--cov-report= --cov=mycli"
+
+ - name: Lint
+ run: |
+ ./setup.py lint --branch=HEAD
+
+ - name: Coverage
+ run: |
+ coverage combine
+ coverage report
+ codecov
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..b13429e
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,14 @@
+.idea/
+.vscode/
+/build
+/dist
+/mycli.egg-info
+/src
+/test/behave.ini
+
+.vagrant
+*.pyc
+*.deb
+.cache/
+.coverage
+.coverage.*
diff --git a/AUTHORS.rst b/AUTHORS.rst
new file mode 100644
index 0000000..995327f
--- /dev/null
+++ b/AUTHORS.rst
@@ -0,0 +1,3 @@
+Check out our `AUTHORS`_.
+
+.. _AUTHORS: mycli/AUTHORS
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..cac4f04
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,167 @@
+# Development Guide
+
+This is a guide for developers who would like to contribute to this project.
+
+If you're interested in contributing to mycli, thank you. We'd love your help!
+You'll always get credit for your work.
+
+## GitHub Workflow
+
+1. [Fork the repository](https://github.com/dbcli/mycli) on GitHub.
+
+2. Clone your fork locally:
+ ```bash
+ $ git clone <url-for-your-fork>
+ ```
+
+3. Add the official repository (`upstream`) as a remote repository:
+ ```bash
+ $ git remote add upstream git@github.com:dbcli/mycli.git
+ ```
+
+4. Set up a [virtual environment](http://docs.python-guide.org/en/latest/dev/virtualenvs)
+ for development:
+
+ ```bash
+ $ cd mycli
+ $ pip install virtualenv
+ $ virtualenv mycli_dev
+ ```
+
+ We've just created a virtual environment that we'll use to install all the dependencies
+ and tools we need to work on mycli. Whenever you want to work on mycli, you
+ need to activate the virtual environment:
+
+ ```bash
+ $ source mycli_dev/bin/activate
+ ```
+
+ When you're done working, you can deactivate the virtual environment:
+
+ ```bash
+ $ deactivate
+ ```
+
+5. Install the dependencies and development tools:
+
+ ```bash
+ $ pip install -r requirements-dev.txt
+ $ pip install --editable .
+ ```
+
+6. Create a branch for your bugfix or feature based off the `main` branch:
+
+ ```bash
+ $ git checkout -b <name-of-bugfix-or-feature> main
+ ```
+
+7. While you work on your bugfix or feature, be sure to pull the latest changes from `upstream`. This ensures that your local codebase is up-to-date:
+
+ ```bash
+ $ git pull upstream main
+ ```
+
+8. When your work is ready for the mycli team to review it, push your branch to your fork:
+
+ ```bash
+ $ git push origin <name-of-bugfix-or-feature>
+ ```
+
+9. [Create a pull request](https://help.github.com/articles/creating-a-pull-request-from-a-fork/)
+ on GitHub.
+
+
+## Running the Tests
+
+While you work on mycli, it's important to run the tests to make sure your code
+hasn't broken any existing functionality. To run the tests, just type in:
+
+```bash
+$ ./setup.py test
+```
+
+Mycli supports Python 2.7 and 3.4+. You can test against multiple versions of
+Python by running tox:
+
+```bash
+$ tox
+```
+
+
+### Test Database Credentials
+
+The tests require a database connection to work. You can tell the tests which
+credentials to use by setting the applicable environment variables:
+
+```bash
+$ export PYTEST_HOST=localhost
+$ export PYTEST_USER=mycli
+$ export PYTEST_PASSWORD=myclirocks
+$ export PYTEST_PORT=3306
+$ export PYTEST_CHARSET=utf8
+```
+
+The default values are `localhost`, `root`, no password, `3306`, and `utf8`.
+You only need to set the values that differ from the defaults.
+
+If you would like to run the tests as a user with only the necessary privileges,
+create a `mycli` user and run the following grant statements.
+
+```sql
+GRANT ALL PRIVILEGES ON `mycli_%`.* TO 'mycli'@'localhost';
+GRANT SELECT ON mysql.* TO 'mycli'@'localhost';
+GRANT SELECT ON performance_schema.* TO 'mycli'@'localhost';
+```
+
+### CLI Tests
+
+Some CLI tests expect the program `ex` to be a symbolic link to `vim`.
+
+In some systems (e.g. Arch Linux) `ex` is a symbolic link to `vi`, which will
+change the output and therefore make some tests fail.
+
+You can check this by running:
+```bash
+$ readlink -f $(which ex)
+```
+
+
+## Coding Style
+
+Mycli requires code submissions to adhere to
+[PEP 8](https://www.python.org/dev/peps/pep-0008/).
+It's easy to check the style of your code, just run:
+
+```bash
+$ ./setup.py lint
+```
+
+If you see any PEP 8 style issues, you can automatically fix them by running:
+
+```bash
+$ ./setup.py lint --fix
+```
+
+Be sure to commit and push any PEP 8 fixes.
+
+## Releasing a new version of mycli
+
+You have been made the maintainer of `mycli`? Congratulations! We have a release script to help you:
+
+```sh
+> python release.py --help
+Usage: release.py [options]
+
+Options:
+ -h, --help show this help message and exit
+ -c, --confirm-steps Confirm every step. If the step is not confirmed, it
+ will be skipped.
+ -d, --dry-run Print out, but not actually run any steps.
+```
+
+To release a new version of the package:
+
+* Create and merge a PR to bump the version in the changelog ([example PR](https://github.com/dbcli/mycli/pull/1043)).
+* Pull `main` and bump the version number inside `mycli/__init__.py`. Do not check in - the release script will do that.
+* Make sure you have the dev requirements installed: `pip install -r requirements-dev.txt -U --upgrade-strategy only-if-needed`.
+* Finally, run the release script: `python release.py`.
diff --git a/LICENSE.txt b/LICENSE.txt
new file mode 100644
index 0000000..7b4904e
--- /dev/null
+++ b/LICENSE.txt
@@ -0,0 +1,34 @@
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification,
+are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice, this
+ list of conditions and the following disclaimer in the documentation and/or
+ other materials provided with the distribution.
+
+* Neither the name of the {organization} nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
+ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+-------------------------------------------------------------------------------
+
+This program also bundles with it python-tabulate
+(https://pypi.python.org/pypi/tabulate) library. This library is licensed under
+MIT License.
+
+-------------------------------------------------------------------------------
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000..04f4d9a
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,6 @@
+include LICENSE.txt *.md *.rst requirements-dev.txt screenshots/*
+include tasks.py .coveragerc tox.ini
+recursive-include test *.cnf
+recursive-include test *.feature
+recursive-include test *.py
+recursive-include test *.txt
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..9e177b7
--- /dev/null
+++ b/README.md
@@ -0,0 +1,236 @@
+# mycli
+
+[![Build Status](https://github.com/dbcli/mycli/workflows/mycli/badge.svg)](https://github.com/dbcli/mycli/actions?query=workflow%3Amycli)
+[![PyPI](https://img.shields.io/pypi/v/mycli.svg)](https://pypi.python.org/pypi/mycli)
+[![LGTM](https://img.shields.io/lgtm/grade/python/github/dbcli/mycli.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/dbcli/mycli/context:python)
+
+A command line client for MySQL that can do auto-completion and syntax highlighting.
+
+HomePage: [http://mycli.net](http://mycli.net)
+Documentation: [http://mycli.net/docs](http://mycli.net/docs)
+
+![Completion](screenshots/tables.png)
+![CompletionGif](screenshots/main.gif)
+
+Postgres Equivalent: [http://pgcli.com](http://pgcli.com)
+
+Quick Start
+-----------
+
+If you already know how to install python packages, then you can install it via pip:
+
+You might need sudo on linux.
+
+```
+$ pip install -U mycli
+```
+
+or
+
+```
+$ brew update && brew install mycli # Only on macOS
+```
+
+or
+
+```
+$ sudo apt-get install mycli # Only on debian or ubuntu
+```
+
+### Usage
+
+ $ mycli --help
+ Usage: mycli [OPTIONS] [DATABASE]
+
+ A MySQL terminal client with auto-completion and syntax highlighting.
+
+ Examples:
+ - mycli my_database
+ - mycli -u my_user -h my_host.com my_database
+ - mycli mysql://my_user@my_host.com:3306/my_database
+
+ Options:
+ -h, --host TEXT Host address of the database.
+ -P, --port INTEGER Port number to use for connection. Honors
+ $MYSQL_TCP_PORT.
+
+ -u, --user TEXT User name to connect to the database.
+ -S, --socket TEXT The socket file to use for connection.
+ -p, --password TEXT Password to connect to the database.
+ --pass TEXT Password to connect to the database.
+ --ssh-user TEXT User name to connect to ssh server.
+ --ssh-host TEXT Host name to connect to ssh server.
+ --ssh-port INTEGER Port to connect to ssh server.
+ --ssh-password TEXT Password to connect to ssh server.
+ --ssh-key-filename TEXT Private key filename (identify file) for the
+ ssh connection.
+
+ --ssh-config-path TEXT Path to ssh configuration.
+ --ssh-config-host TEXT Host to connect to ssh server reading from ssh
+ configuration.
+
+ --ssl Enable SSL for connection (automatically
+ enabled with other flags).
+ --ssl-ca PATH CA file in PEM format.
+ --ssl-capath TEXT CA directory.
+ --ssl-cert PATH X509 cert in PEM format.
+ --ssl-key PATH X509 key in PEM format.
+ --ssl-cipher TEXT SSL cipher to use.
+ --ssl-verify-server-cert Verify server's "Common Name" in its cert
+ against hostname used when connecting. This
+ option is disabled by default.
+
+ -V, --version Output mycli's version.
+ -v, --verbose Verbose output.
+ -D, --database TEXT Database to use.
+ -d, --dsn TEXT Use DSN configured into the [alias_dsn]
+ section of myclirc file.
+
+ --list-dsn list of DSN configured into the [alias_dsn]
+ section of myclirc file.
+
+ --list-ssh-config list ssh configurations in the ssh config
+ (requires paramiko).
+
+ -R, --prompt TEXT Prompt format (Default: "\t \u@\h:\d> ").
+ -l, --logfile FILENAME Log every query and its results to a file.
+ --defaults-group-suffix TEXT Read MySQL config groups with the specified
+ suffix.
+
+ --defaults-file PATH Only read MySQL options from the given file.
+ --myclirc PATH Location of myclirc file.
+ --auto-vertical-output Automatically switch to vertical output mode
+ if the result is wider than the terminal
+ width.
+
+ -t, --table Display batch output in table format.
+ --csv Display batch output in CSV format.
+ --warn / --no-warn Warn before running a destructive query.
+ --local-infile BOOLEAN Enable/disable LOAD DATA LOCAL INFILE.
+ -g, --login-path TEXT Read this path from the login file.
+ -e, --execute TEXT Execute command and quit.
+ --init-command TEXT SQL statement to execute after connecting.
+ --charset TEXT Character set for MySQL session.
+ --password-file PATH File or FIFO path containing the password
+ to connect to the db if not specified otherwise
+ --help Show this message and exit.
+
+
+Features
+--------
+
+`mycli` is written using [prompt_toolkit](https://github.com/jonathanslenders/python-prompt-toolkit/).
+
+* Auto-completion as you type for SQL keywords as well as tables, views and
+ columns in the database.
+* Syntax highlighting using Pygments.
+* Smart-completion (enabled by default) will suggest context-sensitive completion.
+ - `SELECT * FROM <tab>` will only show table names.
+ - `SELECT * FROM users WHERE <tab>` will only show column names.
+* Support for multiline queries.
+* Favorite queries with optional positional parameters. Save a query using
+ `\fs alias query` and execute it with `\f alias` whenever you need.
+* Timing of sql statements and table rendering.
+* Config file is automatically created at ``~/.myclirc`` at first launch.
+* Log every query and its results to a file (disabled by default).
+* Pretty prints tabular data (with colors!)
+* Support for SSL connections
+* Some features are only exposed as [key bindings](doc/key_bindings.rst)
+
+Contributions:
+--------------
+
+If you're interested in contributing to this project, first of all I would like
+to extend my heartfelt gratitude. I've written a small doc to describe how to
+get this running in a development setup.
+
+https://github.com/dbcli/mycli/blob/main/CONTRIBUTING.md
+
+Please feel free to reach out to me if you need help.
+
+My email: amjith.r@gmail.com
+
+Twitter: [@amjithr](http://twitter.com/amjithr)
+
+## Detailed Install Instructions:
+
+### Arch, Manjaro
+
+You can install the mycli package available in the AUR:
+
+```
+$ yay -S mycli
+```
+
+### Debian, Ubuntu
+
+On Debian, Ubuntu distributions, you can easily install the mycli package using apt:
+
+```
+$ sudo apt-get install mycli
+```
+
+### Fedora
+
+Fedora has a package available for mycli, install it using dnf:
+
+```
+$ sudo dnf install mycli
+```
+
+### RHEL, Centos
+
+I haven't built an RPM package for mycli for RHEL or Centos yet. So please use `pip` to install `mycli`. You can install pip on your system using:
+
+```
+$ sudo yum install python3-pip
+```
+
+Once that is installed, you can install mycli as follows:
+
+```
+$ sudo pip3 install mycli
+```
+
+### Windows
+
+Follow the instructions on this blogpost: https://www.codewall.co.uk/installing-using-mycli-on-windows/
+
+### Cygwin
+
+1. Make sure the following Cygwin packages are installed:
+`python3`, `python3-pip`.
+2. Install mycli: `pip3 install mycli`
+
+### Thanks:
+
+This project was funded through kickstarter. My thanks to the [backers](http://mycli.net/sponsors) who supported the project.
+
+A special thanks to [Jonathan Slenders](https://twitter.com/jonathan_s) for
+creating [Python Prompt Toolkit](http://github.com/jonathanslenders/python-prompt-toolkit),
+which is quite literally the backbone library, that made this app possible.
+Jonathan has also provided valuable feedback and support during the development
+of this app.
+
+[Click](http://click.pocoo.org/) is used for command line option parsing
+and printing error messages.
+
+Thanks to [PyMysql](https://github.com/PyMySQL/PyMySQL) for a pure python adapter to MySQL database.
+
+
+### Compatibility
+
+Mycli is tested on macOS and Linux, and requires Python 3.7 or better.
+
+**Mycli is not tested on Windows**, but the libraries used in this app are Windows-compatible.
+This means it should work without any modifications. If you're unable to run it
+on Windows, please [file a bug](https://github.com/dbcli/mycli/issues/new).
+
+### Configuration and Usage
+
+For more information on using and configuring mycli, [check out our documentation](http://mycli.net/docs).
+
+Common topics include:
+- [Configuring mycli](http://mycli.net/config)
+- [Using/Disabling the pager](http://mycli.net/pager)
+- [Syntax colors](http://mycli.net/syntax)
diff --git a/SPONSORS.rst b/SPONSORS.rst
new file mode 100644
index 0000000..173555c
--- /dev/null
+++ b/SPONSORS.rst
@@ -0,0 +1,3 @@
+Check out our `SPONSORS`_.
+
+.. _SPONSORS: mycli/SPONSORS
diff --git a/changelog.md b/changelog.md
new file mode 100644
index 0000000..3dbdc1f
--- /dev/null
+++ b/changelog.md
@@ -0,0 +1,935 @@
+
+1.26.1 (2022/09/01)
+===
+
+Bug Fixes:
+----------
+* Require Python 3.7 in `setup.py`
+
+
+1.26.0 (2022/09/01)
+===================
+
+Features:
+---------
+
+* Add `--ssl` flag to enable ssl/tls.
+* Add `pager` option to `~/.myclirc`, for instance `pager = 'pspg --csv'` (Thanks: [BuonOmo])
+* Add prettify/unprettify keybindings to format the current statement using `sqlglot`.
+
+
+Internal:
+---------
+* Pin `cryptography` to suppress `paramiko` warning, helping CI complete and presumably affecting some users.
+* Upgrade some dev requirements
+* Change tests to always use databases prefixed with 'mycli_' for better security
+
+Bug Fixes:
+----------
+* Support for some MySQL compatible databases, which may not implement connection_id().
+* Fix the status command to work with missing 'Flush_commands' (mariadb)
+* Ignore the user of the system [myslqd] config.
+
+
+1.25.0 (2022/04/02)
+===================
+
+Features:
+---------
+* Add `beep_after_seconds` option to `~/.myclirc`, to ring the terminal bell after long queries.
+
+
+1.24.4 (2022/03/30)
+===================
+
+Internal:
+---------
+* Upgrade Ubuntu VM for runners as Github has deprecated it
+
+Bug Fixes:
+----------
+* Change in main.py - Replace the `click.get_terminal_size()` with `shutil.get_terminal_size()`
+
+
+
+1.24.3 (2022/01/20)
+===================
+
+Bug Fixes:
+----------
+* Upgrade cli_helpers to workaround Pygments regression.
+
+
+1.24.2 (2022/01/11)
+===================
+
+Bug Fixes:
+----------
+* Fix autocompletion for more than one JOIN
+* Fix the status command when connected to TiDB or other servers that don't implement 'Threads\_connected'
+* Pin pygments version to avoid a breaking change
+
+1.24.1:
+=======
+
+Bug Fixes:
+---------
+* Restore dependency on cryptography for the interactive password prompt
+
+Internal:
+---------
+* Deprecate Python mock
+
+
+1.24.0
+======
+
+Bug Fixes:
+----------
+* Allow `FileNotFound` exception for SSH config files.
+* Fix startup error on MySQL < 5.0.22
+* Check error code rather than message for Access Denied error
+* Fix login with ~/.my.cnf files
+
+Features:
+---------
+* Add `-g` shortcut to option `--login-path`.
+* Alt-Enter dispatches the command in multi-line mode.
+* Allow to pass a file or FIFO path with --password-file when password is not specified or is failing (as suggested in this best-practice https://www.netmeister.org/blog/passing-passwords.html)
+
+Internal:
+---------
+* Remove unused function is_open_quote()
+* Use importlib, instead of file links, to locate resources
+* Test various host-port combinations in command line arguments
+* Switched from Cryptography to pyaes for decrypting mylogin.cnf
+
+
+1.23.2
+======
+
+Bug Fixes:
+----------
+* Ensure `--port` is always an int.
+
+1.23.1
+======
+
+Bug Fixes:
+----------
+* Allow `--host` without `--port` to make a TCP connection.
+
+1.23.0
+======
+
+Bug Fixes:
+----------
+* Fix config file include logic
+
+Features:
+---------
+
+* Add an option `--init-command` to execute SQL after connecting (Thanks: [KITAGAWA Yasutaka]).
+* Use InputMode.REPLACE_SINGLE
+* Add support for ANSI escape sequences for coloring the prompt.
+* Allow customization of Pygments SQL syntax-highlighting styles.
+* Add a `\clip` special command to copy queries to the system clipboard.
+* Add a special command `\pipe_once` to pipe output to a subprocess.
+* Add an option `--charset` to set the default charset when connect database.
+
+Bug Fixes:
+----------
+* Fixed compatibility with sqlparse 0.4 (Thanks: [mtorromeo]).
+* Fixed iPython magic (Thanks: [mwcm]).
+* Send "Connecting to socket" message to the standard error.
+* Respect empty string for prompt_continuation via `prompt_continuation = ''` in `.myclirc`
+* Fix \once -o to overwrite output whole, instead of line-by-line.
+* Dispatch lines ending with `\e` or `\clip` on return, even in multiline mode.
+* Restore working local `--socket=<UDS>` (Thanks: [xeron]).
+* Allow backtick quoting around the database argument to the `use` command.
+* Avoid opening `/dev/tty` when `--no-warn` is given.
+* Fixed some typo errors in `README.md`.
+
+1.22.2
+======
+
+Bug Fixes:
+----------
+
+* Make the `pwd` module optional.
+
+1.22.1
+======
+
+Bug Fixes:
+----------
+* Fix the breaking change introduced in PyMySQL 0.10.0. (Thanks: [Amjith]).
+
+Features:
+---------
+* Add an option `--ssh-config-host` to read ssh configuration from OpenSSH configuration file.
+* Add an option `--list-ssh-config` to list ssh configurations.
+* Add an option `--ssh-config-path` to choose ssh configuration path.
+
+Bug Fixes:
+----------
+
+* Fix specifying empty password with `--password=''` when config file has a password set (Thanks: [Zach DeCook]).
+
+
+1.21.1
+======
+
+
+Bug Fixes:
+----------
+
+* Fix broken auto-completion for favorite queries (Thanks: [Amjith]).
+* Fix undefined variable exception when running with --no-warn (Thanks: [Georgy Frolov])
+* Support setting color for null value (Thanks: [laixintao])
+
+1.21.0
+======
+
+Features:
+---------
+* Added DSN alias name as a format specifier to the prompt (Thanks: [Georgy Frolov]).
+* Mark `update` without `where`-clause as destructive query (Thanks: [Klaus Wünschel]).
+* Added DELIMITER command (Thanks: [Georgy Frolov])
+* Added clearer error message when failing to connect to the default socket.
+* Extend main.is_dropping_database check with create after delete statement.
+* Search `${XDG_CONFIG_HOME}/mycli/myclirc` after `${HOME}/.myclirc` and before `/etc/myclirc` (Thanks: [Takeshi D. Itoh])
+
+Bug Fixes:
+----------
+
+* Allow \o command more than once per session (Thanks: [Georgy Frolov])
+* Fixed crash when the query dropping the current database starts with a comment (Thanks: [Georgy Frolov])
+
+Internal:
+---------
+* deprecate python versions 2.7, 3.4, 3.5; support python 3.8
+
+1.20.1
+======
+
+Bug Fixes:
+----------
+
+* Fix an error when using login paths with an explicit database name (Thanks: [Thomas Roten]).
+
+1.20.0
+======
+
+Features:
+----------
+* Auto find alias dsn when `://` not in `database` (Thanks: [QiaoHou Peng]).
+* Mention URL encoding as escaping technique for special characters in connection DSN (Thanks: [Aljosha Papsch]).
+* Pressing Alt-Enter will introduce a line break. This is a way to break up the query into multiple lines without switching to multi-line mode. (Thanks: [Amjith Ramanujam]).
+* Use a generator to stream the output to the pager (Thanks: [Dick Marinus]).
+
+Bug Fixes:
+----------
+
+* Fix the missing completion for special commands (Thanks: [Amjith Ramanujam]).
+* Fix favorites queries being loaded/stored only from/in default config file and not --myclirc (Thanks: [Matheus Rosa])
+* Fix automatic vertical output with native syntax style (Thanks: [Thomas Roten]).
+* Update `cli_helpers` version, this will remove quotes from batch output like the official client (Thanks: [Dick Marinus])
+* Update `setup.py` to no longer require `sqlparse` to be less than 0.3.0 as that just came out and there are no notable changes. ([VVelox])
+* workaround for ConfigObj parsing strings containing "," as lists (Thanks: [Mike Palandra])
+
+Internal:
+---------
+* fix unhashable FormattedText from prompt toolkit in unit tests (Thanks: [Dick Marinus]).
+
+1.19.0
+======
+
+Internal:
+---------
+
+* Add Python 3.7 trove classifier (Thanks: [Thomas Roten]).
+* Fix pytest in Fedora mock (Thanks: [Dick Marinus]).
+* Require `prompt_toolkit>=2.0.6` (Thanks: [Dick Marinus]).
+
+Features:
+---------
+
+* Add Token.Prompt/Continuation (Thanks: [Dick Marinus]).
+* Don't reconnect when switching databases using use (Thanks: [Angelo Lupo]).
+* Handle MemoryErrors while trying to pipe in large files and exit gracefully with an error (Thanks: [Amjith Ramanujam])
+
+Bug Fixes:
+----------
+
+* Enable Ctrl-Z to suspend the app (Thanks: [Amjith Ramanujam]).
+
+1.18.2
+======
+
+Bug Fixes:
+----------
+
+* Fixes database reconnecting feature (Thanks: [Yang Zou]).
+
+Internal:
+---------
+
+* Update Twine version to 1.12.1 (Thanks: [Thomas Roten]).
+* Fix warnings for running tests on Python 3.7 (Thanks: [Dick Marinus]).
+* Clean up and add behave logging (Thanks: [Dick Marinus]).
+
+1.18.1
+======
+
+Features:
+---------
+
+* Add Keywords: TINYINT, SMALLINT, MEDIUMINT, INT, BIGINT (Thanks: [QiaoHou Peng]).
+
+Internal:
+---------
+
+* Update prompt toolkit (Thanks: [Jonathan Slenders], [Irina Truong], [Dick Marinus]).
+
+1.18.0
+======
+
+Features:
+---------
+
+* Display server version in welcome message (Thanks: [Irina Truong]).
+* Set `program_name` connection attribute (Thanks: [Dick Marinus]).
+* Use `return` to terminate a generator for better Python 3.7 support (Thanks: [Zhongyang Guan]).
+* Add `SAVEPOINT` to SQLCompleter (Thanks: [Huachao Mao]).
+* Connect using a SSH transport (Thanks: [Dick Marinus]).
+* Add `FROM_UNIXTIME` and `UNIX_TIMESTAMP` to SQLCompleter (Thanks: [QiaoHou Peng])
+* Search `${PWD}/.myclirc`, then `${HOME}/.myclirc`, lastly `/etc/myclirc` (Thanks: [QiaoHao Peng])
+
+Bug Fixes:
+----------
+
+* When DSN is used, allow overrides from mycli arguments (Thanks: [Dick Marinus]).
+* A DSN without password should be allowed (Thanks: [Dick Marinus])
+
+Bug Fixes:
+----------
+
+* Convert `sql_format` to unicode strings for py27 compatibility (Thanks: [Dick Marinus]).
+* Fixes mycli compatibility with pbr (Thanks: [Thomas Roten]).
+* Don't align decimals for `sql_format` (Thanks: [Dick Marinus]).
+
+Internal:
+---------
+
+* Use fileinput (Thanks: [Dick Marinus]).
+* Enable tests for Python 3.7 (Thanks: [Thomas Roten]).
+* Remove `*.swp` from gitignore (Thanks: [Dick Marinus]).
+
+1.17.0:
+=======
+
+Features:
+----------
+
+* Add `CONCAT` to SQLCompleter and remove unused code (Thanks: [caitinggui])
+* Do not quit when aborting a confirmation prompt (Thanks: [Thomas Roten]).
+* Add option list-dsn (Thanks: [Frederic Aoustin]).
+* Add verbose option for list-dsn, add tests and clean up code (Thanks: [Dick Marinus]).
+
+Bug Fixes:
+----------
+
+* Add enable_pager to the config file (Thanks: [Frederic Aoustin]).
+* Mark `test_sql_output` as a dbtest (Thanks: [Dick Marinus]).
+* Don't crash if the log/history file directories don't exist (Thanks: [Thomas Roten]).
+* Unquote dsn username and password (Thanks: [Dick Marinus]).
+* Output `Password:` prompt to stderr (Thanks: [ushuz]).
+* Mark `alter` as a destructive query (Thanks: [Dick Marinus]).
+* Quote CSV fields (Thanks: [Thomas Roten]).
+* Fix `thanks_picker` (Thanks: [Dick Marinus]).
+
+Internal:
+---------
+
+* Refactor Destructive Warning behave tests (Thanks: [Dick Marinus]).
+
+
+1.16.0:
+=======
+
+Features:
+---------
+
+* Add DSN aliases to the config file (Thanks: [Frederic Aoustin]).
+
+Bug Fixes:
+----------
+
+* Do not try to connect to a unix socket on Windows (Thanks: [Thomas Roten]).
+
+1.15.0:
+=======
+
+Features:
+---------
+
+* Add sql-update/insert output format. (Thanks: [Dick Marinus]).
+* Also complete aliases in WHERE. (Thanks: [Dick Marinus]).
+
+1.14.0:
+=======
+
+Features:
+---------
+
+* Add `watch [seconds] query` command to repeat a query every [seconds] seconds (by default 5). (Thanks: [David Caro](https://github.com/Terseus))
+* Default to unix socket connection if host and port are unspecified. This simplifies authentication on some systems and matches mysql behaviour.
+* Add support for positional parameters to favorite queries. (Thanks: [Scrappy Soft](https://github.com/scrappysoft))
+
+Bug Fixes:
+----------
+
+* Fix source command for script in current working directory. (Thanks: [Dick Marinus]).
+* Fix issue where the `tee` command did not work on Python 2.7 (Thanks: [Thomas Roten]).
+
+Internal Changes:
+-----------------
+
+* Drop support for Python 3.3 (Thanks: [Thomas Roten]).
+
+* Make tests more compatible between different build environments. (Thanks: [David Caro])
+* Merge `_on_completions_refreshed` and `_swap_completer_objects` functions (Thanks: [Dick Marinus]).
+
+1.13.1:
+=======
+
+Bug Fixes:
+----------
+
+* Fix keyword completion suggestion for `SHOW` (Thanks: [Thomas Roten]).
+* Prevent mycli from crashing when failing to read login path file (Thanks: [Thomas Roten]).
+
+Internal Changes:
+-----------------
+
+* Make tests ignore user config files (Thanks: [Thomas Roten]).
+
+1.13.0:
+=======
+
+Features:
+---------
+
+* Add file name completion for source command (issue #500). (Thanks: [Irina Truong]).
+
+Bug Fixes:
+----------
+
+* Fix UnicodeEncodeError when editing sql command in external editor (Thanks: Klaus Wünschel).
+* Fix MySQL4 version comment retrieval (Thanks: [François Pietka])
+* Fix error that occurred when outputting JSON and NULL data (Thanks: [Thomas Roten]).
+
+1.12.1:
+=======
+
+Bug Fixes:
+----------
+
+* Prevent missing MySQL help database from causing errors in completions (Thanks: [Thomas Roten]).
+* Fix mycli from crashing with small terminal windows under Python 2 (Thanks: [Thomas Roten]).
+* Prevent an error from displaying when you drop the current database (Thanks: [Thomas Roten]).
+
+Internal Changes:
+-----------------
+
+* Use less memory when formatting results for display (Thanks: [Dick Marinus]).
+* Preliminary work for a future change in outputting results that uses less memory (Thanks: [Dick Marinus]).
+
+1.12.0:
+=======
+
+Features:
+---------
+
+* Add fish-style auto-suggestion from history. (Thanks: [Amjith Ramanujam])
+
+
+1.11.0:
+=======
+
+Features:
+---------
+
+* Handle reserved space for completion menu better in small windows. (Thanks: [Thomas Roten]).
+* Display current vi mode in toolbar. (Thanks: [Thomas Roten]).
+* Opening an external editor will edit the last-run query. (Thanks: [Thomas Roten]).
+* Output once special command. (Thanks: [Dick Marinus]).
+* Add special command to show create table statement. (Thanks: [Ryan Smith])
+* Display all result sets returned by stored procedures (Thanks: [Thomas Roten]).
+* Add current time to prompt options (Thanks: [Thomas Roten]).
+* Output status text in a more intuitive way (Thanks: [Thomas Roten]).
+* Add colored/styled headers and odd/even rows (Thanks: [Thomas Roten]).
+* Keyword completion casing (upper/lower/auto) (Thanks: [Irina Truong]).
+
+Bug Fixes:
+----------
+
+* Fixed incorrect timekeeping when running queries from a file. (Thanks: [Thomas Roten]).
+* Do not display time and empty line for blank queries (Thanks: [Thomas Roten]).
+* Fixed issue where quit command would sometimes not work (Thanks: [Thomas Roten]).
+* Remove shebang from main.py (Thanks: [Dick Marinus]).
+* Only use pager if output doesn't fit. (Thanks: [Dick Marinus]).
+* Support tilde user directory for output file names (Thanks: [Thomas Roten]).
+* Auto vertical output is a little bit better at its calculations (Thanks: [Thomas Roten]).
+
+Internal Changes:
+-----------------
+
+* Rename tests/ to test/. (Thanks: [Dick Marinus]).
+* Move AUTHORS and SPONSORS to mycli directory. (Thanks: [Terje Røsten] []).
+* Switch from pycryptodome to cryptography (Thanks: [Thomas Roten]).
+* Add pager wrapper for behave tests (Thanks: [Dick Marinus]).
+* Behave test source command (Thanks: [Dick Marinus]).
+* Test using behave the tee command (Thanks: [Dick Marinus]).
+* Behave fix clean up. (Thanks: [Dick Marinus]).
+* Remove output formatter code in favor of CLI Helpers dependency (Thanks: [Thomas Roten]).
+* Better handle common before/after scenarios in behave. (Thanks: [Dick Marinus])
+* Added a regression test for sqlparse >= 0.2.3 (Thanks: [Dick Marinus]).
+* Reverted removal of temporary hack for sqlparse (Thanks: [Dick Marinus]).
+* Add setup.py commands to simplify development tasks (Thanks: [Thomas Roten]).
+* Add behave tests to tox (Thanks: [Dick Marinus]).
+* Add missing @dbtest to tests (Thanks: [Dick Marinus]).
+* Standardizes punctuation/grammar for help strings (Thanks: [Thomas Roten]).
+
+1.10.0:
+=======
+
+Features:
+---------
+
+* Add ability to specify alternative myclirc file. (Thanks: [Dick Marinus]).
+* Add new display formats for pretty printing query results. (Thanks: [Amjith
+ Ramanujam], [Dick Marinus], [Thomas Roten]).
+* Add logic to shorten the default prompt if it becomes too long once generated. (Thanks: [John Sterling]).
+
+Bug Fixes:
+----------
+
+* Fix external editor bug (issue #377). (Thanks: [Irina Truong]).
+* Fixed bug so that favorite queries can include unicode characters. (Thanks:
+ [Thomas Roten]).
+* Fix requirements and remove old compatibility code (Thanks: [Dick Marinus])
+* Fix bug where mycli would not start due to the thanks/credit intro text.
+ (Thanks: [Thomas Roten]).
+* Use pymysql default conversions (issue #375). (Thanks: [Dick Marinus]).
+
+Internal Changes:
+-----------------
+
+* Upload mycli distributions in a safer manner (using twine). (Thanks: [Thomas
+ Roten]).
+* Test mycli using pexpect/python-behave (Thanks: [Dick Marinus]).
+* Run pep8 checks in travis (Thanks: [Irina Truong]).
+* Remove temporary hack for sqlparse (Thanks: [Dick Marinus]).
+
+1.9.0:
+======
+
+Features:
+---------
+
+* Add tee/notee commands for outputing results to a file. (Thanks: [Dick Marinus]).
+* Add date, port, and whitespace options to prompt configuration. (Thanks: [Matheus Rosa]).
+* Allow user to specify LESS pager flags. (Thanks: [John Sterling]).
+* Add support for auto-reconnect. (Thanks: [Jialong Liu]).
+* Add CSV batch output. (Thanks: [Matheus Rosa]).
+* Add `auto_vertical_output` config to myclirc. (Thanks: [Matheus Rosa]).
+* Improve Fedora install instructions. (Thanks: [Dick Marinus]).
+
+Bug Fixes:
+----------
+
+* Fix crashes occuring from commands starting with #. (Thanks: [Zhidong]).
+* Fix broken PyMySQL link in README. (Thanks: [Daniël van Eeden]).
+* Add various missing keywords for highlighting and autocompletion. (Thanks: [zer09]).
+* Add the missing REGEXP keyword for highlighting and autocompletion. (Thanks: [cxbig]).
+* Fix duplicate username entries in completion list. (Thanks: [John Sterling]).
+* Remove extra spaces in TSV table format output. (Thanks: [Dick Marinus]).
+* Kill running query when interrupted via Ctrl-C. (Thanks: [chainkite]).
+* Read the `smart_completion` config from myclirc. (Thanks: [Thomas Roten]).
+
+Internal Changes:
+-----------------
+
+* Improve handling of test database credentials. (Thanks: [Dick Marinus]).
+* Add Python 3.6 to test environments and PyPI metadata. (Thanks: [Thomas Roten]).
+* Drop Python 2.6 support. (Thanks: [Thomas Roten]).
+* Swap pycrypto dependency for pycryptodome. (Thanks: [Michał Górny]).
+* Bump sqlparse version so pgcli and mycli can be installed together. (Thanks: [darikg]).
+
+1.8.1:
+======
+
+Bug Fixes:
+----------
+* Remove duplicate listing of DISTINCT keyword. (Thanks: [Amjith Ramanujam]).
+* Add an try/except for AS keyword crash. (Thanks: [Amjith Ramanujam]).
+* Support python-sqlparse 0.2. (Thanks: [Dick Marinus]).
+* Fallback to the raw object for invalid time values. (Thanks: [Amjith Ramanujam]).
+* Reset the show items when completion is refreshed. (Thanks: [Amjith Ramanujam]).
+
+Internal Changes:
+-----------------
+* Make the dependency of sqlparse slightly more liberal. (Thanks: [Amjith Ramanujam]).
+
+1.8.0:
+======
+
+Features:
+---------
+
+* Add support for --execute/-e commandline arg. (Thanks: [Matheus Rosa]).
+* Add `less_chatty` config option to skip the intro messages. (Thanks: [Scrappy Soft]).
+* Support `MYCLI_HISTFILE` environment variable to specify where to write the history file. (Thanks: [Scrappy Soft]).
+* Add `prompt_continuation` config option to allow configuring the continuation prompt for multi-line queries. (Thanks: [Scrappy Soft]).
+* Display login-path instead of host in prompt. (Thanks: [Irina Truong]).
+
+Bug Fixes:
+----------
+
+* Pin sqlparse to version 0.1.19 since the new version is breaking completion. (Thanks: [Amjith Ramanujam]).
+* Remove unsupported keywords. (Thanks: [Matheus Rosa]).
+* Fix completion suggestion inside functions with operands. (Thanks: [Irina Truong]).
+
+1.7.0:
+======
+
+Features:
+---------
+
+* Add stdin batch mode. (Thanks: [Thomas Roten]).
+* Add warn/no-warn command-line options. (Thanks: [Thomas Roten]).
+* Upgrade sqlparse dependency to 0.1.19. (Thanks: [Amjith Ramanujam]).
+* Update features list in README.md. (Thanks: [Matheus Rosa]).
+* Remove extra \n in features list in README.md. (Thanks: [Matheus Rosa]).
+
+Bug Fixes:
+----------
+
+* Enable history search via <C-r>. (Thanks: [Amjith Ramanujam]).
+
+Internal Changes:
+-----------------
+
+* Upgrade `prompt_toolkit` to 1.0.0. (Thanks: [Jonathan Slenders])
+
+1.6.0:
+======
+
+Features:
+---------
+
+* Change continuation prompt for multi-line mode to match default mysql.
+* Add `status` command to match mysql's `status` command. (Thanks: [Thomas Roten]).
+* Add SSL support for `mycli`. (Thanks: [Artem Bezsmertnyi]).
+* Add auto-completion and highlight support for OFFSET keyword. (Thanks: [Matheus Rosa]).
+* Add support for `MYSQL_TEST_LOGIN_FILE` env variable to specify alternate login file. (Thanks: [Thomas Roten]).
+* Add support for `--auto-vertical-output` to automatically switch to vertical output if the output doesn't fit in the table format.
+* Add support for system-wide config. Now /etc/myclirc will be honored. (Thanks: [Thomas Roten]).
+* Add support for `nopager` and `\n` to turn off the pager. (Thanks: [Thomas Roten]).
+* Add support for `--local-infile` command-line option. (Thanks: [Thomas Roten]).
+
+Bug Fixes:
+----------
+
+* Remove -S from `less` option which was clobbering the scroll back in history. (Thanks: [Thomas Roten]).
+* Make system command work with Python 3. (Thanks: [Thomas Roten]).
+* Support \G terminator for \f queries. (Thanks: [Terseus]).
+
+Internal Changes:
+-----------------
+
+* Upgrade `prompt_toolkit` to 0.60.
+* Add Python 3.5 to test environments. (Thanks: [Thomas Roten]).
+* Remove license meta-data. (Thanks: [Thomas Roten]).
+* Skip binary tests if PyMySQL version does not support it. (Thanks: [Thomas Roten]).
+* Refactor pager handling. (Thanks: [Thomas Roten])
+* Capture warnings to log file. (Thanks: [Mikhail Borisov]).
+* Make `syntax_style` a tiny bit more intuitive. (Thanks: [Phil Cohen]).
+
+1.5.2:
+======
+
+Bug Fixes:
+----------
+
+* Protect against port number being None when no port is specified in command line.
+
+1.5.1:
+======
+
+Bug Fixes:
+----------
+
+* Cast the value of port read from my.cnf to int.
+
+1.5.0:
+======
+
+Features:
+---------
+
+* Make a config option to enable `audit_log`. (Thanks: [Matheus Rosa]).
+* Add support for reading .mylogin.cnf to get user credentials. (Thanks: [Thomas Roten]).
+ This feature is only available when `pycrypto` package is installed.
+* Register the special command `prompt` with the `\R` as alias. (Thanks: [Matheus Rosa]).
+ Users can now change the mysql prompt at runtime using `prompt` command.
+ eg:
+ ```
+ mycli> prompt \u@\h>
+ Changed prompt format to \u@\h>
+ Time: 0.001s
+ amjith@localhost>
+ ```
+* Perform completion refresh in a background thread. Now mycli can handle
+ databases with thousands of tables without blocking.
+* Add support for `system` command. (Thanks: [Matheus Rosa]).
+ Users can now run a system command from within mycli as follows:
+ ```
+ amjith@localhost:(none)>system cat tmp.sql
+ select 1;
+ select * from django_migrations;
+ ```
+* Caught and hexed binary fields in MySQL. (Thanks: [Daniel West]).
+ Geometric fields stored in a database will be displayed as hexed strings.
+* Treat enter key as tab when the suggestion menu is open. (Thanks: [Matheus Rosa])
+* Add "delete" and "truncate" as destructive commands. (Thanks: [Martijn Engler]).
+* Change \dt syntax to add an optional table name. (Thanks: [Shoma Suzuki]).
+ `\dt [tablename]` will describe the columns in a table.
+* Add TRANSACTION related keywords.
+* Treat DESC and EXPLAIN as DESCRIBE. (Thanks: [spacewander]).
+
+Bug Fixes:
+----------
+
+* Fix the removal of whitespace from table output.
+* Add ability to make suggestions for compound join clauses. (Thanks: [Matheus Rosa]).
+* Fix the incorrect reporting of command time.
+* Add type validation for port argument. (Thanks [Matheus Rosa])
+
+Internal Changes:
+-----------------
+* Make pycrypto optional and only install it in \*nix systems. (Thanks: [Irina Truong]).
+* Add badge for PyPI version to README. (Thanks: [Shoma Suzuki]).
+* Updated release script with a --dry-run and --confirm-steps option. (Thanks: [Irina Truong]).
+* Adds support for PyMySQL 0.6.2 and above. This is useful for debian package builders. (Thanks: [Thomas Roten]).
+* Disable click warning.
+
+1.4.0:
+======
+
+Features:
+---------
+
+* Add `source` command. This allows running sql statement from a file.
+
+ eg:
+ ```
+ mycli> source filename.sql
+ ```
+
+* Added a config option to make the warning before destructive commands optional. (Thanks: [Daniel West](https://github.com/danieljwest))
+
+ In the config file ~/.myclirc set `destructive_warning = False` which will
+ disable the warning before running `DROP` commands.
+
+* Add completion support for CHANGE TO and other master/slave commands. This is
+ still preliminary and it will be enhanced in the future.
+
+* Add custom styles to color the menus and toolbars.
+
+* Upgrade `prompt_toolkit` to 0.46. (Thanks: [Jonathan Slenders])
+
+ Multi-line queries are automatically indented.
+
+Bug Fixes:
+----------
+
+* Fix keyword completion after the `WHERE` clause.
+* Add `\g` and `\G` as valid query terminators. Previously in multi-line mode
+ ending a query with a `\G` wouldn't run the query. This is now fixed.
+
+1.3.0:
+======
+
+Features:
+---------
+* Add a new special command (\T) to change the table format on the fly. (Thanks: [Jonathan Bruno](https://github.com/brewneaux))
+ eg:
+ ```
+ mycli> \T tsv
+ ```
+* Add `--defaults-group-suffix` to the command line. This lets the user specify
+ a group to use in the my.cnf files. (Thanks: [Irina Truong](http://github.com/j-bennet))
+
+ In the my.cnf file a user can specify credentials for different databases and
+ invoke mycli with the group name to use the appropriate credentials.
+ eg:
+ ```
+ # my.cnf
+ [client]
+ user = 'root'
+ socket = '/tmp/mysql.sock'
+ pager = 'less -RXSF'
+ database = 'account'
+
+ [clientamjith]
+ user = 'amjith'
+ database = 'user_management'
+
+ $ mycli --defaults-group-suffix=amjith # uses the [clientamjith] section in my.cnf
+ ```
+
+* Add `--defaults-file` option to the command line. This allows specifying a
+ `my.cnf` to use at launch. This also makes it play nice with mysql sandbox.
+
+* Make `-p` and `--password` take the password in commandline. This makes mycli
+ a drop in replacement for mysql.
+
+1.2.0:
+======
+
+Features:
+---------
+
+* Add support for wider completion menus in the config file.
+
+ Add `wider_completion_menu = True` in the config file (~/.myclirc) to enable this feature.
+
+Bug Fixes:
+---------
+
+* Prevent Ctrl-C from quitting mycli while the pager is active.
+* Refresh auto-completions after the database is changed via a CONNECT command.
+
+Internal Changes:
+-----------------
+
+* Upgrade `prompt_toolkit` dependency version to 0.45.
+* Added Travis CI to run the tests automatically.
+
+1.1.1:
+======
+
+Bug Fixes:
+----------
+
+* Change dictonary comprehension used in mycnf reader to list comprehension to make it compatible with Python 2.6.
+
+
+1.1.0:
+======
+
+Features:
+---------
+
+* Fuzzy completion is now case-insensitive. (Thanks: [bjarnagin](https://github.com/bjarnagin))
+* Added new-line (`\n`) to the list of special characters to use in prompt. (Thanks: [brewneaux](https://github.com/brewneaux))
+* Honor the `pager` setting in my.cnf files. (Thanks: [Irina Truong](http://github.com/j-bennet))
+
+Bug Fixes:
+----------
+
+* Fix a crashing bug in completion engine for cross joins.
+* Make `<null>` value consistent between tabular and vertical output.
+
+Internal Changes:
+-----------------
+
+* Changed pymysql version to be greater than 0.6.6.
+* Upgrade `prompt_toolkit` version to 0.42. (Thanks: [Yasuhiro Matsumoto](https://github.com/mattn))
+* Removed the explicit dependency on six.
+
+2015/06/10:
+===========
+
+Features:
+---------
+
+* Customizable prompt. (Thanks [Steve Robbins](https://github.com/steverobbins))
+* Make `\G` formatting to behave more like mysql.
+
+Bug Fixes:
+----------
+
+* Formatting issue in \G for really long column values.
+
+
+2015/06/07:
+===========
+
+Features:
+---------
+
+* Upgrade `prompt_toolkit` to 0.38. This improves the performance of pasting long queries.
+* Add support for reading my.cnf files.
+* Add editor command \e.
+* Replace ConfigParser with ConfigObj.
+* Add \dt to show all tables.
+* Add fuzzy completion for table names and column names.
+* Automatically reconnect when connection is lost to the database.
+
+Bug Fixes:
+----------
+
+* Fix a bug with reconnect failure.
+* Fix the issue with `use` command not changing the prompt.
+* Fix the issue where `\\r` shortcut was not recognized.
+
+
+2015/05/24
+==========
+
+Features:
+---------
+
+* Add support for connecting via socket.
+* Add completion for SQL functions.
+* Add completion support for SHOW statements.
+* Made the timing of sql statements human friendly.
+* Automatically prompt for a password if needed.
+
+Bug Fixes:
+----------
+* Fixed the installation issues with PyMySQL dependency on case-sensitive file systems.
+
+[Amjith Ramanujam]: https://blog.amjith.com
+[Artem Bezsmertnyi]: https://github.com/mrdeathless
+[BuonOmo]: https://github.com/BuonOmo
+[Carlos Afonso]: https://github.com/afonsocarlos
+[Casper Langemeijer]: https://github.com/langemeijer
+[Daniel West]: http://github.com/danieljwest
+[Dick Marinus]: https://github.com/meeuw
+[François Pietka]: https://github.com/fpietka
+[Frederic Aoustin]: https://github.com/fraoustin
+[Georgy Frolov]: https://github.com/pasenor
+[Irina Truong]: https://github.com/j-bennet
+[Jonathan Slenders]: https://github.com/jonathanslenders
+[Kacper Kwapisz]: https://github.com/KKKas
+[laixintao]: https://github.com/laixintao
+[Lennart Weller]: https://github.com/lhw
+[Martijn Engler]: https://github.com/martijnengler
+[Matheus Rosa]: https://github.com/mdsrosa
+[Mikhail Borisov]: https://github.com/borman
+[mtorromeo]: https://github.com/mtorromeo
+[mwcm]: https://github.com/mwcm
+[Phil Cohen]: https://github.com/phlipper
+[Scrappy Soft]: https://github.com/scrappysoft
+[Shoma Suzuki]: https://github.com/shoma
+[spacewander]: https://github.com/spacewander
+[Terseus]: https://github.com/Terseus
+[Thomas Roten]: https://github.com/tsroten
+[William GARCIA]: https://github.com/willgarcia
+[xeron]: https://github.com/xeron
+[Zach DeCook]: https://zachdecook.com
diff --git a/doc/key_bindings.rst b/doc/key_bindings.rst
new file mode 100644
index 0000000..0534870
--- /dev/null
+++ b/doc/key_bindings.rst
@@ -0,0 +1,65 @@
+*************
+Key Bindings:
+*************
+
+Most key bindings are simply inherited from `prompt-toolkit <https://python-prompt-toolkit.readthedocs.io/en/master/index.html>`_ .
+
+The following key bindings are special to mycli:
+
+###
+F2
+###
+
+Enable/Disable SmartCompletion Mode.
+
+###
+F3
+###
+
+Enable/Disable Multiline Mode.
+
+###
+F4
+###
+
+Toggle between Vi and Emacs mode.
+
+###
+Tab
+###
+
+Force autocompletion at cursor.
+
+#######
+C-space
+#######
+
+Initialize autocompletion at cursor.
+
+If the autocompletion menu is not showing, display it with the appropriate completions for the context.
+
+If the menu is showing, select the next completion.
+
+#########
+ESC Enter
+#########
+
+Introduce a line break in multi-line mode, or dispatch the command in single-line mode.
+
+The sequence ESC-Enter is often sent by Alt-Enter.
+
+#################################
+C-x p (Emacs-mode) or > (Vi-mode)
+#################################
+
+Prettify and indent current statement, usually into multiple lines.
+
+Only accepts buffers containing single SQL statements.
+
+#################################
+C-x u (Emacs-mode) or < (Vi-mode)
+#################################
+
+Unprettify and dedent current statement, usually into one line.
+
+Only accepts buffers containing single SQL statements.
diff --git a/mycli/AUTHORS b/mycli/AUTHORS
new file mode 100644
index 0000000..a805465
--- /dev/null
+++ b/mycli/AUTHORS
@@ -0,0 +1,101 @@
+Project Lead:
+-------------
+ * Thomas Roten
+
+
+Core Developers:
+----------------
+
+ * Irina Truong
+ * Matheus Rosa
+ * Darik Gamble
+ * Dick Marinus
+ * Amjith Ramanujam
+
+Contributors:
+-------------
+
+ * 0xflotus
+ * Abirami P
+ * Adam Chainz
+ * Aljosha Papsch
+ * Andy Teijelo Pérez
+ * Angelo Lupo
+ * Artem Bezsmertnyi
+ * bitkeen
+ * bjarnagin
+ * BuonOmo
+ * caitinggui
+ * Carlos Afonso
+ * Casper Langemeijer
+ * chainkite
+ * Claude Becker
+ * Colin Caine
+ * cxbig
+ * Daniel Black
+ * Daniel West
+ * Daniël van Eeden
+ * François Pietka
+ * Frederic Aoustin
+ * Georgy Frolov
+ * Heath Naylor
+ * Huachao Mao
+ * Ishaan Bhimwal
+ * Jakub Boukal
+ * jbruno
+ * Jerome Provensal
+ * Jialong Liu
+ * Johannes Hoff
+ * John Sterling
+ * Jonathan Bruno
+ * Jonathan Lloyd
+ * Jonathan Slenders
+ * Kacper Kwapisz
+ * Karthikeyan Singaravelan
+ * kevinhwang91
+ * KITAGAWA Yasutaka
+ * Klaus Wünschel
+ * laixintao
+ * Lennart Weller
+ * Martijn Engler
+ * Massimiliano Torromeo
+ * Michał Górny
+ * Mike Palandra
+ * Mikhail Borisov
+ * Morgan Mitchell
+ * mrdeathless
+ * Nathan Huang
+ * Nicolas Palumbo
+ * Phil Cohen
+ * QiaoHou Peng
+ * Roland Walker
+ * Ryan Smith
+ * Scrappy Soft
+ * Seamile
+ * Shoma Suzuki
+ * spacewander
+ * Steve Robbins
+ * Takeshi D. Itoh
+ * Terje Røsten
+ * Terseus
+ * Tyler Kuipers
+ * ushuz
+ * William GARCIA
+ * xeron
+ * Yang Zou
+ * Yasuhiro Matsumoto
+ * Yuanchun Shang
+ * Zach DeCook
+ * Zane C. Bowers-Hadley
+ * zer09
+ * Zhaolong Zhu
+ * Zhidong
+ * Zhongyang Guan
+ * Arvind Mishra
+ * Kevin Schmeichel
+ * Mel Dafert
+
+Created by:
+-----------
+
+Amjith Ramanujam
diff --git a/mycli/SPONSORS b/mycli/SPONSORS
new file mode 100644
index 0000000..81b0904
--- /dev/null
+++ b/mycli/SPONSORS
@@ -0,0 +1,31 @@
+Many thanks to the following Kickstarter backers.
+
+* Tech Blue Software
+* jweiland.net
+
+# Silver Sponsors
+
+* Whitane Tech
+* Open Query Pty Ltd
+* Prathap Ramamurthy
+* Lincoln Loop
+
+# Sponsors
+
+* Nathan Taggart
+* Iryna Cherniavska
+* Sudaraka Wijesinghe
+* www.mysqlfanboy.com
+* Steve Robbins
+* Norbert Spichtig
+* orpharion bestheneme
+* Daniel Black
+* Anonymous
+* Magnus udd
+* Anonymous
+* Lewis Peckover
+* Cyrille Tabary
+* Heath Naylor
+* Ted Pennings
+* Chris Anderton
+* Jonathan Slenders
diff --git a/mycli/__init__.py b/mycli/__init__.py
new file mode 100644
index 0000000..1512b41
--- /dev/null
+++ b/mycli/__init__.py
@@ -0,0 +1 @@
+__version__ = '1.26.1'
diff --git a/mycli/clibuffer.py b/mycli/clibuffer.py
new file mode 100644
index 0000000..81353b6
--- /dev/null
+++ b/mycli/clibuffer.py
@@ -0,0 +1,55 @@
+from prompt_toolkit.enums import DEFAULT_BUFFER
+from prompt_toolkit.filters import Condition
+from prompt_toolkit.application import get_app
+from .packages import special
+
+
+def cli_is_multiline(mycli):
+ @Condition
+ def cond():
+ doc = get_app().layout.get_buffer_by_name(DEFAULT_BUFFER).document
+
+ if not mycli.multi_line:
+ return False
+ else:
+ return not _multiline_exception(doc.text)
+ return cond
+
+
+def _multiline_exception(text):
+ orig = text
+ text = text.strip()
+
+ # Multi-statement favorite query is a special case. Because there will
+ # be a semicolon separating statements, we can't consider semicolon an
+ # EOL. Let's consider an empty line an EOL instead.
+ if text.startswith('\\fs'):
+ return orig.endswith('\n')
+
+ return (
+ # Special Command
+ text.startswith('\\') or
+
+ # Delimiter declaration
+ text.lower().startswith('delimiter') or
+
+ # Ended with the current delimiter (usually a semi-column)
+ text.endswith(special.get_current_delimiter()) or
+
+ text.endswith('\\g') or
+ text.endswith('\\G') or
+ text.endswith(r'\e') or
+ text.endswith(r'\clip') or
+
+ # Exit doesn't need semi-column`
+ (text == 'exit') or
+
+ # Quit doesn't need semi-column
+ (text == 'quit') or
+
+ # To all teh vim fans out there
+ (text == ':q') or
+
+ # just a plain enter without any text
+ (text == '')
+ )
diff --git a/mycli/clistyle.py b/mycli/clistyle.py
new file mode 100644
index 0000000..b0ac992
--- /dev/null
+++ b/mycli/clistyle.py
@@ -0,0 +1,152 @@
+import logging
+
+import pygments.styles
+from pygments.token import string_to_tokentype, Token
+from pygments.style import Style as PygmentsStyle
+from pygments.util import ClassNotFound
+from prompt_toolkit.styles.pygments import style_from_pygments_cls
+from prompt_toolkit.styles import merge_styles, Style
+
+logger = logging.getLogger(__name__)
+
+# map Pygments tokens (ptk 1.0) to class names (ptk 2.0).
+TOKEN_TO_PROMPT_STYLE = {
+ Token.Menu.Completions.Completion.Current: 'completion-menu.completion.current',
+ Token.Menu.Completions.Completion: 'completion-menu.completion',
+ Token.Menu.Completions.Meta.Current: 'completion-menu.meta.completion.current',
+ Token.Menu.Completions.Meta: 'completion-menu.meta.completion',
+ Token.Menu.Completions.MultiColumnMeta: 'completion-menu.multi-column-meta',
+ Token.Menu.Completions.ProgressButton: 'scrollbar.arrow', # best guess
+ Token.Menu.Completions.ProgressBar: 'scrollbar', # best guess
+ Token.SelectedText: 'selected',
+ Token.SearchMatch: 'search',
+ Token.SearchMatch.Current: 'search.current',
+ Token.Toolbar: 'bottom-toolbar',
+ Token.Toolbar.Off: 'bottom-toolbar.off',
+ Token.Toolbar.On: 'bottom-toolbar.on',
+ Token.Toolbar.Search: 'search-toolbar',
+ Token.Toolbar.Search.Text: 'search-toolbar.text',
+ Token.Toolbar.System: 'system-toolbar',
+ Token.Toolbar.Arg: 'arg-toolbar',
+ Token.Toolbar.Arg.Text: 'arg-toolbar.text',
+ Token.Toolbar.Transaction.Valid: 'bottom-toolbar.transaction.valid',
+ Token.Toolbar.Transaction.Failed: 'bottom-toolbar.transaction.failed',
+ Token.Output.Header: 'output.header',
+ Token.Output.OddRow: 'output.odd-row',
+ Token.Output.EvenRow: 'output.even-row',
+ Token.Output.Null: 'output.null',
+ Token.Prompt: 'prompt',
+ Token.Continuation: 'continuation',
+}
+
+# reverse dict for cli_helpers, because they still expect Pygments tokens.
+PROMPT_STYLE_TO_TOKEN = {
+ v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()
+}
+
+# all tokens that the Pygments MySQL lexer can produce
+OVERRIDE_STYLE_TO_TOKEN = {
+ 'sql.comment': Token.Comment,
+ 'sql.comment.multi-line': Token.Comment.Multiline,
+ 'sql.comment.single-line': Token.Comment.Single,
+ 'sql.comment.optimizer-hint': Token.Comment.Special,
+ 'sql.escape': Token.Error,
+ 'sql.keyword': Token.Keyword,
+ 'sql.datatype': Token.Keyword.Type,
+ 'sql.literal': Token.Literal,
+ 'sql.literal.date': Token.Literal.Date,
+ 'sql.symbol': Token.Name,
+ 'sql.quoted-schema-object': Token.Name.Quoted,
+ 'sql.quoted-schema-object.escape': Token.Name.Quoted.Escape,
+ 'sql.constant': Token.Name.Constant,
+ 'sql.function': Token.Name.Function,
+ 'sql.variable': Token.Name.Variable,
+ 'sql.number': Token.Number,
+ 'sql.number.binary': Token.Number.Bin,
+ 'sql.number.float': Token.Number.Float,
+ 'sql.number.hex': Token.Number.Hex,
+ 'sql.number.integer': Token.Number.Integer,
+ 'sql.operator': Token.Operator,
+ 'sql.punctuation': Token.Punctuation,
+ 'sql.string': Token.String,
+ 'sql.string.double-quouted': Token.String.Double,
+ 'sql.string.escape': Token.String.Escape,
+ 'sql.string.single-quoted': Token.String.Single,
+ 'sql.whitespace': Token.Text,
+}
+
+def parse_pygments_style(token_name, style_object, style_dict):
+ """Parse token type and style string.
+
+ :param token_name: str name of Pygments token. Example: "Token.String"
+ :param style_object: pygments.style.Style instance to use as base
+ :param style_dict: dict of token names and their styles, customized to this cli
+
+ """
+ token_type = string_to_tokentype(token_name)
+ try:
+ other_token_type = string_to_tokentype(style_dict[token_name])
+ return token_type, style_object.styles[other_token_type]
+ except AttributeError as err:
+ return token_type, style_dict[token_name]
+
+
+def style_factory(name, cli_style):
+ try:
+ style = pygments.styles.get_style_by_name(name)
+ except ClassNotFound:
+ style = pygments.styles.get_style_by_name('native')
+
+ prompt_styles = []
+ # prompt-toolkit used pygments tokens for styling before, switched to style
+ # names in 2.0. Convert old token types to new style names, for backwards compatibility.
+ for token in cli_style:
+ if token.startswith('Token.'):
+ # treat as pygments token (1.0)
+ token_type, style_value = parse_pygments_style(
+ token, style, cli_style)
+ if token_type in TOKEN_TO_PROMPT_STYLE:
+ prompt_style = TOKEN_TO_PROMPT_STYLE[token_type]
+ prompt_styles.append((prompt_style, style_value))
+ else:
+ # we don't want to support tokens anymore
+ logger.error('Unhandled style / class name: %s', token)
+ else:
+ # treat as prompt style name (2.0). See default style names here:
+ # https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py
+ prompt_styles.append((token, cli_style[token]))
+
+ override_style = Style([('bottom-toolbar', 'noreverse')])
+ return merge_styles([
+ style_from_pygments_cls(style),
+ override_style,
+ Style(prompt_styles)
+ ])
+
+
+def style_factory_output(name, cli_style):
+ try:
+ style = pygments.styles.get_style_by_name(name).styles
+ except ClassNotFound:
+ style = pygments.styles.get_style_by_name('native').styles
+
+ for token in cli_style:
+ if token.startswith('Token.'):
+ token_type, style_value = parse_pygments_style(
+ token, style, cli_style)
+ style.update({token_type: style_value})
+ elif token in PROMPT_STYLE_TO_TOKEN:
+ token_type = PROMPT_STYLE_TO_TOKEN[token]
+ style.update({token_type: cli_style[token]})
+ elif token in OVERRIDE_STYLE_TO_TOKEN:
+ token_type = OVERRIDE_STYLE_TO_TOKEN[token]
+ style.update({token_type: cli_style[token]})
+ else:
+ # TODO: cli helpers will have to switch to ptk.Style
+ logger.error('Unhandled style / class name: %s', token)
+
+ class OutputStyle(PygmentsStyle):
+ default_style = ""
+ styles = style
+
+ return OutputStyle
diff --git a/mycli/clitoolbar.py b/mycli/clitoolbar.py
new file mode 100644
index 0000000..24d1108
--- /dev/null
+++ b/mycli/clitoolbar.py
@@ -0,0 +1,58 @@
+from prompt_toolkit.key_binding.vi_state import InputMode
+from prompt_toolkit.application import get_app
+from prompt_toolkit.enums import EditingMode
+from .packages import special
+
+
+def create_toolbar_tokens_func(mycli, show_fish_help):
+ """Return a function that generates the toolbar tokens."""
+ def get_toolbar_tokens():
+ result = []
+ result.append(('class:bottom-toolbar', ' '))
+
+ if mycli.multi_line:
+ delimiter = special.get_current_delimiter()
+ result.append(
+ (
+ 'class:bottom-toolbar',
+ ' ({} [{}] will end the line) '.format(
+ 'Semi-colon' if delimiter == ';' else 'Delimiter', delimiter)
+ ))
+
+ if mycli.multi_line:
+ result.append(('class:bottom-toolbar.on', '[F3] Multiline: ON '))
+ else:
+ result.append(('class:bottom-toolbar.off',
+ '[F3] Multiline: OFF '))
+ if mycli.prompt_app.editing_mode == EditingMode.VI:
+ result.append((
+ 'class:botton-toolbar.on',
+ 'Vi-mode ({})'.format(_get_vi_mode())
+ ))
+
+ if mycli.toolbar_error_message:
+ result.append(
+ ('class:bottom-toolbar', ' ' + mycli.toolbar_error_message))
+ mycli.toolbar_error_message = None
+
+ if show_fish_help():
+ result.append(
+ ('class:bottom-toolbar', ' Right-arrow to complete suggestion'))
+
+ if mycli.completion_refresher.is_refreshing():
+ result.append(
+ ('class:bottom-toolbar', ' Refreshing completions...'))
+
+ return result
+ return get_toolbar_tokens
+
+
+def _get_vi_mode():
+ """Get the current vi mode for display."""
+ return {
+ InputMode.INSERT: 'I',
+ InputMode.NAVIGATION: 'N',
+ InputMode.REPLACE: 'R',
+ InputMode.REPLACE_SINGLE: 'R',
+ InputMode.INSERT_MULTIPLE: 'M',
+ }[get_app().vi_state.input_mode]
diff --git a/mycli/compat.py b/mycli/compat.py
new file mode 100644
index 0000000..2ebfe07
--- /dev/null
+++ b/mycli/compat.py
@@ -0,0 +1,6 @@
+"""Platform and Python version compatibility support."""
+
+import sys
+
+
+WIN = sys.platform in ('win32', 'cygwin')
diff --git a/mycli/completion_refresher.py b/mycli/completion_refresher.py
new file mode 100644
index 0000000..8eb3de9
--- /dev/null
+++ b/mycli/completion_refresher.py
@@ -0,0 +1,123 @@
+import threading
+from .packages.special.main import COMMANDS
+from collections import OrderedDict
+
+from .sqlcompleter import SQLCompleter
+from .sqlexecute import SQLExecute
+
+class CompletionRefresher(object):
+
+ refreshers = OrderedDict()
+
+ def __init__(self):
+ self._completer_thread = None
+ self._restart_refresh = threading.Event()
+
+ def refresh(self, executor, callbacks, completer_options=None):
+ """Creates a SQLCompleter object and populates it with the relevant
+ completion suggestions in a background thread.
+
+ executor - SQLExecute object, used to extract the credentials to connect
+ to the database.
+ callbacks - A function or a list of functions to call after the thread
+ has completed the refresh. The newly created completion
+ object will be passed in as an argument to each callback.
+ completer_options - dict of options to pass to SQLCompleter.
+
+ """
+ if completer_options is None:
+ completer_options = {}
+
+ if self.is_refreshing():
+ self._restart_refresh.set()
+ return [(None, None, None, 'Auto-completion refresh restarted.')]
+ else:
+ self._completer_thread = threading.Thread(
+ target=self._bg_refresh,
+ args=(executor, callbacks, completer_options),
+ name='completion_refresh')
+ self._completer_thread.daemon = True
+ self._completer_thread.start()
+ return [(None, None, None,
+ 'Auto-completion refresh started in the background.')]
+
+ def is_refreshing(self):
+ return self._completer_thread and self._completer_thread.is_alive()
+
+ def _bg_refresh(self, sqlexecute, callbacks, completer_options):
+ completer = SQLCompleter(**completer_options)
+
+ # Create a new pgexecute method to populate the completions.
+ e = sqlexecute
+ executor = SQLExecute(e.dbname, e.user, e.password, e.host, e.port,
+ e.socket, e.charset, e.local_infile, e.ssl,
+ e.ssh_user, e.ssh_host, e.ssh_port,
+ e.ssh_password, e.ssh_key_filename)
+
+ # If callbacks is a single function then push it into a list.
+ if callable(callbacks):
+ callbacks = [callbacks]
+
+ while 1:
+ for refresher in self.refreshers.values():
+ refresher(completer, executor)
+ if self._restart_refresh.is_set():
+ self._restart_refresh.clear()
+ break
+ else:
+ # Break out of while loop if the for loop finishes natually
+ # without hitting the break statement.
+ break
+
+ # Start over the refresh from the beginning if the for loop hit the
+ # break statement.
+ continue
+
+ for callback in callbacks:
+ callback(completer)
+
+def refresher(name, refreshers=CompletionRefresher.refreshers):
+ """Decorator to add the decorated function to the dictionary of
+ refreshers. Any function decorated with a @refresher will be executed as
+ part of the completion refresh routine."""
+ def wrapper(wrapped):
+ refreshers[name] = wrapped
+ return wrapped
+ return wrapper
+
+@refresher('databases')
+def refresh_databases(completer, executor):
+ completer.extend_database_names(executor.databases())
+
+@refresher('schemata')
+def refresh_schemata(completer, executor):
+ # schemata - In MySQL Schema is the same as database. But for mycli
+ # schemata will be the name of the current database.
+ completer.extend_schemata(executor.dbname)
+ completer.set_dbname(executor.dbname)
+
+@refresher('tables')
+def refresh_tables(completer, executor):
+ completer.extend_relations(executor.tables(), kind='tables')
+ completer.extend_columns(executor.table_columns(), kind='tables')
+
+@refresher('users')
+def refresh_users(completer, executor):
+ completer.extend_users(executor.users())
+
+# @refresher('views')
+# def refresh_views(completer, executor):
+# completer.extend_relations(executor.views(), kind='views')
+# completer.extend_columns(executor.view_columns(), kind='views')
+
+@refresher('functions')
+def refresh_functions(completer, executor):
+ completer.extend_functions(executor.functions())
+
+@refresher('special_commands')
+def refresh_special(completer, executor):
+ completer.extend_special_commands(COMMANDS.keys())
+
+@refresher('show_commands')
+def refresh_show_commands(completer, executor):
+ completer.extend_show_items(executor.show_candidates())
diff --git a/mycli/config.py b/mycli/config.py
new file mode 100644
index 0000000..5d71109
--- /dev/null
+++ b/mycli/config.py
@@ -0,0 +1,344 @@
+from copy import copy
+from io import BytesIO, TextIOWrapper
+import logging
+import os
+from os.path import exists
+import struct
+import sys
+from typing import Union, IO
+
+from configobj import ConfigObj, ConfigObjError
+import pyaes
+
+try:
+ import importlib.resources as resources
+except ImportError:
+ # Python < 3.7
+ import importlib_resources as resources
+
+try:
+ basestring
+except NameError:
+ basestring = str
+
+
+logger = logging.getLogger(__name__)
+
+
+def log(logger, level, message):
+ """Logs message to stderr if logging isn't initialized."""
+
+ if logger.parent.name != 'root':
+ logger.log(level, message)
+ else:
+ print(message, file=sys.stderr)
+
+
+def read_config_file(f, list_values=True):
+ """Read a config file.
+
+ *list_values* set to `True` is the default behavior of ConfigObj.
+ Disabling it causes values to not be parsed for lists,
+ (e.g. 'a,b,c' -> ['a', 'b', 'c']. Additionally, the config values are
+ not unquoted. We are disabling list_values when reading MySQL config files
+ so we can correctly interpret commas in passwords.
+
+ """
+
+ if isinstance(f, basestring):
+ f = os.path.expanduser(f)
+
+ try:
+ config = ConfigObj(f, interpolation=False, encoding='utf8',
+ list_values=list_values)
+ except ConfigObjError as e:
+ log(logger, logging.WARNING, "Unable to parse line {0} of config file "
+ "'{1}'.".format(e.line_number, f))
+ log(logger, logging.WARNING, "Using successfully parsed config values.")
+ return e.config
+ except (IOError, OSError) as e:
+ log(logger, logging.WARNING, "You don't have permission to read "
+ "config file '{0}'.".format(e.filename))
+ return None
+
+ return config
+
+
+def get_included_configs(config_file: Union[str, TextIOWrapper]) -> list:
+ """Get a list of configuration files that are included into config_path
+ with !includedir directive.
+
+ "Normal" configs should be passed as file paths. The only exception
+ is .mylogin which is decoded into a stream. However, it never
+ contains include directives and so will be ignored by this
+ function.
+
+ """
+ if not isinstance(config_file, str) or not os.path.isfile(config_file):
+ return []
+ included_configs = []
+
+ try:
+ with open(config_file) as f:
+ include_directives = filter(
+ lambda s: s.startswith('!includedir'),
+ f
+ )
+ dirs = map(lambda s: s.strip().split()[-1], include_directives)
+ dirs = filter(os.path.isdir, dirs)
+ for dir in dirs:
+ for filename in os.listdir(dir):
+ if filename.endswith('.cnf'):
+ included_configs.append(os.path.join(dir, filename))
+ except (PermissionError, UnicodeDecodeError):
+ pass
+ return included_configs
+
+
+def read_config_files(files, list_values=True):
+ """Read and merge a list of config files."""
+
+ config = create_default_config(list_values=list_values)
+ _files = copy(files)
+ while _files:
+ _file = _files.pop(0)
+ _config = read_config_file(_file, list_values=list_values)
+
+ # expand includes only if we were able to parse config
+ # (otherwise we'll just encounter the same errors again)
+ if config is not None:
+ _files = get_included_configs(_file) + _files
+ if bool(_config) is True:
+ config.merge(_config)
+ config.filename = _config.filename
+
+ return config
+
+
+def create_default_config(list_values=True):
+ import mycli
+ default_config_file = resources.open_text(mycli, 'myclirc')
+ return read_config_file(default_config_file, list_values=list_values)
+
+
+def write_default_config(destination, overwrite=False):
+ import mycli
+ default_config = resources.read_text(mycli, 'myclirc')
+ destination = os.path.expanduser(destination)
+ if not overwrite and exists(destination):
+ return
+
+ with open(destination, 'w') as f:
+ f.write(default_config)
+
+
+def get_mylogin_cnf_path():
+ """Return the path to the login path file or None if it doesn't exist."""
+ mylogin_cnf_path = os.getenv('MYSQL_TEST_LOGIN_FILE')
+
+ if mylogin_cnf_path is None:
+ app_data = os.getenv('APPDATA')
+ default_dir = os.path.join(app_data, 'MySQL') if app_data else '~'
+ mylogin_cnf_path = os.path.join(default_dir, '.mylogin.cnf')
+
+ mylogin_cnf_path = os.path.expanduser(mylogin_cnf_path)
+
+ if exists(mylogin_cnf_path):
+ logger.debug("Found login path file at '{0}'".format(mylogin_cnf_path))
+ return mylogin_cnf_path
+ return None
+
+
+def open_mylogin_cnf(name):
+ """Open a readable version of .mylogin.cnf.
+
+ Returns the file contents as a TextIOWrapper object.
+
+ :param str name: The pathname of the file to be opened.
+ :return: the login path file or None
+ """
+
+ try:
+ with open(name, 'rb') as f:
+ plaintext = read_and_decrypt_mylogin_cnf(f)
+ except (OSError, IOError, ValueError):
+ logger.error('Unable to open login path file.')
+ return None
+
+ if not isinstance(plaintext, BytesIO):
+ logger.error('Unable to read login path file.')
+ return None
+
+ return TextIOWrapper(plaintext)
+
+
+# TODO reuse code between encryption an decryption
+def encrypt_mylogin_cnf(plaintext: IO[str]):
+ """Encryption of .mylogin.cnf file, analogous to calling
+ mysql_config_editor.
+
+ Code is based on the python implementation by Kristian Koehntopp
+ https://github.com/isotopp/mysql-config-coder
+
+ """
+ def realkey(key):
+ """Create the AES key from the login key."""
+ rkey = bytearray(16)
+ for i in range(len(key)):
+ rkey[i % 16] ^= key[i]
+ return bytes(rkey)
+
+ def encode_line(plaintext, real_key, buf_len):
+ aes = pyaes.AESModeOfOperationECB(real_key)
+ text_len = len(plaintext)
+ pad_len = buf_len - text_len
+ pad_chr = bytes(chr(pad_len), "utf8")
+ plaintext = plaintext.encode() + pad_chr * pad_len
+ encrypted_text = b''.join(
+ [aes.encrypt(plaintext[i: i + 16])
+ for i in range(0, len(plaintext), 16)]
+ )
+ return encrypted_text
+
+ LOGIN_KEY_LENGTH = 20
+ key = os.urandom(LOGIN_KEY_LENGTH)
+ real_key = realkey(key)
+
+ outfile = BytesIO()
+
+ outfile.write(struct.pack("i", 0))
+ outfile.write(key)
+
+ while True:
+ line = plaintext.readline()
+ if not line:
+ break
+ real_len = len(line)
+ pad_len = (int(real_len / 16) + 1) * 16
+
+ outfile.write(struct.pack("i", pad_len))
+ x = encode_line(line, real_key, pad_len)
+ outfile.write(x)
+
+ outfile.seek(0)
+ return outfile
+
+
+def read_and_decrypt_mylogin_cnf(f):
+ """Read and decrypt the contents of .mylogin.cnf.
+
+ This decryption algorithm mimics the code in MySQL's
+ mysql_config_editor.cc.
+
+ The login key is 20-bytes of random non-printable ASCII.
+ It is written to the actual login path file. It is used
+ to generate the real key used in the AES cipher.
+
+ :param f: an I/O object opened in binary mode
+ :return: the decrypted login path file
+ :rtype: io.BytesIO or None
+ """
+
+ # Number of bytes used to store the length of ciphertext.
+ MAX_CIPHER_STORE_LEN = 4
+
+ LOGIN_KEY_LEN = 20
+
+ # Move past the unused buffer.
+ buf = f.read(4)
+
+ if not buf or len(buf) != 4:
+ logger.error('Login path file is blank or incomplete.')
+ return None
+
+ # Read the login key.
+ key = f.read(LOGIN_KEY_LEN)
+
+ # Generate the real key.
+ rkey = [0] * 16
+ for i in range(LOGIN_KEY_LEN):
+ try:
+ rkey[i % 16] ^= ord(key[i:i+1])
+ except TypeError:
+ # ord() was unable to get the value of the byte.
+ logger.error('Unable to generate login path AES key.')
+ return None
+ rkey = struct.pack('16B', *rkey)
+
+ # Create a bytes buffer to hold the plaintext.
+ plaintext = BytesIO()
+ aes = pyaes.AESModeOfOperationECB(rkey)
+
+ while True:
+ # Read the length of the ciphertext.
+ len_buf = f.read(MAX_CIPHER_STORE_LEN)
+ if len(len_buf) < MAX_CIPHER_STORE_LEN:
+ break
+ cipher_len, = struct.unpack("<i", len_buf)
+
+ # Read cipher_len bytes from the file and decrypt.
+ cipher = f.read(cipher_len)
+ plain = _remove_pad(
+ b''.join([aes.decrypt(cipher[i: i + 16])
+ for i in range(0, cipher_len, 16)])
+ )
+ if plain is False:
+ continue
+ plaintext.write(plain)
+
+ if plaintext.tell() == 0:
+ logger.error('No data successfully decrypted from login path file.')
+ return None
+
+ plaintext.seek(0)
+ return plaintext
+
+
+def str_to_bool(s):
+ """Convert a string value to its corresponding boolean value."""
+ if isinstance(s, bool):
+ return s
+ elif not isinstance(s, basestring):
+ raise TypeError('argument must be a string')
+
+ true_values = ('true', 'on', '1')
+ false_values = ('false', 'off', '0')
+
+ if s.lower() in true_values:
+ return True
+ elif s.lower() in false_values:
+ return False
+ else:
+ raise ValueError('not a recognized boolean value: {0}'.format(s))
+
+
+def strip_matching_quotes(s):
+ """Remove matching, surrounding quotes from a string.
+
+ This is the same logic that ConfigObj uses when parsing config
+ values.
+
+ """
+ if (isinstance(s, basestring) and len(s) >= 2 and
+ s[0] == s[-1] and s[0] in ('"', "'")):
+ s = s[1:-1]
+ return s
+
+
+def _remove_pad(line):
+ """Remove the pad from the *line*."""
+ try:
+ # Determine pad length.
+ pad_length = ord(line[-1:])
+ except TypeError:
+ # ord() was unable to get the value of the byte.
+ logger.warning('Unable to remove pad.')
+ return False
+
+ if pad_length > len(line) or len(set(line[-pad_length:])) != 1:
+ # Pad length should be less than or equal to the length of the
+ # plaintext. The pad should have a single unique byte.
+ logger.warning('Invalid pad found in login path file.')
+ return False
+
+ return line[:-pad_length]
diff --git a/mycli/key_bindings.py b/mycli/key_bindings.py
new file mode 100644
index 0000000..03e4ace
--- /dev/null
+++ b/mycli/key_bindings.py
@@ -0,0 +1,131 @@
+import logging
+from prompt_toolkit.enums import EditingMode
+from prompt_toolkit.filters import completion_is_selected, emacs_mode, vi_mode
+from prompt_toolkit.key_binding import KeyBindings
+
+_logger = logging.getLogger(__name__)
+
+
+def mycli_bindings(mycli):
+ """Custom key bindings for mycli."""
+ kb = KeyBindings()
+
+ @kb.add('f2')
+ def _(event):
+ """Enable/Disable SmartCompletion Mode."""
+ _logger.debug('Detected F2 key.')
+ mycli.completer.smart_completion = not mycli.completer.smart_completion
+
+ @kb.add('f3')
+ def _(event):
+ """Enable/Disable Multiline Mode."""
+ _logger.debug('Detected F3 key.')
+ mycli.multi_line = not mycli.multi_line
+
+ @kb.add('f4')
+ def _(event):
+ """Toggle between Vi and Emacs mode."""
+ _logger.debug('Detected F4 key.')
+ if mycli.key_bindings == "vi":
+ event.app.editing_mode = EditingMode.EMACS
+ mycli.key_bindings = "emacs"
+ else:
+ event.app.editing_mode = EditingMode.VI
+ mycli.key_bindings = "vi"
+
+ @kb.add('tab')
+ def _(event):
+ """Force autocompletion at cursor."""
+ _logger.debug('Detected <Tab> key.')
+ b = event.app.current_buffer
+ if b.complete_state:
+ b.complete_next()
+ else:
+ b.start_completion(select_first=True)
+
+ @kb.add('c-space')
+ def _(event):
+ """
+ Initialize autocompletion at cursor.
+
+ If the autocompletion menu is not showing, display it with the
+ appropriate completions for the context.
+
+ If the menu is showing, select the next completion.
+ """
+ _logger.debug('Detected <C-Space> key.')
+
+ b = event.app.current_buffer
+ if b.complete_state:
+ b.complete_next()
+ else:
+ b.start_completion(select_first=False)
+
+ @kb.add('>', filter=vi_mode)
+ @kb.add('c-x', 'p', filter=emacs_mode)
+ def _(event):
+ """
+ Prettify and indent current statement, usually into multiple lines.
+
+ Only accepts buffers containing single SQL statements.
+ """
+ _logger.debug('Detected <C-x p>/> key.')
+
+ b = event.app.current_buffer
+ cursorpos_relative = b.cursor_position / len(b.text)
+ pretty_text = mycli.handle_prettify_binding(b.text)
+ if len(pretty_text) > 0:
+ b.text = pretty_text
+ cursorpos_abs = int(round(cursorpos_relative * len(b.text)))
+ while 0 < cursorpos_abs < len(b.text) \
+ and b.text[cursorpos_abs] in (' ', '\n'):
+ cursorpos_abs -= 1
+ b.cursor_position = min(cursorpos_abs, len(b.text))
+
+ @kb.add('<', filter=vi_mode)
+ @kb.add('c-x', 'u', filter=emacs_mode)
+ def _(event):
+ """
+ Unprettify and dedent current statement, usually into one line.
+
+ Only accepts buffers containing single SQL statements.
+ """
+ _logger.debug('Detected <C-x u>/< key.')
+
+ b = event.app.current_buffer
+ cursorpos_relative = b.cursor_position / len(b.text)
+ unpretty_text = mycli.handle_unprettify_binding(b.text)
+ if len(unpretty_text) > 0:
+ b.text = unpretty_text
+ cursorpos_abs = int(round(cursorpos_relative * len(b.text)))
+ while 0 < cursorpos_abs < len(b.text) \
+ and b.text[cursorpos_abs] in (' ', '\n'):
+ cursorpos_abs -= 1
+ b.cursor_position = min(cursorpos_abs, len(b.text))
+
+ @kb.add('enter', filter=completion_is_selected)
+ def _(event):
+ """Makes the enter key work as the tab key only when showing the menu.
+
+ In other words, don't execute query when enter is pressed in
+ the completion dropdown menu, instead close the dropdown menu
+ (accept current selection).
+
+ """
+ _logger.debug('Detected enter key.')
+
+ event.current_buffer.complete_state = None
+ b = event.app.current_buffer
+ b.complete_state = None
+
+ @kb.add('escape', 'enter')
+ def _(event):
+ """Introduces a line break in multi-line mode, or dispatches the
+ command in single-line mode."""
+ _logger.debug('Detected alt-enter key.')
+ if mycli.multi_line:
+ event.app.current_buffer.validate_and_handle()
+ else:
+ event.app.current_buffer.insert_text('\n')
+
+ return kb
diff --git a/mycli/lexer.py b/mycli/lexer.py
new file mode 100644
index 0000000..4b14d72
--- /dev/null
+++ b/mycli/lexer.py
@@ -0,0 +1,12 @@
+from pygments.lexer import inherit
+from pygments.lexers.sql import MySqlLexer
+from pygments.token import Keyword
+
+
+class MyCliLexer(MySqlLexer):
+ """Extends MySQL lexer to add keywords."""
+
+ tokens = {
+ 'root': [(r'\brepair\b', Keyword),
+ (r'\boffset\b', Keyword), inherit],
+ }
diff --git a/mycli/magic.py b/mycli/magic.py
new file mode 100644
index 0000000..aad229a
--- /dev/null
+++ b/mycli/magic.py
@@ -0,0 +1,54 @@
+from .main import MyCli
+import sql.parse
+import sql.connection
+import logging
+
+_logger = logging.getLogger(__name__)
+
+def load_ipython_extension(ipython):
+
+ # This is called via the ipython command '%load_ext mycli.magic'.
+
+ # First, load the sql magic if it isn't already loaded.
+ if not ipython.find_line_magic('sql'):
+ ipython.run_line_magic('load_ext', 'sql')
+
+ # Register our own magic.
+ ipython.register_magic_function(mycli_line_magic, 'line', 'mycli')
+
+def mycli_line_magic(line):
+ _logger.debug('mycli magic called: %r', line)
+ parsed = sql.parse.parse(line, {})
+ conn = sql.connection.Connection(parsed['connection'])
+
+ try:
+ # A corresponding mycli object already exists
+ mycli = conn._mycli
+ _logger.debug('Reusing existing mycli')
+ except AttributeError:
+ mycli = MyCli()
+ u = conn.session.engine.url
+ _logger.debug('New mycli: %r', str(u))
+
+ mycli.connect(host=u.host, port=u.port, passwd=u.password, database=u.database, user=u.username, init_command=None)
+ conn._mycli = mycli
+
+ # For convenience, print the connection alias
+ print('Connected: {}'.format(conn.name))
+
+ try:
+ mycli.run_cli()
+ except SystemExit:
+ pass
+
+ if not mycli.query_history:
+ return
+
+ q = mycli.query_history[-1]
+ if q.mutating:
+ _logger.debug('Mutating query detected -- ignoring')
+ return
+
+ if q.successful:
+ ipython = get_ipython()
+ return ipython.run_cell_magic('sql', line, q.query)
diff --git a/mycli/main.py b/mycli/main.py
new file mode 100755
index 0000000..208572d
--- /dev/null
+++ b/mycli/main.py
@@ -0,0 +1,1468 @@
+from collections import defaultdict
+from io import open
+import os
+import sys
+import shutil
+import traceback
+import logging
+import threading
+import re
+import stat
+import fileinput
+from collections import namedtuple
+try:
+ from pwd import getpwuid
+except ImportError:
+ pass
+from time import time
+from datetime import datetime
+from random import choice
+
+from pymysql import OperationalError
+from cli_helpers.tabular_output import TabularOutputFormatter
+from cli_helpers.tabular_output import preprocessors
+from cli_helpers.utils import strip_ansi
+import click
+import sqlparse
+import sqlglot
+from mycli.packages.parseutils import is_dropping_database, is_destructive
+from prompt_toolkit.completion import DynamicCompleter
+from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
+from prompt_toolkit.key_binding.bindings.named_commands import register as prompt_register
+from prompt_toolkit.shortcuts import PromptSession, CompleteStyle
+from prompt_toolkit.document import Document
+from prompt_toolkit.filters import HasFocus, IsDone
+from prompt_toolkit.formatted_text import ANSI
+from prompt_toolkit.layout.processors import (HighlightMatchingBracketProcessor,
+ ConditionalProcessor)
+from prompt_toolkit.lexers import PygmentsLexer
+from prompt_toolkit.history import FileHistory
+from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
+
+from .packages.special.main import NO_QUERY
+from .packages.prompt_utils import confirm, confirm_destructive_query
+from .packages.tabular_output import sql_format
+from .packages import special
+from .packages.special.favoritequeries import FavoriteQueries
+from .sqlcompleter import SQLCompleter
+from .clitoolbar import create_toolbar_tokens_func
+from .clistyle import style_factory, style_factory_output
+from .sqlexecute import FIELD_TYPES, SQLExecute, ERROR_CODE_ACCESS_DENIED
+from .clibuffer import cli_is_multiline
+from .completion_refresher import CompletionRefresher
+from .config import (write_default_config, get_mylogin_cnf_path,
+ open_mylogin_cnf, read_config_files, str_to_bool,
+ strip_matching_quotes)
+from .key_bindings import mycli_bindings
+from .lexer import MyCliLexer
+from . import __version__
+from .compat import WIN
+from .packages.filepaths import dir_path_exists, guess_socket_location
+
+import itertools
+
+click.disable_unicode_literals_warning = True
+
+try:
+ from urlparse import urlparse
+ from urlparse import unquote
+except ImportError:
+ from urllib.parse import urlparse
+ from urllib.parse import unquote
+
+try:
+ import importlib.resources as resources
+except ImportError:
+ # Python < 3.7
+ import importlib_resources as resources
+
+try:
+ import paramiko
+except ImportError:
+ from mycli.packages.paramiko_stub import paramiko
+
+# Query tuples are used for maintaining history
+Query = namedtuple('Query', ['query', 'successful', 'mutating'])
+
+SUPPORT_INFO = (
+ 'Home: http://mycli.net\n'
+ 'Bug tracker: https://github.com/dbcli/mycli/issues'
+)
+
+
+class MyCli(object):
+
+ default_prompt = '\\t \\u@\\h:\\d> '
+ max_len_prompt = 45
+ defaults_suffix = None
+
+ # In order of being loaded. Files lower in list override earlier ones.
+ cnf_files = [
+ '/etc/my.cnf',
+ '/etc/mysql/my.cnf',
+ '/usr/local/etc/my.cnf',
+ os.path.expanduser('~/.my.cnf'),
+ ]
+
+ # check XDG_CONFIG_HOME exists and not an empty string
+ if os.environ.get("XDG_CONFIG_HOME"):
+ xdg_config_home = os.environ.get("XDG_CONFIG_HOME")
+ else:
+ xdg_config_home = "~/.config"
+ system_config_files = [
+ '/etc/myclirc',
+ os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc")
+ ]
+
+ pwd_config_file = os.path.join(os.getcwd(), ".myclirc")
+
+ def __init__(self, sqlexecute=None, prompt=None,
+ logfile=None, defaults_suffix=None, defaults_file=None,
+ login_path=None, auto_vertical_output=False, warn=None,
+ myclirc="~/.myclirc"):
+ self.sqlexecute = sqlexecute
+ self.logfile = logfile
+ self.defaults_suffix = defaults_suffix
+ self.login_path = login_path
+ self.toolbar_error_message = None
+
+ # self.cnf_files is a class variable that stores the list of mysql
+ # config files to read in at launch.
+ # If defaults_file is specified then override the class variable with
+ # defaults_file.
+ if defaults_file:
+ self.cnf_files = [defaults_file]
+
+ # Load config.
+ config_files = (self.system_config_files +
+ [myclirc] + [self.pwd_config_file])
+ c = self.config = read_config_files(config_files)
+ self.multi_line = c['main'].as_bool('multi_line')
+ self.key_bindings = c['main']['key_bindings']
+ special.set_timing_enabled(c['main'].as_bool('timing'))
+ self.beep_after_seconds = float(c['main']['beep_after_seconds'] or 0)
+
+ FavoriteQueries.instance = FavoriteQueries.from_config(self.config)
+
+ self.dsn_alias = None
+ self.formatter = TabularOutputFormatter(
+ format_name=c['main']['table_format'])
+ sql_format.register_new_formatter(self.formatter)
+ self.formatter.mycli = self
+ self.syntax_style = c['main']['syntax_style']
+ self.less_chatty = c['main'].as_bool('less_chatty')
+ self.cli_style = c['colors']
+ self.output_style = style_factory_output(
+ self.syntax_style,
+ self.cli_style
+ )
+ self.wider_completion_menu = c['main'].as_bool('wider_completion_menu')
+ c_dest_warning = c['main'].as_bool('destructive_warning')
+ self.destructive_warning = c_dest_warning if warn is None else warn
+ self.login_path_as_host = c['main'].as_bool('login_path_as_host')
+
+ # read from cli argument or user config file
+ self.auto_vertical_output = auto_vertical_output or \
+ c['main'].as_bool('auto_vertical_output')
+
+ # Write user config if system config wasn't the last config loaded.
+ if c.filename not in self.system_config_files and not os.path.exists(myclirc):
+ write_default_config(myclirc)
+
+ # audit log
+ if self.logfile is None and 'audit_log' in c['main']:
+ try:
+ self.logfile = open(os.path.expanduser(c['main']['audit_log']), 'a')
+ except (IOError, OSError) as e:
+ self.echo('Error: Unable to open the audit log file. Your queries will not be logged.',
+ err=True, fg='red')
+ self.logfile = False
+
+ self.completion_refresher = CompletionRefresher()
+
+ self.logger = logging.getLogger(__name__)
+ self.initialize_logging()
+
+ prompt_cnf = self.read_my_cnf_files(self.cnf_files, ['prompt'])['prompt']
+ self.prompt_format = prompt or prompt_cnf or c['main']['prompt'] or \
+ self.default_prompt
+ self.multiline_continuation_char = c['main']['prompt_continuation']
+ keyword_casing = c['main'].get('keyword_casing', 'auto')
+
+ self.query_history = []
+
+ # Initialize completer.
+ self.smart_completion = c['main'].as_bool('smart_completion')
+ self.completer = SQLCompleter(
+ self.smart_completion,
+ supported_formats=self.formatter.supported_formats,
+ keyword_casing=keyword_casing)
+ self._completer_lock = threading.Lock()
+
+ # Register custom special commands.
+ self.register_special_commands()
+
+ # Load .mylogin.cnf if it exists.
+ mylogin_cnf_path = get_mylogin_cnf_path()
+ if mylogin_cnf_path:
+ mylogin_cnf = open_mylogin_cnf(mylogin_cnf_path)
+ if mylogin_cnf_path and mylogin_cnf:
+ # .mylogin.cnf gets read last, even if defaults_file is specified.
+ self.cnf_files.append(mylogin_cnf)
+ elif mylogin_cnf_path and not mylogin_cnf:
+ # There was an error reading the login path file.
+ print('Error: Unable to read login path file.')
+
+ self.prompt_app = None
+
+ def register_special_commands(self):
+ special.register_special_command(self.change_db, 'use',
+ '\\u', 'Change to a new database.', aliases=('\\u',))
+ special.register_special_command(self.change_db, 'connect',
+ '\\r', 'Reconnect to the database. Optional database argument.',
+ aliases=('\\r', ), case_sensitive=True)
+ special.register_special_command(self.refresh_completions, 'rehash',
+ '\\#', 'Refresh auto-completions.', arg_type=NO_QUERY, aliases=('\\#',))
+ special.register_special_command(
+ self.change_table_format, 'tableformat', '\\T',
+ 'Change the table format used to output results.',
+ aliases=('\\T',), case_sensitive=True)
+ special.register_special_command(self.execute_from_file, 'source', '\\. filename',
+ 'Execute commands from file.', aliases=('\\.',))
+ special.register_special_command(self.change_prompt_format, 'prompt',
+ '\\R', 'Change prompt format.', aliases=('\\R',), case_sensitive=True)
+
+ def change_table_format(self, arg, **_):
+ try:
+ self.formatter.format_name = arg
+ yield (None, None, None,
+ 'Changed table format to {}'.format(arg))
+ except ValueError:
+ msg = 'Table format {} not recognized. Allowed formats:'.format(
+ arg)
+ for table_type in self.formatter.supported_formats:
+ msg += "\n\t{}".format(table_type)
+ yield (None, None, None, msg)
+
+ def change_db(self, arg, **_):
+ if not arg:
+ click.secho(
+ "No database selected",
+ err=True, fg="red"
+ )
+ return
+
+ if arg.startswith('`') and arg.endswith('`'):
+ arg = re.sub(r'^`(.*)`$', r'\1', arg)
+ arg = re.sub(r'``', r'`', arg)
+ self.sqlexecute.change_db(arg)
+
+ yield (None, None, None, 'You are now connected to database "%s" as '
+ 'user "%s"' % (self.sqlexecute.dbname, self.sqlexecute.user))
+
+ def execute_from_file(self, arg, **_):
+ if not arg:
+ message = 'Missing required argument, filename.'
+ return [(None, None, None, message)]
+ try:
+ with open(os.path.expanduser(arg)) as f:
+ query = f.read()
+ except IOError as e:
+ return [(None, None, None, str(e))]
+
+ if (self.destructive_warning and
+ confirm_destructive_query(query) is False):
+ message = 'Wise choice. Command execution stopped.'
+ return [(None, None, None, message)]
+
+ return self.sqlexecute.run(query)
+
+ def change_prompt_format(self, arg, **_):
+ """
+ Change the prompt format.
+ """
+ if not arg:
+ message = 'Missing required argument, format.'
+ return [(None, None, None, message)]
+
+ self.prompt_format = self.get_prompt(arg)
+ return [(None, None, None, "Changed prompt format to %s" % arg)]
+
+ def initialize_logging(self):
+
+ log_file = os.path.expanduser(self.config['main']['log_file'])
+ log_level = self.config['main']['log_level']
+
+ level_map = {'CRITICAL': logging.CRITICAL,
+ 'ERROR': logging.ERROR,
+ 'WARNING': logging.WARNING,
+ 'INFO': logging.INFO,
+ 'DEBUG': logging.DEBUG
+ }
+
+ # Disable logging if value is NONE by switching to a no-op handler
+ # Set log level to a high value so it doesn't even waste cycles getting called.
+ if log_level.upper() == "NONE":
+ handler = logging.NullHandler()
+ log_level = "CRITICAL"
+ elif dir_path_exists(log_file):
+ handler = logging.FileHandler(log_file)
+ else:
+ self.echo(
+ 'Error: Unable to open the log file "{}".'.format(log_file),
+ err=True, fg='red')
+ return
+
+ formatter = logging.Formatter(
+ '%(asctime)s (%(process)d/%(threadName)s) '
+ '%(name)s %(levelname)s - %(message)s')
+
+ handler.setFormatter(formatter)
+
+ root_logger = logging.getLogger('mycli')
+ root_logger.addHandler(handler)
+ root_logger.setLevel(level_map[log_level.upper()])
+
+ logging.captureWarnings(True)
+
+ root_logger.debug('Initializing mycli logging.')
+ root_logger.debug('Log file %r.', log_file)
+
+
+ def read_my_cnf_files(self, files, keys):
+ """
+ Reads a list of config files and merges them. The last one will win.
+ :param files: list of files to read
+ :param keys: list of keys to retrieve
+ :returns: tuple, with None for missing keys.
+ """
+ cnf = read_config_files(files, list_values=False)
+
+ sections = ['client', 'mysqld']
+ key_transformations = {
+ 'mysqld': {
+ 'socket': 'default_socket',
+ 'port': 'default_port',
+ 'user': 'default_user',
+ },
+ }
+
+ if self.login_path and self.login_path != 'client':
+ sections.append(self.login_path)
+
+ if self.defaults_suffix:
+ sections.extend([sect + self.defaults_suffix for sect in sections])
+
+ configuration = defaultdict(lambda: None)
+ for key in keys:
+ for section in cnf:
+ if (
+ section not in sections or
+ key not in cnf[section]
+ ):
+ continue
+ new_key = key_transformations.get(section, {}).get(key) or key
+ configuration[new_key] = strip_matching_quotes(
+ cnf[section][key])
+
+ return configuration
+
+
+ def merge_ssl_with_cnf(self, ssl, cnf):
+ """Merge SSL configuration dict with cnf dict"""
+
+ merged = {}
+ merged.update(ssl)
+ prefix = 'ssl-'
+ for k, v in cnf.items():
+ # skip unrelated options
+ if not k.startswith(prefix):
+ continue
+ if v is None:
+ continue
+ # special case because PyMySQL argument is significantly different
+ # from commandline
+ if k == 'ssl-verify-server-cert':
+ merged['check_hostname'] = v
+ else:
+ # use argument name just strip "ssl-" prefix
+ arg = k[len(prefix):]
+ merged[arg] = v
+
+ return merged
+
+ def connect(self, database='', user='', passwd='', host='', port='',
+ socket='', charset='', local_infile='', ssl='',
+ ssh_user='', ssh_host='', ssh_port='',
+ ssh_password='', ssh_key_filename='', init_command='', password_file=''):
+
+ cnf = {'database': None,
+ 'user': None,
+ 'password': None,
+ 'host': None,
+ 'port': None,
+ 'socket': None,
+ 'default_socket': None,
+ 'default-character-set': None,
+ 'local-infile': None,
+ 'loose-local-infile': None,
+ 'ssl-ca': None,
+ 'ssl-cert': None,
+ 'ssl-key': None,
+ 'ssl-cipher': None,
+ 'ssl-verify-serer-cert': None,
+ }
+
+ cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys())
+
+ # Fall back to config values only if user did not specify a value.
+ database = database or cnf['database']
+ user = user or cnf['user'] or os.getenv('USER')
+ host = host or cnf['host']
+ port = port or cnf['port']
+ ssl = ssl or {}
+
+ port = port and int(port)
+ if not port:
+ port = 3306
+ if not host or host == 'localhost':
+ socket = (
+ cnf['socket'] or
+ cnf['default_socket'] or
+ guess_socket_location()
+ )
+
+
+ passwd = passwd if isinstance(passwd, str) else cnf['password']
+ charset = charset or cnf['default-character-set'] or 'utf8'
+
+ # Favor whichever local_infile option is set.
+ for local_infile_option in (local_infile, cnf['local-infile'],
+ cnf['loose-local-infile'], False):
+ try:
+ local_infile = str_to_bool(local_infile_option)
+ break
+ except (TypeError, ValueError):
+ pass
+
+ ssl = self.merge_ssl_with_cnf(ssl, cnf)
+ # prune lone check_hostname=False
+ if not any(v for v in ssl.values()):
+ ssl = None
+
+ # if the passwd is not specified try to set it using the password_file option
+ password_from_file = self.get_password_from_file(password_file)
+ passwd = passwd or password_from_file
+
+ # Connect to the database.
+
+ def _connect():
+ try:
+ self.sqlexecute = SQLExecute(
+ database, user, passwd, host, port, socket, charset,
+ local_infile, ssl, ssh_user, ssh_host, ssh_port,
+ ssh_password, ssh_key_filename, init_command
+ )
+ except OperationalError as e:
+ if e.args[0] == ERROR_CODE_ACCESS_DENIED:
+ if password_from_file:
+ new_passwd = password_from_file
+ else:
+ new_passwd = click.prompt('Password', hide_input=True,
+ show_default=False, type=str, err=True)
+ self.sqlexecute = SQLExecute(
+ database, user, new_passwd, host, port, socket,
+ charset, local_infile, ssl, ssh_user, ssh_host,
+ ssh_port, ssh_password, ssh_key_filename, init_command
+ )
+ else:
+ raise e
+
+ try:
+ if not WIN and socket:
+ socket_owner = getpwuid(os.stat(socket).st_uid).pw_name
+ self.echo(
+ f"Connecting to socket {socket}, owned by user {socket_owner}", err=True)
+ try:
+ _connect()
+ except OperationalError as e:
+ # These are "Can't open socket" and 2x "Can't connect"
+ if [code for code in (2001, 2002, 2003) if code == e.args[0]]:
+ self.logger.debug('Database connection failed: %r.', e)
+ self.logger.error(
+ "traceback: %r", traceback.format_exc())
+ self.logger.debug('Retrying over TCP/IP')
+ self.echo(
+ "Failed to connect to local MySQL server through socket '{}':".format(socket))
+ self.echo(str(e), err=True)
+ self.echo(
+ 'Retrying over TCP/IP', err=True)
+
+ # Else fall back to TCP/IP localhost
+ socket = ""
+ host = 'localhost'
+ port = 3306
+ _connect()
+ else:
+ raise e
+ else:
+ host = host or 'localhost'
+ port = port or 3306
+
+ # Bad ports give particularly daft error messages
+ try:
+ port = int(port)
+ except ValueError as e:
+ self.echo("Error: Invalid port number: '{0}'.".format(port),
+ err=True, fg='red')
+ exit(1)
+
+ _connect()
+ except Exception as e: # Connecting to a database could fail.
+ self.logger.debug('Database connection failed: %r.', e)
+ self.logger.error("traceback: %r", traceback.format_exc())
+ self.echo(str(e), err=True, fg='red')
+ exit(1)
+
+ def get_password_from_file(self, password_file):
+ password_from_file = None
+ if password_file:
+ if (os.path.isfile(password_file) or stat.S_ISFIFO(os.stat(password_file).st_mode)) \
+ and os.access(password_file, os.R_OK):
+ with open(password_file) as fp:
+ password_from_file = fp.readline()
+ password_from_file = password_from_file.rstrip().lstrip()
+
+ return password_from_file
+
+ def handle_editor_command(self, text):
+ r"""Editor command is any query that is prefixed or suffixed by a '\e'.
+ The reason for a while loop is because a user might edit a query
+ multiple times. For eg:
+
+ "select * from \e"<enter> to edit it in vim, then come
+ back to the prompt with the edited query "select * from
+ blah where q = 'abc'\e" to edit it again.
+ :param text: Document
+ :return: Document
+
+ """
+
+ while special.editor_command(text):
+ filename = special.get_filename(text)
+ query = (special.get_editor_query(text) or
+ self.get_last_query())
+ sql, message = special.open_external_editor(filename, sql=query)
+ if message:
+ # Something went wrong. Raise an exception and bail.
+ raise RuntimeError(message)
+ while True:
+ try:
+ text = self.prompt_app.prompt(default=sql)
+ break
+ except KeyboardInterrupt:
+ sql = ""
+
+ continue
+ return text
+
+ def handle_clip_command(self, text):
+ r"""A clip command is any query that is prefixed or suffixed by a
+ '\clip'.
+
+ :param text: Document
+ :return: Boolean
+
+ """
+
+ if special.clip_command(text):
+ query = (special.get_clip_query(text) or
+ self.get_last_query())
+ message = special.copy_query_to_clipboard(sql=query)
+ if message:
+ raise RuntimeError(message)
+ return True
+ return False
+
+ def handle_prettify_binding(self, text):
+ try:
+ statements = sqlglot.parse(text, read='mysql')
+ except Exception as e:
+ statements = []
+ if len(statements) == 1:
+ pretty_text = statements[0].sql(pretty=True, pad=4, dialect='mysql')
+ else:
+ pretty_text = ''
+ self.toolbar_error_message = 'Prettify failed to parse statement'
+ if len(pretty_text) > 0:
+ pretty_text = pretty_text + ';'
+ return pretty_text
+
+ def handle_unprettify_binding(self, text):
+ try:
+ statements = sqlglot.parse(text, read='mysql')
+ except Exception as e:
+ statements = []
+ if len(statements) == 1:
+ unpretty_text = statements[0].sql(pretty=False, dialect='mysql')
+ else:
+ unpretty_text = ''
+ self.toolbar_error_message = 'Unprettify failed to parse statement'
+ if len(unpretty_text) > 0:
+ unpretty_text = unpretty_text + ';'
+ return unpretty_text
+
+ def run_cli(self):
+ iterations = 0
+ sqlexecute = self.sqlexecute
+ logger = self.logger
+ self.configure_pager()
+
+ if self.smart_completion:
+ self.refresh_completions()
+
+ history_file = os.path.expanduser(
+ os.environ.get('MYCLI_HISTFILE', '~/.mycli-history'))
+ if dir_path_exists(history_file):
+ history = FileHistory(history_file)
+ else:
+ history = None
+ self.echo(
+ 'Error: Unable to open the history file "{}". '
+ 'Your query history will not be saved.'.format(history_file),
+ err=True, fg='red')
+
+ key_bindings = mycli_bindings(self)
+
+ if not self.less_chatty:
+ print(sqlexecute.server_info)
+ print('mycli', __version__)
+ print(SUPPORT_INFO)
+ print('Thanks to the contributor -', thanks_picker())
+
+ def get_message():
+ prompt = self.get_prompt(self.prompt_format)
+ if self.prompt_format == self.default_prompt and len(prompt) > self.max_len_prompt:
+ prompt = self.get_prompt('\\d> ')
+ prompt = prompt.replace("\\x1b", "\x1b")
+ return ANSI(prompt)
+
+ def get_continuation(width, *_):
+ if self.multiline_continuation_char == '':
+ continuation = ''
+ elif self.multiline_continuation_char:
+ left_padding = width - len(self.multiline_continuation_char)
+ continuation = " " * \
+ max((left_padding - 1), 0) + \
+ self.multiline_continuation_char + " "
+ else:
+ continuation = " "
+ return [('class:continuation', continuation)]
+
+ def show_suggestion_tip():
+ return iterations < 2
+
+ def one_iteration(text=None):
+ if text is None:
+ try:
+ text = self.prompt_app.prompt()
+ except KeyboardInterrupt:
+ return
+
+ special.set_expanded_output(False)
+
+ try:
+ text = self.handle_editor_command(text)
+ except RuntimeError as e:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ self.echo(str(e), err=True, fg='red')
+ return
+
+ try:
+ if self.handle_clip_command(text):
+ return
+ except RuntimeError as e:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ self.echo(str(e), err=True, fg='red')
+ return
+
+ if not text.strip():
+ return
+
+ if self.destructive_warning:
+ destroy = confirm_destructive_query(text)
+ if destroy is None:
+ pass # Query was not destructive. Nothing to do here.
+ elif destroy is True:
+ self.echo('Your call!')
+ else:
+ self.echo('Wise choice!')
+ return
+ else:
+ destroy = True
+
+ # Keep track of whether or not the query is mutating. In case
+ # of a multi-statement query, the overall query is considered
+ # mutating if any one of the component statements is mutating
+ mutating = False
+
+ try:
+ logger.debug('sql: %r', text)
+
+ special.write_tee(self.get_prompt(self.prompt_format) + text)
+ if self.logfile:
+ self.logfile.write('\n# %s\n' % datetime.now())
+ self.logfile.write(text)
+ self.logfile.write('\n')
+
+ successful = False
+ start = time()
+ res = sqlexecute.run(text)
+ self.formatter.query = text
+ successful = True
+ result_count = 0
+ for title, cur, headers, status in res:
+ logger.debug("headers: %r", headers)
+ logger.debug("rows: %r", cur)
+ logger.debug("status: %r", status)
+ threshold = 1000
+ if (is_select(status) and
+ cur and cur.rowcount > threshold):
+ self.echo('The result set has more than {} rows.'.format(
+ threshold), fg='red')
+ if not confirm('Do you want to continue?'):
+ self.echo("Aborted!", err=True, fg='red')
+ break
+
+ if self.auto_vertical_output:
+ max_width = self.prompt_app.output.get_size().columns
+ else:
+ max_width = None
+
+ formatted = self.format_output(
+ title, cur, headers, special.is_expanded_output(),
+ max_width)
+
+ t = time() - start
+ try:
+ if result_count > 0:
+ self.echo('')
+ try:
+ self.output(formatted, status)
+ except KeyboardInterrupt:
+ pass
+ if self.beep_after_seconds > 0 and t >= self.beep_after_seconds:
+ self.bell()
+ if special.is_timing_enabled():
+ self.echo('Time: %0.03fs' % t)
+ except KeyboardInterrupt:
+ pass
+
+ start = time()
+ result_count += 1
+ mutating = mutating or destroy or is_mutating(status)
+ special.unset_once_if_written()
+ special.unset_pipe_once_if_written()
+ except EOFError as e:
+ raise e
+ except KeyboardInterrupt:
+ # get last connection id
+ connection_id_to_kill = sqlexecute.connection_id
+ # some mysql compatible databases may not implemente connection_id()
+ if connection_id_to_kill > 0:
+ logger.debug("connection id to kill: %r", connection_id_to_kill)
+ # Restart connection to the database
+ sqlexecute.connect()
+ try:
+ for title, cur, headers, status in sqlexecute.run('kill %s' % connection_id_to_kill):
+ status_str = str(status).lower()
+ if status_str.find('ok') > -1:
+ logger.debug("cancelled query, connection id: %r, sql: %r",
+ connection_id_to_kill, text)
+ self.echo("cancelled query", err=True, fg='red')
+ except Exception as e:
+ self.echo('Encountered error while cancelling query: {}'.format(e),
+ err=True, fg='red')
+ else:
+ logger.debug("Did not get a connection id, skip cancelling query")
+ except NotImplementedError:
+ self.echo('Not Yet Implemented.', fg="yellow")
+ except OperationalError as e:
+ logger.debug("Exception: %r", e)
+ if (e.args[0] in (2003, 2006, 2013)):
+ logger.debug('Attempting to reconnect.')
+ self.echo('Reconnecting...', fg='yellow')
+ try:
+ sqlexecute.connect()
+ logger.debug('Reconnected successfully.')
+ one_iteration(text)
+ return # OK to just return, cuz the recursion call runs to the end.
+ except OperationalError as e:
+ logger.debug('Reconnect failed. e: %r', e)
+ self.echo(str(e), err=True, fg='red')
+ # If reconnection failed, don't proceed further.
+ return
+ else:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ self.echo(str(e), err=True, fg='red')
+ except Exception as e:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ self.echo(str(e), err=True, fg='red')
+ else:
+ if is_dropping_database(text, self.sqlexecute.dbname):
+ self.sqlexecute.dbname = None
+ self.sqlexecute.connect()
+
+ # Refresh the table names and column names if necessary.
+ if need_completion_refresh(text):
+ self.refresh_completions(
+ reset=need_completion_reset(text))
+ finally:
+ if self.logfile is False:
+ self.echo("Warning: This query was not logged.",
+ err=True, fg='red')
+ query = Query(text, successful, mutating)
+ self.query_history.append(query)
+
+ get_toolbar_tokens = create_toolbar_tokens_func(
+ self, show_suggestion_tip)
+ if self.wider_completion_menu:
+ complete_style = CompleteStyle.MULTI_COLUMN
+ else:
+ complete_style = CompleteStyle.COLUMN
+
+ with self._completer_lock:
+
+ if self.key_bindings == 'vi':
+ editing_mode = EditingMode.VI
+ else:
+ editing_mode = EditingMode.EMACS
+
+ self.prompt_app = PromptSession(
+ lexer=PygmentsLexer(MyCliLexer),
+ reserve_space_for_menu=self.get_reserved_space(),
+ message=get_message,
+ prompt_continuation=get_continuation,
+ bottom_toolbar=get_toolbar_tokens,
+ complete_style=complete_style,
+ input_processors=[ConditionalProcessor(
+ processor=HighlightMatchingBracketProcessor(
+ chars='[](){}'),
+ filter=HasFocus(DEFAULT_BUFFER) & ~IsDone()
+ )],
+ tempfile_suffix='.sql',
+ completer=DynamicCompleter(lambda: self.completer),
+ history=history,
+ auto_suggest=AutoSuggestFromHistory(),
+ complete_while_typing=True,
+ multiline=cli_is_multiline(self),
+ style=style_factory(self.syntax_style, self.cli_style),
+ include_default_pygments_style=False,
+ key_bindings=key_bindings,
+ enable_open_in_editor=True,
+ enable_system_prompt=True,
+ enable_suspend=True,
+ editing_mode=editing_mode,
+ search_ignore_case=True
+ )
+
+ try:
+ while True:
+ one_iteration()
+ iterations += 1
+ except EOFError:
+ special.close_tee()
+ if not self.less_chatty:
+ self.echo('Goodbye!')
+
+ def log_output(self, output):
+ """Log the output in the audit log, if it's enabled."""
+ if self.logfile:
+ click.echo(output, file=self.logfile)
+
+ def echo(self, s, **kwargs):
+ """Print a message to stdout.
+
+ The message will be logged in the audit log, if enabled.
+
+ All keyword arguments are passed to click.echo().
+
+ """
+ self.log_output(s)
+ click.secho(s, **kwargs)
+
+ def bell(self):
+ """Print a bell on the stderr.
+ """
+ click.secho('\a', err=True, nl=False)
+
+ def get_output_margin(self, status=None):
+ """Get the output margin (number of rows for the prompt, footer and
+ timing message."""
+ margin = self.get_reserved_space() + self.get_prompt(self.prompt_format).count('\n') + 1
+ if special.is_timing_enabled():
+ margin += 1
+ if status:
+ margin += 1 + status.count('\n')
+
+ return margin
+
+
+ def output(self, output, status=None):
+ """Output text to stdout or a pager command.
+
+ The status text is not outputted to pager or files.
+
+ The message will be logged in the audit log, if enabled. The
+ message will be written to the tee file, if enabled. The
+ message will be written to the output file, if enabled.
+
+ """
+ if output:
+ size = self.prompt_app.output.get_size()
+
+ margin = self.get_output_margin(status)
+
+ fits = True
+ buf = []
+ output_via_pager = self.explicit_pager and special.is_pager_enabled()
+ for i, line in enumerate(output, 1):
+ self.log_output(line)
+ special.write_tee(line)
+ special.write_once(line)
+ special.write_pipe_once(line)
+
+ if fits or output_via_pager:
+ # buffering
+ buf.append(line)
+ if len(line) > size.columns or i > (size.rows - margin):
+ fits = False
+ if not self.explicit_pager and special.is_pager_enabled():
+ # doesn't fit, use pager
+ output_via_pager = True
+
+ if not output_via_pager:
+ # doesn't fit, flush buffer
+ for buf_line in buf:
+ click.secho(buf_line)
+ buf = []
+ else:
+ click.secho(line)
+
+ if buf:
+ if output_via_pager:
+ def newlinewrapper(text):
+ for line in text:
+ yield line + "\n"
+ click.echo_via_pager(newlinewrapper(buf))
+ else:
+ for line in buf:
+ click.secho(line)
+
+ if status:
+ self.log_output(status)
+ click.secho(status)
+
+ def configure_pager(self):
+ # Provide sane defaults for less if they are empty.
+ if not os.environ.get('LESS'):
+ os.environ['LESS'] = '-RXF'
+
+ cnf = self.read_my_cnf_files(self.cnf_files, ['pager', 'skip-pager'])
+ cnf_pager = cnf['pager'] or self.config['main']['pager']
+ if cnf_pager:
+ special.set_pager(cnf_pager)
+ self.explicit_pager = True
+ else:
+ self.explicit_pager = False
+
+ if cnf['skip-pager'] or not self.config['main'].as_bool('enable_pager'):
+ special.disable_pager()
+
+ def refresh_completions(self, reset=False):
+ if reset:
+ with self._completer_lock:
+ self.completer.reset_completions()
+ self.completion_refresher.refresh(
+ self.sqlexecute, self._on_completions_refreshed,
+ {'smart_completion': self.smart_completion,
+ 'supported_formats': self.formatter.supported_formats,
+ 'keyword_casing': self.completer.keyword_casing})
+
+ return [(None, None, None,
+ 'Auto-completion refresh started in the background.')]
+
+ def _on_completions_refreshed(self, new_completer):
+ """Swap the completer object in cli with the newly created completer.
+ """
+ with self._completer_lock:
+ self.completer = new_completer
+
+ if self.prompt_app:
+ # After refreshing, redraw the CLI to clear the statusbar
+ # "Refreshing completions..." indicator
+ self.prompt_app.app.invalidate()
+
+ def get_completions(self, text, cursor_positition):
+ with self._completer_lock:
+ return self.completer.get_completions(
+ Document(text=text, cursor_position=cursor_positition), None)
+
+ def get_prompt(self, string):
+ sqlexecute = self.sqlexecute
+ host = self.login_path if self.login_path and self.login_path_as_host else sqlexecute.host
+ now = datetime.now()
+ string = string.replace('\\u', sqlexecute.user or '(none)')
+ string = string.replace('\\h', host or '(none)')
+ string = string.replace('\\d', sqlexecute.dbname or '(none)')
+ string = string.replace('\\t', sqlexecute.server_info.species.name)
+ string = string.replace('\\n', "\n")
+ string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y'))
+ string = string.replace('\\m', now.strftime('%M'))
+ string = string.replace('\\P', now.strftime('%p'))
+ string = string.replace('\\R', now.strftime('%H'))
+ string = string.replace('\\r', now.strftime('%I'))
+ string = string.replace('\\s', now.strftime('%S'))
+ string = string.replace('\\p', str(sqlexecute.port))
+ string = string.replace('\\A', self.dsn_alias or '(none)')
+ string = string.replace('\\_', ' ')
+ return string
+
+ def run_query(self, query, new_line=True):
+ """Runs *query*."""
+ results = self.sqlexecute.run(query)
+ for result in results:
+ title, cur, headers, status = result
+ self.formatter.query = query
+ output = self.format_output(title, cur, headers)
+ for line in output:
+ click.echo(line, nl=new_line)
+
+ def format_output(self, title, cur, headers, expanded=False,
+ max_width=None):
+ expanded = expanded or self.formatter.format_name == 'vertical'
+ output = []
+
+ output_kwargs = {
+ 'dialect': 'unix',
+ 'disable_numparse': True,
+ 'preserve_whitespace': True,
+ 'style': self.output_style
+ }
+
+ if not self.formatter.format_name in sql_format.supported_formats:
+ output_kwargs["preprocessors"] = (preprocessors.align_decimals, )
+
+ if title: # Only print the title if it's not None.
+ output = itertools.chain(output, [title])
+
+ if cur:
+ column_types = None
+ if hasattr(cur, 'description'):
+ def get_col_type(col):
+ col_type = FIELD_TYPES.get(col[1], str)
+ return col_type if type(col_type) is type else str
+ column_types = [get_col_type(col) for col in cur.description]
+
+ if max_width is not None:
+ cur = list(cur)
+
+ formatted = self.formatter.format_output(
+ cur, headers, format_name='vertical' if expanded else None,
+ column_types=column_types,
+ **output_kwargs)
+
+ if isinstance(formatted, str):
+ formatted = formatted.splitlines()
+ formatted = iter(formatted)
+
+ if (not expanded and max_width and headers and cur):
+ first_line = next(formatted)
+ if len(strip_ansi(first_line)) > max_width:
+ formatted = self.formatter.format_output(
+ cur, headers, format_name='vertical', column_types=column_types, **output_kwargs)
+ if isinstance(formatted, str):
+ formatted = iter(formatted.splitlines())
+ else:
+ formatted = itertools.chain([first_line], formatted)
+
+ output = itertools.chain(output, formatted)
+
+
+ return output
+
+ def get_reserved_space(self):
+ """Get the number of lines to reserve for the completion menu."""
+ reserved_space_ratio = .45
+ max_reserved_space = 8
+ _, height = shutil.get_terminal_size()
+ return min(int(round(height * reserved_space_ratio)), max_reserved_space)
+
+ def get_last_query(self):
+ """Get the last query executed or None."""
+ return self.query_history[-1][0] if self.query_history else None
+
+
+@click.command()
+@click.option('-h', '--host', envvar='MYSQL_HOST', help='Host address of the database.')
+@click.option('-P', '--port', envvar='MYSQL_TCP_PORT', type=int, help='Port number to use for connection. Honors '
+ '$MYSQL_TCP_PORT.')
+@click.option('-u', '--user', help='User name to connect to the database.')
+@click.option('-S', '--socket', envvar='MYSQL_UNIX_PORT', help='The socket file to use for connection.')
+@click.option('-p', '--password', 'password', envvar='MYSQL_PWD', type=str,
+ help='Password to connect to the database.')
+@click.option('--pass', 'password', envvar='MYSQL_PWD', type=str,
+ help='Password to connect to the database.')
+@click.option('--ssh-user', help='User name to connect to ssh server.')
+@click.option('--ssh-host', help='Host name to connect to ssh server.')
+@click.option('--ssh-port', default=22, help='Port to connect to ssh server.')
+@click.option('--ssh-password', help='Password to connect to ssh server.')
+@click.option('--ssh-key-filename', help='Private key filename (identify file) for the ssh connection.')
+@click.option('--ssh-config-path', help='Path to ssh configuration.',
+ default=os.path.expanduser('~') + '/.ssh/config')
+@click.option('--ssh-config-host', help='Host to connect to ssh server reading from ssh configuration.')
+@click.option('--ssl', 'ssl_enable', is_flag=True,
+ help='Enable SSL for connection (automatically enabled with other flags).')
+@click.option('--ssl-ca', help='CA file in PEM format.',
+ type=click.Path(exists=True))
+@click.option('--ssl-capath', help='CA directory.')
+@click.option('--ssl-cert', help='X509 cert in PEM format.',
+ type=click.Path(exists=True))
+@click.option('--ssl-key', help='X509 key in PEM format.',
+ type=click.Path(exists=True))
+@click.option('--ssl-cipher', help='SSL cipher to use.')
+@click.option('--ssl-verify-server-cert', is_flag=True,
+ help=('Verify server\'s "Common Name" in its cert against '
+ 'hostname used when connecting. This option is disabled '
+ 'by default.'))
+# as of 2016-02-15 revocation list is not supported by underling PyMySQL
+# library (--ssl-crl and --ssl-crlpath options in vanilla mysql client)
+@click.option('-V', '--version', is_flag=True, help='Output mycli\'s version.')
+@click.option('-v', '--verbose', is_flag=True, help='Verbose output.')
+@click.option('-D', '--database', 'dbname', help='Database to use.')
+@click.option('-d', '--dsn', default='', envvar='DSN',
+ help='Use DSN configured into the [alias_dsn] section of myclirc file.')
+@click.option('--list-dsn', 'list_dsn', is_flag=True,
+ help='list of DSN configured into the [alias_dsn] section of myclirc file.')
+@click.option('--list-ssh-config', 'list_ssh_config', is_flag=True,
+ help='list ssh configurations in the ssh config (requires paramiko).')
+@click.option('-R', '--prompt', 'prompt',
+ help='Prompt format (Default: "{0}").'.format(
+ MyCli.default_prompt))
+@click.option('-l', '--logfile', type=click.File(mode='a', encoding='utf-8'),
+ help='Log every query and its results to a file.')
+@click.option('--defaults-group-suffix', type=str,
+ help='Read MySQL config groups with the specified suffix.')
+@click.option('--defaults-file', type=click.Path(),
+ help='Only read MySQL options from the given file.')
+@click.option('--myclirc', type=click.Path(), default="~/.myclirc",
+ help='Location of myclirc file.')
+@click.option('--auto-vertical-output', is_flag=True,
+ help='Automatically switch to vertical output mode if the result is wider than the terminal width.')
+@click.option('-t', '--table', is_flag=True,
+ help='Display batch output in table format.')
+@click.option('--csv', is_flag=True,
+ help='Display batch output in CSV format.')
+@click.option('--warn/--no-warn', default=None,
+ help='Warn before running a destructive query.')
+@click.option('--local-infile', type=bool,
+ help='Enable/disable LOAD DATA LOCAL INFILE.')
+@click.option('-g', '--login-path', type=str,
+ help='Read this path from the login file.')
+@click.option('-e', '--execute', type=str,
+ help='Execute command and quit.')
+@click.option('--init-command', type=str,
+ help='SQL statement to execute after connecting.')
+@click.option('--charset', type=str,
+ help='Character set for MySQL session.')
+@click.option('--password-file', type=click.Path(),
+ help='File or FIFO path containing the password to connect to the db if not specified otherwise.')
+@click.argument('database', default='', nargs=1)
+def cli(database, user, host, port, socket, password, dbname,
+ version, verbose, prompt, logfile, defaults_group_suffix,
+ defaults_file, login_path, auto_vertical_output, local_infile,
+ ssl_enable, ssl_ca, ssl_capath, ssl_cert, ssl_key, ssl_cipher,
+ ssl_verify_server_cert, table, csv, warn, execute, myclirc, dsn,
+ list_dsn, ssh_user, ssh_host, ssh_port, ssh_password,
+ ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host,
+ init_command, charset, password_file):
+ """A MySQL terminal client with auto-completion and syntax highlighting.
+
+ \b
+ Examples:
+ - mycli my_database
+ - mycli -u my_user -h my_host.com my_database
+ - mycli mysql://my_user@my_host.com:3306/my_database
+
+ """
+
+ if version:
+ print('Version:', __version__)
+ sys.exit(0)
+
+ mycli = MyCli(prompt=prompt, logfile=logfile,
+ defaults_suffix=defaults_group_suffix,
+ defaults_file=defaults_file, login_path=login_path,
+ auto_vertical_output=auto_vertical_output, warn=warn,
+ myclirc=myclirc)
+ if list_dsn:
+ try:
+ alias_dsn = mycli.config['alias_dsn']
+ except KeyError as err:
+ click.secho('Invalid DSNs found in the config file. '\
+ 'Please check the "[alias_dsn]" section in myclirc.',
+ err=True, fg='red')
+ exit(1)
+ except Exception as e:
+ click.secho(str(e), err=True, fg='red')
+ exit(1)
+ for alias, value in alias_dsn.items():
+ if verbose:
+ click.secho("{} : {}".format(alias, value))
+ else:
+ click.secho(alias)
+ sys.exit(0)
+ if list_ssh_config:
+ ssh_config = read_ssh_config(ssh_config_path)
+ for host in ssh_config.get_hostnames():
+ if verbose:
+ host_config = ssh_config.lookup(host)
+ click.secho("{} : {}".format(
+ host, host_config.get('hostname')))
+ else:
+ click.secho(host)
+ sys.exit(0)
+ # Choose which ever one has a valid value.
+ database = dbname or database
+
+ ssl = {
+ 'enable': ssl_enable,
+ 'ca': ssl_ca and os.path.expanduser(ssl_ca),
+ 'cert': ssl_cert and os.path.expanduser(ssl_cert),
+ 'key': ssl_key and os.path.expanduser(ssl_key),
+ 'capath': ssl_capath,
+ 'cipher': ssl_cipher,
+ 'check_hostname': ssl_verify_server_cert,
+ }
+
+ # remove empty ssl options
+ ssl = {k: v for k, v in ssl.items() if v is not None}
+
+ dsn_uri = None
+
+ # Treat the database argument as a DSN alias if we're missing
+ # other connection information.
+ if (mycli.config['alias_dsn'] and database and '://' not in database
+ and not any([user, password, host, port, login_path])):
+ dsn, database = database, ''
+
+ if database and '://' in database:
+ dsn_uri, database = database, ''
+
+ if dsn:
+ try:
+ dsn_uri = mycli.config['alias_dsn'][dsn]
+ except KeyError:
+ click.secho('Could not find the specified DSN in the config file. '
+ 'Please check the "[alias_dsn]" section in your '
+ 'myclirc.', err=True, fg='red')
+ exit(1)
+ else:
+ mycli.dsn_alias = dsn
+
+ if dsn_uri:
+ uri = urlparse(dsn_uri)
+ if not database:
+ database = uri.path[1:] # ignore the leading fwd slash
+ if not user:
+ user = unquote(uri.username)
+ if not password and uri.password is not None:
+ password = unquote(uri.password)
+ if not host:
+ host = uri.hostname
+ if not port:
+ port = uri.port
+
+ if ssh_config_host:
+ ssh_config = read_ssh_config(
+ ssh_config_path
+ ).lookup(ssh_config_host)
+ ssh_host = ssh_host if ssh_host else ssh_config.get('hostname')
+ ssh_user = ssh_user if ssh_user else ssh_config.get('user')
+ if ssh_config.get('port') and ssh_port == 22:
+ # port has a default value, overwrite it if it's in the config
+ ssh_port = int(ssh_config.get('port'))
+ ssh_key_filename = ssh_key_filename if ssh_key_filename else ssh_config.get(
+ 'identityfile', [None])[0]
+
+ ssh_key_filename = ssh_key_filename and os.path.expanduser(ssh_key_filename)
+
+ mycli.connect(
+ database=database,
+ user=user,
+ passwd=password,
+ host=host,
+ port=port,
+ socket=socket,
+ local_infile=local_infile,
+ ssl=ssl,
+ ssh_user=ssh_user,
+ ssh_host=ssh_host,
+ ssh_port=ssh_port,
+ ssh_password=ssh_password,
+ ssh_key_filename=ssh_key_filename,
+ init_command=init_command,
+ charset=charset,
+ password_file=password_file
+ )
+
+ mycli.logger.debug('Launch Params: \n'
+ '\tdatabase: %r'
+ '\tuser: %r'
+ '\thost: %r'
+ '\tport: %r', database, user, host, port)
+
+ # --execute argument
+ if execute:
+ try:
+ if csv:
+ mycli.formatter.format_name = 'csv'
+ elif not table:
+ mycli.formatter.format_name = 'tsv'
+
+ mycli.run_query(execute)
+ exit(0)
+ except Exception as e:
+ click.secho(str(e), err=True, fg='red')
+ exit(1)
+
+ if sys.stdin.isatty():
+ mycli.run_cli()
+ else:
+ stdin = click.get_text_stream('stdin')
+ try:
+ stdin_text = stdin.read()
+ except MemoryError:
+ click.secho('Failed! Ran out of memory.', err=True, fg='red')
+ click.secho('You might want to try the official mysql client.', err=True, fg='red')
+ click.secho('Sorry... :(', err=True, fg='red')
+ exit(1)
+
+ if mycli.destructive_warning and is_destructive(stdin_text):
+ try:
+ sys.stdin = open('/dev/tty')
+ warn_confirmed = confirm_destructive_query(stdin_text)
+ except (IOError, OSError):
+ mycli.logger.warning('Unable to open TTY as stdin.')
+ if not warn_confirmed:
+ exit(0)
+
+ try:
+ new_line = True
+
+ if csv:
+ mycli.formatter.format_name = 'csv'
+ elif not table:
+ mycli.formatter.format_name = 'tsv'
+
+ mycli.run_query(stdin_text, new_line=new_line)
+ exit(0)
+ except Exception as e:
+ click.secho(str(e), err=True, fg='red')
+ exit(1)
+
+
+def need_completion_refresh(queries):
+ """Determines if the completion needs a refresh by checking if the sql
+ statement is an alter, create, drop or change db."""
+ for query in sqlparse.split(queries):
+ try:
+ first_token = query.split()[0]
+ if first_token.lower() in ('alter', 'create', 'use', '\\r',
+ '\\u', 'connect', 'drop', 'rename'):
+ return True
+ except Exception:
+ return False
+
+
+def need_completion_reset(queries):
+ """Determines if the statement is a database switch such as 'use' or '\\u'.
+ When a database is changed the existing completions must be reset before we
+ start the completion refresh for the new database.
+ """
+ for query in sqlparse.split(queries):
+ try:
+ first_token = query.split()[0]
+ if first_token.lower() in ('use', '\\u'):
+ return True
+ except Exception:
+ return False
+
+
+def is_mutating(status):
+ """Determines if the statement is mutating based on the status."""
+ if not status:
+ return False
+
+ mutating = set(['insert', 'update', 'delete', 'alter', 'create', 'drop',
+ 'replace', 'truncate', 'load', 'rename'])
+ return status.split(None, 1)[0].lower() in mutating
+
+
+def is_select(status):
+ """Returns true if the first word in status is 'select'."""
+ if not status:
+ return False
+ return status.split(None, 1)[0].lower() == 'select'
+
+
+def thanks_picker():
+ import mycli
+ lines = (
+ resources.read_text(mycli, 'AUTHORS') +
+ resources.read_text(mycli, 'SPONSORS')
+ ).split('\n')
+
+ contents = []
+ for line in lines:
+ m = re.match(r'^ *\* (.*)', line)
+ if m:
+ contents.append(m.group(1))
+ return choice(contents)
+
+
+@prompt_register('edit-and-execute-command')
+def edit_and_execute(event):
+ """Different from the prompt-toolkit default, we want to have a choice not
+ to execute a query after editing, hence validate_and_handle=False."""
+ buff = event.current_buffer
+ buff.open_in_editor(validate_and_handle=False)
+
+
+def read_ssh_config(ssh_config_path):
+ ssh_config = paramiko.config.SSHConfig()
+ try:
+ with open(ssh_config_path) as f:
+ ssh_config.parse(f)
+ except FileNotFoundError as e:
+ click.secho(str(e), err=True, fg='red')
+ sys.exit(1)
+ # Paramiko prior to version 2.7 raises Exception on parse errors.
+ # In 2.7 it has become paramiko.ssh_exception.SSHException,
+ # but let's catch everything for compatibility
+ except Exception as err:
+ click.secho(
+ f'Could not parse SSH configuration file {ssh_config_path}:\n{err} ',
+ err=True, fg='red'
+ )
+ sys.exit(1)
+ else:
+ return ssh_config
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/mycli/myclirc b/mycli/myclirc
new file mode 100644
index 0000000..cd58dfe
--- /dev/null
+++ b/mycli/myclirc
@@ -0,0 +1,159 @@
+# vi: ft=dosini
+[main]
+
+# Enables context sensitive auto-completion. If this is disabled the all
+# possible completions will be listed.
+smart_completion = True
+
+# Multi-line mode allows breaking up the sql statements into multiple lines. If
+# this is set to True, then the end of the statements must have a semi-colon.
+# If this is set to False then sql statements can't be split into multiple
+# lines. End of line (return) is considered as the end of the statement.
+multi_line = False
+
+# Destructive warning mode will alert you before executing a sql statement
+# that may cause harm to the database such as "drop table", "drop database"
+# or "shutdown".
+destructive_warning = True
+
+# log_file location.
+log_file = ~/.mycli.log
+
+# Default log level. Possible values: "CRITICAL", "ERROR", "WARNING", "INFO"
+# and "DEBUG". "NONE" disables logging.
+log_level = INFO
+
+# Log every query and its results to a file. Enable this by uncommenting the
+# line below.
+# audit_log = ~/.mycli-audit.log
+
+# Timing of sql statements and table rendering.
+timing = True
+
+# Beep after long-running queries are completed; 0 to disable.
+beep_after_seconds = 0
+
+# Table format. Possible values: ascii, double, github,
+# psql, plain, simple, grid, fancy_grid, pipe, orgtbl, rst, mediawiki, html,
+# latex, latex_booktabs, textile, moinmoin, jira, vertical, tsv, csv.
+# Recommended: ascii
+table_format = ascii
+
+# Syntax coloring style. Possible values (many support the "-dark" suffix):
+# manni, igor, xcode, vim, autumn, vs, rrt, native, perldoc, borland, tango, emacs,
+# friendly, monokai, paraiso, colorful, murphy, bw, pastie, paraiso, trac, default,
+# fruity.
+# Screenshots at http://mycli.net/syntax
+# Can be further modified in [colors]
+syntax_style = default
+
+# Keybindings: Possible values: emacs, vi.
+# Emacs mode: Ctrl-A is home, Ctrl-E is end. All emacs keybindings are available in the REPL.
+# When Vi mode is enabled you can use modal editing features offered by Vi in the REPL.
+key_bindings = emacs
+
+# Enabling this option will show the suggestions in a wider menu. Thus more items are suggested.
+wider_completion_menu = False
+
+# MySQL prompt
+# \D - The full current date
+# \d - Database name
+# \h - Hostname of the server
+# \m - Minutes of the current time
+# \n - Newline
+# \P - AM/PM
+# \p - Port
+# \R - The current time, in 24-hour military time (0-23)
+# \r - The current time, standard 12-hour time (1-12)
+# \s - Seconds of the current time
+# \t - Product type (Percona, MySQL, MariaDB, TiDB)
+# \A - DSN alias name (from the [alias_dsn] section)
+# \u - Username
+# \x1b[...m - insert ANSI escape sequence
+prompt = '\t \u@\h:\d> '
+prompt_continuation = '->'
+
+# Skip intro info on startup and outro info on exit
+less_chatty = False
+
+# Use alias from --login-path instead of host name in prompt
+login_path_as_host = False
+
+# Cause result sets to be displayed vertically if they are too wide for the current window,
+# and using normal tabular format otherwise. (This applies to statements terminated by ; or \G.)
+auto_vertical_output = False
+
+# keyword casing preference. Possible values "lower", "upper", "auto"
+keyword_casing = auto
+
+# disabled pager on startup
+enable_pager = True
+
+# Choose a specific pager
+pager = 'less'
+
+# Custom colors for the completion menu, toolbar, etc.
+[colors]
+completion-menu.completion.current = 'bg:#ffffff #000000'
+completion-menu.completion = 'bg:#008888 #ffffff'
+completion-menu.meta.completion.current = 'bg:#44aaaa #000000'
+completion-menu.meta.completion = 'bg:#448888 #ffffff'
+completion-menu.multi-column-meta = 'bg:#aaffff #000000'
+scrollbar.arrow = 'bg:#003333'
+scrollbar = 'bg:#00aaaa'
+selected = '#ffffff bg:#6666aa'
+search = '#ffffff bg:#4444aa'
+search.current = '#ffffff bg:#44aa44'
+bottom-toolbar = 'bg:#222222 #aaaaaa'
+bottom-toolbar.off = 'bg:#222222 #888888'
+bottom-toolbar.on = 'bg:#222222 #ffffff'
+search-toolbar = 'noinherit bold'
+search-toolbar.text = 'nobold'
+system-toolbar = 'noinherit bold'
+arg-toolbar = 'noinherit bold'
+arg-toolbar.text = 'nobold'
+bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold'
+bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold'
+
+# style classes for colored table output
+output.header = "#00ff5f bold"
+output.odd-row = ""
+output.even-row = ""
+output.null = "#808080"
+
+# SQL syntax highlighting overrides
+# sql.comment = 'italic #408080'
+# sql.comment.multi-line = ''
+# sql.comment.single-line = ''
+# sql.comment.optimizer-hint = ''
+# sql.escape = 'border:#FF0000'
+# sql.keyword = 'bold #008000'
+# sql.datatype = 'nobold #B00040'
+# sql.literal = ''
+# sql.literal.date = ''
+# sql.symbol = ''
+# sql.quoted-schema-object = ''
+# sql.quoted-schema-object.escape = ''
+# sql.constant = '#880000'
+# sql.function = '#0000FF'
+# sql.variable = '#19177C'
+# sql.number = '#666666'
+# sql.number.binary = ''
+# sql.number.float = ''
+# sql.number.hex = ''
+# sql.number.integer = ''
+# sql.operator = '#666666'
+# sql.punctuation = ''
+# sql.string = '#BA2121'
+# sql.string.double-quouted = ''
+# sql.string.escape = 'bold #BB6622'
+# sql.string.single-quoted = ''
+# sql.whitespace = ''
+
+# Favorite queries.
+[favorite_queries]
+
+# Use the -d option to reference a DSN.
+# Special characters in passwords and other strings can be escaped with URL encoding.
+[alias_dsn]
+# example_dsn = mysql://[user[:password]@][host][:port][/dbname]
diff --git a/mycli/packages/__init__.py b/mycli/packages/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/mycli/packages/__init__.py
diff --git a/mycli/packages/completion_engine.py b/mycli/packages/completion_engine.py
new file mode 100644
index 0000000..2735f5b
--- /dev/null
+++ b/mycli/packages/completion_engine.py
@@ -0,0 +1,294 @@
+import sqlparse
+from sqlparse.sql import Comparison, Identifier, Where
+from .parseutils import last_word, extract_tables, find_prev_keyword
+from .special import parse_special_command
+
+
+def suggest_type(full_text, text_before_cursor):
+ """Takes the full_text that is typed so far and also the text before the
+ cursor to suggest completion type and scope.
+
+ Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
+ A scope for a column category will be a list of tables.
+ """
+
+ word_before_cursor = last_word(text_before_cursor,
+ include='many_punctuations')
+
+ identifier = None
+
+ # here should be removed once sqlparse has been fixed
+ try:
+ # If we've partially typed a word then word_before_cursor won't be an empty
+ # string. In that case we want to remove the partially typed string before
+ # sending it to the sqlparser. Otherwise the last token will always be the
+ # partially typed string which renders the smart completion useless because
+ # it will always return the list of keywords as completion.
+ if word_before_cursor:
+ if word_before_cursor.endswith(
+ '(') or word_before_cursor.startswith('\\'):
+ parsed = sqlparse.parse(text_before_cursor)
+ else:
+ parsed = sqlparse.parse(
+ text_before_cursor[:-len(word_before_cursor)])
+
+ # word_before_cursor may include a schema qualification, like
+ # "schema_name.partial_name" or "schema_name.", so parse it
+ # separately
+ p = sqlparse.parse(word_before_cursor)[0]
+
+ if p.tokens and isinstance(p.tokens[0], Identifier):
+ identifier = p.tokens[0]
+ else:
+ parsed = sqlparse.parse(text_before_cursor)
+ except (TypeError, AttributeError):
+ return [{'type': 'keyword'}]
+
+ if len(parsed) > 1:
+ # Multiple statements being edited -- isolate the current one by
+ # cumulatively summing statement lengths to find the one that bounds the
+ # current position
+ current_pos = len(text_before_cursor)
+ stmt_start, stmt_end = 0, 0
+
+ for statement in parsed:
+ stmt_len = len(str(statement))
+ stmt_start, stmt_end = stmt_end, stmt_end + stmt_len
+
+ if stmt_end >= current_pos:
+ text_before_cursor = full_text[stmt_start:current_pos]
+ full_text = full_text[stmt_start:]
+ break
+
+ elif parsed:
+ # A single statement
+ statement = parsed[0]
+ else:
+ # The empty string
+ statement = None
+
+ # Check for special commands and handle those separately
+ if statement:
+ # Be careful here because trivial whitespace is parsed as a statement,
+ # but the statement won't have a first token
+ tok1 = statement.token_first()
+ if tok1 and (tok1.value == 'source' or tok1.value.startswith('\\')):
+ return suggest_special(text_before_cursor)
+
+ last_token = statement and statement.token_prev(len(statement.tokens))[1] or ''
+
+ return suggest_based_on_last_token(last_token, text_before_cursor,
+ full_text, identifier)
+
+
+def suggest_special(text):
+ text = text.lstrip()
+ cmd, _, arg = parse_special_command(text)
+
+ if cmd == text:
+ # Trying to complete the special command itself
+ return [{'type': 'special'}]
+
+ if cmd in ('\\u', '\\r'):
+ return [{'type': 'database'}]
+
+ if cmd in ('\\T'):
+ return [{'type': 'table_format'}]
+
+ if cmd in ['\\f', '\\fs', '\\fd']:
+ return [{'type': 'favoritequery'}]
+
+ if cmd in ['\\dt', '\\dt+']:
+ return [
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'},
+ ]
+ elif cmd in ['\\.', 'source']:
+ return[{'type': 'file_name'}]
+
+ return [{'type': 'keyword'}, {'type': 'special'}]
+
+
+def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier):
+ if isinstance(token, str):
+ token_v = token.lower()
+ elif isinstance(token, Comparison):
+ # If 'token' is a Comparison type such as
+ # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling
+ # token.value on the comparison type will only return the lhs of the
+ # comparison. In this case a.id. So we need to do token.tokens to get
+ # both sides of the comparison and pick the last token out of that
+ # list.
+ token_v = token.tokens[-1].value.lower()
+ elif isinstance(token, Where):
+ # sqlparse groups all tokens from the where clause into a single token
+ # list. This means that token.value may be something like
+ # 'where foo > 5 and '. We need to look "inside" token.tokens to handle
+ # suggestions in complicated where clauses correctly
+ prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
+ return suggest_based_on_last_token(prev_keyword, text_before_cursor,
+ full_text, identifier)
+ elif token is None:
+ return [{'type': 'keyword'}]
+ else:
+ token_v = token.value.lower()
+
+ is_operand = lambda x: x and any([x.endswith(op) for op in ['+', '-', '*', '/']])
+
+ if not token:
+ return [{'type': 'keyword'}, {'type': 'special'}]
+ elif token_v.endswith('('):
+ p = sqlparse.parse(text_before_cursor)[0]
+
+ if p.tokens and isinstance(p.tokens[-1], Where):
+ # Four possibilities:
+ # 1 - Parenthesized clause like "WHERE foo AND ("
+ # Suggest columns/functions
+ # 2 - Function call like "WHERE foo("
+ # Suggest columns/functions
+ # 3 - Subquery expression like "WHERE EXISTS ("
+ # Suggest keywords, in order to do a subquery
+ # 4 - Subquery OR array comparison like "WHERE foo = ANY("
+ # Suggest columns/functions AND keywords. (If we wanted to be
+ # really fancy, we could suggest only array-typed columns)
+
+ column_suggestions = suggest_based_on_last_token('where',
+ text_before_cursor, full_text, identifier)
+
+ # Check for a subquery expression (cases 3 & 4)
+ where = p.tokens[-1]
+ idx, prev_tok = where.token_prev(len(where.tokens) - 1)
+
+ if isinstance(prev_tok, Comparison):
+ # e.g. "SELECT foo FROM bar WHERE foo = ANY("
+ prev_tok = prev_tok.tokens[-1]
+
+ prev_tok = prev_tok.value.lower()
+ if prev_tok == 'exists':
+ return [{'type': 'keyword'}]
+ else:
+ return column_suggestions
+
+ # Get the token before the parens
+ idx, prev_tok = p.token_prev(len(p.tokens) - 1)
+ if prev_tok and prev_tok.value and prev_tok.value.lower() == 'using':
+ # tbl1 INNER JOIN tbl2 USING (col1, col2)
+ tables = extract_tables(full_text)
+
+ # suggest columns that are present in more than one table
+ return [{'type': 'column', 'tables': tables, 'drop_unique': True}]
+ elif p.token_first().value.lower() == 'select':
+ # If the lparen is preceeded by a space chances are we're about to
+ # do a sub-select.
+ if last_word(text_before_cursor,
+ 'all_punctuations').startswith('('):
+ return [{'type': 'keyword'}]
+ elif p.token_first().value.lower() == 'show':
+ return [{'type': 'show'}]
+
+ # We're probably in a function argument list
+ return [{'type': 'column', 'tables': extract_tables(full_text)}]
+ elif token_v in ('set', 'order by', 'distinct'):
+ return [{'type': 'column', 'tables': extract_tables(full_text)}]
+ elif token_v == 'as':
+ # Don't suggest anything for an alias
+ return []
+ elif token_v in ('show'):
+ return [{'type': 'show'}]
+ elif token_v in ('to',):
+ p = sqlparse.parse(text_before_cursor)[0]
+ if p.token_first().value.lower() == 'change':
+ return [{'type': 'change'}]
+ else:
+ return [{'type': 'user'}]
+ elif token_v in ('user', 'for'):
+ return [{'type': 'user'}]
+ elif token_v in ('select', 'where', 'having'):
+ # Check for a table alias or schema qualification
+ parent = (identifier and identifier.get_parent_name()) or []
+
+ tables = extract_tables(full_text)
+ if parent:
+ tables = [t for t in tables if identifies(parent, *t)]
+ return [{'type': 'column', 'tables': tables},
+ {'type': 'table', 'schema': parent},
+ {'type': 'view', 'schema': parent},
+ {'type': 'function', 'schema': parent}]
+ else:
+ aliases = [alias or table for (schema, table, alias) in tables]
+ return [{'type': 'column', 'tables': tables},
+ {'type': 'function', 'schema': []},
+ {'type': 'alias', 'aliases': aliases},
+ {'type': 'keyword'}]
+ elif (token_v.endswith('join') and token.is_keyword) or (token_v in
+ ('copy', 'from', 'update', 'into', 'describe', 'truncate',
+ 'desc', 'explain')):
+ schema = (identifier and identifier.get_parent_name()) or []
+
+ # Suggest tables from either the currently-selected schema or the
+ # public schema if no schema has been specified
+ suggest = [{'type': 'table', 'schema': schema}]
+
+ if not schema:
+ # Suggest schemas
+ suggest.insert(0, {'type': 'schema'})
+
+ # Only tables can be TRUNCATED, otherwise suggest views
+ if token_v != 'truncate':
+ suggest.append({'type': 'view', 'schema': schema})
+
+ return suggest
+
+ elif token_v in ('table', 'view', 'function'):
+ # E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>'
+ rel_type = token_v
+ schema = (identifier and identifier.get_parent_name()) or []
+ if schema:
+ return [{'type': rel_type, 'schema': schema}]
+ else:
+ return [{'type': 'schema'}, {'type': rel_type, 'schema': []}]
+ elif token_v == 'on':
+ tables = extract_tables(full_text) # [(schema, table, alias), ...]
+ parent = (identifier and identifier.get_parent_name()) or []
+ if parent:
+ # "ON parent.<suggestion>"
+ # parent can be either a schema name or table alias
+ tables = [t for t in tables if identifies(parent, *t)]
+ return [{'type': 'column', 'tables': tables},
+ {'type': 'table', 'schema': parent},
+ {'type': 'view', 'schema': parent},
+ {'type': 'function', 'schema': parent}]
+ else:
+ # ON <suggestion>
+ # Use table alias if there is one, otherwise the table name
+ aliases = [alias or table for (schema, table, alias) in tables]
+ suggest = [{'type': 'alias', 'aliases': aliases}]
+
+ # The lists of 'aliases' could be empty if we're trying to complete
+ # a GRANT query. eg: GRANT SELECT, INSERT ON <tab>
+ # In that case we just suggest all tables.
+ if not aliases:
+ suggest.append({'type': 'table', 'schema': parent})
+ return suggest
+
+ elif token_v in ('use', 'database', 'template', 'connect'):
+ # "\c <db", "use <db>", "DROP DATABASE <db>",
+ # "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
+ return [{'type': 'database'}]
+ elif token_v == 'tableformat':
+ return [{'type': 'table_format'}]
+ elif token_v.endswith(',') or is_operand(token_v) or token_v in ['=', 'and', 'or']:
+ prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
+ if prev_keyword:
+ return suggest_based_on_last_token(
+ prev_keyword, text_before_cursor, full_text, identifier)
+ else:
+ return []
+ else:
+ return [{'type': 'keyword'}]
+
+
+def identifies(id, schema, table, alias):
+ return id == alias or id == table or (
+ schema and (id == schema + '.' + table))
diff --git a/mycli/packages/filepaths.py b/mycli/packages/filepaths.py
new file mode 100644
index 0000000..79fe26d
--- /dev/null
+++ b/mycli/packages/filepaths.py
@@ -0,0 +1,106 @@
+import os
+import platform
+
+
+if os.name == "posix":
+ if platform.system() == "Darwin":
+ DEFAULT_SOCKET_DIRS = ("/tmp",)
+ else:
+ DEFAULT_SOCKET_DIRS = ("/var/run", "/var/lib")
+else:
+ DEFAULT_SOCKET_DIRS = ()
+
+
+def list_path(root_dir):
+ """List directory if exists.
+
+ :param root_dir: str
+ :return: list
+
+ """
+ res = []
+ if os.path.isdir(root_dir):
+ for name in os.listdir(root_dir):
+ res.append(name)
+ return res
+
+
+def complete_path(curr_dir, last_dir):
+ """Return the path to complete that matches the last entered component.
+
+ If the last entered component is ~, expanded path would not
+ match, so return all of the available paths.
+
+ :param curr_dir: str
+ :param last_dir: str
+ :return: str
+
+ """
+ if not last_dir or curr_dir.startswith(last_dir):
+ return curr_dir
+ elif last_dir == '~':
+ return os.path.join(last_dir, curr_dir)
+
+
+def parse_path(root_dir):
+ """Split path into head and last component for the completer.
+
+ Also return position where last component starts.
+
+ :param root_dir: str path
+ :return: tuple of (string, string, int)
+
+ """
+ base_dir, last_dir, position = '', '', 0
+ if root_dir:
+ base_dir, last_dir = os.path.split(root_dir)
+ position = -len(last_dir) if last_dir else 0
+ return base_dir, last_dir, position
+
+
+def suggest_path(root_dir):
+ """List all files and subdirectories in a directory.
+
+ If the directory is not specified, suggest root directory,
+ user directory, current and parent directory.
+
+ :param root_dir: string: directory to list
+ :return: list
+
+ """
+ if not root_dir:
+ return [os.path.abspath(os.sep), '~', os.curdir, os.pardir]
+
+ if '~' in root_dir:
+ root_dir = os.path.expanduser(root_dir)
+
+ if not os.path.exists(root_dir):
+ root_dir, _ = os.path.split(root_dir)
+
+ return list_path(root_dir)
+
+
+def dir_path_exists(path):
+ """Check if the directory path exists for a given file.
+
+ For example, for a file /home/user/.cache/mycli/log, check if
+ /home/user/.cache/mycli exists.
+
+ :param str path: The file path.
+ :return: Whether or not the directory path exists.
+
+ """
+ return os.path.exists(os.path.dirname(path))
+
+
+def guess_socket_location():
+ """Try to guess the location of the default mysql socket file."""
+ socket_dirs = filter(os.path.exists, DEFAULT_SOCKET_DIRS)
+ for directory in socket_dirs:
+ for r, dirs, files in os.walk(directory, topdown=True):
+ for filename in files:
+ name, ext = os.path.splitext(filename)
+ if name.startswith("mysql") and ext in ('.socket', '.sock'):
+ return os.path.join(r, filename)
+ dirs[:] = [d for d in dirs if d.startswith("mysql")]
+ return None
diff --git a/mycli/packages/paramiko_stub/__init__.py b/mycli/packages/paramiko_stub/__init__.py
new file mode 100644
index 0000000..045b00e
--- /dev/null
+++ b/mycli/packages/paramiko_stub/__init__.py
@@ -0,0 +1,28 @@
+"""A module to import instead of paramiko when it is not available (to avoid
+checking for paramiko all over the place).
+
+When paramiko is first envoked, it simply shuts down mycli, telling
+user they either have to install paramiko or should not use SSH
+features.
+
+"""
+
+
+class Paramiko:
+ def __getattr__(self, name):
+ import sys
+ from textwrap import dedent
+ print(dedent("""
+ To enable certain SSH features you need to install paramiko:
+
+ pip install paramiko
+
+ It is required for the following configuration options:
+ --list-ssh-config
+ --ssh-config-host
+ --ssh-host
+ """))
+ sys.exit(1)
+
+
+paramiko = Paramiko()
diff --git a/mycli/packages/parseutils.py b/mycli/packages/parseutils.py
new file mode 100644
index 0000000..3090530
--- /dev/null
+++ b/mycli/packages/parseutils.py
@@ -0,0 +1,266 @@
+import re
+import sqlparse
+from sqlparse.sql import IdentifierList, Identifier, Function
+from sqlparse.tokens import Keyword, DML, Punctuation
+
+cleanup_regex = {
+ # This matches only alphanumerics and underscores.
+ 'alphanum_underscore': re.compile(r'(\w+)$'),
+ # This matches everything except spaces, parens, colon, and comma
+ 'many_punctuations': re.compile(r'([^():,\s]+)$'),
+ # This matches everything except spaces, parens, colon, comma, and period
+ 'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
+ # This matches everything except a space.
+ 'all_punctuations': re.compile(r'([^\s]+)$'),
+}
+
+
+def last_word(text, include='alphanum_underscore'):
+ r"""
+ Find the last word in a sentence.
+
+ >>> last_word('abc')
+ 'abc'
+ >>> last_word(' abc')
+ 'abc'
+ >>> last_word('')
+ ''
+ >>> last_word(' ')
+ ''
+ >>> last_word('abc ')
+ ''
+ >>> last_word('abc def')
+ 'def'
+ >>> last_word('abc def ')
+ ''
+ >>> last_word('abc def;')
+ ''
+ >>> last_word('bac $def')
+ 'def'
+ >>> last_word('bac $def', include='most_punctuations')
+ '$def'
+ >>> last_word('bac \def', include='most_punctuations')
+ '\\\\def'
+ >>> last_word('bac \def;', include='most_punctuations')
+ '\\\\def;'
+ >>> last_word('bac::def', include='most_punctuations')
+ 'def'
+ """
+
+ if not text: # Empty string
+ return ''
+
+ if text[-1].isspace():
+ return ''
+ else:
+ regex = cleanup_regex[include]
+ matches = regex.search(text)
+ if matches:
+ return matches.group(0)
+ else:
+ return ''
+
+
+# This code is borrowed from sqlparse example script.
+# <url>
+def is_subselect(parsed):
+ if not parsed.is_group:
+ return False
+ for item in parsed.tokens:
+ if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT',
+ 'UPDATE', 'CREATE', 'DELETE'):
+ return True
+ return False
+
+def extract_from_part(parsed, stop_at_punctuation=True):
+ tbl_prefix_seen = False
+ for item in parsed.tokens:
+ if tbl_prefix_seen:
+ if is_subselect(item):
+ for x in extract_from_part(item, stop_at_punctuation):
+ yield x
+ elif stop_at_punctuation and item.ttype is Punctuation:
+ return
+ # Multiple JOINs in the same query won't work properly since
+ # "ON" is a keyword and will trigger the next elif condition.
+ # So instead of stooping the loop when finding an "ON" skip it
+ # eg: 'SELECT * FROM abc JOIN def ON abc.id = def.abc_id JOIN ghi'
+ elif item.ttype is Keyword and item.value.upper() == 'ON':
+ tbl_prefix_seen = False
+ continue
+ # An incomplete nested select won't be recognized correctly as a
+ # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
+ # the second FROM to trigger this elif condition resulting in a
+ # StopIteration. So we need to ignore the keyword if the keyword
+ # FROM.
+ # Also 'SELECT * FROM abc JOIN def' will trigger this elif
+ # condition. So we need to ignore the keyword JOIN and its variants
+ # INNER JOIN, FULL OUTER JOIN, etc.
+ elif item.ttype is Keyword and (
+ not item.value.upper() == 'FROM') and (
+ not item.value.upper().endswith('JOIN')):
+ return
+ else:
+ yield item
+ elif ((item.ttype is Keyword or item.ttype is Keyword.DML) and
+ item.value.upper() in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE', 'JOIN',)):
+ tbl_prefix_seen = True
+ # 'SELECT a, FROM abc' will detect FROM as part of the column list.
+ # So this check here is necessary.
+ elif isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ if (identifier.ttype is Keyword and
+ identifier.value.upper() == 'FROM'):
+ tbl_prefix_seen = True
+ break
+
+def extract_table_identifiers(token_stream):
+ """yields tuples of (schema_name, table_name, table_alias)"""
+
+ for item in token_stream:
+ if isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ # Sometimes Keywords (such as FROM ) are classified as
+ # identifiers which don't have the get_real_name() method.
+ try:
+ schema_name = identifier.get_parent_name()
+ real_name = identifier.get_real_name()
+ except AttributeError:
+ continue
+ if real_name:
+ yield (schema_name, real_name, identifier.get_alias())
+ elif isinstance(item, Identifier):
+ real_name = item.get_real_name()
+ schema_name = item.get_parent_name()
+
+ if real_name:
+ yield (schema_name, real_name, item.get_alias())
+ else:
+ name = item.get_name()
+ yield (None, name, item.get_alias() or name)
+ elif isinstance(item, Function):
+ yield (None, item.get_name(), item.get_name())
+
+# extract_tables is inspired from examples in the sqlparse lib.
+def extract_tables(sql):
+ """Extract the table names from an SQL statement.
+
+ Returns a list of (schema, table, alias) tuples
+
+ """
+ parsed = sqlparse.parse(sql)
+ if not parsed:
+ return []
+
+ # INSERT statements must stop looking for tables at the sign of first
+ # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
+ # abc is the table name, but if we don't stop at the first lparen, then
+ # we'll identify abc, col1 and col2 as table names.
+ insert_stmt = parsed[0].token_first().value.lower() == 'insert'
+ stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
+ return list(extract_table_identifiers(stream))
+
+def find_prev_keyword(sql):
+ """ Find the last sql keyword in an SQL statement
+
+ Returns the value of the last keyword, and the text of the query with
+ everything after the last keyword stripped
+ """
+ if not sql.strip():
+ return None, ''
+
+ parsed = sqlparse.parse(sql)[0]
+ flattened = list(parsed.flatten())
+
+ logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
+
+ for t in reversed(flattened):
+ if t.value == '(' or (t.is_keyword and (
+ t.value.upper() not in logical_operators)):
+ # Find the location of token t in the original parsed statement
+ # We can't use parsed.token_index(t) because t may be a child token
+ # inside a TokenList, in which case token_index thows an error
+ # Minimal example:
+ # p = sqlparse.parse('select * from foo where bar')
+ # t = list(p.flatten())[-3] # The "Where" token
+ # p.token_index(t) # Throws ValueError: not in list
+ idx = flattened.index(t)
+
+ # Combine the string values of all tokens in the original list
+ # up to and including the target keyword token t, to produce a
+ # query string with everything after the keyword token removed
+ text = ''.join(tok.value for tok in flattened[:idx+1])
+ return t, text
+
+ return None, ''
+
+
+def query_starts_with(query, prefixes):
+ """Check if the query starts with any item from *prefixes*."""
+ prefixes = [prefix.lower() for prefix in prefixes]
+ formatted_sql = sqlparse.format(query.lower(), strip_comments=True)
+ return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
+
+
+def queries_start_with(queries, prefixes):
+ """Check if any queries start with any item from *prefixes*."""
+ for query in sqlparse.split(queries):
+ if query and query_starts_with(query, prefixes) is True:
+ return True
+ return False
+
+
+def query_has_where_clause(query):
+ """Check if the query contains a where-clause."""
+ return any(
+ isinstance(token, sqlparse.sql.Where)
+ for token_list in sqlparse.parse(query)
+ for token in token_list
+ )
+
+
+def is_destructive(queries):
+ """Returns if any of the queries in *queries* is destructive."""
+ keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter')
+ for query in sqlparse.split(queries):
+ if query:
+ if query_starts_with(query, keywords) is True:
+ return True
+ elif query_starts_with(
+ query, ['update']
+ ) is True and not query_has_where_clause(query):
+ return True
+
+ return False
+
+
+if __name__ == '__main__':
+ sql = 'select * from (select t. from tabl t'
+ print (extract_tables(sql))
+
+
+def is_dropping_database(queries, dbname):
+ """Determine if the query is dropping a specific database."""
+ result = False
+ if dbname is None:
+ return False
+
+ def normalize_db_name(db):
+ return db.lower().strip('`"')
+
+ dbname = normalize_db_name(dbname)
+
+ for query in sqlparse.parse(queries):
+ keywords = [t for t in query.tokens if t.is_keyword]
+ if len(keywords) < 2:
+ continue
+ if keywords[0].normalized in ("DROP", "CREATE") and keywords[1].value.lower() in (
+ "database",
+ "schema",
+ ):
+ database_token = next(
+ (t for t in query.tokens if isinstance(t, Identifier)), None
+ )
+ if database_token is not None and normalize_db_name(database_token.get_name()) == dbname:
+ result = keywords[0].normalized == "DROP"
+ return result
diff --git a/mycli/packages/prompt_utils.py b/mycli/packages/prompt_utils.py
new file mode 100644
index 0000000..fb1e431
--- /dev/null
+++ b/mycli/packages/prompt_utils.py
@@ -0,0 +1,54 @@
+import sys
+import click
+from .parseutils import is_destructive
+
+
+class ConfirmBoolParamType(click.ParamType):
+ name = 'confirmation'
+
+ def convert(self, value, param, ctx):
+ if isinstance(value, bool):
+ return bool(value)
+ value = value.lower()
+ if value in ('yes', 'y'):
+ return True
+ elif value in ('no', 'n'):
+ return False
+ self.fail('%s is not a valid boolean' % value, param, ctx)
+
+ def __repr__(self):
+ return 'BOOL'
+
+
+BOOLEAN_TYPE = ConfirmBoolParamType()
+
+
+def confirm_destructive_query(queries):
+ """Check if the query is destructive and prompts the user to confirm.
+
+ Returns:
+ * None if the query is non-destructive or we can't prompt the user.
+ * True if the query is destructive and the user wants to proceed.
+ * False if the query is destructive and the user doesn't want to proceed.
+
+ """
+ prompt_text = ("You're about to run a destructive command.\n"
+ "Do you want to proceed? (y/n)")
+ if is_destructive(queries) and sys.stdin.isatty():
+ return prompt(prompt_text, type=BOOLEAN_TYPE)
+
+
+def confirm(*args, **kwargs):
+ """Prompt for confirmation (yes/no) and handle any abort exceptions."""
+ try:
+ return click.confirm(*args, **kwargs)
+ except click.Abort:
+ return False
+
+
+def prompt(*args, **kwargs):
+ """Prompt the user for input and handle any abort exceptions."""
+ try:
+ return click.prompt(*args, **kwargs)
+ except click.Abort:
+ return False
diff --git a/mycli/packages/special/__init__.py b/mycli/packages/special/__init__.py
new file mode 100644
index 0000000..92bcca6
--- /dev/null
+++ b/mycli/packages/special/__init__.py
@@ -0,0 +1,10 @@
+__all__ = []
+
+def export(defn):
+ """Decorator to explicitly mark functions that are exposed in a lib."""
+ globals()[defn.__name__] = defn
+ __all__.append(defn.__name__)
+ return defn
+
+from . import dbcommands
+from . import iocommands
diff --git a/mycli/packages/special/dbcommands.py b/mycli/packages/special/dbcommands.py
new file mode 100644
index 0000000..5c29c55
--- /dev/null
+++ b/mycli/packages/special/dbcommands.py
@@ -0,0 +1,162 @@
+import logging
+import os
+import platform
+from mycli import __version__
+from mycli.packages.special import iocommands
+from mycli.packages.special.utils import format_uptime
+from .main import special_command, RAW_QUERY, PARSED_QUERY
+from pymysql import ProgrammingError
+
+log = logging.getLogger(__name__)
+
+
+@special_command('\\dt', '\\dt[+] [table]', 'List or describe tables.',
+ arg_type=PARSED_QUERY, case_sensitive=True)
+def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False):
+ if arg:
+ query = 'SHOW FIELDS FROM {0}'.format(arg)
+ else:
+ query = 'SHOW TABLES'
+ log.debug(query)
+ cur.execute(query)
+ tables = cur.fetchall()
+ status = ''
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ else:
+ return [(None, None, None, '')]
+
+ if verbose and arg:
+ query = 'SHOW CREATE TABLE {0}'.format(arg)
+ log.debug(query)
+ cur.execute(query)
+ status = cur.fetchone()[1]
+
+ return [(None, tables, headers, status)]
+
+
+@special_command('\\l', '\\l', 'List databases.', arg_type=RAW_QUERY, case_sensitive=True)
+def list_databases(cur, **_):
+ query = 'SHOW DATABASES'
+ log.debug(query)
+ cur.execute(query)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ return [(None, cur, headers, '')]
+ else:
+ return [(None, None, None, '')]
+
+
+@special_command('status', '\\s', 'Get status information from the server.',
+ arg_type=RAW_QUERY, aliases=('\\s', ), case_sensitive=True)
+def status(cur, **_):
+ query = 'SHOW GLOBAL STATUS;'
+ log.debug(query)
+ try:
+ cur.execute(query)
+ except ProgrammingError:
+ # Fallback in case query fail, as it does with Mysql 4
+ query = 'SHOW STATUS;'
+ log.debug(query)
+ cur.execute(query)
+ status = dict(cur.fetchall())
+
+ query = 'SHOW GLOBAL VARIABLES;'
+ log.debug(query)
+ cur.execute(query)
+ variables = dict(cur.fetchall())
+
+ # prepare in case keys are bytes, as with Python 3 and Mysql 4
+ if (isinstance(list(variables)[0], bytes) and
+ isinstance(list(status)[0], bytes)):
+ variables = {k.decode('utf-8'): v.decode('utf-8') for k, v
+ in variables.items()}
+ status = {k.decode('utf-8'): v.decode('utf-8') for k, v
+ in status.items()}
+
+ # Create output buffers.
+ title = []
+ output = []
+ footer = []
+
+ title.append('--------------')
+
+ # Output the mycli client information.
+ implementation = platform.python_implementation()
+ version = platform.python_version()
+ client_info = []
+ client_info.append('mycli {0},'.format(__version__))
+ client_info.append('running on {0} {1}'.format(implementation, version))
+ title.append(' '.join(client_info) + '\n')
+
+ # Build the output that will be displayed as a table.
+ output.append(('Connection id:', cur.connection.thread_id()))
+
+ query = 'SELECT DATABASE(), USER();'
+ log.debug(query)
+ cur.execute(query)
+ db, user = cur.fetchone()
+ if db is None:
+ db = ''
+
+ output.append(('Current database:', db))
+ output.append(('Current user:', user))
+
+ if iocommands.is_pager_enabled():
+ if 'PAGER' in os.environ:
+ pager = os.environ['PAGER']
+ else:
+ pager = 'System default'
+ else:
+ pager = 'stdout'
+ output.append(('Current pager:', pager))
+
+ output.append(('Server version:', '{0} {1}'.format(
+ variables['version'], variables['version_comment'])))
+ output.append(('Protocol version:', variables['protocol_version']))
+
+ if 'unix' in cur.connection.host_info.lower():
+ host_info = cur.connection.host_info
+ else:
+ host_info = '{0} via TCP/IP'.format(cur.connection.host)
+
+ output.append(('Connection:', host_info))
+
+ query = ('SELECT @@character_set_server, @@character_set_database, '
+ '@@character_set_client, @@character_set_connection LIMIT 1;')
+ log.debug(query)
+ cur.execute(query)
+ charset = cur.fetchone()
+ output.append(('Server characterset:', charset[0]))
+ output.append(('Db characterset:', charset[1]))
+ output.append(('Client characterset:', charset[2]))
+ output.append(('Conn. characterset:', charset[3]))
+
+ if 'TCP/IP' in host_info:
+ output.append(('TCP port:', cur.connection.port))
+ else:
+ output.append(('UNIX socket:', variables['socket']))
+
+ if 'Uptime' in status:
+ output.append(('Uptime:', format_uptime(status['Uptime'])))
+
+ if 'Threads_connected' in status:
+ # Print the current server statistics.
+ stats = []
+ stats.append('Connections: {0}'.format(status['Threads_connected']))
+ if 'Queries' in status:
+ stats.append('Queries: {0}'.format(status['Queries']))
+ stats.append('Slow queries: {0}'.format(status['Slow_queries']))
+ stats.append('Opens: {0}'.format(status['Opened_tables']))
+ if 'Flush_commands' in status:
+ stats.append('Flush tables: {0}'.format(status['Flush_commands']))
+ stats.append('Open tables: {0}'.format(status['Open_tables']))
+ if 'Queries' in status:
+ queries_per_second = int(status['Queries']) / int(status['Uptime'])
+ stats.append('Queries per second avg: {:.3f}'.format(
+ queries_per_second))
+ stats = ' '.join(stats)
+ footer.append('\n' + stats)
+
+ footer.append('--------------')
+ return [('\n'.join(title), output, '', '\n'.join(footer))]
diff --git a/mycli/packages/special/delimitercommand.py b/mycli/packages/special/delimitercommand.py
new file mode 100644
index 0000000..994b134
--- /dev/null
+++ b/mycli/packages/special/delimitercommand.py
@@ -0,0 +1,80 @@
+import re
+import sqlparse
+
+
+class DelimiterCommand(object):
+ def __init__(self):
+ self._delimiter = ';'
+
+ def _split(self, sql):
+ """Temporary workaround until sqlparse.split() learns about custom
+ delimiters."""
+
+ placeholder = "\ufffc" # unicode object replacement character
+
+ if self._delimiter == ';':
+ return sqlparse.split(sql)
+
+ # We must find a string that original sql does not contain.
+ # Most likely, our placeholder is enough, but if not, keep looking
+ while placeholder in sql:
+ placeholder += placeholder[0]
+ sql = sql.replace(';', placeholder)
+ sql = sql.replace(self._delimiter, ';')
+
+ split = sqlparse.split(sql)
+
+ return [
+ stmt.replace(';', self._delimiter).replace(placeholder, ';')
+ for stmt in split
+ ]
+
+ def queries_iter(self, input):
+ """Iterate over queries in the input string."""
+
+ queries = self._split(input)
+ while queries:
+ for sql in queries:
+ delimiter = self._delimiter
+ sql = queries.pop(0)
+ if sql.endswith(delimiter):
+ trailing_delimiter = True
+ sql = sql.strip(delimiter)
+ else:
+ trailing_delimiter = False
+
+ yield sql
+
+ # if the delimiter was changed by the last command,
+ # re-split everything, and if we previously stripped
+ # the delimiter, append it to the end
+ if self._delimiter != delimiter:
+ combined_statement = ' '.join([sql] + queries)
+ if trailing_delimiter:
+ combined_statement += delimiter
+ queries = self._split(combined_statement)[1:]
+
+ def set(self, arg, **_):
+ """Change delimiter.
+
+ Since `arg` is everything that follows the DELIMITER token
+ after sqlparse (it may include other statements separated by
+ the new delimiter), we want to set the delimiter to the first
+ word of it.
+
+ """
+ match = arg and re.search(r'[^\s]+', arg)
+ if not match:
+ message = 'Missing required argument, delimiter'
+ return [(None, None, None, message)]
+
+ delimiter = match.group()
+ if delimiter.lower() == 'delimiter':
+ return [(None, None, None, 'Invalid delimiter "delimiter"')]
+
+ self._delimiter = delimiter
+ return [(None, None, None, "Changed delimiter to {}".format(delimiter))]
+
+ @property
+ def current(self):
+ return self._delimiter
diff --git a/mycli/packages/special/favoritequeries.py b/mycli/packages/special/favoritequeries.py
new file mode 100644
index 0000000..0b91400
--- /dev/null
+++ b/mycli/packages/special/favoritequeries.py
@@ -0,0 +1,63 @@
+class FavoriteQueries(object):
+
+ section_name = 'favorite_queries'
+
+ usage = '''
+Favorite Queries are a way to save frequently used queries
+with a short name.
+Examples:
+
+ # Save a new favorite query.
+ > \\fs simple select * from abc where a is not Null;
+
+ # List all favorite queries.
+ > \\f
+ ╒════════╤═══════════════════════════════════════╕
+ │ Name │ Query │
+ ╞════════╪═══════════════════════════════════════╡
+ │ simple │ SELECT * FROM abc where a is not NULL │
+ ╘════════╧═══════════════════════════════════════╛
+
+ # Run a favorite query.
+ > \\f simple
+ ╒════════╤════════╕
+ │ a │ b │
+ ╞════════╪════════╡
+ │ 日本語 │ 日本語 │
+ ╘════════╧════════╛
+
+ # Delete a favorite query.
+ > \\fd simple
+ simple: Deleted
+'''
+
+ # Class-level variable, for convenience to use as a singleton.
+ instance = None
+
+ def __init__(self, config):
+ self.config = config
+
+ @classmethod
+ def from_config(cls, config):
+ return FavoriteQueries(config)
+
+ def list(self):
+ return self.config.get(self.section_name, [])
+
+ def get(self, name):
+ return self.config.get(self.section_name, {}).get(name, None)
+
+ def save(self, name, query):
+ self.config.encoding = 'utf-8'
+ if self.section_name not in self.config:
+ self.config[self.section_name] = {}
+ self.config[self.section_name][name] = query
+ self.config.write()
+
+ def delete(self, name):
+ try:
+ del self.config[self.section_name][name]
+ except KeyError:
+ return '%s: Not Found.' % name
+ self.config.write()
+ return '%s: Deleted' % name
diff --git a/mycli/packages/special/iocommands.py b/mycli/packages/special/iocommands.py
new file mode 100644
index 0000000..01f3c7b
--- /dev/null
+++ b/mycli/packages/special/iocommands.py
@@ -0,0 +1,543 @@
+import os
+import re
+import locale
+import logging
+import subprocess
+import shlex
+from io import open
+from time import sleep
+
+import click
+import pyperclip
+import sqlparse
+
+from . import export
+from .main import special_command, NO_QUERY, PARSED_QUERY
+from .favoritequeries import FavoriteQueries
+from .delimitercommand import DelimiterCommand
+from .utils import handle_cd_command
+from mycli.packages.prompt_utils import confirm_destructive_query
+
+TIMING_ENABLED = False
+use_expanded_output = False
+PAGER_ENABLED = True
+tee_file = None
+once_file = None
+written_to_once_file = False
+pipe_once_process = None
+written_to_pipe_once_process = False
+delimiter_command = DelimiterCommand()
+
+
+@export
+def set_timing_enabled(val):
+ global TIMING_ENABLED
+ TIMING_ENABLED = val
+
+@export
+def set_pager_enabled(val):
+ global PAGER_ENABLED
+ PAGER_ENABLED = val
+
+
+@export
+def is_pager_enabled():
+ return PAGER_ENABLED
+
+@export
+@special_command('pager', '\\P [command]',
+ 'Set PAGER. Print the query results via PAGER.',
+ arg_type=PARSED_QUERY, aliases=('\\P', ), case_sensitive=True)
+def set_pager(arg, **_):
+ if arg:
+ os.environ['PAGER'] = arg
+ msg = 'PAGER set to %s.' % arg
+ set_pager_enabled(True)
+ else:
+ if 'PAGER' in os.environ:
+ msg = 'PAGER set to %s.' % os.environ['PAGER']
+ else:
+ # This uses click's default per echo_via_pager.
+ msg = 'Pager enabled.'
+ set_pager_enabled(True)
+
+ return [(None, None, None, msg)]
+
+@export
+@special_command('nopager', '\\n', 'Disable pager, print to stdout.',
+ arg_type=NO_QUERY, aliases=('\\n', ), case_sensitive=True)
+def disable_pager():
+ set_pager_enabled(False)
+ return [(None, None, None, 'Pager disabled.')]
+
+@special_command('\\timing', '\\t', 'Toggle timing of commands.', arg_type=NO_QUERY, aliases=('\\t', ), case_sensitive=True)
+def toggle_timing():
+ global TIMING_ENABLED
+ TIMING_ENABLED = not TIMING_ENABLED
+ message = "Timing is "
+ message += "on." if TIMING_ENABLED else "off."
+ return [(None, None, None, message)]
+
+@export
+def is_timing_enabled():
+ return TIMING_ENABLED
+
+@export
+def set_expanded_output(val):
+ global use_expanded_output
+ use_expanded_output = val
+
+@export
+def is_expanded_output():
+ return use_expanded_output
+
+_logger = logging.getLogger(__name__)
+
+@export
+def editor_command(command):
+ """
+ Is this an external editor command?
+ :param command: string
+ """
+ # It is possible to have `\e filename` or `SELECT * FROM \e`. So we check
+ # for both conditions.
+ return command.strip().endswith('\\e') or command.strip().startswith('\\e')
+
+@export
+def get_filename(sql):
+ if sql.strip().startswith('\\e'):
+ command, _, filename = sql.partition(' ')
+ return filename.strip() or None
+
+
+@export
+def get_editor_query(sql):
+ """Get the query part of an editor command."""
+ sql = sql.strip()
+
+ # The reason we can't simply do .strip('\e') is that it strips characters,
+ # not a substring. So it'll strip "e" in the end of the sql also!
+ # Ex: "select * from style\e" -> "select * from styl".
+ pattern = re.compile(r'(^\\e|\\e$)')
+ while pattern.search(sql):
+ sql = pattern.sub('', sql)
+
+ return sql
+
+
+@export
+def open_external_editor(filename=None, sql=None):
+ """Open external editor, wait for the user to type in their query, return
+ the query.
+
+ :return: list with one tuple, query as first element.
+
+ """
+
+ message = None
+ filename = filename.strip().split(' ', 1)[0] if filename else None
+
+ sql = sql or ''
+ MARKER = '# Type your query above this line.\n'
+
+ # Populate the editor buffer with the partial sql (if available) and a
+ # placeholder comment.
+ query = click.edit(u'{sql}\n\n{marker}'.format(sql=sql, marker=MARKER),
+ filename=filename, extension='.sql')
+
+ if filename:
+ try:
+ with open(filename) as f:
+ query = f.read()
+ except IOError:
+ message = 'Error reading file: %s.' % filename
+
+ if query is not None:
+ query = query.split(MARKER, 1)[0].rstrip('\n')
+ else:
+ # Don't return None for the caller to deal with.
+ # Empty string is ok.
+ query = sql
+
+ return (query, message)
+
+
+@export
+def clip_command(command):
+ """Is this a clip command?
+
+ :param command: string
+
+ """
+ # It is possible to have `\clip` or `SELECT * FROM \clip`. So we check
+ # for both conditions.
+ return command.strip().endswith('\\clip') or command.strip().startswith('\\clip')
+
+
+@export
+def get_clip_query(sql):
+ """Get the query part of a clip command."""
+ sql = sql.strip()
+
+ # The reason we can't simply do .strip('\clip') is that it strips characters,
+ # not a substring. So it'll strip "c" in the end of the sql also!
+ pattern = re.compile(r'(^\\clip|\\clip$)')
+ while pattern.search(sql):
+ sql = pattern.sub('', sql)
+
+ return sql
+
+
+@export
+def copy_query_to_clipboard(sql=None):
+ """Send query to the clipboard."""
+
+ sql = sql or ''
+ message = None
+
+ try:
+ pyperclip.copy(u'{sql}'.format(sql=sql))
+ except RuntimeError as e:
+ message = 'Error clipping query: %s.' % e.strerror
+
+ return message
+
+
+@special_command('\\f', '\\f [name [args..]]', 'List or execute favorite queries.', arg_type=PARSED_QUERY, case_sensitive=True)
+def execute_favorite_query(cur, arg, **_):
+ """Returns (title, rows, headers, status)"""
+ if arg == '':
+ for result in list_favorite_queries():
+ yield result
+
+ """Parse out favorite name and optional substitution parameters"""
+ name, _, arg_str = arg.partition(' ')
+ args = shlex.split(arg_str)
+
+ query = FavoriteQueries.instance.get(name)
+ if query is None:
+ message = "No favorite query: %s" % (name)
+ yield (None, None, None, message)
+ else:
+ query, arg_error = subst_favorite_query_args(query, args)
+ if arg_error:
+ yield (None, None, None, arg_error)
+ else:
+ for sql in sqlparse.split(query):
+ sql = sql.rstrip(';')
+ title = '> %s' % (sql)
+ cur.execute(sql)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ yield (title, cur, headers, None)
+ else:
+ yield (title, None, None, None)
+
+def list_favorite_queries():
+ """List of all favorite queries.
+ Returns (title, rows, headers, status)"""
+
+ headers = ["Name", "Query"]
+ rows = [(r, FavoriteQueries.instance.get(r))
+ for r in FavoriteQueries.instance.list()]
+
+ if not rows:
+ status = '\nNo favorite queries found.' + FavoriteQueries.instance.usage
+ else:
+ status = ''
+ return [('', rows, headers, status)]
+
+
+def subst_favorite_query_args(query, args):
+ """replace positional parameters ($1...$N) in query."""
+ for idx, val in enumerate(args):
+ subst_var = '$' + str(idx + 1)
+ if subst_var not in query:
+ return [None, 'query does not have substitution parameter ' + subst_var + ':\n ' + query]
+
+ query = query.replace(subst_var, val)
+
+ match = re.search(r'\$\d+', query)
+ if match:
+ return[None, 'missing substitution for ' + match.group(0) + ' in query:\n ' + query]
+
+ return [query, None]
+
+@special_command('\\fs', '\\fs name query', 'Save a favorite query.')
+def save_favorite_query(arg, **_):
+ """Save a new favorite query.
+ Returns (title, rows, headers, status)"""
+
+ usage = 'Syntax: \\fs name query.\n\n' + FavoriteQueries.instance.usage
+ if not arg:
+ return [(None, None, None, usage)]
+
+ name, _, query = arg.partition(' ')
+
+ # If either name or query is missing then print the usage and complain.
+ if (not name) or (not query):
+ return [(None, None, None,
+ usage + 'Err: Both name and query are required.')]
+
+ FavoriteQueries.instance.save(name, query)
+ return [(None, None, None, "Saved.")]
+
+
+@special_command('\\fd', '\\fd [name]', 'Delete a favorite query.')
+def delete_favorite_query(arg, **_):
+ """Delete an existing favorite query."""
+ usage = 'Syntax: \\fd name.\n\n' + FavoriteQueries.instance.usage
+ if not arg:
+ return [(None, None, None, usage)]
+
+ status = FavoriteQueries.instance.delete(arg)
+
+ return [(None, None, None, status)]
+
+
+@special_command('system', 'system [command]',
+ 'Execute a system shell commmand.')
+def execute_system_command(arg, **_):
+ """Execute a system shell command."""
+ usage = "Syntax: system [command].\n"
+
+ if not arg:
+ return [(None, None, None, usage)]
+
+ try:
+ command = arg.strip()
+ if command.startswith('cd'):
+ ok, error_message = handle_cd_command(arg)
+ if not ok:
+ return [(None, None, None, error_message)]
+ return [(None, None, None, '')]
+
+ args = arg.split(' ')
+ process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ output, error = process.communicate()
+ response = output if not error else error
+
+ # Python 3 returns bytes. This needs to be decoded to a string.
+ if isinstance(response, bytes):
+ encoding = locale.getpreferredencoding(False)
+ response = response.decode(encoding)
+
+ return [(None, None, None, response)]
+ except OSError as e:
+ return [(None, None, None, 'OSError: %s' % e.strerror)]
+
+
+def parseargfile(arg):
+ if arg.startswith('-o '):
+ mode = "w"
+ filename = arg[3:]
+ else:
+ mode = 'a'
+ filename = arg
+
+ if not filename:
+ raise TypeError('You must provide a filename.')
+
+ return {'file': os.path.expanduser(filename), 'mode': mode}
+
+
+@special_command('tee', 'tee [-o] filename',
+ 'Append all results to an output file (overwrite using -o).')
+def set_tee(arg, **_):
+ global tee_file
+
+ try:
+ tee_file = open(**parseargfile(arg))
+ except (IOError, OSError) as e:
+ raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror))
+
+ return [(None, None, None, "")]
+
+@export
+def close_tee():
+ global tee_file
+ if tee_file:
+ tee_file.close()
+ tee_file = None
+
+
+@special_command('notee', 'notee', 'Stop writing results to an output file.')
+def no_tee(arg, **_):
+ close_tee()
+ return [(None, None, None, "")]
+
+@export
+def write_tee(output):
+ global tee_file
+ if tee_file:
+ click.echo(output, file=tee_file, nl=False)
+ click.echo(u'\n', file=tee_file, nl=False)
+ tee_file.flush()
+
+
+@special_command('\\once', '\\o [-o] filename',
+ 'Append next result to an output file (overwrite using -o).',
+ aliases=('\\o', ))
+def set_once(arg, **_):
+ global once_file, written_to_once_file
+
+ try:
+ once_file = open(**parseargfile(arg))
+ except (IOError, OSError) as e:
+ raise OSError("Cannot write to file '{}': {}".format(
+ e.filename, e.strerror))
+ written_to_once_file = False
+
+ return [(None, None, None, "")]
+
+
+@export
+def write_once(output):
+ global once_file, written_to_once_file
+ if output and once_file:
+ click.echo(output, file=once_file, nl=False)
+ click.echo(u"\n", file=once_file, nl=False)
+ once_file.flush()
+ written_to_once_file = True
+
+
+@export
+def unset_once_if_written():
+ """Unset the once file, if it has been written to."""
+ global once_file, written_to_once_file
+ if written_to_once_file and once_file:
+ once_file.close()
+ once_file = None
+
+
+@special_command('\\pipe_once', '\\| command',
+ 'Send next result to a subprocess.',
+ aliases=('\\|', ))
+def set_pipe_once(arg, **_):
+ global pipe_once_process, written_to_pipe_once_process
+ pipe_once_cmd = shlex.split(arg)
+ if len(pipe_once_cmd) == 0:
+ raise OSError("pipe_once requires a command")
+ written_to_pipe_once_process = False
+ pipe_once_process = subprocess.Popen(pipe_once_cmd,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ bufsize=1,
+ encoding='UTF-8',
+ universal_newlines=True)
+ return [(None, None, None, "")]
+
+
+@export
+def write_pipe_once(output):
+ global pipe_once_process, written_to_pipe_once_process
+ if output and pipe_once_process:
+ try:
+ click.echo(output, file=pipe_once_process.stdin, nl=False)
+ click.echo(u"\n", file=pipe_once_process.stdin, nl=False)
+ except (IOError, OSError) as e:
+ pipe_once_process.terminate()
+ raise OSError(
+ "Failed writing to pipe_once subprocess: {}".format(e.strerror))
+ written_to_pipe_once_process = True
+
+
+@export
+def unset_pipe_once_if_written():
+ """Unset the pipe_once cmd, if it has been written to."""
+ global pipe_once_process, written_to_pipe_once_process
+ if written_to_pipe_once_process:
+ (stdout_data, stderr_data) = pipe_once_process.communicate()
+ if len(stdout_data) > 0:
+ print(stdout_data.rstrip(u"\n"))
+ if len(stderr_data) > 0:
+ print(stderr_data.rstrip(u"\n"))
+ pipe_once_process = None
+ written_to_pipe_once_process = False
+
+
+@special_command(
+ 'watch',
+ 'watch [seconds] [-c] query',
+ 'Executes the query every [seconds] seconds (by default 5).'
+)
+def watch_query(arg, **kwargs):
+ usage = """Syntax: watch [seconds] [-c] query.
+ * seconds: The interval at the query will be repeated, in seconds.
+ By default 5.
+ * -c: Clears the screen between every iteration.
+"""
+ if not arg:
+ yield (None, None, None, usage)
+ return
+ seconds = 5
+ clear_screen = False
+ statement = None
+ while statement is None:
+ arg = arg.strip()
+ if not arg:
+ # Oops, we parsed all the arguments without finding a statement
+ yield (None, None, None, usage)
+ return
+ (current_arg, _, arg) = arg.partition(' ')
+ try:
+ seconds = float(current_arg)
+ continue
+ except ValueError:
+ pass
+ if current_arg == '-c':
+ clear_screen = True
+ continue
+ statement = '{0!s} {1!s}'.format(current_arg, arg)
+ destructive_prompt = confirm_destructive_query(statement)
+ if destructive_prompt is False:
+ click.secho("Wise choice!")
+ return
+ elif destructive_prompt is True:
+ click.secho("Your call!")
+ cur = kwargs['cur']
+ sql_list = [
+ (sql.rstrip(';'), "> {0!s}".format(sql))
+ for sql in sqlparse.split(statement)
+ ]
+ old_pager_enabled = is_pager_enabled()
+ while True:
+ if clear_screen:
+ click.clear()
+ try:
+ # Somewhere in the code the pager its activated after every yield,
+ # so we disable it in every iteration
+ set_pager_enabled(False)
+ for (sql, title) in sql_list:
+ cur.execute(sql)
+ if cur.description:
+ headers = [x[0] for x in cur.description]
+ yield (title, cur, headers, None)
+ else:
+ yield (title, None, None, None)
+ sleep(seconds)
+ except KeyboardInterrupt:
+ # This prints the Ctrl-C character in its own line, which prevents
+ # to print a line with the cursor positioned behind the prompt
+ click.secho("", nl=True)
+ return
+ finally:
+ set_pager_enabled(old_pager_enabled)
+
+
+@export
+@special_command('delimiter', None, 'Change SQL delimiter.')
+def set_delimiter(arg, **_):
+ return delimiter_command.set(arg)
+
+
+@export
+def get_current_delimiter():
+ return delimiter_command.current
+
+
+@export
+def split_queries(input):
+ for query in delimiter_command.queries_iter(input):
+ yield query
diff --git a/mycli/packages/special/main.py b/mycli/packages/special/main.py
new file mode 100644
index 0000000..ab04f30
--- /dev/null
+++ b/mycli/packages/special/main.py
@@ -0,0 +1,120 @@
+import logging
+from collections import namedtuple
+
+from . import export
+
+log = logging.getLogger(__name__)
+
+NO_QUERY = 0
+PARSED_QUERY = 1
+RAW_QUERY = 2
+
+SpecialCommand = namedtuple('SpecialCommand',
+ ['handler', 'command', 'shortcut', 'description', 'arg_type', 'hidden',
+ 'case_sensitive'])
+
+COMMANDS = {}
+
+@export
+class CommandNotFound(Exception):
+ pass
+
+@export
+def parse_special_command(sql):
+ command, _, arg = sql.partition(' ')
+ verbose = '+' in command
+ command = command.strip().replace('+', '')
+ return (command, verbose, arg.strip())
+
+@export
+def special_command(command, shortcut, description, arg_type=PARSED_QUERY,
+ hidden=False, case_sensitive=False, aliases=()):
+ def wrapper(wrapped):
+ register_special_command(wrapped, command, shortcut, description,
+ arg_type, hidden, case_sensitive, aliases)
+ return wrapped
+ return wrapper
+
+@export
+def register_special_command(handler, command, shortcut, description,
+ arg_type=PARSED_QUERY, hidden=False, case_sensitive=False, aliases=()):
+ cmd = command.lower() if not case_sensitive else command
+ COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description,
+ arg_type, hidden, case_sensitive)
+ for alias in aliases:
+ cmd = alias.lower() if not case_sensitive else alias
+ COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description,
+ arg_type, case_sensitive=case_sensitive,
+ hidden=True)
+
+@export
+def execute(cur, sql):
+ """Execute a special command and return the results. If the special command
+ is not supported a KeyError will be raised.
+ """
+ command, verbose, arg = parse_special_command(sql)
+
+ if (command not in COMMANDS) and (command.lower() not in COMMANDS):
+ raise CommandNotFound
+
+ try:
+ special_cmd = COMMANDS[command]
+ except KeyError:
+ special_cmd = COMMANDS[command.lower()]
+ if special_cmd.case_sensitive:
+ raise CommandNotFound('Command not found: %s' % command)
+
+ # "help <SQL KEYWORD> is a special case. We want built-in help, not
+ # mycli help here.
+ if command == 'help' and arg:
+ return show_keyword_help(cur=cur, arg=arg)
+
+ if special_cmd.arg_type == NO_QUERY:
+ return special_cmd.handler()
+ elif special_cmd.arg_type == PARSED_QUERY:
+ return special_cmd.handler(cur=cur, arg=arg, verbose=verbose)
+ elif special_cmd.arg_type == RAW_QUERY:
+ return special_cmd.handler(cur=cur, query=sql)
+
+@special_command('help', '\\?', 'Show this help.', arg_type=NO_QUERY, aliases=('\\?', '?'))
+def show_help(): # All the parameters are ignored.
+ headers = ['Command', 'Shortcut', 'Description']
+ result = []
+
+ for _, value in sorted(COMMANDS.items()):
+ if not value.hidden:
+ result.append((value.command, value.shortcut, value.description))
+ return [(None, result, headers, None)]
+
+def show_keyword_help(cur, arg):
+ """
+ Call the built-in "show <command>", to display help for an SQL keyword.
+ :param cur: cursor
+ :param arg: string
+ :return: list
+ """
+ keyword = arg.strip('"').strip("'")
+ query = "help '{0}'".format(keyword)
+ log.debug(query)
+ cur.execute(query)
+ if cur.description and cur.rowcount > 0:
+ headers = [x[0] for x in cur.description]
+ return [(None, cur, headers, '')]
+ else:
+ return [(None, None, None, 'No help found for {0}.'.format(keyword))]
+
+
+@special_command('exit', '\\q', 'Exit.', arg_type=NO_QUERY, aliases=('\\q', ))
+@special_command('quit', '\\q', 'Quit.', arg_type=NO_QUERY)
+def quit(*_args):
+ raise EOFError
+
+
+@special_command('\\e', '\\e', 'Edit command with editor (uses $EDITOR).',
+ arg_type=NO_QUERY, case_sensitive=True)
+@special_command('\\clip', '\\clip', 'Copy query to the system clipboard.',
+ arg_type=NO_QUERY, case_sensitive=True)
+@special_command('\\G', '\\G', 'Display current query results vertically.',
+ arg_type=NO_QUERY, case_sensitive=True)
+def stub():
+ raise NotImplementedError
diff --git a/mycli/packages/special/utils.py b/mycli/packages/special/utils.py
new file mode 100644
index 0000000..ef96093
--- /dev/null
+++ b/mycli/packages/special/utils.py
@@ -0,0 +1,46 @@
+import os
+import subprocess
+
+def handle_cd_command(arg):
+ """Handles a `cd` shell command by calling python's os.chdir."""
+ CD_CMD = 'cd'
+ tokens = arg.split(CD_CMD + ' ')
+ directory = tokens[-1] if len(tokens) > 1 else None
+ if not directory:
+ return False, "No folder name was provided."
+ try:
+ os.chdir(directory)
+ subprocess.call(['pwd'])
+ return True, None
+ except OSError as e:
+ return False, e.strerror
+
+def format_uptime(uptime_in_seconds):
+ """Format number of seconds into human-readable string.
+
+ :param uptime_in_seconds: The server uptime in seconds.
+ :returns: A human-readable string representing the uptime.
+
+ >>> uptime = format_uptime('56892')
+ >>> print(uptime)
+ 15 hours 48 min 12 sec
+ """
+
+ m, s = divmod(int(uptime_in_seconds), 60)
+ h, m = divmod(m, 60)
+ d, h = divmod(h, 24)
+
+ uptime_values = []
+
+ for value, unit in ((d, 'days'), (h, 'hours'), (m, 'min'), (s, 'sec')):
+ if value == 0 and not uptime_values:
+ # Don't include a value/unit if the unit isn't applicable to
+ # the uptime. E.g. don't do 0 days 0 hours 1 min 30 sec.
+ continue
+ elif value == 1 and unit.endswith('s'):
+ # Remove the "s" if the unit is singular.
+ unit = unit[:-1]
+ uptime_values.append('{0} {1}'.format(value, unit))
+
+ uptime = ' '.join(uptime_values)
+ return uptime
diff --git a/mycli/packages/tabular_output/__init__.py b/mycli/packages/tabular_output/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/mycli/packages/tabular_output/__init__.py
diff --git a/mycli/packages/tabular_output/sql_format.py b/mycli/packages/tabular_output/sql_format.py
new file mode 100644
index 0000000..e6587bd
--- /dev/null
+++ b/mycli/packages/tabular_output/sql_format.py
@@ -0,0 +1,62 @@
+"""Format adapter for sql."""
+
+from mycli.packages.parseutils import extract_tables
+
+supported_formats = ('sql-insert', 'sql-update', 'sql-update-1',
+ 'sql-update-2', )
+
+preprocessors = ()
+
+
+def escape_for_sql_statement(value):
+ if isinstance(value, bytes):
+ return f"X'{value.hex()}'"
+ else:
+ return formatter.mycli.sqlexecute.conn.escape(value)
+
+
+def adapter(data, headers, table_format=None, **kwargs):
+ tables = extract_tables(formatter.query)
+ if len(tables) > 0:
+ table = tables[0]
+ if table[0]:
+ table_name = "{}.{}".format(*table[:2])
+ else:
+ table_name = table[1]
+ else:
+ table_name = "`DUAL`"
+ if table_format == 'sql-insert':
+ h = "`, `".join(headers)
+ yield "INSERT INTO {} (`{}`) VALUES".format(table_name, h)
+ prefix = " "
+ for d in data:
+ values = ", ".join(escape_for_sql_statement(v)
+ for i, v in enumerate(d))
+ yield "{}({})".format(prefix, values)
+ if prefix == " ":
+ prefix = ", "
+ yield ";"
+ if table_format.startswith('sql-update'):
+ s = table_format.split('-')
+ keys = 1
+ if len(s) > 2:
+ keys = int(s[-1])
+ for d in data:
+ yield "UPDATE {} SET".format(table_name)
+ prefix = " "
+ for i, v in enumerate(d[keys:], keys):
+ yield "{}`{}` = {}".format(prefix, headers[i], escape_for_sql_statement(v))
+ if prefix == " ":
+ prefix = ", "
+ f = "`{}` = {}"
+ where = (f.format(headers[i], escape_for_sql_statement(
+ d[i])) for i in range(keys))
+ yield "WHERE {};".format(" AND ".join(where))
+
+
+def register_new_formatter(TabularOutputFormatter):
+ global formatter
+ formatter = TabularOutputFormatter
+ for sql_format in supported_formats:
+ TabularOutputFormatter.register_new_formatter(
+ sql_format, adapter, preprocessors, {'table_format': sql_format})
diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py
new file mode 100644
index 0000000..3656aa6
--- /dev/null
+++ b/mycli/sqlcompleter.py
@@ -0,0 +1,435 @@
+import logging
+from re import compile, escape
+from collections import Counter
+
+from prompt_toolkit.completion import Completer, Completion
+
+from .packages.completion_engine import suggest_type
+from .packages.parseutils import last_word
+from .packages.filepaths import parse_path, complete_path, suggest_path
+from .packages.special.favoritequeries import FavoriteQueries
+
+_logger = logging.getLogger(__name__)
+
+
+class SQLCompleter(Completer):
+ keywords = ['ACCESS', 'ADD', 'ALL', 'ALTER TABLE', 'AND', 'ANY', 'AS',
+ 'ASC', 'AUTO_INCREMENT', 'BEFORE', 'BEGIN', 'BETWEEN',
+ 'BIGINT', 'BINARY', 'BY', 'CASE', 'CHANGE MASTER TO', 'CHAR',
+ 'CHARACTER SET', 'CHECK', 'COLLATE', 'COLUMN', 'COMMENT',
+ 'COMMIT', 'CONSTRAINT', 'CREATE', 'CURRENT',
+ 'CURRENT_TIMESTAMP', 'DATABASE', 'DATE', 'DECIMAL', 'DEFAULT',
+ 'DELETE FROM', 'DESC', 'DESCRIBE', 'DROP',
+ 'ELSE', 'END', 'ENGINE', 'ESCAPE', 'EXISTS', 'FILE', 'FLOAT',
+ 'FOR', 'FOREIGN KEY', 'FORMAT', 'FROM', 'FULL', 'FUNCTION',
+ 'GRANT', 'GROUP BY', 'HAVING', 'HOST', 'IDENTIFIED', 'IN',
+ 'INCREMENT', 'INDEX', 'INSERT INTO', 'INT', 'INTEGER',
+ 'INTERVAL', 'INTO', 'IS', 'JOIN', 'KEY', 'LEFT', 'LEVEL',
+ 'LIKE', 'LIMIT', 'LOCK', 'LOGS', 'LONG', 'MASTER',
+ 'MEDIUMINT', 'MODE', 'MODIFY', 'NOT', 'NULL', 'NUMBER',
+ 'OFFSET', 'ON', 'OPTION', 'OR', 'ORDER BY', 'OUTER', 'OWNER',
+ 'PASSWORD', 'PORT', 'PRIMARY', 'PRIVILEGES', 'PROCESSLIST',
+ 'PURGE', 'REFERENCES', 'REGEXP', 'RENAME', 'REPAIR', 'RESET',
+ 'REVOKE', 'RIGHT', 'ROLLBACK', 'ROW', 'ROWS', 'ROW_FORMAT',
+ 'SAVEPOINT', 'SELECT', 'SESSION', 'SET', 'SHARE', 'SHOW',
+ 'SLAVE', 'SMALLINT', 'SMALLINT', 'START', 'STOP', 'TABLE',
+ 'THEN', 'TINYINT', 'TO', 'TRANSACTION', 'TRIGGER', 'TRUNCATE',
+ 'UNION', 'UNIQUE', 'UNSIGNED', 'UPDATE', 'USE', 'USER',
+ 'USING', 'VALUES', 'VARCHAR', 'VIEW', 'WHEN', 'WHERE', 'WITH']
+
+ functions = ['AVG', 'CONCAT', 'COUNT', 'DISTINCT', 'FIRST', 'FORMAT',
+ 'FROM_UNIXTIME', 'LAST', 'LCASE', 'LEN', 'MAX', 'MID',
+ 'MIN', 'NOW', 'ROUND', 'SUM', 'TOP', 'UCASE', 'UNIX_TIMESTAMP']
+
+ show_items = []
+
+ change_items = ['MASTER_BIND', 'MASTER_HOST', 'MASTER_USER',
+ 'MASTER_PASSWORD', 'MASTER_PORT', 'MASTER_CONNECT_RETRY',
+ 'MASTER_HEARTBEAT_PERIOD', 'MASTER_LOG_FILE',
+ 'MASTER_LOG_POS', 'RELAY_LOG_FILE', 'RELAY_LOG_POS',
+ 'MASTER_SSL', 'MASTER_SSL_CA', 'MASTER_SSL_CAPATH',
+ 'MASTER_SSL_CERT', 'MASTER_SSL_KEY', 'MASTER_SSL_CIPHER',
+ 'MASTER_SSL_VERIFY_SERVER_CERT', 'IGNORE_SERVER_IDS']
+
+ users = []
+
+ def __init__(self, smart_completion=True, supported_formats=(), keyword_casing='auto'):
+ super(self.__class__, self).__init__()
+ self.smart_completion = smart_completion
+ self.reserved_words = set()
+ for x in self.keywords:
+ self.reserved_words.update(x.split())
+ self.name_pattern = compile(r"^[_a-z][_a-z0-9\$]*$")
+
+ self.special_commands = []
+ self.table_formats = supported_formats
+ if keyword_casing not in ('upper', 'lower', 'auto'):
+ keyword_casing = 'auto'
+ self.keyword_casing = keyword_casing
+ self.reset_completions()
+
+ def escape_name(self, name):
+ if name and ((not self.name_pattern.match(name))
+ or (name.upper() in self.reserved_words)
+ or (name.upper() in self.functions)):
+ name = '`%s`' % name
+
+ return name
+
+ def unescape_name(self, name):
+ """Unquote a string."""
+ if name and name[0] == '"' and name[-1] == '"':
+ name = name[1:-1]
+
+ return name
+
+ def escaped_names(self, names):
+ return [self.escape_name(name) for name in names]
+
+ def extend_special_commands(self, special_commands):
+ # Special commands are not part of all_completions since they can only
+ # be at the beginning of a line.
+ self.special_commands.extend(special_commands)
+
+ def extend_database_names(self, databases):
+ self.databases.extend(databases)
+
+ def extend_keywords(self, additional_keywords):
+ self.keywords.extend(additional_keywords)
+ self.all_completions.update(additional_keywords)
+
+ def extend_show_items(self, show_items):
+ for show_item in show_items:
+ self.show_items.extend(show_item)
+ self.all_completions.update(show_item)
+
+ def extend_change_items(self, change_items):
+ for change_item in change_items:
+ self.change_items.extend(change_item)
+ self.all_completions.update(change_item)
+
+ def extend_users(self, users):
+ for user in users:
+ self.users.extend(user)
+ self.all_completions.update(user)
+
+ def extend_schemata(self, schema):
+ if schema is None:
+ return
+ metadata = self.dbmetadata['tables']
+ metadata[schema] = {}
+
+ # dbmetadata.values() are the 'tables' and 'functions' dicts
+ for metadata in self.dbmetadata.values():
+ metadata[schema] = {}
+ self.all_completions.update(schema)
+
+ def extend_relations(self, data, kind):
+ """Extend metadata for tables or views
+
+ :param data: list of (rel_name, ) tuples
+ :param kind: either 'tables' or 'views'
+ :return:
+ """
+ # 'data' is a generator object. It can throw an exception while being
+ # consumed. This could happen if the user has launched the app without
+ # specifying a database name. This exception must be handled to prevent
+ # crashing.
+ try:
+ data = [self.escaped_names(d) for d in data]
+ except Exception:
+ data = []
+
+ # dbmetadata['tables'][$schema_name][$table_name] should be a list of
+ # column names. Default to an asterisk
+ metadata = self.dbmetadata[kind]
+ for relname in data:
+ try:
+ metadata[self.dbname][relname[0]] = ['*']
+ except KeyError:
+ _logger.error('%r %r listed in unrecognized schema %r',
+ kind, relname[0], self.dbname)
+ self.all_completions.add(relname[0])
+
+ def extend_columns(self, column_data, kind):
+ """Extend column metadata
+
+ :param column_data: list of (rel_name, column_name) tuples
+ :param kind: either 'tables' or 'views'
+ :return:
+ """
+ # 'column_data' is a generator object. It can throw an exception while
+ # being consumed. This could happen if the user has launched the app
+ # without specifying a database name. This exception must be handled to
+ # prevent crashing.
+ try:
+ column_data = [self.escaped_names(d) for d in column_data]
+ except Exception:
+ column_data = []
+
+ metadata = self.dbmetadata[kind]
+ for relname, column in column_data:
+ metadata[self.dbname][relname].append(column)
+ self.all_completions.add(column)
+
+ def extend_functions(self, func_data):
+ # 'func_data' is a generator object. It can throw an exception while
+ # being consumed. This could happen if the user has launched the app
+ # without specifying a database name. This exception must be handled to
+ # prevent crashing.
+ try:
+ func_data = [self.escaped_names(d) for d in func_data]
+ except Exception:
+ func_data = []
+
+ # dbmetadata['functions'][$schema_name][$function_name] should return
+ # function metadata.
+ metadata = self.dbmetadata['functions']
+
+ for func in func_data:
+ metadata[self.dbname][func[0]] = None
+ self.all_completions.add(func[0])
+
+ def set_dbname(self, dbname):
+ self.dbname = dbname
+
+ def reset_completions(self):
+ self.databases = []
+ self.users = []
+ self.show_items = []
+ self.dbname = ''
+ self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {}}
+ self.all_completions = set(self.keywords + self.functions)
+
+ @staticmethod
+ def find_matches(text, collection, start_only=False, fuzzy=True, casing=None):
+ """Find completion matches for the given text.
+
+ Given the user's input text and a collection of available
+ completions, find completions matching the last word of the
+ text.
+
+ If `start_only` is True, the text will match an available
+ completion only at the beginning. Otherwise, a completion is
+ considered a match if the text appears anywhere within it.
+
+ yields prompt_toolkit Completion instances for any matches found
+ in the collection of available completions.
+ """
+ last = last_word(text, include='most_punctuations')
+ text = last.lower()
+
+ completions = []
+
+ if fuzzy:
+ regex = '.*?'.join(map(escape, text))
+ pat = compile('(%s)' % regex)
+ for item in sorted(collection):
+ r = pat.search(item.lower())
+ if r:
+ completions.append((len(r.group()), r.start(), item))
+ else:
+ match_end_limit = len(text) if start_only else None
+ for item in sorted(collection):
+ match_point = item.lower().find(text, 0, match_end_limit)
+ if match_point >= 0:
+ completions.append((len(text), match_point, item))
+
+ if casing == 'auto':
+ casing = 'lower' if last and last[-1].islower() else 'upper'
+
+ def apply_case(kw):
+ if casing == 'upper':
+ return kw.upper()
+ return kw.lower()
+
+ return (Completion(z if casing is None else apply_case(z), -len(text))
+ for x, y, z in sorted(completions))
+
+ def get_completions(self, document, complete_event, smart_completion=None):
+ word_before_cursor = document.get_word_before_cursor(WORD=True)
+ if smart_completion is None:
+ smart_completion = self.smart_completion
+
+ # If smart_completion is off then match any word that starts with
+ # 'word_before_cursor'.
+ if not smart_completion:
+ return self.find_matches(word_before_cursor, self.all_completions,
+ start_only=True, fuzzy=False)
+
+ completions = []
+ suggestions = suggest_type(document.text, document.text_before_cursor)
+
+ for suggestion in suggestions:
+
+ _logger.debug('Suggestion type: %r', suggestion['type'])
+
+ if suggestion['type'] == 'column':
+ tables = suggestion['tables']
+ _logger.debug("Completion column scope: %r", tables)
+ scoped_cols = self.populate_scoped_cols(tables)
+ if suggestion.get('drop_unique'):
+ # drop_unique is used for 'tb11 JOIN tbl2 USING (...'
+ # which should suggest only columns that appear in more than
+ # one table
+ scoped_cols = [
+ col for (col, count) in Counter(scoped_cols).items()
+ if count > 1 and col != '*'
+ ]
+
+ cols = self.find_matches(word_before_cursor, scoped_cols)
+ completions.extend(cols)
+
+ elif suggestion['type'] == 'function':
+ # suggest user-defined functions using substring matching
+ funcs = self.populate_schema_objects(suggestion['schema'],
+ 'functions')
+ user_funcs = self.find_matches(word_before_cursor, funcs)
+ completions.extend(user_funcs)
+
+ # suggest hardcoded functions using startswith matching only if
+ # there is no schema qualifier. If a schema qualifier is
+ # present it probably denotes a table.
+ # eg: SELECT * FROM users u WHERE u.
+ if not suggestion['schema']:
+ predefined_funcs = self.find_matches(word_before_cursor,
+ self.functions,
+ start_only=True,
+ fuzzy=False,
+ casing=self.keyword_casing)
+ completions.extend(predefined_funcs)
+
+ elif suggestion['type'] == 'table':
+ tables = self.populate_schema_objects(suggestion['schema'],
+ 'tables')
+ tables = self.find_matches(word_before_cursor, tables)
+ completions.extend(tables)
+
+ elif suggestion['type'] == 'view':
+ views = self.populate_schema_objects(suggestion['schema'],
+ 'views')
+ views = self.find_matches(word_before_cursor, views)
+ completions.extend(views)
+
+ elif suggestion['type'] == 'alias':
+ aliases = suggestion['aliases']
+ aliases = self.find_matches(word_before_cursor, aliases)
+ completions.extend(aliases)
+
+ elif suggestion['type'] == 'database':
+ dbs = self.find_matches(word_before_cursor, self.databases)
+ completions.extend(dbs)
+
+ elif suggestion['type'] == 'keyword':
+ keywords = self.find_matches(word_before_cursor, self.keywords,
+ start_only=True,
+ fuzzy=False,
+ casing=self.keyword_casing)
+ completions.extend(keywords)
+
+ elif suggestion['type'] == 'show':
+ show_items = self.find_matches(word_before_cursor,
+ self.show_items,
+ start_only=False,
+ fuzzy=True,
+ casing=self.keyword_casing)
+ completions.extend(show_items)
+
+ elif suggestion['type'] == 'change':
+ change_items = self.find_matches(word_before_cursor,
+ self.change_items,
+ start_only=False,
+ fuzzy=True)
+ completions.extend(change_items)
+ elif suggestion['type'] == 'user':
+ users = self.find_matches(word_before_cursor, self.users,
+ start_only=False,
+ fuzzy=True)
+ completions.extend(users)
+
+ elif suggestion['type'] == 'special':
+ special = self.find_matches(word_before_cursor,
+ self.special_commands,
+ start_only=True,
+ fuzzy=False)
+ completions.extend(special)
+ elif suggestion['type'] == 'favoritequery':
+ queries = self.find_matches(word_before_cursor,
+ FavoriteQueries.instance.list(),
+ start_only=False, fuzzy=True)
+ completions.extend(queries)
+ elif suggestion['type'] == 'table_format':
+ formats = self.find_matches(word_before_cursor,
+ self.table_formats,
+ start_only=True, fuzzy=False)
+ completions.extend(formats)
+ elif suggestion['type'] == 'file_name':
+ file_names = self.find_files(word_before_cursor)
+ completions.extend(file_names)
+
+ return completions
+
+ def find_files(self, word):
+ """Yield matching directory or file names.
+
+ :param word:
+ :return: iterable
+
+ """
+ base_path, last_path, position = parse_path(word)
+ paths = suggest_path(word)
+ for name in sorted(paths):
+ suggestion = complete_path(name, last_path)
+ if suggestion:
+ yield Completion(suggestion, position)
+
+ def populate_scoped_cols(self, scoped_tbls):
+ """Find all columns in a set of scoped_tables
+ :param scoped_tbls: list of (schema, table, alias) tuples
+ :return: list of column names
+ """
+ columns = []
+ meta = self.dbmetadata
+
+ for tbl in scoped_tbls:
+ # A fully qualified schema.relname reference or default_schema
+ # DO NOT escape schema names.
+ schema = tbl[0] or self.dbname
+ relname = tbl[1]
+ escaped_relname = self.escape_name(tbl[1])
+
+ # We don't know if schema.relname is a table or view. Since
+ # tables and views cannot share the same name, we can check one
+ # at a time
+ try:
+ columns.extend(meta['tables'][schema][relname])
+
+ # Table exists, so don't bother checking for a view
+ continue
+ except KeyError:
+ try:
+ columns.extend(meta['tables'][schema][escaped_relname])
+ # Table exists, so don't bother checking for a view
+ continue
+ except KeyError:
+ pass
+
+ try:
+ columns.extend(meta['views'][schema][relname])
+ except KeyError:
+ pass
+
+ return columns
+
+ def populate_schema_objects(self, schema, obj_type):
+ """Returns list of tables or functions for a (optional) schema"""
+ metadata = self.dbmetadata[obj_type]
+ schema = schema or self.dbname
+
+ try:
+ objects = metadata[schema].keys()
+ except KeyError:
+ # schema doesn't exist
+ objects = []
+
+ return objects
diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py
new file mode 100644
index 0000000..c019707
--- /dev/null
+++ b/mycli/sqlexecute.py
@@ -0,0 +1,356 @@
+import enum
+import logging
+import re
+
+import pymysql
+from .packages import special
+from pymysql.constants import FIELD_TYPE
+from pymysql.converters import (convert_datetime,
+ convert_timedelta, convert_date, conversions,
+ decoders)
+try:
+ import paramiko
+except ImportError:
+ from mycli.packages.paramiko_stub import paramiko
+
+_logger = logging.getLogger(__name__)
+
+FIELD_TYPES = decoders.copy()
+FIELD_TYPES.update({
+ FIELD_TYPE.NULL: type(None)
+})
+
+
+ERROR_CODE_ACCESS_DENIED = 1045
+
+
+class ServerSpecies(enum.Enum):
+ MySQL = 'MySQL'
+ MariaDB = 'MariaDB'
+ Percona = 'Percona'
+ TiDB = 'TiDB'
+ Unknown = 'MySQL'
+
+
+class ServerInfo:
+ def __init__(self, species, version_str):
+ self.species = species
+ self.version_str = version_str
+ self.version = self.calc_mysql_version_value(version_str)
+
+ @staticmethod
+ def calc_mysql_version_value(version_str) -> int:
+ if not version_str or not isinstance(version_str, str):
+ return 0
+ try:
+ major, minor, patch = version_str.split('.')
+ except ValueError:
+ return 0
+ else:
+ return int(major) * 10_000 + int(minor) * 100 + int(patch)
+
+ @classmethod
+ def from_version_string(cls, version_string):
+ if not version_string:
+ return cls(ServerSpecies.Unknown, '')
+
+ re_species = (
+ (r'(?P<version>[0-9\.]+)-MariaDB', ServerSpecies.MariaDB),
+ (r'(?P<version>[0-9\.]+)[a-z0-9]*-TiDB', ServerSpecies.TiDB),
+ (r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[0-9]+$)',
+ ServerSpecies.Percona),
+ (r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[A-Za-z0-9_]+)',
+ ServerSpecies.MySQL),
+ )
+ for regexp, species in re_species:
+ match = re.search(regexp, version_string)
+ if match is not None:
+ parsed_version = match.group('version')
+ detected_species = species
+ break
+ else:
+ detected_species = ServerSpecies.Unknown
+ parsed_version = ''
+
+ return cls(detected_species, parsed_version)
+
+ def __str__(self):
+ if self.species:
+ return f'{self.species.value} {self.version_str}'
+ else:
+ return self.version_str
+
+
+class SQLExecute(object):
+
+ databases_query = '''SHOW DATABASES'''
+
+ tables_query = '''SHOW TABLES'''
+
+ show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"'''
+
+ users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user'''
+
+ functions_query = '''SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES
+ WHERE ROUTINE_TYPE="FUNCTION" AND ROUTINE_SCHEMA = "%s"'''
+
+ table_columns_query = '''select TABLE_NAME, COLUMN_NAME from information_schema.columns
+ where table_schema = '%s'
+ order by table_name,ordinal_position'''
+
+ def __init__(self, database, user, password, host, port, socket, charset,
+ local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password,
+ ssh_key_filename, init_command=None):
+ self.dbname = database
+ self.user = user
+ self.password = password
+ self.host = host
+ self.port = port
+ self.socket = socket
+ self.charset = charset
+ self.local_infile = local_infile
+ self.ssl = ssl
+ self.server_info = None
+ self.connection_id = None
+ self.ssh_user = ssh_user
+ self.ssh_host = ssh_host
+ self.ssh_port = ssh_port
+ self.ssh_password = ssh_password
+ self.ssh_key_filename = ssh_key_filename
+ self.init_command = init_command
+ self.connect()
+
+ def connect(self, database=None, user=None, password=None, host=None,
+ port=None, socket=None, charset=None, local_infile=None,
+ ssl=None, ssh_host=None, ssh_port=None, ssh_user=None,
+ ssh_password=None, ssh_key_filename=None, init_command=None):
+ db = (database or self.dbname)
+ user = (user or self.user)
+ password = (password or self.password)
+ host = (host or self.host)
+ port = (port or self.port)
+ socket = (socket or self.socket)
+ charset = (charset or self.charset)
+ local_infile = (local_infile or self.local_infile)
+ ssl = (ssl or self.ssl)
+ ssh_user = (ssh_user or self.ssh_user)
+ ssh_host = (ssh_host or self.ssh_host)
+ ssh_port = (ssh_port or self.ssh_port)
+ ssh_password = (ssh_password or self.ssh_password)
+ ssh_key_filename = (ssh_key_filename or self.ssh_key_filename)
+ init_command = (init_command or self.init_command)
+ _logger.debug(
+ 'Connection DB Params: \n'
+ '\tdatabase: %r'
+ '\tuser: %r'
+ '\thost: %r'
+ '\tport: %r'
+ '\tsocket: %r'
+ '\tcharset: %r'
+ '\tlocal_infile: %r'
+ '\tssl: %r'
+ '\tssh_user: %r'
+ '\tssh_host: %r'
+ '\tssh_port: %r'
+ '\tssh_password: %r'
+ '\tssh_key_filename: %r'
+ '\tinit_command: %r',
+ db, user, host, port, socket, charset, local_infile, ssl,
+ ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename,
+ init_command
+ )
+ conv = conversions.copy()
+ conv.update({
+ FIELD_TYPE.TIMESTAMP: lambda obj: (convert_datetime(obj) or obj),
+ FIELD_TYPE.DATETIME: lambda obj: (convert_datetime(obj) or obj),
+ FIELD_TYPE.TIME: lambda obj: (convert_timedelta(obj) or obj),
+ FIELD_TYPE.DATE: lambda obj: (convert_date(obj) or obj),
+ })
+
+ defer_connect = False
+
+ if ssh_host:
+ defer_connect = True
+
+ client_flag = pymysql.constants.CLIENT.INTERACTIVE
+ if init_command and len(list(special.split_queries(init_command))) > 1:
+ client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS
+
+ conn = pymysql.connect(
+ database=db, user=user, password=password, host=host, port=port,
+ unix_socket=socket, use_unicode=True, charset=charset,
+ autocommit=True, client_flag=client_flag,
+ local_infile=local_infile, conv=conv, ssl=ssl, program_name="mycli",
+ defer_connect=defer_connect, init_command=init_command
+ )
+
+ if ssh_host:
+ client = paramiko.SSHClient()
+ client.load_system_host_keys()
+ client.set_missing_host_key_policy(paramiko.WarningPolicy())
+ client.connect(
+ ssh_host, ssh_port, ssh_user, ssh_password,
+ key_filename=ssh_key_filename
+ )
+ chan = client.get_transport().open_channel(
+ 'direct-tcpip',
+ (host, port),
+ ('0.0.0.0', 0),
+ )
+ conn.connect(chan)
+
+ if hasattr(self, 'conn'):
+ self.conn.close()
+ self.conn = conn
+ # Update them after the connection is made to ensure that it was a
+ # successful connection.
+ self.dbname = db
+ self.user = user
+ self.password = password
+ self.host = host
+ self.port = port
+ self.socket = socket
+ self.charset = charset
+ self.ssl = ssl
+ self.init_command = init_command
+ # retrieve connection id
+ self.reset_connection_id()
+ self.server_info = ServerInfo.from_version_string(conn.server_version)
+
+ def run(self, statement):
+ """Execute the sql in the database and return the results. The results
+ are a list of tuples. Each tuple has 4 values
+ (title, rows, headers, status).
+ """
+
+ # Remove spaces and EOL
+ statement = statement.strip()
+ if not statement: # Empty string
+ yield (None, None, None, None)
+
+ # Split the sql into separate queries and run each one.
+ # Unless it's saving a favorite query, in which case we
+ # want to save them all together.
+ if statement.startswith('\\fs'):
+ components = [statement]
+ else:
+ components = special.split_queries(statement)
+
+ for sql in components:
+ # \G is treated specially since we have to set the expanded output.
+ if sql.endswith('\\G'):
+ special.set_expanded_output(True)
+ sql = sql[:-2].strip()
+
+ cur = self.conn.cursor()
+ try: # Special command
+ _logger.debug('Trying a dbspecial command. sql: %r', sql)
+ for result in special.execute(cur, sql):
+ yield result
+ except special.CommandNotFound: # Regular SQL
+ _logger.debug('Regular sql statement. sql: %r', sql)
+ cur.execute(sql)
+ while True:
+ yield self.get_result(cur)
+
+ # PyMySQL returns an extra, empty result set with stored
+ # procedures. We skip it (rowcount is zero and no
+ # description).
+ if not cur.nextset() or (not cur.rowcount and cur.description is None):
+ break
+
+ def get_result(self, cursor):
+ """Get the current result's data from the cursor."""
+ title = headers = None
+
+ # cursor.description is not None for queries that return result sets,
+ # e.g. SELECT or SHOW.
+ if cursor.description is not None:
+ headers = [x[0] for x in cursor.description]
+ status = '{0} row{1} in set'
+ else:
+ _logger.debug('No rows in result.')
+ status = 'Query OK, {0} row{1} affected'
+ status = status.format(cursor.rowcount,
+ '' if cursor.rowcount == 1 else 's')
+
+ return (title, cursor if cursor.description else None, headers, status)
+
+ def tables(self):
+ """Yields table names"""
+
+ with self.conn.cursor() as cur:
+ _logger.debug('Tables Query. sql: %r', self.tables_query)
+ cur.execute(self.tables_query)
+ for row in cur:
+ yield row
+
+ def table_columns(self):
+ """Yields (table name, column name) pairs"""
+ with self.conn.cursor() as cur:
+ _logger.debug('Columns Query. sql: %r', self.table_columns_query)
+ cur.execute(self.table_columns_query % self.dbname)
+ for row in cur:
+ yield row
+
+ def databases(self):
+ with self.conn.cursor() as cur:
+ _logger.debug('Databases Query. sql: %r', self.databases_query)
+ cur.execute(self.databases_query)
+ return [x[0] for x in cur.fetchall()]
+
+ def functions(self):
+ """Yields tuples of (schema_name, function_name)"""
+
+ with self.conn.cursor() as cur:
+ _logger.debug('Functions Query. sql: %r', self.functions_query)
+ cur.execute(self.functions_query % self.dbname)
+ for row in cur:
+ yield row
+
+ def show_candidates(self):
+ with self.conn.cursor() as cur:
+ _logger.debug('Show Query. sql: %r', self.show_candidates_query)
+ try:
+ cur.execute(self.show_candidates_query)
+ except pymysql.DatabaseError as e:
+ _logger.error('No show completions due to %r', e)
+ yield ''
+ else:
+ for row in cur:
+ yield (row[0].split(None, 1)[-1], )
+
+ def users(self):
+ with self.conn.cursor() as cur:
+ _logger.debug('Users Query. sql: %r', self.users_query)
+ try:
+ cur.execute(self.users_query)
+ except pymysql.DatabaseError as e:
+ _logger.error('No user completions due to %r', e)
+ yield ''
+ else:
+ for row in cur:
+ yield row
+
+ def get_connection_id(self):
+ if not self.connection_id:
+ self.reset_connection_id()
+ return self.connection_id
+
+ def reset_connection_id(self):
+ # Remember current connection id
+ _logger.debug('Get current connection id')
+ try:
+ res = self.run('select connection_id()')
+ for title, cur, headers, status in res:
+ self.connection_id = cur.fetchone()[0]
+ except Exception as e:
+ # See #1054
+ self.connection_id = -1
+ _logger.error('Failed to get connection id: %s', e)
+ else:
+ _logger.debug('Current connection id: %s', self.connection_id)
+
+ def change_db(self, db):
+ self.conn.select_db(db)
+ self.dbname = db
diff --git a/pytest.ini b/pytest.ini
new file mode 100644
index 0000000..5422131
--- /dev/null
+++ b/pytest.ini
@@ -0,0 +1,2 @@
+[pytest]
+addopts = --ignore=mycli/packages/paramiko_stub/__init__.py
diff --git a/release.py b/release.py
new file mode 100755
index 0000000..62daa80
--- /dev/null
+++ b/release.py
@@ -0,0 +1,119 @@
+"""A script to publish a release of mycli to PyPI."""
+
+from optparse import OptionParser
+import re
+import subprocess
+import sys
+
+import click
+
+DEBUG = False
+CONFIRM_STEPS = False
+DRY_RUN = False
+
+
+def skip_step():
+ """
+ Asks for user's response whether to run a step. Default is yes.
+ :return: boolean
+ """
+ global CONFIRM_STEPS
+
+ if CONFIRM_STEPS:
+ return not click.confirm('--- Run this step?', default=True)
+ return False
+
+
+def run_step(*args):
+ """
+ Prints out the command and asks if it should be run.
+ If yes (default), runs it.
+ :param args: list of strings (command and args)
+ """
+ global DRY_RUN
+
+ cmd = args
+ print(' '.join(cmd))
+ if skip_step():
+ print('--- Skipping...')
+ elif DRY_RUN:
+ print('--- Pretending to run...')
+ else:
+ subprocess.check_output(cmd)
+
+
+def version(version_file):
+ _version_re = re.compile(
+ r'__version__\s+=\s+(?P<quote>[\'"])(?P<version>.*)(?P=quote)')
+
+ with open(version_file) as f:
+ ver = _version_re.search(f.read()).group('version')
+
+ return ver
+
+
+def commit_for_release(version_file, ver):
+ run_step('git', 'reset')
+ run_step('git', 'add', version_file)
+ run_step('git', 'commit', '--message',
+ 'Releasing version {}'.format(ver))
+
+
+def create_git_tag(tag_name):
+ run_step('git', 'tag', tag_name)
+
+
+def create_distribution_files():
+ run_step('python', 'setup.py', 'sdist', 'bdist_wheel')
+
+
+def upload_distribution_files():
+ run_step('twine', 'upload', 'dist/*')
+
+
+def push_to_github():
+ run_step('git', 'push', 'origin', 'main')
+
+
+def push_tags_to_github():
+ run_step('git', 'push', '--tags', 'origin')
+
+
+def checklist(questions):
+ for question in questions:
+ if not click.confirm('--- {}'.format(question), default=False):
+ sys.exit(1)
+
+
+if __name__ == '__main__':
+ if DEBUG:
+ subprocess.check_output = lambda x: x
+
+ ver = version('mycli/__init__.py')
+
+ parser = OptionParser()
+ parser.add_option(
+ "-c", "--confirm-steps", action="store_true", dest="confirm_steps",
+ default=False, help=("Confirm every step. If the step is not "
+ "confirmed, it will be skipped.")
+ )
+ parser.add_option(
+ "-d", "--dry-run", action="store_true", dest="dry_run",
+ default=False, help="Print out, but not actually run any steps."
+ )
+
+ popts, pargs = parser.parse_args()
+ CONFIRM_STEPS = popts.confirm_steps
+ DRY_RUN = popts.dry_run
+
+ print('Releasing Version:', ver)
+
+ if not click.confirm('Are you sure?', default=False):
+ sys.exit(1)
+
+ commit_for_release('mycli/__init__.py', ver)
+ create_git_tag('v{}'.format(ver))
+ create_distribution_files()
+ push_to_github()
+ push_tags_to_github()
+ upload_distribution_files()
diff --git a/requirements-dev.txt b/requirements-dev.txt
new file mode 100644
index 0000000..955a9f5
--- /dev/null
+++ b/requirements-dev.txt
@@ -0,0 +1,17 @@
+pytest>=3.3.0
+pytest-cov>=2.4.0
+tox
+twine>=1.12.1
+behave>=1.2.4
+pexpect>=3.3
+coverage>=5.0.4
+codecov>=2.0.9
+autopep8==1.3.3
+colorama>=0.4.1
+git+https://github.com/hayd/pep8radius.git # --error-status option not released
+click>=7.0
+paramiko==2.11.0
+pyperclip>=1.8.1
+importlib_resources>=5.0.0
+pyaes>=1.6.1
+sqlglot>=5.1.3
diff --git a/screenshots/main.gif b/screenshots/main.gif
new file mode 100644
index 0000000..8973195
--- /dev/null
+++ b/screenshots/main.gif
Binary files differ
diff --git a/screenshots/tables.png b/screenshots/tables.png
new file mode 100644
index 0000000..1d6afcf
--- /dev/null
+++ b/screenshots/tables.png
Binary files differ
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..e533c7b
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,18 @@
+[bdist_wheel]
+universal = 1
+
+[tool:pytest]
+addopts = --capture=sys
+ --showlocals
+ --doctest-modules
+ --doctest-ignore-import-errors
+ --ignore=setup.py
+ --ignore=mycli/magic.py
+ --ignore=mycli/packages/parseutils.py
+ --ignore=test/features
+
+[pep8]
+rev = master
+docformatter = True
+diff = True
+error-status = True
diff --git a/setup.py b/setup.py
new file mode 100755
index 0000000..2f69672
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,127 @@
+#!/usr/bin/env python
+
+import ast
+import re
+import subprocess
+import sys
+
+from setuptools import Command, find_packages, setup
+from setuptools.command.test import test as TestCommand
+
+_version_re = re.compile(r'__version__\s+=\s+(.*)')
+
+with open('mycli/__init__.py') as f:
+ version = ast.literal_eval(_version_re.search(
+ f.read()).group(1))
+
+description = 'CLI for MySQL Database. With auto-completion and syntax highlighting.'
+
+install_requirements = [
+ 'click >= 7.0',
+ # Temporary to suppress paramiko Blowfish warning which breaks CI.
+ # Pinning cryptography should not be needed after paramiko 2.11.0.
+ 'cryptography == 36.0.2',
+ # 'Pygments>=1.6,<=2.11.1',
+ 'Pygments>=1.6',
+ 'prompt_toolkit>=3.0.6,<4.0.0',
+ 'PyMySQL >= 0.9.2',
+ 'sqlparse>=0.3.0,<0.5.0',
+ 'sqlglot>=5.1.3',
+ 'configobj >= 5.0.5',
+ 'cli_helpers[styles] >= 2.2.1',
+ 'pyperclip >= 1.8.1',
+ 'pyaes >= 1.6.1'
+]
+
+if sys.version_info.minor < 9:
+ install_requirements.append('importlib_resources >= 5.0.0')
+
+
+class lint(Command):
+ description = 'check code against PEP 8 (and fix violations)'
+
+ user_options = [
+ ('branch=', 'b', 'branch/revision to compare against (e.g. main)'),
+ ('fix', 'f', 'fix the violations in place'),
+ ('error-status', 'e', 'return an error code on failed PEP check'),
+ ]
+
+ def initialize_options(self):
+ """Set the default options."""
+ self.branch = 'main'
+ self.fix = False
+ self.error_status = True
+
+ def finalize_options(self):
+ pass
+
+ def run(self):
+ cmd = 'pep8radius {}'.format(self.branch)
+ if self.fix:
+ cmd += ' --in-place'
+ if self.error_status:
+ cmd += ' --error-status'
+ sys.exit(subprocess.call(cmd, shell=True))
+
+
+class test(TestCommand):
+
+ user_options = [
+ ('pytest-args=', 'a', 'Arguments to pass to pytest'),
+ ('behave-args=', 'b', 'Arguments to pass to pytest')
+ ]
+
+ def initialize_options(self):
+ TestCommand.initialize_options(self)
+ self.pytest_args = ''
+ self.behave_args = '--no-capture'
+
+ def run_tests(self):
+ unit_test_errno = subprocess.call(
+ 'pytest test/ ' + self.pytest_args,
+ shell=True
+ )
+ cli_errno = subprocess.call(
+ 'behave test/features ' + self.behave_args,
+ shell=True
+ )
+ subprocess.run(['git', 'checkout', '--', 'test/myclirc'], check=False)
+ sys.exit(unit_test_errno or cli_errno)
+
+
+setup(
+ name='mycli',
+ author='Mycli Core Team',
+ author_email='mycli-dev@googlegroups.com',
+ version=version,
+ url='http://mycli.net',
+ packages=find_packages(),
+ package_data={'mycli': ['myclirc', 'AUTHORS', 'SPONSORS']},
+ description=description,
+ long_description=description,
+ install_requires=install_requirements,
+ entry_points={
+ 'console_scripts': ['mycli = mycli.main:cli'],
+ },
+ cmdclass={'lint': lint, 'test': test},
+ python_requires=">=3.7",
+ classifiers=[
+ 'Intended Audience :: Developers',
+ 'License :: OSI Approved :: BSD License',
+ 'Operating System :: Unix',
+ 'Programming Language :: Python',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.7',
+ 'Programming Language :: Python :: 3.8',
+ 'Programming Language :: Python :: 3.9',
+ 'Programming Language :: Python :: 3.10',
+ 'Programming Language :: SQL',
+ 'Topic :: Database',
+ 'Topic :: Database :: Front-Ends',
+ 'Topic :: Software Development',
+ 'Topic :: Software Development :: Libraries :: Python Modules',
+ ],
+ extras_require={
+ 'ssh': ['paramiko'],
+ },
+)
diff --git a/test/__init__.py b/test/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/__init__.py
diff --git a/test/conftest.py b/test/conftest.py
new file mode 100644
index 0000000..1325596
--- /dev/null
+++ b/test/conftest.py
@@ -0,0 +1,29 @@
+import pytest
+from .utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db,
+ db_connection, SSH_USER, SSH_HOST, SSH_PORT)
+import mycli.sqlexecute
+
+
+@pytest.fixture(scope="function")
+def connection():
+ create_db('mycli_test_db')
+ connection = db_connection('mycli_test_db')
+ yield connection
+
+ connection.close()
+
+
+@pytest.fixture
+def cursor(connection):
+ with connection.cursor() as cur:
+ return cur
+
+
+@pytest.fixture
+def executor(connection):
+ return mycli.sqlexecute.SQLExecute(
+ database='mycli_test_db', user=USER,
+ host=HOST, password=PASSWORD, port=PORT, socket=None, charset=CHARSET,
+ local_infile=False, ssl=None, ssh_user=SSH_USER, ssh_host=SSH_HOST,
+ ssh_port=SSH_PORT, ssh_password=None, ssh_key_filename=None
+ )
diff --git a/test/features/__init__.py b/test/features/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/features/__init__.py
diff --git a/test/features/auto_vertical.feature b/test/features/auto_vertical.feature
new file mode 100644
index 0000000..aa95718
--- /dev/null
+++ b/test/features/auto_vertical.feature
@@ -0,0 +1,12 @@
+Feature: auto_vertical mode:
+ on, off
+
+ Scenario: auto_vertical on with small query
+ When we run dbcli with --auto-vertical-output
+ and we execute a small query
+ then we see small results in horizontal format
+
+ Scenario: auto_vertical on with large query
+ When we run dbcli with --auto-vertical-output
+ and we execute a large query
+ then we see large results in vertical format
diff --git a/test/features/basic_commands.feature b/test/features/basic_commands.feature
new file mode 100644
index 0000000..a12e899
--- /dev/null
+++ b/test/features/basic_commands.feature
@@ -0,0 +1,19 @@
+Feature: run the cli,
+ call the help command,
+ exit the cli
+
+ Scenario: run "\?" command
+ When we send "\?" command
+ then we see help output
+
+ Scenario: run source command
+ When we send source command
+ then we see help output
+
+ Scenario: check our application_name
+ When we run query to check application_name
+ then we see found
+
+ Scenario: run the cli and exit
+ When we send "ctrl + d"
+ then dbcli exits
diff --git a/test/features/connection.feature b/test/features/connection.feature
new file mode 100644
index 0000000..b06935e
--- /dev/null
+++ b/test/features/connection.feature
@@ -0,0 +1,35 @@
+Feature: connect to a database:
+
+ @requires_local_db
+ Scenario: run mycli on localhost without port
+ When we run mycli with arguments "host=localhost" without arguments "port"
+ When we query "status"
+ Then status contains "via UNIX socket"
+
+ Scenario: run mycli on TCP host without port
+ When we run mycli without arguments "port"
+ When we query "status"
+ Then status contains "via TCP/IP"
+
+ Scenario: run mycli with port but without host
+ When we run mycli without arguments "host"
+ When we query "status"
+ Then status contains "via TCP/IP"
+
+ @requires_local_db
+ Scenario: run mycli without host and port
+ When we run mycli without arguments "host port"
+ When we query "status"
+ Then status contains "via UNIX socket"
+
+ Scenario: run mycli with my.cnf configuration
+ When we create my.cnf file
+ When we run mycli without arguments "host port user pass defaults_file"
+ Then we are logged in
+
+ Scenario: run mycli with mylogin.cnf configuration
+ When we create mylogin.cnf file
+ When we run mycli with arguments "login_path=test_login_path" without arguments "host port user pass defaults_file"
+ Then we are logged in
+
+
diff --git a/test/features/crud_database.feature b/test/features/crud_database.feature
new file mode 100644
index 0000000..f4a7a7f
--- /dev/null
+++ b/test/features/crud_database.feature
@@ -0,0 +1,30 @@
+Feature: manipulate databases:
+ create, drop, connect, disconnect
+
+ Scenario: create and drop temporary database
+ When we create database
+ then we see database created
+ when we drop database
+ then we confirm the destructive warning
+ then we see database dropped
+ when we connect to dbserver
+ then we see database connected
+
+ Scenario: connect and disconnect from test database
+ When we connect to test database
+ then we see database connected
+ when we connect to dbserver
+ then we see database connected
+
+ Scenario: connect and disconnect from quoted test database
+ When we connect to quoted test database
+ then we see database connected
+
+ Scenario: create and drop default database
+ When we create database
+ then we see database created
+ when we connect to tmp database
+ then we see database connected
+ when we drop database
+ then we confirm the destructive warning
+ then we see database dropped and no default database
diff --git a/test/features/crud_table.feature b/test/features/crud_table.feature
new file mode 100644
index 0000000..3384efd
--- /dev/null
+++ b/test/features/crud_table.feature
@@ -0,0 +1,49 @@
+Feature: manipulate tables:
+ create, insert, update, select, delete from, drop
+
+ Scenario: create, insert, select from, update, drop table
+ When we connect to test database
+ then we see database connected
+ when we create table
+ then we see table created
+ when we insert into table
+ then we see record inserted
+ when we update table
+ then we see record updated
+ when we select from table
+ then we see data selected
+ when we delete from table
+ then we confirm the destructive warning
+ then we see record deleted
+ when we drop table
+ then we confirm the destructive warning
+ then we see table dropped
+ when we connect to dbserver
+ then we see database connected
+
+ Scenario: select null values
+ When we connect to test database
+ then we see database connected
+ when we select null
+ then we see null selected
+
+ Scenario: confirm destructive query
+ When we query "create table foo(x integer);"
+ and we query "delete from foo;"
+ and we answer the destructive warning with "y"
+ then we see text "Your call!"
+
+ Scenario: decline destructive query
+ When we query "delete from foo;"
+ and we answer the destructive warning with "n"
+ then we see text "Wise choice!"
+
+ Scenario: no destructive warning if disabled in config
+ When we run dbcli with --no-warn
+ and we query "create table blabla(x integer);"
+ and we query "delete from blabla;"
+ Then we see text "Query OK"
+
+ Scenario: confirm destructive query with invalid response
+ When we query "delete from foo;"
+ then we answer the destructive warning with invalid "1" and see text "is not a valid boolean"
diff --git a/test/features/db_utils.py b/test/features/db_utils.py
new file mode 100644
index 0000000..be550e9
--- /dev/null
+++ b/test/features/db_utils.py
@@ -0,0 +1,93 @@
+import pymysql
+
+
+def create_db(hostname='localhost', port=3306, username=None,
+ password=None, dbname=None):
+ """Create test database.
+
+ :param hostname: string
+ :param port: int
+ :param username: string
+ :param password: string
+ :param dbname: string
+ :return:
+
+ """
+ cn = pymysql.connect(
+ host=hostname,
+ port=port,
+ user=username,
+ password=password,
+ charset='utf8mb4',
+ cursorclass=pymysql.cursors.DictCursor
+ )
+
+ with cn.cursor() as cr:
+ cr.execute('drop database if exists ' + dbname)
+ cr.execute('create database ' + dbname)
+
+ cn.close()
+
+ cn = create_cn(hostname, port, password, username, dbname)
+ return cn
+
+
+def create_cn(hostname, port, password, username, dbname):
+ """Open connection to database.
+
+ :param hostname:
+ :param port:
+ :param password:
+ :param username:
+ :param dbname: string
+ :return: psycopg2.connection
+
+ """
+ cn = pymysql.connect(
+ host=hostname,
+ port=port,
+ user=username,
+ password=password,
+ db=dbname,
+ charset='utf8mb4',
+ cursorclass=pymysql.cursors.DictCursor
+ )
+
+ return cn
+
+
+def drop_db(hostname='localhost', port=3306, username=None,
+ password=None, dbname=None):
+ """Drop database.
+
+ :param hostname: string
+ :param port: int
+ :param username: string
+ :param password: string
+ :param dbname: string
+
+ """
+ cn = pymysql.connect(
+ host=hostname,
+ port=port,
+ user=username,
+ password=password,
+ db=dbname,
+ charset='utf8mb4',
+ cursorclass=pymysql.cursors.DictCursor
+ )
+
+ with cn.cursor() as cr:
+ cr.execute('drop database if exists ' + dbname)
+
+ close_cn(cn)
+
+
+def close_cn(cn=None):
+ """Close connection.
+
+ :param connection: pymysql.connection
+
+ """
+ if cn:
+ cn.close()
diff --git a/test/features/environment.py b/test/features/environment.py
new file mode 100644
index 0000000..1ea0f08
--- /dev/null
+++ b/test/features/environment.py
@@ -0,0 +1,176 @@
+import os
+import shutil
+import sys
+from tempfile import mkstemp
+
+import db_utils as dbutils
+import fixture_utils as fixutils
+import pexpect
+
+from steps.wrappers import run_cli, wait_prompt
+
+test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log')
+
+
+SELF_CONNECTING_FEATURES = (
+ 'test/features/connection.feature',
+)
+
+
+MY_CNF_PATH = os.path.expanduser('~/.my.cnf')
+MY_CNF_BACKUP_PATH = f'{MY_CNF_PATH}.backup'
+MYLOGIN_CNF_PATH = os.path.expanduser('~/.mylogin.cnf')
+MYLOGIN_CNF_BACKUP_PATH = f'{MYLOGIN_CNF_PATH}.backup'
+
+
+def get_db_name_from_context(context):
+ return context.config.userdata.get(
+ 'my_test_db', None
+ ) or "mycli_behave_tests"
+
+
+
+def before_all(context):
+ """Set env parameters."""
+ os.environ['LINES'] = "100"
+ os.environ['COLUMNS'] = "100"
+ os.environ['EDITOR'] = 'ex'
+ os.environ['LC_ALL'] = 'en_US.UTF-8'
+ os.environ['PROMPT_TOOLKIT_NO_CPR'] = '1'
+ os.environ['MYCLI_HISTFILE'] = os.devnull
+
+ test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
+ login_path_file = os.path.join(test_dir, 'mylogin.cnf')
+# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
+
+ context.package_root = os.path.abspath(
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
+
+ os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root,
+ '.coveragerc')
+
+ context.exit_sent = False
+
+ vi = '_'.join([str(x) for x in sys.version_info[:3]])
+ db_name = get_db_name_from_context(context)
+ db_name_full = '{0}_{1}'.format(db_name, vi)
+
+ # Store get params from config/environment variables
+ context.conf = {
+ 'host': context.config.userdata.get(
+ 'my_test_host',
+ os.getenv('PYTEST_HOST', 'localhost')
+ ),
+ 'port': context.config.userdata.get(
+ 'my_test_port',
+ int(os.getenv('PYTEST_PORT', '3306'))
+ ),
+ 'user': context.config.userdata.get(
+ 'my_test_user',
+ os.getenv('PYTEST_USER', 'root')
+ ),
+ 'pass': context.config.userdata.get(
+ 'my_test_pass',
+ os.getenv('PYTEST_PASSWORD', None)
+ ),
+ 'cli_command': context.config.userdata.get(
+ 'my_cli_command', None) or
+ sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"',
+ 'dbname': db_name,
+ 'dbname_tmp': db_name_full + '_tmp',
+ 'vi': vi,
+ 'pager_boundary': '---boundary---',
+ }
+
+ _, my_cnf = mkstemp()
+ with open(my_cnf, 'w') as f:
+ f.write(
+ '[client]\n'
+ 'pager={0} {1} {2}\n'.format(
+ sys.executable, os.path.join(context.package_root,
+ 'test/features/wrappager.py'),
+ context.conf['pager_boundary'])
+ )
+ context.conf['defaults-file'] = my_cnf
+ context.conf['myclirc'] = os.path.join(context.package_root, 'test',
+ 'myclirc')
+
+ context.cn = dbutils.create_db(context.conf['host'], context.conf['port'],
+ context.conf['user'],
+ context.conf['pass'],
+ context.conf['dbname'])
+
+ context.fixture_data = fixutils.read_fixture_files()
+
+
+def after_all(context):
+ """Unset env parameters."""
+ dbutils.close_cn(context.cn)
+ dbutils.drop_db(context.conf['host'], context.conf['port'],
+ context.conf['user'], context.conf['pass'],
+ context.conf['dbname'])
+
+ # Restore env vars.
+ #for k, v in context.pgenv.items():
+ # if k in os.environ and v is None:
+ # del os.environ[k]
+ # elif v:
+ # os.environ[k] = v
+
+
+def before_step(context, _):
+ context.atprompt = False
+
+
+def before_scenario(context, arg):
+ with open(test_log_file, 'w') as f:
+ f.write('')
+ if arg.location.filename not in SELF_CONNECTING_FEATURES:
+ run_cli(context)
+ wait_prompt(context)
+
+ if os.path.exists(MY_CNF_PATH):
+ shutil.move(MY_CNF_PATH, MY_CNF_BACKUP_PATH)
+
+ if os.path.exists(MYLOGIN_CNF_PATH):
+ shutil.move(MYLOGIN_CNF_PATH, MYLOGIN_CNF_BACKUP_PATH)
+
+
+def after_scenario(context, _):
+ """Cleans up after each test complete."""
+ with open(test_log_file) as f:
+ for line in f:
+ if 'error' in line.lower():
+ raise RuntimeError(f'Error in log file: {line}')
+
+ if hasattr(context, 'cli') and not context.exit_sent:
+ # Quit nicely.
+ if not context.atprompt:
+ user = context.conf['user']
+ host = context.conf['host']
+ dbname = context.currentdb
+ context.cli.expect_exact(
+ '{0}@{1}:{2}>'.format(
+ user, host, dbname
+ ),
+ timeout=5
+ )
+ context.cli.sendcontrol('c')
+ context.cli.sendcontrol('d')
+ context.cli.expect_exact(pexpect.EOF, timeout=5)
+
+ if os.path.exists(MY_CNF_BACKUP_PATH):
+ shutil.move(MY_CNF_BACKUP_PATH, MY_CNF_PATH)
+
+ if os.path.exists(MYLOGIN_CNF_BACKUP_PATH):
+ shutil.move(MYLOGIN_CNF_BACKUP_PATH, MYLOGIN_CNF_PATH)
+ elif os.path.exists(MYLOGIN_CNF_PATH):
+ # This file was moved in `before_scenario`.
+ # If it exists now, it has been created during a test
+ os.remove(MYLOGIN_CNF_PATH)
+
+
+# TODO: uncomment to debug a failure
+# def after_step(context, step):
+# if step.status == "failed":
+# import ipdb; ipdb.set_trace()
diff --git a/test/features/fixture_data/help.txt b/test/features/fixture_data/help.txt
new file mode 100644
index 0000000..deb499a
--- /dev/null
+++ b/test/features/fixture_data/help.txt
@@ -0,0 +1,24 @@
++--------------------------+-----------------------------------------------+
+| Command | Description |
+|--------------------------+-----------------------------------------------|
+| \# | Refresh auto-completions. |
+| \? | Show Help. |
+| \c[onnect] database_name | Change to a new database. |
+| \d [pattern] | List or describe tables, views and sequences. |
+| \dT[S+] [pattern] | List data types |
+| \df[+] [pattern] | List functions. |
+| \di[+] [pattern] | List indexes. |
+| \dn[+] [pattern] | List schemas. |
+| \ds[+] [pattern] | List sequences. |
+| \dt[+] [pattern] | List tables. |
+| \du[+] [pattern] | List roles. |
+| \dv[+] [pattern] | List views. |
+| \e [file] | Edit the query with external editor. |
+| \l | List databases. |
+| \n[+] [name] | List or execute named queries. |
+| \nd [name [query]] | Delete a named query. |
+| \ns name query | Save a named query. |
+| \refresh | Refresh auto-completions. |
+| \timing | Toggle timing of commands. |
+| \x | Toggle expanded output. |
++--------------------------+-----------------------------------------------+
diff --git a/test/features/fixture_data/help_commands.txt b/test/features/fixture_data/help_commands.txt
new file mode 100644
index 0000000..2c06d5d
--- /dev/null
+++ b/test/features/fixture_data/help_commands.txt
@@ -0,0 +1,31 @@
++-------------+----------------------------+------------------------------------------------------------+
+| Command | Shortcut | Description |
++-------------+----------------------------+------------------------------------------------------------+
+| \G | \G | Display current query results vertically. |
+| \clip | \clip | Copy query to the system clipboard. |
+| \dt | \dt[+] [table] | List or describe tables. |
+| \e | \e | Edit command with editor (uses $EDITOR). |
+| \f | \f [name [args..]] | List or execute favorite queries. |
+| \fd | \fd [name] | Delete a favorite query. |
+| \fs | \fs name query | Save a favorite query. |
+| \l | \l | List databases. |
+| \once | \o [-o] filename | Append next result to an output file (overwrite using -o). |
+| \pipe_once | \| command | Send next result to a subprocess. |
+| \timing | \t | Toggle timing of commands. |
+| connect | \r | Reconnect to the database. Optional database argument. |
+| exit | \q | Exit. |
+| help | \? | Show this help. |
+| nopager | \n | Disable pager, print to stdout. |
+| notee | notee | Stop writing results to an output file. |
+| pager | \P [command] | Set PAGER. Print the query results via PAGER. |
+| prompt | \R | Change prompt format. |
+| quit | \q | Quit. |
+| rehash | \# | Refresh auto-completions. |
+| source | \. filename | Execute commands from file. |
+| status | \s | Get status information from the server. |
+| system | system [command] | Execute a system shell commmand. |
+| tableformat | \T | Change the table format used to output results. |
+| tee | tee [-o] filename | Append all results to an output file (overwrite using -o). |
+| use | \u | Change to a new database. |
+| watch | watch [seconds] [-c] query | Executes the query every [seconds] seconds (by default 5). |
++-------------+----------------------------+------------------------------------------------------------+
diff --git a/test/features/fixture_utils.py b/test/features/fixture_utils.py
new file mode 100644
index 0000000..f85e0f6
--- /dev/null
+++ b/test/features/fixture_utils.py
@@ -0,0 +1,29 @@
+import os
+import io
+
+
+def read_fixture_lines(filename):
+ """Read lines of text from file.
+
+ :param filename: string name
+ :return: list of strings
+
+ """
+ lines = []
+ for line in open(filename):
+ lines.append(line.strip())
+ return lines
+
+
+def read_fixture_files():
+ """Read all files inside fixture_data directory."""
+ fixture_dict = {}
+
+ current_dir = os.path.dirname(__file__)
+ fixture_dir = os.path.join(current_dir, 'fixture_data/')
+ for filename in os.listdir(fixture_dir):
+ if filename not in ['.', '..']:
+ fullname = os.path.join(fixture_dir, filename)
+ fixture_dict[filename] = read_fixture_lines(fullname)
+
+ return fixture_dict
diff --git a/test/features/iocommands.feature b/test/features/iocommands.feature
new file mode 100644
index 0000000..95366eb
--- /dev/null
+++ b/test/features/iocommands.feature
@@ -0,0 +1,47 @@
+Feature: I/O commands
+
+ Scenario: edit sql in file with external editor
+ When we start external editor providing a file name
+ and we type "select * from abc" in the editor
+ and we exit the editor
+ then we see dbcli prompt
+ and we see "select * from abc" in prompt
+
+ Scenario: tee output from query
+ When we tee output
+ and we wait for prompt
+ and we select "select 123456"
+ and we wait for prompt
+ and we notee output
+ and we wait for prompt
+ then we see 123456 in tee output
+
+ Scenario: set delimiter
+ When we query "delimiter $"
+ then delimiter is set to "$"
+
+ Scenario: set delimiter twice
+ When we query "delimiter $"
+ and we query "delimiter ]]"
+ then delimiter is set to "]]"
+
+ Scenario: set delimiter and query on same line
+ When we query "select 123; delimiter $ select 456 $ delimiter %"
+ then we see result "123"
+ and we see result "456"
+ and delimiter is set to "%"
+
+ Scenario: send output to file
+ When we query "\o /tmp/output1.sql"
+ and we query "select 123"
+ and we query "system cat /tmp/output1.sql"
+ then we see result "123"
+
+ Scenario: send output to file two times
+ When we query "\o /tmp/output1.sql"
+ and we query "select 123"
+ and we query "\o /tmp/output2.sql"
+ and we query "select 456"
+ and we query "system cat /tmp/output2.sql"
+ then we see result "456"
+ \ No newline at end of file
diff --git a/test/features/named_queries.feature b/test/features/named_queries.feature
new file mode 100644
index 0000000..5e681ec
--- /dev/null
+++ b/test/features/named_queries.feature
@@ -0,0 +1,24 @@
+Feature: named queries:
+ save, use and delete named queries
+
+ Scenario: save, use and delete named queries
+ When we connect to test database
+ then we see database connected
+ when we save a named query
+ then we see the named query saved
+ when we use a named query
+ then we see the named query executed
+ when we delete a named query
+ then we see the named query deleted
+
+ Scenario: save, use and delete named queries with parameters
+ When we connect to test database
+ then we see database connected
+ when we save a named query with parameters
+ then we see the named query saved
+ when we use named query with parameters
+ then we see the named query with parameters executed
+ when we use named query with too few parameters
+ then we see the named query with parameters fail with missing parameters
+ when we use named query with too many parameters
+ then we see the named query with parameters fail with extra parameters
diff --git a/test/features/specials.feature b/test/features/specials.feature
new file mode 100644
index 0000000..bb36757
--- /dev/null
+++ b/test/features/specials.feature
@@ -0,0 +1,7 @@
+Feature: Special commands
+
+ @wip
+ Scenario: run refresh command
+ When we refresh completions
+ and we wait for prompt
+ then we see completions refresh started
diff --git a/test/features/steps/__init__.py b/test/features/steps/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/features/steps/__init__.py
diff --git a/test/features/steps/auto_vertical.py b/test/features/steps/auto_vertical.py
new file mode 100644
index 0000000..e1cb26f
--- /dev/null
+++ b/test/features/steps/auto_vertical.py
@@ -0,0 +1,46 @@
+from textwrap import dedent
+
+from behave import then, when
+
+import wrappers
+from utils import parse_cli_args_to_dict
+
+
+@when('we run dbcli with {arg}')
+def step_run_cli_with_arg(context, arg):
+ wrappers.run_cli(context, run_args=parse_cli_args_to_dict(arg))
+
+
+@when('we execute a small query')
+def step_execute_small_query(context):
+ context.cli.sendline('select 1')
+
+
+@when('we execute a large query')
+def step_execute_large_query(context):
+ context.cli.sendline(
+ 'select {}'.format(','.join([str(n) for n in range(1, 50)])))
+
+
+@then('we see small results in horizontal format')
+def step_see_small_results(context):
+ wrappers.expect_pager(context, dedent("""\
+ +---+\r
+ | 1 |\r
+ +---+\r
+ | 1 |\r
+ +---+\r
+ \r
+ """), timeout=5)
+ wrappers.expect_exact(context, '1 row in set', timeout=2)
+
+
+@then('we see large results in vertical format')
+def step_see_large_results(context):
+ rows = ['{n:3}| {n}'.format(n=str(n)) for n in range(1, 50)]
+ expected = ('***************************[ 1. row ]'
+ '***************************\r\n' +
+ '{}\r\n'.format('\r\n'.join(rows) + '\r\n'))
+
+ wrappers.expect_pager(context, expected, timeout=10)
+ wrappers.expect_exact(context, '1 row in set', timeout=2)
diff --git a/test/features/steps/basic_commands.py b/test/features/steps/basic_commands.py
new file mode 100644
index 0000000..425ef67
--- /dev/null
+++ b/test/features/steps/basic_commands.py
@@ -0,0 +1,100 @@
+"""Steps for behavioral style tests are defined in this module.
+
+Each step is defined by the string decorating it. This string is used
+to call the step in "*.feature" file.
+
+"""
+
+from behave import when
+from textwrap import dedent
+import tempfile
+import wrappers
+
+
+@when('we run dbcli')
+def step_run_cli(context):
+ wrappers.run_cli(context)
+
+
+@when('we wait for prompt')
+def step_wait_prompt(context):
+ wrappers.wait_prompt(context)
+
+
+@when('we send "ctrl + d"')
+def step_ctrl_d(context):
+ """Send Ctrl + D to hopefully exit."""
+ context.cli.sendcontrol('d')
+ context.exit_sent = True
+
+
+@when('we send "\?" command')
+def step_send_help(context):
+ """Send \?
+
+ to see help.
+
+ """
+ context.cli.sendline('\\?')
+ wrappers.expect_exact(
+ context, context.conf['pager_boundary'] + '\r\n', timeout=5)
+
+
+@when(u'we send source command')
+def step_send_source_command(context):
+ with tempfile.NamedTemporaryFile() as f:
+ f.write(b'\?')
+ f.flush()
+ context.cli.sendline('\. {0}'.format(f.name))
+ wrappers.expect_exact(
+ context, context.conf['pager_boundary'] + '\r\n', timeout=5)
+
+
+@when(u'we run query to check application_name')
+def step_check_application_name(context):
+ context.cli.sendline(
+ "SELECT 'found' FROM performance_schema.session_connect_attrs WHERE attr_name = 'program_name' AND attr_value = 'mycli'"
+ )
+
+
+@then(u'we see found')
+def step_see_found(context):
+ wrappers.expect_exact(
+ context,
+ context.conf['pager_boundary'] + '\r' + dedent('''
+ +-------+\r
+ | found |\r
+ +-------+\r
+ | found |\r
+ +-------+\r
+ \r
+ ''') + context.conf['pager_boundary'],
+ timeout=5
+ )
+
+
+@then(u'we confirm the destructive warning')
+def step_confirm_destructive_command(context):
+ """Confirm destructive command."""
+ wrappers.expect_exact(
+ context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
+ context.cli.sendline('y')
+
+
+@when(u'we answer the destructive warning with "{confirmation}"')
+def step_confirm_destructive_command(context, confirmation):
+ """Confirm destructive command."""
+ wrappers.expect_exact(
+ context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
+ context.cli.sendline(confirmation)
+
+
+@then(u'we answer the destructive warning with invalid "{confirmation}" and see text "{text}"')
+def step_confirm_destructive_command(context, confirmation, text):
+ """Confirm destructive command."""
+ wrappers.expect_exact(
+ context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2)
+ context.cli.sendline(confirmation)
+ wrappers.expect_exact(context, text, timeout=2)
+ # we must exit the Click loop, or the feature will hang
+ context.cli.sendline('n')
diff --git a/test/features/steps/connection.py b/test/features/steps/connection.py
new file mode 100644
index 0000000..e16dd86
--- /dev/null
+++ b/test/features/steps/connection.py
@@ -0,0 +1,71 @@
+import io
+import os
+import shlex
+
+from behave import when, then
+import pexpect
+
+import wrappers
+from test.features.steps.utils import parse_cli_args_to_dict
+from test.features.environment import MY_CNF_PATH, MYLOGIN_CNF_PATH, get_db_name_from_context
+from test.utils import HOST, PORT, USER, PASSWORD
+from mycli.config import encrypt_mylogin_cnf
+
+
+TEST_LOGIN_PATH = 'test_login_path'
+
+
+@when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"')
+@when('we run mycli without arguments "{excluded_args}"')
+def step_run_cli_without_args(context, excluded_args, exact_args=''):
+ wrappers.run_cli(
+ context,
+ run_args=parse_cli_args_to_dict(exact_args),
+ exclude_args=parse_cli_args_to_dict(excluded_args).keys()
+ )
+
+
+@then('status contains "{expression}"')
+def status_contains(context, expression):
+ wrappers.expect_exact(context, f'{expression}', timeout=5)
+
+ # Normally, the shutdown after scenario waits for the prompt.
+ # But we may have changed the prompt, depending on parameters,
+ # so let's wait for its last character
+ context.cli.expect_exact('>')
+ context.atprompt = True
+
+
+@when('we create my.cnf file')
+def step_create_my_cnf_file(context):
+ my_cnf = (
+ '[client]\n'
+ f'host = {HOST}\n'
+ f'port = {PORT}\n'
+ f'user = {USER}\n'
+ f'password = {PASSWORD}\n'
+ )
+ with open(MY_CNF_PATH, 'w') as f:
+ f.write(my_cnf)
+
+
+@when('we create mylogin.cnf file')
+def step_create_mylogin_cnf_file(context):
+ os.environ.pop('MYSQL_TEST_LOGIN_FILE', None)
+ mylogin_cnf = (
+ f'[{TEST_LOGIN_PATH}]\n'
+ f'host = {HOST}\n'
+ f'port = {PORT}\n'
+ f'user = {USER}\n'
+ f'password = {PASSWORD}\n'
+ )
+ with open(MYLOGIN_CNF_PATH, 'wb') as f:
+ input_file = io.StringIO(mylogin_cnf)
+ f.write(encrypt_mylogin_cnf(input_file).read())
+
+
+@then('we are logged in')
+def we_are_logged_in(context):
+ db_name = get_db_name_from_context(context)
+ context.cli.expect_exact(f'{db_name}>', timeout=5)
+ context.atprompt = True
diff --git a/test/features/steps/crud_database.py b/test/features/steps/crud_database.py
new file mode 100644
index 0000000..841f37d
--- /dev/null
+++ b/test/features/steps/crud_database.py
@@ -0,0 +1,115 @@
+"""Steps for behavioral style tests are defined in this module.
+
+Each step is defined by the string decorating it. This string is used
+to call the step in "*.feature" file.
+
+"""
+
+import pexpect
+
+import wrappers
+from behave import when, then
+
+
+@when('we create database')
+def step_db_create(context):
+ """Send create database."""
+ context.cli.sendline('create database {0};'.format(
+ context.conf['dbname_tmp']))
+
+ context.response = {
+ 'database_name': context.conf['dbname_tmp']
+ }
+
+
+@when('we drop database')
+def step_db_drop(context):
+ """Send drop database."""
+ context.cli.sendline('drop database {0};'.format(
+ context.conf['dbname_tmp']))
+
+
+@when('we connect to test database')
+def step_db_connect_test(context):
+ """Send connect to database."""
+ db_name = context.conf['dbname']
+ context.currentdb = db_name
+ context.cli.sendline('use {0};'.format(db_name))
+
+
+@when('we connect to quoted test database')
+def step_db_connect_quoted_tmp(context):
+ """Send connect to database."""
+ db_name = context.conf['dbname']
+ context.currentdb = db_name
+ context.cli.sendline('use `{0}`;'.format(db_name))
+
+
+@when('we connect to tmp database')
+def step_db_connect_tmp(context):
+ """Send connect to database."""
+ db_name = context.conf['dbname_tmp']
+ context.currentdb = db_name
+ context.cli.sendline('use {0}'.format(db_name))
+
+
+@when('we connect to dbserver')
+def step_db_connect_dbserver(context):
+ """Send connect to database."""
+ context.currentdb = 'mysql'
+ context.cli.sendline('use mysql')
+
+
+@then('dbcli exits')
+def step_wait_exit(context):
+ """Make sure the cli exits."""
+ wrappers.expect_exact(context, pexpect.EOF, timeout=5)
+
+
+@then('we see dbcli prompt')
+def step_see_prompt(context):
+ """Wait to see the prompt."""
+ user = context.conf['user']
+ host = context.conf['host']
+ dbname = context.currentdb
+ wrappers.wait_prompt(context, '{0}@{1}:{2}> '.format(user, host, dbname))
+
+
+@then('we see help output')
+def step_see_help(context):
+ for expected_line in context.fixture_data['help_commands.txt']:
+ wrappers.expect_exact(context, expected_line, timeout=1)
+
+
+@then('we see database created')
+def step_see_db_created(context):
+ """Wait to see create database output."""
+ wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
+
+
+@then('we see database dropped')
+def step_see_db_dropped(context):
+ """Wait to see drop database output."""
+ wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
+
+
+@then('we see database dropped and no default database')
+def step_see_db_dropped_no_default(context):
+ """Wait to see drop database output."""
+ user = context.conf['user']
+ host = context.conf['host']
+ database = '(none)'
+ context.currentdb = None
+
+ wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
+ wrappers.wait_prompt(context, '{0}@{1}:{2}>'.format(user, host, database))
+
+
+@then('we see database connected')
+def step_see_db_connected(context):
+ """Wait to see drop database output."""
+ wrappers.expect_exact(
+ context, 'You are now connected to database "', timeout=2)
+ wrappers.expect_exact(context, '"', timeout=2)
+ wrappers.expect_exact(context, ' as user "{0}"'.format(
+ context.conf['user']), timeout=2)
diff --git a/test/features/steps/crud_table.py b/test/features/steps/crud_table.py
new file mode 100644
index 0000000..f715f0c
--- /dev/null
+++ b/test/features/steps/crud_table.py
@@ -0,0 +1,112 @@
+"""Steps for behavioral style tests are defined in this module.
+
+Each step is defined by the string decorating it. This string is used
+to call the step in "*.feature" file.
+
+"""
+
+import wrappers
+from behave import when, then
+from textwrap import dedent
+
+
+@when('we create table')
+def step_create_table(context):
+ """Send create table."""
+ context.cli.sendline('create table a(x text);')
+
+
+@when('we insert into table')
+def step_insert_into_table(context):
+ """Send insert into table."""
+ context.cli.sendline('''insert into a(x) values('xxx');''')
+
+
+@when('we update table')
+def step_update_table(context):
+ """Send insert into table."""
+ context.cli.sendline('''update a set x = 'yyy' where x = 'xxx';''')
+
+
+@when('we select from table')
+def step_select_from_table(context):
+ """Send select from table."""
+ context.cli.sendline('select * from a;')
+
+
+@when('we delete from table')
+def step_delete_from_table(context):
+ """Send deete from table."""
+ context.cli.sendline('''delete from a where x = 'yyy';''')
+
+
+@when('we drop table')
+def step_drop_table(context):
+ """Send drop table."""
+ context.cli.sendline('drop table a;')
+
+
+@then('we see table created')
+def step_see_table_created(context):
+ """Wait to see create table output."""
+ wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
+
+
+@then('we see record inserted')
+def step_see_record_inserted(context):
+ """Wait to see insert output."""
+ wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
+
+
+@then('we see record updated')
+def step_see_record_updated(context):
+ """Wait to see update output."""
+ wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
+
+
+@then('we see data selected')
+def step_see_data_selected(context):
+ """Wait to see select output."""
+ wrappers.expect_pager(
+ context, dedent("""\
+ +-----+\r
+ | x |\r
+ +-----+\r
+ | yyy |\r
+ +-----+\r
+ \r
+ """), timeout=2)
+ wrappers.expect_exact(context, '1 row in set', timeout=2)
+
+
+@then('we see record deleted')
+def step_see_data_deleted(context):
+ """Wait to see delete output."""
+ wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2)
+
+
+@then('we see table dropped')
+def step_see_table_dropped(context):
+ """Wait to see drop output."""
+ wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2)
+
+
+@when('we select null')
+def step_select_null(context):
+ """Send select null."""
+ context.cli.sendline('select null;')
+
+
+@then('we see null selected')
+def step_see_null_selected(context):
+ """Wait to see null output."""
+ wrappers.expect_pager(
+ context, dedent("""\
+ +--------+\r
+ | NULL |\r
+ +--------+\r
+ | <null> |\r
+ +--------+\r
+ \r
+ """), timeout=2)
+ wrappers.expect_exact(context, '1 row in set', timeout=2)
diff --git a/test/features/steps/iocommands.py b/test/features/steps/iocommands.py
new file mode 100644
index 0000000..bbabf43
--- /dev/null
+++ b/test/features/steps/iocommands.py
@@ -0,0 +1,105 @@
+import os
+import wrappers
+
+from behave import when, then
+from textwrap import dedent
+
+
+@when('we start external editor providing a file name')
+def step_edit_file(context):
+ """Edit file with external editor."""
+ context.editor_file_name = os.path.join(
+ context.package_root, 'test_file_{0}.sql'.format(context.conf['vi']))
+ if os.path.exists(context.editor_file_name):
+ os.remove(context.editor_file_name)
+ context.cli.sendline('\e {0}'.format(
+ os.path.basename(context.editor_file_name)))
+ wrappers.expect_exact(
+ context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2)
+ wrappers.expect_exact(context, '\r\n:', timeout=2)
+
+
+@when('we type "{query}" in the editor')
+def step_edit_type_sql(context, query):
+ context.cli.sendline('i')
+ context.cli.sendline(query)
+ context.cli.sendline('.')
+ wrappers.expect_exact(context, '\r\n:', timeout=2)
+
+
+@when('we exit the editor')
+def step_edit_quit(context):
+ context.cli.sendline('x')
+ wrappers.expect_exact(context, "written", timeout=2)
+
+
+@then('we see "{query}" in prompt')
+def step_edit_done_sql(context, query):
+ for match in query.split(' '):
+ wrappers.expect_exact(context, match, timeout=5)
+ # Cleanup the command line.
+ context.cli.sendcontrol('c')
+ # Cleanup the edited file.
+ if context.editor_file_name and os.path.exists(context.editor_file_name):
+ os.remove(context.editor_file_name)
+
+
+@when(u'we tee output')
+def step_tee_ouptut(context):
+ context.tee_file_name = os.path.join(
+ context.package_root, 'tee_file_{0}.sql'.format(context.conf['vi']))
+ if os.path.exists(context.tee_file_name):
+ os.remove(context.tee_file_name)
+ context.cli.sendline('tee {0}'.format(
+ os.path.basename(context.tee_file_name)))
+
+
+@when(u'we select "select {param}"')
+def step_query_select_number(context, param):
+ context.cli.sendline(u'select {}'.format(param))
+ wrappers.expect_pager(context, dedent(u"""\
+ +{dashes}+\r
+ | {param} |\r
+ +{dashes}+\r
+ | {param} |\r
+ +{dashes}+\r
+ \r
+ """.format(param=param, dashes='-' * (len(param) + 2))
+ ), timeout=5)
+ wrappers.expect_exact(context, '1 row in set', timeout=2)
+
+
+@then(u'we see result "{result}"')
+def step_see_result(context, result):
+ wrappers.expect_exact(
+ context,
+ u"| {} |".format(result),
+ timeout=2
+ )
+
+
+@when(u'we query "{query}"')
+def step_query(context, query):
+ context.cli.sendline(query)
+
+
+@when(u'we notee output')
+def step_notee_output(context):
+ context.cli.sendline('notee')
+
+
+@then(u'we see 123456 in tee output')
+def step_see_123456_in_ouput(context):
+ with open(context.tee_file_name) as f:
+ assert '123456' in f.read()
+ if os.path.exists(context.tee_file_name):
+ os.remove(context.tee_file_name)
+
+
+@then(u'delimiter is set to "{delimiter}"')
+def delimiter_is_set(context, delimiter):
+ wrappers.expect_exact(
+ context,
+ u'Changed delimiter to {}'.format(delimiter),
+ timeout=2
+ )
diff --git a/test/features/steps/named_queries.py b/test/features/steps/named_queries.py
new file mode 100644
index 0000000..bc1f866
--- /dev/null
+++ b/test/features/steps/named_queries.py
@@ -0,0 +1,90 @@
+"""Steps for behavioral style tests are defined in this module.
+
+Each step is defined by the string decorating it. This string is used
+to call the step in "*.feature" file.
+
+"""
+
+import wrappers
+from behave import when, then
+
+
+@when('we save a named query')
+def step_save_named_query(context):
+ """Send \fs command."""
+ context.cli.sendline('\\fs foo SELECT 12345')
+
+
+@when('we use a named query')
+def step_use_named_query(context):
+ """Send \f command."""
+ context.cli.sendline('\\f foo')
+
+
+@when('we delete a named query')
+def step_delete_named_query(context):
+ """Send \fd command."""
+ context.cli.sendline('\\fd foo')
+
+
+@then('we see the named query saved')
+def step_see_named_query_saved(context):
+ """Wait to see query saved."""
+ wrappers.expect_exact(context, 'Saved.', timeout=2)
+
+
+@then('we see the named query executed')
+def step_see_named_query_executed(context):
+ """Wait to see select output."""
+ wrappers.expect_exact(context, 'SELECT 12345', timeout=2)
+
+
+@then('we see the named query deleted')
+def step_see_named_query_deleted(context):
+ """Wait to see query deleted."""
+ wrappers.expect_exact(context, 'foo: Deleted', timeout=2)
+
+
+@when('we save a named query with parameters')
+def step_save_named_query_with_parameters(context):
+ """Send \fs command for query with parameters."""
+ context.cli.sendline('\\fs foo_args SELECT $1, "$2", "$3"')
+
+
+@when('we use named query with parameters')
+def step_use_named_query_with_parameters(context):
+ """Send \f command with parameters."""
+ context.cli.sendline('\\f foo_args 101 second "third value"')
+
+
+@then('we see the named query with parameters executed')
+def step_see_named_query_with_parameters_executed(context):
+ """Wait to see select output."""
+ wrappers.expect_exact(
+ context, 'SELECT 101, "second", "third value"', timeout=2)
+
+
+@when('we use named query with too few parameters')
+def step_use_named_query_with_too_few_parameters(context):
+ """Send \f command with missing parameters."""
+ context.cli.sendline('\\f foo_args 101')
+
+
+@then('we see the named query with parameters fail with missing parameters')
+def step_see_named_query_with_parameters_fail_with_missing_parameters(context):
+ """Wait to see select output."""
+ wrappers.expect_exact(
+ context, 'missing substitution for $2 in query:', timeout=2)
+
+
+@when('we use named query with too many parameters')
+def step_use_named_query_with_too_many_parameters(context):
+ """Send \f command with extra parameters."""
+ context.cli.sendline('\\f foo_args 101 102 103 104')
+
+
+@then('we see the named query with parameters fail with extra parameters')
+def step_see_named_query_with_parameters_fail_with_extra_parameters(context):
+ """Wait to see select output."""
+ wrappers.expect_exact(
+ context, 'query does not have substitution parameter $4:', timeout=2)
diff --git a/test/features/steps/specials.py b/test/features/steps/specials.py
new file mode 100644
index 0000000..e8b99e3
--- /dev/null
+++ b/test/features/steps/specials.py
@@ -0,0 +1,27 @@
+"""Steps for behavioral style tests are defined in this module.
+
+Each step is defined by the string decorating it. This string is used
+to call the step in "*.feature" file.
+
+"""
+
+import wrappers
+from behave import when, then
+
+
+@when('we refresh completions')
+def step_refresh_completions(context):
+ """Send refresh command."""
+ context.cli.sendline('rehash')
+
+
+@then('we see text "{text}"')
+def step_see_text(context, text):
+ """Wait to see given text message."""
+ wrappers.expect_exact(context, text, timeout=2)
+
+@then('we see completions refresh started')
+def step_see_refresh_started(context):
+ """Wait to see refresh output."""
+ wrappers.expect_exact(
+ context, 'Auto-completion refresh started in the background.', timeout=2)
diff --git a/test/features/steps/utils.py b/test/features/steps/utils.py
new file mode 100644
index 0000000..1ae63d2
--- /dev/null
+++ b/test/features/steps/utils.py
@@ -0,0 +1,12 @@
+import shlex
+
+
+def parse_cli_args_to_dict(cli_args: str):
+ args_dict = {}
+ for arg in shlex.split(cli_args):
+ if '=' in arg:
+ key, value = arg.split('=')
+ args_dict[key] = value
+ else:
+ args_dict[arg] = None
+ return args_dict
diff --git a/test/features/steps/wrappers.py b/test/features/steps/wrappers.py
new file mode 100644
index 0000000..6408f23
--- /dev/null
+++ b/test/features/steps/wrappers.py
@@ -0,0 +1,117 @@
+import re
+import pexpect
+import sys
+import textwrap
+
+
+try:
+ from StringIO import StringIO
+except ImportError:
+ from io import StringIO
+
+
+def expect_exact(context, expected, timeout):
+ timedout = False
+ try:
+ context.cli.expect_exact(expected, timeout=timeout)
+ except pexpect.TIMEOUT:
+ timedout = True
+ if timedout:
+ # Strip color codes out of the output.
+ actual = re.sub(r'\x1b\[([0-9A-Za-z;?])+[m|K]?',
+ '', context.cli.before)
+ raise Exception(
+ textwrap.dedent('''\
+ Expected:
+ ---
+ {0!r}
+ ---
+ Actual:
+ ---
+ {1!r}
+ ---
+ Full log:
+ ---
+ {2!r}
+ ---
+ ''').format(
+ expected,
+ actual,
+ context.logfile.getvalue()
+ )
+ )
+
+
+def expect_pager(context, expected, timeout):
+ expect_exact(context, "{0}\r\n{1}{0}\r\n".format(
+ context.conf['pager_boundary'], expected), timeout=timeout)
+
+
+def run_cli(context, run_args=None, exclude_args=None):
+ """Run the process using pexpect."""
+ run_args = run_args or {}
+ rendered_args = []
+ exclude_args = set(exclude_args) if exclude_args else set()
+
+ conf = dict(**context.conf)
+ conf.update(run_args)
+
+ def add_arg(name, key, value):
+ if name not in exclude_args:
+ if value is not None:
+ rendered_args.extend((key, value))
+ else:
+ rendered_args.append(key)
+
+ if conf.get('host', None):
+ add_arg('host', '-h', conf['host'])
+ if conf.get('user', None):
+ add_arg('user', '-u', conf['user'])
+ if conf.get('pass', None):
+ add_arg('pass', '-p', conf['pass'])
+ if conf.get('port', None):
+ add_arg('port', '-P', str(conf['port']))
+ if conf.get('dbname', None):
+ add_arg('dbname', '-D', conf['dbname'])
+ if conf.get('defaults-file', None):
+ add_arg('defaults_file', '--defaults-file', conf['defaults-file'])
+ if conf.get('myclirc', None):
+ add_arg('myclirc', '--myclirc', conf['myclirc'])
+ if conf.get('login_path'):
+ add_arg('login_path', '--login-path', conf['login_path'])
+
+ for arg_name, arg_value in conf.items():
+ if arg_name.startswith('-'):
+ add_arg(arg_name, arg_name, arg_value)
+
+ try:
+ cli_cmd = context.conf['cli_command']
+ except KeyError:
+ cli_cmd = (
+ '{0!s} -c "'
+ 'import coverage ; '
+ 'coverage.process_startup(); '
+ 'import mycli.main; '
+ 'mycli.main.cli()'
+ '"'
+ ).format(sys.executable)
+
+ cmd_parts = [cli_cmd] + rendered_args
+ cmd = ' '.join(cmd_parts)
+ context.cli = pexpect.spawnu(cmd, cwd=context.package_root)
+ context.logfile = StringIO()
+ context.cli.logfile = context.logfile
+ context.exit_sent = False
+ context.currentdb = context.conf['dbname']
+
+
+def wait_prompt(context, prompt=None):
+ """Make sure prompt is displayed."""
+ if prompt is None:
+ user = context.conf['user']
+ host = context.conf['host']
+ dbname = context.currentdb
+ prompt = '{0}@{1}:{2}>'.format(
+ user, host, dbname),
+ expect_exact(context, prompt, timeout=5)
+ context.atprompt = True
diff --git a/test/features/wrappager.py b/test/features/wrappager.py
new file mode 100755
index 0000000..51d4909
--- /dev/null
+++ b/test/features/wrappager.py
@@ -0,0 +1,16 @@
+#!/usr/bin/env python
+import sys
+
+
+def wrappager(boundary):
+ print(boundary)
+ while 1:
+ buf = sys.stdin.read(2048)
+ if not buf:
+ break
+ sys.stdout.write(buf)
+ print(boundary)
+
+
+if __name__ == "__main__":
+ wrappager(sys.argv[1])
diff --git a/test/myclirc b/test/myclirc
new file mode 100644
index 0000000..261bee6
--- /dev/null
+++ b/test/myclirc
@@ -0,0 +1,12 @@
+# vi: ft=dosini
+
+# This file is loaded after mycli/myclirc and should override only those
+# variables needed for testing.
+# To see what every variable does see mycli/myclirc
+
+[main]
+
+log_file = ~/.mycli.test.log
+log_level = DEBUG
+prompt = '\t \u@\h:\d> '
+less_chatty = True
diff --git a/test/mylogin.cnf b/test/mylogin.cnf
new file mode 100644
index 0000000..1363cc3
--- /dev/null
+++ b/test/mylogin.cnf
Binary files differ
diff --git a/test/test.txt b/test/test.txt
new file mode 100644
index 0000000..8d8b211
--- /dev/null
+++ b/test/test.txt
@@ -0,0 +1 @@
+mycli rocks!
diff --git a/test/test_clistyle.py b/test/test_clistyle.py
new file mode 100644
index 0000000..f82cdf0
--- /dev/null
+++ b/test/test_clistyle.py
@@ -0,0 +1,27 @@
+"""Test the mycli.clistyle module."""
+import pytest
+
+from pygments.style import Style
+from pygments.token import Token
+
+from mycli.clistyle import style_factory
+
+
+@pytest.mark.skip(reason="incompatible with new prompt toolkit")
+def test_style_factory():
+ """Test that a Pygments Style class is created."""
+ header = 'bold underline #ansired'
+ cli_style = {'Token.Output.Header': header}
+ style = style_factory('default', cli_style)
+
+ assert isinstance(style(), Style)
+ assert Token.Output.Header in style.styles
+ assert header == style.styles[Token.Output.Header]
+
+
+@pytest.mark.skip(reason="incompatible with new prompt toolkit")
+def test_style_factory_unknown_name():
+ """Test that an unrecognized name will not throw an error."""
+ style = style_factory('foobar', {})
+
+ assert isinstance(style(), Style)
diff --git a/test/test_completion_engine.py b/test/test_completion_engine.py
new file mode 100644
index 0000000..318b632
--- /dev/null
+++ b/test/test_completion_engine.py
@@ -0,0 +1,555 @@
+from mycli.packages.completion_engine import suggest_type
+import pytest
+
+
+def sorted_dicts(dicts):
+ """input is a list of dicts."""
+ return sorted(tuple(x.items()) for x in dicts)
+
+
+def test_select_suggests_cols_with_visible_table_scope():
+ suggestions = suggest_type('SELECT FROM tabl', 'SELECT ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['tabl']},
+ {'type': 'column', 'tables': [(None, 'tabl', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+
+def test_select_suggests_cols_with_qualified_table_scope():
+ suggestions = suggest_type('SELECT FROM sch.tabl', 'SELECT ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['tabl']},
+ {'type': 'column', 'tables': [('sch', 'tabl', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+
+@pytest.mark.parametrize('expression', [
+ 'SELECT * FROM tabl WHERE ',
+ 'SELECT * FROM tabl WHERE (',
+ 'SELECT * FROM tabl WHERE foo = ',
+ 'SELECT * FROM tabl WHERE bar OR ',
+ 'SELECT * FROM tabl WHERE foo = 1 AND ',
+ 'SELECT * FROM tabl WHERE (bar > 10 AND ',
+ 'SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (',
+ 'SELECT * FROM tabl WHERE 10 < ',
+ 'SELECT * FROM tabl WHERE foo BETWEEN ',
+ 'SELECT * FROM tabl WHERE foo BETWEEN foo AND ',
+])
+def test_where_suggests_columns_functions(expression):
+ suggestions = suggest_type(expression, expression)
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['tabl']},
+ {'type': 'column', 'tables': [(None, 'tabl', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+
+@pytest.mark.parametrize('expression', [
+ 'SELECT * FROM tabl WHERE foo IN (',
+ 'SELECT * FROM tabl WHERE foo IN (bar, ',
+])
+def test_where_in_suggests_columns(expression):
+ suggestions = suggest_type(expression, expression)
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['tabl']},
+ {'type': 'column', 'tables': [(None, 'tabl', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+
+def test_where_equals_any_suggests_columns_or_keywords():
+ text = 'SELECT * FROM tabl WHERE foo = ANY('
+ suggestions = suggest_type(text, text)
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['tabl']},
+ {'type': 'column', 'tables': [(None, 'tabl', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'}])
+
+
+def test_lparen_suggests_cols():
+ suggestion = suggest_type('SELECT MAX( FROM tbl', 'SELECT MAX(')
+ assert suggestion == [
+ {'type': 'column', 'tables': [(None, 'tbl', None)]}]
+
+
+def test_operand_inside_function_suggests_cols1():
+ suggestion = suggest_type(
+ 'SELECT MAX(col1 + FROM tbl', 'SELECT MAX(col1 + ')
+ assert suggestion == [
+ {'type': 'column', 'tables': [(None, 'tbl', None)]}]
+
+
+def test_operand_inside_function_suggests_cols2():
+ suggestion = suggest_type(
+ 'SELECT MAX(col1 + col2 + FROM tbl', 'SELECT MAX(col1 + col2 + ')
+ assert suggestion == [
+ {'type': 'column', 'tables': [(None, 'tbl', None)]}]
+
+
+def test_select_suggests_cols_and_funcs():
+ suggestions = suggest_type('SELECT ', 'SELECT ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': []},
+ {'type': 'column', 'tables': []},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+
+@pytest.mark.parametrize('expression', [
+ 'SELECT * FROM ',
+ 'INSERT INTO ',
+ 'COPY ',
+ 'UPDATE ',
+ 'DESCRIBE ',
+ 'DESC ',
+ 'EXPLAIN ',
+ 'SELECT * FROM foo JOIN ',
+])
+def test_expression_suggests_tables_views_and_schemas(expression):
+ suggestions = suggest_type(expression, expression)
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+
+@pytest.mark.parametrize('expression', [
+ 'SELECT * FROM sch.',
+ 'INSERT INTO sch.',
+ 'COPY sch.',
+ 'UPDATE sch.',
+ 'DESCRIBE sch.',
+ 'DESC sch.',
+ 'EXPLAIN sch.',
+ 'SELECT * FROM foo JOIN sch.',
+])
+def test_expression_suggests_qualified_tables_views_and_schemas(expression):
+ suggestions = suggest_type(expression, expression)
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': 'sch'},
+ {'type': 'view', 'schema': 'sch'}])
+
+
+def test_truncate_suggests_tables_and_schemas():
+ suggestions = suggest_type('TRUNCATE ', 'TRUNCATE ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'schema'}])
+
+
+def test_truncate_suggests_qualified_tables():
+ suggestions = suggest_type('TRUNCATE sch.', 'TRUNCATE sch.')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': 'sch'}])
+
+
+def test_distinct_suggests_cols():
+ suggestions = suggest_type('SELECT DISTINCT ', 'SELECT DISTINCT ')
+ assert suggestions == [{'type': 'column', 'tables': []}]
+
+
+def test_col_comma_suggests_cols():
+ suggestions = suggest_type('SELECT a, b, FROM tbl', 'SELECT a, b,')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['tbl']},
+ {'type': 'column', 'tables': [(None, 'tbl', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+
+def test_table_comma_suggests_tables_and_schemas():
+ suggestions = suggest_type('SELECT a, b FROM tbl1, ',
+ 'SELECT a, b FROM tbl1, ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+
+def test_into_suggests_tables_and_schemas():
+ suggestion = suggest_type('INSERT INTO ', 'INSERT INTO ')
+ assert sorted_dicts(suggestion) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+
+def test_insert_into_lparen_suggests_cols():
+ suggestions = suggest_type('INSERT INTO abc (', 'INSERT INTO abc (')
+ assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
+
+
+def test_insert_into_lparen_partial_text_suggests_cols():
+ suggestions = suggest_type('INSERT INTO abc (i', 'INSERT INTO abc (i')
+ assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
+
+
+def test_insert_into_lparen_comma_suggests_cols():
+ suggestions = suggest_type('INSERT INTO abc (id,', 'INSERT INTO abc (id,')
+ assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
+
+
+def test_partially_typed_col_name_suggests_col_names():
+ suggestions = suggest_type('SELECT * FROM tabl WHERE col_n',
+ 'SELECT * FROM tabl WHERE col_n')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['tabl']},
+ {'type': 'column', 'tables': [(None, 'tabl', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+
+def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
+ suggestions = suggest_type('SELECT tabl. FROM tabl', 'SELECT tabl.')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'column', 'tables': [(None, 'tabl', None)]},
+ {'type': 'table', 'schema': 'tabl'},
+ {'type': 'view', 'schema': 'tabl'},
+ {'type': 'function', 'schema': 'tabl'}])
+
+
+def test_dot_suggests_cols_of_an_alias():
+ suggestions = suggest_type('SELECT t1. FROM tabl1 t1, tabl2 t2',
+ 'SELECT t1.')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': 't1'},
+ {'type': 'view', 'schema': 't1'},
+ {'type': 'column', 'tables': [(None, 'tabl1', 't1')]},
+ {'type': 'function', 'schema': 't1'}])
+
+
+def test_dot_col_comma_suggests_cols_or_schema_qualified_table():
+ suggestions = suggest_type('SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2',
+ 'SELECT t1.a, t2.')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'column', 'tables': [(None, 'tabl2', 't2')]},
+ {'type': 'table', 'schema': 't2'},
+ {'type': 'view', 'schema': 't2'},
+ {'type': 'function', 'schema': 't2'}])
+
+
+@pytest.mark.parametrize('expression', [
+ 'SELECT * FROM (',
+ 'SELECT * FROM foo WHERE EXISTS (',
+ 'SELECT * FROM foo WHERE bar AND NOT EXISTS (',
+ 'SELECT 1 AS',
+])
+def test_sub_select_suggests_keyword(expression):
+ suggestion = suggest_type(expression, expression)
+ assert suggestion == [{'type': 'keyword'}]
+
+
+@pytest.mark.parametrize('expression', [
+ 'SELECT * FROM (S',
+ 'SELECT * FROM foo WHERE EXISTS (S',
+ 'SELECT * FROM foo WHERE bar AND NOT EXISTS (S',
+])
+def test_sub_select_partial_text_suggests_keyword(expression):
+ suggestion = suggest_type(expression, expression)
+ assert suggestion == [{'type': 'keyword'}]
+
+
+def test_outer_table_reference_in_exists_subquery_suggests_columns():
+ q = 'SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f.'
+ suggestions = suggest_type(q, q)
+ assert suggestions == [
+ {'type': 'column', 'tables': [(None, 'foo', 'f')]},
+ {'type': 'table', 'schema': 'f'},
+ {'type': 'view', 'schema': 'f'},
+ {'type': 'function', 'schema': 'f'}]
+
+
+@pytest.mark.parametrize('expression', [
+ 'SELECT * FROM (SELECT * FROM ',
+ 'SELECT * FROM foo WHERE EXISTS (SELECT * FROM ',
+ 'SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ',
+])
+def test_sub_select_table_name_completion(expression):
+ suggestion = suggest_type(expression, expression)
+ assert sorted_dicts(suggestion) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+
+def test_sub_select_col_name_completion():
+ suggestions = suggest_type('SELECT * FROM (SELECT FROM abc',
+ 'SELECT * FROM (SELECT ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['abc']},
+ {'type': 'column', 'tables': [(None, 'abc', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+
+@pytest.mark.xfail
+def test_sub_select_multiple_col_name_completion():
+ suggestions = suggest_type('SELECT * FROM (SELECT a, FROM abc',
+ 'SELECT * FROM (SELECT a, ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'column', 'tables': [(None, 'abc', None)]},
+ {'type': 'function', 'schema': []}])
+
+
+def test_sub_select_dot_col_name_completion():
+ suggestions = suggest_type('SELECT * FROM (SELECT t. FROM tabl t',
+ 'SELECT * FROM (SELECT t.')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'column', 'tables': [(None, 'tabl', 't')]},
+ {'type': 'table', 'schema': 't'},
+ {'type': 'view', 'schema': 't'},
+ {'type': 'function', 'schema': 't'}])
+
+
+@pytest.mark.parametrize('join_type', ['', 'INNER', 'LEFT', 'RIGHT OUTER'])
+@pytest.mark.parametrize('tbl_alias', ['', 'foo'])
+def test_join_suggests_tables_and_schemas(tbl_alias, join_type):
+ text = 'SELECT * FROM abc {0} {1} JOIN '.format(tbl_alias, join_type)
+ suggestion = suggest_type(text, text)
+ assert sorted_dicts(suggestion) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+
+@pytest.mark.parametrize('sql', [
+ 'SELECT * FROM abc a JOIN def d ON a.',
+ 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.',
+])
+def test_join_alias_dot_suggests_cols1(sql):
+ suggestions = suggest_type(sql, sql)
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'column', 'tables': [(None, 'abc', 'a')]},
+ {'type': 'table', 'schema': 'a'},
+ {'type': 'view', 'schema': 'a'},
+ {'type': 'function', 'schema': 'a'}])
+
+
+@pytest.mark.parametrize('sql', [
+ 'SELECT * FROM abc a JOIN def d ON a.id = d.',
+ 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.',
+])
+def test_join_alias_dot_suggests_cols2(sql):
+ suggestions = suggest_type(sql, sql)
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'column', 'tables': [(None, 'def', 'd')]},
+ {'type': 'table', 'schema': 'd'},
+ {'type': 'view', 'schema': 'd'},
+ {'type': 'function', 'schema': 'd'}])
+
+
+@pytest.mark.parametrize('sql', [
+ 'select a.x, b.y from abc a join bcd b on ',
+ 'select a.x, b.y from abc a join bcd b on a.id = b.id OR ',
+])
+def test_on_suggests_aliases(sql):
+ suggestions = suggest_type(sql, sql)
+ assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}]
+
+
+@pytest.mark.parametrize('sql', [
+ 'select abc.x, bcd.y from abc join bcd on ',
+ 'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ',
+])
+def test_on_suggests_tables(sql):
+ suggestions = suggest_type(sql, sql)
+ assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}]
+
+
+@pytest.mark.parametrize('sql', [
+ 'select a.x, b.y from abc a join bcd b on a.id = ',
+ 'select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ',
+])
+def test_on_suggests_aliases_right_side(sql):
+ suggestions = suggest_type(sql, sql)
+ assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}]
+
+
+@pytest.mark.parametrize('sql', [
+ 'select abc.x, bcd.y from abc join bcd on ',
+ 'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ',
+])
+def test_on_suggests_tables_right_side(sql):
+ suggestions = suggest_type(sql, sql)
+ assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}]
+
+
+@pytest.mark.parametrize('col_list', ['', 'col1, '])
+def test_join_using_suggests_common_columns(col_list):
+ text = 'select * from abc inner join def using (' + col_list
+ assert suggest_type(text, text) == [
+ {'type': 'column',
+ 'tables': [(None, 'abc', None), (None, 'def', None)],
+ 'drop_unique': True}]
+
+@pytest.mark.parametrize('sql', [
+ 'SELECT * FROM abc a JOIN def d ON a.id = d.id JOIN ghi g ON g.',
+ 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.id2 JOIN ghi g ON d.id = g.id AND g.',
+])
+def test_two_join_alias_dot_suggests_cols1(sql):
+ suggestions = suggest_type(sql, sql)
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'column', 'tables': [(None, 'ghi', 'g')]},
+ {'type': 'table', 'schema': 'g'},
+ {'type': 'view', 'schema': 'g'},
+ {'type': 'function', 'schema': 'g'}])
+
+def test_2_statements_2nd_current():
+ suggestions = suggest_type('select * from a; select * from ',
+ 'select * from a; select * from ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+ suggestions = suggest_type('select * from a; select from b',
+ 'select * from a; select ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['b']},
+ {'type': 'column', 'tables': [(None, 'b', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+ # Should work even if first statement is invalid
+ suggestions = suggest_type('select * from; select * from ',
+ 'select * from; select * from ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+
+def test_2_statements_1st_current():
+ suggestions = suggest_type('select * from ; select * from b',
+ 'select * from ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+ suggestions = suggest_type('select from a; select * from b',
+ 'select ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['a']},
+ {'type': 'column', 'tables': [(None, 'a', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+
+def test_3_statements_2nd_current():
+ suggestions = suggest_type('select * from a; select * from ; select * from c',
+ 'select * from a; select * from ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+ suggestions = suggest_type('select * from a; select from b; select * from c',
+ 'select * from a; select ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'alias', 'aliases': ['b']},
+ {'type': 'column', 'tables': [(None, 'b', None)]},
+ {'type': 'function', 'schema': []},
+ {'type': 'keyword'},
+ ])
+
+
+def test_create_db_with_template():
+ suggestions = suggest_type('create database foo with template ',
+ 'create database foo with template ')
+
+ assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'database'}])
+
+
+@pytest.mark.parametrize('initial_text', ['', ' ', '\t \t'])
+def test_specials_included_for_initial_completion(initial_text):
+ suggestions = suggest_type(initial_text, initial_text)
+
+ assert sorted_dicts(suggestions) == \
+ sorted_dicts([{'type': 'keyword'}, {'type': 'special'}])
+
+
+def test_specials_not_included_after_initial_token():
+ suggestions = suggest_type('create table foo (dt d',
+ 'create table foo (dt d')
+
+ assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'keyword'}])
+
+
+def test_drop_schema_qualified_table_suggests_only_tables():
+ text = 'DROP TABLE schema_name.table_name'
+ suggestions = suggest_type(text, text)
+ assert suggestions == [{'type': 'table', 'schema': 'schema_name'}]
+
+
+@pytest.mark.parametrize('text', [',', ' ,', 'sel ,'])
+def test_handle_pre_completion_comma_gracefully(text):
+ suggestions = suggest_type(text, text)
+
+ assert iter(suggestions)
+
+
+def test_cross_join():
+ text = 'select * from v1 cross join v2 JOIN v1.id, '
+ suggestions = suggest_type(text, text)
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+
+@pytest.mark.parametrize('expression', [
+ 'SELECT 1 AS ',
+ 'SELECT 1 FROM tabl AS ',
+])
+def test_after_as(expression):
+ suggestions = suggest_type(expression, expression)
+ assert set(suggestions) == set()
+
+
+@pytest.mark.parametrize('expression', [
+ '\\. ',
+ 'select 1; \\. ',
+ 'select 1;\\. ',
+ 'select 1 ; \\. ',
+ 'source ',
+ 'truncate table test; source ',
+ 'truncate table test ; source ',
+ 'truncate table test;source ',
+])
+def test_source_is_file(expression):
+ suggestions = suggest_type(expression, expression)
+ assert suggestions == [{'type': 'file_name'}]
+
+
+@pytest.mark.parametrize("expression", [
+ "\\f ",
+])
+def test_favorite_name_suggestion(expression):
+ suggestions = suggest_type(expression, expression)
+ assert suggestions == [{'type': 'favoritequery'}]
+
+
+def test_order_by():
+ text = 'select * from foo order by '
+ suggestions = suggest_type(text, text)
+ assert suggestions == [{'tables': [(None, 'foo', None)], 'type': 'column'}]
+
+
+def test_quoted_where():
+ text = "'where i=';"
+ suggestions = suggest_type(text, text)
+ assert suggestions == [{'type': 'keyword'}]
diff --git a/test/test_completion_refresher.py b/test/test_completion_refresher.py
new file mode 100644
index 0000000..cdc2fb5
--- /dev/null
+++ b/test/test_completion_refresher.py
@@ -0,0 +1,88 @@
+import time
+import pytest
+from unittest.mock import Mock, patch
+
+
+@pytest.fixture
+def refresher():
+ from mycli.completion_refresher import CompletionRefresher
+ return CompletionRefresher()
+
+
+def test_ctor(refresher):
+ """Refresher object should contain a few handlers.
+
+ :param refresher:
+ :return:
+
+ """
+ assert len(refresher.refreshers) > 0
+ actual_handlers = list(refresher.refreshers.keys())
+ expected_handlers = ['databases', 'schemata', 'tables', 'users', 'functions',
+ 'special_commands', 'show_commands']
+ assert expected_handlers == actual_handlers
+
+
+def test_refresh_called_once(refresher):
+ """
+
+ :param refresher:
+ :return:
+ """
+ callbacks = Mock()
+ sqlexecute = Mock()
+
+ with patch.object(refresher, '_bg_refresh') as bg_refresh:
+ actual = refresher.refresh(sqlexecute, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert len(actual) == 1
+ assert len(actual[0]) == 4
+ assert actual[0][3] == 'Auto-completion refresh started in the background.'
+ bg_refresh.assert_called_with(sqlexecute, callbacks, {})
+
+
+def test_refresh_called_twice(refresher):
+ """If refresh is called a second time, it should be restarted.
+
+ :param refresher:
+ :return:
+
+ """
+ callbacks = Mock()
+
+ sqlexecute = Mock()
+
+ def dummy_bg_refresh(*args):
+ time.sleep(3) # seconds
+
+ refresher._bg_refresh = dummy_bg_refresh
+
+ actual1 = refresher.refresh(sqlexecute, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert len(actual1) == 1
+ assert len(actual1[0]) == 4
+ assert actual1[0][3] == 'Auto-completion refresh started in the background.'
+
+ actual2 = refresher.refresh(sqlexecute, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert len(actual2) == 1
+ assert len(actual2[0]) == 4
+ assert actual2[0][3] == 'Auto-completion refresh restarted.'
+
+
+def test_refresh_with_callbacks(refresher):
+ """Callbacks must be called.
+
+ :param refresher:
+
+ """
+ callbacks = [Mock()]
+ sqlexecute_class = Mock()
+ sqlexecute = Mock()
+
+ with patch('mycli.completion_refresher.SQLExecute', sqlexecute_class):
+ # Set refreshers to 0: we're not testing refresh logic here
+ refresher.refreshers = {}
+ refresher.refresh(sqlexecute, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert (callbacks[0].call_count == 1)
diff --git a/test/test_config.py b/test/test_config.py
new file mode 100644
index 0000000..7f2b244
--- /dev/null
+++ b/test/test_config.py
@@ -0,0 +1,196 @@
+"""Unit tests for the mycli.config module."""
+from io import BytesIO, StringIO, TextIOWrapper
+import os
+import struct
+import sys
+import tempfile
+import pytest
+
+from mycli.config import (get_mylogin_cnf_path, open_mylogin_cnf,
+ read_and_decrypt_mylogin_cnf, read_config_file,
+ str_to_bool, strip_matching_quotes)
+
+LOGIN_PATH_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__),
+ 'mylogin.cnf'))
+
+
+def open_bmylogin_cnf(name):
+ """Open contents of *name* in a BytesIO buffer."""
+ with open(name, 'rb') as f:
+ buf = BytesIO()
+ buf.write(f.read())
+ return buf
+
+def test_read_mylogin_cnf():
+ """Tests that a login path file can be read and decrypted."""
+ mylogin_cnf = open_mylogin_cnf(LOGIN_PATH_FILE)
+
+ assert isinstance(mylogin_cnf, TextIOWrapper)
+
+ contents = mylogin_cnf.read()
+ for word in ('[test]', 'user', 'password', 'host', 'port'):
+ assert word in contents
+
+
+def test_decrypt_blank_mylogin_cnf():
+ """Test that a blank login path file is handled correctly."""
+ mylogin_cnf = read_and_decrypt_mylogin_cnf(BytesIO())
+ assert mylogin_cnf is None
+
+
+def test_corrupted_login_key():
+ """Test that a corrupted login path key is handled correctly."""
+ buf = open_bmylogin_cnf(LOGIN_PATH_FILE)
+
+ # Skip past the unused bytes
+ buf.seek(4)
+
+ # Write null bytes over half the login key
+ buf.write(b'\0\0\0\0\0\0\0\0\0\0')
+
+ buf.seek(0)
+ mylogin_cnf = read_and_decrypt_mylogin_cnf(buf)
+
+ assert mylogin_cnf is None
+
+
+def test_corrupted_pad():
+ """Tests that a login path file with a corrupted pad is partially read."""
+ buf = open_bmylogin_cnf(LOGIN_PATH_FILE)
+
+ # Skip past the login key
+ buf.seek(24)
+
+ # Skip option group
+ len_buf = buf.read(4)
+ cipher_len, = struct.unpack("<i", len_buf)
+ buf.read(cipher_len)
+
+ # Corrupt the pad for the user line
+ len_buf = buf.read(4)
+ cipher_len, = struct.unpack("<i", len_buf)
+ buf.read(cipher_len - 1)
+ buf.write(b'\0')
+
+ buf.seek(0)
+ mylogin_cnf = TextIOWrapper(read_and_decrypt_mylogin_cnf(buf))
+ contents = mylogin_cnf.read()
+ for word in ('[test]', 'password', 'host', 'port'):
+ assert word in contents
+ assert 'user' not in contents
+
+
+def test_get_mylogin_cnf_path():
+ """Tests that the path for .mylogin.cnf is detected."""
+ original_env = None
+ if 'MYSQL_TEST_LOGIN_FILE' in os.environ:
+ original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE')
+ is_windows = sys.platform == 'win32'
+
+ login_cnf_path = get_mylogin_cnf_path()
+
+ if original_env is not None:
+ os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env
+
+ if login_cnf_path is not None:
+ assert login_cnf_path.endswith('.mylogin.cnf')
+
+ if is_windows is True:
+ assert 'MySQL' in login_cnf_path
+ else:
+ home_dir = os.path.expanduser('~')
+ assert login_cnf_path.startswith(home_dir)
+
+
+def test_alternate_get_mylogin_cnf_path():
+ """Tests that the alternate path for .mylogin.cnf is detected."""
+ original_env = None
+ if 'MYSQL_TEST_LOGIN_FILE' in os.environ:
+ original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE')
+
+ _, temp_path = tempfile.mkstemp()
+ os.environ['MYSQL_TEST_LOGIN_FILE'] = temp_path
+
+ login_cnf_path = get_mylogin_cnf_path()
+
+ if original_env is not None:
+ os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env
+
+ assert temp_path == login_cnf_path
+
+
+def test_str_to_bool():
+ """Tests that str_to_bool function converts values correctly."""
+
+ assert str_to_bool(False) is False
+ assert str_to_bool(True) is True
+ assert str_to_bool('False') is False
+ assert str_to_bool('True') is True
+ assert str_to_bool('TRUE') is True
+ assert str_to_bool('1') is True
+ assert str_to_bool('0') is False
+ assert str_to_bool('on') is True
+ assert str_to_bool('off') is False
+ assert str_to_bool('off') is False
+
+ with pytest.raises(ValueError):
+ str_to_bool('foo')
+
+ with pytest.raises(TypeError):
+ str_to_bool(None)
+
+
+def test_read_config_file_list_values_default():
+ """Test that reading a config file uses list_values by default."""
+
+ f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n")
+ config = read_config_file(f)
+
+ assert config['main']['weather'] == u"cloudy with a chance of meatballs"
+
+
+def test_read_config_file_list_values_off():
+ """Test that you can disable list_values when reading a config file."""
+
+ f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n")
+ config = read_config_file(f, list_values=False)
+
+ assert config['main']['weather'] == u"'cloudy with a chance of meatballs'"
+
+
+def test_strip_quotes_with_matching_quotes():
+ """Test that a string with matching quotes is unquoted."""
+
+ s = "May the force be with you."
+ assert s == strip_matching_quotes('"{}"'.format(s))
+ assert s == strip_matching_quotes("'{}'".format(s))
+
+
+def test_strip_quotes_with_unmatching_quotes():
+ """Test that a string with unmatching quotes is not unquoted."""
+
+ s = "May the force be with you."
+ assert '"' + s == strip_matching_quotes('"{}'.format(s))
+ assert s + "'" == strip_matching_quotes("{}'".format(s))
+
+
+def test_strip_quotes_with_empty_string():
+ """Test that an empty string is handled during unquoting."""
+
+ assert '' == strip_matching_quotes('')
+
+
+def test_strip_quotes_with_none():
+ """Test that None is handled during unquoting."""
+
+ assert None is strip_matching_quotes(None)
+
+
+def test_strip_quotes_with_quotes():
+ """Test that strings with quotes in them are handled during unquoting."""
+
+ s1 = 'Darth Vader said, "Luke, I am your father."'
+ assert s1 == strip_matching_quotes(s1)
+
+ s2 = '"Darth Vader said, "Luke, I am your father.""'
+ assert s2[1:-1] == strip_matching_quotes(s2)
diff --git a/test/test_dbspecial.py b/test/test_dbspecial.py
new file mode 100644
index 0000000..21e389c
--- /dev/null
+++ b/test/test_dbspecial.py
@@ -0,0 +1,42 @@
+from mycli.packages.completion_engine import suggest_type
+from .test_completion_engine import sorted_dicts
+from mycli.packages.special.utils import format_uptime
+
+
+def test_u_suggests_databases():
+ suggestions = suggest_type('\\u ', '\\u ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'database'}])
+
+
+def test_describe_table():
+ suggestions = suggest_type('\\dt', '\\dt ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+
+def test_list_or_show_create_tables():
+ suggestions = suggest_type('\\dt+', '\\dt+ ')
+ assert sorted_dicts(suggestions) == sorted_dicts([
+ {'type': 'table', 'schema': []},
+ {'type': 'view', 'schema': []},
+ {'type': 'schema'}])
+
+
+def test_format_uptime():
+ seconds = 59
+ assert '59 sec' == format_uptime(seconds)
+
+ seconds = 120
+ assert '2 min 0 sec' == format_uptime(seconds)
+
+ seconds = 54890
+ assert '15 hours 14 min 50 sec' == format_uptime(seconds)
+
+ seconds = 598244
+ assert '6 days 22 hours 10 min 44 sec' == format_uptime(seconds)
+
+ seconds = 522600
+ assert '6 days 1 hour 10 min 0 sec' == format_uptime(seconds)
diff --git a/test/test_main.py b/test/test_main.py
new file mode 100644
index 0000000..64cba0a
--- /dev/null
+++ b/test/test_main.py
@@ -0,0 +1,548 @@
+import os
+import shutil
+
+import click
+from click.testing import CliRunner
+
+from mycli.main import MyCli, cli, thanks_picker
+from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
+from mycli.sqlexecute import ServerInfo
+from .utils import USER, HOST, PORT, PASSWORD, dbtest, run
+
+from textwrap import dedent
+from collections import namedtuple
+
+from tempfile import NamedTemporaryFile
+from textwrap import dedent
+
+
+test_dir = os.path.abspath(os.path.dirname(__file__))
+project_dir = os.path.dirname(test_dir)
+default_config_file = os.path.join(project_dir, 'test', 'myclirc')
+login_path_file = os.path.join(test_dir, 'mylogin.cnf')
+
+os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file
+CLI_ARGS = ['--user', USER, '--host', HOST, '--port', PORT,
+ '--password', PASSWORD, '--myclirc', default_config_file,
+ '--defaults-file', default_config_file,
+ 'mycli_test_db']
+
+
+@dbtest
+def test_execute_arg(executor):
+ run(executor, 'create table test (a text)')
+ run(executor, 'insert into test values("abc")')
+
+ sql = 'select * from test;'
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql])
+
+ assert result.exit_code == 0
+ assert 'abc' in result.output
+
+ result = runner.invoke(cli, args=CLI_ARGS + ['--execute', sql])
+
+ assert result.exit_code == 0
+ assert 'abc' in result.output
+
+ expected = 'a\nabc\n'
+
+ assert expected in result.output
+
+
+@dbtest
+def test_execute_arg_with_table(executor):
+ run(executor, 'create table test (a text)')
+ run(executor, 'insert into test values("abc")')
+
+ sql = 'select * from test;'
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--table'])
+ expected = '+-----+\n| a |\n+-----+\n| abc |\n+-----+\n'
+
+ assert result.exit_code == 0
+ assert expected in result.output
+
+
+@dbtest
+def test_execute_arg_with_csv(executor):
+ run(executor, 'create table test (a text)')
+ run(executor, 'insert into test values("abc")')
+
+ sql = 'select * from test;'
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--csv'])
+ expected = '"a"\n"abc"\n'
+
+ assert result.exit_code == 0
+ assert expected in "".join(result.output)
+
+
+@dbtest
+def test_batch_mode(executor):
+ run(executor, '''create table test(a text)''')
+ run(executor, '''insert into test values('abc'), ('def'), ('ghi')''')
+
+ sql = (
+ 'select count(*) from test;\n'
+ 'select * from test limit 1;'
+ )
+
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS, input=sql)
+
+ assert result.exit_code == 0
+ assert 'count(*)\n3\na\nabc\n' in "".join(result.output)
+
+
+@dbtest
+def test_batch_mode_table(executor):
+ run(executor, '''create table test(a text)''')
+ run(executor, '''insert into test values('abc'), ('def'), ('ghi')''')
+
+ sql = (
+ 'select count(*) from test;\n'
+ 'select * from test limit 1;'
+ )
+
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ['-t'], input=sql)
+
+ expected = (dedent("""\
+ +----------+
+ | count(*) |
+ +----------+
+ | 3 |
+ +----------+
+ +-----+
+ | a |
+ +-----+
+ | abc |
+ +-----+"""))
+
+ assert result.exit_code == 0
+ assert expected in result.output
+
+
+@dbtest
+def test_batch_mode_csv(executor):
+ run(executor, '''create table test(a text, b text)''')
+ run(executor,
+ '''insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')''')
+
+ sql = 'select * from test;'
+
+ runner = CliRunner()
+ result = runner.invoke(cli, args=CLI_ARGS + ['--csv'], input=sql)
+
+ expected = '"a","b"\n"abc","de\nf"\n"ghi","jkl"\n'
+
+ assert result.exit_code == 0
+ assert expected in "".join(result.output)
+
+
+def test_thanks_picker_utf8():
+ name = thanks_picker()
+ assert name and isinstance(name, str)
+
+
+def test_help_strings_end_with_periods():
+ """Make sure click options have help text that end with a period."""
+ for param in cli.params:
+ if isinstance(param, click.core.Option):
+ assert hasattr(param, 'help')
+ assert param.help.endswith('.')
+
+
+def test_command_descriptions_end_with_periods():
+ """Make sure that mycli commands' descriptions end with a period."""
+ MyCli()
+ for _, command in SPECIAL_COMMANDS.items():
+ assert command[3].endswith('.')
+
+
+def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager):
+ global clickoutput
+ clickoutput = ""
+ m = MyCli(myclirc=default_config_file)
+
+ class TestOutput():
+ def get_size(self):
+ size = namedtuple('Size', 'rows columns')
+ size.columns, size.rows = terminal_size
+ return size
+
+ class TestExecute():
+ host = 'test'
+ user = 'test'
+ dbname = 'test'
+ server_info = ServerInfo.from_version_string('unknown')
+ port = 0
+
+ def server_type(self):
+ return ['test']
+
+ class PromptBuffer():
+ output = TestOutput()
+
+ m.prompt_app = PromptBuffer()
+ m.sqlexecute = TestExecute()
+ m.explicit_pager = explicit_pager
+
+ def echo_via_pager(s):
+ assert expect_pager
+ global clickoutput
+ clickoutput += "".join(s)
+
+ def secho(s):
+ assert not expect_pager
+ global clickoutput
+ clickoutput += s + "\n"
+
+ monkeypatch.setattr(click, 'echo_via_pager', echo_via_pager)
+ monkeypatch.setattr(click, 'secho', secho)
+ m.output(testdata)
+ if clickoutput.endswith("\n"):
+ clickoutput = clickoutput[:-1]
+ assert clickoutput == "\n".join(testdata)
+
+
+def test_conditional_pager(monkeypatch):
+ testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split(
+ " ")
+ # User didn't set pager, output doesn't fit screen -> pager
+ output(
+ monkeypatch,
+ terminal_size=(5, 10),
+ testdata=testdata,
+ explicit_pager=False,
+ expect_pager=True
+ )
+ # User didn't set pager, output fits screen -> no pager
+ output(
+ monkeypatch,
+ terminal_size=(20, 20),
+ testdata=testdata,
+ explicit_pager=False,
+ expect_pager=False
+ )
+ # User manually configured pager, output doesn't fit screen -> pager
+ output(
+ monkeypatch,
+ terminal_size=(5, 10),
+ testdata=testdata,
+ explicit_pager=True,
+ expect_pager=True
+ )
+ # User manually configured pager, output fit screen -> pager
+ output(
+ monkeypatch,
+ terminal_size=(20, 20),
+ testdata=testdata,
+ explicit_pager=True,
+ expect_pager=True
+ )
+
+ SPECIAL_COMMANDS['nopager'].handler()
+ output(
+ monkeypatch,
+ terminal_size=(5, 10),
+ testdata=testdata,
+ explicit_pager=False,
+ expect_pager=False
+ )
+ SPECIAL_COMMANDS['pager'].handler('')
+
+
+def test_reserved_space_is_integer():
+ """Make sure that reserved space is returned as an integer."""
+ def stub_terminal_size():
+ return (5, 5)
+
+ old_func = shutil.get_terminal_size
+
+ shutil.get_terminal_size = stub_terminal_size
+ mycli = MyCli()
+ assert isinstance(mycli.get_reserved_space(), int)
+
+ shutil.get_terminal_size = old_func
+
+
+def test_list_dsn():
+ runner = CliRunner()
+ with NamedTemporaryFile(mode="w") as myclirc:
+ myclirc.write(dedent("""\
+ [alias_dsn]
+ test = mysql://test/test
+ """))
+ myclirc.flush()
+ args = ['--list-dsn', '--myclirc', myclirc.name]
+ result = runner.invoke(cli, args=args)
+ assert result.output == "test\n"
+ result = runner.invoke(cli, args=args + ['--verbose'])
+ assert result.output == "test : mysql://test/test\n"
+
+
+def test_prettify_statement():
+ statement = 'SELECT 1'
+ m = MyCli()
+ pretty_statement = m.handle_prettify_binding(statement)
+ assert pretty_statement == 'SELECT\n 1;'
+
+
+def test_unprettify_statement():
+ statement = 'SELECT\n 1'
+ m = MyCli()
+ unpretty_statement = m.handle_unprettify_binding(statement)
+ assert unpretty_statement == 'SELECT 1;'
+
+
+def test_list_ssh_config():
+ runner = CliRunner()
+ with NamedTemporaryFile(mode="w") as ssh_config:
+ ssh_config.write(dedent("""\
+ Host test
+ Hostname test.example.com
+ User joe
+ Port 22222
+ IdentityFile ~/.ssh/gateway
+ """))
+ ssh_config.flush()
+ args = ['--list-ssh-config', '--ssh-config-path', ssh_config.name]
+ result = runner.invoke(cli, args=args)
+ assert "test\n" in result.output
+ result = runner.invoke(cli, args=args + ['--verbose'])
+ assert "test : test.example.com\n" in result.output
+
+
+def test_dsn(monkeypatch):
+ # Setup classes to mock mycli.main.MyCli
+ class Formatter:
+ format_name = None
+
+ class Logger:
+ def debug(self, *args, **args_dict):
+ pass
+
+ def warning(self, *args, **args_dict):
+ pass
+
+ class MockMyCli:
+ config = {'alias_dsn': {}}
+
+ def __init__(self, **args):
+ self.logger = Logger()
+ self.destructive_warning = False
+ self.formatter = Formatter()
+
+ def connect(self, **args):
+ MockMyCli.connect_args = args
+
+ def run_query(self, query, new_line=True):
+ pass
+
+ import mycli.main
+ monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli)
+ runner = CliRunner()
+
+ # When a user supplies a DSN as database argument to mycli,
+ # use these values.
+ result = runner.invoke(mycli.main.cli, args=[
+ "mysql://dsn_user:dsn_passwd@dsn_host:1/dsn_database"]
+ )
+ assert result.exit_code == 0, result.output + " " + str(result.exception)
+ assert \
+ MockMyCli.connect_args["user"] == "dsn_user" and \
+ MockMyCli.connect_args["passwd"] == "dsn_passwd" and \
+ MockMyCli.connect_args["host"] == "dsn_host" and \
+ MockMyCli.connect_args["port"] == 1 and \
+ MockMyCli.connect_args["database"] == "dsn_database"
+
+ MockMyCli.connect_args = None
+
+ # When a use supplies a DSN as database argument to mycli,
+ # and used command line arguments, use the command line
+ # arguments.
+ result = runner.invoke(mycli.main.cli, args=[
+ "mysql://dsn_user:dsn_passwd@dsn_host:2/dsn_database",
+ "--user", "arg_user",
+ "--password", "arg_password",
+ "--host", "arg_host",
+ "--port", "3",
+ "--database", "arg_database",
+ ])
+ assert result.exit_code == 0, result.output + " " + str(result.exception)
+ assert \
+ MockMyCli.connect_args["user"] == "arg_user" and \
+ MockMyCli.connect_args["passwd"] == "arg_password" and \
+ MockMyCli.connect_args["host"] == "arg_host" and \
+ MockMyCli.connect_args["port"] == 3 and \
+ MockMyCli.connect_args["database"] == "arg_database"
+
+ MockMyCli.config = {
+ 'alias_dsn': {
+ 'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database'
+ }
+ }
+ MockMyCli.connect_args = None
+
+ # When a user uses a DSN from the configuration file (alias_dsn),
+ # use these values.
+ result = runner.invoke(cli, args=['--dsn', 'test'])
+ assert result.exit_code == 0, result.output + " " + str(result.exception)
+ assert \
+ MockMyCli.connect_args["user"] == "alias_dsn_user" and \
+ MockMyCli.connect_args["passwd"] == "alias_dsn_passwd" and \
+ MockMyCli.connect_args["host"] == "alias_dsn_host" and \
+ MockMyCli.connect_args["port"] == 4 and \
+ MockMyCli.connect_args["database"] == "alias_dsn_database"
+
+ MockMyCli.config = {
+ 'alias_dsn': {
+ 'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database'
+ }
+ }
+ MockMyCli.connect_args = None
+
+ # When a user uses a DSN from the configuration file (alias_dsn)
+ # and used command line arguments, use the command line arguments.
+ result = runner.invoke(cli, args=[
+ '--dsn', 'test', '',
+ "--user", "arg_user",
+ "--password", "arg_password",
+ "--host", "arg_host",
+ "--port", "5",
+ "--database", "arg_database",
+ ])
+ assert result.exit_code == 0, result.output + " " + str(result.exception)
+ assert \
+ MockMyCli.connect_args["user"] == "arg_user" and \
+ MockMyCli.connect_args["passwd"] == "arg_password" and \
+ MockMyCli.connect_args["host"] == "arg_host" and \
+ MockMyCli.connect_args["port"] == 5 and \
+ MockMyCli.connect_args["database"] == "arg_database"
+
+ # Use a DSN without password
+ result = runner.invoke(mycli.main.cli, args=[
+ "mysql://dsn_user@dsn_host:6/dsn_database"]
+ )
+ assert result.exit_code == 0, result.output + " " + str(result.exception)
+ assert \
+ MockMyCli.connect_args["user"] == "dsn_user" and \
+ MockMyCli.connect_args["passwd"] is None and \
+ MockMyCli.connect_args["host"] == "dsn_host" and \
+ MockMyCli.connect_args["port"] == 6 and \
+ MockMyCli.connect_args["database"] == "dsn_database"
+
+
+def test_ssh_config(monkeypatch):
+ # Setup classes to mock mycli.main.MyCli
+ class Formatter:
+ format_name = None
+
+ class Logger:
+ def debug(self, *args, **args_dict):
+ pass
+
+ def warning(self, *args, **args_dict):
+ pass
+
+ class MockMyCli:
+ config = {'alias_dsn': {}}
+
+ def __init__(self, **args):
+ self.logger = Logger()
+ self.destructive_warning = False
+ self.formatter = Formatter()
+
+ def connect(self, **args):
+ MockMyCli.connect_args = args
+
+ def run_query(self, query, new_line=True):
+ pass
+
+ import mycli.main
+ monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli)
+ runner = CliRunner()
+
+ # Setup temporary configuration
+ with NamedTemporaryFile(mode="w") as ssh_config:
+ ssh_config.write(dedent("""\
+ Host test
+ Hostname test.example.com
+ User joe
+ Port 22222
+ IdentityFile ~/.ssh/gateway
+ """))
+ ssh_config.flush()
+
+ # When a user supplies a ssh config.
+ result = runner.invoke(mycli.main.cli, args=[
+ "--ssh-config-path",
+ ssh_config.name,
+ "--ssh-config-host",
+ "test"
+ ])
+ assert result.exit_code == 0, result.output + \
+ " " + str(result.exception)
+ assert \
+ MockMyCli.connect_args["ssh_user"] == "joe" and \
+ MockMyCli.connect_args["ssh_host"] == "test.example.com" and \
+ MockMyCli.connect_args["ssh_port"] == 22222 and \
+ MockMyCli.connect_args["ssh_key_filename"] == os.getenv(
+ "HOME") + "/.ssh/gateway"
+
+ # When a user supplies a ssh config host as argument to mycli,
+ # and used command line arguments, use the command line
+ # arguments.
+ result = runner.invoke(mycli.main.cli, args=[
+ "--ssh-config-path",
+ ssh_config.name,
+ "--ssh-config-host",
+ "test",
+ "--ssh-user", "arg_user",
+ "--ssh-host", "arg_host",
+ "--ssh-port", "3",
+ "--ssh-key-filename", "/path/to/key"
+ ])
+ assert result.exit_code == 0, result.output + \
+ " " + str(result.exception)
+ assert \
+ MockMyCli.connect_args["ssh_user"] == "arg_user" and \
+ MockMyCli.connect_args["ssh_host"] == "arg_host" and \
+ MockMyCli.connect_args["ssh_port"] == 3 and \
+ MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key"
+
+
+@dbtest
+def test_init_command_arg(executor):
+ init_command = "set sql_select_limit=1000"
+ sql = 'show variables like "sql_select_limit";'
+ runner = CliRunner()
+ result = runner.invoke(
+ cli, args=CLI_ARGS + ["--init-command", init_command], input=sql
+ )
+
+ expected = "sql_select_limit\t1000\n"
+ assert result.exit_code == 0
+ assert expected in result.output
+
+
+@dbtest
+def test_init_command_multiple_arg(executor):
+ init_command = 'set sql_select_limit=2000; set max_join_size=20000'
+ sql = (
+ 'show variables like "sql_select_limit";\n'
+ 'show variables like "max_join_size"'
+ )
+ runner = CliRunner()
+ result = runner.invoke(
+ cli, args=CLI_ARGS + ['--init-command', init_command], input=sql
+ )
+
+ expected_sql_select_limit = 'sql_select_limit\t2000\n'
+ expected_max_join_size = 'max_join_size\t20000\n'
+
+ assert result.exit_code == 0
+ assert expected_sql_select_limit in result.output
+ assert expected_max_join_size in result.output
diff --git a/test/test_naive_completion.py b/test/test_naive_completion.py
new file mode 100644
index 0000000..32b2abd
--- /dev/null
+++ b/test/test_naive_completion.py
@@ -0,0 +1,63 @@
+import pytest
+from prompt_toolkit.completion import Completion
+from prompt_toolkit.document import Document
+
+
+@pytest.fixture
+def completer():
+ import mycli.sqlcompleter as sqlcompleter
+ return sqlcompleter.SQLCompleter(smart_completion=False)
+
+
+@pytest.fixture
+def complete_event():
+ from unittest.mock import Mock
+ return Mock()
+
+
+def test_empty_string_completion(completer, complete_event):
+ text = ''
+ position = 0
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list(map(Completion, sorted(completer.all_completions)))
+
+
+def test_select_keyword_completion(completer, complete_event):
+ text = 'SEL'
+ position = len('SEL')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([Completion(text='SELECT', start_position=-3)])
+
+
+def test_function_name_completion(completer, complete_event):
+ text = 'SELECT MA'
+ position = len('SELECT MA')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='MASTER', start_position=-2),
+ Completion(text='MAX', start_position=-2)])
+
+
+def test_column_name_completion(completer, complete_event):
+ text = 'SELECT FROM users'
+ position = len('SELECT ')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list(map(Completion, sorted(completer.all_completions)))
+
+
+def test_special_name_completion(completer, complete_event):
+ text = '\\'
+ position = len('\\')
+ result = set(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ # Special commands will NOT be suggested during naive completion mode.
+ assert result == set()
diff --git a/test/test_parseutils.py b/test/test_parseutils.py
new file mode 100644
index 0000000..920a08d
--- /dev/null
+++ b/test/test_parseutils.py
@@ -0,0 +1,190 @@
+import pytest
+from mycli.packages.parseutils import (
+ extract_tables, query_starts_with, queries_start_with, is_destructive, query_has_where_clause,
+ is_dropping_database)
+
+
+def test_empty_string():
+ tables = extract_tables('')
+ assert tables == []
+
+
+def test_simple_select_single_table():
+ tables = extract_tables('select * from abc')
+ assert tables == [(None, 'abc', None)]
+
+
+def test_simple_select_single_table_schema_qualified():
+ tables = extract_tables('select * from abc.def')
+ assert tables == [('abc', 'def', None)]
+
+
+def test_simple_select_multiple_tables():
+ tables = extract_tables('select * from abc, def')
+ assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
+
+
+def test_simple_select_multiple_tables_schema_qualified():
+ tables = extract_tables('select * from abc.def, ghi.jkl')
+ assert sorted(tables) == [('abc', 'def', None), ('ghi', 'jkl', None)]
+
+
+def test_simple_select_with_cols_single_table():
+ tables = extract_tables('select a,b from abc')
+ assert tables == [(None, 'abc', None)]
+
+
+def test_simple_select_with_cols_single_table_schema_qualified():
+ tables = extract_tables('select a,b from abc.def')
+ assert tables == [('abc', 'def', None)]
+
+
+def test_simple_select_with_cols_multiple_tables():
+ tables = extract_tables('select a,b from abc, def')
+ assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
+
+
+def test_simple_select_with_cols_multiple_tables_with_schema():
+ tables = extract_tables('select a,b from abc.def, def.ghi')
+ assert sorted(tables) == [('abc', 'def', None), ('def', 'ghi', None)]
+
+
+def test_select_with_hanging_comma_single_table():
+ tables = extract_tables('select a, from abc')
+ assert tables == [(None, 'abc', None)]
+
+
+def test_select_with_hanging_comma_multiple_tables():
+ tables = extract_tables('select a, from abc, def')
+ assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
+
+
+def test_select_with_hanging_period_multiple_tables():
+ tables = extract_tables('SELECT t1. FROM tabl1 t1, tabl2 t2')
+ assert sorted(tables) == [(None, 'tabl1', 't1'), (None, 'tabl2', 't2')]
+
+
+def test_simple_insert_single_table():
+ tables = extract_tables('insert into abc (id, name) values (1, "def")')
+
+ # sqlparse mistakenly assigns an alias to the table
+ # assert tables == [(None, 'abc', None)]
+ assert tables == [(None, 'abc', 'abc')]
+
+
+@pytest.mark.xfail
+def test_simple_insert_single_table_schema_qualified():
+ tables = extract_tables('insert into abc.def (id, name) values (1, "def")')
+ assert tables == [('abc', 'def', None)]
+
+
+def test_simple_update_table():
+ tables = extract_tables('update abc set id = 1')
+ assert tables == [(None, 'abc', None)]
+
+
+def test_simple_update_table_with_schema():
+ tables = extract_tables('update abc.def set id = 1')
+ assert tables == [('abc', 'def', None)]
+
+
+def test_join_table():
+ tables = extract_tables('SELECT * FROM abc a JOIN def d ON a.id = d.num')
+ assert sorted(tables) == [(None, 'abc', 'a'), (None, 'def', 'd')]
+
+
+def test_join_table_schema_qualified():
+ tables = extract_tables(
+ 'SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num')
+ assert tables == [('abc', 'def', 'x'), ('ghi', 'jkl', 'y')]
+
+
+def test_join_as_table():
+ tables = extract_tables('SELECT * FROM my_table AS m WHERE m.a > 5')
+ assert tables == [(None, 'my_table', 'm')]
+
+
+def test_query_starts_with():
+ query = 'USE test;'
+ assert query_starts_with(query, ('use', )) is True
+
+ query = 'DROP DATABASE test;'
+ assert query_starts_with(query, ('use', )) is False
+
+
+def test_query_starts_with_comment():
+ query = '# comment\nUSE test;'
+ assert query_starts_with(query, ('use', )) is True
+
+
+def test_queries_start_with():
+ sql = (
+ '# comment\n'
+ 'show databases;'
+ 'use foo;'
+ )
+ assert queries_start_with(sql, ('show', 'select')) is True
+ assert queries_start_with(sql, ('use', 'drop')) is True
+ assert queries_start_with(sql, ('delete', 'update')) is False
+
+
+def test_is_destructive():
+ sql = (
+ 'use test;\n'
+ 'show databases;\n'
+ 'drop database foo;'
+ )
+ assert is_destructive(sql) is True
+
+
+def test_is_destructive_update_with_where_clause():
+ sql = (
+ 'use test;\n'
+ 'show databases;\n'
+ 'UPDATE test SET x = 1 WHERE id = 1;'
+ )
+ assert is_destructive(sql) is False
+
+
+def test_is_destructive_update_without_where_clause():
+ sql = (
+ 'use test;\n'
+ 'show databases;\n'
+ 'UPDATE test SET x = 1;'
+ )
+ assert is_destructive(sql) is True
+
+
+@pytest.mark.parametrize(
+ ('sql', 'has_where_clause'),
+ [
+ ('update test set dummy = 1;', False),
+ ('update test set dummy = 1 where id = 1);', True),
+ ],
+)
+def test_query_has_where_clause(sql, has_where_clause):
+ assert query_has_where_clause(sql) is has_where_clause
+
+
+@pytest.mark.parametrize(
+ ('sql', 'dbname', 'is_dropping'),
+ [
+ ('select bar from foo', 'foo', False),
+ ('drop database "foo";', '`foo`', True),
+ ('drop schema foo', 'foo', True),
+ ('drop schema foo', 'bar', False),
+ ('drop database bar', 'foo', False),
+ ('drop database foo', None, False),
+ ('drop database foo; create database foo', 'foo', False),
+ ('drop database foo; create database bar', 'foo', True),
+ ('select bar from foo; drop database bazz', 'foo', False),
+ ('select bar from foo; drop database bazz', 'bazz', True),
+ ('-- dropping database \n '
+ 'drop -- really dropping \n '
+ 'schema abc -- now it is dropped',
+ 'abc',
+ True)
+ ]
+)
+def test_is_dropping_database(sql, dbname, is_dropping):
+ assert is_dropping_database(sql, dbname) == is_dropping
diff --git a/test/test_plan.wiki b/test/test_plan.wiki
new file mode 100644
index 0000000..43e9083
--- /dev/null
+++ b/test/test_plan.wiki
@@ -0,0 +1,38 @@
+= Gross Checks =
+ * [ ] Check connecting to a local database.
+ * [ ] Check connecting to a remote database.
+ * [ ] Check connecting to a database with a user/password.
+ * [ ] Check connecting to a non-existent database.
+ * [ ] Test changing the database.
+
+ == PGExecute ==
+ * [ ] Test successful execution given a cursor.
+ * [ ] Test unsuccessful execution with a syntax error.
+ * [ ] Test a series of executions with the same cursor without failure.
+ * [ ] Test a series of executions with the same cursor with failure.
+ * [ ] Test passing in a special command.
+
+ == Naive Autocompletion ==
+ * [ ] Input empty string, ask for completions - Everything.
+ * [ ] Input partial prefix, ask for completions - Stars with prefix.
+ * [ ] Input fully autocompleted string, ask for completions - Only full match
+ * [ ] Input non-existent prefix, ask for completions - nothing
+ * [ ] Input lowercase prefix - case insensitive completions
+
+ == Smart Autocompletion ==
+ * [ ] Input empty string and check if only keywords are returned.
+ * [ ] Input SELECT prefix and check if only columns and '*' are returned.
+ * [ ] Input SELECT blah - only keywords are returned.
+ * [ ] Input SELECT * FROM - Table names only
+
+ == PGSpecial ==
+ * [ ] Test \d
+ * [ ] Test \d tablename
+ * [ ] Test \d tablena*
+ * [ ] Test \d non-existent-tablename
+ * [ ] Test \d index
+ * [ ] Test \d sequence
+ * [ ] Test \d view
+
+ == Exceptionals ==
+ * [ ] Test the 'use' command to change db.
diff --git a/test/test_prompt_utils.py b/test/test_prompt_utils.py
new file mode 100644
index 0000000..2373fac
--- /dev/null
+++ b/test/test_prompt_utils.py
@@ -0,0 +1,11 @@
+import click
+
+from mycli.packages.prompt_utils import confirm_destructive_query
+
+
+def test_confirm_destructive_query_notty():
+ stdin = click.get_text_stream('stdin')
+ assert stdin.isatty() is False
+
+ sql = 'drop database foo;'
+ assert confirm_destructive_query(sql) is None
diff --git a/test/test_smart_completion_public_schema_only.py b/test/test_smart_completion_public_schema_only.py
new file mode 100644
index 0000000..e7d460a
--- /dev/null
+++ b/test/test_smart_completion_public_schema_only.py
@@ -0,0 +1,385 @@
+import pytest
+from unittest.mock import patch
+from prompt_toolkit.completion import Completion
+from prompt_toolkit.document import Document
+import mycli.packages.special.main as special
+
+metadata = {
+ 'users': ['id', 'email', 'first_name', 'last_name'],
+ 'orders': ['id', 'ordered_date', 'status'],
+ 'select': ['id', 'insert', 'ABC'],
+ 'réveillé': ['id', 'insert', 'ABC']
+}
+
+
+@pytest.fixture
+def completer():
+
+ import mycli.sqlcompleter as sqlcompleter
+ comp = sqlcompleter.SQLCompleter(smart_completion=True)
+
+ tables, columns = [], []
+
+ for table, cols in metadata.items():
+ tables.append((table,))
+ columns.extend([(table, col) for col in cols])
+
+ comp.set_dbname('test')
+ comp.extend_schemata('test')
+ comp.extend_relations(tables, kind='tables')
+ comp.extend_columns(columns, kind='tables')
+ comp.extend_special_commands(special.COMMANDS)
+
+ return comp
+
+
+@pytest.fixture
+def complete_event():
+ from unittest.mock import Mock
+ return Mock()
+
+
+def test_special_name_completion(completer, complete_event):
+ text = '\\d'
+ position = len('\\d')
+ result = completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event)
+ assert result == [Completion(text='\\dt', start_position=-2)]
+
+
+def test_empty_string_completion(completer, complete_event):
+ text = ''
+ position = 0
+ result = list(
+ completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert list(map(Completion, sorted(completer.keywords) +
+ sorted(completer.special_commands))) == result
+
+
+def test_select_keyword_completion(completer, complete_event):
+ text = 'SEL'
+ position = len('SEL')
+ result = completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event)
+ assert list(result) == list([Completion(text='SELECT', start_position=-3)])
+
+
+def test_table_completion(completer, complete_event):
+ text = 'SELECT * FROM '
+ position = len(text)
+ result = completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event)
+ assert list(result) == list([
+ Completion(text='`réveillé`', start_position=0),
+ Completion(text='`select`', start_position=0),
+ Completion(text='orders', start_position=0),
+ Completion(text='users', start_position=0),
+ ])
+
+
+def test_function_name_completion(completer, complete_event):
+ text = 'SELECT MA'
+ position = len('SELECT MA')
+ result = completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event)
+ assert list(result) == list([Completion(text='MAX', start_position=-2),
+ Completion(text='MASTER', start_position=-2),
+ ])
+
+
+def test_suggested_column_names(completer, complete_event):
+ """Suggest column and function names when selecting from table.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = 'SELECT from users'
+ position = len('SELECT ')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='*', start_position=0),
+ Completion(text='email', start_position=0),
+ Completion(text='first_name', start_position=0),
+ Completion(text='id', start_position=0),
+ Completion(text='last_name', start_position=0),
+ ] +
+ list(map(Completion, completer.functions)) +
+ [Completion(text='users', start_position=0)] +
+ list(map(Completion, completer.keywords)))
+
+
+def test_suggested_column_names_in_function(completer, complete_event):
+ """Suggest column and function names when selecting multiple columns from
+ table.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = 'SELECT MAX( from users'
+ position = len('SELECT MAX(')
+ result = completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event)
+ assert list(result) == list([
+ Completion(text='*', start_position=0),
+ Completion(text='email', start_position=0),
+ Completion(text='first_name', start_position=0),
+ Completion(text='id', start_position=0),
+ Completion(text='last_name', start_position=0)])
+
+
+def test_suggested_column_names_with_table_dot(completer, complete_event):
+ """Suggest column names on table name and dot.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = 'SELECT users. from users'
+ position = len('SELECT users.')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='*', start_position=0),
+ Completion(text='email', start_position=0),
+ Completion(text='first_name', start_position=0),
+ Completion(text='id', start_position=0),
+ Completion(text='last_name', start_position=0)])
+
+
+def test_suggested_column_names_with_alias(completer, complete_event):
+ """Suggest column names on table alias and dot.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = 'SELECT u. from users u'
+ position = len('SELECT u.')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='*', start_position=0),
+ Completion(text='email', start_position=0),
+ Completion(text='first_name', start_position=0),
+ Completion(text='id', start_position=0),
+ Completion(text='last_name', start_position=0)])
+
+
+def test_suggested_multiple_column_names(completer, complete_event):
+ """Suggest column and function names when selecting multiple columns from
+ table.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = 'SELECT id, from users u'
+ position = len('SELECT id, ')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='*', start_position=0),
+ Completion(text='email', start_position=0),
+ Completion(text='first_name', start_position=0),
+ Completion(text='id', start_position=0),
+ Completion(text='last_name', start_position=0)] +
+ list(map(Completion, completer.functions)) +
+ [Completion(text='u', start_position=0)] +
+ list(map(Completion, completer.keywords)))
+
+
+def test_suggested_multiple_column_names_with_alias(completer, complete_event):
+ """Suggest column names on table alias and dot when selecting multiple
+ columns from table.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = 'SELECT u.id, u. from users u'
+ position = len('SELECT u.id, u.')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='*', start_position=0),
+ Completion(text='email', start_position=0),
+ Completion(text='first_name', start_position=0),
+ Completion(text='id', start_position=0),
+ Completion(text='last_name', start_position=0)])
+
+
+def test_suggested_multiple_column_names_with_dot(completer, complete_event):
+ """Suggest column names on table names and dot when selecting multiple
+ columns from table.
+
+ :param completer:
+ :param complete_event:
+ :return:
+
+ """
+ text = 'SELECT users.id, users. from users u'
+ position = len('SELECT users.id, users.')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='*', start_position=0),
+ Completion(text='email', start_position=0),
+ Completion(text='first_name', start_position=0),
+ Completion(text='id', start_position=0),
+ Completion(text='last_name', start_position=0)])
+
+
+def test_suggested_aliases_after_on(completer, complete_event):
+ text = 'SELECT u.name, o.id FROM users u JOIN orders o ON '
+ position = len('SELECT u.name, o.id FROM users u JOIN orders o ON ')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='o', start_position=0),
+ Completion(text='u', start_position=0)])
+
+
+def test_suggested_aliases_after_on_right_side(completer, complete_event):
+ text = 'SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = '
+ position = len(
+ 'SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='o', start_position=0),
+ Completion(text='u', start_position=0)])
+
+
+def test_suggested_tables_after_on(completer, complete_event):
+ text = 'SELECT users.name, orders.id FROM users JOIN orders ON '
+ position = len('SELECT users.name, orders.id FROM users JOIN orders ON ')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='orders', start_position=0),
+ Completion(text='users', start_position=0)])
+
+
+def test_suggested_tables_after_on_right_side(completer, complete_event):
+ text = 'SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = '
+ position = len(
+ 'SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='orders', start_position=0),
+ Completion(text='users', start_position=0)])
+
+
+def test_table_names_after_from(completer, complete_event):
+ text = 'SELECT * FROM '
+ position = len('SELECT * FROM ')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='`réveillé`', start_position=0),
+ Completion(text='`select`', start_position=0),
+ Completion(text='orders', start_position=0),
+ Completion(text='users', start_position=0),
+ ])
+
+
+def test_auto_escaped_col_names(completer, complete_event):
+ text = 'SELECT from `select`'
+ position = len('SELECT ')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == [
+ Completion(text='*', start_position=0),
+ Completion(text='`ABC`', start_position=0),
+ Completion(text='`insert`', start_position=0),
+ Completion(text='id', start_position=0),
+ ] + \
+ list(map(Completion, completer.functions)) + \
+ [Completion(text='`select`', start_position=0)] + \
+ list(map(Completion, completer.keywords))
+
+
+def test_un_escaped_table_names(completer, complete_event):
+ text = 'SELECT from réveillé'
+ position = len('SELECT ')
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ assert result == list([
+ Completion(text='*', start_position=0),
+ Completion(text='`ABC`', start_position=0),
+ Completion(text='`insert`', start_position=0),
+ Completion(text='id', start_position=0),
+ ] +
+ list(map(Completion, completer.functions)) +
+ [Completion(text='réveillé', start_position=0)] +
+ list(map(Completion, completer.keywords)))
+
+
+def dummy_list_path(dir_name):
+ dirs = {
+ '/': [
+ 'dir1',
+ 'file1.sql',
+ 'file2.sql',
+ ],
+ '/dir1': [
+ 'subdir1',
+ 'subfile1.sql',
+ 'subfile2.sql',
+ ],
+ '/dir1/subdir1': [
+ 'lastfile.sql',
+ ],
+ }
+ return dirs.get(dir_name, [])
+
+
+@patch('mycli.packages.filepaths.list_path', new=dummy_list_path)
+@pytest.mark.parametrize('text,expected', [
+ # ('source ', [('~', 0),
+ # ('/', 0),
+ # ('.', 0),
+ # ('..', 0)]),
+ ('source /', [('dir1', 0),
+ ('file1.sql', 0),
+ ('file2.sql', 0)]),
+ ('source /dir1/', [('subdir1', 0),
+ ('subfile1.sql', 0),
+ ('subfile2.sql', 0)]),
+ ('source /dir1/subdir1/', [('lastfile.sql', 0)]),
+])
+def test_file_name_completion(completer, complete_event, text, expected):
+ position = len(text)
+ result = list(completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event))
+ expected = list((Completion(txt, pos) for txt, pos in expected))
+ assert result == expected
diff --git a/test/test_special_iocommands.py b/test/test_special_iocommands.py
new file mode 100644
index 0000000..8b6be33
--- /dev/null
+++ b/test/test_special_iocommands.py
@@ -0,0 +1,287 @@
+import os
+import stat
+import tempfile
+from time import time
+from unittest.mock import patch
+
+import pytest
+from pymysql import ProgrammingError
+
+import mycli.packages.special
+
+from .utils import dbtest, db_connection, send_ctrl_c
+
+
+def test_set_get_pager():
+ mycli.packages.special.set_pager_enabled(True)
+ assert mycli.packages.special.is_pager_enabled()
+ mycli.packages.special.set_pager_enabled(False)
+ assert not mycli.packages.special.is_pager_enabled()
+ mycli.packages.special.set_pager('less')
+ assert os.environ['PAGER'] == "less"
+ mycli.packages.special.set_pager(False)
+ assert os.environ['PAGER'] == "less"
+ del os.environ['PAGER']
+ mycli.packages.special.set_pager(False)
+ mycli.packages.special.disable_pager()
+ assert not mycli.packages.special.is_pager_enabled()
+
+
+def test_set_get_timing():
+ mycli.packages.special.set_timing_enabled(True)
+ assert mycli.packages.special.is_timing_enabled()
+ mycli.packages.special.set_timing_enabled(False)
+ assert not mycli.packages.special.is_timing_enabled()
+
+
+def test_set_get_expanded_output():
+ mycli.packages.special.set_expanded_output(True)
+ assert mycli.packages.special.is_expanded_output()
+ mycli.packages.special.set_expanded_output(False)
+ assert not mycli.packages.special.is_expanded_output()
+
+
+def test_editor_command():
+ assert mycli.packages.special.editor_command(r'hello\e')
+ assert mycli.packages.special.editor_command(r'\ehello')
+ assert not mycli.packages.special.editor_command(r'hello')
+
+ assert mycli.packages.special.get_filename(r'\e filename') == "filename"
+
+ os.environ['EDITOR'] = 'true'
+ os.environ['VISUAL'] = 'true'
+ mycli.packages.special.open_external_editor(sql=r'select 1') == "select 1"
+
+
+def test_tee_command():
+ mycli.packages.special.write_tee(u"hello world") # write without file set
+ with tempfile.NamedTemporaryFile() as f:
+ mycli.packages.special.execute(None, u"tee " + f.name)
+ mycli.packages.special.write_tee(u"hello world")
+ assert f.read() == b"hello world\n"
+
+ mycli.packages.special.execute(None, u"tee -o " + f.name)
+ mycli.packages.special.write_tee(u"hello world")
+ f.seek(0)
+ assert f.read() == b"hello world\n"
+
+ mycli.packages.special.execute(None, u"notee")
+ mycli.packages.special.write_tee(u"hello world")
+ f.seek(0)
+ assert f.read() == b"hello world\n"
+
+
+def test_tee_command_error():
+ with pytest.raises(TypeError):
+ mycli.packages.special.execute(None, 'tee')
+
+ with pytest.raises(OSError):
+ with tempfile.NamedTemporaryFile() as f:
+ os.chmod(f.name, stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH)
+ mycli.packages.special.execute(None, 'tee {}'.format(f.name))
+
+
+@dbtest
+def test_favorite_query():
+ with db_connection().cursor() as cur:
+ query = u'select "✔"'
+ mycli.packages.special.execute(cur, u'\\fs check {0}'.format(query))
+ assert next(mycli.packages.special.execute(
+ cur, u'\\f check'))[0] == "> " + query
+
+
+def test_once_command():
+ with pytest.raises(TypeError):
+ mycli.packages.special.execute(None, u"\\once")
+
+ with pytest.raises(OSError):
+ mycli.packages.special.execute(None, u"\\once /proc/access-denied")
+
+ mycli.packages.special.write_once(u"hello world") # write without file set
+ with tempfile.NamedTemporaryFile() as f:
+ mycli.packages.special.execute(None, u"\\once " + f.name)
+ mycli.packages.special.write_once(u"hello world")
+ assert f.read() == b"hello world\n"
+
+ mycli.packages.special.execute(None, u"\\once -o " + f.name)
+ mycli.packages.special.write_once(u"hello world line 1")
+ mycli.packages.special.write_once(u"hello world line 2")
+ f.seek(0)
+ assert f.read() == b"hello world line 1\nhello world line 2\n"
+
+
+def test_pipe_once_command():
+ with pytest.raises(IOError):
+ mycli.packages.special.execute(None, u"\\pipe_once")
+
+ with pytest.raises(OSError):
+ mycli.packages.special.execute(
+ None, u"\\pipe_once /proc/access-denied")
+
+ mycli.packages.special.execute(None, u"\\pipe_once wc")
+ mycli.packages.special.write_once(u"hello world")
+ mycli.packages.special.unset_pipe_once_if_written()
+ # how to assert on wc output?
+
+
+def test_parseargfile():
+ """Test that parseargfile expands the user directory."""
+ expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'),
+ 'mode': 'a'}
+ assert expected == mycli.packages.special.iocommands.parseargfile(
+ '~/filename')
+
+ expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'),
+ 'mode': 'w'}
+ assert expected == mycli.packages.special.iocommands.parseargfile(
+ '-o ~/filename')
+
+
+def test_parseargfile_no_file():
+ """Test that parseargfile raises a TypeError if there is no filename."""
+ with pytest.raises(TypeError):
+ mycli.packages.special.iocommands.parseargfile('')
+
+ with pytest.raises(TypeError):
+ mycli.packages.special.iocommands.parseargfile('-o ')
+
+
+@dbtest
+def test_watch_query_iteration():
+ """Test that a single iteration of the result of `watch_query` executes
+ the desired query and returns the given results."""
+ expected_value = "1"
+ query = "SELECT {0!s}".format(expected_value)
+ expected_title = '> {0!s}'.format(query)
+ with db_connection().cursor() as cur:
+ result = next(mycli.packages.special.iocommands.watch_query(
+ arg=query, cur=cur
+ ))
+ assert result[0] == expected_title
+ assert result[2][0] == expected_value
+
+
+@dbtest
+def test_watch_query_full():
+ """Test that `watch_query`:
+
+ * Returns the expected results.
+ * Executes the defined times inside the given interval, in this case with
+ a 0.3 seconds wait, it should execute 4 times inside a 1 seconds
+ interval.
+ * Stops at Ctrl-C
+
+ """
+ watch_seconds = 0.3
+ wait_interval = 1
+ expected_value = "1"
+ query = "SELECT {0!s}".format(expected_value)
+ expected_title = '> {0!s}'.format(query)
+ expected_results = 4
+ ctrl_c_process = send_ctrl_c(wait_interval)
+ with db_connection().cursor() as cur:
+ results = list(
+ result for result in mycli.packages.special.iocommands.watch_query(
+ arg='{0!s} {1!s}'.format(watch_seconds, query), cur=cur
+ )
+ )
+ ctrl_c_process.join(1)
+ assert len(results) == expected_results
+ for result in results:
+ assert result[0] == expected_title
+ assert result[2][0] == expected_value
+
+
+@dbtest
+@patch('click.clear')
+def test_watch_query_clear(clear_mock):
+ """Test that the screen is cleared with the -c flag of `watch` command
+ before execute the query."""
+ with db_connection().cursor() as cur:
+ watch_gen = mycli.packages.special.iocommands.watch_query(
+ arg='0.1 -c select 1;', cur=cur
+ )
+ assert not clear_mock.called
+ next(watch_gen)
+ assert clear_mock.called
+ clear_mock.reset_mock()
+ next(watch_gen)
+ assert clear_mock.called
+ clear_mock.reset_mock()
+
+
+@dbtest
+def test_watch_query_bad_arguments():
+ """Test different incorrect combinations of arguments for `watch`
+ command."""
+ watch_query = mycli.packages.special.iocommands.watch_query
+ with db_connection().cursor() as cur:
+ with pytest.raises(ProgrammingError):
+ next(watch_query('a select 1;', cur=cur))
+ with pytest.raises(ProgrammingError):
+ next(watch_query('-a select 1;', cur=cur))
+ with pytest.raises(ProgrammingError):
+ next(watch_query('1 -a select 1;', cur=cur))
+ with pytest.raises(ProgrammingError):
+ next(watch_query('-c -a select 1;', cur=cur))
+
+
+@dbtest
+@patch('click.clear')
+def test_watch_query_interval_clear(clear_mock):
+ """Test `watch` command with interval and clear flag."""
+ def test_asserts(gen):
+ clear_mock.reset_mock()
+ start = time()
+ next(gen)
+ assert clear_mock.called
+ next(gen)
+ exec_time = time() - start
+ assert exec_time > seconds and exec_time < (seconds + seconds)
+
+ seconds = 1.0
+ watch_query = mycli.packages.special.iocommands.watch_query
+ with db_connection().cursor() as cur:
+ test_asserts(watch_query('{0!s} -c select 1;'.format(seconds),
+ cur=cur))
+ test_asserts(watch_query('-c {0!s} select 1;'.format(seconds),
+ cur=cur))
+
+
+def test_split_sql_by_delimiter():
+ for delimiter_str in (';', '$', '😀'):
+ mycli.packages.special.set_delimiter(delimiter_str)
+ sql_input = "select 1{} select \ufffc2".format(delimiter_str)
+ queries = (
+ "select 1",
+ "select \ufffc2"
+ )
+ for query, parsed_query in zip(
+ queries, mycli.packages.special.split_queries(sql_input)):
+ assert(query == parsed_query)
+
+
+def test_switch_delimiter_within_query():
+ mycli.packages.special.set_delimiter(';')
+ sql_input = "select 1; delimiter $$ select 2 $$ select 3 $$"
+ queries = (
+ "select 1",
+ "delimiter $$ select 2 $$ select 3 $$",
+ "select 2",
+ "select 3"
+ )
+ for query, parsed_query in zip(
+ queries,
+ mycli.packages.special.split_queries(sql_input)):
+ assert(query == parsed_query)
+
+
+def test_set_delimiter():
+
+ for delim in ('foo', 'bar'):
+ mycli.packages.special.set_delimiter(delim)
+ assert mycli.packages.special.get_current_delimiter() == delim
+
+
+def teardown_function():
+ mycli.packages.special.set_delimiter(';')
diff --git a/test/test_sqlexecute.py b/test/test_sqlexecute.py
new file mode 100644
index 0000000..38ca5ef
--- /dev/null
+++ b/test/test_sqlexecute.py
@@ -0,0 +1,295 @@
+import os
+
+import pytest
+import pymysql
+
+from mycli.sqlexecute import ServerInfo, ServerSpecies
+from .utils import run, dbtest, set_expanded_output, is_expanded_output
+
+
+def assert_result_equal(result, title=None, rows=None, headers=None,
+ status=None, auto_status=True, assert_contains=False):
+ """Assert that an sqlexecute.run() result matches the expected values."""
+ if status is None and auto_status and rows:
+ status = '{} row{} in set'.format(
+ len(rows), 's' if len(rows) > 1 else '')
+ fields = {'title': title, 'rows': rows, 'headers': headers,
+ 'status': status}
+
+ if assert_contains:
+ # Do a loose match on the results using the *in* operator.
+ for key, field in fields.items():
+ if field:
+ assert field in result[0][key]
+ else:
+ # Do an exact match on the fields.
+ assert result == [fields]
+
+
+@dbtest
+def test_conn(executor):
+ run(executor, '''create table test(a text)''')
+ run(executor, '''insert into test values('abc')''')
+ results = run(executor, '''select * from test''')
+
+ assert_result_equal(results, headers=['a'], rows=[('abc',)])
+
+
+@dbtest
+def test_bools(executor):
+ run(executor, '''create table test(a boolean)''')
+ run(executor, '''insert into test values(True)''')
+ results = run(executor, '''select * from test''')
+
+ assert_result_equal(results, headers=['a'], rows=[(1,)])
+
+
+@dbtest
+def test_binary(executor):
+ run(executor, '''create table bt(geom linestring NOT NULL)''')
+ run(executor, "INSERT INTO bt VALUES "
+ "(ST_GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));")
+ results = run(executor, '''select * from bt''')
+
+ geom = (b'\x00\x00\x00\x00\x01\x02\x00\x00\x00\x02\x00\x00\x009\x7f\x13\n'
+ b'\x11\x18]@4\xf4Op\xb1\xdeC@\x00\x00\x00\x00\x00\x18]@B>\xe8\xd9'
+ b'\xac\xdeC@')
+
+ assert_result_equal(results, headers=['geom'], rows=[(geom,)])
+
+
+@dbtest
+def test_table_and_columns_query(executor):
+ run(executor, "create table a(x text, y text)")
+ run(executor, "create table b(z text)")
+
+ assert set(executor.tables()) == set([('a',), ('b',)])
+ assert set(executor.table_columns()) == set(
+ [('a', 'x'), ('a', 'y'), ('b', 'z')])
+
+
+@dbtest
+def test_database_list(executor):
+ databases = executor.databases()
+ assert 'mycli_test_db' in databases
+
+
+@dbtest
+def test_invalid_syntax(executor):
+ with pytest.raises(pymysql.ProgrammingError) as excinfo:
+ run(executor, 'invalid syntax!')
+ assert 'You have an error in your SQL syntax;' in str(excinfo.value)
+
+
+@dbtest
+def test_invalid_column_name(executor):
+ with pytest.raises(pymysql.err.OperationalError) as excinfo:
+ run(executor, 'select invalid command')
+ assert "Unknown column 'invalid' in 'field list'" in str(excinfo.value)
+
+
+@dbtest
+def test_unicode_support_in_output(executor):
+ run(executor, "create table unicodechars(t text)")
+ run(executor, u"insert into unicodechars (t) values ('é')")
+
+ # See issue #24, this raises an exception without proper handling
+ results = run(executor, u"select * from unicodechars")
+ assert_result_equal(results, headers=['t'], rows=[(u'é',)])
+
+
+@dbtest
+def test_multiple_queries_same_line(executor):
+ results = run(executor, "select 'foo'; select 'bar'")
+
+ expected = [{'title': None, 'headers': ['foo'], 'rows': [('foo',)],
+ 'status': '1 row in set'},
+ {'title': None, 'headers': ['bar'], 'rows': [('bar',)],
+ 'status': '1 row in set'}]
+ assert expected == results
+
+
+@dbtest
+def test_multiple_queries_same_line_syntaxerror(executor):
+ with pytest.raises(pymysql.ProgrammingError) as excinfo:
+ run(executor, "select 'foo'; invalid syntax")
+ assert 'You have an error in your SQL syntax;' in str(excinfo.value)
+
+
+@dbtest
+def test_favorite_query(executor):
+ set_expanded_output(False)
+ run(executor, "create table test(a text)")
+ run(executor, "insert into test values('abc')")
+ run(executor, "insert into test values('def')")
+
+ results = run(executor, "\\fs test-a select * from test where a like 'a%'")
+ assert_result_equal(results, status='Saved.')
+
+ results = run(executor, "\\f test-a")
+ assert_result_equal(results,
+ title="> select * from test where a like 'a%'",
+ headers=['a'], rows=[('abc',)], auto_status=False)
+
+ results = run(executor, "\\fd test-a")
+ assert_result_equal(results, status='test-a: Deleted')
+
+
+@dbtest
+def test_favorite_query_multiple_statement(executor):
+ set_expanded_output(False)
+ run(executor, "create table test(a text)")
+ run(executor, "insert into test values('abc')")
+ run(executor, "insert into test values('def')")
+
+ results = run(executor,
+ "\\fs test-ad select * from test where a like 'a%'; "
+ "select * from test where a like 'd%'")
+ assert_result_equal(results, status='Saved.')
+
+ results = run(executor, "\\f test-ad")
+ expected = [{'title': "> select * from test where a like 'a%'",
+ 'headers': ['a'], 'rows': [('abc',)], 'status': None},
+ {'title': "> select * from test where a like 'd%'",
+ 'headers': ['a'], 'rows': [('def',)], 'status': None}]
+ assert expected == results
+
+ results = run(executor, "\\fd test-ad")
+ assert_result_equal(results, status='test-ad: Deleted')
+
+
+@dbtest
+def test_favorite_query_expanded_output(executor):
+ set_expanded_output(False)
+ run(executor, '''create table test(a text)''')
+ run(executor, '''insert into test values('abc')''')
+
+ results = run(executor, "\\fs test-ae select * from test")
+ assert_result_equal(results, status='Saved.')
+
+ results = run(executor, "\\f test-ae \\G")
+ assert is_expanded_output() is True
+ assert_result_equal(results, title='> select * from test',
+ headers=['a'], rows=[('abc',)], auto_status=False)
+
+ set_expanded_output(False)
+
+ results = run(executor, "\\fd test-ae")
+ assert_result_equal(results, status='test-ae: Deleted')
+
+
+@dbtest
+def test_special_command(executor):
+ results = run(executor, '\\?')
+ assert_result_equal(results, rows=('quit', '\\q', 'Quit.'),
+ headers='Command', assert_contains=True,
+ auto_status=False)
+
+
+@dbtest
+def test_cd_command_without_a_folder_name(executor):
+ results = run(executor, 'system cd')
+ assert_result_equal(results, status='No folder name was provided.')
+
+
+@dbtest
+def test_system_command_not_found(executor):
+ results = run(executor, 'system xyz')
+ assert_result_equal(results, status='OSError: No such file or directory',
+ assert_contains=True)
+
+
+@dbtest
+def test_system_command_output(executor):
+ test_dir = os.path.abspath(os.path.dirname(__file__))
+ test_file_path = os.path.join(test_dir, 'test.txt')
+ results = run(executor, 'system cat {0}'.format(test_file_path))
+ assert_result_equal(results, status='mycli rocks!\n')
+
+
+@dbtest
+def test_cd_command_current_dir(executor):
+ test_path = os.path.abspath(os.path.dirname(__file__))
+ run(executor, 'system cd {0}'.format(test_path))
+ assert os.getcwd() == test_path
+
+
+@dbtest
+def test_unicode_support(executor):
+ results = run(executor, u"SELECT '日本語' AS japanese;")
+ assert_result_equal(results, headers=['japanese'], rows=[(u'日本語',)])
+
+
+@dbtest
+def test_timestamp_null(executor):
+ run(executor, '''create table ts_null(a timestamp null)''')
+ run(executor, '''insert into ts_null values(null)''')
+ results = run(executor, '''select * from ts_null''')
+ assert_result_equal(results, headers=['a'],
+ rows=[(None,)])
+
+
+@dbtest
+def test_datetime_null(executor):
+ run(executor, '''create table dt_null(a datetime null)''')
+ run(executor, '''insert into dt_null values(null)''')
+ results = run(executor, '''select * from dt_null''')
+ assert_result_equal(results, headers=['a'],
+ rows=[(None,)])
+
+
+@dbtest
+def test_date_null(executor):
+ run(executor, '''create table date_null(a date null)''')
+ run(executor, '''insert into date_null values(null)''')
+ results = run(executor, '''select * from date_null''')
+ assert_result_equal(results, headers=['a'], rows=[(None,)])
+
+
+@dbtest
+def test_time_null(executor):
+ run(executor, '''create table time_null(a time null)''')
+ run(executor, '''insert into time_null values(null)''')
+ results = run(executor, '''select * from time_null''')
+ assert_result_equal(results, headers=['a'], rows=[(None,)])
+
+
+@dbtest
+def test_multiple_results(executor):
+ query = '''CREATE PROCEDURE dmtest()
+ BEGIN
+ SELECT 1;
+ SELECT 2;
+ END'''
+ executor.conn.cursor().execute(query)
+
+ results = run(executor, 'call dmtest;')
+ expected = [
+ {'title': None, 'rows': [(1,)], 'headers': ['1'],
+ 'status': '1 row in set'},
+ {'title': None, 'rows': [(2,)], 'headers': ['2'],
+ 'status': '1 row in set'}
+ ]
+ assert results == expected
+
+
+@pytest.mark.parametrize(
+ 'version_string, species, parsed_version_string, version',
+ (
+ ('5.7.25-TiDB-v6.1.0','TiDB', '5.7.25', 50725),
+ ('5.7.32-35', 'Percona', '5.7.32', 50732),
+ ('5.7.32-0ubuntu0.18.04.1', 'MySQL', '5.7.32', 50732),
+ ('10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
+ ('5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508),
+ ('5.0.16-pro-nt-log', 'MySQL', '5.0.16', 50016),
+ ('5.1.5a-alpha', 'MySQL', '5.1.5', 50105),
+ ('unexpected version string', None, '', 0),
+ ('', None, '', 0),
+ (None, None, '', 0),
+ )
+)
+def test_version_parsing(version_string, species, parsed_version_string, version):
+ server_info = ServerInfo.from_version_string(version_string)
+ assert (server_info.species and server_info.species.name) == species or ServerSpecies.Unknown
+ assert server_info.version_str == parsed_version_string
+ assert server_info.version == version
diff --git a/test/test_tabular_output.py b/test/test_tabular_output.py
new file mode 100644
index 0000000..c20c7de
--- /dev/null
+++ b/test/test_tabular_output.py
@@ -0,0 +1,118 @@
+"""Test the sql output adapter."""
+
+from textwrap import dedent
+
+from mycli.packages.tabular_output import sql_format
+from cli_helpers.tabular_output import TabularOutputFormatter
+
+from .utils import USER, PASSWORD, HOST, PORT, dbtest
+
+import pytest
+from mycli.main import MyCli
+
+from pymysql.constants import FIELD_TYPE
+
+
+@pytest.fixture
+def mycli():
+ cli = MyCli()
+ cli.connect(None, USER, PASSWORD, HOST, PORT, None, init_command=None)
+ return cli
+
+
+@dbtest
+def test_sql_output(mycli):
+ """Test the sql output adapter."""
+ headers = ['letters', 'number', 'optional', 'float', 'binary']
+
+ class FakeCursor(object):
+ def __init__(self):
+ self.data = [
+ ('abc', 1, None, 10.0, b'\xAA'),
+ ('d', 456, '1', 0.5, b'\xAA\xBB')
+ ]
+ self.description = [
+ (None, FIELD_TYPE.VARCHAR),
+ (None, FIELD_TYPE.LONG),
+ (None, FIELD_TYPE.LONG),
+ (None, FIELD_TYPE.FLOAT),
+ (None, FIELD_TYPE.BLOB)
+ ]
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.data:
+ return self.data.pop(0)
+ else:
+ raise StopIteration()
+
+ def description(self):
+ return self.description
+
+ # Test sql-update output format
+ assert list(mycli.change_table_format("sql-update")) == \
+ [(None, None, None, 'Changed table format to sql-update')]
+ mycli.formatter.query = ""
+ output = mycli.format_output(None, FakeCursor(), headers)
+ actual = "\n".join(output)
+ assert actual == dedent('''\
+ UPDATE `DUAL` SET
+ `number` = 1
+ , `optional` = NULL
+ , `float` = 10.0e0
+ , `binary` = X'aa'
+ WHERE `letters` = 'abc';
+ UPDATE `DUAL` SET
+ `number` = 456
+ , `optional` = '1'
+ , `float` = 0.5e0
+ , `binary` = X'aabb'
+ WHERE `letters` = 'd';''')
+ # Test sql-update-2 output format
+ assert list(mycli.change_table_format("sql-update-2")) == \
+ [(None, None, None, 'Changed table format to sql-update-2')]
+ mycli.formatter.query = ""
+ output = mycli.format_output(None, FakeCursor(), headers)
+ assert "\n".join(output) == dedent('''\
+ UPDATE `DUAL` SET
+ `optional` = NULL
+ , `float` = 10.0e0
+ , `binary` = X'aa'
+ WHERE `letters` = 'abc' AND `number` = 1;
+ UPDATE `DUAL` SET
+ `optional` = '1'
+ , `float` = 0.5e0
+ , `binary` = X'aabb'
+ WHERE `letters` = 'd' AND `number` = 456;''')
+ # Test sql-insert output format (without table name)
+ assert list(mycli.change_table_format("sql-insert")) == \
+ [(None, None, None, 'Changed table format to sql-insert')]
+ mycli.formatter.query = ""
+ output = mycli.format_output(None, FakeCursor(), headers)
+ assert "\n".join(output) == dedent('''\
+ INSERT INTO `DUAL` (`letters`, `number`, `optional`, `float`, `binary`) VALUES
+ ('abc', 1, NULL, 10.0e0, X'aa')
+ , ('d', 456, '1', 0.5e0, X'aabb')
+ ;''')
+ # Test sql-insert output format (with table name)
+ assert list(mycli.change_table_format("sql-insert")) == \
+ [(None, None, None, 'Changed table format to sql-insert')]
+ mycli.formatter.query = "SELECT * FROM `table`"
+ output = mycli.format_output(None, FakeCursor(), headers)
+ assert "\n".join(output) == dedent('''\
+ INSERT INTO `table` (`letters`, `number`, `optional`, `float`, `binary`) VALUES
+ ('abc', 1, NULL, 10.0e0, X'aa')
+ , ('d', 456, '1', 0.5e0, X'aabb')
+ ;''')
+ # Test sql-insert output format (with database + table name)
+ assert list(mycli.change_table_format("sql-insert")) == \
+ [(None, None, None, 'Changed table format to sql-insert')]
+ mycli.formatter.query = "SELECT * FROM `database`.`table`"
+ output = mycli.format_output(None, FakeCursor(), headers)
+ assert "\n".join(output) == dedent('''\
+ INSERT INTO `database`.`table` (`letters`, `number`, `optional`, `float`, `binary`) VALUES
+ ('abc', 1, NULL, 10.0e0, X'aa')
+ , ('d', 456, '1', 0.5e0, X'aabb')
+ ;''')
diff --git a/test/utils.py b/test/utils.py
new file mode 100644
index 0000000..ab12248
--- /dev/null
+++ b/test/utils.py
@@ -0,0 +1,94 @@
+import os
+import time
+import signal
+import platform
+import multiprocessing
+
+import pymysql
+import pytest
+
+from mycli.main import special
+
+PASSWORD = os.getenv('PYTEST_PASSWORD')
+USER = os.getenv('PYTEST_USER', 'root')
+HOST = os.getenv('PYTEST_HOST', 'localhost')
+PORT = int(os.getenv('PYTEST_PORT', 3306))
+CHARSET = os.getenv('PYTEST_CHARSET', 'utf8')
+SSH_USER = os.getenv('PYTEST_SSH_USER', None)
+SSH_HOST = os.getenv('PYTEST_SSH_HOST', None)
+SSH_PORT = os.getenv('PYTEST_SSH_PORT', 22)
+
+
+def db_connection(dbname=None):
+ conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname,
+ password=PASSWORD, charset=CHARSET,
+ local_infile=False)
+ conn.autocommit = True
+ return conn
+
+
+try:
+ db_connection()
+ CAN_CONNECT_TO_DB = True
+except:
+ CAN_CONNECT_TO_DB = False
+
+dbtest = pytest.mark.skipif(
+ not CAN_CONNECT_TO_DB,
+ reason="Need a mysql instance at localhost accessible by user 'root'")
+
+
+def create_db(dbname):
+ with db_connection().cursor() as cur:
+ try:
+ cur.execute('''DROP DATABASE IF EXISTS mycli_test_db''')
+ cur.execute('''CREATE DATABASE mycli_test_db''')
+ except:
+ pass
+
+
+def run(executor, sql, rows_as_list=True):
+ """Return string output for the sql to be run."""
+ result = []
+
+ for title, rows, headers, status in executor.run(sql):
+ rows = list(rows) if (rows_as_list and rows) else rows
+ result.append({'title': title, 'rows': rows, 'headers': headers,
+ 'status': status})
+
+ return result
+
+
+def set_expanded_output(is_expanded):
+ """Pass-through for the tests."""
+ return special.set_expanded_output(is_expanded)
+
+
+def is_expanded_output():
+ """Pass-through for the tests."""
+ return special.is_expanded_output()
+
+
+def send_ctrl_c_to_pid(pid, wait_seconds):
+ """Sends a Ctrl-C like signal to the given `pid` after `wait_seconds`
+ seconds."""
+ time.sleep(wait_seconds)
+ system_name = platform.system()
+ if system_name == "Windows":
+ os.kill(pid, signal.CTRL_C_EVENT)
+ else:
+ os.kill(pid, signal.SIGINT)
+
+
+def send_ctrl_c(wait_seconds):
+ """Create a process that sends a Ctrl-C like signal to the current process
+ after `wait_seconds` seconds.
+
+ Returns the `multiprocessing.Process` created.
+
+ """
+ ctrl_c_process = multiprocessing.Process(
+ target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds)
+ )
+ ctrl_c_process.start()
+ return ctrl_c_process
diff --git a/tox.ini b/tox.ini
new file mode 100644
index 0000000..612e8b7
--- /dev/null
+++ b/tox.ini
@@ -0,0 +1,15 @@
+[tox]
+envlist = py36, py37, py38
+
+[testenv]
+deps = pytest
+ mock
+ pexpect
+ behave
+ coverage
+commands = python setup.py test
+passenv = PYTEST_HOST
+ PYTEST_USER
+ PYTEST_PASSWORD
+ PYTEST_PORT
+ PYTEST_CHARSET