diff options
Diffstat (limited to '')
93 files changed, 10386 insertions, 0 deletions
diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..8d3149f --- /dev/null +++ b/.coveragerc @@ -0,0 +1,3 @@ +[run] +parallel = True +source = mycli diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/.git-blame-ignore-revs diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..8d498ab --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,9 @@ +## Description +<!--- Describe your changes in detail. --> + + + +## Checklist +<!--- We appreciate your help and want to give you credit. Please take a moment to put an `x` in the boxes below as you complete them. --> +- [ ] I've added this contribution to the `changelog.md`. +- [ ] I've added my name to the `AUTHORS` file (or it's already there). diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b13429e --- /dev/null +++ b/.gitignore @@ -0,0 +1,14 @@ +.idea/ +.vscode/ +/build +/dist +/mycli.egg-info +/src +/test/behave.ini + +.vagrant +*.pyc +*.deb +.cache/ +.coverage +.coverage.* diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..0afb5cc --- /dev/null +++ b/.travis.yml @@ -0,0 +1,33 @@ +language: python +python: + - "3.6" + - "3.7" + - "3.8" + +matrix: + include: + - python: 3.7 + dist: xenial + sudo: true + +install: + - pip install -r requirements-dev.txt + - pip install -e . + - sudo rm -f /etc/mysql/conf.d/performance-schema.cnf + - sudo service mysql restart + +script: + - ./setup.py test --pytest-args="--cov-report= --cov=mycli" + - coverage combine + - coverage report + - ./setup.py lint --branch=$TRAVIS_BRANCH + +after_success: + - codecov + +notifications: + webhooks: + urls: + - YOUR_WEBHOOK_URL + on_success: change # options: [always|never|change] default: always + on_failure: always # options: [always|never|change] default: always diff --git a/AUTHORS.rst b/AUTHORS.rst new file mode 100644 index 0000000..995327f --- /dev/null +++ b/AUTHORS.rst @@ -0,0 +1,3 @@ +Check out our `AUTHORS`_. + +.. _AUTHORS: mycli/AUTHORS diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..124b19a --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,137 @@ +# Development Guide + +This is a guide for developers who would like to contribute to this project. + +If you're interested in contributing to mycli, thank you. We'd love your help! +You'll always get credit for your work. + +## GitHub Workflow + +1. [Fork the repository](https://github.com/dbcli/mycli) on GitHub. + +2. Clone your fork locally: + ```bash + $ git clone <url-for-your-fork> + ``` + +3. Add the official repository (`upstream`) as a remote repository: + ```bash + $ git remote add upstream git@github.com:dbcli/mycli.git + ``` + +4. Set up a [virtual environment](http://docs.python-guide.org/en/latest/dev/virtualenvs) + for development: + + ```bash + $ cd mycli + $ pip install virtualenv + $ virtualenv mycli_dev + ``` + + We've just created a virtual environment that we'll use to install all the dependencies + and tools we need to work on mycli. Whenever you want to work on mycli, you + need to activate the virtual environment: + + ```bash + $ source mycli_dev/bin/activate + ``` + + When you're done working, you can deactivate the virtual environment: + + ```bash + $ deactivate + ``` + +5. Install the dependencies and development tools: + + ```bash + $ pip install -r requirements-dev.txt + $ pip install --editable . + ``` + +6. Create a branch for your bugfix or feature based off the `master` branch: + + ```bash + $ git checkout -b <name-of-bugfix-or-feature> master + ``` + +7. While you work on your bugfix or feature, be sure to pull the latest changes from `upstream`. This ensures that your local codebase is up-to-date: + + ```bash + $ git pull upstream master + ``` + +8. When your work is ready for the mycli team to review it, push your branch to your fork: + + ```bash + $ git push origin <name-of-bugfix-or-feature> + ``` + +9. [Create a pull request](https://help.github.com/articles/creating-a-pull-request-from-a-fork/) + on GitHub. + + +## Running the Tests + +While you work on mycli, it's important to run the tests to make sure your code +hasn't broken any existing functionality. To run the tests, just type in: + +```bash +$ ./setup.py test +``` + +Mycli supports Python 2.7 and 3.4+. You can test against multiple versions of +Python by running tox: + +```bash +$ tox +``` + + +### Test Database Credentials + +The tests require a database connection to work. You can tell the tests which +credentials to use by setting the applicable environment variables: + +```bash +$ export PYTEST_HOST=localhost +$ export PYTEST_USER=user +$ export PYTEST_PASSWORD=myclirocks +$ export PYTEST_PORT=3306 +$ export PYTEST_CHARSET=utf8 +``` + +The default values are `localhost`, `root`, no password, `3306`, and `utf8`. +You only need to set the values that differ from the defaults. + + +### CLI Tests + +Some CLI tests expect the program `ex` to be a symbolic link to `vim`. + +In some systems (e.g. Arch Linux) `ex` is a symbolic link to `vi`, which will +change the output and therefore make some tests fail. + +You can check this by running: +```bash +$ readlink -f $(which ex) +``` + + +## Coding Style + +Mycli requires code submissions to adhere to +[PEP 8](https://www.python.org/dev/peps/pep-0008/). +It's easy to check the style of your code, just run: + +```bash +$ ./setup.py lint +``` + +If you see any PEP 8 style issues, you can automatically fix them by running: + +```bash +$ ./setup.py lint --fix +``` + +Be sure to commit and push any PEP 8 fixes. diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..7b4904e --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,34 @@ +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the documentation and/or + other materials provided with the distribution. + +* Neither the name of the {organization} nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +------------------------------------------------------------------------------- + +This program also bundles with it python-tabulate +(https://pypi.python.org/pypi/tabulate) library. This library is licensed under +MIT License. + +------------------------------------------------------------------------------- diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..04f4d9a --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,6 @@ +include LICENSE.txt *.md *.rst requirements-dev.txt screenshots/* +include tasks.py .coveragerc tox.ini +recursive-include test *.cnf +recursive-include test *.feature +recursive-include test *.py +recursive-include test *.txt diff --git a/README.md b/README.md new file mode 100644 index 0000000..efe804d --- /dev/null +++ b/README.md @@ -0,0 +1,201 @@ +# mycli + +[![Build Status](https://travis-ci.org/dbcli/mycli.svg?branch=master)](https://travis-ci.org/dbcli/mycli) +[![PyPI](https://img.shields.io/pypi/v/mycli.svg?style=plastic)](https://pypi.python.org/pypi/mycli) +[![Join the chat at https://gitter.im/dbcli/mycli](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/dbcli/mycli?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) + +A command line client for MySQL that can do auto-completion and syntax highlighting. + +HomePage: [http://mycli.net](http://mycli.net) +Documentation: [http://mycli.net/docs](http://mycli.net/docs) + +![Completion](screenshots/tables.png) +![CompletionGif](screenshots/main.gif) + +Postgres Equivalent: [http://pgcli.com](http://pgcli.com) + +Quick Start +----------- + +If you already know how to install python packages, then you can install it via pip: + +You might need sudo on linux. + +``` +$ pip install -U mycli +``` + +or + +``` +$ brew update && brew install mycli # Only on macOS +``` + +or + +``` +$ sudo apt-get install mycli # Only on debian or ubuntu +``` + +### Usage + + $ mycli --help + Usage: mycli [OPTIONS] [DATABASE] + + A MySQL terminal client with auto-completion and syntax highlighting. + + Examples: + - mycli my_database + - mycli -u my_user -h my_host.com my_database + - mycli mysql://my_user@my_host.com:3306/my_database + + Options: + -h, --host TEXT Host address of the database. + -P, --port INTEGER Port number to use for connection. Honors + $MYSQL_TCP_PORT. + -u, --user TEXT User name to connect to the database. + -S, --socket TEXT The socket file to use for connection. + -p, --password TEXT Password to connect to the database. + --pass TEXT Password to connect to the database. + --ssh-user TEXT User name to connect to ssh server. + --ssh-host TEXT Host name to connect to ssh server. + --ssh-port INTEGER Port to connect to ssh server. + --ssh-password TEXT Password to connect to ssh server. + --ssh-key-filename TEXT Private key filename (identify file) for the + ssh connection. + --ssh-config-path TEXT Path to ssh configuation. + --ssh-config-host TEXT Host for ssh server in ssh configuations (requires paramiko). + --ssl-ca PATH CA file in PEM format. + --ssl-capath TEXT CA directory. + --ssl-cert PATH X509 cert in PEM format. + --ssl-key PATH X509 key in PEM format. + --ssl-cipher TEXT SSL cipher to use. + --ssl-verify-server-cert Verify server's "Common Name" in its cert + against hostname used when connecting. This + option is disabled by default. + -V, --version Output mycli's version. + -v, --verbose Verbose output. + -D, --database TEXT Database to use. + -d, --dsn TEXT Use DSN configured into the [alias_dsn] + section of myclirc file. + --list-dsn list of DSN configured into the [alias_dsn] + section of myclirc file. + --list-ssh-config list ssh configurations in the ssh config (requires paramiko). + -R, --prompt TEXT Prompt format (Default: "\t \u@\h:\d> "). + -l, --logfile FILENAME Log every query and its results to a file. + --defaults-group-suffix TEXT Read MySQL config groups with the specified + suffix. + --defaults-file PATH Only read MySQL options from the given file. + --myclirc PATH Location of myclirc file. + --auto-vertical-output Automatically switch to vertical output mode + if the result is wider than the terminal + width. + -t, --table Display batch output in table format. + --csv Display batch output in CSV format. + --warn / --no-warn Warn before running a destructive query. + --local-infile BOOLEAN Enable/disable LOAD DATA LOCAL INFILE. + --login-path TEXT Read this path from the login file. + -e, --execute TEXT Execute command and quit. + --help Show this message and exit. + +Features +-------- + +`mycli` is written using [prompt_toolkit](https://github.com/jonathanslenders/python-prompt-toolkit/). + +* Auto-completion as you type for SQL keywords as well as tables, views and + columns in the database. +* Syntax highlighting using Pygments. +* Smart-completion (enabled by default) will suggest context-sensitive completion. + - `SELECT * FROM <tab>` will only show table names. + - `SELECT * FROM users WHERE <tab>` will only show column names. +* Support for multiline queries. +* Favorite queries with optional positional parameters. Save a query using + `\fs alias query` and execute it with `\f alias` whenever you need. +* Timing of sql statments and table rendering. +* Config file is automatically created at ``~/.myclirc`` at first launch. +* Log every query and its results to a file (disabled by default). +* Pretty prints tabular data (with colors!) +* Support for SSL connections + +Contributions: +-------------- + +If you're interested in contributing to this project, first of all I would like +to extend my heartfelt gratitude. I've written a small doc to describe how to +get this running in a development setup. + +https://github.com/dbcli/mycli/blob/master/CONTRIBUTING.md + +Please feel free to reach out to me if you need help. + +My email: amjith.r@gmail.com + +Twitter: [@amjithr](http://twitter.com/amjithr) + +## Detailed Install Instructions: + +### Fedora + +Fedora has a package available for mycli, install it using dnf: + +``` +$ sudo dnf install mycli +``` + +### RHEL, Centos + +I haven't built an RPM package for mycli for RHEL or Centos yet. So please use `pip` to install `mycli`. You can install pip on your system using: + +``` +$ sudo yum install python-pip +``` + +Once that is installed, you can install mycli as follows: + +``` +$ sudo pip install mycli +``` + +### Windows + +Follow the instructions on this blogpost: https://www.codewall.co.uk/installing-using-mycli-on-windows/ + +### Cygwin + +1. Make sure the following Cygwin packages are installed: +`python3`, `python3-pip`. +2. Install mycli: `pip3 install mycli` + +### Thanks: + +This project was funded through kickstarter. My thanks to the [backers](http://mycli.net/sponsors) who supported the project. + +A special thanks to [Jonathan Slenders](https://twitter.com/jonathan_s) for +creating [Python Prompt Toolkit](http://github.com/jonathanslenders/python-prompt-toolkit), +which is quite literally the backbone library, that made this app possible. +Jonathan has also provided valuable feedback and support during the development +of this app. + +[Click](http://click.pocoo.org/) is used for command line option parsing +and printing error messages. + +Thanks to [PyMysql](https://github.com/PyMySQL/PyMySQL) for a pure python adapter to MySQL database. + + +### Compatibility + +Mycli is tested on macOS and Linux. + +**Mycli is not tested on Windows**, but the libraries used in this app are Windows-compatible. +This means it should work without any modifications. If you're unable to run it +on Windows, please [file a bug](https://github.com/dbcli/mycli/issues/new). + +### Configuration and Usage + +For more information on using and configuring mycli, [check out our documentation](http://mycli.net/docs). + +Common topics include: +- [Configuring mycli](http://mycli.net/config) +- [Using/Disabling the pager](http://mycli.net/pager) +- [Syntax colors](http://mycli.net/syntax) diff --git a/SPONSORS.rst b/SPONSORS.rst new file mode 100644 index 0000000..173555c --- /dev/null +++ b/SPONSORS.rst @@ -0,0 +1,3 @@ +Check out our `SPONSORS`_. + +.. _SPONSORS: mycli/SPONSORS diff --git a/changelog.md b/changelog.md new file mode 100644 index 0000000..a4fea35 --- /dev/null +++ b/changelog.md @@ -0,0 +1,770 @@ +1.22.2 +====== + +Bug Fixes: +---------- +* Make the `pwd` module optional. + +1.22.1 +====== + +Bug Fixes: +---------- +* Fix the breaking change introduced in PyMySQL 0.10.0. (Thanks: [Amjith]). + +Features: +--------- +* Add an option `--ssh-config-host` to read ssh configuration from OpenSSH configuration file. +* Add an option `--list-ssh-config` to list ssh configurations. +* Add an option `--ssh-config-path` to choose ssh configuration path. + + +1.21.1 +====== + + +Bug Fixes: +---------- + +* Fix broken auto-completion for favorite queries (Thanks: [Amjith]). +* Fix undefined variable exception when running with --no-warn (Thanks: [Georgy Frolov]) + +1.21.0 +====== + +Features: +--------- +* Added DSN alias name as a format specifier to the prompt (Thanks: [Georgy Frolov]). +* Mark `update` without `where`-clause as destructive query (Thanks: [Klaus Wünschel]). +* Added DELIMITER command (Thanks: [Georgy Frolov]) +* Added clearer error message when failing to connect to the default socket. +* Extend main.is_dropping_database check with create after delete statement. +* Search `${XDG_CONFIG_HOME}/mycli/myclirc` after `${HOME}/.myclirc` and before `/etc/myclirc` (Thanks: [Takeshi D. Itoh]) + +Bug Fixes: +---------- + +* Allow \o command more than once per session (Thanks: [Georgy Frolov]) +* Fixed crash when the query dropping the current database starts with a comment (Thanks: [Georgy Frolov]) + +Internal: +--------- +* deprecate python versions 2.7, 3.4, 3.5; support python 3.8 + +1.20.1 +====== + +Bug Fixes: +---------- + +* Fix an error when using login paths with an explicit database name (Thanks: [Thomas Roten]). + +1.20.0 +====== + +Features: +---------- +* Auto find alias dsn when `://` not in `database` (Thanks: [QiaoHou Peng]). +* Mention URL encoding as escaping technique for special characters in connection DSN (Thanks: [Aljosha Papsch]). +* Pressing Alt-Enter will introduce a line break. This is a way to break up the query into multiple lines without switching to multi-line mode. (Thanks: [Amjith Ramanujam]). +* Use a generator to stream the output to the pager (Thanks: [Dick Marinus]). + +Bug Fixes: +---------- + +* Fix the missing completion for special commands (Thanks: [Amjith Ramanujam]). +* Fix favorites queries being loaded/stored only from/in default config file and not --myclirc (Thanks: [Matheus Rosa]) +* Fix automatic vertical output with native syntax style (Thanks: [Thomas Roten]). +* Update `cli_helpers` version, this will remove quotes from batch output like the official client (Thanks: [Dick Marinus]) +* Update `setup.py` to no longer require `sqlparse` to be less than 0.3.0 as that just came out and there are no notable changes. ([VVelox]) +* workaround for ConfigObj parsing strings containing "," as lists (Thanks: [Mike Palandra]) + +Internal: +--------- +* fix unhashable FormattedText from prompt toolkit in unit tests (Thanks: [Dick Marinus]). + +1.19.0 +====== + +Internal: +--------- + +* Add Python 3.7 trove classifier (Thanks: [Thomas Roten]). +* Fix pytest in Fedora mock (Thanks: [Dick Marinus]). +* Require `prompt_toolkit>=2.0.6` (Thanks: [Dick Marinus]). + +Features: +--------- + +* Add Token.Prompt/Continuation (Thanks: [Dick Marinus]). +* Don't reconnect when switching databases using use (Thanks: [Angelo Lupo]). +* Handle MemoryErrors while trying to pipe in large files and exit gracefully with an error (Thanks: [Amjith Ramanujam]) + +Bug Fixes: +---------- + +* Enable Ctrl-Z to suspend the app (Thanks: [Amjith Ramanujam]). + +1.18.2 +====== + +Bug Fixes: +---------- + +* Fixes database reconnecting feature (Thanks: [Yang Zou]). + +Internal: +--------- + +* Update Twine version to 1.12.1 (Thanks: [Thomas Roten]). +* Fix warnings for running tests on Python 3.7 (Thanks: [Dick Marinus]). +* Clean up and add behave logging (Thanks: [Dick Marinus]). + +1.18.1 +====== + +Features: +--------- + +* Add Keywords: TINYINT, SMALLINT, MEDIUMINT, INT, BIGINT (Thanks: [QiaoHou Peng]). + +Internal: +--------- + +* Update prompt toolkit (Thanks: [Jonathan Slenders], [Irina Truong], [Dick Marinus]). + +1.18.0 +====== + +Features: +--------- + +* Display server version in welcome message (Thanks: [Irina Truong]). +* Set `program_name` connection attribute (Thanks: [Dick Marinus]). +* Use `return` to terminate a generator for better Python 3.7 support (Thanks: [Zhongyang Guan]). +* Add `SAVEPOINT` to SQLCompleter (Thanks: [Huachao Mao]). +* Connect using a SSH transport (Thanks: [Dick Marinus]). +* Add `FROM_UNIXTIME` and `UNIX_TIMESTAMP` to SQLCompleter (Thanks: [QiaoHou Peng]) +* Search `${PWD}/.myclirc`, then `${HOME}/.myclirc`, lastly `/etc/myclirc` (Thanks: [QiaoHao Peng]) + +Bug Fixes: +---------- + +* When DSN is used, allow overrides from mycli arguments (Thanks: [Dick Marinus]). +* A DSN without password should be allowed (Thanks: [Dick Marinus]) + +Bug Fixes: +---------- + +* Convert `sql_format` to unicode strings for py27 compatibility (Thanks: [Dick Marinus]). +* Fixes mycli compatibility with pbr (Thanks: [Thomas Roten]). +* Don't align decimals for `sql_format` (Thanks: [Dick Marinus]). + +Internal: +--------- + +* Use fileinput (Thanks: [Dick Marinus]). +* Enable tests for Python 3.7 (Thanks: [Thomas Roten]). +* Remove `*.swp` from gitignore (Thanks: [Dick Marinus]). + +1.17.0: +======= + +Features: +---------- + +* Add `CONCAT` to SQLCompleter and remove unused code (Thanks: [caitinggui]) +* Do not quit when aborting a confirmation prompt (Thanks: [Thomas Roten]). +* Add option list-dsn (Thanks: [Frederic Aoustin]). +* Add verbose option for list-dsn, add tests and clean up code (Thanks: [Dick Marinus]). + +Bug Fixes: +---------- + +* Add enable_pager to the config file (Thanks: [Frederic Aoustin]). +* Mark `test_sql_output` as a dbtest (Thanks: [Dick Marinus]). +* Don't crash if the log/history file directories don't exist (Thanks: [Thomas Roten]). +* Unquote dsn username and password (Thanks: [Dick Marinus]). +* Output `Password:` prompt to stderr (Thanks: [ushuz]). +* Mark `alter` as a destructive query (Thanks: [Dick Marinus]). +* Quote CSV fields (Thanks: [Thomas Roten]). +* Fix `thanks_picker` (Thanks: [Dick Marinus]). + +Internal: +--------- + +* Refactor Destructive Warning behave tests (Thanks: [Dick Marinus]). + + +1.16.0: +======= + +Features: +--------- + +* Add DSN aliases to the config file (Thanks: [Frederic Aoustin]). + +Bug Fixes: +---------- + +* Do not try to connect to a unix socket on Windows (Thanks: [Thomas Roten]). + +1.15.0: +======= + +Features: +--------- + +* Add sql-update/insert output format. (Thanks: [Dick Marinus]). +* Also complete aliases in WHERE. (Thanks: [Dick Marinus]). + +1.14.0: +======= + +Features: +--------- + +* Add `watch [seconds] query` command to repeat a query every [seconds] seconds (by default 5). (Thanks: [David Caro](https://github.com/Terseus)) +* Default to unix socket connection if host and port are unspecified. This simplifies authentication on some systems and matches mysql behaviour. +* Add support for positional parameters to favorite queries. (Thanks: [Scrappy Soft](https://github.com/scrappysoft)) + +Bug Fixes: +---------- + +* Fix source command for script in current working directory. (Thanks: [Dick Marinus]). +* Fix issue where the `tee` command did not work on Python 2.7 (Thanks: [Thomas Roten]). + +Internal Changes: +----------------- + +* Drop support for Python 3.3 (Thanks: [Thomas Roten]). + +* Make tests more compatible between different build environments. (Thanks: [David Caro]) +* Merge `_on_completions_refreshed` and `_swap_completer_objects` functions (Thanks: [Dick Marinus]). + +1.13.1: +======= + +Bug Fixes: +---------- + +* Fix keyword completion suggestion for `SHOW` (Thanks: [Thomas Roten]). +* Prevent mycli from crashing when failing to read login path file (Thanks: [Thomas Roten]). + +Internal Changes: +----------------- + +* Make tests ignore user config files (Thanks: [Thomas Roten]). + +1.13.0: +======= + +Features: +--------- + +* Add file name completion for source command (issue #500). (Thanks: [Irina Truong]). + +Bug Fixes: +---------- + +* Fix UnicodeEncodeError when editing sql command in external editor (Thanks: Klaus Wünschel). +* Fix MySQL4 version comment retrieval (Thanks: [François Pietka]) +* Fix error that occurred when outputting JSON and NULL data (Thanks: [Thomas Roten]). + +1.12.1: +======= + +Bug Fixes: +---------- + +* Prevent missing MySQL help database from causing errors in completions (Thanks: [Thomas Roten]). +* Fix mycli from crashing with small terminal windows under Python 2 (Thanks: [Thomas Roten]). +* Prevent an error from displaying when you drop the current database (Thanks: [Thomas Roten]). + +Internal Changes: +----------------- + +* Use less memory when formatting results for display (Thanks: [Dick Marinus]). +* Preliminary work for a future change in outputting results that uses less memory (Thanks: [Dick Marinus]). + +1.12.0: +======= + +Features: +--------- + +* Add fish-style auto-suggestion from history. (Thanks: [Amjith Ramanujam]) + + +1.11.0: +======= + +Features: +--------- + +* Handle reserved space for completion menu better in small windows. (Thanks: [Thomas Roten]). +* Display current vi mode in toolbar. (Thanks: [Thomas Roten]). +* Opening an external editor will edit the last-run query. (Thanks: [Thomas Roten]). +* Output once special command. (Thanks: [Dick Marinus]). +* Add special command to show create table statement. (Thanks: [Ryan Smith]) +* Display all result sets returned by stored procedures (Thanks: [Thomas Roten]). +* Add current time to prompt options (Thanks: [Thomas Roten]). +* Output status text in a more intuitive way (Thanks: [Thomas Roten]). +* Add colored/styled headers and odd/even rows (Thanks: [Thomas Roten]). +* Keyword completion casing (upper/lower/auto) (Thanks: [Irina Truong]). + +Bug Fixes: +---------- + +* Fixed incorrect timekeeping when running queries from a file. (Thanks: [Thomas Roten]). +* Do not display time and empty line for blank queries (Thanks: [Thomas Roten]). +* Fixed issue where quit command would sometimes not work (Thanks: [Thomas Roten]). +* Remove shebang from main.py (Thanks: [Dick Marinus]). +* Only use pager if output doesn't fit. (Thanks: [Dick Marinus]). +* Support tilde user directory for output file names (Thanks: [Thomas Roten]). +* Auto vertical output is a little bit better at its calculations (Thanks: [Thomas Roten]). + +Internal Changes: +----------------- + +* Rename tests/ to test/. (Thanks: [Dick Marinus]). +* Move AUTHORS and SPONSORS to mycli directory. (Thanks: [Terje Røsten] []). +* Switch from pycryptodome to cryptography (Thanks: [Thomas Roten]). +* Add pager wrapper for behave tests (Thanks: [Dick Marinus]). +* Behave test source command (Thanks: [Dick Marinus]). +* Test using behave the tee command (Thanks: [Dick Marinus]). +* Behave fix clean up. (Thanks: [Dick Marinus]). +* Remove output formatter code in favor of CLI Helpers dependency (Thanks: [Thomas Roten]). +* Better handle common before/after scenarios in behave. (Thanks: [Dick Marinus]) +* Added a regression test for sqlparse >= 0.2.3 (Thanks: [Dick Marinus]). +* Reverted removal of temporary hack for sqlparse (Thanks: [Dick Marinus]). +* Add setup.py commands to simplify development tasks (Thanks: [Thomas Roten]). +* Add behave tests to tox (Thanks: [Dick Marinus]). +* Add missing @dbtest to tests (Thanks: [Dick Marinus]). +* Standardizes punctuation/grammar for help strings (Thanks: [Thomas Roten]). + +1.10.0: +======= + +Features: +--------- + +* Add ability to specify alternative myclirc file. (Thanks: [Dick Marinus]). +* Add new display formats for pretty printing query results. (Thanks: [Amjith + Ramanujam], [Dick Marinus], [Thomas Roten]). +* Add logic to shorten the default prompt if it becomes too long once generated. (Thanks: [John Sterling]). + +Bug Fixes: +---------- + +* Fix external editor bug (issue #377). (Thanks: [Irina Truong]). +* Fixed bug so that favorite queries can include unicode characters. (Thanks: + [Thomas Roten]). +* Fix requirements and remove old compatibility code (Thanks: [Dick Marinus]) +* Fix bug where mycli would not start due to the thanks/credit intro text. + (Thanks: [Thomas Roten]). +* Use pymysql default conversions (issue #375). (Thanks: [Dick Marinus]). + +Internal Changes: +----------------- + +* Upload mycli distributions in a safer manner (using twine). (Thanks: [Thomas + Roten]). +* Test mycli using pexpect/python-behave (Thanks: [Dick Marinus]). +* Run pep8 checks in travis (Thanks: [Irina Truong]). +* Remove temporary hack for sqlparse (Thanks: [Dick Marinus]). + +1.9.0: +====== + +Features: +--------- + +* Add tee/notee commands for outputing results to a file. (Thanks: [Dick Marinus]). +* Add date, port, and whitespace options to prompt configuration. (Thanks: [Matheus Rosa]). +* Allow user to specify LESS pager flags. (Thanks: [John Sterling]). +* Add support for auto-reconnect. (Thanks: [Jialong Liu]). +* Add CSV batch output. (Thanks: [Matheus Rosa]). +* Add `auto_vertical_output` config to myclirc. (Thanks: [Matheus Rosa]). +* Improve Fedora install instructions. (Thanks: [Dick Marinus]). + +Bug Fixes: +---------- + +* Fix crashes occuring from commands starting with #. (Thanks: [Zhidong]). +* Fix broken PyMySQL link in README. (Thanks: [Daniël van Eeden]). +* Add various missing keywords for highlighting and autocompletion. (Thanks: [zer09]). +* Add the missing REGEXP keyword for highlighting and autocompletion. (Thanks: [cxbig]). +* Fix duplicate username entries in completion list. (Thanks: [John Sterling]). +* Remove extra spaces in TSV table format output. (Thanks: [Dick Marinus]). +* Kill running query when interrupted via Ctrl-C. (Thanks: [chainkite]). +* Read the `smart_completion` config from myclirc. (Thanks: [Thomas Roten]). + +Internal Changes: +----------------- + +* Improve handling of test database credentials. (Thanks: [Dick Marinus]). +* Add Python 3.6 to test environments and PyPI metadata. (Thanks: [Thomas Roten]). +* Drop Python 2.6 support. (Thanks: [Thomas Roten]). +* Swap pycrypto dependency for pycryptodome. (Thanks: [Michał Górny]). +* Bump sqlparse version so pgcli and mycli can be installed together. (Thanks: [darikg]). + +1.8.1: +====== + +Bug Fixes: +---------- +* Remove duplicate listing of DISTINCT keyword. (Thanks: [Amjith Ramanujam]). +* Add an try/except for AS keyword crash. (Thanks: [Amjith Ramanujam]). +* Support python-sqlparse 0.2. (Thanks: [Dick Marinus]). +* Fallback to the raw object for invalid time values. (Thanks: [Amjith Ramanujam]). +* Reset the show items when completion is refreshed. (Thanks: [Amjith Ramanujam]). + +Internal Changes: +----------------- +* Make the dependency of sqlparse slightly more liberal. (Thanks: [Amjith Ramanujam]). + +1.8.0: +====== + +Features: +--------- + +* Add support for --execute/-e commandline arg. (Thanks: [Matheus Rosa]). +* Add `less_chatty` config option to skip the intro messages. (Thanks: [Scrappy Soft]). +* Support `MYCLI_HISTFILE` environment variable to specify where to write the history file. (Thanks: [Scrappy Soft]). +* Add `prompt_continuation` config option to allow configuring the continuation prompt for multi-line queries. (Thanks: [Scrappy Soft]). +* Display login-path instead of host in prompt. (Thanks: [Irina Truong]). + +Bug Fixes: +---------- + +* Pin sqlparse to version 0.1.19 since the new version is breaking completion. (Thanks: [Amjith Ramanujam]). +* Remove unsupported keywords. (Thanks: [Matheus Rosa]). +* Fix completion suggestion inside functions with operands. (Thanks: [Irina Truong]). + +1.7.0: +====== + +Features: +--------- + +* Add stdin batch mode. (Thanks: [Thomas Roten]). +* Add warn/no-warn command-line options. (Thanks: [Thomas Roten]). +* Upgrade sqlparse dependency to 0.1.19. (Thanks: [Amjith Ramanujam]). +* Update features list in README.md. (Thanks: [Matheus Rosa]). +* Remove extra \n in features list in README.md. (Thanks: [Matheus Rosa]). + +Bug Fixes: +---------- + +* Enable history search via <C-r>. (Thanks: [Amjith Ramanujam]). + +Internal Changes: +----------------- + +* Upgrade `prompt_toolkit` to 1.0.0. (Thanks: [Jonathan Slenders]) + +1.6.0: +====== + +Features: +--------- + +* Change continuation prompt for multi-line mode to match default mysql. +* Add `status` command to match mysql's `status` command. (Thanks: [Thomas Roten]). +* Add SSL support for `mycli`. (Thanks: [Artem Bezsmertnyi]). +* Add auto-completion and highlight support for OFFSET keyword. (Thanks: [Matheus Rosa]). +* Add support for `MYSQL_TEST_LOGIN_FILE` env variable to specify alternate login file. (Thanks: [Thomas Roten]). +* Add support for `--auto-vertical-output` to automatically switch to vertical output if the output doesn't fit in the table format. +* Add support for system-wide config. Now /etc/myclirc will be honored. (Thanks: [Thomas Roten]). +* Add support for `nopager` and `\n` to turn off the pager. (Thanks: [Thomas Roten]). +* Add support for `--local-infile` command-line option. (Thanks: [Thomas Roten]). + +Bug Fixes: +---------- + +* Remove -S from `less` option which was clobbering the scroll back in history. (Thanks: [Thomas Roten]). +* Make system command work with Python 3. (Thanks: [Thomas Roten]). +* Support \G terminator for \f queries. (Thanks: [Terseus]). + +Internal Changes: +----------------- + +* Upgrade `prompt_toolkit` to 0.60. +* Add Python 3.5 to test environments. (Thanks: [Thomas Roten]). +* Remove license meta-data. (Thanks: [Thomas Roten]). +* Skip binary tests if PyMySQL version does not support it. (Thanks: [Thomas Roten]). +* Refactor pager handling. (Thanks: [Thomas Roten]) +* Capture warnings to log file. (Thanks: [Mikhail Borisov]). +* Make `syntax_style` a tiny bit more intuitive. (Thanks: [Phil Cohen]). + +1.5.2: +====== + +Bug Fixes: +---------- + +* Protect against port number being None when no port is specified in command line. + +1.5.1: +====== + +Bug Fixes: +---------- + +* Cast the value of port read from my.cnf to int. + +1.5.0: +====== + +Features: +--------- + +* Make a config option to enable `audit_log`. (Thanks: [Matheus Rosa]). +* Add support for reading .mylogin.cnf to get user credentials. (Thanks: [Thomas Roten]). + This feature is only available when `pycrypto` package is installed. +* Register the special command `prompt` with the `\R` as alias. (Thanks: [Matheus Rosa]). + Users can now change the mysql prompt at runtime using `prompt` command. + eg: + ``` + mycli> prompt \u@\h> + Changed prompt format to \u@\h> + Time: 0.001s + amjith@localhost> + ``` +* Perform completion refresh in a background thread. Now mycli can handle + databases with thousands of tables without blocking. +* Add support for `system` command. (Thanks: [Matheus Rosa]). + Users can now run a system command from within mycli as follows: + ``` + amjith@localhost:(none)>system cat tmp.sql + select 1; + select * from django_migrations; + ``` +* Caught and hexed binary fields in MySQL. (Thanks: [Daniel West]). + Geometric fields stored in a database will be displayed as hexed strings. +* Treat enter key as tab when the suggestion menu is open. (Thanks: [Matheus Rosa]) +* Add "delete" and "truncate" as destructive commands. (Thanks: [Martijn Engler]). +* Change \dt syntax to add an optional table name. (Thanks: [Shoma Suzuki]). + `\dt [tablename]` will describe the columns in a table. +* Add TRANSACTION related keywords. +* Treat DESC and EXPLAIN as DESCRIBE. (Thanks: [spacewander]). + +Bug Fixes: +---------- + +* Fix the removal of whitespace from table output. +* Add ability to make suggestions for compound join clauses. (Thanks: [Matheus Rosa]). +* Fix the incorrect reporting of command time. +* Add type validation for port argument. (Thanks [Matheus Rosa]) + +Internal Changes: +----------------- +* Make pycrypto optional and only install it in \*nix systems. (Thanks: [Irina Truong]). +* Add badge for PyPI version to README. (Thanks: [Shoma Suzuki]). +* Updated release script with a --dry-run and --confirm-steps option. (Thanks: [Irina Truong]). +* Adds support for PyMySQL 0.6.2 and above. This is useful for debian package builders. (Thanks: [Thomas Roten]). +* Disable click warning. + +1.4.0: +====== + +Features: +--------- + +* Add `source` command. This allows running sql statement from a file. + + eg: + ``` + mycli> source filename.sql + ``` + +* Added a config option to make the warning before destructive commands optional. (Thanks: [Daniel West](https://github.com/danieljwest)) + + In the config file ~/.myclirc set `destructive_warning = False` which will + disable the warning before running `DROP` commands. + +* Add completion support for CHANGE TO and other master/slave commands. This is + still preliminary and it will be enhanced in the future. + +* Add custom styles to color the menus and toolbars. + +* Upgrade `prompt_toolkit` to 0.46. (Thanks: [Jonathan Slenders]) + + Multi-line queries are automatically indented. + +Bug Fixes: +---------- + +* Fix keyword completion after the `WHERE` clause. +* Add `\g` and `\G` as valid query terminators. Previously in multi-line mode + ending a query with a `\G` wouldn't run the query. This is now fixed. + +1.3.0: +====== + +Features: +--------- +* Add a new special command (\T) to change the table format on the fly. (Thanks: [Jonathan Bruno](https://github.com/brewneaux)) + eg: + ``` + mycli> \T tsv + ``` +* Add `--defaults-group-suffix` to the command line. This lets the user specify + a group to use in the my.cnf files. (Thanks: [Irina Truong](http://github.com/j-bennet)) + + In the my.cnf file a user can specify credentials for different databases and + invoke mycli with the group name to use the appropriate credentials. + eg: + ``` + # my.cnf + [client] + user = 'root' + socket = '/tmp/mysql.sock' + pager = 'less -RXSF' + database = 'account' + + [clientamjith] + user = 'amjith' + database = 'user_management' + + $ mycli --defaults-group-suffix=amjith # uses the [clientamjith] section in my.cnf + ``` + +* Add `--defaults-file` option to the command line. This allows specifying a + `my.cnf` to use at launch. This also makes it play nice with mysql sandbox. + +* Make `-p` and `--password` take the password in commandline. This makes mycli + a drop in replacement for mysql. + +1.2.0: +====== + +Features: +--------- + +* Add support for wider completion menus in the config file. + + Add `wider_completion_menu = True` in the config file (~/.myclirc) to enable this feature. + +Bug Fixes: +--------- + +* Prevent Ctrl-C from quitting mycli while the pager is active. +* Refresh auto-completions after the database is changed via a CONNECT command. + +Internal Changes: +----------------- + +* Upgrade `prompt_toolkit` dependency version to 0.45. +* Added Travis CI to run the tests automatically. + +1.1.1: +====== + +Bug Fixes: +---------- + +* Change dictonary comprehension used in mycnf reader to list comprehension to make it compatible with Python 2.6. + + +1.1.0: +====== + +Features: +--------- + +* Fuzzy completion is now case-insensitive. (Thanks: [bjarnagin](https://github.com/bjarnagin)) +* Added new-line (`\n`) to the list of special characters to use in prompt. (Thanks: [brewneaux](https://github.com/brewneaux)) +* Honor the `pager` setting in my.cnf files. (Thanks: [Irina Truong](http://github.com/j-bennet)) + +Bug Fixes: +---------- + +* Fix a crashing bug in completion engine for cross joins. +* Make `<null>` value consistent between tabular and vertical output. + +Internal Changes: +----------------- + +* Changed pymysql version to be greater than 0.6.6. +* Upgrade `prompt_toolkit` version to 0.42. (Thanks: [Yasuhiro Matsumoto](https://github.com/mattn)) +* Removed the explicit dependency on six. + +2015/06/10: +=========== + +Features: +--------- + +* Customizable prompt. (Thanks [Steve Robbins](https://github.com/steverobbins)) +* Make `\G` formatting to behave more like mysql. + +Bug Fixes: +---------- + +* Formatting issue in \G for really long column values. + + +2015/06/07: +=========== + +Features: +--------- + +* Upgrade `prompt_toolkit` to 0.38. This improves the performance of pasting long queries. +* Add support for reading my.cnf files. +* Add editor command \e. +* Replace ConfigParser with ConfigObj. +* Add \dt to show all tables. +* Add fuzzy completion for table names and column names. +* Automatically reconnect when connection is lost to the database. + +Bug Fixes: +---------- + +* Fix a bug with reconnect failure. +* Fix the issue with `use` command not changing the prompt. +* Fix the issue where `\\r` shortcut was not recognized. + + +2015/05/24 +========== + +Features: +--------- + +* Add support for connecting via socket. +* Add completion for SQL functions. +* Add completion support for SHOW statements. +* Made the timing of sql statements human friendly. +* Automatically prompt for a password if needed. + +Bug Fixes: +---------- +* Fixed the installation issues with PyMySQL dependency on case-sensitive file systems. + +[Daniel West]: http://github.com/danieljwest +[Irina Truong]: https://github.com/j-bennet +[Amjith Ramanujam]: https://blog.amjith.com +[Kacper Kwapisz]: https://github.com/KKKas +[Martijn Engler]: https://github.com/martijnengler +[Matheus Rosa]: https://github.com/mdsrosa +[Shoma Suzuki]: https://github.com/shoma +[spacewander]: https://github.com/spacewander +[Thomas Roten]: https://github.com/tsroten +[Artem Bezsmertnyi]: https://github.com/mrdeathless +[Mikhail Borisov]: https://github.com/borman +[Casper Langemeijer]: Casper Langemeijer +[Lennart Weller]: https://github.com/lhw +[Phil Cohen]: https://github.com/phlipper +[Terseus]: https://github.com/Terseus +[William GARCIA]: https://github.com/willgarcia +[Jonathan Slenders]: https://github.com/jonathanslenders +[Casper Langemeijer]: https://github.com/langemeijer +[Scrappy Soft]: https://github.com/scrappysoft +[Dick Marinus]: https://github.com/meeuw +[François Pietka]: https://github.com/fpietka +[Frederic Aoustin]: https://github.com/fraoustin +[Georgy Frolov]: https://github.com/pasenor diff --git a/mycli/AUTHORS b/mycli/AUTHORS new file mode 100644 index 0000000..b3636d9 --- /dev/null +++ b/mycli/AUTHORS @@ -0,0 +1,79 @@ +Project Lead: +------------- + * Thomas Roten + + +Core Developers: +---------------- + + * Irina Truong + * Matheus Rosa + * Darik Gamble + * Dick Marinus + * Amjith Ramanujam + +Contributors: +------------- + + * Steve Robbins + * Shoma Suzuki + * Daniel West + * Scrappy Soft + * Daniel Black + * Jonathan Bruno + * Casper Langemeijer + * Jonathan Slenders + * Artem Bezsmertnyi + * Mikhail Borisov + * Heath Naylor + * Phil Cohen + * spacewander + * Adam Chainz + * Johannes Hoff + * Kacper Kwapisz + * Lennart Weller + * Martijn Engler + * Terseus + * Tyler Kuipers + * William GARCIA + * Yasuhiro Matsumoto + * bjarnagin + * jbruno + * mrdeathless + * Abirami P + * John Sterling + * Jialong Liu + * Zhidong + * Daniël van Eeden + * zer09 + * cxbig + * chainkite + * Michał Górny + * Terje Røsten + * Ryan Smith + * Klaus Wünschel + * François Pietka + * Colin Caine + * Frederic Aoustin + * caitinggui + * ushuz + * Zhaolong Zhu + * Zhongyang Guan + * Huachao Mao + * QiaoHou Peng + * Yang Zou + * Angelo Lupo + * Aljosha Papsch + * Zane C. Bowers-Hadley + * Mike Palandra + * Georgy Frolov + * Jonathan Lloyd + * Nathan Huang + * Jakub Boukal + * Takeshi D. Itoh + * laixintao + +Creator: +-------- + +Amjith Ramanujam diff --git a/mycli/SPONSORS b/mycli/SPONSORS new file mode 100644 index 0000000..81b0904 --- /dev/null +++ b/mycli/SPONSORS @@ -0,0 +1,31 @@ +Many thanks to the following Kickstarter backers. + +* Tech Blue Software +* jweiland.net + +# Silver Sponsors + +* Whitane Tech +* Open Query Pty Ltd +* Prathap Ramamurthy +* Lincoln Loop + +# Sponsors + +* Nathan Taggart +* Iryna Cherniavska +* Sudaraka Wijesinghe +* www.mysqlfanboy.com +* Steve Robbins +* Norbert Spichtig +* orpharion bestheneme +* Daniel Black +* Anonymous +* Magnus udd +* Anonymous +* Lewis Peckover +* Cyrille Tabary +* Heath Naylor +* Ted Pennings +* Chris Anderton +* Jonathan Slenders diff --git a/mycli/__init__.py b/mycli/__init__.py new file mode 100644 index 0000000..53bfe2e --- /dev/null +++ b/mycli/__init__.py @@ -0,0 +1 @@ +__version__ = '1.22.2' diff --git a/mycli/clibuffer.py b/mycli/clibuffer.py new file mode 100644 index 0000000..c9d29d1 --- /dev/null +++ b/mycli/clibuffer.py @@ -0,0 +1,54 @@ +from prompt_toolkit.enums import DEFAULT_BUFFER +from prompt_toolkit.filters import Condition +from prompt_toolkit.application import get_app +from .packages.parseutils import is_open_quote +from .packages import special + + +def cli_is_multiline(mycli): + @Condition + def cond(): + doc = get_app().layout.get_buffer_by_name(DEFAULT_BUFFER).document + + if not mycli.multi_line: + return False + else: + return not _multiline_exception(doc.text) + return cond + + +def _multiline_exception(text): + orig = text + text = text.strip() + + # Multi-statement favorite query is a special case. Because there will + # be a semicolon separating statements, we can't consider semicolon an + # EOL. Let's consider an empty line an EOL instead. + if text.startswith('\\fs'): + return orig.endswith('\n') + + return ( + # Special Command + text.startswith('\\') or + + # Delimiter declaration + text.lower().startswith('delimiter') or + + # Ended with the current delimiter (usually a semi-column) + text.endswith(special.get_current_delimiter()) or + + text.endswith('\\g') or + text.endswith('\\G') or + + # Exit doesn't need semi-column` + (text == 'exit') or + + # Quit doesn't need semi-column + (text == 'quit') or + + # To all teh vim fans out there + (text == ':q') or + + # just a plain enter without any text + (text == '') + ) diff --git a/mycli/clistyle.py b/mycli/clistyle.py new file mode 100644 index 0000000..c94f793 --- /dev/null +++ b/mycli/clistyle.py @@ -0,0 +1,118 @@ +import logging + +import pygments.styles +from pygments.token import string_to_tokentype, Token +from pygments.style import Style as PygmentsStyle +from pygments.util import ClassNotFound +from prompt_toolkit.styles.pygments import style_from_pygments_cls +from prompt_toolkit.styles import merge_styles, Style + +logger = logging.getLogger(__name__) + +# map Pygments tokens (ptk 1.0) to class names (ptk 2.0). +TOKEN_TO_PROMPT_STYLE = { + Token.Menu.Completions.Completion.Current: 'completion-menu.completion.current', + Token.Menu.Completions.Completion: 'completion-menu.completion', + Token.Menu.Completions.Meta.Current: 'completion-menu.meta.completion.current', + Token.Menu.Completions.Meta: 'completion-menu.meta.completion', + Token.Menu.Completions.MultiColumnMeta: 'completion-menu.multi-column-meta', + Token.Menu.Completions.ProgressButton: 'scrollbar.arrow', # best guess + Token.Menu.Completions.ProgressBar: 'scrollbar', # best guess + Token.SelectedText: 'selected', + Token.SearchMatch: 'search', + Token.SearchMatch.Current: 'search.current', + Token.Toolbar: 'bottom-toolbar', + Token.Toolbar.Off: 'bottom-toolbar.off', + Token.Toolbar.On: 'bottom-toolbar.on', + Token.Toolbar.Search: 'search-toolbar', + Token.Toolbar.Search.Text: 'search-toolbar.text', + Token.Toolbar.System: 'system-toolbar', + Token.Toolbar.Arg: 'arg-toolbar', + Token.Toolbar.Arg.Text: 'arg-toolbar.text', + Token.Toolbar.Transaction.Valid: 'bottom-toolbar.transaction.valid', + Token.Toolbar.Transaction.Failed: 'bottom-toolbar.transaction.failed', + Token.Output.Header: 'output.header', + Token.Output.OddRow: 'output.odd-row', + Token.Output.EvenRow: 'output.even-row', + Token.Prompt: 'prompt', + Token.Continuation: 'continuation', +} + +# reverse dict for cli_helpers, because they still expect Pygments tokens. +PROMPT_STYLE_TO_TOKEN = { + v: k for k, v in TOKEN_TO_PROMPT_STYLE.items() +} + + +def parse_pygments_style(token_name, style_object, style_dict): + """Parse token type and style string. + + :param token_name: str name of Pygments token. Example: "Token.String" + :param style_object: pygments.style.Style instance to use as base + :param style_dict: dict of token names and their styles, customized to this cli + + """ + token_type = string_to_tokentype(token_name) + try: + other_token_type = string_to_tokentype(style_dict[token_name]) + return token_type, style_object.styles[other_token_type] + except AttributeError as err: + return token_type, style_dict[token_name] + + +def style_factory(name, cli_style): + try: + style = pygments.styles.get_style_by_name(name) + except ClassNotFound: + style = pygments.styles.get_style_by_name('native') + + prompt_styles = [] + # prompt-toolkit used pygments tokens for styling before, switched to style + # names in 2.0. Convert old token types to new style names, for backwards compatibility. + for token in cli_style: + if token.startswith('Token.'): + # treat as pygments token (1.0) + token_type, style_value = parse_pygments_style( + token, style, cli_style) + if token_type in TOKEN_TO_PROMPT_STYLE: + prompt_style = TOKEN_TO_PROMPT_STYLE[token_type] + prompt_styles.append((prompt_style, style_value)) + else: + # we don't want to support tokens anymore + logger.error('Unhandled style / class name: %s', token) + else: + # treat as prompt style name (2.0). See default style names here: + # https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py + prompt_styles.append((token, cli_style[token])) + + override_style = Style([('bottom-toolbar', 'noreverse')]) + return merge_styles([ + style_from_pygments_cls(style), + override_style, + Style(prompt_styles) + ]) + + +def style_factory_output(name, cli_style): + try: + style = pygments.styles.get_style_by_name(name).styles + except ClassNotFound: + style = pygments.styles.get_style_by_name('native').styles + + for token in cli_style: + if token.startswith('Token.'): + token_type, style_value = parse_pygments_style( + token, style, cli_style) + style.update({token_type: style_value}) + elif token in PROMPT_STYLE_TO_TOKEN: + token_type = PROMPT_STYLE_TO_TOKEN[token] + style.update({token_type: cli_style[token]}) + else: + # TODO: cli helpers will have to switch to ptk.Style + logger.error('Unhandled style / class name: %s', token) + + class OutputStyle(PygmentsStyle): + default_style = "" + styles = style + + return OutputStyle diff --git a/mycli/clitoolbar.py b/mycli/clitoolbar.py new file mode 100644 index 0000000..e03e182 --- /dev/null +++ b/mycli/clitoolbar.py @@ -0,0 +1,52 @@ +from prompt_toolkit.key_binding.vi_state import InputMode +from prompt_toolkit.application import get_app +from prompt_toolkit.enums import EditingMode +from .packages import special + + +def create_toolbar_tokens_func(mycli, show_fish_help): + """Return a function that generates the toolbar tokens.""" + def get_toolbar_tokens(): + result = [] + result.append(('class:bottom-toolbar', ' ')) + + if mycli.multi_line: + delimiter = special.get_current_delimiter() + result.append( + ( + 'class:bottom-toolbar', + ' ({} [{}] will end the line) '.format( + 'Semi-colon' if delimiter == ';' else 'Delimiter', delimiter) + )) + + if mycli.multi_line: + result.append(('class:bottom-toolbar.on', '[F3] Multiline: ON ')) + else: + result.append(('class:bottom-toolbar.off', + '[F3] Multiline: OFF ')) + if mycli.prompt_app.editing_mode == EditingMode.VI: + result.append(( + 'class:botton-toolbar.on', + 'Vi-mode ({})'.format(_get_vi_mode()) + )) + + if show_fish_help(): + result.append( + ('class:bottom-toolbar', ' Right-arrow to complete suggestion')) + + if mycli.completion_refresher.is_refreshing(): + result.append( + ('class:bottom-toolbar', ' Refreshing completions...')) + + return result + return get_toolbar_tokens + + +def _get_vi_mode(): + """Get the current vi mode for display.""" + return { + InputMode.INSERT: 'I', + InputMode.NAVIGATION: 'N', + InputMode.REPLACE: 'R', + InputMode.INSERT_MULTIPLE: 'M', + }[get_app().vi_state.input_mode] diff --git a/mycli/compat.py b/mycli/compat.py new file mode 100644 index 0000000..2ebfe07 --- /dev/null +++ b/mycli/compat.py @@ -0,0 +1,6 @@ +"""Platform and Python version compatibility support.""" + +import sys + + +WIN = sys.platform in ('win32', 'cygwin') diff --git a/mycli/completion_refresher.py b/mycli/completion_refresher.py new file mode 100644 index 0000000..e6c8dd0 --- /dev/null +++ b/mycli/completion_refresher.py @@ -0,0 +1,123 @@ +import threading +from .packages.special.main import COMMANDS +from collections import OrderedDict + +from .sqlcompleter import SQLCompleter +from .sqlexecute import SQLExecute + +class CompletionRefresher(object): + + refreshers = OrderedDict() + + def __init__(self): + self._completer_thread = None + self._restart_refresh = threading.Event() + + def refresh(self, executor, callbacks, completer_options=None): + """Creates a SQLCompleter object and populates it with the relevant + completion suggestions in a background thread. + + executor - SQLExecute object, used to extract the credentials to connect + to the database. + callbacks - A function or a list of functions to call after the thread + has completed the refresh. The newly created completion + object will be passed in as an argument to each callback. + completer_options - dict of options to pass to SQLCompleter. + + """ + if completer_options is None: + completer_options = {} + + if self.is_refreshing(): + self._restart_refresh.set() + return [(None, None, None, 'Auto-completion refresh restarted.')] + else: + self._completer_thread = threading.Thread( + target=self._bg_refresh, + args=(executor, callbacks, completer_options), + name='completion_refresh') + self._completer_thread.setDaemon(True) + self._completer_thread.start() + return [(None, None, None, + 'Auto-completion refresh started in the background.')] + + def is_refreshing(self): + return self._completer_thread and self._completer_thread.is_alive() + + def _bg_refresh(self, sqlexecute, callbacks, completer_options): + completer = SQLCompleter(**completer_options) + + # Create a new pgexecute method to popoulate the completions. + e = sqlexecute + executor = SQLExecute(e.dbname, e.user, e.password, e.host, e.port, + e.socket, e.charset, e.local_infile, e.ssl, + e.ssh_user, e.ssh_host, e.ssh_port, + e.ssh_password, e.ssh_key_filename) + + # If callbacks is a single function then push it into a list. + if callable(callbacks): + callbacks = [callbacks] + + while 1: + for refresher in self.refreshers.values(): + refresher(completer, executor) + if self._restart_refresh.is_set(): + self._restart_refresh.clear() + break + else: + # Break out of while loop if the for loop finishes natually + # without hitting the break statement. + break + + # Start over the refresh from the beginning if the for loop hit the + # break statement. + continue + + for callback in callbacks: + callback(completer) + +def refresher(name, refreshers=CompletionRefresher.refreshers): + """Decorator to add the decorated function to the dictionary of + refreshers. Any function decorated with a @refresher will be executed as + part of the completion refresh routine.""" + def wrapper(wrapped): + refreshers[name] = wrapped + return wrapped + return wrapper + +@refresher('databases') +def refresh_databases(completer, executor): + completer.extend_database_names(executor.databases()) + +@refresher('schemata') +def refresh_schemata(completer, executor): + # schemata - In MySQL Schema is the same as database. But for mycli + # schemata will be the name of the current database. + completer.extend_schemata(executor.dbname) + completer.set_dbname(executor.dbname) + +@refresher('tables') +def refresh_tables(completer, executor): + completer.extend_relations(executor.tables(), kind='tables') + completer.extend_columns(executor.table_columns(), kind='tables') + +@refresher('users') +def refresh_users(completer, executor): + completer.extend_users(executor.users()) + +# @refresher('views') +# def refresh_views(completer, executor): +# completer.extend_relations(executor.views(), kind='views') +# completer.extend_columns(executor.view_columns(), kind='views') + +@refresher('functions') +def refresh_functions(completer, executor): + completer.extend_functions(executor.functions()) + +@refresher('special_commands') +def refresh_special(completer, executor): + completer.extend_special_commands(COMMANDS.keys()) + +@refresher('show_commands') +def refresh_show_commands(completer, executor): + completer.extend_show_items(executor.show_candidates()) diff --git a/mycli/config.py b/mycli/config.py new file mode 100644 index 0000000..e0f2d1f --- /dev/null +++ b/mycli/config.py @@ -0,0 +1,286 @@ +import io +import shutil +from copy import copy +from io import BytesIO, TextIOWrapper +import logging +import os +from os.path import exists +import struct +import sys +from typing import Union + +from configobj import ConfigObj, ConfigObjError +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.backends import default_backend + +try: + basestring +except NameError: + basestring = str + + +logger = logging.getLogger(__name__) + + +def log(logger, level, message): + """Logs message to stderr if logging isn't initialized.""" + + if logger.parent.name != 'root': + logger.log(level, message) + else: + print(message, file=sys.stderr) + + +def read_config_file(f, list_values=True): + """Read a config file. + + *list_values* set to `True` is the default behavior of ConfigObj. + Disabling it causes values to not be parsed for lists, + (e.g. 'a,b,c' -> ['a', 'b', 'c']. Additionally, the config values are + not unquoted. We are disabling list_values when reading MySQL config files + so we can correctly interpret commas in passwords. + + """ + + if isinstance(f, basestring): + f = os.path.expanduser(f) + + try: + config = ConfigObj(f, interpolation=False, encoding='utf8', + list_values=list_values) + except ConfigObjError as e: + log(logger, logging.ERROR, "Unable to parse line {0} of config file " + "'{1}'.".format(e.line_number, f)) + log(logger, logging.ERROR, "Using successfully parsed config values.") + return e.config + except (IOError, OSError) as e: + log(logger, logging.WARNING, "You don't have permission to read " + "config file '{0}'.".format(e.filename)) + return None + + return config + + +def get_included_configs(config_file: Union[str, io.TextIOWrapper]) -> list: + """Get a list of configuration files that are included into config_path + with !includedir directive. + + "Normal" configs should be passed as file paths. The only exception + is .mylogin which is decoded into a stream. However, it never + contains include directives and so will be ignored by this + function. + + """ + if not isinstance(config_file, str) or not os.path.isfile(config_file): + return [] + included_configs = [] + + try: + with open(config_file) as f: + include_directives = filter( + lambda s: s.startswith('!includedir'), + f + ) + dirs = map(lambda s: s.strip().split()[-1], include_directives) + dirs = filter(os.path.isdir, dirs) + for dir in dirs: + for filename in os.listdir(dir): + if filename.endswith('.cnf'): + included_configs.append(os.path.join(dir, filename)) + except (PermissionError, UnicodeDecodeError): + pass + return included_configs + + +def read_config_files(files, list_values=True): + """Read and merge a list of config files.""" + + config = ConfigObj(list_values=list_values) + _files = copy(files) + while _files: + _file = _files.pop(0) + _config = read_config_file(_file, list_values=list_values) + + # expand includes only if we were able to parse config + # (otherwise we'll just encounter the same errors again) + if config is not None: + _files = get_included_configs(_file) + _files + if bool(_config) is True: + config.merge(_config) + config.filename = _config.filename + + return config + + +def write_default_config(source, destination, overwrite=False): + destination = os.path.expanduser(destination) + if not overwrite and exists(destination): + return + + shutil.copyfile(source, destination) + + +def get_mylogin_cnf_path(): + """Return the path to the login path file or None if it doesn't exist.""" + mylogin_cnf_path = os.getenv('MYSQL_TEST_LOGIN_FILE') + + if mylogin_cnf_path is None: + app_data = os.getenv('APPDATA') + default_dir = os.path.join(app_data, 'MySQL') if app_data else '~' + mylogin_cnf_path = os.path.join(default_dir, '.mylogin.cnf') + + mylogin_cnf_path = os.path.expanduser(mylogin_cnf_path) + + if exists(mylogin_cnf_path): + logger.debug("Found login path file at '{0}'".format(mylogin_cnf_path)) + return mylogin_cnf_path + return None + + +def open_mylogin_cnf(name): + """Open a readable version of .mylogin.cnf. + + Returns the file contents as a TextIOWrapper object. + + :param str name: The pathname of the file to be opened. + :return: the login path file or None + """ + + try: + with open(name, 'rb') as f: + plaintext = read_and_decrypt_mylogin_cnf(f) + except (OSError, IOError, ValueError): + logger.error('Unable to open login path file.') + return None + + if not isinstance(plaintext, BytesIO): + logger.error('Unable to read login path file.') + return None + + return TextIOWrapper(plaintext) + + +def read_and_decrypt_mylogin_cnf(f): + """Read and decrypt the contents of .mylogin.cnf. + + This decryption algorithm mimics the code in MySQL's + mysql_config_editor.cc. + + The login key is 20-bytes of random non-printable ASCII. + It is written to the actual login path file. It is used + to generate the real key used in the AES cipher. + + :param f: an I/O object opened in binary mode + :return: the decrypted login path file + :rtype: io.BytesIO or None + """ + + # Number of bytes used to store the length of ciphertext. + MAX_CIPHER_STORE_LEN = 4 + + LOGIN_KEY_LEN = 20 + + # Move past the unused buffer. + buf = f.read(4) + + if not buf or len(buf) != 4: + logger.error('Login path file is blank or incomplete.') + return None + + # Read the login key. + key = f.read(LOGIN_KEY_LEN) + + # Generate the real key. + rkey = [0] * 16 + for i in range(LOGIN_KEY_LEN): + try: + rkey[i % 16] ^= ord(key[i:i+1]) + except TypeError: + # ord() was unable to get the value of the byte. + logger.error('Unable to generate login path AES key.') + return None + rkey = struct.pack('16B', *rkey) + + # Create a decryptor object using the key. + decryptor = _get_decryptor(rkey) + + # Create a bytes buffer to hold the plaintext. + plaintext = BytesIO() + + while True: + # Read the length of the ciphertext. + len_buf = f.read(MAX_CIPHER_STORE_LEN) + if len(len_buf) < MAX_CIPHER_STORE_LEN: + break + cipher_len, = struct.unpack("<i", len_buf) + + # Read cipher_len bytes from the file and decrypt. + cipher = f.read(cipher_len) + plain = _remove_pad(decryptor.update(cipher)) + if plain is False: + continue + plaintext.write(plain) + + if plaintext.tell() == 0: + logger.error('No data successfully decrypted from login path file.') + return None + + plaintext.seek(0) + return plaintext + + +def str_to_bool(s): + """Convert a string value to its corresponding boolean value.""" + if isinstance(s, bool): + return s + elif not isinstance(s, basestring): + raise TypeError('argument must be a string') + + true_values = ('true', 'on', '1') + false_values = ('false', 'off', '0') + + if s.lower() in true_values: + return True + elif s.lower() in false_values: + return False + else: + raise ValueError('not a recognized boolean value: %s'.format(s)) + + +def strip_matching_quotes(s): + """Remove matching, surrounding quotes from a string. + + This is the same logic that ConfigObj uses when parsing config + values. + + """ + if (isinstance(s, basestring) and len(s) >= 2 and + s[0] == s[-1] and s[0] in ('"', "'")): + s = s[1:-1] + return s + + +def _get_decryptor(key): + """Get the AES decryptor.""" + c = Cipher(algorithms.AES(key), modes.ECB(), backend=default_backend()) + return c.decryptor() + + +def _remove_pad(line): + """Remove the pad from the *line*.""" + pad_length = ord(line[-1:]) + try: + # Determine pad length. + pad_length = ord(line[-1:]) + except TypeError: + # ord() was unable to get the value of the byte. + logger.warning('Unable to remove pad.') + return False + + if pad_length > len(line) or len(set(line[-pad_length:])) != 1: + # Pad length should be less than or equal to the length of the + # plaintext. The pad should have a single unique byte. + logger.warning('Invalid pad found in login path file.') + return False + + return line[:-pad_length] diff --git a/mycli/key_bindings.py b/mycli/key_bindings.py new file mode 100644 index 0000000..57b917b --- /dev/null +++ b/mycli/key_bindings.py @@ -0,0 +1,85 @@ +import logging +from prompt_toolkit.enums import EditingMode +from prompt_toolkit.filters import completion_is_selected +from prompt_toolkit.key_binding import KeyBindings + +_logger = logging.getLogger(__name__) + + +def mycli_bindings(mycli): + """Custom key bindings for mycli.""" + kb = KeyBindings() + + @kb.add('f2') + def _(event): + """Enable/Disable SmartCompletion Mode.""" + _logger.debug('Detected F2 key.') + mycli.completer.smart_completion = not mycli.completer.smart_completion + + @kb.add('f3') + def _(event): + """Enable/Disable Multiline Mode.""" + _logger.debug('Detected F3 key.') + mycli.multi_line = not mycli.multi_line + + @kb.add('f4') + def _(event): + """Toggle between Vi and Emacs mode.""" + _logger.debug('Detected F4 key.') + if mycli.key_bindings == "vi": + event.app.editing_mode = EditingMode.EMACS + mycli.key_bindings = "emacs" + else: + event.app.editing_mode = EditingMode.VI + mycli.key_bindings = "vi" + + @kb.add('tab') + def _(event): + """Force autocompletion at cursor.""" + _logger.debug('Detected <Tab> key.') + b = event.app.current_buffer + if b.complete_state: + b.complete_next() + else: + b.start_completion(select_first=True) + + @kb.add('c-space') + def _(event): + """ + Initialize autocompletion at cursor. + + If the autocompletion menu is not showing, display it with the + appropriate completions for the context. + + If the menu is showing, select the next completion. + """ + _logger.debug('Detected <C-Space> key.') + + b = event.app.current_buffer + if b.complete_state: + b.complete_next() + else: + b.start_completion(select_first=False) + + @kb.add('enter', filter=completion_is_selected) + def _(event): + """Makes the enter key work as the tab key only when showing the menu. + + In other words, don't execute query when enter is pressed in + the completion dropdown menu, instead close the dropdown menu + (accept current selection). + + """ + _logger.debug('Detected enter key.') + + event.current_buffer.complete_state = None + b = event.app.current_buffer + b.complete_state = None + + @kb.add('escape', 'enter') + def _(event): + """Introduces a line break regardless of multi-line mode or not.""" + _logger.debug('Detected alt-enter key.') + event.app.current_buffer.insert_text('\n') + + return kb diff --git a/mycli/lexer.py b/mycli/lexer.py new file mode 100644 index 0000000..4b14d72 --- /dev/null +++ b/mycli/lexer.py @@ -0,0 +1,12 @@ +from pygments.lexer import inherit +from pygments.lexers.sql import MySqlLexer +from pygments.token import Keyword + + +class MyCliLexer(MySqlLexer): + """Extends MySQL lexer to add keywords.""" + + tokens = { + 'root': [(r'\brepair\b', Keyword), + (r'\boffset\b', Keyword), inherit], + } diff --git a/mycli/magic.py b/mycli/magic.py new file mode 100644 index 0000000..5527f72 --- /dev/null +++ b/mycli/magic.py @@ -0,0 +1,54 @@ +from .main import MyCli +import sql.parse +import sql.connection +import logging + +_logger = logging.getLogger(__name__) + +def load_ipython_extension(ipython): + + # This is called via the ipython command '%load_ext mycli.magic'. + + # First, load the sql magic if it isn't already loaded. + if not ipython.find_line_magic('sql'): + ipython.run_line_magic('load_ext', 'sql') + + # Register our own magic. + ipython.register_magic_function(mycli_line_magic, 'line', 'mycli') + +def mycli_line_magic(line): + _logger.debug('mycli magic called: %r', line) + parsed = sql.parse.parse(line, {}) + conn = sql.connection.Connection.get(parsed['connection']) + + try: + # A corresponding mycli object already exists + mycli = conn._mycli + _logger.debug('Reusing existing mycli') + except AttributeError: + mycli = MyCli() + u = conn.session.engine.url + _logger.debug('New mycli: %r', str(u)) + + mycli.connect(u.database, u.host, u.username, u.port, u.password) + conn._mycli = mycli + + # For convenience, print the connection alias + print('Connected: {}'.format(conn.name)) + + try: + mycli.run_cli() + except SystemExit: + pass + + if not mycli.query_history: + return + + q = mycli.query_history[-1] + if q.mutating: + _logger.debug('Mutating query detected -- ignoring') + return + + if q.successful: + ipython = get_ipython() + return ipython.run_cell_magic('sql', line, q.query) diff --git a/mycli/main.py b/mycli/main.py new file mode 100755 index 0000000..03797a0 --- /dev/null +++ b/mycli/main.py @@ -0,0 +1,1326 @@ +import os +import sys +import traceback +import logging +import threading +import re +import fileinput +from collections import namedtuple +try: + from pwd import getpwuid +except ImportError: + pass +from time import time +from datetime import datetime +from random import choice +from io import open + +from pymysql import OperationalError +from cli_helpers.tabular_output import TabularOutputFormatter +from cli_helpers.tabular_output import preprocessors +from cli_helpers.utils import strip_ansi +import click +import sqlparse +from mycli.packages.parseutils import is_dropping_database +from prompt_toolkit.completion import DynamicCompleter +from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode +from prompt_toolkit.key_binding.bindings.named_commands import register as prompt_register +from prompt_toolkit.shortcuts import PromptSession, CompleteStyle +from prompt_toolkit.document import Document +from prompt_toolkit.filters import HasFocus, IsDone +from prompt_toolkit.layout.processors import (HighlightMatchingBracketProcessor, + ConditionalProcessor) +from prompt_toolkit.lexers import PygmentsLexer +from prompt_toolkit.history import FileHistory +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory + +from .packages.special.main import NO_QUERY +from .packages.prompt_utils import confirm, confirm_destructive_query +from .packages.tabular_output import sql_format +from .packages import special +from .packages.special.favoritequeries import FavoriteQueries +from .sqlcompleter import SQLCompleter +from .clitoolbar import create_toolbar_tokens_func +from .clistyle import style_factory, style_factory_output +from .sqlexecute import FIELD_TYPES, SQLExecute +from .clibuffer import cli_is_multiline +from .completion_refresher import CompletionRefresher +from .config import (write_default_config, get_mylogin_cnf_path, + open_mylogin_cnf, read_config_files, str_to_bool, + strip_matching_quotes) +from .key_bindings import mycli_bindings +from .lexer import MyCliLexer +from .__init__ import __version__ +from .compat import WIN +from .packages.filepaths import dir_path_exists, guess_socket_location + +import itertools + +click.disable_unicode_literals_warning = True + +try: + from urlparse import urlparse + from urlparse import unquote +except ImportError: + from urllib.parse import urlparse + from urllib.parse import unquote + + +try: + import paramiko +except ImportError: + from mycli.packages.paramiko_stub import paramiko + +# Query tuples are used for maintaining history +Query = namedtuple('Query', ['query', 'successful', 'mutating']) + +PACKAGE_ROOT = os.path.abspath(os.path.dirname(__file__)) + + +class MyCli(object): + + default_prompt = '\\t \\u@\\h:\\d> ' + max_len_prompt = 45 + defaults_suffix = None + + # In order of being loaded. Files lower in list override earlier ones. + cnf_files = [ + '/etc/my.cnf', + '/etc/mysql/my.cnf', + '/usr/local/etc/my.cnf', + '~/.my.cnf' + ] + + # check XDG_CONFIG_HOME exists and not an empty string + if os.environ.get("XDG_CONFIG_HOME"): + xdg_config_home = os.environ.get("XDG_CONFIG_HOME") + else: + xdg_config_home = "~/.config" + system_config_files = [ + '/etc/myclirc', + os.path.join(xdg_config_home, "mycli", "myclirc") + ] + + default_config_file = os.path.join(PACKAGE_ROOT, 'myclirc') + pwd_config_file = os.path.join(os.getcwd(), ".myclirc") + + def __init__(self, sqlexecute=None, prompt=None, + logfile=None, defaults_suffix=None, defaults_file=None, + login_path=None, auto_vertical_output=False, warn=None, + myclirc="~/.myclirc"): + self.sqlexecute = sqlexecute + self.logfile = logfile + self.defaults_suffix = defaults_suffix + self.login_path = login_path + + # self.cnf_files is a class variable that stores the list of mysql + # config files to read in at launch. + # If defaults_file is specified then override the class variable with + # defaults_file. + if defaults_file: + self.cnf_files = [defaults_file] + + # Load config. + config_files = ([self.default_config_file] + self.system_config_files + + [myclirc] + [self.pwd_config_file]) + c = self.config = read_config_files(config_files) + self.multi_line = c['main'].as_bool('multi_line') + self.key_bindings = c['main']['key_bindings'] + special.set_timing_enabled(c['main'].as_bool('timing')) + + FavoriteQueries.instance = FavoriteQueries.from_config(self.config) + + self.dsn_alias = None + self.formatter = TabularOutputFormatter( + format_name=c['main']['table_format']) + sql_format.register_new_formatter(self.formatter) + self.formatter.mycli = self + self.syntax_style = c['main']['syntax_style'] + self.less_chatty = c['main'].as_bool('less_chatty') + self.cli_style = c['colors'] + self.output_style = style_factory_output( + self.syntax_style, + self.cli_style + ) + self.wider_completion_menu = c['main'].as_bool('wider_completion_menu') + c_dest_warning = c['main'].as_bool('destructive_warning') + self.destructive_warning = c_dest_warning if warn is None else warn + self.login_path_as_host = c['main'].as_bool('login_path_as_host') + + # read from cli argument or user config file + self.auto_vertical_output = auto_vertical_output or \ + c['main'].as_bool('auto_vertical_output') + + # Write user config if system config wasn't the last config loaded. + if c.filename not in self.system_config_files: + write_default_config(self.default_config_file, myclirc) + + # audit log + if self.logfile is None and 'audit_log' in c['main']: + try: + self.logfile = open(os.path.expanduser(c['main']['audit_log']), 'a') + except (IOError, OSError) as e: + self.echo('Error: Unable to open the audit log file. Your queries will not be logged.', + err=True, fg='red') + self.logfile = False + + self.completion_refresher = CompletionRefresher() + + self.logger = logging.getLogger(__name__) + self.initialize_logging() + + prompt_cnf = self.read_my_cnf_files(self.cnf_files, ['prompt'])['prompt'] + self.prompt_format = prompt or prompt_cnf or c['main']['prompt'] or \ + self.default_prompt + self.multiline_continuation_char = c['main']['prompt_continuation'] + keyword_casing = c['main'].get('keyword_casing', 'auto') + + self.query_history = [] + + # Initialize completer. + self.smart_completion = c['main'].as_bool('smart_completion') + self.completer = SQLCompleter( + self.smart_completion, + supported_formats=self.formatter.supported_formats, + keyword_casing=keyword_casing) + self._completer_lock = threading.Lock() + + # Register custom special commands. + self.register_special_commands() + + # Load .mylogin.cnf if it exists. + mylogin_cnf_path = get_mylogin_cnf_path() + if mylogin_cnf_path: + mylogin_cnf = open_mylogin_cnf(mylogin_cnf_path) + if mylogin_cnf_path and mylogin_cnf: + # .mylogin.cnf gets read last, even if defaults_file is specified. + self.cnf_files.append(mylogin_cnf) + elif mylogin_cnf_path and not mylogin_cnf: + # There was an error reading the login path file. + print('Error: Unable to read login path file.') + + self.prompt_app = None + + def register_special_commands(self): + special.register_special_command(self.change_db, 'use', + '\\u', 'Change to a new database.', aliases=('\\u',)) + special.register_special_command(self.change_db, 'connect', + '\\r', 'Reconnect to the database. Optional database argument.', + aliases=('\\r', ), case_sensitive=True) + special.register_special_command(self.refresh_completions, 'rehash', + '\\#', 'Refresh auto-completions.', arg_type=NO_QUERY, aliases=('\\#',)) + special.register_special_command( + self.change_table_format, 'tableformat', '\\T', + 'Change the table format used to output results.', + aliases=('\\T',), case_sensitive=True) + special.register_special_command(self.execute_from_file, 'source', '\\. filename', + 'Execute commands from file.', aliases=('\\.',)) + special.register_special_command(self.change_prompt_format, 'prompt', + '\\R', 'Change prompt format.', aliases=('\\R',), case_sensitive=True) + + def change_table_format(self, arg, **_): + try: + self.formatter.format_name = arg + yield (None, None, None, + 'Changed table format to {}'.format(arg)) + except ValueError: + msg = 'Table format {} not recognized. Allowed formats:'.format( + arg) + for table_type in self.formatter.supported_formats: + msg += "\n\t{}".format(table_type) + yield (None, None, None, msg) + + def change_db(self, arg, **_): + if not arg: + click.secho( + "No database selected", + err=True, fg="red" + ) + return + + self.sqlexecute.change_db(arg) + + yield (None, None, None, 'You are now connected to database "%s" as ' + 'user "%s"' % (self.sqlexecute.dbname, self.sqlexecute.user)) + + def execute_from_file(self, arg, **_): + if not arg: + message = 'Missing required argument, filename.' + return [(None, None, None, message)] + try: + with open(os.path.expanduser(arg)) as f: + query = f.read() + except IOError as e: + return [(None, None, None, str(e))] + + if (self.destructive_warning and + confirm_destructive_query(query) is False): + message = 'Wise choice. Command execution stopped.' + return [(None, None, None, message)] + + return self.sqlexecute.run(query) + + def change_prompt_format(self, arg, **_): + """ + Change the prompt format. + """ + if not arg: + message = 'Missing required argument, format.' + return [(None, None, None, message)] + + self.prompt_format = self.get_prompt(arg) + return [(None, None, None, "Changed prompt format to %s" % arg)] + + def initialize_logging(self): + + log_file = os.path.expanduser(self.config['main']['log_file']) + log_level = self.config['main']['log_level'] + + level_map = {'CRITICAL': logging.CRITICAL, + 'ERROR': logging.ERROR, + 'WARNING': logging.WARNING, + 'INFO': logging.INFO, + 'DEBUG': logging.DEBUG + } + + # Disable logging if value is NONE by switching to a no-op handler + # Set log level to a high value so it doesn't even waste cycles getting called. + if log_level.upper() == "NONE": + handler = logging.NullHandler() + log_level = "CRITICAL" + elif dir_path_exists(log_file): + handler = logging.FileHandler(log_file) + else: + self.echo( + 'Error: Unable to open the log file "{}".'.format(log_file), + err=True, fg='red') + return + + formatter = logging.Formatter( + '%(asctime)s (%(process)d/%(threadName)s) ' + '%(name)s %(levelname)s - %(message)s') + + handler.setFormatter(formatter) + + root_logger = logging.getLogger('mycli') + root_logger.addHandler(handler) + root_logger.setLevel(level_map[log_level.upper()]) + + logging.captureWarnings(True) + + root_logger.debug('Initializing mycli logging.') + root_logger.debug('Log file %r.', log_file) + + + def read_my_cnf_files(self, files, keys): + """ + Reads a list of config files and merges them. The last one will win. + :param files: list of files to read + :param keys: list of keys to retrieve + :returns: tuple, with None for missing keys. + """ + cnf = read_config_files(files, list_values=False) + + sections = ['client', 'mysqld'] + if self.login_path and self.login_path != 'client': + sections.append(self.login_path) + + if self.defaults_suffix: + sections.extend([sect + self.defaults_suffix for sect in sections]) + + def get(key): + result = None + for sect in cnf: + if sect in sections and key in cnf[sect]: + result = strip_matching_quotes(cnf[sect][key]) + return result + + return {x: get(x) for x in keys} + + def merge_ssl_with_cnf(self, ssl, cnf): + """Merge SSL configuration dict with cnf dict""" + + merged = {} + merged.update(ssl) + prefix = 'ssl-' + for k, v in cnf.items(): + # skip unrelated options + if not k.startswith(prefix): + continue + if v is None: + continue + # special case because PyMySQL argument is significantly different + # from commandline + if k == 'ssl-verify-server-cert': + merged['check_hostname'] = v + else: + # use argument name just strip "ssl-" prefix + arg = k[len(prefix):] + merged[arg] = v + + return merged + + def connect(self, database='', user='', passwd='', host='', port='', + socket='', charset='', local_infile='', ssl='', + ssh_user='', ssh_host='', ssh_port='', + ssh_password='', ssh_key_filename=''): + + cnf = {'database': None, + 'user': None, + 'password': None, + 'host': None, + 'port': None, + 'socket': None, + 'default-character-set': None, + 'local-infile': None, + 'loose-local-infile': None, + 'ssl-ca': None, + 'ssl-cert': None, + 'ssl-key': None, + 'ssl-cipher': None, + 'ssl-verify-serer-cert': None, + } + + cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys()) + + # Fall back to config values only if user did not specify a value. + + database = database or cnf['database'] + # Socket interface not supported for SSH connections + if port or host or ssh_host or ssh_port: + socket = '' + else: + socket = socket or cnf['socket'] or guess_socket_location() + user = user or cnf['user'] or os.getenv('USER') + host = host or cnf['host'] + port = port or cnf['port'] + ssl = ssl or {} + + passwd = passwd or cnf['password'] + charset = charset or cnf['default-character-set'] or 'utf8' + + # Favor whichever local_infile option is set. + for local_infile_option in (local_infile, cnf['local-infile'], + cnf['loose-local-infile'], False): + try: + local_infile = str_to_bool(local_infile_option) + break + except (TypeError, ValueError): + pass + + ssl = self.merge_ssl_with_cnf(ssl, cnf) + # prune lone check_hostname=False + if not any(v for v in ssl.values()): + ssl = None + + # Connect to the database. + + def _connect(): + try: + self.sqlexecute = SQLExecute( + database, user, passwd, host, port, socket, charset, + local_infile, ssl, ssh_user, ssh_host, ssh_port, + ssh_password, ssh_key_filename + ) + except OperationalError as e: + if ('Access denied for user' in e.args[1]): + new_passwd = click.prompt('Password', hide_input=True, + show_default=False, type=str, err=True) + self.sqlexecute = SQLExecute( + database, user, new_passwd, host, port, socket, + charset, local_infile, ssl, ssh_user, ssh_host, + ssh_port, ssh_password, ssh_key_filename + ) + else: + raise e + + try: + if not WIN and socket: + socket_owner = getpwuid(os.stat(socket).st_uid).pw_name + self.echo( + f"Connecting to socket {socket}, owned by user {socket_owner}") + try: + _connect() + except OperationalError as e: + # These are "Can't open socket" and 2x "Can't connect" + if [code for code in (2001, 2002, 2003) if code == e.args[0]]: + self.logger.debug('Database connection failed: %r.', e) + self.logger.error( + "traceback: %r", traceback.format_exc()) + self.logger.debug('Retrying over TCP/IP') + self.echo( + "Failed to connect to local MySQL server through socket '{}':".format(socket)) + self.echo(str(e), err=True) + self.echo( + 'Retrying over TCP/IP', err=True) + + # Else fall back to TCP/IP localhost + socket = "" + host = 'localhost' + port = 3306 + _connect() + else: + raise e + else: + host = host or 'localhost' + port = port or 3306 + + # Bad ports give particularly daft error messages + try: + port = int(port) + except ValueError as e: + self.echo("Error: Invalid port number: '{0}'.".format(port), + err=True, fg='red') + exit(1) + + _connect() + except Exception as e: # Connecting to a database could fail. + self.logger.debug('Database connection failed: %r.', e) + self.logger.error("traceback: %r", traceback.format_exc()) + self.echo(str(e), err=True, fg='red') + exit(1) + + def handle_editor_command(self, text): + """Editor command is any query that is prefixed or suffixed by a '\e'. + The reason for a while loop is because a user might edit a query + multiple times. For eg: + + "select * from \e"<enter> to edit it in vim, then come + back to the prompt with the edited query "select * from + blah where q = 'abc'\e" to edit it again. + :param text: Document + :return: Document + + """ + + while special.editor_command(text): + filename = special.get_filename(text) + query = (special.get_editor_query(text) or + self.get_last_query()) + sql, message = special.open_external_editor(filename, sql=query) + if message: + # Something went wrong. Raise an exception and bail. + raise RuntimeError(message) + while True: + try: + text = self.prompt_app.prompt(default=sql) + break + except KeyboardInterrupt: + sql = "" + + continue + return text + + def run_cli(self): + iterations = 0 + sqlexecute = self.sqlexecute + logger = self.logger + self.configure_pager() + + if self.smart_completion: + self.refresh_completions() + + author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS') + sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS') + + history_file = os.path.expanduser( + os.environ.get('MYCLI_HISTFILE', '~/.mycli-history')) + if dir_path_exists(history_file): + history = FileHistory(history_file) + else: + history = None + self.echo( + 'Error: Unable to open the history file "{}". ' + 'Your query history will not be saved.'.format(history_file), + err=True, fg='red') + + key_bindings = mycli_bindings(self) + + if not self.less_chatty: + print(' '.join(sqlexecute.server_type())) + print('mycli', __version__) + print('Chat: https://gitter.im/dbcli/mycli') + print('Mail: https://groups.google.com/forum/#!forum/mycli-users') + print('Home: http://mycli.net') + print('Thanks to the contributor -', thanks_picker([author_file, sponsor_file])) + + def get_message(): + prompt = self.get_prompt(self.prompt_format) + if self.prompt_format == self.default_prompt and len(prompt) > self.max_len_prompt: + prompt = self.get_prompt('\\d> ') + return [('class:prompt', prompt)] + + def get_continuation(width, *_): + if self.multiline_continuation_char: + left_padding = width - len(self.multiline_continuation_char) + continuation = " " * \ + max((left_padding - 1), 0) + \ + self.multiline_continuation_char + " " + else: + continuation = " " + return [('class:continuation', continuation)] + + def show_suggestion_tip(): + return iterations < 2 + + def one_iteration(text=None): + if text is None: + try: + text = self.prompt_app.prompt() + except KeyboardInterrupt: + return + + special.set_expanded_output(False) + + try: + text = self.handle_editor_command(text) + except RuntimeError as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + self.echo(str(e), err=True, fg='red') + return + + if not text.strip(): + return + + if self.destructive_warning: + destroy = confirm_destructive_query(text) + if destroy is None: + pass # Query was not destructive. Nothing to do here. + elif destroy is True: + self.echo('Your call!') + else: + self.echo('Wise choice!') + return + else: + destroy = True + + # Keep track of whether or not the query is mutating. In case + # of a multi-statement query, the overall query is considered + # mutating if any one of the component statements is mutating + mutating = False + + try: + logger.debug('sql: %r', text) + + special.write_tee(self.get_prompt(self.prompt_format) + text) + if self.logfile: + self.logfile.write('\n# %s\n' % datetime.now()) + self.logfile.write(text) + self.logfile.write('\n') + + successful = False + start = time() + res = sqlexecute.run(text) + self.formatter.query = text + successful = True + result_count = 0 + for title, cur, headers, status in res: + logger.debug("headers: %r", headers) + logger.debug("rows: %r", cur) + logger.debug("status: %r", status) + threshold = 1000 + if (is_select(status) and + cur and cur.rowcount > threshold): + self.echo('The result set has more than {} rows.'.format( + threshold), fg='red') + if not confirm('Do you want to continue?'): + self.echo("Aborted!", err=True, fg='red') + break + + if self.auto_vertical_output: + max_width = self.prompt_app.output.get_size().columns + else: + max_width = None + + formatted = self.format_output( + title, cur, headers, special.is_expanded_output(), + max_width) + + t = time() - start + try: + if result_count > 0: + self.echo('') + try: + self.output(formatted, status) + except KeyboardInterrupt: + pass + if special.is_timing_enabled(): + self.echo('Time: %0.03fs' % t) + except KeyboardInterrupt: + pass + + start = time() + result_count += 1 + mutating = mutating or destroy or is_mutating(status) + special.unset_once_if_written() + except EOFError as e: + raise e + except KeyboardInterrupt: + # get last connection id + connection_id_to_kill = sqlexecute.connection_id + logger.debug("connection id to kill: %r", connection_id_to_kill) + # Restart connection to the database + sqlexecute.connect() + try: + for title, cur, headers, status in sqlexecute.run('kill %s' % connection_id_to_kill): + status_str = str(status).lower() + if status_str.find('ok') > -1: + logger.debug("cancelled query, connection id: %r, sql: %r", + connection_id_to_kill, text) + self.echo("cancelled query", err=True, fg='red') + except Exception as e: + self.echo('Encountered error while cancelling query: {}'.format(e), + err=True, fg='red') + except NotImplementedError: + self.echo('Not Yet Implemented.', fg="yellow") + except OperationalError as e: + logger.debug("Exception: %r", e) + if (e.args[0] in (2003, 2006, 2013)): + logger.debug('Attempting to reconnect.') + self.echo('Reconnecting...', fg='yellow') + try: + sqlexecute.connect() + logger.debug('Reconnected successfully.') + one_iteration(text) + return # OK to just return, cuz the recursion call runs to the end. + except OperationalError as e: + logger.debug('Reconnect failed. e: %r', e) + self.echo(str(e), err=True, fg='red') + # If reconnection failed, don't proceed further. + return + else: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + self.echo(str(e), err=True, fg='red') + except Exception as e: + logger.error("sql: %r, error: %r", text, e) + logger.error("traceback: %r", traceback.format_exc()) + self.echo(str(e), err=True, fg='red') + else: + if is_dropping_database(text, self.sqlexecute.dbname): + self.sqlexecute.dbname = None + self.sqlexecute.connect() + + # Refresh the table names and column names if necessary. + if need_completion_refresh(text): + self.refresh_completions( + reset=need_completion_reset(text)) + finally: + if self.logfile is False: + self.echo("Warning: This query was not logged.", + err=True, fg='red') + query = Query(text, successful, mutating) + self.query_history.append(query) + + get_toolbar_tokens = create_toolbar_tokens_func( + self, show_suggestion_tip) + if self.wider_completion_menu: + complete_style = CompleteStyle.MULTI_COLUMN + else: + complete_style = CompleteStyle.COLUMN + + with self._completer_lock: + + if self.key_bindings == 'vi': + editing_mode = EditingMode.VI + else: + editing_mode = EditingMode.EMACS + + self.prompt_app = PromptSession( + lexer=PygmentsLexer(MyCliLexer), + reserve_space_for_menu=self.get_reserved_space(), + message=get_message, + prompt_continuation=get_continuation, + bottom_toolbar=get_toolbar_tokens, + complete_style=complete_style, + input_processors=[ConditionalProcessor( + processor=HighlightMatchingBracketProcessor( + chars='[](){}'), + filter=HasFocus(DEFAULT_BUFFER) & ~IsDone() + )], + tempfile_suffix='.sql', + completer=DynamicCompleter(lambda: self.completer), + history=history, + auto_suggest=AutoSuggestFromHistory(), + complete_while_typing=True, + multiline=cli_is_multiline(self), + style=style_factory(self.syntax_style, self.cli_style), + include_default_pygments_style=False, + key_bindings=key_bindings, + enable_open_in_editor=True, + enable_system_prompt=True, + enable_suspend=True, + editing_mode=editing_mode, + search_ignore_case=True + ) + + try: + while True: + one_iteration() + iterations += 1 + except EOFError: + special.close_tee() + if not self.less_chatty: + self.echo('Goodbye!') + + def log_output(self, output): + """Log the output in the audit log, if it's enabled.""" + if self.logfile: + click.echo(output, file=self.logfile) + + def echo(self, s, **kwargs): + """Print a message to stdout. + + The message will be logged in the audit log, if enabled. + + All keyword arguments are passed to click.echo(). + + """ + self.log_output(s) + click.secho(s, **kwargs) + + def get_output_margin(self, status=None): + """Get the output margin (number of rows for the prompt, footer and + timing message.""" + margin = self.get_reserved_space() + self.get_prompt(self.prompt_format).count('\n') + 1 + if special.is_timing_enabled(): + margin += 1 + if status: + margin += 1 + status.count('\n') + + return margin + + + def output(self, output, status=None): + """Output text to stdout or a pager command. + + The status text is not outputted to pager or files. + + The message will be logged in the audit log, if enabled. The + message will be written to the tee file, if enabled. The + message will be written to the output file, if enabled. + + """ + if output: + size = self.prompt_app.output.get_size() + + margin = self.get_output_margin(status) + + fits = True + buf = [] + output_via_pager = self.explicit_pager and special.is_pager_enabled() + for i, line in enumerate(output, 1): + self.log_output(line) + special.write_tee(line) + special.write_once(line) + + if fits or output_via_pager: + # buffering + buf.append(line) + if len(line) > size.columns or i > (size.rows - margin): + fits = False + if not self.explicit_pager and special.is_pager_enabled(): + # doesn't fit, use pager + output_via_pager = True + + if not output_via_pager: + # doesn't fit, flush buffer + for line in buf: + click.secho(line) + buf = [] + else: + click.secho(line) + + if buf: + if output_via_pager: + def newlinewrapper(text): + for line in text: + yield line + "\n" + click.echo_via_pager(newlinewrapper(buf)) + else: + for line in buf: + click.secho(line) + + if status: + self.log_output(status) + click.secho(status) + + def configure_pager(self): + # Provide sane defaults for less if they are empty. + if not os.environ.get('LESS'): + os.environ['LESS'] = '-RXF' + + cnf = self.read_my_cnf_files(self.cnf_files, ['pager', 'skip-pager']) + if cnf['pager']: + special.set_pager(cnf['pager']) + self.explicit_pager = True + else: + self.explicit_pager = False + + if cnf['skip-pager'] or not self.config['main'].as_bool('enable_pager'): + special.disable_pager() + + def refresh_completions(self, reset=False): + if reset: + with self._completer_lock: + self.completer.reset_completions() + self.completion_refresher.refresh( + self.sqlexecute, self._on_completions_refreshed, + {'smart_completion': self.smart_completion, + 'supported_formats': self.formatter.supported_formats, + 'keyword_casing': self.completer.keyword_casing}) + + return [(None, None, None, + 'Auto-completion refresh started in the background.')] + + def _on_completions_refreshed(self, new_completer): + """Swap the completer object in cli with the newly created completer. + """ + with self._completer_lock: + self.completer = new_completer + + if self.prompt_app: + # After refreshing, redraw the CLI to clear the statusbar + # "Refreshing completions..." indicator + self.prompt_app.app.invalidate() + + def get_completions(self, text, cursor_positition): + with self._completer_lock: + return self.completer.get_completions( + Document(text=text, cursor_position=cursor_positition), None) + + def get_prompt(self, string): + sqlexecute = self.sqlexecute + host = self.login_path if self.login_path and self.login_path_as_host else sqlexecute.host + now = datetime.now() + string = string.replace('\\u', sqlexecute.user or '(none)') + string = string.replace('\\h', host or '(none)') + string = string.replace('\\d', sqlexecute.dbname or '(none)') + string = string.replace('\\t', sqlexecute.server_type()[0] or 'mycli') + string = string.replace('\\n', "\n") + string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y')) + string = string.replace('\\m', now.strftime('%M')) + string = string.replace('\\P', now.strftime('%p')) + string = string.replace('\\R', now.strftime('%H')) + string = string.replace('\\r', now.strftime('%I')) + string = string.replace('\\s', now.strftime('%S')) + string = string.replace('\\p', str(sqlexecute.port)) + string = string.replace('\\A', self.dsn_alias or '(none)') + string = string.replace('\\_', ' ') + return string + + def run_query(self, query, new_line=True): + """Runs *query*.""" + results = self.sqlexecute.run(query) + for result in results: + title, cur, headers, status = result + self.formatter.query = query + output = self.format_output(title, cur, headers) + for line in output: + click.echo(line, nl=new_line) + + def format_output(self, title, cur, headers, expanded=False, + max_width=None): + expanded = expanded or self.formatter.format_name == 'vertical' + output = [] + + output_kwargs = { + 'dialect': 'unix', + 'disable_numparse': True, + 'preserve_whitespace': True, + 'style': self.output_style + } + + if not self.formatter.format_name in sql_format.supported_formats: + output_kwargs["preprocessors"] = (preprocessors.align_decimals, ) + + if title: # Only print the title if it's not None. + output = itertools.chain(output, [title]) + + if cur: + column_types = None + if hasattr(cur, 'description'): + def get_col_type(col): + col_type = FIELD_TYPES.get(col[1], str) + return col_type if type(col_type) is type else str + column_types = [get_col_type(col) for col in cur.description] + + if max_width is not None: + cur = list(cur) + + formatted = self.formatter.format_output( + cur, headers, format_name='vertical' if expanded else None, + column_types=column_types, + **output_kwargs) + + if isinstance(formatted, str): + formatted = formatted.splitlines() + formatted = iter(formatted) + + if (not expanded and max_width and headers and cur): + first_line = next(formatted) + if len(strip_ansi(first_line)) > max_width: + formatted = self.formatter.format_output( + cur, headers, format_name='vertical', column_types=column_types, **output_kwargs) + if isinstance(formatted, str): + formatted = iter(formatted.splitlines()) + else: + formatted = itertools.chain([first_line], formatted) + + output = itertools.chain(output, formatted) + + + return output + + def get_reserved_space(self): + """Get the number of lines to reserve for the completion menu.""" + reserved_space_ratio = .45 + max_reserved_space = 8 + _, height = click.get_terminal_size() + return min(int(round(height * reserved_space_ratio)), max_reserved_space) + + def get_last_query(self): + """Get the last query executed or None.""" + return self.query_history[-1][0] if self.query_history else None + + +@click.command() +@click.option('-h', '--host', envvar='MYSQL_HOST', help='Host address of the database.') +@click.option('-P', '--port', envvar='MYSQL_TCP_PORT', type=int, help='Port number to use for connection. Honors ' + '$MYSQL_TCP_PORT.') +@click.option('-u', '--user', help='User name to connect to the database.') +@click.option('-S', '--socket', envvar='MYSQL_UNIX_PORT', help='The socket file to use for connection.') +@click.option('-p', '--password', 'password', envvar='MYSQL_PWD', type=str, + help='Password to connect to the database.') +@click.option('--pass', 'password', envvar='MYSQL_PWD', type=str, + help='Password to connect to the database.') +@click.option('--ssh-user', help='User name to connect to ssh server.') +@click.option('--ssh-host', help='Host name to connect to ssh server.') +@click.option('--ssh-port', default=22, help='Port to connect to ssh server.') +@click.option('--ssh-password', help='Password to connect to ssh server.') +@click.option('--ssh-key-filename', help='Private key filename (identify file) for the ssh connection.') +@click.option('--ssh-config-path', help='Path to ssh configuration.', + default=os.path.expanduser('~') + '/.ssh/config') +@click.option('--ssh-config-host', help='Host to connect to ssh server reading from ssh configuration.') +@click.option('--ssl-ca', help='CA file in PEM format.', + type=click.Path(exists=True)) +@click.option('--ssl-capath', help='CA directory.') +@click.option('--ssl-cert', help='X509 cert in PEM format.', + type=click.Path(exists=True)) +@click.option('--ssl-key', help='X509 key in PEM format.', + type=click.Path(exists=True)) +@click.option('--ssl-cipher', help='SSL cipher to use.') +@click.option('--ssl-verify-server-cert', is_flag=True, + help=('Verify server\'s "Common Name" in its cert against ' + 'hostname used when connecting. This option is disabled ' + 'by default.')) +# as of 2016-02-15 revocation list is not supported by underling PyMySQL +# library (--ssl-crl and --ssl-crlpath options in vanilla mysql client) +@click.option('-V', '--version', is_flag=True, help='Output mycli\'s version.') +@click.option('-v', '--verbose', is_flag=True, help='Verbose output.') +@click.option('-D', '--database', 'dbname', help='Database to use.') +@click.option('-d', '--dsn', default='', envvar='DSN', + help='Use DSN configured into the [alias_dsn] section of myclirc file.') +@click.option('--list-dsn', 'list_dsn', is_flag=True, + help='list of DSN configured into the [alias_dsn] section of myclirc file.') +@click.option('--list-ssh-config', 'list_ssh_config', is_flag=True, + help='list ssh configurations in the ssh config (requires paramiko).') +@click.option('-R', '--prompt', 'prompt', + help='Prompt format (Default: "{0}").'.format( + MyCli.default_prompt)) +@click.option('-l', '--logfile', type=click.File(mode='a', encoding='utf-8'), + help='Log every query and its results to a file.') +@click.option('--defaults-group-suffix', type=str, + help='Read MySQL config groups with the specified suffix.') +@click.option('--defaults-file', type=click.Path(), + help='Only read MySQL options from the given file.') +@click.option('--myclirc', type=click.Path(), default="~/.myclirc", + help='Location of myclirc file.') +@click.option('--auto-vertical-output', is_flag=True, + help='Automatically switch to vertical output mode if the result is wider than the terminal width.') +@click.option('-t', '--table', is_flag=True, + help='Display batch output in table format.') +@click.option('--csv', is_flag=True, + help='Display batch output in CSV format.') +@click.option('--warn/--no-warn', default=None, + help='Warn before running a destructive query.') +@click.option('--local-infile', type=bool, + help='Enable/disable LOAD DATA LOCAL INFILE.') +@click.option('--login-path', type=str, + help='Read this path from the login file.') +@click.option('-e', '--execute', type=str, + help='Execute command and quit.') +@click.argument('database', default='', nargs=1) +def cli(database, user, host, port, socket, password, dbname, + version, verbose, prompt, logfile, defaults_group_suffix, + defaults_file, login_path, auto_vertical_output, local_infile, + ssl_ca, ssl_capath, ssl_cert, ssl_key, ssl_cipher, + ssl_verify_server_cert, table, csv, warn, execute, myclirc, dsn, + list_dsn, ssh_user, ssh_host, ssh_port, ssh_password, + ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host): + """A MySQL terminal client with auto-completion and syntax highlighting. + + \b + Examples: + - mycli my_database + - mycli -u my_user -h my_host.com my_database + - mycli mysql://my_user@my_host.com:3306/my_database + + """ + + if version: + print('Version:', __version__) + sys.exit(0) + + mycli = MyCli(prompt=prompt, logfile=logfile, + defaults_suffix=defaults_group_suffix, + defaults_file=defaults_file, login_path=login_path, + auto_vertical_output=auto_vertical_output, warn=warn, + myclirc=myclirc) + if list_dsn: + try: + alias_dsn = mycli.config['alias_dsn'] + except KeyError as err: + click.secho('Invalid DSNs found in the config file. '\ + 'Please check the "[alias_dsn]" section in myclirc.', + err=True, fg='red') + exit(1) + except Exception as e: + click.secho(str(e), err=True, fg='red') + exit(1) + for alias, value in alias_dsn.items(): + if verbose: + click.secho("{} : {}".format(alias, value)) + else: + click.secho(alias) + sys.exit(0) + if list_ssh_config: + ssh_config = read_ssh_config(ssh_config_path) + for host in ssh_config.get_hostnames(): + if verbose: + host_config = ssh_config.lookup(host) + click.secho("{} : {}".format( + host, host_config.get('hostname'))) + else: + click.secho(host) + sys.exit(0) + # Choose which ever one has a valid value. + database = dbname or database + + ssl = { + 'ca': ssl_ca and os.path.expanduser(ssl_ca), + 'cert': ssl_cert and os.path.expanduser(ssl_cert), + 'key': ssl_key and os.path.expanduser(ssl_key), + 'capath': ssl_capath, + 'cipher': ssl_cipher, + 'check_hostname': ssl_verify_server_cert, + } + + # remove empty ssl options + ssl = {k: v for k, v in ssl.items() if v is not None} + + dsn_uri = None + + # Treat the database argument as a DSN alias if we're missing + # other connection information. + if (mycli.config['alias_dsn'] and database and '://' not in database + and not any([user, password, host, port, login_path])): + dsn, database = database, '' + + if database and '://' in database: + dsn_uri, database = database, '' + + if dsn: + try: + dsn_uri = mycli.config['alias_dsn'][dsn] + except KeyError: + click.secho('Could not find the specified DSN in the config file. ' + 'Please check the "[alias_dsn]" section in your ' + 'myclirc.', err=True, fg='red') + exit(1) + else: + mycli.dsn_alias = dsn + + if dsn_uri: + uri = urlparse(dsn_uri) + if not database: + database = uri.path[1:] # ignore the leading fwd slash + if not user: + user = unquote(uri.username) + if not password and uri.password is not None: + password = unquote(uri.password) + if not host: + host = uri.hostname + if not port: + port = uri.port + + if ssh_config_host: + ssh_config = read_ssh_config( + ssh_config_path + ).lookup(ssh_config_host) + ssh_host = ssh_host if ssh_host else ssh_config.get('hostname') + ssh_user = ssh_user if ssh_user else ssh_config.get('user') + if ssh_config.get('port') and ssh_port == 22: + # port has a default value, overwrite it if it's in the config + ssh_port = int(ssh_config.get('port')) + ssh_key_filename = ssh_key_filename if ssh_key_filename else ssh_config.get( + 'identityfile', [None])[0] + + ssh_key_filename = ssh_key_filename and os.path.expanduser(ssh_key_filename) + + mycli.connect( + database=database, + user=user, + passwd=password, + host=host, + port=port, + socket=socket, + local_infile=local_infile, + ssl=ssl, + ssh_user=ssh_user, + ssh_host=ssh_host, + ssh_port=ssh_port, + ssh_password=ssh_password, + ssh_key_filename=ssh_key_filename + ) + + mycli.logger.debug('Launch Params: \n' + '\tdatabase: %r' + '\tuser: %r' + '\thost: %r' + '\tport: %r', database, user, host, port) + + # --execute argument + if execute: + try: + if csv: + mycli.formatter.format_name = 'csv' + elif not table: + mycli.formatter.format_name = 'tsv' + + mycli.run_query(execute) + exit(0) + except Exception as e: + click.secho(str(e), err=True, fg='red') + exit(1) + + if sys.stdin.isatty(): + mycli.run_cli() + else: + stdin = click.get_text_stream('stdin') + try: + stdin_text = stdin.read() + except MemoryError: + click.secho('Failed! Ran out of memory.', err=True, fg='red') + click.secho('You might want to try the official mysql client.', err=True, fg='red') + click.secho('Sorry... :(', err=True, fg='red') + exit(1) + + try: + sys.stdin = open('/dev/tty') + except (IOError, OSError): + mycli.logger.warning('Unable to open TTY as stdin.') + + if (mycli.destructive_warning and + confirm_destructive_query(stdin_text) is False): + exit(0) + try: + new_line = True + + if csv: + mycli.formatter.format_name = 'csv' + elif not table: + mycli.formatter.format_name = 'tsv' + + mycli.run_query(stdin_text, new_line=new_line) + exit(0) + except Exception as e: + click.secho(str(e), err=True, fg='red') + exit(1) + + +def need_completion_refresh(queries): + """Determines if the completion needs a refresh by checking if the sql + statement is an alter, create, drop or change db.""" + for query in sqlparse.split(queries): + try: + first_token = query.split()[0] + if first_token.lower() in ('alter', 'create', 'use', '\\r', + '\\u', 'connect', 'drop', 'rename'): + return True + except Exception: + return False + + +def need_completion_reset(queries): + """Determines if the statement is a database switch such as 'use' or '\\u'. + When a database is changed the existing completions must be reset before we + start the completion refresh for the new database. + """ + for query in sqlparse.split(queries): + try: + first_token = query.split()[0] + if first_token.lower() in ('use', '\\u'): + return True + except Exception: + return False + + +def is_mutating(status): + """Determines if the statement is mutating based on the status.""" + if not status: + return False + + mutating = set(['insert', 'update', 'delete', 'alter', 'create', 'drop', + 'replace', 'truncate', 'load', 'rename']) + return status.split(None, 1)[0].lower() in mutating + + +def is_select(status): + """Returns true if the first word in status is 'select'.""" + if not status: + return False + return status.split(None, 1)[0].lower() == 'select' + + +def thanks_picker(files=()): + contents = [] + for line in fileinput.input(files=files): + m = re.match('^ *\* (.*)', line) + if m: + contents.append(m.group(1)) + return choice(contents) + + +@prompt_register('edit-and-execute-command') +def edit_and_execute(event): + """Different from the prompt-toolkit default, we want to have a choice not + to execute a query after editing, hence validate_and_handle=False.""" + buff = event.current_buffer + buff.open_in_editor(validate_and_handle=False) + + +def read_ssh_config(ssh_config_path): + ssh_config = paramiko.config.SSHConfig() + try: + with open(ssh_config_path) as f: + ssh_config.parse(f) + # Paramiko prior to version 2.7 raises Exception on parse errors. + # In 2.7 it has become paramiko.ssh_exception.SSHException, + # but let's catch everything for compatibility + except Exception as err: + click.secho( + f'Could not parse SSH configuration file {ssh_config_path}:\n{err} ', + err=True, fg='red' + ) + sys.exit(1) + except FileNotFoundError as e: + click.secho(str(e), err=True, fg='red') + sys.exit(1) + else: + return ssh_config + + +if __name__ == "__main__": + cli() diff --git a/mycli/myclirc b/mycli/myclirc new file mode 100644 index 0000000..534b201 --- /dev/null +++ b/mycli/myclirc @@ -0,0 +1,121 @@ +# vi: ft=dosini +[main] + +# Enables context sensitive auto-completion. If this is disabled the all +# possible completions will be listed. +smart_completion = True + +# Multi-line mode allows breaking up the sql statements into multiple lines. If +# this is set to True, then the end of the statements must have a semi-colon. +# If this is set to False then sql statements can't be split into multiple +# lines. End of line (return) is considered as the end of the statement. +multi_line = False + +# Destructive warning mode will alert you before executing a sql statement +# that may cause harm to the database such as "drop table", "drop database" +# or "shutdown". +destructive_warning = True + +# log_file location. +log_file = ~/.mycli.log + +# Default log level. Possible values: "CRITICAL", "ERROR", "WARNING", "INFO" +# and "DEBUG". "NONE" disables logging. +log_level = INFO + +# Log every query and its results to a file. Enable this by uncommenting the +# line below. +# audit_log = ~/.mycli-audit.log + +# Timing of sql statments and table rendering. +timing = True + +# Table format. Possible values: ascii, double, github, +# psql, plain, simple, grid, fancy_grid, pipe, orgtbl, rst, mediawiki, html, +# latex, latex_booktabs, textile, moinmoin, jira, vertical, tsv, csv. +# Recommended: ascii +table_format = ascii + +# Syntax coloring style. Possible values (many support the "-dark" suffix): +# manni, igor, xcode, vim, autumn, vs, rrt, native, perldoc, borland, tango, emacs, +# friendly, monokai, paraiso, colorful, murphy, bw, pastie, paraiso, trac, default, +# fruity. +# Screenshots at http://mycli.net/syntax +syntax_style = default + +# Keybindings: Possible values: emacs, vi. +# Emacs mode: Ctrl-A is home, Ctrl-E is end. All emacs keybindings are available in the REPL. +# When Vi mode is enabled you can use modal editing features offered by Vi in the REPL. +key_bindings = emacs + +# Enabling this option will show the suggestions in a wider menu. Thus more items are suggested. +wider_completion_menu = False + +# MySQL prompt +# \D - The full current date +# \d - Database name +# \h - Hostname of the server +# \m - Minutes of the current time +# \n - Newline +# \P - AM/PM +# \p - Port +# \R - The current time, in 24-hour military time (0–23) +# \r - The current time, standard 12-hour time (1–12) +# \s - Seconds of the current time +# \t - Product type (Percona, MySQL, MariaDB) +# \A - DSN alias name (from the [alias_dsn] section) +# \u - Username +prompt = '\t \u@\h:\d> ' +prompt_continuation = '->' + +# Skip intro info on startup and outro info on exit +less_chatty = False + +# Use alias from --login-path instead of host name in prompt +login_path_as_host = False + +# Cause result sets to be displayed vertically if they are too wide for the current window, +# and using normal tabular format otherwise. (This applies to statements terminated by ; or \G.) +auto_vertical_output = False + +# keyword casing preference. Possible values "lower", "upper", "auto" +keyword_casing = auto + +# disabled pager on startup +enable_pager = True + +# Custom colors for the completion menu, toolbar, etc. +[colors] +completion-menu.completion.current = 'bg:#ffffff #000000' +completion-menu.completion = 'bg:#008888 #ffffff' +completion-menu.meta.completion.current = 'bg:#44aaaa #000000' +completion-menu.meta.completion = 'bg:#448888 #ffffff' +completion-menu.multi-column-meta = 'bg:#aaffff #000000' +scrollbar.arrow = 'bg:#003333' +scrollbar = 'bg:#00aaaa' +selected = '#ffffff bg:#6666aa' +search = '#ffffff bg:#4444aa' +search.current = '#ffffff bg:#44aa44' +bottom-toolbar = 'bg:#222222 #aaaaaa' +bottom-toolbar.off = 'bg:#222222 #888888' +bottom-toolbar.on = 'bg:#222222 #ffffff' +search-toolbar = 'noinherit bold' +search-toolbar.text = 'nobold' +system-toolbar = 'noinherit bold' +arg-toolbar = 'noinherit bold' +arg-toolbar.text = 'nobold' +bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold' +bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold' + +# style classes for colored table output +output.header = "#00ff5f bold" +output.odd-row = "" +output.even-row = "" + +# Favorite queries. +[favorite_queries] + +# Use the -d option to reference a DSN. +# Special characters in passwords and other strings can be escaped with URL encoding. +[alias_dsn] +# example_dsn = mysql://[user[:password]@][host][:port][/dbname] diff --git a/mycli/packages/__init__.py b/mycli/packages/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/mycli/packages/__init__.py diff --git a/mycli/packages/completion_engine.py b/mycli/packages/completion_engine.py new file mode 100644 index 0000000..2b19c32 --- /dev/null +++ b/mycli/packages/completion_engine.py @@ -0,0 +1,295 @@ +import os +import sys +import sqlparse +from sqlparse.sql import Comparison, Identifier, Where +from sqlparse.compat import text_type +from .parseutils import last_word, extract_tables, find_prev_keyword +from .special import parse_special_command + + +def suggest_type(full_text, text_before_cursor): + """Takes the full_text that is typed so far and also the text before the + cursor to suggest completion type and scope. + + Returns a tuple with a type of entity ('table', 'column' etc) and a scope. + A scope for a column category will be a list of tables. + """ + + word_before_cursor = last_word(text_before_cursor, + include='many_punctuations') + + identifier = None + + # here should be removed once sqlparse has been fixed + try: + # If we've partially typed a word then word_before_cursor won't be an empty + # string. In that case we want to remove the partially typed string before + # sending it to the sqlparser. Otherwise the last token will always be the + # partially typed string which renders the smart completion useless because + # it will always return the list of keywords as completion. + if word_before_cursor: + if word_before_cursor.endswith( + '(') or word_before_cursor.startswith('\\'): + parsed = sqlparse.parse(text_before_cursor) + else: + parsed = sqlparse.parse( + text_before_cursor[:-len(word_before_cursor)]) + + # word_before_cursor may include a schema qualification, like + # "schema_name.partial_name" or "schema_name.", so parse it + # separately + p = sqlparse.parse(word_before_cursor)[0] + + if p.tokens and isinstance(p.tokens[0], Identifier): + identifier = p.tokens[0] + else: + parsed = sqlparse.parse(text_before_cursor) + except (TypeError, AttributeError): + return [{'type': 'keyword'}] + + if len(parsed) > 1: + # Multiple statements being edited -- isolate the current one by + # cumulatively summing statement lengths to find the one that bounds the + # current position + current_pos = len(text_before_cursor) + stmt_start, stmt_end = 0, 0 + + for statement in parsed: + stmt_len = len(text_type(statement)) + stmt_start, stmt_end = stmt_end, stmt_end + stmt_len + + if stmt_end >= current_pos: + text_before_cursor = full_text[stmt_start:current_pos] + full_text = full_text[stmt_start:] + break + + elif parsed: + # A single statement + statement = parsed[0] + else: + # The empty string + statement = None + + # Check for special commands and handle those separately + if statement: + # Be careful here because trivial whitespace is parsed as a statement, + # but the statement won't have a first token + tok1 = statement.token_first() + if tok1 and (tok1.value == 'source' or tok1.value.startswith('\\')): + return suggest_special(text_before_cursor) + + last_token = statement and statement.token_prev(len(statement.tokens))[1] or '' + + return suggest_based_on_last_token(last_token, text_before_cursor, + full_text, identifier) + + +def suggest_special(text): + text = text.lstrip() + cmd, _, arg = parse_special_command(text) + + if cmd == text: + # Trying to complete the special command itself + return [{'type': 'special'}] + + if cmd in ('\\u', '\\r'): + return [{'type': 'database'}] + + if cmd in ('\\T'): + return [{'type': 'table_format'}] + + if cmd in ['\\f', '\\fs', '\\fd']: + return [{'type': 'favoritequery'}] + + if cmd in ['\\dt', '\\dt+']: + return [ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}, + ] + elif cmd in ['\\.', 'source']: + return[{'type': 'file_name'}] + + return [{'type': 'keyword'}, {'type': 'special'}] + + +def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier): + if isinstance(token, str): + token_v = token.lower() + elif isinstance(token, Comparison): + # If 'token' is a Comparison type such as + # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling + # token.value on the comparison type will only return the lhs of the + # comparison. In this case a.id. So we need to do token.tokens to get + # both sides of the comparison and pick the last token out of that + # list. + token_v = token.tokens[-1].value.lower() + elif isinstance(token, Where): + # sqlparse groups all tokens from the where clause into a single token + # list. This means that token.value may be something like + # 'where foo > 5 and '. We need to look "inside" token.tokens to handle + # suggestions in complicated where clauses correctly + prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor) + return suggest_based_on_last_token(prev_keyword, text_before_cursor, + full_text, identifier) + else: + token_v = token.value.lower() + + is_operand = lambda x: x and any([x.endswith(op) for op in ['+', '-', '*', '/']]) + + if not token: + return [{'type': 'keyword'}, {'type': 'special'}] + elif token_v.endswith('('): + p = sqlparse.parse(text_before_cursor)[0] + + if p.tokens and isinstance(p.tokens[-1], Where): + # Four possibilities: + # 1 - Parenthesized clause like "WHERE foo AND (" + # Suggest columns/functions + # 2 - Function call like "WHERE foo(" + # Suggest columns/functions + # 3 - Subquery expression like "WHERE EXISTS (" + # Suggest keywords, in order to do a subquery + # 4 - Subquery OR array comparison like "WHERE foo = ANY(" + # Suggest columns/functions AND keywords. (If we wanted to be + # really fancy, we could suggest only array-typed columns) + + column_suggestions = suggest_based_on_last_token('where', + text_before_cursor, full_text, identifier) + + # Check for a subquery expression (cases 3 & 4) + where = p.tokens[-1] + idx, prev_tok = where.token_prev(len(where.tokens) - 1) + + if isinstance(prev_tok, Comparison): + # e.g. "SELECT foo FROM bar WHERE foo = ANY(" + prev_tok = prev_tok.tokens[-1] + + prev_tok = prev_tok.value.lower() + if prev_tok == 'exists': + return [{'type': 'keyword'}] + else: + return column_suggestions + + # Get the token before the parens + idx, prev_tok = p.token_prev(len(p.tokens) - 1) + if prev_tok and prev_tok.value and prev_tok.value.lower() == 'using': + # tbl1 INNER JOIN tbl2 USING (col1, col2) + tables = extract_tables(full_text) + + # suggest columns that are present in more than one table + return [{'type': 'column', 'tables': tables, 'drop_unique': True}] + elif p.token_first().value.lower() == 'select': + # If the lparen is preceeded by a space chances are we're about to + # do a sub-select. + if last_word(text_before_cursor, + 'all_punctuations').startswith('('): + return [{'type': 'keyword'}] + elif p.token_first().value.lower() == 'show': + return [{'type': 'show'}] + + # We're probably in a function argument list + return [{'type': 'column', 'tables': extract_tables(full_text)}] + elif token_v in ('set', 'order by', 'distinct'): + return [{'type': 'column', 'tables': extract_tables(full_text)}] + elif token_v == 'as': + # Don't suggest anything for an alias + return [] + elif token_v in ('show'): + return [{'type': 'show'}] + elif token_v in ('to',): + p = sqlparse.parse(text_before_cursor)[0] + if p.token_first().value.lower() == 'change': + return [{'type': 'change'}] + else: + return [{'type': 'user'}] + elif token_v in ('user', 'for'): + return [{'type': 'user'}] + elif token_v in ('select', 'where', 'having'): + # Check for a table alias or schema qualification + parent = (identifier and identifier.get_parent_name()) or [] + + tables = extract_tables(full_text) + if parent: + tables = [t for t in tables if identifies(parent, *t)] + return [{'type': 'column', 'tables': tables}, + {'type': 'table', 'schema': parent}, + {'type': 'view', 'schema': parent}, + {'type': 'function', 'schema': parent}] + else: + aliases = [alias or table for (schema, table, alias) in tables] + return [{'type': 'column', 'tables': tables}, + {'type': 'function', 'schema': []}, + {'type': 'alias', 'aliases': aliases}, + {'type': 'keyword'}] + elif (token_v.endswith('join') and token.is_keyword) or (token_v in + ('copy', 'from', 'update', 'into', 'describe', 'truncate', + 'desc', 'explain')): + schema = (identifier and identifier.get_parent_name()) or [] + + # Suggest tables from either the currently-selected schema or the + # public schema if no schema has been specified + suggest = [{'type': 'table', 'schema': schema}] + + if not schema: + # Suggest schemas + suggest.insert(0, {'type': 'schema'}) + + # Only tables can be TRUNCATED, otherwise suggest views + if token_v != 'truncate': + suggest.append({'type': 'view', 'schema': schema}) + + return suggest + + elif token_v in ('table', 'view', 'function'): + # E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>' + rel_type = token_v + schema = (identifier and identifier.get_parent_name()) or [] + if schema: + return [{'type': rel_type, 'schema': schema}] + else: + return [{'type': 'schema'}, {'type': rel_type, 'schema': []}] + elif token_v == 'on': + tables = extract_tables(full_text) # [(schema, table, alias), ...] + parent = (identifier and identifier.get_parent_name()) or [] + if parent: + # "ON parent.<suggestion>" + # parent can be either a schema name or table alias + tables = [t for t in tables if identifies(parent, *t)] + return [{'type': 'column', 'tables': tables}, + {'type': 'table', 'schema': parent}, + {'type': 'view', 'schema': parent}, + {'type': 'function', 'schema': parent}] + else: + # ON <suggestion> + # Use table alias if there is one, otherwise the table name + aliases = [alias or table for (schema, table, alias) in tables] + suggest = [{'type': 'alias', 'aliases': aliases}] + + # The lists of 'aliases' could be empty if we're trying to complete + # a GRANT query. eg: GRANT SELECT, INSERT ON <tab> + # In that case we just suggest all tables. + if not aliases: + suggest.append({'type': 'table', 'schema': parent}) + return suggest + + elif token_v in ('use', 'database', 'template', 'connect'): + # "\c <db", "use <db>", "DROP DATABASE <db>", + # "CREATE DATABASE <newdb> WITH TEMPLATE <db>" + return [{'type': 'database'}] + elif token_v == 'tableformat': + return [{'type': 'table_format'}] + elif token_v.endswith(',') or is_operand(token_v) or token_v in ['=', 'and', 'or']: + prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor) + if prev_keyword: + return suggest_based_on_last_token( + prev_keyword, text_before_cursor, full_text, identifier) + else: + return [] + else: + return [{'type': 'keyword'}] + + +def identifies(id, schema, table, alias): + return id == alias or id == table or ( + schema and (id == schema + '.' + table)) diff --git a/mycli/packages/filepaths.py b/mycli/packages/filepaths.py new file mode 100644 index 0000000..79fe26d --- /dev/null +++ b/mycli/packages/filepaths.py @@ -0,0 +1,106 @@ +import os +import platform + + +if os.name == "posix": + if platform.system() == "Darwin": + DEFAULT_SOCKET_DIRS = ("/tmp",) + else: + DEFAULT_SOCKET_DIRS = ("/var/run", "/var/lib") +else: + DEFAULT_SOCKET_DIRS = () + + +def list_path(root_dir): + """List directory if exists. + + :param root_dir: str + :return: list + + """ + res = [] + if os.path.isdir(root_dir): + for name in os.listdir(root_dir): + res.append(name) + return res + + +def complete_path(curr_dir, last_dir): + """Return the path to complete that matches the last entered component. + + If the last entered component is ~, expanded path would not + match, so return all of the available paths. + + :param curr_dir: str + :param last_dir: str + :return: str + + """ + if not last_dir or curr_dir.startswith(last_dir): + return curr_dir + elif last_dir == '~': + return os.path.join(last_dir, curr_dir) + + +def parse_path(root_dir): + """Split path into head and last component for the completer. + + Also return position where last component starts. + + :param root_dir: str path + :return: tuple of (string, string, int) + + """ + base_dir, last_dir, position = '', '', 0 + if root_dir: + base_dir, last_dir = os.path.split(root_dir) + position = -len(last_dir) if last_dir else 0 + return base_dir, last_dir, position + + +def suggest_path(root_dir): + """List all files and subdirectories in a directory. + + If the directory is not specified, suggest root directory, + user directory, current and parent directory. + + :param root_dir: string: directory to list + :return: list + + """ + if not root_dir: + return [os.path.abspath(os.sep), '~', os.curdir, os.pardir] + + if '~' in root_dir: + root_dir = os.path.expanduser(root_dir) + + if not os.path.exists(root_dir): + root_dir, _ = os.path.split(root_dir) + + return list_path(root_dir) + + +def dir_path_exists(path): + """Check if the directory path exists for a given file. + + For example, for a file /home/user/.cache/mycli/log, check if + /home/user/.cache/mycli exists. + + :param str path: The file path. + :return: Whether or not the directory path exists. + + """ + return os.path.exists(os.path.dirname(path)) + + +def guess_socket_location(): + """Try to guess the location of the default mysql socket file.""" + socket_dirs = filter(os.path.exists, DEFAULT_SOCKET_DIRS) + for directory in socket_dirs: + for r, dirs, files in os.walk(directory, topdown=True): + for filename in files: + name, ext = os.path.splitext(filename) + if name.startswith("mysql") and ext in ('.socket', '.sock'): + return os.path.join(r, filename) + dirs[:] = [d for d in dirs if d.startswith("mysql")] + return None diff --git a/mycli/packages/paramiko_stub/__init__.py b/mycli/packages/paramiko_stub/__init__.py new file mode 100644 index 0000000..045b00e --- /dev/null +++ b/mycli/packages/paramiko_stub/__init__.py @@ -0,0 +1,28 @@ +"""A module to import instead of paramiko when it is not available (to avoid +checking for paramiko all over the place). + +When paramiko is first envoked, it simply shuts down mycli, telling +user they either have to install paramiko or should not use SSH +features. + +""" + + +class Paramiko: + def __getattr__(self, name): + import sys + from textwrap import dedent + print(dedent(""" + To enable certain SSH features you need to install paramiko: + + pip install paramiko + + It is required for the following configuration options: + --list-ssh-config + --ssh-config-host + --ssh-host + """)) + sys.exit(1) + + +paramiko = Paramiko() diff --git a/mycli/packages/parseutils.py b/mycli/packages/parseutils.py new file mode 100644 index 0000000..e3b383e --- /dev/null +++ b/mycli/packages/parseutils.py @@ -0,0 +1,267 @@ +import re +import sqlparse +from sqlparse.sql import IdentifierList, Identifier, Function +from sqlparse.tokens import Keyword, DML, Punctuation + +cleanup_regex = { + # This matches only alphanumerics and underscores. + 'alphanum_underscore': re.compile(r'(\w+)$'), + # This matches everything except spaces, parens, colon, and comma + 'many_punctuations': re.compile(r'([^():,\s]+)$'), + # This matches everything except spaces, parens, colon, comma, and period + 'most_punctuations': re.compile(r'([^\.():,\s]+)$'), + # This matches everything except a space. + 'all_punctuations': re.compile('([^\s]+)$'), + } + +def last_word(text, include='alphanum_underscore'): + """ + Find the last word in a sentence. + + >>> last_word('abc') + 'abc' + >>> last_word(' abc') + 'abc' + >>> last_word('') + '' + >>> last_word(' ') + '' + >>> last_word('abc ') + '' + >>> last_word('abc def') + 'def' + >>> last_word('abc def ') + '' + >>> last_word('abc def;') + '' + >>> last_word('bac $def') + 'def' + >>> last_word('bac $def', include='most_punctuations') + '$def' + >>> last_word('bac \def', include='most_punctuations') + '\\\\def' + >>> last_word('bac \def;', include='most_punctuations') + '\\\\def;' + >>> last_word('bac::def', include='most_punctuations') + 'def' + """ + + if not text: # Empty string + return '' + + if text[-1].isspace(): + return '' + else: + regex = cleanup_regex[include] + matches = regex.search(text) + if matches: + return matches.group(0) + else: + return '' + + +# This code is borrowed from sqlparse example script. +# <url> +def is_subselect(parsed): + if not parsed.is_group: + return False + for item in parsed.tokens: + if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT', + 'UPDATE', 'CREATE', 'DELETE'): + return True + return False + +def extract_from_part(parsed, stop_at_punctuation=True): + tbl_prefix_seen = False + for item in parsed.tokens: + if tbl_prefix_seen: + if is_subselect(item): + for x in extract_from_part(item, stop_at_punctuation): + yield x + elif stop_at_punctuation and item.ttype is Punctuation: + return + # An incomplete nested select won't be recognized correctly as a + # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes + # the second FROM to trigger this elif condition resulting in a + # StopIteration. So we need to ignore the keyword if the keyword + # FROM. + # Also 'SELECT * FROM abc JOIN def' will trigger this elif + # condition. So we need to ignore the keyword JOIN and its variants + # INNER JOIN, FULL OUTER JOIN, etc. + elif item.ttype is Keyword and ( + not item.value.upper() == 'FROM') and ( + not item.value.upper().endswith('JOIN')): + return + else: + yield item + elif ((item.ttype is Keyword or item.ttype is Keyword.DML) and + item.value.upper() in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE', 'JOIN',)): + tbl_prefix_seen = True + # 'SELECT a, FROM abc' will detect FROM as part of the column list. + # So this check here is necessary. + elif isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + if (identifier.ttype is Keyword and + identifier.value.upper() == 'FROM'): + tbl_prefix_seen = True + break + +def extract_table_identifiers(token_stream): + """yields tuples of (schema_name, table_name, table_alias)""" + + for item in token_stream: + if isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + # Sometimes Keywords (such as FROM ) are classified as + # identifiers which don't have the get_real_name() method. + try: + schema_name = identifier.get_parent_name() + real_name = identifier.get_real_name() + except AttributeError: + continue + if real_name: + yield (schema_name, real_name, identifier.get_alias()) + elif isinstance(item, Identifier): + real_name = item.get_real_name() + schema_name = item.get_parent_name() + + if real_name: + yield (schema_name, real_name, item.get_alias()) + else: + name = item.get_name() + yield (None, name, item.get_alias() or name) + elif isinstance(item, Function): + yield (None, item.get_name(), item.get_name()) + +# extract_tables is inspired from examples in the sqlparse lib. +def extract_tables(sql): + """Extract the table names from an SQL statment. + + Returns a list of (schema, table, alias) tuples + + """ + parsed = sqlparse.parse(sql) + if not parsed: + return [] + + # INSERT statements must stop looking for tables at the sign of first + # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2) + # abc is the table name, but if we don't stop at the first lparen, then + # we'll identify abc, col1 and col2 as table names. + insert_stmt = parsed[0].token_first().value.lower() == 'insert' + stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt) + return list(extract_table_identifiers(stream)) + +def find_prev_keyword(sql): + """ Find the last sql keyword in an SQL statement + + Returns the value of the last keyword, and the text of the query with + everything after the last keyword stripped + """ + if not sql.strip(): + return None, '' + + parsed = sqlparse.parse(sql)[0] + flattened = list(parsed.flatten()) + + logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN') + + for t in reversed(flattened): + if t.value == '(' or (t.is_keyword and ( + t.value.upper() not in logical_operators)): + # Find the location of token t in the original parsed statement + # We can't use parsed.token_index(t) because t may be a child token + # inside a TokenList, in which case token_index thows an error + # Minimal example: + # p = sqlparse.parse('select * from foo where bar') + # t = list(p.flatten())[-3] # The "Where" token + # p.token_index(t) # Throws ValueError: not in list + idx = flattened.index(t) + + # Combine the string values of all tokens in the original list + # up to and including the target keyword token t, to produce a + # query string with everything after the keyword token removed + text = ''.join(tok.value for tok in flattened[:idx+1]) + return t, text + + return None, '' + + +def query_starts_with(query, prefixes): + """Check if the query starts with any item from *prefixes*.""" + prefixes = [prefix.lower() for prefix in prefixes] + formatted_sql = sqlparse.format(query.lower(), strip_comments=True) + return bool(formatted_sql) and formatted_sql.split()[0] in prefixes + + +def queries_start_with(queries, prefixes): + """Check if any queries start with any item from *prefixes*.""" + for query in sqlparse.split(queries): + if query and query_starts_with(query, prefixes) is True: + return True + return False + + +def query_has_where_clause(query): + """Check if the query contains a where-clause.""" + return any( + isinstance(token, sqlparse.sql.Where) + for token_list in sqlparse.parse(query) + for token in token_list + ) + + +def is_destructive(queries): + """Returns if any of the queries in *queries* is destructive.""" + keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter') + for query in sqlparse.split(queries): + if query: + if query_starts_with(query, keywords) is True: + return True + elif query_starts_with( + query, ['update'] + ) is True and not query_has_where_clause(query): + return True + + return False + + +def is_open_quote(sql): + """Returns true if the query contains an unclosed quote.""" + + # parsed can contain one or more semi-colon separated commands + parsed = sqlparse.parse(sql) + return any(_parsed_is_open_quote(p) for p in parsed) + + +if __name__ == '__main__': + sql = 'select * from (select t. from tabl t' + print (extract_tables(sql)) + + +def is_dropping_database(queries, dbname): + """Determine if the query is dropping a specific database.""" + result = False + if dbname is None: + return False + + def normalize_db_name(db): + return db.lower().strip('`"') + + dbname = normalize_db_name(dbname) + + for query in sqlparse.parse(queries): + keywords = [t for t in query.tokens if t.is_keyword] + if len(keywords) < 2: + continue + if keywords[0].normalized in ("DROP", "CREATE") and keywords[1].value.lower() in ( + "database", + "schema", + ): + database_token = next( + (t for t in query.tokens if isinstance(t, Identifier)), None + ) + if database_token is not None and normalize_db_name(database_token.get_name()) == dbname: + result = keywords[0].normalized == "DROP" + else: + return result diff --git a/mycli/packages/prompt_utils.py b/mycli/packages/prompt_utils.py new file mode 100644 index 0000000..fb1e431 --- /dev/null +++ b/mycli/packages/prompt_utils.py @@ -0,0 +1,54 @@ +import sys +import click +from .parseutils import is_destructive + + +class ConfirmBoolParamType(click.ParamType): + name = 'confirmation' + + def convert(self, value, param, ctx): + if isinstance(value, bool): + return bool(value) + value = value.lower() + if value in ('yes', 'y'): + return True + elif value in ('no', 'n'): + return False + self.fail('%s is not a valid boolean' % value, param, ctx) + + def __repr__(self): + return 'BOOL' + + +BOOLEAN_TYPE = ConfirmBoolParamType() + + +def confirm_destructive_query(queries): + """Check if the query is destructive and prompts the user to confirm. + + Returns: + * None if the query is non-destructive or we can't prompt the user. + * True if the query is destructive and the user wants to proceed. + * False if the query is destructive and the user doesn't want to proceed. + + """ + prompt_text = ("You're about to run a destructive command.\n" + "Do you want to proceed? (y/n)") + if is_destructive(queries) and sys.stdin.isatty(): + return prompt(prompt_text, type=BOOLEAN_TYPE) + + +def confirm(*args, **kwargs): + """Prompt for confirmation (yes/no) and handle any abort exceptions.""" + try: + return click.confirm(*args, **kwargs) + except click.Abort: + return False + + +def prompt(*args, **kwargs): + """Prompt the user for input and handle any abort exceptions.""" + try: + return click.prompt(*args, **kwargs) + except click.Abort: + return False diff --git a/mycli/packages/special/__init__.py b/mycli/packages/special/__init__.py new file mode 100644 index 0000000..92bcca6 --- /dev/null +++ b/mycli/packages/special/__init__.py @@ -0,0 +1,10 @@ +__all__ = [] + +def export(defn): + """Decorator to explicitly mark functions that are exposed in a lib.""" + globals()[defn.__name__] = defn + __all__.append(defn.__name__) + return defn + +from . import dbcommands +from . import iocommands diff --git a/mycli/packages/special/dbcommands.py b/mycli/packages/special/dbcommands.py new file mode 100644 index 0000000..ed90e4c --- /dev/null +++ b/mycli/packages/special/dbcommands.py @@ -0,0 +1,157 @@ +import logging +import os +import platform +from mycli import __version__ +from mycli.packages.special import iocommands +from mycli.packages.special.utils import format_uptime +from .main import special_command, RAW_QUERY, PARSED_QUERY +from pymysql import ProgrammingError + +log = logging.getLogger(__name__) + + +@special_command('\\dt', '\\dt[+] [table]', 'List or describe tables.', + arg_type=PARSED_QUERY, case_sensitive=True) +def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False): + if arg: + query = 'SHOW FIELDS FROM {0}'.format(arg) + else: + query = 'SHOW TABLES' + log.debug(query) + cur.execute(query) + tables = cur.fetchall() + status = '' + if cur.description: + headers = [x[0] for x in cur.description] + else: + return [(None, None, None, '')] + + if verbose and arg: + query = 'SHOW CREATE TABLE {0}'.format(arg) + log.debug(query) + cur.execute(query) + status = cur.fetchone()[1] + + return [(None, tables, headers, status)] + +@special_command('\\l', '\\l', 'List databases.', arg_type=RAW_QUERY, case_sensitive=True) +def list_databases(cur, **_): + query = 'SHOW DATABASES' + log.debug(query) + cur.execute(query) + if cur.description: + headers = [x[0] for x in cur.description] + return [(None, cur, headers, '')] + else: + return [(None, None, None, '')] + +@special_command('status', '\\s', 'Get status information from the server.', + arg_type=RAW_QUERY, aliases=('\\s', ), case_sensitive=True) +def status(cur, **_): + query = 'SHOW GLOBAL STATUS;' + log.debug(query) + try: + cur.execute(query) + except ProgrammingError: + # Fallback in case query fail, as it does with Mysql 4 + query = 'SHOW STATUS;' + log.debug(query) + cur.execute(query) + status = dict(cur.fetchall()) + + query = 'SHOW GLOBAL VARIABLES;' + log.debug(query) + cur.execute(query) + variables = dict(cur.fetchall()) + + # prepare in case keys are bytes, as with Python 3 and Mysql 4 + if (isinstance(list(variables)[0], bytes) and + isinstance(list(status)[0], bytes)): + variables = {k.decode('utf-8'): v.decode('utf-8') for k, v + in variables.items()} + status = {k.decode('utf-8'): v.decode('utf-8') for k, v + in status.items()} + + # Create output buffers. + title = [] + output = [] + footer = [] + + title.append('--------------') + + # Output the mycli client information. + implementation = platform.python_implementation() + version = platform.python_version() + client_info = [] + client_info.append('mycli {0},'.format(__version__)) + client_info.append('running on {0} {1}'.format(implementation, version)) + title.append(' '.join(client_info) + '\n') + + # Build the output that will be displayed as a table. + output.append(('Connection id:', cur.connection.thread_id())) + + query = 'SELECT DATABASE(), USER();' + log.debug(query) + cur.execute(query) + db, user = cur.fetchone() + if db is None: + db = '' + + output.append(('Current database:', db)) + output.append(('Current user:', user)) + + if iocommands.is_pager_enabled(): + if 'PAGER' in os.environ: + pager = os.environ['PAGER'] + else: + pager = 'System default' + else: + pager = 'stdout' + output.append(('Current pager:', pager)) + + output.append(('Server version:', '{0} {1}'.format( + variables['version'], variables['version_comment']))) + output.append(('Protocol version:', variables['protocol_version'])) + + if 'unix' in cur.connection.host_info.lower(): + host_info = cur.connection.host_info + else: + host_info = '{0} via TCP/IP'.format(cur.connection.host) + + output.append(('Connection:', host_info)) + + query = ('SELECT @@character_set_server, @@character_set_database, ' + '@@character_set_client, @@character_set_connection LIMIT 1;') + log.debug(query) + cur.execute(query) + charset = cur.fetchone() + output.append(('Server characterset:', charset[0])) + output.append(('Db characterset:', charset[1])) + output.append(('Client characterset:', charset[2])) + output.append(('Conn. characterset:', charset[3])) + + if 'TCP/IP' in host_info: + output.append(('TCP port:', cur.connection.port)) + else: + output.append(('UNIX socket:', variables['socket'])) + + output.append(('Uptime:', format_uptime(status['Uptime']))) + + # Print the current server statistics. + stats = [] + stats.append('Connections: {0}'.format(status['Threads_connected'])) + if 'Queries' in status: + stats.append('Queries: {0}'.format(status['Queries'])) + stats.append('Slow queries: {0}'.format(status['Slow_queries'])) + stats.append('Opens: {0}'.format(status['Opened_tables'])) + stats.append('Flush tables: {0}'.format(status['Flush_commands'])) + stats.append('Open tables: {0}'.format(status['Open_tables'])) + if 'Queries' in status: + queries_per_second = int(status['Queries']) / int(status['Uptime']) + stats.append('Queries per second avg: {:.3f}'.format( + queries_per_second)) + stats = ' '.join(stats) + footer.append('\n' + stats) + + footer.append('--------------') + return [('\n'.join(title), output, '', '\n'.join(footer))] diff --git a/mycli/packages/special/delimitercommand.py b/mycli/packages/special/delimitercommand.py new file mode 100644 index 0000000..994b134 --- /dev/null +++ b/mycli/packages/special/delimitercommand.py @@ -0,0 +1,80 @@ +import re +import sqlparse + + +class DelimiterCommand(object): + def __init__(self): + self._delimiter = ';' + + def _split(self, sql): + """Temporary workaround until sqlparse.split() learns about custom + delimiters.""" + + placeholder = "\ufffc" # unicode object replacement character + + if self._delimiter == ';': + return sqlparse.split(sql) + + # We must find a string that original sql does not contain. + # Most likely, our placeholder is enough, but if not, keep looking + while placeholder in sql: + placeholder += placeholder[0] + sql = sql.replace(';', placeholder) + sql = sql.replace(self._delimiter, ';') + + split = sqlparse.split(sql) + + return [ + stmt.replace(';', self._delimiter).replace(placeholder, ';') + for stmt in split + ] + + def queries_iter(self, input): + """Iterate over queries in the input string.""" + + queries = self._split(input) + while queries: + for sql in queries: + delimiter = self._delimiter + sql = queries.pop(0) + if sql.endswith(delimiter): + trailing_delimiter = True + sql = sql.strip(delimiter) + else: + trailing_delimiter = False + + yield sql + + # if the delimiter was changed by the last command, + # re-split everything, and if we previously stripped + # the delimiter, append it to the end + if self._delimiter != delimiter: + combined_statement = ' '.join([sql] + queries) + if trailing_delimiter: + combined_statement += delimiter + queries = self._split(combined_statement)[1:] + + def set(self, arg, **_): + """Change delimiter. + + Since `arg` is everything that follows the DELIMITER token + after sqlparse (it may include other statements separated by + the new delimiter), we want to set the delimiter to the first + word of it. + + """ + match = arg and re.search(r'[^\s]+', arg) + if not match: + message = 'Missing required argument, delimiter' + return [(None, None, None, message)] + + delimiter = match.group() + if delimiter.lower() == 'delimiter': + return [(None, None, None, 'Invalid delimiter "delimiter"')] + + self._delimiter = delimiter + return [(None, None, None, "Changed delimiter to {}".format(delimiter))] + + @property + def current(self): + return self._delimiter diff --git a/mycli/packages/special/favoritequeries.py b/mycli/packages/special/favoritequeries.py new file mode 100644 index 0000000..0b91400 --- /dev/null +++ b/mycli/packages/special/favoritequeries.py @@ -0,0 +1,63 @@ +class FavoriteQueries(object): + + section_name = 'favorite_queries' + + usage = ''' +Favorite Queries are a way to save frequently used queries +with a short name. +Examples: + + # Save a new favorite query. + > \\fs simple select * from abc where a is not Null; + + # List all favorite queries. + > \\f + ╒════════╤═══════════════════════════════════════╕ + │ Name │ Query │ + ╞════════╪═══════════════════════════════════════╡ + │ simple │ SELECT * FROM abc where a is not NULL │ + ╘════════╧═══════════════════════════════════════╛ + + # Run a favorite query. + > \\f simple + ╒════════╤════════╕ + │ a │ b │ + ╞════════╪════════╡ + │ 日本語 │ 日本語 │ + ╘════════╧════════╛ + + # Delete a favorite query. + > \\fd simple + simple: Deleted +''' + + # Class-level variable, for convenience to use as a singleton. + instance = None + + def __init__(self, config): + self.config = config + + @classmethod + def from_config(cls, config): + return FavoriteQueries(config) + + def list(self): + return self.config.get(self.section_name, []) + + def get(self, name): + return self.config.get(self.section_name, {}).get(name, None) + + def save(self, name, query): + self.config.encoding = 'utf-8' + if self.section_name not in self.config: + self.config[self.section_name] = {} + self.config[self.section_name][name] = query + self.config.write() + + def delete(self, name): + try: + del self.config[self.section_name][name] + except KeyError: + return '%s: Not Found.' % name + self.config.write() + return '%s: Deleted' % name diff --git a/mycli/packages/special/iocommands.py b/mycli/packages/special/iocommands.py new file mode 100644 index 0000000..11dca8d --- /dev/null +++ b/mycli/packages/special/iocommands.py @@ -0,0 +1,453 @@ +import os +import re +import locale +import logging +import subprocess +import shlex +from io import open +from time import sleep + +import click +import sqlparse + +from . import export +from .main import special_command, NO_QUERY, PARSED_QUERY +from .favoritequeries import FavoriteQueries +from .delimitercommand import DelimiterCommand +from .utils import handle_cd_command +from mycli.packages.prompt_utils import confirm_destructive_query + +TIMING_ENABLED = False +use_expanded_output = False +PAGER_ENABLED = True +tee_file = None +once_file = None +written_to_once_file = False +delimiter_command = DelimiterCommand() + + +@export +def set_timing_enabled(val): + global TIMING_ENABLED + TIMING_ENABLED = val + +@export +def set_pager_enabled(val): + global PAGER_ENABLED + PAGER_ENABLED = val + + +@export +def is_pager_enabled(): + return PAGER_ENABLED + +@export +@special_command('pager', '\\P [command]', + 'Set PAGER. Print the query results via PAGER.', + arg_type=PARSED_QUERY, aliases=('\\P', ), case_sensitive=True) +def set_pager(arg, **_): + if arg: + os.environ['PAGER'] = arg + msg = 'PAGER set to %s.' % arg + set_pager_enabled(True) + else: + if 'PAGER' in os.environ: + msg = 'PAGER set to %s.' % os.environ['PAGER'] + else: + # This uses click's default per echo_via_pager. + msg = 'Pager enabled.' + set_pager_enabled(True) + + return [(None, None, None, msg)] + +@export +@special_command('nopager', '\\n', 'Disable pager, print to stdout.', + arg_type=NO_QUERY, aliases=('\\n', ), case_sensitive=True) +def disable_pager(): + set_pager_enabled(False) + return [(None, None, None, 'Pager disabled.')] + +@special_command('\\timing', '\\t', 'Toggle timing of commands.', arg_type=NO_QUERY, aliases=('\\t', ), case_sensitive=True) +def toggle_timing(): + global TIMING_ENABLED + TIMING_ENABLED = not TIMING_ENABLED + message = "Timing is " + message += "on." if TIMING_ENABLED else "off." + return [(None, None, None, message)] + +@export +def is_timing_enabled(): + return TIMING_ENABLED + +@export +def set_expanded_output(val): + global use_expanded_output + use_expanded_output = val + +@export +def is_expanded_output(): + return use_expanded_output + +_logger = logging.getLogger(__name__) + +@export +def editor_command(command): + """ + Is this an external editor command? + :param command: string + """ + # It is possible to have `\e filename` or `SELECT * FROM \e`. So we check + # for both conditions. + return command.strip().endswith('\\e') or command.strip().startswith('\\e') + +@export +def get_filename(sql): + if sql.strip().startswith('\\e'): + command, _, filename = sql.partition(' ') + return filename.strip() or None + + +@export +def get_editor_query(sql): + """Get the query part of an editor command.""" + sql = sql.strip() + + # The reason we can't simply do .strip('\e') is that it strips characters, + # not a substring. So it'll strip "e" in the end of the sql also! + # Ex: "select * from style\e" -> "select * from styl". + pattern = re.compile('(^\\\e|\\\e$)') + while pattern.search(sql): + sql = pattern.sub('', sql) + + return sql + + +@export +def open_external_editor(filename=None, sql=None): + """Open external editor, wait for the user to type in their query, return + the query. + + :return: list with one tuple, query as first element. + + """ + + message = None + filename = filename.strip().split(' ', 1)[0] if filename else None + + sql = sql or '' + MARKER = '# Type your query above this line.\n' + + # Populate the editor buffer with the partial sql (if available) and a + # placeholder comment. + query = click.edit(u'{sql}\n\n{marker}'.format(sql=sql, marker=MARKER), + filename=filename, extension='.sql') + + if filename: + try: + with open(filename) as f: + query = f.read() + except IOError: + message = 'Error reading file: %s.' % filename + + if query is not None: + query = query.split(MARKER, 1)[0].rstrip('\n') + else: + # Don't return None for the caller to deal with. + # Empty string is ok. + query = sql + + return (query, message) + + +@special_command('\\f', '\\f [name [args..]]', 'List or execute favorite queries.', arg_type=PARSED_QUERY, case_sensitive=True) +def execute_favorite_query(cur, arg, **_): + """Returns (title, rows, headers, status)""" + if arg == '': + for result in list_favorite_queries(): + yield result + + """Parse out favorite name and optional substitution parameters""" + name, _, arg_str = arg.partition(' ') + args = shlex.split(arg_str) + + query = FavoriteQueries.instance.get(name) + if query is None: + message = "No favorite query: %s" % (name) + yield (None, None, None, message) + else: + query, arg_error = subst_favorite_query_args(query, args) + if arg_error: + yield (None, None, None, arg_error) + else: + for sql in sqlparse.split(query): + sql = sql.rstrip(';') + title = '> %s' % (sql) + cur.execute(sql) + if cur.description: + headers = [x[0] for x in cur.description] + yield (title, cur, headers, None) + else: + yield (title, None, None, None) + +def list_favorite_queries(): + """List of all favorite queries. + Returns (title, rows, headers, status)""" + + headers = ["Name", "Query"] + rows = [(r, FavoriteQueries.instance.get(r)) + for r in FavoriteQueries.instance.list()] + + if not rows: + status = '\nNo favorite queries found.' + FavoriteQueries.instance.usage + else: + status = '' + return [('', rows, headers, status)] + + +def subst_favorite_query_args(query, args): + """replace positional parameters ($1...$N) in query.""" + for idx, val in enumerate(args): + subst_var = '$' + str(idx + 1) + if subst_var not in query: + return [None, 'query does not have substitution parameter ' + subst_var + ':\n ' + query] + + query = query.replace(subst_var, val) + + match = re.search('\\$\d+', query) + if match: + return[None, 'missing substitution for ' + match.group(0) + ' in query:\n ' + query] + + return [query, None] + +@special_command('\\fs', '\\fs name query', 'Save a favorite query.') +def save_favorite_query(arg, **_): + """Save a new favorite query. + Returns (title, rows, headers, status)""" + + usage = 'Syntax: \\fs name query.\n\n' + FavoriteQueries.instance.usage + if not arg: + return [(None, None, None, usage)] + + name, _, query = arg.partition(' ') + + # If either name or query is missing then print the usage and complain. + if (not name) or (not query): + return [(None, None, None, + usage + 'Err: Both name and query are required.')] + + FavoriteQueries.instance.save(name, query) + return [(None, None, None, "Saved.")] + + +@special_command('\\fd', '\\fd [name]', 'Delete a favorite query.') +def delete_favorite_query(arg, **_): + """Delete an existing favorite query.""" + usage = 'Syntax: \\fd name.\n\n' + FavoriteQueries.instance.usage + if not arg: + return [(None, None, None, usage)] + + status = FavoriteQueries.instance.delete(arg) + + return [(None, None, None, status)] + + +@special_command('system', 'system [command]', + 'Execute a system shell commmand.') +def execute_system_command(arg, **_): + """Execute a system shell command.""" + usage = "Syntax: system [command].\n" + + if not arg: + return [(None, None, None, usage)] + + try: + command = arg.strip() + if command.startswith('cd'): + ok, error_message = handle_cd_command(arg) + if not ok: + return [(None, None, None, error_message)] + return [(None, None, None, '')] + + args = arg.split(' ') + process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + output, error = process.communicate() + response = output if not error else error + + # Python 3 returns bytes. This needs to be decoded to a string. + if isinstance(response, bytes): + encoding = locale.getpreferredencoding(False) + response = response.decode(encoding) + + return [(None, None, None, response)] + except OSError as e: + return [(None, None, None, 'OSError: %s' % e.strerror)] + + +def parseargfile(arg): + if arg.startswith('-o '): + mode = "w" + filename = arg[3:] + else: + mode = 'a' + filename = arg + + if not filename: + raise TypeError('You must provide a filename.') + + return {'file': os.path.expanduser(filename), 'mode': mode} + + +@special_command('tee', 'tee [-o] filename', + 'Append all results to an output file (overwrite using -o).') +def set_tee(arg, **_): + global tee_file + + try: + tee_file = open(**parseargfile(arg)) + except (IOError, OSError) as e: + raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror)) + + return [(None, None, None, "")] + +@export +def close_tee(): + global tee_file + if tee_file: + tee_file.close() + tee_file = None + + +@special_command('notee', 'notee', 'Stop writing results to an output file.') +def no_tee(arg, **_): + close_tee() + return [(None, None, None, "")] + +@export +def write_tee(output): + global tee_file + if tee_file: + click.echo(output, file=tee_file, nl=False) + click.echo(u'\n', file=tee_file, nl=False) + tee_file.flush() + + +@special_command('\\once', '\\o [-o] filename', + 'Append next result to an output file (overwrite using -o).', + aliases=('\\o', )) +def set_once(arg, **_): + global once_file, written_to_once_file + + once_file = parseargfile(arg) + written_to_once_file = False + + return [(None, None, None, "")] + + +@export +def write_once(output): + global once_file, written_to_once_file + if output and once_file: + try: + f = open(**once_file) + except (IOError, OSError) as e: + once_file = None + raise OSError("Cannot write to file '{}': {}".format( + e.filename, e.strerror)) + with f: + click.echo(output, file=f, nl=False) + click.echo(u"\n", file=f, nl=False) + written_to_once_file = True + + +@export +def unset_once_if_written(): + """Unset the once file, if it has been written to.""" + global once_file + if written_to_once_file: + once_file = None + + +@special_command( + 'watch', + 'watch [seconds] [-c] query', + 'Executes the query every [seconds] seconds (by default 5).' +) +def watch_query(arg, **kwargs): + usage = """Syntax: watch [seconds] [-c] query. + * seconds: The interval at the query will be repeated, in seconds. + By default 5. + * -c: Clears the screen between every iteration. +""" + if not arg: + yield (None, None, None, usage) + return + seconds = 5 + clear_screen = False + statement = None + while statement is None: + arg = arg.strip() + if not arg: + # Oops, we parsed all the arguments without finding a statement + yield (None, None, None, usage) + return + (current_arg, _, arg) = arg.partition(' ') + try: + seconds = float(current_arg) + continue + except ValueError: + pass + if current_arg == '-c': + clear_screen = True + continue + statement = '{0!s} {1!s}'.format(current_arg, arg) + destructive_prompt = confirm_destructive_query(statement) + if destructive_prompt is False: + click.secho("Wise choice!") + return + elif destructive_prompt is True: + click.secho("Your call!") + cur = kwargs['cur'] + sql_list = [ + (sql.rstrip(';'), "> {0!s}".format(sql)) + for sql in sqlparse.split(statement) + ] + old_pager_enabled = is_pager_enabled() + while True: + if clear_screen: + click.clear() + try: + # Somewhere in the code the pager its activated after every yield, + # so we disable it in every iteration + set_pager_enabled(False) + for (sql, title) in sql_list: + cur.execute(sql) + if cur.description: + headers = [x[0] for x in cur.description] + yield (title, cur, headers, None) + else: + yield (title, None, None, None) + sleep(seconds) + except KeyboardInterrupt: + # This prints the Ctrl-C character in its own line, which prevents + # to print a line with the cursor positioned behind the prompt + click.secho("", nl=True) + return + finally: + set_pager_enabled(old_pager_enabled) + + +@export +@special_command('delimiter', None, 'Change SQL delimiter.') +def set_delimiter(arg, **_): + return delimiter_command.set(arg) + + +@export +def get_current_delimiter(): + return delimiter_command.current + + +@export +def split_queries(input): + for query in delimiter_command.queries_iter(input): + yield query diff --git a/mycli/packages/special/main.py b/mycli/packages/special/main.py new file mode 100644 index 0000000..dddba66 --- /dev/null +++ b/mycli/packages/special/main.py @@ -0,0 +1,118 @@ +import logging +from collections import namedtuple + +from . import export + +log = logging.getLogger(__name__) + +NO_QUERY = 0 +PARSED_QUERY = 1 +RAW_QUERY = 2 + +SpecialCommand = namedtuple('SpecialCommand', + ['handler', 'command', 'shortcut', 'description', 'arg_type', 'hidden', + 'case_sensitive']) + +COMMANDS = {} + +@export +class CommandNotFound(Exception): + pass + +@export +def parse_special_command(sql): + command, _, arg = sql.partition(' ') + verbose = '+' in command + command = command.strip().replace('+', '') + return (command, verbose, arg.strip()) + +@export +def special_command(command, shortcut, description, arg_type=PARSED_QUERY, + hidden=False, case_sensitive=False, aliases=()): + def wrapper(wrapped): + register_special_command(wrapped, command, shortcut, description, + arg_type, hidden, case_sensitive, aliases) + return wrapped + return wrapper + +@export +def register_special_command(handler, command, shortcut, description, + arg_type=PARSED_QUERY, hidden=False, case_sensitive=False, aliases=()): + cmd = command.lower() if not case_sensitive else command + COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description, + arg_type, hidden, case_sensitive) + for alias in aliases: + cmd = alias.lower() if not case_sensitive else alias + COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description, + arg_type, case_sensitive=case_sensitive, + hidden=True) + +@export +def execute(cur, sql): + """Execute a special command and return the results. If the special command + is not supported a KeyError will be raised. + """ + command, verbose, arg = parse_special_command(sql) + + if (command not in COMMANDS) and (command.lower() not in COMMANDS): + raise CommandNotFound + + try: + special_cmd = COMMANDS[command] + except KeyError: + special_cmd = COMMANDS[command.lower()] + if special_cmd.case_sensitive: + raise CommandNotFound('Command not found: %s' % command) + + # "help <SQL KEYWORD> is a special case. We want built-in help, not + # mycli help here. + if command == 'help' and arg: + return show_keyword_help(cur=cur, arg=arg) + + if special_cmd.arg_type == NO_QUERY: + return special_cmd.handler() + elif special_cmd.arg_type == PARSED_QUERY: + return special_cmd.handler(cur=cur, arg=arg, verbose=verbose) + elif special_cmd.arg_type == RAW_QUERY: + return special_cmd.handler(cur=cur, query=sql) + +@special_command('help', '\\?', 'Show this help.', arg_type=NO_QUERY, aliases=('\\?', '?')) +def show_help(): # All the parameters are ignored. + headers = ['Command', 'Shortcut', 'Description'] + result = [] + + for _, value in sorted(COMMANDS.items()): + if not value.hidden: + result.append((value.command, value.shortcut, value.description)) + return [(None, result, headers, None)] + +def show_keyword_help(cur, arg): + """ + Call the built-in "show <command>", to display help for an SQL keyword. + :param cur: cursor + :param arg: string + :return: list + """ + keyword = arg.strip('"').strip("'") + query = "help '{0}'".format(keyword) + log.debug(query) + cur.execute(query) + if cur.description and cur.rowcount > 0: + headers = [x[0] for x in cur.description] + return [(None, cur, headers, '')] + else: + return [(None, None, None, 'No help found for {0}.'.format(keyword))] + + +@special_command('exit', '\\q', 'Exit.', arg_type=NO_QUERY, aliases=('\\q', )) +@special_command('quit', '\\q', 'Quit.', arg_type=NO_QUERY) +def quit(*_args): + raise EOFError + + +@special_command('\\e', '\\e', 'Edit command with editor (uses $EDITOR).', + arg_type=NO_QUERY, case_sensitive=True) +@special_command('\\G', '\\G', 'Display current query results vertically.', + arg_type=NO_QUERY, case_sensitive=True) +def stub(): + raise NotImplementedError diff --git a/mycli/packages/special/utils.py b/mycli/packages/special/utils.py new file mode 100644 index 0000000..ef96093 --- /dev/null +++ b/mycli/packages/special/utils.py @@ -0,0 +1,46 @@ +import os +import subprocess + +def handle_cd_command(arg): + """Handles a `cd` shell command by calling python's os.chdir.""" + CD_CMD = 'cd' + tokens = arg.split(CD_CMD + ' ') + directory = tokens[-1] if len(tokens) > 1 else None + if not directory: + return False, "No folder name was provided." + try: + os.chdir(directory) + subprocess.call(['pwd']) + return True, None + except OSError as e: + return False, e.strerror + +def format_uptime(uptime_in_seconds): + """Format number of seconds into human-readable string. + + :param uptime_in_seconds: The server uptime in seconds. + :returns: A human-readable string representing the uptime. + + >>> uptime = format_uptime('56892') + >>> print(uptime) + 15 hours 48 min 12 sec + """ + + m, s = divmod(int(uptime_in_seconds), 60) + h, m = divmod(m, 60) + d, h = divmod(h, 24) + + uptime_values = [] + + for value, unit in ((d, 'days'), (h, 'hours'), (m, 'min'), (s, 'sec')): + if value == 0 and not uptime_values: + # Don't include a value/unit if the unit isn't applicable to + # the uptime. E.g. don't do 0 days 0 hours 1 min 30 sec. + continue + elif value == 1 and unit.endswith('s'): + # Remove the "s" if the unit is singular. + unit = unit[:-1] + uptime_values.append('{0} {1}'.format(value, unit)) + + uptime = ' '.join(uptime_values) + return uptime diff --git a/mycli/packages/tabular_output/__init__.py b/mycli/packages/tabular_output/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/mycli/packages/tabular_output/__init__.py diff --git a/mycli/packages/tabular_output/sql_format.py b/mycli/packages/tabular_output/sql_format.py new file mode 100644 index 0000000..730e633 --- /dev/null +++ b/mycli/packages/tabular_output/sql_format.py @@ -0,0 +1,63 @@ +"""Format adapter for sql.""" + +from cli_helpers.utils import filter_dict_by_key +from mycli.packages.parseutils import extract_tables + +supported_formats = ('sql-insert', 'sql-update', 'sql-update-1', + 'sql-update-2', ) + +preprocessors = () + + +def escape_for_sql_statement(value): + if isinstance(value, bytes): + return f"X'{value.hex()}'" + else: + return formatter.mycli.sqlexecute.conn.escape(value) + + +def adapter(data, headers, table_format=None, **kwargs): + tables = extract_tables(formatter.query) + if len(tables) > 0: + table = tables[0] + if table[0]: + table_name = "{}.{}".format(*table[:2]) + else: + table_name = table[1] + else: + table_name = "`DUAL`" + if table_format == 'sql-insert': + h = "`, `".join(headers) + yield "INSERT INTO {} (`{}`) VALUES".format(table_name, h) + prefix = " " + for d in data: + values = ", ".join(escape_for_sql_statement(v) + for i, v in enumerate(d)) + yield "{}({})".format(prefix, values) + if prefix == " ": + prefix = ", " + yield ";" + if table_format.startswith('sql-update'): + s = table_format.split('-') + keys = 1 + if len(s) > 2: + keys = int(s[-1]) + for d in data: + yield "UPDATE {} SET".format(table_name) + prefix = " " + for i, v in enumerate(d[keys:], keys): + yield "{}`{}` = {}".format(prefix, headers[i], escape_for_sql_statement(v)) + if prefix == " ": + prefix = ", " + f = "`{}` = {}" + where = (f.format(headers[i], escape_for_sql_statement( + d[i])) for i in range(keys)) + yield "WHERE {};".format(" AND ".join(where)) + + +def register_new_formatter(TabularOutputFormatter): + global formatter + formatter = TabularOutputFormatter + for sql_format in supported_formats: + TabularOutputFormatter.register_new_formatter( + sql_format, adapter, preprocessors, {'table_format': sql_format}) diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py new file mode 100644 index 0000000..20611be --- /dev/null +++ b/mycli/sqlcompleter.py @@ -0,0 +1,435 @@ +import logging +from re import compile, escape +from collections import Counter + +from prompt_toolkit.completion import Completer, Completion + +from .packages.completion_engine import suggest_type +from .packages.parseutils import last_word +from .packages.filepaths import parse_path, complete_path, suggest_path +from .packages.special.favoritequeries import FavoriteQueries + +_logger = logging.getLogger(__name__) + + +class SQLCompleter(Completer): + keywords = ['ACCESS', 'ADD', 'ALL', 'ALTER TABLE', 'AND', 'ANY', 'AS', + 'ASC', 'AUTO_INCREMENT', 'BEFORE', 'BEGIN', 'BETWEEN', + 'BIGINT', 'BINARY', 'BY', 'CASE', 'CHANGE MASTER TO', 'CHAR', + 'CHARACTER SET', 'CHECK', 'COLLATE', 'COLUMN', 'COMMENT', + 'COMMIT', 'CONSTRAINT', 'CREATE', 'CURRENT', + 'CURRENT_TIMESTAMP', 'DATABASE', 'DATE', 'DECIMAL', 'DEFAULT', + 'DELETE FROM', 'DESC', 'DESCRIBE', 'DROP', + 'ELSE', 'END', 'ENGINE', 'ESCAPE', 'EXISTS', 'FILE', 'FLOAT', + 'FOR', 'FOREIGN KEY', 'FORMAT', 'FROM', 'FULL', 'FUNCTION', + 'GRANT', 'GROUP BY', 'HAVING', 'HOST', 'IDENTIFIED', 'IN', + 'INCREMENT', 'INDEX', 'INSERT INTO', 'INT', 'INTEGER', + 'INTERVAL', 'INTO', 'IS', 'JOIN', 'KEY', 'LEFT', 'LEVEL', + 'LIKE', 'LIMIT', 'LOCK', 'LOGS', 'LONG', 'MASTER', + 'MEDIUMINT', 'MODE', 'MODIFY', 'NOT', 'NULL', 'NUMBER', + 'OFFSET', 'ON', 'OPTION', 'OR', 'ORDER BY', 'OUTER', 'OWNER', + 'PASSWORD', 'PORT', 'PRIMARY', 'PRIVILEGES', 'PROCESSLIST', + 'PURGE', 'REFERENCES', 'REGEXP', 'RENAME', 'REPAIR', 'RESET', + 'REVOKE', 'RIGHT', 'ROLLBACK', 'ROW', 'ROWS', 'ROW_FORMAT', + 'SAVEPOINT', 'SELECT', 'SESSION', 'SET', 'SHARE', 'SHOW', + 'SLAVE', 'SMALLINT', 'SMALLINT', 'START', 'STOP', 'TABLE', + 'THEN', 'TINYINT', 'TO', 'TRANSACTION', 'TRIGGER', 'TRUNCATE', + 'UNION', 'UNIQUE', 'UNSIGNED', 'UPDATE', 'USE', 'USER', + 'USING', 'VALUES', 'VARCHAR', 'VIEW', 'WHEN', 'WHERE', 'WITH'] + + functions = ['AVG', 'CONCAT', 'COUNT', 'DISTINCT', 'FIRST', 'FORMAT', + 'FROM_UNIXTIME', 'LAST', 'LCASE', 'LEN', 'MAX', 'MID', + 'MIN', 'NOW', 'ROUND', 'SUM', 'TOP', 'UCASE', 'UNIX_TIMESTAMP'] + + show_items = [] + + change_items = ['MASTER_BIND', 'MASTER_HOST', 'MASTER_USER', + 'MASTER_PASSWORD', 'MASTER_PORT', 'MASTER_CONNECT_RETRY', + 'MASTER_HEARTBEAT_PERIOD', 'MASTER_LOG_FILE', + 'MASTER_LOG_POS', 'RELAY_LOG_FILE', 'RELAY_LOG_POS', + 'MASTER_SSL', 'MASTER_SSL_CA', 'MASTER_SSL_CAPATH', + 'MASTER_SSL_CERT', 'MASTER_SSL_KEY', 'MASTER_SSL_CIPHER', + 'MASTER_SSL_VERIFY_SERVER_CERT', 'IGNORE_SERVER_IDS'] + + users = [] + + def __init__(self, smart_completion=True, supported_formats=(), keyword_casing='auto'): + super(self.__class__, self).__init__() + self.smart_completion = smart_completion + self.reserved_words = set() + for x in self.keywords: + self.reserved_words.update(x.split()) + self.name_pattern = compile("^[_a-z][_a-z0-9\$]*$") + + self.special_commands = [] + self.table_formats = supported_formats + if keyword_casing not in ('upper', 'lower', 'auto'): + keyword_casing = 'auto' + self.keyword_casing = keyword_casing + self.reset_completions() + + def escape_name(self, name): + if name and ((not self.name_pattern.match(name)) + or (name.upper() in self.reserved_words) + or (name.upper() in self.functions)): + name = '`%s`' % name + + return name + + def unescape_name(self, name): + """Unquote a string.""" + if name and name[0] == '"' and name[-1] == '"': + name = name[1:-1] + + return name + + def escaped_names(self, names): + return [self.escape_name(name) for name in names] + + def extend_special_commands(self, special_commands): + # Special commands are not part of all_completions since they can only + # be at the beginning of a line. + self.special_commands.extend(special_commands) + + def extend_database_names(self, databases): + self.databases.extend(databases) + + def extend_keywords(self, additional_keywords): + self.keywords.extend(additional_keywords) + self.all_completions.update(additional_keywords) + + def extend_show_items(self, show_items): + for show_item in show_items: + self.show_items.extend(show_item) + self.all_completions.update(show_item) + + def extend_change_items(self, change_items): + for change_item in change_items: + self.change_items.extend(change_item) + self.all_completions.update(change_item) + + def extend_users(self, users): + for user in users: + self.users.extend(user) + self.all_completions.update(user) + + def extend_schemata(self, schema): + if schema is None: + return + metadata = self.dbmetadata['tables'] + metadata[schema] = {} + + # dbmetadata.values() are the 'tables' and 'functions' dicts + for metadata in self.dbmetadata.values(): + metadata[schema] = {} + self.all_completions.update(schema) + + def extend_relations(self, data, kind): + """Extend metadata for tables or views + + :param data: list of (rel_name, ) tuples + :param kind: either 'tables' or 'views' + :return: + """ + # 'data' is a generator object. It can throw an exception while being + # consumed. This could happen if the user has launched the app without + # specifying a database name. This exception must be handled to prevent + # crashing. + try: + data = [self.escaped_names(d) for d in data] + except Exception: + data = [] + + # dbmetadata['tables'][$schema_name][$table_name] should be a list of + # column names. Default to an asterisk + metadata = self.dbmetadata[kind] + for relname in data: + try: + metadata[self.dbname][relname[0]] = ['*'] + except KeyError: + _logger.error('%r %r listed in unrecognized schema %r', + kind, relname[0], self.dbname) + self.all_completions.add(relname[0]) + + def extend_columns(self, column_data, kind): + """Extend column metadata + + :param column_data: list of (rel_name, column_name) tuples + :param kind: either 'tables' or 'views' + :return: + """ + # 'column_data' is a generator object. It can throw an exception while + # being consumed. This could happen if the user has launched the app + # without specifying a database name. This exception must be handled to + # prevent crashing. + try: + column_data = [self.escaped_names(d) for d in column_data] + except Exception: + column_data = [] + + metadata = self.dbmetadata[kind] + for relname, column in column_data: + metadata[self.dbname][relname].append(column) + self.all_completions.add(column) + + def extend_functions(self, func_data): + # 'func_data' is a generator object. It can throw an exception while + # being consumed. This could happen if the user has launched the app + # without specifying a database name. This exception must be handled to + # prevent crashing. + try: + func_data = [self.escaped_names(d) for d in func_data] + except Exception: + func_data = [] + + # dbmetadata['functions'][$schema_name][$function_name] should return + # function metadata. + metadata = self.dbmetadata['functions'] + + for func in func_data: + metadata[self.dbname][func[0]] = None + self.all_completions.add(func[0]) + + def set_dbname(self, dbname): + self.dbname = dbname + + def reset_completions(self): + self.databases = [] + self.users = [] + self.show_items = [] + self.dbname = '' + self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {}} + self.all_completions = set(self.keywords + self.functions) + + @staticmethod + def find_matches(text, collection, start_only=False, fuzzy=True, casing=None): + """Find completion matches for the given text. + + Given the user's input text and a collection of available + completions, find completions matching the last word of the + text. + + If `start_only` is True, the text will match an available + completion only at the beginning. Otherwise, a completion is + considered a match if the text appears anywhere within it. + + yields prompt_toolkit Completion instances for any matches found + in the collection of available completions. + """ + last = last_word(text, include='most_punctuations') + text = last.lower() + + completions = [] + + if fuzzy: + regex = '.*?'.join(map(escape, text)) + pat = compile('(%s)' % regex) + for item in sorted(collection): + r = pat.search(item.lower()) + if r: + completions.append((len(r.group()), r.start(), item)) + else: + match_end_limit = len(text) if start_only else None + for item in sorted(collection): + match_point = item.lower().find(text, 0, match_end_limit) + if match_point >= 0: + completions.append((len(text), match_point, item)) + + if casing == 'auto': + casing = 'lower' if last and last[-1].islower() else 'upper' + + def apply_case(kw): + if casing == 'upper': + return kw.upper() + return kw.lower() + + return (Completion(z if casing is None else apply_case(z), -len(text)) + for x, y, z in sorted(completions)) + + def get_completions(self, document, complete_event, smart_completion=None): + word_before_cursor = document.get_word_before_cursor(WORD=True) + if smart_completion is None: + smart_completion = self.smart_completion + + # If smart_completion is off then match any word that starts with + # 'word_before_cursor'. + if not smart_completion: + return self.find_matches(word_before_cursor, self.all_completions, + start_only=True, fuzzy=False) + + completions = [] + suggestions = suggest_type(document.text, document.text_before_cursor) + + for suggestion in suggestions: + + _logger.debug('Suggestion type: %r', suggestion['type']) + + if suggestion['type'] == 'column': + tables = suggestion['tables'] + _logger.debug("Completion column scope: %r", tables) + scoped_cols = self.populate_scoped_cols(tables) + if suggestion.get('drop_unique'): + # drop_unique is used for 'tb11 JOIN tbl2 USING (...' + # which should suggest only columns that appear in more than + # one table + scoped_cols = [ + col for (col, count) in Counter(scoped_cols).items() + if count > 1 and col != '*' + ] + + cols = self.find_matches(word_before_cursor, scoped_cols) + completions.extend(cols) + + elif suggestion['type'] == 'function': + # suggest user-defined functions using substring matching + funcs = self.populate_schema_objects(suggestion['schema'], + 'functions') + user_funcs = self.find_matches(word_before_cursor, funcs) + completions.extend(user_funcs) + + # suggest hardcoded functions using startswith matching only if + # there is no schema qualifier. If a schema qualifier is + # present it probably denotes a table. + # eg: SELECT * FROM users u WHERE u. + if not suggestion['schema']: + predefined_funcs = self.find_matches(word_before_cursor, + self.functions, + start_only=True, + fuzzy=False, + casing=self.keyword_casing) + completions.extend(predefined_funcs) + + elif suggestion['type'] == 'table': + tables = self.populate_schema_objects(suggestion['schema'], + 'tables') + tables = self.find_matches(word_before_cursor, tables) + completions.extend(tables) + + elif suggestion['type'] == 'view': + views = self.populate_schema_objects(suggestion['schema'], + 'views') + views = self.find_matches(word_before_cursor, views) + completions.extend(views) + + elif suggestion['type'] == 'alias': + aliases = suggestion['aliases'] + aliases = self.find_matches(word_before_cursor, aliases) + completions.extend(aliases) + + elif suggestion['type'] == 'database': + dbs = self.find_matches(word_before_cursor, self.databases) + completions.extend(dbs) + + elif suggestion['type'] == 'keyword': + keywords = self.find_matches(word_before_cursor, self.keywords, + start_only=True, + fuzzy=False, + casing=self.keyword_casing) + completions.extend(keywords) + + elif suggestion['type'] == 'show': + show_items = self.find_matches(word_before_cursor, + self.show_items, + start_only=False, + fuzzy=True, + casing=self.keyword_casing) + completions.extend(show_items) + + elif suggestion['type'] == 'change': + change_items = self.find_matches(word_before_cursor, + self.change_items, + start_only=False, + fuzzy=True) + completions.extend(change_items) + elif suggestion['type'] == 'user': + users = self.find_matches(word_before_cursor, self.users, + start_only=False, + fuzzy=True) + completions.extend(users) + + elif suggestion['type'] == 'special': + special = self.find_matches(word_before_cursor, + self.special_commands, + start_only=True, + fuzzy=False) + completions.extend(special) + elif suggestion['type'] == 'favoritequery': + queries = self.find_matches(word_before_cursor, + FavoriteQueries.instance.list(), + start_only=False, fuzzy=True) + completions.extend(queries) + elif suggestion['type'] == 'table_format': + formats = self.find_matches(word_before_cursor, + self.table_formats, + start_only=True, fuzzy=False) + completions.extend(formats) + elif suggestion['type'] == 'file_name': + file_names = self.find_files(word_before_cursor) + completions.extend(file_names) + + return completions + + def find_files(self, word): + """Yield matching directory or file names. + + :param word: + :return: iterable + + """ + base_path, last_path, position = parse_path(word) + paths = suggest_path(word) + for name in sorted(paths): + suggestion = complete_path(name, last_path) + if suggestion: + yield Completion(suggestion, position) + + def populate_scoped_cols(self, scoped_tbls): + """Find all columns in a set of scoped_tables + :param scoped_tbls: list of (schema, table, alias) tuples + :return: list of column names + """ + columns = [] + meta = self.dbmetadata + + for tbl in scoped_tbls: + # A fully qualified schema.relname reference or default_schema + # DO NOT escape schema names. + schema = tbl[0] or self.dbname + relname = tbl[1] + escaped_relname = self.escape_name(tbl[1]) + + # We don't know if schema.relname is a table or view. Since + # tables and views cannot share the same name, we can check one + # at a time + try: + columns.extend(meta['tables'][schema][relname]) + + # Table exists, so don't bother checking for a view + continue + except KeyError: + try: + columns.extend(meta['tables'][schema][escaped_relname]) + # Table exists, so don't bother checking for a view + continue + except KeyError: + pass + + try: + columns.extend(meta['views'][schema][relname]) + except KeyError: + pass + + return columns + + def populate_schema_objects(self, schema, obj_type): + """Returns list of tables or functions for a (optional) schema""" + metadata = self.dbmetadata[obj_type] + schema = schema or self.dbname + + try: + objects = metadata[schema].keys() + except KeyError: + # schema doesn't exist + objects = [] + + return objects diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py new file mode 100644 index 0000000..c68af0f --- /dev/null +++ b/mycli/sqlexecute.py @@ -0,0 +1,313 @@ +import logging +import pymysql +import sqlparse +from .packages import special +from pymysql.constants import FIELD_TYPE +from pymysql.converters import (convert_datetime, + convert_timedelta, convert_date, conversions, + decoders) +try: + import paramiko +except ImportError: + from mycli.packages.paramiko_stub import paramiko + +_logger = logging.getLogger(__name__) + +FIELD_TYPES = decoders.copy() +FIELD_TYPES.update({ + FIELD_TYPE.NULL: type(None) +}) + +class SQLExecute(object): + + databases_query = '''SHOW DATABASES''' + + tables_query = '''SHOW TABLES''' + + version_query = '''SELECT @@VERSION''' + + version_comment_query = '''SELECT @@VERSION_COMMENT''' + version_comment_query_mysql4 = '''SHOW VARIABLES LIKE "version_comment"''' + + show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"''' + + users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user''' + + functions_query = '''SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES + WHERE ROUTINE_TYPE="FUNCTION" AND ROUTINE_SCHEMA = "%s"''' + + table_columns_query = '''select TABLE_NAME, COLUMN_NAME from information_schema.columns + where table_schema = '%s' + order by table_name,ordinal_position''' + + def __init__(self, database, user, password, host, port, socket, charset, + local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password, + ssh_key_filename): + self.dbname = database + self.user = user + self.password = password + self.host = host + self.port = port + self.socket = socket + self.charset = charset + self.local_infile = local_infile + self.ssl = ssl + self._server_type = None + self.connection_id = None + self.ssh_user = ssh_user + self.ssh_host = ssh_host + self.ssh_port = ssh_port + self.ssh_password = ssh_password + self.ssh_key_filename = ssh_key_filename + self.connect() + + def connect(self, database=None, user=None, password=None, host=None, + port=None, socket=None, charset=None, local_infile=None, + ssl=None, ssh_host=None, ssh_port=None, ssh_user=None, + ssh_password=None, ssh_key_filename=None): + db = (database or self.dbname) + user = (user or self.user) + password = (password or self.password) + host = (host or self.host) + port = (port or self.port) + socket = (socket or self.socket) + charset = (charset or self.charset) + local_infile = (local_infile or self.local_infile) + ssl = (ssl or self.ssl) + ssh_user = (ssh_user or self.ssh_user) + ssh_host = (ssh_host or self.ssh_host) + ssh_port = (ssh_port or self.ssh_port) + ssh_password = (ssh_password or self.ssh_password) + ssh_key_filename = (ssh_key_filename or self.ssh_key_filename) + _logger.debug( + 'Connection DB Params: \n' + '\tdatabase: %r' + '\tuser: %r' + '\thost: %r' + '\tport: %r' + '\tsocket: %r' + '\tcharset: %r' + '\tlocal_infile: %r' + '\tssl: %r' + '\tssh_user: %r' + '\tssh_host: %r' + '\tssh_port: %r' + '\tssh_password: %r' + '\tssh_key_filename: %r', + db, user, host, port, socket, charset, local_infile, ssl, + ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename + ) + conv = conversions.copy() + conv.update({ + FIELD_TYPE.TIMESTAMP: lambda obj: (convert_datetime(obj) or obj), + FIELD_TYPE.DATETIME: lambda obj: (convert_datetime(obj) or obj), + FIELD_TYPE.TIME: lambda obj: (convert_timedelta(obj) or obj), + FIELD_TYPE.DATE: lambda obj: (convert_date(obj) or obj), + }) + + defer_connect = False + + if ssh_host: + defer_connect = True + + conn = pymysql.connect( + database=db, user=user, password=password, host=host, port=port, + unix_socket=socket, use_unicode=True, charset=charset, + autocommit=True, client_flag=pymysql.constants.CLIENT.INTERACTIVE, + local_infile=local_infile, conv=conv, ssl=ssl, program_name="mycli", + defer_connect=defer_connect + ) + + if ssh_host: + client = paramiko.SSHClient() + client.load_system_host_keys() + client.set_missing_host_key_policy(paramiko.WarningPolicy()) + client.connect( + ssh_host, ssh_port, ssh_user, ssh_password, + key_filename=ssh_key_filename + ) + chan = client.get_transport().open_channel( + 'direct-tcpip', + (host, port), + ('0.0.0.0', 0), + ) + conn.connect(chan) + + if hasattr(self, 'conn'): + self.conn.close() + self.conn = conn + # Update them after the connection is made to ensure that it was a + # successful connection. + self.dbname = db + self.user = user + self.password = password + self.host = host + self.port = port + self.socket = socket + self.charset = charset + self.ssl = ssl + # retrieve connection id + self.reset_connection_id() + + def run(self, statement): + """Execute the sql in the database and return the results. The results + are a list of tuples. Each tuple has 4 values + (title, rows, headers, status). + """ + + # Remove spaces and EOL + statement = statement.strip() + if not statement: # Empty string + yield (None, None, None, None) + + # Split the sql into separate queries and run each one. + # Unless it's saving a favorite query, in which case we + # want to save them all together. + if statement.startswith('\\fs'): + components = [statement] + else: + components = special.split_queries(statement) + + for sql in components: + # \G is treated specially since we have to set the expanded output. + if sql.endswith('\\G'): + special.set_expanded_output(True) + sql = sql[:-2].strip() + + cur = self.conn.cursor() + try: # Special command + _logger.debug('Trying a dbspecial command. sql: %r', sql) + for result in special.execute(cur, sql): + yield result + except special.CommandNotFound: # Regular SQL + _logger.debug('Regular sql statement. sql: %r', sql) + cur.execute(sql) + while True: + yield self.get_result(cur) + + # PyMySQL returns an extra, empty result set with stored + # procedures. We skip it (rowcount is zero and no + # description). + if not cur.nextset() or (not cur.rowcount and cur.description is None): + break + + def get_result(self, cursor): + """Get the current result's data from the cursor.""" + title = headers = None + + # cursor.description is not None for queries that return result sets, + # e.g. SELECT or SHOW. + if cursor.description is not None: + headers = [x[0] for x in cursor.description] + status = '{0} row{1} in set' + else: + _logger.debug('No rows in result.') + status = 'Query OK, {0} row{1} affected' + status = status.format(cursor.rowcount, + '' if cursor.rowcount == 1 else 's') + + return (title, cursor if cursor.description else None, headers, status) + + def tables(self): + """Yields table names""" + + with self.conn.cursor() as cur: + _logger.debug('Tables Query. sql: %r', self.tables_query) + cur.execute(self.tables_query) + for row in cur: + yield row + + def table_columns(self): + """Yields (table name, column name) pairs""" + with self.conn.cursor() as cur: + _logger.debug('Columns Query. sql: %r', self.table_columns_query) + cur.execute(self.table_columns_query % self.dbname) + for row in cur: + yield row + + def databases(self): + with self.conn.cursor() as cur: + _logger.debug('Databases Query. sql: %r', self.databases_query) + cur.execute(self.databases_query) + return [x[0] for x in cur.fetchall()] + + def functions(self): + """Yields tuples of (schema_name, function_name)""" + + with self.conn.cursor() as cur: + _logger.debug('Functions Query. sql: %r', self.functions_query) + cur.execute(self.functions_query % self.dbname) + for row in cur: + yield row + + def show_candidates(self): + with self.conn.cursor() as cur: + _logger.debug('Show Query. sql: %r', self.show_candidates_query) + try: + cur.execute(self.show_candidates_query) + except pymysql.DatabaseError as e: + _logger.error('No show completions due to %r', e) + yield '' + else: + for row in cur: + yield (row[0].split(None, 1)[-1], ) + + def users(self): + with self.conn.cursor() as cur: + _logger.debug('Users Query. sql: %r', self.users_query) + try: + cur.execute(self.users_query) + except pymysql.DatabaseError as e: + _logger.error('No user completions due to %r', e) + yield '' + else: + for row in cur: + yield row + + def server_type(self): + if self._server_type: + return self._server_type + with self.conn.cursor() as cur: + _logger.debug('Version Query. sql: %r', self.version_query) + cur.execute(self.version_query) + version = cur.fetchone()[0] + if version[0] == '4': + _logger.debug('Version Comment. sql: %r', + self.version_comment_query_mysql4) + cur.execute(self.version_comment_query_mysql4) + version_comment = cur.fetchone()[1].lower() + if isinstance(version_comment, bytes): + # with python3 this query returns bytes + version_comment = version_comment.decode('utf-8') + else: + _logger.debug('Version Comment. sql: %r', + self.version_comment_query) + cur.execute(self.version_comment_query) + version_comment = cur.fetchone()[0].lower() + + if 'mariadb' in version_comment: + product_type = 'mariadb' + elif 'percona' in version_comment: + product_type = 'percona' + else: + product_type = 'mysql' + + self._server_type = (product_type, version) + return self._server_type + + def get_connection_id(self): + if not self.connection_id: + self.reset_connection_id() + return self.connection_id + + def reset_connection_id(self): + # Remember current connection id + _logger.debug('Get current connection id') + res = self.run('select connection_id()') + for title, cur, headers, status in res: + self.connection_id = cur.fetchone()[0] + _logger.debug('Current connection id: %s', self.connection_id) + + def change_db(self, db): + self.conn.select_db(db) + self.dbname = db diff --git a/release.py b/release.py new file mode 100755 index 0000000..30c41b3 --- /dev/null +++ b/release.py @@ -0,0 +1,119 @@ +"""A script to publish a release of mycli to PyPI.""" + +import io +from optparse import OptionParser +import re +import subprocess +import sys + +import click + +DEBUG = False +CONFIRM_STEPS = False +DRY_RUN = False + + +def skip_step(): + """ + Asks for user's response whether to run a step. Default is yes. + :return: boolean + """ + global CONFIRM_STEPS + + if CONFIRM_STEPS: + return not click.confirm('--- Run this step?', default=True) + return False + + +def run_step(*args): + """ + Prints out the command and asks if it should be run. + If yes (default), runs it. + :param args: list of strings (command and args) + """ + global DRY_RUN + + cmd = args + print(' '.join(cmd)) + if skip_step(): + print('--- Skipping...') + elif DRY_RUN: + print('--- Pretending to run...') + else: + subprocess.check_output(cmd) + + +def version(version_file): + _version_re = re.compile( + r'__version__\s+=\s+(?P<quote>[\'"])(?P<version>.*)(?P=quote)') + + with open(version_file) as f: + ver = _version_re.search(f.read()).group('version') + + return ver + + +def commit_for_release(version_file, ver): + run_step('git', 'reset') + run_step('git', 'add', version_file) + run_step('git', 'commit', '--message', + 'Releasing version {}'.format(ver)) + + +def create_git_tag(tag_name): + run_step('git', 'tag', tag_name) + + +def create_distribution_files(): + run_step('python', 'setup.py', 'sdist', 'bdist_wheel') + + +def upload_distribution_files(): + run_step('twine', 'upload', 'dist/*') + + +def push_to_github(): + run_step('git', 'push', 'origin', 'master') + + +def push_tags_to_github(): + run_step('git', 'push', '--tags', 'origin') + + +def checklist(questions): + for question in questions: + if not click.confirm('--- {}'.format(question), default=False): + sys.exit(1) + + +if __name__ == '__main__': + if DEBUG: + subprocess.check_output = lambda x: x + + ver = version('mycli/__init__.py') + print('Releasing Version:', ver) + + parser = OptionParser() + parser.add_option( + "-c", "--confirm-steps", action="store_true", dest="confirm_steps", + default=False, help=("Confirm every step. If the step is not " + "confirmed, it will be skipped.") + ) + parser.add_option( + "-d", "--dry-run", action="store_true", dest="dry_run", + default=False, help="Print out, but not actually run any steps." + ) + + popts, pargs = parser.parse_args() + CONFIRM_STEPS = popts.confirm_steps + DRY_RUN = popts.dry_run + + if not click.confirm('Are you sure?', default=False): + sys.exit(1) + + commit_for_release('mycli/__init__.py', ver) + create_git_tag('v{}'.format(ver)) + create_distribution_files() + push_to_github() + push_tags_to_github() + upload_distribution_files() diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..8e206a5 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,14 @@ +mock +pytest!=3.3.0 +pytest-cov==2.4.0 +tox +twine==1.12.1 +behave +pexpect +coverage==5.0.4 +codecov==2.0.9 +autopep8==1.3.3 +colorama==0.4.1 +git+https://github.com/hayd/pep8radius.git # --error-status option not released +click>=7.0 +paramiko==2.7.1 diff --git a/screenshots/main.gif b/screenshots/main.gif Binary files differnew file mode 100644 index 0000000..8973195 --- /dev/null +++ b/screenshots/main.gif diff --git a/screenshots/tables.png b/screenshots/tables.png Binary files differnew file mode 100644 index 0000000..1d6afcf --- /dev/null +++ b/screenshots/tables.png diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..e533c7b --- /dev/null +++ b/setup.cfg @@ -0,0 +1,18 @@ +[bdist_wheel] +universal = 1 + +[tool:pytest] +addopts = --capture=sys + --showlocals + --doctest-modules + --doctest-ignore-import-errors + --ignore=setup.py + --ignore=mycli/magic.py + --ignore=mycli/packages/parseutils.py + --ignore=test/features + +[pep8] +rev = master +docformatter = True +diff = True +error-status = True diff --git a/setup.py b/setup.py new file mode 100755 index 0000000..156cd1a --- /dev/null +++ b/setup.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python + +import ast +import re +import subprocess +import sys + +from setuptools import Command, find_packages, setup +from setuptools.command.test import test as TestCommand + +_version_re = re.compile(r'__version__\s+=\s+(.*)') + +with open('mycli/__init__.py') as f: + version = ast.literal_eval(_version_re.search( + f.read()).group(1)) + +description = 'CLI for MySQL Database. With auto-completion and syntax highlighting.' + +install_requirements = [ + 'click >= 7.0', + 'Pygments >= 1.6', + 'prompt_toolkit>=3.0.0,<4.0.0', + 'PyMySQL >= 0.9.2', + 'sqlparse>=0.3.0,<0.4.0', + 'configobj >= 5.0.5', + 'cryptography >= 1.0.0', + 'cli_helpers[styles] > 1.1.0', +] + + +class lint(Command): + description = 'check code against PEP 8 (and fix violations)' + + user_options = [ + ('branch=', 'b', 'branch/revision to compare against (e.g. master)'), + ('fix', 'f', 'fix the violations in place'), + ('error-status', 'e', 'return an error code on failed PEP check'), + ] + + def initialize_options(self): + """Set the default options.""" + self.branch = 'master' + self.fix = False + self.error_status = True + + def finalize_options(self): + pass + + def run(self): + cmd = 'pep8radius {}'.format(self.branch) + if self.fix: + cmd += ' --in-place' + if self.error_status: + cmd += ' --error-status' + sys.exit(subprocess.call(cmd, shell=True)) + + +class test(TestCommand): + + user_options = [ + ('pytest-args=', 'a', 'Arguments to pass to pytest'), + ('behave-args=', 'b', 'Arguments to pass to pytest') + ] + + def initialize_options(self): + TestCommand.initialize_options(self) + self.pytest_args = '' + self.behave_args = '' + + def run_tests(self): + unit_test_errno = subprocess.call( + 'pytest test/ ' + self.pytest_args, + shell=True + ) + cli_errno = subprocess.call( + 'behave test/features ' + self.behave_args, + shell=True + ) + sys.exit(unit_test_errno or cli_errno) + + +setup( + name='mycli', + author='Mycli Core Team', + author_email='mycli-dev@googlegroups.com', + version=version, + url='http://mycli.net', + packages=find_packages(), + package_data={'mycli': ['myclirc', 'AUTHORS', 'SPONSORS']}, + description=description, + long_description=description, + install_requires=install_requirements, + entry_points={ + 'console_scripts': ['mycli = mycli.main:cli'], + }, + cmdclass={'lint': lint, 'test': test}, + python_requires=">=3.6", + classifiers=[ + 'Intended Audience :: Developers', + 'License :: OSI Approved :: BSD License', + 'Operating System :: Unix', + 'Programming Language :: Python', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: SQL', + 'Topic :: Database', + 'Topic :: Database :: Front-Ends', + 'Topic :: Software Development', + 'Topic :: Software Development :: Libraries :: Python Modules', + ], + extras_require={ + 'ssh': ['paramiko'], + }, +) diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/test/__init__.py diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..cf6d721 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,29 @@ +import pytest +from .utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db, + db_connection, SSH_USER, SSH_HOST, SSH_PORT) +import mycli.sqlexecute + + +@pytest.yield_fixture(scope="function") +def connection(): + create_db('_test_db') + connection = db_connection('_test_db') + yield connection + + connection.close() + + +@pytest.fixture +def cursor(connection): + with connection.cursor() as cur: + return cur + + +@pytest.fixture +def executor(connection): + return mycli.sqlexecute.SQLExecute( + database='_test_db', user=USER, + host=HOST, password=PASSWORD, port=PORT, socket=None, charset=CHARSET, + local_infile=False, ssl=None, ssh_user=SSH_USER, ssh_host=SSH_HOST, + ssh_port=SSH_PORT, ssh_password=None, ssh_key_filename=None + ) diff --git a/test/features/__init__.py b/test/features/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/test/features/__init__.py diff --git a/test/features/auto_vertical.feature b/test/features/auto_vertical.feature new file mode 100644 index 0000000..aa95718 --- /dev/null +++ b/test/features/auto_vertical.feature @@ -0,0 +1,12 @@ +Feature: auto_vertical mode: + on, off + + Scenario: auto_vertical on with small query + When we run dbcli with --auto-vertical-output + and we execute a small query + then we see small results in horizontal format + + Scenario: auto_vertical on with large query + When we run dbcli with --auto-vertical-output + and we execute a large query + then we see large results in vertical format diff --git a/test/features/basic_commands.feature b/test/features/basic_commands.feature new file mode 100644 index 0000000..a12e899 --- /dev/null +++ b/test/features/basic_commands.feature @@ -0,0 +1,19 @@ +Feature: run the cli, + call the help command, + exit the cli + + Scenario: run "\?" command + When we send "\?" command + then we see help output + + Scenario: run source command + When we send source command + then we see help output + + Scenario: check our application_name + When we run query to check application_name + then we see found + + Scenario: run the cli and exit + When we send "ctrl + d" + then dbcli exits diff --git a/test/features/crud_database.feature b/test/features/crud_database.feature new file mode 100644 index 0000000..0c298b6 --- /dev/null +++ b/test/features/crud_database.feature @@ -0,0 +1,26 @@ +Feature: manipulate databases: + create, drop, connect, disconnect + + Scenario: create and drop temporary database + When we create database + then we see database created + when we drop database + then we confirm the destructive warning + then we see database dropped + when we connect to dbserver + then we see database connected + + Scenario: connect and disconnect from test database + When we connect to test database + then we see database connected + when we connect to dbserver + then we see database connected + + Scenario: create and drop default database + When we create database + then we see database created + when we connect to tmp database + then we see database connected + when we drop database + then we confirm the destructive warning + then we see database dropped and no default database diff --git a/test/features/crud_table.feature b/test/features/crud_table.feature new file mode 100644 index 0000000..3384efd --- /dev/null +++ b/test/features/crud_table.feature @@ -0,0 +1,49 @@ +Feature: manipulate tables: + create, insert, update, select, delete from, drop + + Scenario: create, insert, select from, update, drop table + When we connect to test database + then we see database connected + when we create table + then we see table created + when we insert into table + then we see record inserted + when we update table + then we see record updated + when we select from table + then we see data selected + when we delete from table + then we confirm the destructive warning + then we see record deleted + when we drop table + then we confirm the destructive warning + then we see table dropped + when we connect to dbserver + then we see database connected + + Scenario: select null values + When we connect to test database + then we see database connected + when we select null + then we see null selected + + Scenario: confirm destructive query + When we query "create table foo(x integer);" + and we query "delete from foo;" + and we answer the destructive warning with "y" + then we see text "Your call!" + + Scenario: decline destructive query + When we query "delete from foo;" + and we answer the destructive warning with "n" + then we see text "Wise choice!" + + Scenario: no destructive warning if disabled in config + When we run dbcli with --no-warn + and we query "create table blabla(x integer);" + and we query "delete from blabla;" + Then we see text "Query OK" + + Scenario: confirm destructive query with invalid response + When we query "delete from foo;" + then we answer the destructive warning with invalid "1" and see text "is not a valid boolean" diff --git a/test/features/db_utils.py b/test/features/db_utils.py new file mode 100644 index 0000000..c29dedb --- /dev/null +++ b/test/features/db_utils.py @@ -0,0 +1,87 @@ +import pymysql + + +def create_db(hostname='localhost', username=None, password=None, + dbname=None): + """Create test database. + + :param hostname: string + :param username: string + :param password: string + :param dbname: string + :return: + + """ + cn = pymysql.connect( + host=hostname, + user=username, + password=password, + charset='utf8mb4', + cursorclass=pymysql.cursors.DictCursor + ) + + with cn.cursor() as cr: + cr.execute('drop database if exists ' + dbname) + cr.execute('create database ' + dbname) + + cn.close() + + cn = create_cn(hostname, password, username, dbname) + return cn + + +def create_cn(hostname, password, username, dbname): + """Open connection to database. + + :param hostname: + :param password: + :param username: + :param dbname: string + :return: psycopg2.connection + + """ + cn = pymysql.connect( + host=hostname, + user=username, + password=password, + db=dbname, + charset='utf8mb4', + cursorclass=pymysql.cursors.DictCursor + ) + + return cn + + +def drop_db(hostname='localhost', username=None, password=None, + dbname=None): + """Drop database. + + :param hostname: string + :param username: string + :param password: string + :param dbname: string + + """ + cn = pymysql.connect( + host=hostname, + user=username, + password=password, + db=dbname, + charset='utf8mb4', + cursorclass=pymysql.cursors.DictCursor + ) + + with cn.cursor() as cr: + cr.execute('drop database if exists ' + dbname) + + close_cn(cn) + + +def close_cn(cn=None): + """Close connection. + + :param connection: pymysql.connection + + """ + if cn: + cn.close() diff --git a/test/features/environment.py b/test/features/environment.py new file mode 100644 index 0000000..1a49dbe --- /dev/null +++ b/test/features/environment.py @@ -0,0 +1,132 @@ +import os +import sys +from tempfile import mkstemp + +import db_utils as dbutils +import fixture_utils as fixutils +import pexpect + +from steps.wrappers import run_cli, wait_prompt + +test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log') + + +def before_all(context): + """Set env parameters.""" + os.environ['LINES'] = "100" + os.environ['COLUMNS'] = "100" + os.environ['EDITOR'] = 'ex' + os.environ['LC_ALL'] = 'en_US.utf8' + os.environ['PROMPT_TOOLKIT_NO_CPR'] = '1' + + test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) + login_path_file = os.path.join(test_dir, 'mylogin.cnf') + os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file + + context.package_root = os.path.abspath( + os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + + os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root, + '.coveragerc') + + context.exit_sent = False + + vi = '_'.join([str(x) for x in sys.version_info[:3]]) + db_name = context.config.userdata.get( + 'my_test_db', None) or "mycli_behave_tests" + db_name_full = '{0}_{1}'.format(db_name, vi) + + # Store get params from config/environment variables + context.conf = { + 'host': context.config.userdata.get( + 'my_test_host', + os.getenv('PYTEST_HOST', 'localhost') + ), + 'user': context.config.userdata.get( + 'my_test_user', + os.getenv('PYTEST_USER', 'root') + ), + 'pass': context.config.userdata.get( + 'my_test_pass', + os.getenv('PYTEST_PASSWORD', None) + ), + 'cli_command': context.config.userdata.get( + 'my_cli_command', None) or + sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"', + 'dbname': db_name, + 'dbname_tmp': db_name_full + '_tmp', + 'vi': vi, + 'pager_boundary': '---boundary---', + } + + _, my_cnf = mkstemp() + with open(my_cnf, 'w') as f: + f.write( + '[client]\n' + 'pager={0} {1} {2}\n'.format( + sys.executable, os.path.join(context.package_root, + 'test/features/wrappager.py'), + context.conf['pager_boundary']) + ) + context.conf['defaults-file'] = my_cnf + context.conf['myclirc'] = os.path.join(context.package_root, 'test', + 'myclirc') + + context.cn = dbutils.create_db(context.conf['host'], context.conf['user'], + context.conf['pass'], + context.conf['dbname']) + + context.fixture_data = fixutils.read_fixture_files() + + +def after_all(context): + """Unset env parameters.""" + dbutils.close_cn(context.cn) + dbutils.drop_db(context.conf['host'], context.conf['user'], + context.conf['pass'], context.conf['dbname']) + + # Restore env vars. + #for k, v in context.pgenv.items(): + # if k in os.environ and v is None: + # del os.environ[k] + # elif v: + # os.environ[k] = v + + +def before_step(context, _): + context.atprompt = False + + +def before_scenario(context, _): + with open(test_log_file, 'w') as f: + f.write('') + run_cli(context) + wait_prompt(context) + + +def after_scenario(context, _): + """Cleans up after each test complete.""" + with open(test_log_file) as f: + for line in f: + if 'error' in line.lower(): + raise RuntimeError(f'Error in log file: {line}') + + if hasattr(context, 'cli') and not context.exit_sent: + # Quit nicely. + if not context.atprompt: + user = context.conf['user'] + host = context.conf['host'] + dbname = context.currentdb + context.cli.expect_exact( + '{0}@{1}:{2}> '.format( + user, host, dbname + ), + timeout=5 + ) + context.cli.sendcontrol('d') + context.cli.expect_exact(pexpect.EOF, timeout=5) + +# TODO: uncomment to debug a failure +# def after_step(context, step): +# if step.status == "failed": +# import ipdb; ipdb.set_trace() diff --git a/test/features/fixture_data/help.txt b/test/features/fixture_data/help.txt new file mode 100644 index 0000000..deb499a --- /dev/null +++ b/test/features/fixture_data/help.txt @@ -0,0 +1,24 @@ ++--------------------------+-----------------------------------------------+ +| Command | Description | +|--------------------------+-----------------------------------------------| +| \# | Refresh auto-completions. | +| \? | Show Help. | +| \c[onnect] database_name | Change to a new database. | +| \d [pattern] | List or describe tables, views and sequences. | +| \dT[S+] [pattern] | List data types | +| \df[+] [pattern] | List functions. | +| \di[+] [pattern] | List indexes. | +| \dn[+] [pattern] | List schemas. | +| \ds[+] [pattern] | List sequences. | +| \dt[+] [pattern] | List tables. | +| \du[+] [pattern] | List roles. | +| \dv[+] [pattern] | List views. | +| \e [file] | Edit the query with external editor. | +| \l | List databases. | +| \n[+] [name] | List or execute named queries. | +| \nd [name [query]] | Delete a named query. | +| \ns name query | Save a named query. | +| \refresh | Refresh auto-completions. | +| \timing | Toggle timing of commands. | +| \x | Toggle expanded output. | ++--------------------------+-----------------------------------------------+ diff --git a/test/features/fixture_data/help_commands.txt b/test/features/fixture_data/help_commands.txt new file mode 100644 index 0000000..657db7d --- /dev/null +++ b/test/features/fixture_data/help_commands.txt @@ -0,0 +1,29 @@ ++-------------+----------------------------+------------------------------------------------------------+
+| Command | Shortcut | Description |
++-------------+----------------------------+------------------------------------------------------------+
+| \G | \G | Display current query results vertically. |
+| \dt | \dt[+] [table] | List or describe tables. |
+| \e | \e | Edit command with editor (uses $EDITOR). |
+| \f | \f [name [args..]] | List or execute favorite queries. |
+| \fd | \fd [name] | Delete a favorite query. |
+| \fs | \fs name query | Save a favorite query. |
+| \l | \l | List databases. |
+| \once | \o [-o] filename | Append next result to an output file (overwrite using -o). |
+| \timing | \t | Toggle timing of commands. |
+| connect | \r | Reconnect to the database. Optional database argument. |
+| exit | \q | Exit. |
+| help | \? | Show this help. |
+| nopager | \n | Disable pager, print to stdout. |
+| notee | notee | Stop writing results to an output file. |
+| pager | \P [command] | Set PAGER. Print the query results via PAGER. |
+| prompt | \R | Change prompt format. |
+| quit | \q | Quit. |
+| rehash | \# | Refresh auto-completions. |
+| source | \. filename | Execute commands from file. |
+| status | \s | Get status information from the server. |
+| system | system [command] | Execute a system shell commmand. |
+| tableformat | \T | Change the table format used to output results. |
+| tee | tee [-o] filename | Append all results to an output file (overwrite using -o). |
+| use | \u | Change to a new database. |
+| watch | watch [seconds] [-c] query | Executes the query every [seconds] seconds (by default 5). |
++-------------+----------------------------+------------------------------------------------------------+
diff --git a/test/features/fixture_utils.py b/test/features/fixture_utils.py new file mode 100644 index 0000000..f85e0f6 --- /dev/null +++ b/test/features/fixture_utils.py @@ -0,0 +1,29 @@ +import os +import io + + +def read_fixture_lines(filename): + """Read lines of text from file. + + :param filename: string name + :return: list of strings + + """ + lines = [] + for line in open(filename): + lines.append(line.strip()) + return lines + + +def read_fixture_files(): + """Read all files inside fixture_data directory.""" + fixture_dict = {} + + current_dir = os.path.dirname(__file__) + fixture_dir = os.path.join(current_dir, 'fixture_data/') + for filename in os.listdir(fixture_dir): + if filename not in ['.', '..']: + fullname = os.path.join(fixture_dir, filename) + fixture_dict[filename] = read_fixture_lines(fullname) + + return fixture_dict diff --git a/test/features/iocommands.feature b/test/features/iocommands.feature new file mode 100644 index 0000000..95366eb --- /dev/null +++ b/test/features/iocommands.feature @@ -0,0 +1,47 @@ +Feature: I/O commands + + Scenario: edit sql in file with external editor + When we start external editor providing a file name + and we type "select * from abc" in the editor + and we exit the editor + then we see dbcli prompt + and we see "select * from abc" in prompt + + Scenario: tee output from query + When we tee output + and we wait for prompt + and we select "select 123456" + and we wait for prompt + and we notee output + and we wait for prompt + then we see 123456 in tee output + + Scenario: set delimiter + When we query "delimiter $" + then delimiter is set to "$" + + Scenario: set delimiter twice + When we query "delimiter $" + and we query "delimiter ]]" + then delimiter is set to "]]" + + Scenario: set delimiter and query on same line + When we query "select 123; delimiter $ select 456 $ delimiter %" + then we see result "123" + and we see result "456" + and delimiter is set to "%" + + Scenario: send output to file + When we query "\o /tmp/output1.sql" + and we query "select 123" + and we query "system cat /tmp/output1.sql" + then we see result "123" + + Scenario: send output to file two times + When we query "\o /tmp/output1.sql" + and we query "select 123" + and we query "\o /tmp/output2.sql" + and we query "select 456" + and we query "system cat /tmp/output2.sql" + then we see result "456" +
\ No newline at end of file diff --git a/test/features/named_queries.feature b/test/features/named_queries.feature new file mode 100644 index 0000000..5e681ec --- /dev/null +++ b/test/features/named_queries.feature @@ -0,0 +1,24 @@ +Feature: named queries: + save, use and delete named queries + + Scenario: save, use and delete named queries + When we connect to test database + then we see database connected + when we save a named query + then we see the named query saved + when we use a named query + then we see the named query executed + when we delete a named query + then we see the named query deleted + + Scenario: save, use and delete named queries with parameters + When we connect to test database + then we see database connected + when we save a named query with parameters + then we see the named query saved + when we use named query with parameters + then we see the named query with parameters executed + when we use named query with too few parameters + then we see the named query with parameters fail with missing parameters + when we use named query with too many parameters + then we see the named query with parameters fail with extra parameters diff --git a/test/features/specials.feature b/test/features/specials.feature new file mode 100644 index 0000000..bb36757 --- /dev/null +++ b/test/features/specials.feature @@ -0,0 +1,7 @@ +Feature: Special commands + + @wip + Scenario: run refresh command + When we refresh completions + and we wait for prompt + then we see completions refresh started diff --git a/test/features/steps/__init__.py b/test/features/steps/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/test/features/steps/__init__.py diff --git a/test/features/steps/auto_vertical.py b/test/features/steps/auto_vertical.py new file mode 100644 index 0000000..974740d --- /dev/null +++ b/test/features/steps/auto_vertical.py @@ -0,0 +1,45 @@ +from textwrap import dedent + +from behave import then, when + +import wrappers + + +@when('we run dbcli with {arg}') +def step_run_cli_with_arg(context, arg): + wrappers.run_cli(context, run_args=arg.split('=')) + + +@when('we execute a small query') +def step_execute_small_query(context): + context.cli.sendline('select 1') + + +@when('we execute a large query') +def step_execute_large_query(context): + context.cli.sendline( + 'select {}'.format(','.join([str(n) for n in range(1, 50)]))) + + +@then('we see small results in horizontal format') +def step_see_small_results(context): + wrappers.expect_pager(context, dedent("""\ + +---+\r + | 1 |\r + +---+\r + | 1 |\r + +---+\r + \r + """), timeout=5) + wrappers.expect_exact(context, '1 row in set', timeout=2) + + +@then('we see large results in vertical format') +def step_see_large_results(context): + rows = ['{n:3}| {n}'.format(n=str(n)) for n in range(1, 50)] + expected = ('***************************[ 1. row ]' + '***************************\r\n' + + '{}\r\n'.format('\r\n'.join(rows) + '\r\n')) + + wrappers.expect_pager(context, expected, timeout=10) + wrappers.expect_exact(context, '1 row in set', timeout=2) diff --git a/test/features/steps/basic_commands.py b/test/features/steps/basic_commands.py new file mode 100644 index 0000000..425ef67 --- /dev/null +++ b/test/features/steps/basic_commands.py @@ -0,0 +1,100 @@ +"""Steps for behavioral style tests are defined in this module. + +Each step is defined by the string decorating it. This string is used +to call the step in "*.feature" file. + +""" + +from behave import when +from textwrap import dedent +import tempfile +import wrappers + + +@when('we run dbcli') +def step_run_cli(context): + wrappers.run_cli(context) + + +@when('we wait for prompt') +def step_wait_prompt(context): + wrappers.wait_prompt(context) + + +@when('we send "ctrl + d"') +def step_ctrl_d(context): + """Send Ctrl + D to hopefully exit.""" + context.cli.sendcontrol('d') + context.exit_sent = True + + +@when('we send "\?" command') +def step_send_help(context): + """Send \? + + to see help. + + """ + context.cli.sendline('\\?') + wrappers.expect_exact( + context, context.conf['pager_boundary'] + '\r\n', timeout=5) + + +@when(u'we send source command') +def step_send_source_command(context): + with tempfile.NamedTemporaryFile() as f: + f.write(b'\?') + f.flush() + context.cli.sendline('\. {0}'.format(f.name)) + wrappers.expect_exact( + context, context.conf['pager_boundary'] + '\r\n', timeout=5) + + +@when(u'we run query to check application_name') +def step_check_application_name(context): + context.cli.sendline( + "SELECT 'found' FROM performance_schema.session_connect_attrs WHERE attr_name = 'program_name' AND attr_value = 'mycli'" + ) + + +@then(u'we see found') +def step_see_found(context): + wrappers.expect_exact( + context, + context.conf['pager_boundary'] + '\r' + dedent(''' + +-------+\r + | found |\r + +-------+\r + | found |\r + +-------+\r + \r + ''') + context.conf['pager_boundary'], + timeout=5 + ) + + +@then(u'we confirm the destructive warning') +def step_confirm_destructive_command(context): + """Confirm destructive command.""" + wrappers.expect_exact( + context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) + context.cli.sendline('y') + + +@when(u'we answer the destructive warning with "{confirmation}"') +def step_confirm_destructive_command(context, confirmation): + """Confirm destructive command.""" + wrappers.expect_exact( + context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) + context.cli.sendline(confirmation) + + +@then(u'we answer the destructive warning with invalid "{confirmation}" and see text "{text}"') +def step_confirm_destructive_command(context, confirmation, text): + """Confirm destructive command.""" + wrappers.expect_exact( + context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) + context.cli.sendline(confirmation) + wrappers.expect_exact(context, text, timeout=2) + # we must exit the Click loop, or the feature will hang + context.cli.sendline('n') diff --git a/test/features/steps/crud_database.py b/test/features/steps/crud_database.py new file mode 100644 index 0000000..a0bfa53 --- /dev/null +++ b/test/features/steps/crud_database.py @@ -0,0 +1,112 @@ +"""Steps for behavioral style tests are defined in this module. + +Each step is defined by the string decorating it. This string is used +to call the step in "*.feature" file. + +""" + +import pexpect + +import wrappers +from behave import when, then + + +@when('we create database') +def step_db_create(context): + """Send create database.""" + context.cli.sendline('create database {0};'.format( + context.conf['dbname_tmp'])) + + context.response = { + 'database_name': context.conf['dbname_tmp'] + } + + +@when('we drop database') +def step_db_drop(context): + """Send drop database.""" + context.cli.sendline('drop database {0};'.format( + context.conf['dbname_tmp'])) + + +@when('we connect to test database') +def step_db_connect_test(context): + """Send connect to database.""" + db_name = context.conf['dbname'] + context.currentdb = db_name + context.cli.sendline('use {0};'.format(db_name)) + + +@when('we connect to tmp database') +def step_db_connect_tmp(context): + """Send connect to database.""" + db_name = context.conf['dbname_tmp'] + context.currentdb = db_name + context.cli.sendline('use {0}'.format(db_name)) + + +@when('we connect to dbserver') +def step_db_connect_dbserver(context): + """Send connect to database.""" + context.currentdb = 'mysql' + context.cli.sendline('use mysql') + + +@then('dbcli exits') +def step_wait_exit(context): + """Make sure the cli exits.""" + wrappers.expect_exact(context, pexpect.EOF, timeout=5) + + +@then('we see dbcli prompt') +def step_see_prompt(context): + """Wait to see the prompt.""" + user = context.conf['user'] + host = context.conf['host'] + dbname = context.currentdb + wrappers.expect_exact(context, '{0}@{1}:{2}> '.format( + user, host, dbname), timeout=5) + context.atprompt = True + + +@then('we see help output') +def step_see_help(context): + for expected_line in context.fixture_data['help_commands.txt']: + wrappers.expect_exact(context, expected_line + '\r\n', timeout=1) + + +@then('we see database created') +def step_see_db_created(context): + """Wait to see create database output.""" + wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + + +@then('we see database dropped') +def step_see_db_dropped(context): + """Wait to see drop database output.""" + wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) + + +@then('we see database dropped and no default database') +def step_see_db_dropped_no_default(context): + """Wait to see drop database output.""" + user = context.conf['user'] + host = context.conf['host'] + database = '(none)' + context.currentdb = None + + wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) + wrappers.expect_exact(context, '{0}@{1}:{2}> '.format( + user, host, database), timeout=5) + + context.atprompt = True + + +@then('we see database connected') +def step_see_db_connected(context): + """Wait to see drop database output.""" + wrappers.expect_exact( + context, 'You are now connected to database "', timeout=2) + wrappers.expect_exact(context, '"', timeout=2) + wrappers.expect_exact(context, ' as user "{0}"'.format( + context.conf['user']), timeout=2) diff --git a/test/features/steps/crud_table.py b/test/features/steps/crud_table.py new file mode 100644 index 0000000..f715f0c --- /dev/null +++ b/test/features/steps/crud_table.py @@ -0,0 +1,112 @@ +"""Steps for behavioral style tests are defined in this module. + +Each step is defined by the string decorating it. This string is used +to call the step in "*.feature" file. + +""" + +import wrappers +from behave import when, then +from textwrap import dedent + + +@when('we create table') +def step_create_table(context): + """Send create table.""" + context.cli.sendline('create table a(x text);') + + +@when('we insert into table') +def step_insert_into_table(context): + """Send insert into table.""" + context.cli.sendline('''insert into a(x) values('xxx');''') + + +@when('we update table') +def step_update_table(context): + """Send insert into table.""" + context.cli.sendline('''update a set x = 'yyy' where x = 'xxx';''') + + +@when('we select from table') +def step_select_from_table(context): + """Send select from table.""" + context.cli.sendline('select * from a;') + + +@when('we delete from table') +def step_delete_from_table(context): + """Send deete from table.""" + context.cli.sendline('''delete from a where x = 'yyy';''') + + +@when('we drop table') +def step_drop_table(context): + """Send drop table.""" + context.cli.sendline('drop table a;') + + +@then('we see table created') +def step_see_table_created(context): + """Wait to see create table output.""" + wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) + + +@then('we see record inserted') +def step_see_record_inserted(context): + """Wait to see insert output.""" + wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + + +@then('we see record updated') +def step_see_record_updated(context): + """Wait to see update output.""" + wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + + +@then('we see data selected') +def step_see_data_selected(context): + """Wait to see select output.""" + wrappers.expect_pager( + context, dedent("""\ + +-----+\r + | x |\r + +-----+\r + | yyy |\r + +-----+\r + \r + """), timeout=2) + wrappers.expect_exact(context, '1 row in set', timeout=2) + + +@then('we see record deleted') +def step_see_data_deleted(context): + """Wait to see delete output.""" + wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + + +@then('we see table dropped') +def step_see_table_dropped(context): + """Wait to see drop output.""" + wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) + + +@when('we select null') +def step_select_null(context): + """Send select null.""" + context.cli.sendline('select null;') + + +@then('we see null selected') +def step_see_null_selected(context): + """Wait to see null output.""" + wrappers.expect_pager( + context, dedent("""\ + +--------+\r + | NULL |\r + +--------+\r + | <null> |\r + +--------+\r + \r + """), timeout=2) + wrappers.expect_exact(context, '1 row in set', timeout=2) diff --git a/test/features/steps/iocommands.py b/test/features/steps/iocommands.py new file mode 100644 index 0000000..bbabf43 --- /dev/null +++ b/test/features/steps/iocommands.py @@ -0,0 +1,105 @@ +import os +import wrappers + +from behave import when, then +from textwrap import dedent + + +@when('we start external editor providing a file name') +def step_edit_file(context): + """Edit file with external editor.""" + context.editor_file_name = os.path.join( + context.package_root, 'test_file_{0}.sql'.format(context.conf['vi'])) + if os.path.exists(context.editor_file_name): + os.remove(context.editor_file_name) + context.cli.sendline('\e {0}'.format( + os.path.basename(context.editor_file_name))) + wrappers.expect_exact( + context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2) + wrappers.expect_exact(context, '\r\n:', timeout=2) + + +@when('we type "{query}" in the editor') +def step_edit_type_sql(context, query): + context.cli.sendline('i') + context.cli.sendline(query) + context.cli.sendline('.') + wrappers.expect_exact(context, '\r\n:', timeout=2) + + +@when('we exit the editor') +def step_edit_quit(context): + context.cli.sendline('x') + wrappers.expect_exact(context, "written", timeout=2) + + +@then('we see "{query}" in prompt') +def step_edit_done_sql(context, query): + for match in query.split(' '): + wrappers.expect_exact(context, match, timeout=5) + # Cleanup the command line. + context.cli.sendcontrol('c') + # Cleanup the edited file. + if context.editor_file_name and os.path.exists(context.editor_file_name): + os.remove(context.editor_file_name) + + +@when(u'we tee output') +def step_tee_ouptut(context): + context.tee_file_name = os.path.join( + context.package_root, 'tee_file_{0}.sql'.format(context.conf['vi'])) + if os.path.exists(context.tee_file_name): + os.remove(context.tee_file_name) + context.cli.sendline('tee {0}'.format( + os.path.basename(context.tee_file_name))) + + +@when(u'we select "select {param}"') +def step_query_select_number(context, param): + context.cli.sendline(u'select {}'.format(param)) + wrappers.expect_pager(context, dedent(u"""\ + +{dashes}+\r + | {param} |\r + +{dashes}+\r + | {param} |\r + +{dashes}+\r + \r + """.format(param=param, dashes='-' * (len(param) + 2)) + ), timeout=5) + wrappers.expect_exact(context, '1 row in set', timeout=2) + + +@then(u'we see result "{result}"') +def step_see_result(context, result): + wrappers.expect_exact( + context, + u"| {} |".format(result), + timeout=2 + ) + + +@when(u'we query "{query}"') +def step_query(context, query): + context.cli.sendline(query) + + +@when(u'we notee output') +def step_notee_output(context): + context.cli.sendline('notee') + + +@then(u'we see 123456 in tee output') +def step_see_123456_in_ouput(context): + with open(context.tee_file_name) as f: + assert '123456' in f.read() + if os.path.exists(context.tee_file_name): + os.remove(context.tee_file_name) + + +@then(u'delimiter is set to "{delimiter}"') +def delimiter_is_set(context, delimiter): + wrappers.expect_exact( + context, + u'Changed delimiter to {}'.format(delimiter), + timeout=2 + ) diff --git a/test/features/steps/named_queries.py b/test/features/steps/named_queries.py new file mode 100644 index 0000000..bc1f866 --- /dev/null +++ b/test/features/steps/named_queries.py @@ -0,0 +1,90 @@ +"""Steps for behavioral style tests are defined in this module. + +Each step is defined by the string decorating it. This string is used +to call the step in "*.feature" file. + +""" + +import wrappers +from behave import when, then + + +@when('we save a named query') +def step_save_named_query(context): + """Send \fs command.""" + context.cli.sendline('\\fs foo SELECT 12345') + + +@when('we use a named query') +def step_use_named_query(context): + """Send \f command.""" + context.cli.sendline('\\f foo') + + +@when('we delete a named query') +def step_delete_named_query(context): + """Send \fd command.""" + context.cli.sendline('\\fd foo') + + +@then('we see the named query saved') +def step_see_named_query_saved(context): + """Wait to see query saved.""" + wrappers.expect_exact(context, 'Saved.', timeout=2) + + +@then('we see the named query executed') +def step_see_named_query_executed(context): + """Wait to see select output.""" + wrappers.expect_exact(context, 'SELECT 12345', timeout=2) + + +@then('we see the named query deleted') +def step_see_named_query_deleted(context): + """Wait to see query deleted.""" + wrappers.expect_exact(context, 'foo: Deleted', timeout=2) + + +@when('we save a named query with parameters') +def step_save_named_query_with_parameters(context): + """Send \fs command for query with parameters.""" + context.cli.sendline('\\fs foo_args SELECT $1, "$2", "$3"') + + +@when('we use named query with parameters') +def step_use_named_query_with_parameters(context): + """Send \f command with parameters.""" + context.cli.sendline('\\f foo_args 101 second "third value"') + + +@then('we see the named query with parameters executed') +def step_see_named_query_with_parameters_executed(context): + """Wait to see select output.""" + wrappers.expect_exact( + context, 'SELECT 101, "second", "third value"', timeout=2) + + +@when('we use named query with too few parameters') +def step_use_named_query_with_too_few_parameters(context): + """Send \f command with missing parameters.""" + context.cli.sendline('\\f foo_args 101') + + +@then('we see the named query with parameters fail with missing parameters') +def step_see_named_query_with_parameters_fail_with_missing_parameters(context): + """Wait to see select output.""" + wrappers.expect_exact( + context, 'missing substitution for $2 in query:', timeout=2) + + +@when('we use named query with too many parameters') +def step_use_named_query_with_too_many_parameters(context): + """Send \f command with extra parameters.""" + context.cli.sendline('\\f foo_args 101 102 103 104') + + +@then('we see the named query with parameters fail with extra parameters') +def step_see_named_query_with_parameters_fail_with_extra_parameters(context): + """Wait to see select output.""" + wrappers.expect_exact( + context, 'query does not have substitution parameter $4:', timeout=2) diff --git a/test/features/steps/specials.py b/test/features/steps/specials.py new file mode 100644 index 0000000..e8b99e3 --- /dev/null +++ b/test/features/steps/specials.py @@ -0,0 +1,27 @@ +"""Steps for behavioral style tests are defined in this module. + +Each step is defined by the string decorating it. This string is used +to call the step in "*.feature" file. + +""" + +import wrappers +from behave import when, then + + +@when('we refresh completions') +def step_refresh_completions(context): + """Send refresh command.""" + context.cli.sendline('rehash') + + +@then('we see text "{text}"') +def step_see_text(context, text): + """Wait to see given text message.""" + wrappers.expect_exact(context, text, timeout=2) + +@then('we see completions refresh started') +def step_see_refresh_started(context): + """Wait to see refresh output.""" + wrappers.expect_exact( + context, 'Auto-completion refresh started in the background.', timeout=2) diff --git a/test/features/steps/wrappers.py b/test/features/steps/wrappers.py new file mode 100644 index 0000000..565ca59 --- /dev/null +++ b/test/features/steps/wrappers.py @@ -0,0 +1,94 @@ +import re +import pexpect +import sys +import textwrap + +try: + from StringIO import StringIO +except ImportError: + from io import StringIO + + +def expect_exact(context, expected, timeout): + timedout = False + try: + context.cli.expect_exact(expected, timeout=timeout) + except pexpect.exceptions.TIMEOUT: + timedout = True + if timedout: + # Strip color codes out of the output. + actual = re.sub(r'\x1b\[([0-9A-Za-z;?])+[m|K]?', + '', context.cli.before) + raise Exception( + textwrap.dedent('''\ + Expected: + --- + {0!r} + --- + Actual: + --- + {1!r} + --- + Full log: + --- + {2!r} + --- + ''').format( + expected, + actual, + context.logfile.getvalue() + ) + ) + + +def expect_pager(context, expected, timeout): + expect_exact(context, "{0}\r\n{1}{0}\r\n".format( + context.conf['pager_boundary'], expected), timeout=timeout) + + +def run_cli(context, run_args=None): + """Run the process using pexpect.""" + run_args = run_args or [] + if context.conf.get('host', None): + run_args.extend(('-h', context.conf['host'])) + if context.conf.get('user', None): + run_args.extend(('-u', context.conf['user'])) + if context.conf.get('pass', None): + run_args.extend(('-p', context.conf['pass'])) + if context.conf.get('dbname', None): + run_args.extend(('-D', context.conf['dbname'])) + if context.conf.get('defaults-file', None): + run_args.extend(('--defaults-file', context.conf['defaults-file'])) + if context.conf.get('myclirc', None): + run_args.extend(('--myclirc', context.conf['myclirc'])) + try: + cli_cmd = context.conf['cli_command'] + except KeyError: + cli_cmd = ( + '{0!s} -c "' + 'import coverage ; ' + 'coverage.process_startup(); ' + 'import mycli.main; ' + 'mycli.main.cli()' + '"' + ).format(sys.executable) + + cmd_parts = [cli_cmd] + run_args + cmd = ' '.join(cmd_parts) + context.cli = pexpect.spawnu(cmd, cwd=context.package_root) + context.logfile = StringIO() + context.cli.logfile = context.logfile + context.exit_sent = False + context.currentdb = context.conf['dbname'] + + +def wait_prompt(context, prompt=None): + """Make sure prompt is displayed.""" + if prompt is None: + user = context.conf['user'] + host = context.conf['host'] + dbname = context.currentdb + prompt = '{0}@{1}:{2}> '.format( + user, host, dbname), + expect_exact(context, prompt, timeout=5) + context.atprompt = True diff --git a/test/features/wrappager.py b/test/features/wrappager.py new file mode 100755 index 0000000..51d4909 --- /dev/null +++ b/test/features/wrappager.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +import sys + + +def wrappager(boundary): + print(boundary) + while 1: + buf = sys.stdin.read(2048) + if not buf: + break + sys.stdout.write(buf) + print(boundary) + + +if __name__ == "__main__": + wrappager(sys.argv[1]) diff --git a/test/myclirc b/test/myclirc new file mode 100644 index 0000000..261bee6 --- /dev/null +++ b/test/myclirc @@ -0,0 +1,12 @@ +# vi: ft=dosini + +# This file is loaded after mycli/myclirc and should override only those +# variables needed for testing. +# To see what every variable does see mycli/myclirc + +[main] + +log_file = ~/.mycli.test.log +log_level = DEBUG +prompt = '\t \u@\h:\d> ' +less_chatty = True diff --git a/test/mylogin.cnf b/test/mylogin.cnf Binary files differnew file mode 100644 index 0000000..1363cc3 --- /dev/null +++ b/test/mylogin.cnf diff --git a/test/test.txt b/test/test.txt new file mode 100644 index 0000000..8d8b211 --- /dev/null +++ b/test/test.txt @@ -0,0 +1 @@ +mycli rocks! diff --git a/test/test_clistyle.py b/test/test_clistyle.py new file mode 100644 index 0000000..f82cdf0 --- /dev/null +++ b/test/test_clistyle.py @@ -0,0 +1,27 @@ +"""Test the mycli.clistyle module.""" +import pytest + +from pygments.style import Style +from pygments.token import Token + +from mycli.clistyle import style_factory + + +@pytest.mark.skip(reason="incompatible with new prompt toolkit") +def test_style_factory(): + """Test that a Pygments Style class is created.""" + header = 'bold underline #ansired' + cli_style = {'Token.Output.Header': header} + style = style_factory('default', cli_style) + + assert isinstance(style(), Style) + assert Token.Output.Header in style.styles + assert header == style.styles[Token.Output.Header] + + +@pytest.mark.skip(reason="incompatible with new prompt toolkit") +def test_style_factory_unknown_name(): + """Test that an unrecognized name will not throw an error.""" + style = style_factory('foobar', {}) + + assert isinstance(style(), Style) diff --git a/test/test_completion_engine.py b/test/test_completion_engine.py new file mode 100644 index 0000000..9e7c608 --- /dev/null +++ b/test/test_completion_engine.py @@ -0,0 +1,537 @@ +from mycli.packages.completion_engine import suggest_type +import pytest + + +def sorted_dicts(dicts): + """input is a list of dicts.""" + return sorted(tuple(x.items()) for x in dicts) + + +def test_select_suggests_cols_with_visible_table_scope(): + suggestions = suggest_type('SELECT FROM tabl', 'SELECT ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + +def test_select_suggests_cols_with_qualified_table_scope(): + suggestions = suggest_type('SELECT FROM sch.tabl', 'SELECT ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'column', 'tables': [('sch', 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + +@pytest.mark.parametrize('expression', [ + 'SELECT * FROM tabl WHERE ', + 'SELECT * FROM tabl WHERE (', + 'SELECT * FROM tabl WHERE foo = ', + 'SELECT * FROM tabl WHERE bar OR ', + 'SELECT * FROM tabl WHERE foo = 1 AND ', + 'SELECT * FROM tabl WHERE (bar > 10 AND ', + 'SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (', + 'SELECT * FROM tabl WHERE 10 < ', + 'SELECT * FROM tabl WHERE foo BETWEEN ', + 'SELECT * FROM tabl WHERE foo BETWEEN foo AND ', +]) +def test_where_suggests_columns_functions(expression): + suggestions = suggest_type(expression, expression) + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + +@pytest.mark.parametrize('expression', [ + 'SELECT * FROM tabl WHERE foo IN (', + 'SELECT * FROM tabl WHERE foo IN (bar, ', +]) +def test_where_in_suggests_columns(expression): + suggestions = suggest_type(expression, expression) + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + +def test_where_equals_any_suggests_columns_or_keywords(): + text = 'SELECT * FROM tabl WHERE foo = ANY(' + suggestions = suggest_type(text, text) + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}]) + + +def test_lparen_suggests_cols(): + suggestion = suggest_type('SELECT MAX( FROM tbl', 'SELECT MAX(') + assert suggestion == [ + {'type': 'column', 'tables': [(None, 'tbl', None)]}] + + +def test_operand_inside_function_suggests_cols1(): + suggestion = suggest_type( + 'SELECT MAX(col1 + FROM tbl', 'SELECT MAX(col1 + ') + assert suggestion == [ + {'type': 'column', 'tables': [(None, 'tbl', None)]}] + + +def test_operand_inside_function_suggests_cols2(): + suggestion = suggest_type( + 'SELECT MAX(col1 + col2 + FROM tbl', 'SELECT MAX(col1 + col2 + ') + assert suggestion == [ + {'type': 'column', 'tables': [(None, 'tbl', None)]}] + + +def test_select_suggests_cols_and_funcs(): + suggestions = suggest_type('SELECT ', 'SELECT ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': []}, + {'type': 'column', 'tables': []}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + +@pytest.mark.parametrize('expression', [ + 'SELECT * FROM ', + 'INSERT INTO ', + 'COPY ', + 'UPDATE ', + 'DESCRIBE ', + 'DESC ', + 'EXPLAIN ', + 'SELECT * FROM foo JOIN ', +]) +def test_expression_suggests_tables_views_and_schemas(expression): + suggestions = suggest_type(expression, expression) + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + +@pytest.mark.parametrize('expression', [ + 'SELECT * FROM sch.', + 'INSERT INTO sch.', + 'COPY sch.', + 'UPDATE sch.', + 'DESCRIBE sch.', + 'DESC sch.', + 'EXPLAIN sch.', + 'SELECT * FROM foo JOIN sch.', +]) +def test_expression_suggests_qualified_tables_views_and_schemas(expression): + suggestions = suggest_type(expression, expression) + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': 'sch'}, + {'type': 'view', 'schema': 'sch'}]) + + +def test_truncate_suggests_tables_and_schemas(): + suggestions = suggest_type('TRUNCATE ', 'TRUNCATE ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'schema'}]) + + +def test_truncate_suggests_qualified_tables(): + suggestions = suggest_type('TRUNCATE sch.', 'TRUNCATE sch.') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': 'sch'}]) + + +def test_distinct_suggests_cols(): + suggestions = suggest_type('SELECT DISTINCT ', 'SELECT DISTINCT ') + assert suggestions == [{'type': 'column', 'tables': []}] + + +def test_col_comma_suggests_cols(): + suggestions = suggest_type('SELECT a, b, FROM tbl', 'SELECT a, b,') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['tbl']}, + {'type': 'column', 'tables': [(None, 'tbl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + +def test_table_comma_suggests_tables_and_schemas(): + suggestions = suggest_type('SELECT a, b FROM tbl1, ', + 'SELECT a, b FROM tbl1, ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + +def test_into_suggests_tables_and_schemas(): + suggestion = suggest_type('INSERT INTO ', 'INSERT INTO ') + assert sorted_dicts(suggestion) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + +def test_insert_into_lparen_suggests_cols(): + suggestions = suggest_type('INSERT INTO abc (', 'INSERT INTO abc (') + assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}] + + +def test_insert_into_lparen_partial_text_suggests_cols(): + suggestions = suggest_type('INSERT INTO abc (i', 'INSERT INTO abc (i') + assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}] + + +def test_insert_into_lparen_comma_suggests_cols(): + suggestions = suggest_type('INSERT INTO abc (id,', 'INSERT INTO abc (id,') + assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}] + + +def test_partially_typed_col_name_suggests_col_names(): + suggestions = suggest_type('SELECT * FROM tabl WHERE col_n', + 'SELECT * FROM tabl WHERE col_n') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + +def test_dot_suggests_cols_of_a_table_or_schema_qualified_table(): + suggestions = suggest_type('SELECT tabl. FROM tabl', 'SELECT tabl.') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'table', 'schema': 'tabl'}, + {'type': 'view', 'schema': 'tabl'}, + {'type': 'function', 'schema': 'tabl'}]) + + +def test_dot_suggests_cols_of_an_alias(): + suggestions = suggest_type('SELECT t1. FROM tabl1 t1, tabl2 t2', + 'SELECT t1.') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': 't1'}, + {'type': 'view', 'schema': 't1'}, + {'type': 'column', 'tables': [(None, 'tabl1', 't1')]}, + {'type': 'function', 'schema': 't1'}]) + + +def test_dot_col_comma_suggests_cols_or_schema_qualified_table(): + suggestions = suggest_type('SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2', + 'SELECT t1.a, t2.') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'tabl2', 't2')]}, + {'type': 'table', 'schema': 't2'}, + {'type': 'view', 'schema': 't2'}, + {'type': 'function', 'schema': 't2'}]) + + +@pytest.mark.parametrize('expression', [ + 'SELECT * FROM (', + 'SELECT * FROM foo WHERE EXISTS (', + 'SELECT * FROM foo WHERE bar AND NOT EXISTS (', + 'SELECT 1 AS', +]) +def test_sub_select_suggests_keyword(expression): + suggestion = suggest_type(expression, expression) + assert suggestion == [{'type': 'keyword'}] + + +@pytest.mark.parametrize('expression', [ + 'SELECT * FROM (S', + 'SELECT * FROM foo WHERE EXISTS (S', + 'SELECT * FROM foo WHERE bar AND NOT EXISTS (S', +]) +def test_sub_select_partial_text_suggests_keyword(expression): + suggestion = suggest_type(expression, expression) + assert suggestion == [{'type': 'keyword'}] + + +def test_outer_table_reference_in_exists_subquery_suggests_columns(): + q = 'SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f.' + suggestions = suggest_type(q, q) + assert suggestions == [ + {'type': 'column', 'tables': [(None, 'foo', 'f')]}, + {'type': 'table', 'schema': 'f'}, + {'type': 'view', 'schema': 'f'}, + {'type': 'function', 'schema': 'f'}] + + +@pytest.mark.parametrize('expression', [ + 'SELECT * FROM (SELECT * FROM ', + 'SELECT * FROM foo WHERE EXISTS (SELECT * FROM ', + 'SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ', +]) +def test_sub_select_table_name_completion(expression): + suggestion = suggest_type(expression, expression) + assert sorted_dicts(suggestion) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + +def test_sub_select_col_name_completion(): + suggestions = suggest_type('SELECT * FROM (SELECT FROM abc', + 'SELECT * FROM (SELECT ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['abc']}, + {'type': 'column', 'tables': [(None, 'abc', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + +@pytest.mark.xfail +def test_sub_select_multiple_col_name_completion(): + suggestions = suggest_type('SELECT * FROM (SELECT a, FROM abc', + 'SELECT * FROM (SELECT a, ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'abc', None)]}, + {'type': 'function', 'schema': []}]) + + +def test_sub_select_dot_col_name_completion(): + suggestions = suggest_type('SELECT * FROM (SELECT t. FROM tabl t', + 'SELECT * FROM (SELECT t.') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'tabl', 't')]}, + {'type': 'table', 'schema': 't'}, + {'type': 'view', 'schema': 't'}, + {'type': 'function', 'schema': 't'}]) + + +@pytest.mark.parametrize('join_type', ['', 'INNER', 'LEFT', 'RIGHT OUTER']) +@pytest.mark.parametrize('tbl_alias', ['', 'foo']) +def test_join_suggests_tables_and_schemas(tbl_alias, join_type): + text = 'SELECT * FROM abc {0} {1} JOIN '.format(tbl_alias, join_type) + suggestion = suggest_type(text, text) + assert sorted_dicts(suggestion) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + +@pytest.mark.parametrize('sql', [ + 'SELECT * FROM abc a JOIN def d ON a.', + 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.', +]) +def test_join_alias_dot_suggests_cols1(sql): + suggestions = suggest_type(sql, sql) + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'abc', 'a')]}, + {'type': 'table', 'schema': 'a'}, + {'type': 'view', 'schema': 'a'}, + {'type': 'function', 'schema': 'a'}]) + + +@pytest.mark.parametrize('sql', [ + 'SELECT * FROM abc a JOIN def d ON a.id = d.', + 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.', +]) +def test_join_alias_dot_suggests_cols2(sql): + suggestions = suggest_type(sql, sql) + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'def', 'd')]}, + {'type': 'table', 'schema': 'd'}, + {'type': 'view', 'schema': 'd'}, + {'type': 'function', 'schema': 'd'}]) + + +@pytest.mark.parametrize('sql', [ + 'select a.x, b.y from abc a join bcd b on ', + 'select a.x, b.y from abc a join bcd b on a.id = b.id OR ', +]) +def test_on_suggests_aliases(sql): + suggestions = suggest_type(sql, sql) + assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}] + + +@pytest.mark.parametrize('sql', [ + 'select abc.x, bcd.y from abc join bcd on ', + 'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ', +]) +def test_on_suggests_tables(sql): + suggestions = suggest_type(sql, sql) + assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}] + + +@pytest.mark.parametrize('sql', [ + 'select a.x, b.y from abc a join bcd b on a.id = ', + 'select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ', +]) +def test_on_suggests_aliases_right_side(sql): + suggestions = suggest_type(sql, sql) + assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}] + + +@pytest.mark.parametrize('sql', [ + 'select abc.x, bcd.y from abc join bcd on ', + 'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ', +]) +def test_on_suggests_tables_right_side(sql): + suggestions = suggest_type(sql, sql) + assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}] + + +@pytest.mark.parametrize('col_list', ['', 'col1, ']) +def test_join_using_suggests_common_columns(col_list): + text = 'select * from abc inner join def using (' + col_list + assert suggest_type(text, text) == [ + {'type': 'column', + 'tables': [(None, 'abc', None), (None, 'def', None)], + 'drop_unique': True}] + + +def test_2_statements_2nd_current(): + suggestions = suggest_type('select * from a; select * from ', + 'select * from a; select * from ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + suggestions = suggest_type('select * from a; select from b', + 'select * from a; select ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['b']}, + {'type': 'column', 'tables': [(None, 'b', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + # Should work even if first statement is invalid + suggestions = suggest_type('select * from; select * from ', + 'select * from; select * from ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + +def test_2_statements_1st_current(): + suggestions = suggest_type('select * from ; select * from b', + 'select * from ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + suggestions = suggest_type('select from a; select * from b', + 'select ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['a']}, + {'type': 'column', 'tables': [(None, 'a', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + +def test_3_statements_2nd_current(): + suggestions = suggest_type('select * from a; select * from ; select * from c', + 'select * from a; select * from ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + suggestions = suggest_type('select * from a; select from b; select * from c', + 'select * from a; select ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'alias', 'aliases': ['b']}, + {'type': 'column', 'tables': [(None, 'b', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'keyword'}, + ]) + + +def test_create_db_with_template(): + suggestions = suggest_type('create database foo with template ', + 'create database foo with template ') + + assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'database'}]) + + +@pytest.mark.parametrize('initial_text', ['', ' ', '\t \t']) +def test_specials_included_for_initial_completion(initial_text): + suggestions = suggest_type(initial_text, initial_text) + + assert sorted_dicts(suggestions) == \ + sorted_dicts([{'type': 'keyword'}, {'type': 'special'}]) + + +def test_specials_not_included_after_initial_token(): + suggestions = suggest_type('create table foo (dt d', + 'create table foo (dt d') + + assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'keyword'}]) + + +def test_drop_schema_qualified_table_suggests_only_tables(): + text = 'DROP TABLE schema_name.table_name' + suggestions = suggest_type(text, text) + assert suggestions == [{'type': 'table', 'schema': 'schema_name'}] + + +@pytest.mark.parametrize('text', [',', ' ,', 'sel ,']) +def test_handle_pre_completion_comma_gracefully(text): + suggestions = suggest_type(text, text) + + assert iter(suggestions) + + +def test_cross_join(): + text = 'select * from v1 cross join v2 JOIN v1.id, ' + suggestions = suggest_type(text, text) + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + +@pytest.mark.parametrize('expression', [ + 'SELECT 1 AS ', + 'SELECT 1 FROM tabl AS ', +]) +def test_after_as(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == set() + + +@pytest.mark.parametrize('expression', [ + '\\. ', + 'select 1; \\. ', + 'select 1;\\. ', + 'select 1 ; \\. ', + 'source ', + 'truncate table test; source ', + 'truncate table test ; source ', + 'truncate table test;source ', +]) +def test_source_is_file(expression): + suggestions = suggest_type(expression, expression) + assert suggestions == [{'type': 'file_name'}] + + +@pytest.mark.parametrize("expression", [ + "\\f ", +]) +def test_favorite_name_suggestion(expression): + suggestions = suggest_type(expression, expression) + assert suggestions == [{'type': 'favoritequery'}] + +def test_order_by(): + text = 'select * from foo order by ' + suggestions = suggest_type(text, text) + assert suggestions == [{'tables': [(None, 'foo', None)], 'type': 'column'}] diff --git a/test/test_completion_refresher.py b/test/test_completion_refresher.py new file mode 100644 index 0000000..1ed6377 --- /dev/null +++ b/test/test_completion_refresher.py @@ -0,0 +1,88 @@ +import time +import pytest +from mock import Mock, patch + + +@pytest.fixture +def refresher(): + from mycli.completion_refresher import CompletionRefresher + return CompletionRefresher() + + +def test_ctor(refresher): + """Refresher object should contain a few handlers. + + :param refresher: + :return: + + """ + assert len(refresher.refreshers) > 0 + actual_handlers = list(refresher.refreshers.keys()) + expected_handlers = ['databases', 'schemata', 'tables', 'users', 'functions', + 'special_commands', 'show_commands'] + assert expected_handlers == actual_handlers + + +def test_refresh_called_once(refresher): + """ + + :param refresher: + :return: + """ + callbacks = Mock() + sqlexecute = Mock() + + with patch.object(refresher, '_bg_refresh') as bg_refresh: + actual = refresher.refresh(sqlexecute, callbacks) + time.sleep(1) # Wait for the thread to work. + assert len(actual) == 1 + assert len(actual[0]) == 4 + assert actual[0][3] == 'Auto-completion refresh started in the background.' + bg_refresh.assert_called_with(sqlexecute, callbacks, {}) + + +def test_refresh_called_twice(refresher): + """If refresh is called a second time, it should be restarted. + + :param refresher: + :return: + + """ + callbacks = Mock() + + sqlexecute = Mock() + + def dummy_bg_refresh(*args): + time.sleep(3) # seconds + + refresher._bg_refresh = dummy_bg_refresh + + actual1 = refresher.refresh(sqlexecute, callbacks) + time.sleep(1) # Wait for the thread to work. + assert len(actual1) == 1 + assert len(actual1[0]) == 4 + assert actual1[0][3] == 'Auto-completion refresh started in the background.' + + actual2 = refresher.refresh(sqlexecute, callbacks) + time.sleep(1) # Wait for the thread to work. + assert len(actual2) == 1 + assert len(actual2[0]) == 4 + assert actual2[0][3] == 'Auto-completion refresh restarted.' + + +def test_refresh_with_callbacks(refresher): + """Callbacks must be called. + + :param refresher: + + """ + callbacks = [Mock()] + sqlexecute_class = Mock() + sqlexecute = Mock() + + with patch('mycli.completion_refresher.SQLExecute', sqlexecute_class): + # Set refreshers to 0: we're not testing refresh logic here + refresher.refreshers = {} + refresher.refresh(sqlexecute, callbacks) + time.sleep(1) # Wait for the thread to work. + assert (callbacks[0].call_count == 1) diff --git a/test/test_config.py b/test/test_config.py new file mode 100644 index 0000000..7f2b244 --- /dev/null +++ b/test/test_config.py @@ -0,0 +1,196 @@ +"""Unit tests for the mycli.config module.""" +from io import BytesIO, StringIO, TextIOWrapper +import os +import struct +import sys +import tempfile +import pytest + +from mycli.config import (get_mylogin_cnf_path, open_mylogin_cnf, + read_and_decrypt_mylogin_cnf, read_config_file, + str_to_bool, strip_matching_quotes) + +LOGIN_PATH_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__), + 'mylogin.cnf')) + + +def open_bmylogin_cnf(name): + """Open contents of *name* in a BytesIO buffer.""" + with open(name, 'rb') as f: + buf = BytesIO() + buf.write(f.read()) + return buf + +def test_read_mylogin_cnf(): + """Tests that a login path file can be read and decrypted.""" + mylogin_cnf = open_mylogin_cnf(LOGIN_PATH_FILE) + + assert isinstance(mylogin_cnf, TextIOWrapper) + + contents = mylogin_cnf.read() + for word in ('[test]', 'user', 'password', 'host', 'port'): + assert word in contents + + +def test_decrypt_blank_mylogin_cnf(): + """Test that a blank login path file is handled correctly.""" + mylogin_cnf = read_and_decrypt_mylogin_cnf(BytesIO()) + assert mylogin_cnf is None + + +def test_corrupted_login_key(): + """Test that a corrupted login path key is handled correctly.""" + buf = open_bmylogin_cnf(LOGIN_PATH_FILE) + + # Skip past the unused bytes + buf.seek(4) + + # Write null bytes over half the login key + buf.write(b'\0\0\0\0\0\0\0\0\0\0') + + buf.seek(0) + mylogin_cnf = read_and_decrypt_mylogin_cnf(buf) + + assert mylogin_cnf is None + + +def test_corrupted_pad(): + """Tests that a login path file with a corrupted pad is partially read.""" + buf = open_bmylogin_cnf(LOGIN_PATH_FILE) + + # Skip past the login key + buf.seek(24) + + # Skip option group + len_buf = buf.read(4) + cipher_len, = struct.unpack("<i", len_buf) + buf.read(cipher_len) + + # Corrupt the pad for the user line + len_buf = buf.read(4) + cipher_len, = struct.unpack("<i", len_buf) + buf.read(cipher_len - 1) + buf.write(b'\0') + + buf.seek(0) + mylogin_cnf = TextIOWrapper(read_and_decrypt_mylogin_cnf(buf)) + contents = mylogin_cnf.read() + for word in ('[test]', 'password', 'host', 'port'): + assert word in contents + assert 'user' not in contents + + +def test_get_mylogin_cnf_path(): + """Tests that the path for .mylogin.cnf is detected.""" + original_env = None + if 'MYSQL_TEST_LOGIN_FILE' in os.environ: + original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE') + is_windows = sys.platform == 'win32' + + login_cnf_path = get_mylogin_cnf_path() + + if original_env is not None: + os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env + + if login_cnf_path is not None: + assert login_cnf_path.endswith('.mylogin.cnf') + + if is_windows is True: + assert 'MySQL' in login_cnf_path + else: + home_dir = os.path.expanduser('~') + assert login_cnf_path.startswith(home_dir) + + +def test_alternate_get_mylogin_cnf_path(): + """Tests that the alternate path for .mylogin.cnf is detected.""" + original_env = None + if 'MYSQL_TEST_LOGIN_FILE' in os.environ: + original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE') + + _, temp_path = tempfile.mkstemp() + os.environ['MYSQL_TEST_LOGIN_FILE'] = temp_path + + login_cnf_path = get_mylogin_cnf_path() + + if original_env is not None: + os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env + + assert temp_path == login_cnf_path + + +def test_str_to_bool(): + """Tests that str_to_bool function converts values correctly.""" + + assert str_to_bool(False) is False + assert str_to_bool(True) is True + assert str_to_bool('False') is False + assert str_to_bool('True') is True + assert str_to_bool('TRUE') is True + assert str_to_bool('1') is True + assert str_to_bool('0') is False + assert str_to_bool('on') is True + assert str_to_bool('off') is False + assert str_to_bool('off') is False + + with pytest.raises(ValueError): + str_to_bool('foo') + + with pytest.raises(TypeError): + str_to_bool(None) + + +def test_read_config_file_list_values_default(): + """Test that reading a config file uses list_values by default.""" + + f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n") + config = read_config_file(f) + + assert config['main']['weather'] == u"cloudy with a chance of meatballs" + + +def test_read_config_file_list_values_off(): + """Test that you can disable list_values when reading a config file.""" + + f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n") + config = read_config_file(f, list_values=False) + + assert config['main']['weather'] == u"'cloudy with a chance of meatballs'" + + +def test_strip_quotes_with_matching_quotes(): + """Test that a string with matching quotes is unquoted.""" + + s = "May the force be with you." + assert s == strip_matching_quotes('"{}"'.format(s)) + assert s == strip_matching_quotes("'{}'".format(s)) + + +def test_strip_quotes_with_unmatching_quotes(): + """Test that a string with unmatching quotes is not unquoted.""" + + s = "May the force be with you." + assert '"' + s == strip_matching_quotes('"{}'.format(s)) + assert s + "'" == strip_matching_quotes("{}'".format(s)) + + +def test_strip_quotes_with_empty_string(): + """Test that an empty string is handled during unquoting.""" + + assert '' == strip_matching_quotes('') + + +def test_strip_quotes_with_none(): + """Test that None is handled during unquoting.""" + + assert None is strip_matching_quotes(None) + + +def test_strip_quotes_with_quotes(): + """Test that strings with quotes in them are handled during unquoting.""" + + s1 = 'Darth Vader said, "Luke, I am your father."' + assert s1 == strip_matching_quotes(s1) + + s2 = '"Darth Vader said, "Luke, I am your father.""' + assert s2[1:-1] == strip_matching_quotes(s2) diff --git a/test/test_dbspecial.py b/test/test_dbspecial.py new file mode 100644 index 0000000..21e389c --- /dev/null +++ b/test/test_dbspecial.py @@ -0,0 +1,42 @@ +from mycli.packages.completion_engine import suggest_type +from .test_completion_engine import sorted_dicts +from mycli.packages.special.utils import format_uptime + + +def test_u_suggests_databases(): + suggestions = suggest_type('\\u ', '\\u ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'database'}]) + + +def test_describe_table(): + suggestions = suggest_type('\\dt', '\\dt ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + +def test_list_or_show_create_tables(): + suggestions = suggest_type('\\dt+', '\\dt+ ') + assert sorted_dicts(suggestions) == sorted_dicts([ + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + {'type': 'schema'}]) + + +def test_format_uptime(): + seconds = 59 + assert '59 sec' == format_uptime(seconds) + + seconds = 120 + assert '2 min 0 sec' == format_uptime(seconds) + + seconds = 54890 + assert '15 hours 14 min 50 sec' == format_uptime(seconds) + + seconds = 598244 + assert '6 days 22 hours 10 min 44 sec' == format_uptime(seconds) + + seconds = 522600 + assert '6 days 1 hour 10 min 0 sec' == format_uptime(seconds) diff --git a/test/test_main.py b/test/test_main.py new file mode 100644 index 0000000..3f92bd1 --- /dev/null +++ b/test/test_main.py @@ -0,0 +1,494 @@ +import os + +import click +from click.testing import CliRunner + +from mycli.main import MyCli, cli, thanks_picker, PACKAGE_ROOT +from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS +from .utils import USER, HOST, PORT, PASSWORD, dbtest, run + +from textwrap import dedent +from collections import namedtuple + +from tempfile import NamedTemporaryFile +from textwrap import dedent + + +test_dir = os.path.abspath(os.path.dirname(__file__)) +project_dir = os.path.dirname(test_dir) +default_config_file = os.path.join(project_dir, 'test', 'myclirc') +login_path_file = os.path.join(test_dir, 'mylogin.cnf') + +os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file +CLI_ARGS = ['--user', USER, '--host', HOST, '--port', PORT, + '--password', PASSWORD, '--myclirc', default_config_file, + '--defaults-file', default_config_file, + '_test_db'] + + +@dbtest +def test_execute_arg(executor): + run(executor, 'create table test (a text)') + run(executor, 'insert into test values("abc")') + + sql = 'select * from test;' + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql]) + + assert result.exit_code == 0 + assert 'abc' in result.output + + result = runner.invoke(cli, args=CLI_ARGS + ['--execute', sql]) + + assert result.exit_code == 0 + assert 'abc' in result.output + + expected = 'a\nabc\n' + + assert expected in result.output + + +@dbtest +def test_execute_arg_with_table(executor): + run(executor, 'create table test (a text)') + run(executor, 'insert into test values("abc")') + + sql = 'select * from test;' + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--table']) + expected = '+-----+\n| a |\n+-----+\n| abc |\n+-----+\n' + + assert result.exit_code == 0 + assert expected in result.output + + +@dbtest +def test_execute_arg_with_csv(executor): + run(executor, 'create table test (a text)') + run(executor, 'insert into test values("abc")') + + sql = 'select * from test;' + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--csv']) + expected = '"a"\n"abc"\n' + + assert result.exit_code == 0 + assert expected in "".join(result.output) + + +@dbtest +def test_batch_mode(executor): + run(executor, '''create table test(a text)''') + run(executor, '''insert into test values('abc'), ('def'), ('ghi')''') + + sql = ( + 'select count(*) from test;\n' + 'select * from test limit 1;' + ) + + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS, input=sql) + + assert result.exit_code == 0 + assert 'count(*)\n3\na\nabc\n' in "".join(result.output) + + +@dbtest +def test_batch_mode_table(executor): + run(executor, '''create table test(a text)''') + run(executor, '''insert into test values('abc'), ('def'), ('ghi')''') + + sql = ( + 'select count(*) from test;\n' + 'select * from test limit 1;' + ) + + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS + ['-t'], input=sql) + + expected = (dedent("""\ + +----------+ + | count(*) | + +----------+ + | 3 | + +----------+ + +-----+ + | a | + +-----+ + | abc | + +-----+""")) + + assert result.exit_code == 0 + assert expected in result.output + + +@dbtest +def test_batch_mode_csv(executor): + run(executor, '''create table test(a text, b text)''') + run(executor, + '''insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')''') + + sql = 'select * from test;' + + runner = CliRunner() + result = runner.invoke(cli, args=CLI_ARGS + ['--csv'], input=sql) + + expected = '"a","b"\n"abc","de\nf"\n"ghi","jkl"\n' + + assert result.exit_code == 0 + assert expected in "".join(result.output) + + +def test_thanks_picker_utf8(): + author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS') + sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS') + + name = thanks_picker((author_file, sponsor_file)) + assert name and isinstance(name, str) + + +def test_help_strings_end_with_periods(): + """Make sure click options have help text that end with a period.""" + for param in cli.params: + if isinstance(param, click.core.Option): + assert hasattr(param, 'help') + assert param.help.endswith('.') + + +def test_command_descriptions_end_with_periods(): + """Make sure that mycli commands' descriptions end with a period.""" + MyCli() + for _, command in SPECIAL_COMMANDS.items(): + assert command[3].endswith('.') + + +def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager): + global clickoutput + clickoutput = "" + m = MyCli(myclirc=default_config_file) + + class TestOutput(): + def get_size(self): + size = namedtuple('Size', 'rows columns') + size.columns, size.rows = terminal_size + return size + + class TestExecute(): + host = 'test' + user = 'test' + dbname = 'test' + port = 0 + + def server_type(self): + return ['test'] + + class PromptBuffer(): + output = TestOutput() + + m.prompt_app = PromptBuffer() + m.sqlexecute = TestExecute() + m.explicit_pager = explicit_pager + + def echo_via_pager(s): + assert expect_pager + global clickoutput + clickoutput += "".join(s) + + def secho(s): + assert not expect_pager + global clickoutput + clickoutput += s + "\n" + + monkeypatch.setattr(click, 'echo_via_pager', echo_via_pager) + monkeypatch.setattr(click, 'secho', secho) + m.output(testdata) + if clickoutput.endswith("\n"): + clickoutput = clickoutput[:-1] + assert clickoutput == "\n".join(testdata) + + +def test_conditional_pager(monkeypatch): + testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split( + " ") + # User didn't set pager, output doesn't fit screen -> pager + output( + monkeypatch, + terminal_size=(5, 10), + testdata=testdata, + explicit_pager=False, + expect_pager=True + ) + # User didn't set pager, output fits screen -> no pager + output( + monkeypatch, + terminal_size=(20, 20), + testdata=testdata, + explicit_pager=False, + expect_pager=False + ) + # User manually configured pager, output doesn't fit screen -> pager + output( + monkeypatch, + terminal_size=(5, 10), + testdata=testdata, + explicit_pager=True, + expect_pager=True + ) + # User manually configured pager, output fit screen -> pager + output( + monkeypatch, + terminal_size=(20, 20), + testdata=testdata, + explicit_pager=True, + expect_pager=True + ) + + SPECIAL_COMMANDS['nopager'].handler() + output( + monkeypatch, + terminal_size=(5, 10), + testdata=testdata, + explicit_pager=False, + expect_pager=False + ) + SPECIAL_COMMANDS['pager'].handler('') + + +def test_reserved_space_is_integer(): + """Make sure that reserved space is returned as an integer.""" + def stub_terminal_size(): + return (5, 5) + + old_func = click.get_terminal_size + + click.get_terminal_size = stub_terminal_size + mycli = MyCli() + assert isinstance(mycli.get_reserved_space(), int) + + click.get_terminal_size = old_func + + +def test_list_dsn(): + runner = CliRunner() + with NamedTemporaryFile(mode="w") as myclirc: + myclirc.write(dedent("""\ + [alias_dsn] + test = mysql://test/test + """)) + myclirc.flush() + args = ['--list-dsn', '--myclirc', myclirc.name] + result = runner.invoke(cli, args=args) + assert result.output == "test\n" + result = runner.invoke(cli, args=args + ['--verbose']) + assert result.output == "test : mysql://test/test\n" + + +def test_list_ssh_config(): + runner = CliRunner() + with NamedTemporaryFile(mode="w") as ssh_config: + ssh_config.write(dedent("""\ + Host test + Hostname test.example.com + User joe + Port 22222 + IdentityFile ~/.ssh/gateway + """)) + ssh_config.flush() + args = ['--list-ssh-config', '--ssh-config-path', ssh_config.name] + result = runner.invoke(cli, args=args) + assert "test\n" in result.output + result = runner.invoke(cli, args=args + ['--verbose']) + assert "test : test.example.com\n" in result.output + + +def test_dsn(monkeypatch): + # Setup classes to mock mycli.main.MyCli + class Formatter: + format_name = None + class Logger: + def debug(self, *args, **args_dict): + pass + def warning(self, *args, **args_dict): + pass + class MockMyCli: + config = {'alias_dsn': {}} + def __init__(self, **args): + self.logger = Logger() + self.destructive_warning = False + self.formatter = Formatter() + def connect(self, **args): + MockMyCli.connect_args = args + def run_query(self, query, new_line=True): + pass + + import mycli.main + monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + runner = CliRunner() + + # When a user supplies a DSN as database argument to mycli, + # use these values. + result = runner.invoke(mycli.main.cli, args=[ + "mysql://dsn_user:dsn_passwd@dsn_host:1/dsn_database"] + ) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert \ + MockMyCli.connect_args["user"] == "dsn_user" and \ + MockMyCli.connect_args["passwd"] == "dsn_passwd" and \ + MockMyCli.connect_args["host"] == "dsn_host" and \ + MockMyCli.connect_args["port"] == 1 and \ + MockMyCli.connect_args["database"] == "dsn_database" + + MockMyCli.connect_args = None + + # When a use supplies a DSN as database argument to mycli, + # and used command line arguments, use the command line + # arguments. + result = runner.invoke(mycli.main.cli, args=[ + "mysql://dsn_user:dsn_passwd@dsn_host:2/dsn_database", + "--user", "arg_user", + "--password", "arg_password", + "--host", "arg_host", + "--port", "3", + "--database", "arg_database", + ]) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert \ + MockMyCli.connect_args["user"] == "arg_user" and \ + MockMyCli.connect_args["passwd"] == "arg_password" and \ + MockMyCli.connect_args["host"] == "arg_host" and \ + MockMyCli.connect_args["port"] == 3 and \ + MockMyCli.connect_args["database"] == "arg_database" + + MockMyCli.config = { + 'alias_dsn': { + 'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database' + } + } + MockMyCli.connect_args = None + + # When a user uses a DSN from the configuration file (alias_dsn), + # use these values. + result = runner.invoke(cli, args=['--dsn', 'test']) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert \ + MockMyCli.connect_args["user"] == "alias_dsn_user" and \ + MockMyCli.connect_args["passwd"] == "alias_dsn_passwd" and \ + MockMyCli.connect_args["host"] == "alias_dsn_host" and \ + MockMyCli.connect_args["port"] == 4 and \ + MockMyCli.connect_args["database"] == "alias_dsn_database" + + MockMyCli.config = { + 'alias_dsn': { + 'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database' + } + } + MockMyCli.connect_args = None + + # When a user uses a DSN from the configuration file (alias_dsn) + # and used command line arguments, use the command line arguments. + result = runner.invoke(cli, args=[ + '--dsn', 'test', '', + "--user", "arg_user", + "--password", "arg_password", + "--host", "arg_host", + "--port", "5", + "--database", "arg_database", + ]) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert \ + MockMyCli.connect_args["user"] == "arg_user" and \ + MockMyCli.connect_args["passwd"] == "arg_password" and \ + MockMyCli.connect_args["host"] == "arg_host" and \ + MockMyCli.connect_args["port"] == 5 and \ + MockMyCli.connect_args["database"] == "arg_database" + + # Use a DSN without password + result = runner.invoke(mycli.main.cli, args=[ + "mysql://dsn_user@dsn_host:6/dsn_database"] + ) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert \ + MockMyCli.connect_args["user"] == "dsn_user" and \ + MockMyCli.connect_args["passwd"] is None and \ + MockMyCli.connect_args["host"] == "dsn_host" and \ + MockMyCli.connect_args["port"] == 6 and \ + MockMyCli.connect_args["database"] == "dsn_database" + + +def test_ssh_config(monkeypatch): + # Setup classes to mock mycli.main.MyCli + class Formatter: + format_name = None + + class Logger: + def debug(self, *args, **args_dict): + pass + + def warning(self, *args, **args_dict): + pass + + class MockMyCli: + config = {'alias_dsn': {}} + + def __init__(self, **args): + self.logger = Logger() + self.destructive_warning = False + self.formatter = Formatter() + + def connect(self, **args): + MockMyCli.connect_args = args + + def run_query(self, query, new_line=True): + pass + + import mycli.main + monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + runner = CliRunner() + + # Setup temporary configuration + with NamedTemporaryFile(mode="w") as ssh_config: + ssh_config.write(dedent("""\ + Host test + Hostname test.example.com + User joe + Port 22222 + IdentityFile ~/.ssh/gateway + """)) + ssh_config.flush() + + # When a user supplies a ssh config. + result = runner.invoke(mycli.main.cli, args=[ + "--ssh-config-path", + ssh_config.name, + "--ssh-config-host", + "test" + ]) + assert result.exit_code == 0, result.output + \ + " " + str(result.exception) + assert \ + MockMyCli.connect_args["ssh_user"] == "joe" and \ + MockMyCli.connect_args["ssh_host"] == "test.example.com" and \ + MockMyCli.connect_args["ssh_port"] == 22222 and \ + MockMyCli.connect_args["ssh_key_filename"] == os.getenv( + "HOME") + "/.ssh/gateway" + + # When a user supplies a ssh config host as argument to mycli, + # and used command line arguments, use the command line + # arguments. + result = runner.invoke(mycli.main.cli, args=[ + "--ssh-config-path", + ssh_config.name, + "--ssh-config-host", + "test", + "--ssh-user", "arg_user", + "--ssh-host", "arg_host", + "--ssh-port", "3", + "--ssh-key-filename", "/path/to/key" + ]) + assert result.exit_code == 0, result.output + \ + " " + str(result.exception) + assert \ + MockMyCli.connect_args["ssh_user"] == "arg_user" and \ + MockMyCli.connect_args["ssh_host"] == "arg_host" and \ + MockMyCli.connect_args["ssh_port"] == 3 and \ + MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key" diff --git a/test/test_naive_completion.py b/test/test_naive_completion.py new file mode 100644 index 0000000..14c1bf5 --- /dev/null +++ b/test/test_naive_completion.py @@ -0,0 +1,63 @@ +import pytest +from prompt_toolkit.completion import Completion +from prompt_toolkit.document import Document + + +@pytest.fixture +def completer(): + import mycli.sqlcompleter as sqlcompleter + return sqlcompleter.SQLCompleter(smart_completion=False) + + +@pytest.fixture +def complete_event(): + from mock import Mock + return Mock() + + +def test_empty_string_completion(completer, complete_event): + text = '' + position = 0 + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list(map(Completion, sorted(completer.all_completions))) + + +def test_select_keyword_completion(completer, complete_event): + text = 'SEL' + position = len('SEL') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([Completion(text='SELECT', start_position=-3)]) + + +def test_function_name_completion(completer, complete_event): + text = 'SELECT MA' + position = len('SELECT MA') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='MASTER', start_position=-2), + Completion(text='MAX', start_position=-2)]) + + +def test_column_name_completion(completer, complete_event): + text = 'SELECT FROM users' + position = len('SELECT ') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list(map(Completion, sorted(completer.all_completions))) + + +def test_special_name_completion(completer, complete_event): + text = '\\' + position = len('\\') + result = set(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + # Special commands will NOT be suggested during naive completion mode. + assert result == set() diff --git a/test/test_parseutils.py b/test/test_parseutils.py new file mode 100644 index 0000000..920a08d --- /dev/null +++ b/test/test_parseutils.py @@ -0,0 +1,190 @@ +import pytest +from mycli.packages.parseutils import ( + extract_tables, query_starts_with, queries_start_with, is_destructive, query_has_where_clause, + is_dropping_database) + + +def test_empty_string(): + tables = extract_tables('') + assert tables == [] + + +def test_simple_select_single_table(): + tables = extract_tables('select * from abc') + assert tables == [(None, 'abc', None)] + + +def test_simple_select_single_table_schema_qualified(): + tables = extract_tables('select * from abc.def') + assert tables == [('abc', 'def', None)] + + +def test_simple_select_multiple_tables(): + tables = extract_tables('select * from abc, def') + assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] + + +def test_simple_select_multiple_tables_schema_qualified(): + tables = extract_tables('select * from abc.def, ghi.jkl') + assert sorted(tables) == [('abc', 'def', None), ('ghi', 'jkl', None)] + + +def test_simple_select_with_cols_single_table(): + tables = extract_tables('select a,b from abc') + assert tables == [(None, 'abc', None)] + + +def test_simple_select_with_cols_single_table_schema_qualified(): + tables = extract_tables('select a,b from abc.def') + assert tables == [('abc', 'def', None)] + + +def test_simple_select_with_cols_multiple_tables(): + tables = extract_tables('select a,b from abc, def') + assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] + + +def test_simple_select_with_cols_multiple_tables_with_schema(): + tables = extract_tables('select a,b from abc.def, def.ghi') + assert sorted(tables) == [('abc', 'def', None), ('def', 'ghi', None)] + + +def test_select_with_hanging_comma_single_table(): + tables = extract_tables('select a, from abc') + assert tables == [(None, 'abc', None)] + + +def test_select_with_hanging_comma_multiple_tables(): + tables = extract_tables('select a, from abc, def') + assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] + + +def test_select_with_hanging_period_multiple_tables(): + tables = extract_tables('SELECT t1. FROM tabl1 t1, tabl2 t2') + assert sorted(tables) == [(None, 'tabl1', 't1'), (None, 'tabl2', 't2')] + + +def test_simple_insert_single_table(): + tables = extract_tables('insert into abc (id, name) values (1, "def")') + + # sqlparse mistakenly assigns an alias to the table + # assert tables == [(None, 'abc', None)] + assert tables == [(None, 'abc', 'abc')] + + +@pytest.mark.xfail +def test_simple_insert_single_table_schema_qualified(): + tables = extract_tables('insert into abc.def (id, name) values (1, "def")') + assert tables == [('abc', 'def', None)] + + +def test_simple_update_table(): + tables = extract_tables('update abc set id = 1') + assert tables == [(None, 'abc', None)] + + +def test_simple_update_table_with_schema(): + tables = extract_tables('update abc.def set id = 1') + assert tables == [('abc', 'def', None)] + + +def test_join_table(): + tables = extract_tables('SELECT * FROM abc a JOIN def d ON a.id = d.num') + assert sorted(tables) == [(None, 'abc', 'a'), (None, 'def', 'd')] + + +def test_join_table_schema_qualified(): + tables = extract_tables( + 'SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num') + assert tables == [('abc', 'def', 'x'), ('ghi', 'jkl', 'y')] + + +def test_join_as_table(): + tables = extract_tables('SELECT * FROM my_table AS m WHERE m.a > 5') + assert tables == [(None, 'my_table', 'm')] + + +def test_query_starts_with(): + query = 'USE test;' + assert query_starts_with(query, ('use', )) is True + + query = 'DROP DATABASE test;' + assert query_starts_with(query, ('use', )) is False + + +def test_query_starts_with_comment(): + query = '# comment\nUSE test;' + assert query_starts_with(query, ('use', )) is True + + +def test_queries_start_with(): + sql = ( + '# comment\n' + 'show databases;' + 'use foo;' + ) + assert queries_start_with(sql, ('show', 'select')) is True + assert queries_start_with(sql, ('use', 'drop')) is True + assert queries_start_with(sql, ('delete', 'update')) is False + + +def test_is_destructive(): + sql = ( + 'use test;\n' + 'show databases;\n' + 'drop database foo;' + ) + assert is_destructive(sql) is True + + +def test_is_destructive_update_with_where_clause(): + sql = ( + 'use test;\n' + 'show databases;\n' + 'UPDATE test SET x = 1 WHERE id = 1;' + ) + assert is_destructive(sql) is False + + +def test_is_destructive_update_without_where_clause(): + sql = ( + 'use test;\n' + 'show databases;\n' + 'UPDATE test SET x = 1;' + ) + assert is_destructive(sql) is True + + +@pytest.mark.parametrize( + ('sql', 'has_where_clause'), + [ + ('update test set dummy = 1;', False), + ('update test set dummy = 1 where id = 1);', True), + ], +) +def test_query_has_where_clause(sql, has_where_clause): + assert query_has_where_clause(sql) is has_where_clause + + +@pytest.mark.parametrize( + ('sql', 'dbname', 'is_dropping'), + [ + ('select bar from foo', 'foo', False), + ('drop database "foo";', '`foo`', True), + ('drop schema foo', 'foo', True), + ('drop schema foo', 'bar', False), + ('drop database bar', 'foo', False), + ('drop database foo', None, False), + ('drop database foo; create database foo', 'foo', False), + ('drop database foo; create database bar', 'foo', True), + ('select bar from foo; drop database bazz', 'foo', False), + ('select bar from foo; drop database bazz', 'bazz', True), + ('-- dropping database \n ' + 'drop -- really dropping \n ' + 'schema abc -- now it is dropped', + 'abc', + True) + ] +) +def test_is_dropping_database(sql, dbname, is_dropping): + assert is_dropping_database(sql, dbname) == is_dropping diff --git a/test/test_plan.wiki b/test/test_plan.wiki new file mode 100644 index 0000000..43e9083 --- /dev/null +++ b/test/test_plan.wiki @@ -0,0 +1,38 @@ += Gross Checks = + * [ ] Check connecting to a local database. + * [ ] Check connecting to a remote database. + * [ ] Check connecting to a database with a user/password. + * [ ] Check connecting to a non-existent database. + * [ ] Test changing the database. + + == PGExecute == + * [ ] Test successful execution given a cursor. + * [ ] Test unsuccessful execution with a syntax error. + * [ ] Test a series of executions with the same cursor without failure. + * [ ] Test a series of executions with the same cursor with failure. + * [ ] Test passing in a special command. + + == Naive Autocompletion == + * [ ] Input empty string, ask for completions - Everything. + * [ ] Input partial prefix, ask for completions - Stars with prefix. + * [ ] Input fully autocompleted string, ask for completions - Only full match + * [ ] Input non-existent prefix, ask for completions - nothing + * [ ] Input lowercase prefix - case insensitive completions + + == Smart Autocompletion == + * [ ] Input empty string and check if only keywords are returned. + * [ ] Input SELECT prefix and check if only columns and '*' are returned. + * [ ] Input SELECT blah - only keywords are returned. + * [ ] Input SELECT * FROM - Table names only + + == PGSpecial == + * [ ] Test \d + * [ ] Test \d tablename + * [ ] Test \d tablena* + * [ ] Test \d non-existent-tablename + * [ ] Test \d index + * [ ] Test \d sequence + * [ ] Test \d view + + == Exceptionals == + * [ ] Test the 'use' command to change db. diff --git a/test/test_prompt_utils.py b/test/test_prompt_utils.py new file mode 100644 index 0000000..2373fac --- /dev/null +++ b/test/test_prompt_utils.py @@ -0,0 +1,11 @@ +import click + +from mycli.packages.prompt_utils import confirm_destructive_query + + +def test_confirm_destructive_query_notty(): + stdin = click.get_text_stream('stdin') + assert stdin.isatty() is False + + sql = 'drop database foo;' + assert confirm_destructive_query(sql) is None diff --git a/test/test_smart_completion_public_schema_only.py b/test/test_smart_completion_public_schema_only.py new file mode 100644 index 0000000..b66c696 --- /dev/null +++ b/test/test_smart_completion_public_schema_only.py @@ -0,0 +1,385 @@ +import pytest +from mock import patch +from prompt_toolkit.completion import Completion +from prompt_toolkit.document import Document +import mycli.packages.special.main as special + +metadata = { + 'users': ['id', 'email', 'first_name', 'last_name'], + 'orders': ['id', 'ordered_date', 'status'], + 'select': ['id', 'insert', 'ABC'], + 'réveillé': ['id', 'insert', 'ABC'] +} + + +@pytest.fixture +def completer(): + + import mycli.sqlcompleter as sqlcompleter + comp = sqlcompleter.SQLCompleter(smart_completion=True) + + tables, columns = [], [] + + for table, cols in metadata.items(): + tables.append((table,)) + columns.extend([(table, col) for col in cols]) + + comp.set_dbname('test') + comp.extend_schemata('test') + comp.extend_relations(tables, kind='tables') + comp.extend_columns(columns, kind='tables') + comp.extend_special_commands(special.COMMANDS) + + return comp + + +@pytest.fixture +def complete_event(): + from mock import Mock + return Mock() + + +def test_special_name_completion(completer, complete_event): + text = '\\d' + position = len('\\d') + result = completer.get_completions( + Document(text=text, cursor_position=position), + complete_event) + assert result == [Completion(text='\\dt', start_position=-2)] + + +def test_empty_string_completion(completer, complete_event): + text = '' + position = 0 + result = list( + completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert list(map(Completion, sorted(completer.keywords) + + sorted(completer.special_commands))) == result + + +def test_select_keyword_completion(completer, complete_event): + text = 'SEL' + position = len('SEL') + result = completer.get_completions( + Document(text=text, cursor_position=position), + complete_event) + assert list(result) == list([Completion(text='SELECT', start_position=-3)]) + + +def test_table_completion(completer, complete_event): + text = 'SELECT * FROM ' + position = len(text) + result = completer.get_completions( + Document(text=text, cursor_position=position), complete_event) + assert list(result) == list([ + Completion(text='`réveillé`', start_position=0), + Completion(text='`select`', start_position=0), + Completion(text='orders', start_position=0), + Completion(text='users', start_position=0), + ]) + + +def test_function_name_completion(completer, complete_event): + text = 'SELECT MA' + position = len('SELECT MA') + result = completer.get_completions( + Document(text=text, cursor_position=position), complete_event) + assert list(result) == list([Completion(text='MAX', start_position=-2), + Completion(text='MASTER', start_position=-2), + ]) + + +def test_suggested_column_names(completer, complete_event): + """Suggest column and function names when selecting from table. + + :param completer: + :param complete_event: + :return: + + """ + text = 'SELECT from users' + position = len('SELECT ') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='*', start_position=0), + Completion(text='email', start_position=0), + Completion(text='first_name', start_position=0), + Completion(text='id', start_position=0), + Completion(text='last_name', start_position=0), + ] + + list(map(Completion, completer.functions)) + + [Completion(text='users', start_position=0)] + + list(map(Completion, completer.keywords))) + + +def test_suggested_column_names_in_function(completer, complete_event): + """Suggest column and function names when selecting multiple columns from + table. + + :param completer: + :param complete_event: + :return: + + """ + text = 'SELECT MAX( from users' + position = len('SELECT MAX(') + result = completer.get_completions( + Document(text=text, cursor_position=position), + complete_event) + assert list(result) == list([ + Completion(text='*', start_position=0), + Completion(text='email', start_position=0), + Completion(text='first_name', start_position=0), + Completion(text='id', start_position=0), + Completion(text='last_name', start_position=0)]) + + +def test_suggested_column_names_with_table_dot(completer, complete_event): + """Suggest column names on table name and dot. + + :param completer: + :param complete_event: + :return: + + """ + text = 'SELECT users. from users' + position = len('SELECT users.') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='*', start_position=0), + Completion(text='email', start_position=0), + Completion(text='first_name', start_position=0), + Completion(text='id', start_position=0), + Completion(text='last_name', start_position=0)]) + + +def test_suggested_column_names_with_alias(completer, complete_event): + """Suggest column names on table alias and dot. + + :param completer: + :param complete_event: + :return: + + """ + text = 'SELECT u. from users u' + position = len('SELECT u.') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='*', start_position=0), + Completion(text='email', start_position=0), + Completion(text='first_name', start_position=0), + Completion(text='id', start_position=0), + Completion(text='last_name', start_position=0)]) + + +def test_suggested_multiple_column_names(completer, complete_event): + """Suggest column and function names when selecting multiple columns from + table. + + :param completer: + :param complete_event: + :return: + + """ + text = 'SELECT id, from users u' + position = len('SELECT id, ') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='*', start_position=0), + Completion(text='email', start_position=0), + Completion(text='first_name', start_position=0), + Completion(text='id', start_position=0), + Completion(text='last_name', start_position=0)] + + list(map(Completion, completer.functions)) + + [Completion(text='u', start_position=0)] + + list(map(Completion, completer.keywords))) + + +def test_suggested_multiple_column_names_with_alias(completer, complete_event): + """Suggest column names on table alias and dot when selecting multiple + columns from table. + + :param completer: + :param complete_event: + :return: + + """ + text = 'SELECT u.id, u. from users u' + position = len('SELECT u.id, u.') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='*', start_position=0), + Completion(text='email', start_position=0), + Completion(text='first_name', start_position=0), + Completion(text='id', start_position=0), + Completion(text='last_name', start_position=0)]) + + +def test_suggested_multiple_column_names_with_dot(completer, complete_event): + """Suggest column names on table names and dot when selecting multiple + columns from table. + + :param completer: + :param complete_event: + :return: + + """ + text = 'SELECT users.id, users. from users u' + position = len('SELECT users.id, users.') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='*', start_position=0), + Completion(text='email', start_position=0), + Completion(text='first_name', start_position=0), + Completion(text='id', start_position=0), + Completion(text='last_name', start_position=0)]) + + +def test_suggested_aliases_after_on(completer, complete_event): + text = 'SELECT u.name, o.id FROM users u JOIN orders o ON ' + position = len('SELECT u.name, o.id FROM users u JOIN orders o ON ') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='o', start_position=0), + Completion(text='u', start_position=0)]) + + +def test_suggested_aliases_after_on_right_side(completer, complete_event): + text = 'SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ' + position = len( + 'SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='o', start_position=0), + Completion(text='u', start_position=0)]) + + +def test_suggested_tables_after_on(completer, complete_event): + text = 'SELECT users.name, orders.id FROM users JOIN orders ON ' + position = len('SELECT users.name, orders.id FROM users JOIN orders ON ') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='orders', start_position=0), + Completion(text='users', start_position=0)]) + + +def test_suggested_tables_after_on_right_side(completer, complete_event): + text = 'SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ' + position = len( + 'SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='orders', start_position=0), + Completion(text='users', start_position=0)]) + + +def test_table_names_after_from(completer, complete_event): + text = 'SELECT * FROM ' + position = len('SELECT * FROM ') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='`réveillé`', start_position=0), + Completion(text='`select`', start_position=0), + Completion(text='orders', start_position=0), + Completion(text='users', start_position=0), + ]) + + +def test_auto_escaped_col_names(completer, complete_event): + text = 'SELECT from `select`' + position = len('SELECT ') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == [ + Completion(text='*', start_position=0), + Completion(text='`ABC`', start_position=0), + Completion(text='`insert`', start_position=0), + Completion(text='id', start_position=0), + ] + \ + list(map(Completion, completer.functions)) + \ + [Completion(text='`select`', start_position=0)] + \ + list(map(Completion, completer.keywords)) + + +def test_un_escaped_table_names(completer, complete_event): + text = 'SELECT from réveillé' + position = len('SELECT ') + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + assert result == list([ + Completion(text='*', start_position=0), + Completion(text='`ABC`', start_position=0), + Completion(text='`insert`', start_position=0), + Completion(text='id', start_position=0), + ] + + list(map(Completion, completer.functions)) + + [Completion(text='réveillé', start_position=0)] + + list(map(Completion, completer.keywords))) + + +def dummy_list_path(dir_name): + dirs = { + '/': [ + 'dir1', + 'file1.sql', + 'file2.sql', + ], + '/dir1': [ + 'subdir1', + 'subfile1.sql', + 'subfile2.sql', + ], + '/dir1/subdir1': [ + 'lastfile.sql', + ], + } + return dirs.get(dir_name, []) + + +@patch('mycli.packages.filepaths.list_path', new=dummy_list_path) +@pytest.mark.parametrize('text,expected', [ + # ('source ', [('~', 0), + # ('/', 0), + # ('.', 0), + # ('..', 0)]), + ('source /', [('dir1', 0), + ('file1.sql', 0), + ('file2.sql', 0)]), + ('source /dir1/', [('subdir1', 0), + ('subfile1.sql', 0), + ('subfile2.sql', 0)]), + ('source /dir1/subdir1/', [('lastfile.sql', 0)]), +]) +def test_file_name_completion(completer, complete_event, text, expected): + position = len(text) + result = list(completer.get_completions( + Document(text=text, cursor_position=position), + complete_event)) + expected = list((Completion(txt, pos) for txt, pos in expected)) + assert result == expected diff --git a/test/test_special_iocommands.py b/test/test_special_iocommands.py new file mode 100644 index 0000000..b8b3acb --- /dev/null +++ b/test/test_special_iocommands.py @@ -0,0 +1,272 @@ +import os +import stat +import tempfile +from time import time +from mock import patch + +import pytest +from pymysql import ProgrammingError + +import mycli.packages.special + +from .utils import dbtest, db_connection, send_ctrl_c + + +def test_set_get_pager(): + mycli.packages.special.set_pager_enabled(True) + assert mycli.packages.special.is_pager_enabled() + mycli.packages.special.set_pager_enabled(False) + assert not mycli.packages.special.is_pager_enabled() + mycli.packages.special.set_pager('less') + assert os.environ['PAGER'] == "less" + mycli.packages.special.set_pager(False) + assert os.environ['PAGER'] == "less" + del os.environ['PAGER'] + mycli.packages.special.set_pager(False) + mycli.packages.special.disable_pager() + assert not mycli.packages.special.is_pager_enabled() + + +def test_set_get_timing(): + mycli.packages.special.set_timing_enabled(True) + assert mycli.packages.special.is_timing_enabled() + mycli.packages.special.set_timing_enabled(False) + assert not mycli.packages.special.is_timing_enabled() + + +def test_set_get_expanded_output(): + mycli.packages.special.set_expanded_output(True) + assert mycli.packages.special.is_expanded_output() + mycli.packages.special.set_expanded_output(False) + assert not mycli.packages.special.is_expanded_output() + + +def test_editor_command(): + assert mycli.packages.special.editor_command(r'hello\e') + assert mycli.packages.special.editor_command(r'\ehello') + assert not mycli.packages.special.editor_command(r'hello') + + assert mycli.packages.special.get_filename(r'\e filename') == "filename" + + os.environ['EDITOR'] = 'true' + mycli.packages.special.open_external_editor(r'select 1') == "select 1" + + +def test_tee_command(): + mycli.packages.special.write_tee(u"hello world") # write without file set + with tempfile.NamedTemporaryFile() as f: + mycli.packages.special.execute(None, u"tee " + f.name) + mycli.packages.special.write_tee(u"hello world") + assert f.read() == b"hello world\n" + + mycli.packages.special.execute(None, u"tee -o " + f.name) + mycli.packages.special.write_tee(u"hello world") + f.seek(0) + assert f.read() == b"hello world\n" + + mycli.packages.special.execute(None, u"notee") + mycli.packages.special.write_tee(u"hello world") + f.seek(0) + assert f.read() == b"hello world\n" + + +def test_tee_command_error(): + with pytest.raises(TypeError): + mycli.packages.special.execute(None, 'tee') + + with pytest.raises(OSError): + with tempfile.NamedTemporaryFile() as f: + os.chmod(f.name, stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH) + mycli.packages.special.execute(None, 'tee {}'.format(f.name)) + + +@dbtest +def test_favorite_query(): + with db_connection().cursor() as cur: + query = u'select "✔"' + mycli.packages.special.execute(cur, u'\\fs check {0}'.format(query)) + assert next(mycli.packages.special.execute( + cur, u'\\f check'))[0] == "> " + query + + +def test_once_command(): + with pytest.raises(TypeError): + mycli.packages.special.execute(None, u"\\once") + + mycli.packages.special.execute(None, u"\\once /proc/access-denied") + with pytest.raises(OSError): + mycli.packages.special.write_once(u"hello world") + + mycli.packages.special.write_once(u"hello world") # write without file set + with tempfile.NamedTemporaryFile() as f: + mycli.packages.special.execute(None, u"\\once " + f.name) + mycli.packages.special.write_once(u"hello world") + assert f.read() == b"hello world\n" + + mycli.packages.special.execute(None, u"\\once -o " + f.name) + mycli.packages.special.write_once(u"hello world") + f.seek(0) + assert f.read() == b"hello world\n" + + +def test_parseargfile(): + """Test that parseargfile expands the user directory.""" + expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'), + 'mode': 'a'} + assert expected == mycli.packages.special.iocommands.parseargfile( + '~/filename') + + expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'), + 'mode': 'w'} + assert expected == mycli.packages.special.iocommands.parseargfile( + '-o ~/filename') + + +def test_parseargfile_no_file(): + """Test that parseargfile raises a TypeError if there is no filename.""" + with pytest.raises(TypeError): + mycli.packages.special.iocommands.parseargfile('') + + with pytest.raises(TypeError): + mycli.packages.special.iocommands.parseargfile('-o ') + + +@dbtest +def test_watch_query_iteration(): + """Test that a single iteration of the result of `watch_query` executes + the desired query and returns the given results.""" + expected_value = "1" + query = "SELECT {0!s}".format(expected_value) + expected_title = '> {0!s}'.format(query) + with db_connection().cursor() as cur: + result = next(mycli.packages.special.iocommands.watch_query( + arg=query, cur=cur + )) + assert result[0] == expected_title + assert result[2][0] == expected_value + + +@dbtest +def test_watch_query_full(): + """Test that `watch_query`: + + * Returns the expected results. + * Executes the defined times inside the given interval, in this case with + a 0.3 seconds wait, it should execute 4 times inside a 1 seconds + interval. + * Stops at Ctrl-C + + """ + watch_seconds = 0.3 + wait_interval = 1 + expected_value = "1" + query = "SELECT {0!s}".format(expected_value) + expected_title = '> {0!s}'.format(query) + expected_results = 4 + ctrl_c_process = send_ctrl_c(wait_interval) + with db_connection().cursor() as cur: + results = list( + result for result in mycli.packages.special.iocommands.watch_query( + arg='{0!s} {1!s}'.format(watch_seconds, query), cur=cur + ) + ) + ctrl_c_process.join(1) + assert len(results) == expected_results + for result in results: + assert result[0] == expected_title + assert result[2][0] == expected_value + + +@dbtest +@patch('click.clear') +def test_watch_query_clear(clear_mock): + """Test that the screen is cleared with the -c flag of `watch` command + before execute the query.""" + with db_connection().cursor() as cur: + watch_gen = mycli.packages.special.iocommands.watch_query( + arg='0.1 -c select 1;', cur=cur + ) + assert not clear_mock.called + next(watch_gen) + assert clear_mock.called + clear_mock.reset_mock() + next(watch_gen) + assert clear_mock.called + clear_mock.reset_mock() + + +@dbtest +def test_watch_query_bad_arguments(): + """Test different incorrect combinations of arguments for `watch` + command.""" + watch_query = mycli.packages.special.iocommands.watch_query + with db_connection().cursor() as cur: + with pytest.raises(ProgrammingError): + next(watch_query('a select 1;', cur=cur)) + with pytest.raises(ProgrammingError): + next(watch_query('-a select 1;', cur=cur)) + with pytest.raises(ProgrammingError): + next(watch_query('1 -a select 1;', cur=cur)) + with pytest.raises(ProgrammingError): + next(watch_query('-c -a select 1;', cur=cur)) + + +@dbtest +@patch('click.clear') +def test_watch_query_interval_clear(clear_mock): + """Test `watch` command with interval and clear flag.""" + def test_asserts(gen): + clear_mock.reset_mock() + start = time() + next(gen) + assert clear_mock.called + next(gen) + exec_time = time() - start + assert exec_time > seconds and exec_time < (seconds + seconds) + + seconds = 1.0 + watch_query = mycli.packages.special.iocommands.watch_query + with db_connection().cursor() as cur: + test_asserts(watch_query('{0!s} -c select 1;'.format(seconds), + cur=cur)) + test_asserts(watch_query('-c {0!s} select 1;'.format(seconds), + cur=cur)) + + +def test_split_sql_by_delimiter(): + for delimiter_str in (';', '$', '😀'): + mycli.packages.special.set_delimiter(delimiter_str) + sql_input = "select 1{} select \ufffc2".format(delimiter_str) + queries = ( + "select 1", + "select \ufffc2" + ) + for query, parsed_query in zip( + queries, mycli.packages.special.split_queries(sql_input)): + assert(query == parsed_query) + + +def test_switch_delimiter_within_query(): + mycli.packages.special.set_delimiter(';') + sql_input = "select 1; delimiter $$ select 2 $$ select 3 $$" + queries = ( + "select 1", + "delimiter $$ select 2 $$ select 3 $$", + "select 2", + "select 3" + ) + for query, parsed_query in zip( + queries, + mycli.packages.special.split_queries(sql_input)): + assert(query == parsed_query) + + +def test_set_delimiter(): + + for delim in ('foo', 'bar'): + mycli.packages.special.set_delimiter(delim) + assert mycli.packages.special.get_current_delimiter() == delim + + +def teardown_function(): + mycli.packages.special.set_delimiter(';') diff --git a/test/test_sqlexecute.py b/test/test_sqlexecute.py new file mode 100644 index 0000000..c2d38be --- /dev/null +++ b/test/test_sqlexecute.py @@ -0,0 +1,272 @@ +import os + +import pytest +import pymysql + +from .utils import run, dbtest, set_expanded_output, is_expanded_output + + +def assert_result_equal(result, title=None, rows=None, headers=None, + status=None, auto_status=True, assert_contains=False): + """Assert that an sqlexecute.run() result matches the expected values.""" + if status is None and auto_status and rows: + status = '{} row{} in set'.format( + len(rows), 's' if len(rows) > 1 else '') + fields = {'title': title, 'rows': rows, 'headers': headers, + 'status': status} + + if assert_contains: + # Do a loose match on the results using the *in* operator. + for key, field in fields.items(): + if field: + assert field in result[0][key] + else: + # Do an exact match on the fields. + assert result == [fields] + + +@dbtest +def test_conn(executor): + run(executor, '''create table test(a text)''') + run(executor, '''insert into test values('abc')''') + results = run(executor, '''select * from test''') + + assert_result_equal(results, headers=['a'], rows=[('abc',)]) + + +@dbtest +def test_bools(executor): + run(executor, '''create table test(a boolean)''') + run(executor, '''insert into test values(True)''') + results = run(executor, '''select * from test''') + + assert_result_equal(results, headers=['a'], rows=[(1,)]) + + +@dbtest +def test_binary(executor): + run(executor, '''create table bt(geom linestring NOT NULL)''') + run(executor, "INSERT INTO bt VALUES " + "(ST_GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));") + results = run(executor, '''select * from bt''') + + geom = (b'\x00\x00\x00\x00\x01\x02\x00\x00\x00\x02\x00\x00\x009\x7f\x13\n' + b'\x11\x18]@4\xf4Op\xb1\xdeC@\x00\x00\x00\x00\x00\x18]@B>\xe8\xd9' + b'\xac\xdeC@') + + assert_result_equal(results, headers=['geom'], rows=[(geom,)]) + + +@dbtest +def test_table_and_columns_query(executor): + run(executor, "create table a(x text, y text)") + run(executor, "create table b(z text)") + + assert set(executor.tables()) == set([('a',), ('b',)]) + assert set(executor.table_columns()) == set( + [('a', 'x'), ('a', 'y'), ('b', 'z')]) + + +@dbtest +def test_database_list(executor): + databases = executor.databases() + assert '_test_db' in databases + + +@dbtest +def test_invalid_syntax(executor): + with pytest.raises(pymysql.ProgrammingError) as excinfo: + run(executor, 'invalid syntax!') + assert 'You have an error in your SQL syntax;' in str(excinfo.value) + + +@dbtest +def test_invalid_column_name(executor): + with pytest.raises(pymysql.err.OperationalError) as excinfo: + run(executor, 'select invalid command') + assert "Unknown column 'invalid' in 'field list'" in str(excinfo.value) + + +@dbtest +def test_unicode_support_in_output(executor): + run(executor, "create table unicodechars(t text)") + run(executor, u"insert into unicodechars (t) values ('é')") + + # See issue #24, this raises an exception without proper handling + results = run(executor, u"select * from unicodechars") + assert_result_equal(results, headers=['t'], rows=[(u'é',)]) + + +@dbtest +def test_multiple_queries_same_line(executor): + results = run(executor, "select 'foo'; select 'bar'") + + expected = [{'title': None, 'headers': ['foo'], 'rows': [('foo',)], + 'status': '1 row in set'}, + {'title': None, 'headers': ['bar'], 'rows': [('bar',)], + 'status': '1 row in set'}] + assert expected == results + + +@dbtest +def test_multiple_queries_same_line_syntaxerror(executor): + with pytest.raises(pymysql.ProgrammingError) as excinfo: + run(executor, "select 'foo'; invalid syntax") + assert 'You have an error in your SQL syntax;' in str(excinfo.value) + + +@dbtest +def test_favorite_query(executor): + set_expanded_output(False) + run(executor, "create table test(a text)") + run(executor, "insert into test values('abc')") + run(executor, "insert into test values('def')") + + results = run(executor, "\\fs test-a select * from test where a like 'a%'") + assert_result_equal(results, status='Saved.') + + results = run(executor, "\\f test-a") + assert_result_equal(results, + title="> select * from test where a like 'a%'", + headers=['a'], rows=[('abc',)], auto_status=False) + + results = run(executor, "\\fd test-a") + assert_result_equal(results, status='test-a: Deleted') + + +@dbtest +def test_favorite_query_multiple_statement(executor): + set_expanded_output(False) + run(executor, "create table test(a text)") + run(executor, "insert into test values('abc')") + run(executor, "insert into test values('def')") + + results = run(executor, + "\\fs test-ad select * from test where a like 'a%'; " + "select * from test where a like 'd%'") + assert_result_equal(results, status='Saved.') + + results = run(executor, "\\f test-ad") + expected = [{'title': "> select * from test where a like 'a%'", + 'headers': ['a'], 'rows': [('abc',)], 'status': None}, + {'title': "> select * from test where a like 'd%'", + 'headers': ['a'], 'rows': [('def',)], 'status': None}] + assert expected == results + + results = run(executor, "\\fd test-ad") + assert_result_equal(results, status='test-ad: Deleted') + + +@dbtest +def test_favorite_query_expanded_output(executor): + set_expanded_output(False) + run(executor, '''create table test(a text)''') + run(executor, '''insert into test values('abc')''') + + results = run(executor, "\\fs test-ae select * from test") + assert_result_equal(results, status='Saved.') + + results = run(executor, "\\f test-ae \G") + assert is_expanded_output() is True + assert_result_equal(results, title='> select * from test', + headers=['a'], rows=[('abc',)], auto_status=False) + + set_expanded_output(False) + + results = run(executor, "\\fd test-ae") + assert_result_equal(results, status='test-ae: Deleted') + + +@dbtest +def test_special_command(executor): + results = run(executor, '\\?') + assert_result_equal(results, rows=('quit', '\\q', 'Quit.'), + headers='Command', assert_contains=True, + auto_status=False) + + +@dbtest +def test_cd_command_without_a_folder_name(executor): + results = run(executor, 'system cd') + assert_result_equal(results, status='No folder name was provided.') + + +@dbtest +def test_system_command_not_found(executor): + results = run(executor, 'system xyz') + assert_result_equal(results, status='OSError: No such file or directory', + assert_contains=True) + + +@dbtest +def test_system_command_output(executor): + test_dir = os.path.abspath(os.path.dirname(__file__)) + test_file_path = os.path.join(test_dir, 'test.txt') + results = run(executor, 'system cat {0}'.format(test_file_path)) + assert_result_equal(results, status='mycli rocks!\n') + + +@dbtest +def test_cd_command_current_dir(executor): + test_path = os.path.abspath(os.path.dirname(__file__)) + run(executor, 'system cd {0}'.format(test_path)) + assert os.getcwd() == test_path + + +@dbtest +def test_unicode_support(executor): + results = run(executor, u"SELECT '日本語' AS japanese;") + assert_result_equal(results, headers=['japanese'], rows=[(u'日本語',)]) + + +@dbtest +def test_timestamp_null(executor): + run(executor, '''create table ts_null(a timestamp null)''') + run(executor, '''insert into ts_null values(null)''') + results = run(executor, '''select * from ts_null''') + assert_result_equal(results, headers=['a'], + rows=[(None,)]) + + +@dbtest +def test_datetime_null(executor): + run(executor, '''create table dt_null(a datetime null)''') + run(executor, '''insert into dt_null values(null)''') + results = run(executor, '''select * from dt_null''') + assert_result_equal(results, headers=['a'], + rows=[(None,)]) + + +@dbtest +def test_date_null(executor): + run(executor, '''create table date_null(a date null)''') + run(executor, '''insert into date_null values(null)''') + results = run(executor, '''select * from date_null''') + assert_result_equal(results, headers=['a'], rows=[(None,)]) + + +@dbtest +def test_time_null(executor): + run(executor, '''create table time_null(a time null)''') + run(executor, '''insert into time_null values(null)''') + results = run(executor, '''select * from time_null''') + assert_result_equal(results, headers=['a'], rows=[(None,)]) + + +@dbtest +def test_multiple_results(executor): + query = '''CREATE PROCEDURE dmtest() + BEGIN + SELECT 1; + SELECT 2; + END''' + executor.conn.cursor().execute(query) + + results = run(executor, 'call dmtest;') + expected = [ + {'title': None, 'rows': [(1,)], 'headers': ['1'], + 'status': '1 row in set'}, + {'title': None, 'rows': [(2,)], 'headers': ['2'], + 'status': '1 row in set'} + ] + assert results == expected diff --git a/test/test_tabular_output.py b/test/test_tabular_output.py new file mode 100644 index 0000000..7d7d000 --- /dev/null +++ b/test/test_tabular_output.py @@ -0,0 +1,118 @@ +"""Test the sql output adapter.""" + +from textwrap import dedent + +from mycli.packages.tabular_output import sql_format +from cli_helpers.tabular_output import TabularOutputFormatter + +from .utils import USER, PASSWORD, HOST, PORT, dbtest + +import pytest +from mycli.main import MyCli + +from pymysql.constants import FIELD_TYPE + + +@pytest.fixture +def mycli(): + cli = MyCli() + cli.connect(None, USER, PASSWORD, HOST, PORT, None) + return cli + + +@dbtest +def test_sql_output(mycli): + """Test the sql output adapter.""" + headers = ['letters', 'number', 'optional', 'float', 'binary'] + + class FakeCursor(object): + def __init__(self): + self.data = [ + ('abc', 1, None, 10.0, b'\xAA'), + ('d', 456, '1', 0.5, b'\xAA\xBB') + ] + self.description = [ + (None, FIELD_TYPE.VARCHAR), + (None, FIELD_TYPE.LONG), + (None, FIELD_TYPE.LONG), + (None, FIELD_TYPE.FLOAT), + (None, FIELD_TYPE.BLOB) + ] + + def __iter__(self): + return self + + def __next__(self): + if self.data: + return self.data.pop(0) + else: + raise StopIteration() + + def description(self): + return self.description + + # Test sql-update output format + assert list(mycli.change_table_format("sql-update")) == \ + [(None, None, None, 'Changed table format to sql-update')] + mycli.formatter.query = "" + output = mycli.format_output(None, FakeCursor(), headers) + actual = "\n".join(output) + assert actual == dedent('''\ + UPDATE `DUAL` SET + `number` = 1 + , `optional` = NULL + , `float` = 10.0e0 + , `binary` = X'aa' + WHERE `letters` = 'abc'; + UPDATE `DUAL` SET + `number` = 456 + , `optional` = '1' + , `float` = 0.5e0 + , `binary` = X'aabb' + WHERE `letters` = 'd';''') + # Test sql-update-2 output format + assert list(mycli.change_table_format("sql-update-2")) == \ + [(None, None, None, 'Changed table format to sql-update-2')] + mycli.formatter.query = "" + output = mycli.format_output(None, FakeCursor(), headers) + assert "\n".join(output) == dedent('''\ + UPDATE `DUAL` SET + `optional` = NULL + , `float` = 10.0e0 + , `binary` = X'aa' + WHERE `letters` = 'abc' AND `number` = 1; + UPDATE `DUAL` SET + `optional` = '1' + , `float` = 0.5e0 + , `binary` = X'aabb' + WHERE `letters` = 'd' AND `number` = 456;''') + # Test sql-insert output format (without table name) + assert list(mycli.change_table_format("sql-insert")) == \ + [(None, None, None, 'Changed table format to sql-insert')] + mycli.formatter.query = "" + output = mycli.format_output(None, FakeCursor(), headers) + assert "\n".join(output) == dedent('''\ + INSERT INTO `DUAL` (`letters`, `number`, `optional`, `float`, `binary`) VALUES + ('abc', 1, NULL, 10.0e0, X'aa') + , ('d', 456, '1', 0.5e0, X'aabb') + ;''') + # Test sql-insert output format (with table name) + assert list(mycli.change_table_format("sql-insert")) == \ + [(None, None, None, 'Changed table format to sql-insert')] + mycli.formatter.query = "SELECT * FROM `table`" + output = mycli.format_output(None, FakeCursor(), headers) + assert "\n".join(output) == dedent('''\ + INSERT INTO `table` (`letters`, `number`, `optional`, `float`, `binary`) VALUES + ('abc', 1, NULL, 10.0e0, X'aa') + , ('d', 456, '1', 0.5e0, X'aabb') + ;''') + # Test sql-insert output format (with database + table name) + assert list(mycli.change_table_format("sql-insert")) == \ + [(None, None, None, 'Changed table format to sql-insert')] + mycli.formatter.query = "SELECT * FROM `database`.`table`" + output = mycli.format_output(None, FakeCursor(), headers) + assert "\n".join(output) == dedent('''\ + INSERT INTO `database`.`table` (`letters`, `number`, `optional`, `float`, `binary`) VALUES + ('abc', 1, NULL, 10.0e0, X'aa') + , ('d', 456, '1', 0.5e0, X'aabb') + ;''') diff --git a/test/utils.py b/test/utils.py new file mode 100644 index 0000000..66b4194 --- /dev/null +++ b/test/utils.py @@ -0,0 +1,94 @@ +import os +import time +import signal +import platform +import multiprocessing + +import pymysql +import pytest + +from mycli.main import special + +PASSWORD = os.getenv('PYTEST_PASSWORD') +USER = os.getenv('PYTEST_USER', 'root') +HOST = os.getenv('PYTEST_HOST', 'localhost') +PORT = int(os.getenv('PYTEST_PORT', 3306)) +CHARSET = os.getenv('PYTEST_CHARSET', 'utf8') +SSH_USER = os.getenv('PYTEST_SSH_USER', None) +SSH_HOST = os.getenv('PYTEST_SSH_HOST', None) +SSH_PORT = os.getenv('PYTEST_SSH_PORT', 22) + + +def db_connection(dbname=None): + conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname, + password=PASSWORD, charset=CHARSET, + local_infile=False) + conn.autocommit = True + return conn + + +try: + db_connection() + CAN_CONNECT_TO_DB = True +except: + CAN_CONNECT_TO_DB = False + +dbtest = pytest.mark.skipif( + not CAN_CONNECT_TO_DB, + reason="Need a mysql instance at localhost accessible by user 'root'") + + +def create_db(dbname): + with db_connection().cursor() as cur: + try: + cur.execute('''DROP DATABASE IF EXISTS _test_db''') + cur.execute('''CREATE DATABASE _test_db''') + except: + pass + + +def run(executor, sql, rows_as_list=True): + """Return string output for the sql to be run.""" + result = [] + + for title, rows, headers, status in executor.run(sql): + rows = list(rows) if (rows_as_list and rows) else rows + result.append({'title': title, 'rows': rows, 'headers': headers, + 'status': status}) + + return result + + +def set_expanded_output(is_expanded): + """Pass-through for the tests.""" + return special.set_expanded_output(is_expanded) + + +def is_expanded_output(): + """Pass-through for the tests.""" + return special.is_expanded_output() + + +def send_ctrl_c_to_pid(pid, wait_seconds): + """Sends a Ctrl-C like signal to the given `pid` after `wait_seconds` + seconds.""" + time.sleep(wait_seconds) + system_name = platform.system() + if system_name == "Windows": + os.kill(pid, signal.CTRL_C_EVENT) + else: + os.kill(pid, signal.SIGINT) + + +def send_ctrl_c(wait_seconds): + """Create a process that sends a Ctrl-C like signal to the current process + after `wait_seconds` seconds. + + Returns the `multiprocessing.Process` created. + + """ + ctrl_c_process = multiprocessing.Process( + target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds) + ) + ctrl_c_process.start() + return ctrl_c_process @@ -0,0 +1,15 @@ +[tox] +envlist = py36, py37, py38 + +[testenv] +deps = pytest + mock + pexpect + behave + coverage +commands = python setup.py test +passenv = PYTEST_HOST + PYTEST_USER + PYTEST_PASSWORD + PYTEST_PORT + PYTEST_CHARSET |