diff options
68 files changed, 3745 insertions, 3358 deletions
diff --git a/.coveragerc b/.coveragerc index 8d3149f..57ebce1 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,3 +1,2 @@ [run] -parallel = True source = mycli diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fb34daa..ce359d8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,34 +4,21 @@ on: pull_request: paths-ignore: - '**.md' + - 'AUTHORS' jobs: - linux: + build: + runs-on: ubuntu-latest + strategy: matrix: - python-version: [ - '3.8', - '3.9', - '3.10', - '3.11', - '3.12', - ] - include: - - python-version: '3.8' - os: ubuntu-20.04 # MySQL 8.0.36 - - python-version: '3.9' - os: ubuntu-20.04 # MySQL 8.0.36 - - python-version: '3.10' - os: ubuntu-22.04 # MySQL 8.0.36 - - python-version: '3.11' - os: ubuntu-22.04 # MySQL 8.0.36 - - python-version: '3.12' - os: ubuntu-22.04 # MySQL 8.0.36 + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] - runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v4 + - uses: astral-sh/setup-uv@v1 + with: + version: "latest" - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 @@ -43,10 +30,7 @@ jobs: sudo /etc/init.d/mysql start - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements-dev.txt - pip install --no-cache-dir -e . + run: uv sync --all-extras -p ${{ matrix.python-version }} - name: Wait for MySQL connection run: | @@ -59,13 +43,7 @@ jobs: PYTEST_PASSWORD: root PYTEST_HOST: 127.0.0.1 run: | - ./setup.py test --pytest-args="--cov-report= --cov=mycli" + uv run tox -e py${{ matrix.python-version }} - - name: Lint - run: | - ./setup.py lint --branch=HEAD - - - name: Coverage - run: | - coverage combine - coverage report + - name: Run Style Checks + run: uv run tox -e style diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..368091d --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,94 @@ +name: Publish Python Package + +on: + release: + types: [created] + +permissions: + contents: read + +jobs: + test: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/setup-uv@v1 + with: + version: "latest" + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Start MySQL + run: | + sudo /etc/init.d/mysql start + + - name: Install dependencies + run: uv sync --all-extras -p ${{ matrix.python-version }} + + - name: Wait for MySQL connection + run: | + while ! mysqladmin ping --host=localhost --port=3306 --user=root --password=root --silent; do + sleep 5 + done + + - name: Pytest / behave + env: + PYTEST_PASSWORD: root + PYTEST_HOST: 127.0.0.1 + run: | + uv run tox -e py${{ matrix.python-version }} + + - name: Run Style Checks + run: uv run tox -e style + + build: + runs-on: ubuntu-latest + needs: [test] + + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/setup-uv@v1 + with: + version: "latest" + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.13' + + - name: Install dependencies + run: uv sync --all-extras -p 3.13 + + - name: Build + run: uv build + + - name: Store the distribution packages + uses: actions/upload-artifact@v4 + with: + name: python-packages + path: dist/ + + publish: + name: Publish to PyPI + runs-on: ubuntu-latest + if: startsWith(github.ref, 'refs/tags/') + needs: [build] + environment: release + permissions: + id-token: write + steps: + - name: Download distribution packages + uses: actions/download-artifact@v4 + with: + name: python-packages + path: dist/ + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/changelog.md b/changelog.md index 6cab6f5..a418a38 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,34 @@ +1.29.2 (2024/12/11) +=================== + +Internal +-------- + +* Exclude tests from the python package. + +1.29.1 (2024/12/11) +=================== + +Internal +-------- + +* Fix the GH actions to publish a new version. + +1.29.0 (NEVER RELEASED) +======================= + +Bug Fixes +---------- + +* fix SSL through SSH jump host by using a true python socket for a tunnel +* Fix mycli crash when connecting to Vitess + +Internal +--------- + +* Modernize to use PEP-621. Use `uv` instead of `pip` in GH actions. +* Remove Python 3.8 and add Python 3.13 in test matrix. + 1.28.0 (2024/11/10) ====================== diff --git a/mycli/AUTHORS b/mycli/AUTHORS index d5a9ce0..b834452 100644 --- a/mycli/AUTHORS +++ b/mycli/AUTHORS @@ -98,6 +98,7 @@ Contributors: * Houston Wong * Mohamed Rezk * Ryosuke Kazami + * Cornel Cruceru Created by: diff --git a/mycli/__init__.py b/mycli/__init__.py index b3f408d..bd8e3c3 100644 --- a/mycli/__init__.py +++ b/mycli/__init__.py @@ -1 +1,3 @@ -__version__ = "1.28.0" +import importlib.metadata + +__version__ = importlib.metadata.version("mycli") diff --git a/mycli/clibuffer.py b/mycli/clibuffer.py index 81353b6..d9fbf83 100644 --- a/mycli/clibuffer.py +++ b/mycli/clibuffer.py @@ -13,6 +13,7 @@ def cli_is_multiline(mycli): return False else: return not _multiline_exception(doc.text) + return cond @@ -23,33 +24,32 @@ def _multiline_exception(text): # Multi-statement favorite query is a special case. Because there will # be a semicolon separating statements, we can't consider semicolon an # EOL. Let's consider an empty line an EOL instead. - if text.startswith('\\fs'): - return orig.endswith('\n') + if text.startswith("\\fs"): + return orig.endswith("\n") return ( # Special Command - text.startswith('\\') or - + text.startswith("\\") + or # Delimiter declaration - text.lower().startswith('delimiter') or - + text.lower().startswith("delimiter") + or # Ended with the current delimiter (usually a semi-column) - text.endswith(special.get_current_delimiter()) or - - text.endswith('\\g') or - text.endswith('\\G') or - text.endswith(r'\e') or - text.endswith(r'\clip') or - + text.endswith(special.get_current_delimiter()) + or text.endswith("\\g") + or text.endswith("\\G") + or text.endswith(r"\e") + or text.endswith(r"\clip") + or # Exit doesn't need semi-column` - (text == 'exit') or - + (text == "exit") + or # Quit doesn't need semi-column - (text == 'quit') or - + (text == "quit") + or # To all teh vim fans out there - (text == ':q') or - + (text == ":q") + or # just a plain enter without any text - (text == '') + (text == "") ) diff --git a/mycli/clistyle.py b/mycli/clistyle.py index b0ac992..d7bc3fe 100644 --- a/mycli/clistyle.py +++ b/mycli/clistyle.py @@ -11,70 +11,69 @@ logger = logging.getLogger(__name__) # map Pygments tokens (ptk 1.0) to class names (ptk 2.0). TOKEN_TO_PROMPT_STYLE = { - Token.Menu.Completions.Completion.Current: 'completion-menu.completion.current', - Token.Menu.Completions.Completion: 'completion-menu.completion', - Token.Menu.Completions.Meta.Current: 'completion-menu.meta.completion.current', - Token.Menu.Completions.Meta: 'completion-menu.meta.completion', - Token.Menu.Completions.MultiColumnMeta: 'completion-menu.multi-column-meta', - Token.Menu.Completions.ProgressButton: 'scrollbar.arrow', # best guess - Token.Menu.Completions.ProgressBar: 'scrollbar', # best guess - Token.SelectedText: 'selected', - Token.SearchMatch: 'search', - Token.SearchMatch.Current: 'search.current', - Token.Toolbar: 'bottom-toolbar', - Token.Toolbar.Off: 'bottom-toolbar.off', - Token.Toolbar.On: 'bottom-toolbar.on', - Token.Toolbar.Search: 'search-toolbar', - Token.Toolbar.Search.Text: 'search-toolbar.text', - Token.Toolbar.System: 'system-toolbar', - Token.Toolbar.Arg: 'arg-toolbar', - Token.Toolbar.Arg.Text: 'arg-toolbar.text', - Token.Toolbar.Transaction.Valid: 'bottom-toolbar.transaction.valid', - Token.Toolbar.Transaction.Failed: 'bottom-toolbar.transaction.failed', - Token.Output.Header: 'output.header', - Token.Output.OddRow: 'output.odd-row', - Token.Output.EvenRow: 'output.even-row', - Token.Output.Null: 'output.null', - Token.Prompt: 'prompt', - Token.Continuation: 'continuation', + Token.Menu.Completions.Completion.Current: "completion-menu.completion.current", + Token.Menu.Completions.Completion: "completion-menu.completion", + Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current", + Token.Menu.Completions.Meta: "completion-menu.meta.completion", + Token.Menu.Completions.MultiColumnMeta: "completion-menu.multi-column-meta", + Token.Menu.Completions.ProgressButton: "scrollbar.arrow", # best guess + Token.Menu.Completions.ProgressBar: "scrollbar", # best guess + Token.SelectedText: "selected", + Token.SearchMatch: "search", + Token.SearchMatch.Current: "search.current", + Token.Toolbar: "bottom-toolbar", + Token.Toolbar.Off: "bottom-toolbar.off", + Token.Toolbar.On: "bottom-toolbar.on", + Token.Toolbar.Search: "search-toolbar", + Token.Toolbar.Search.Text: "search-toolbar.text", + Token.Toolbar.System: "system-toolbar", + Token.Toolbar.Arg: "arg-toolbar", + Token.Toolbar.Arg.Text: "arg-toolbar.text", + Token.Toolbar.Transaction.Valid: "bottom-toolbar.transaction.valid", + Token.Toolbar.Transaction.Failed: "bottom-toolbar.transaction.failed", + Token.Output.Header: "output.header", + Token.Output.OddRow: "output.odd-row", + Token.Output.EvenRow: "output.even-row", + Token.Output.Null: "output.null", + Token.Prompt: "prompt", + Token.Continuation: "continuation", } # reverse dict for cli_helpers, because they still expect Pygments tokens. -PROMPT_STYLE_TO_TOKEN = { - v: k for k, v in TOKEN_TO_PROMPT_STYLE.items() -} +PROMPT_STYLE_TO_TOKEN = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()} # all tokens that the Pygments MySQL lexer can produce OVERRIDE_STYLE_TO_TOKEN = { - 'sql.comment': Token.Comment, - 'sql.comment.multi-line': Token.Comment.Multiline, - 'sql.comment.single-line': Token.Comment.Single, - 'sql.comment.optimizer-hint': Token.Comment.Special, - 'sql.escape': Token.Error, - 'sql.keyword': Token.Keyword, - 'sql.datatype': Token.Keyword.Type, - 'sql.literal': Token.Literal, - 'sql.literal.date': Token.Literal.Date, - 'sql.symbol': Token.Name, - 'sql.quoted-schema-object': Token.Name.Quoted, - 'sql.quoted-schema-object.escape': Token.Name.Quoted.Escape, - 'sql.constant': Token.Name.Constant, - 'sql.function': Token.Name.Function, - 'sql.variable': Token.Name.Variable, - 'sql.number': Token.Number, - 'sql.number.binary': Token.Number.Bin, - 'sql.number.float': Token.Number.Float, - 'sql.number.hex': Token.Number.Hex, - 'sql.number.integer': Token.Number.Integer, - 'sql.operator': Token.Operator, - 'sql.punctuation': Token.Punctuation, - 'sql.string': Token.String, - 'sql.string.double-quouted': Token.String.Double, - 'sql.string.escape': Token.String.Escape, - 'sql.string.single-quoted': Token.String.Single, - 'sql.whitespace': Token.Text, + "sql.comment": Token.Comment, + "sql.comment.multi-line": Token.Comment.Multiline, + "sql.comment.single-line": Token.Comment.Single, + "sql.comment.optimizer-hint": Token.Comment.Special, + "sql.escape": Token.Error, + "sql.keyword": Token.Keyword, + "sql.datatype": Token.Keyword.Type, + "sql.literal": Token.Literal, + "sql.literal.date": Token.Literal.Date, + "sql.symbol": Token.Name, + "sql.quoted-schema-object": Token.Name.Quoted, + "sql.quoted-schema-object.escape": Token.Name.Quoted.Escape, + "sql.constant": Token.Name.Constant, + "sql.function": Token.Name.Function, + "sql.variable": Token.Name.Variable, + "sql.number": Token.Number, + "sql.number.binary": Token.Number.Bin, + "sql.number.float": Token.Number.Float, + "sql.number.hex": Token.Number.Hex, + "sql.number.integer": Token.Number.Integer, + "sql.operator": Token.Operator, + "sql.punctuation": Token.Punctuation, + "sql.string": Token.String, + "sql.string.double-quouted": Token.String.Double, + "sql.string.escape": Token.String.Escape, + "sql.string.single-quoted": Token.String.Single, + "sql.whitespace": Token.Text, } + def parse_pygments_style(token_name, style_object, style_dict): """Parse token type and style string. @@ -87,7 +86,7 @@ def parse_pygments_style(token_name, style_object, style_dict): try: other_token_type = string_to_tokentype(style_dict[token_name]) return token_type, style_object.styles[other_token_type] - except AttributeError as err: + except AttributeError: return token_type, style_dict[token_name] @@ -95,45 +94,39 @@ def style_factory(name, cli_style): try: style = pygments.styles.get_style_by_name(name) except ClassNotFound: - style = pygments.styles.get_style_by_name('native') + style = pygments.styles.get_style_by_name("native") prompt_styles = [] # prompt-toolkit used pygments tokens for styling before, switched to style # names in 2.0. Convert old token types to new style names, for backwards compatibility. for token in cli_style: - if token.startswith('Token.'): + if token.startswith("Token."): # treat as pygments token (1.0) - token_type, style_value = parse_pygments_style( - token, style, cli_style) + token_type, style_value = parse_pygments_style(token, style, cli_style) if token_type in TOKEN_TO_PROMPT_STYLE: prompt_style = TOKEN_TO_PROMPT_STYLE[token_type] prompt_styles.append((prompt_style, style_value)) else: # we don't want to support tokens anymore - logger.error('Unhandled style / class name: %s', token) + logger.error("Unhandled style / class name: %s", token) else: # treat as prompt style name (2.0). See default style names here: # https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py prompt_styles.append((token, cli_style[token])) - override_style = Style([('bottom-toolbar', 'noreverse')]) - return merge_styles([ - style_from_pygments_cls(style), - override_style, - Style(prompt_styles) - ]) + override_style = Style([("bottom-toolbar", "noreverse")]) + return merge_styles([style_from_pygments_cls(style), override_style, Style(prompt_styles)]) def style_factory_output(name, cli_style): try: style = pygments.styles.get_style_by_name(name).styles except ClassNotFound: - style = pygments.styles.get_style_by_name('native').styles + style = pygments.styles.get_style_by_name("native").styles for token in cli_style: - if token.startswith('Token.'): - token_type, style_value = parse_pygments_style( - token, style, cli_style) + if token.startswith("Token."): + token_type, style_value = parse_pygments_style(token, style, cli_style) style.update({token_type: style_value}) elif token in PROMPT_STYLE_TO_TOKEN: token_type = PROMPT_STYLE_TO_TOKEN[token] @@ -143,7 +136,7 @@ def style_factory_output(name, cli_style): style.update({token_type: cli_style[token]}) else: # TODO: cli helpers will have to switch to ptk.Style - logger.error('Unhandled style / class name: %s', token) + logger.error("Unhandled style / class name: %s", token) class OutputStyle(PygmentsStyle): default_style = "" diff --git a/mycli/clitoolbar.py b/mycli/clitoolbar.py index 52b6ee4..54e2eed 100644 --- a/mycli/clitoolbar.py +++ b/mycli/clitoolbar.py @@ -6,52 +6,47 @@ from .packages import special def create_toolbar_tokens_func(mycli, show_fish_help): """Return a function that generates the toolbar tokens.""" + def get_toolbar_tokens(): - result = [('class:bottom-toolbar', ' ')] + result = [("class:bottom-toolbar", " ")] if mycli.multi_line: delimiter = special.get_current_delimiter() result.append( ( - 'class:bottom-toolbar', - ' ({} [{}] will end the line) '.format( - 'Semi-colon' if delimiter == ';' else 'Delimiter', delimiter) - )) + "class:bottom-toolbar", + " ({} [{}] will end the line) ".format("Semi-colon" if delimiter == ";" else "Delimiter", delimiter), + ) + ) if mycli.multi_line: - result.append(('class:bottom-toolbar.on', '[F3] Multiline: ON ')) + result.append(("class:bottom-toolbar.on", "[F3] Multiline: ON ")) else: - result.append(('class:bottom-toolbar.off', - '[F3] Multiline: OFF ')) + result.append(("class:bottom-toolbar.off", "[F3] Multiline: OFF ")) if mycli.prompt_app.editing_mode == EditingMode.VI: - result.append(( - 'class:bottom-toolbar.on', - 'Vi-mode ({})'.format(_get_vi_mode()) - )) + result.append(("class:bottom-toolbar.on", "Vi-mode ({})".format(_get_vi_mode()))) if mycli.toolbar_error_message: - result.append( - ('class:bottom-toolbar', ' ' + mycli.toolbar_error_message)) + result.append(("class:bottom-toolbar", " " + mycli.toolbar_error_message)) mycli.toolbar_error_message = None if show_fish_help(): - result.append( - ('class:bottom-toolbar', ' Right-arrow to complete suggestion')) + result.append(("class:bottom-toolbar", " Right-arrow to complete suggestion")) if mycli.completion_refresher.is_refreshing(): - result.append( - ('class:bottom-toolbar', ' Refreshing completions...')) + result.append(("class:bottom-toolbar", " Refreshing completions...")) return result + return get_toolbar_tokens def _get_vi_mode(): """Get the current vi mode for display.""" return { - InputMode.INSERT: 'I', - InputMode.NAVIGATION: 'N', - InputMode.REPLACE: 'R', - InputMode.REPLACE_SINGLE: 'R', - InputMode.INSERT_MULTIPLE: 'M', + InputMode.INSERT: "I", + InputMode.NAVIGATION: "N", + InputMode.REPLACE: "R", + InputMode.REPLACE_SINGLE: "R", + InputMode.INSERT_MULTIPLE: "M", }[get_app().vi_state.input_mode] diff --git a/mycli/compat.py b/mycli/compat.py index 2ebfe07..6d06965 100644 --- a/mycli/compat.py +++ b/mycli/compat.py @@ -3,4 +3,4 @@ import sys -WIN = sys.platform in ('win32', 'cygwin') +WIN = sys.platform in ("win32", "cygwin") diff --git a/mycli/completion_refresher.py b/mycli/completion_refresher.py index 5d5f40f..662dd33 100644 --- a/mycli/completion_refresher.py +++ b/mycli/completion_refresher.py @@ -5,8 +5,8 @@ from collections import OrderedDict from .sqlcompleter import SQLCompleter from .sqlexecute import SQLExecute, ServerSpecies -class CompletionRefresher(object): +class CompletionRefresher(object): refreshers = OrderedDict() def __init__(self): @@ -30,16 +30,14 @@ class CompletionRefresher(object): if self.is_refreshing(): self._restart_refresh.set() - return [(None, None, None, 'Auto-completion refresh restarted.')] + return [(None, None, None, "Auto-completion refresh restarted.")] else: self._completer_thread = threading.Thread( - target=self._bg_refresh, - args=(executor, callbacks, completer_options), - name='completion_refresh') + target=self._bg_refresh, args=(executor, callbacks, completer_options), name="completion_refresh" + ) self._completer_thread.daemon = True self._completer_thread.start() - return [(None, None, None, - 'Auto-completion refresh started in the background.')] + return [(None, None, None, "Auto-completion refresh started in the background.")] def is_refreshing(self): return self._completer_thread and self._completer_thread.is_alive() @@ -49,10 +47,22 @@ class CompletionRefresher(object): # Create a new pgexecute method to populate the completions. e = sqlexecute - executor = SQLExecute(e.dbname, e.user, e.password, e.host, e.port, - e.socket, e.charset, e.local_infile, e.ssl, - e.ssh_user, e.ssh_host, e.ssh_port, - e.ssh_password, e.ssh_key_filename) + executor = SQLExecute( + e.dbname, + e.user, + e.password, + e.host, + e.port, + e.socket, + e.charset, + e.local_infile, + e.ssl, + e.ssh_user, + e.ssh_host, + e.ssh_port, + e.ssh_password, + e.ssh_key_filename, + ) # If callbacks is a single function then push it into a list. if callable(callbacks): @@ -76,55 +86,68 @@ class CompletionRefresher(object): for callback in callbacks: callback(completer) + def refresher(name, refreshers=CompletionRefresher.refreshers): """Decorator to add the decorated function to the dictionary of refreshers. Any function decorated with a @refresher will be executed as part of the completion refresh routine.""" + def wrapper(wrapped): refreshers[name] = wrapped return wrapped + return wrapper -@refresher('databases') + +@refresher("databases") def refresh_databases(completer, executor): completer.extend_database_names(executor.databases()) -@refresher('schemata') + +@refresher("schemata") def refresh_schemata(completer, executor): # schemata - In MySQL Schema is the same as database. But for mycli # schemata will be the name of the current database. completer.extend_schemata(executor.dbname) completer.set_dbname(executor.dbname) -@refresher('tables') + +@refresher("tables") def refresh_tables(completer, executor): - completer.extend_relations(executor.tables(), kind='tables') - completer.extend_columns(executor.table_columns(), kind='tables') + table_columns_dbresult = list(executor.table_columns()) + completer.extend_relations(table_columns_dbresult, kind="tables") + completer.extend_columns(table_columns_dbresult, kind="tables") + -@refresher('users') +@refresher("users") def refresh_users(completer, executor): completer.extend_users(executor.users()) + # @refresher('views') # def refresh_views(completer, executor): # completer.extend_relations(executor.views(), kind='views') # completer.extend_columns(executor.view_columns(), kind='views') -@refresher('functions') + +@refresher("functions") def refresh_functions(completer, executor): completer.extend_functions(executor.functions()) if executor.server_info.species == ServerSpecies.TiDB: completer.extend_functions(completer.tidb_functions, builtin=True) -@refresher('special_commands') + +@refresher("special_commands") def refresh_special(completer, executor): completer.extend_special_commands(COMMANDS.keys()) -@refresher('show_commands') + +@refresher("show_commands") def refresh_show_commands(completer, executor): completer.extend_show_items(executor.show_candidates()) -@refresher('keywords') + +@refresher("keywords") def refresh_keywords(completer, executor): if executor.server_info.species == ServerSpecies.TiDB: completer.extend_keywords(completer.tidb_keywords, replace=True) diff --git a/mycli/config.py b/mycli/config.py index 5d71109..4ce5eff 100644 --- a/mycli/config.py +++ b/mycli/config.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) def log(logger, level, message): """Logs message to stderr if logging isn't initialized.""" - if logger.parent.name != 'root': + if logger.parent.name != "root": logger.log(level, message) else: print(message, file=sys.stderr) @@ -49,16 +49,13 @@ def read_config_file(f, list_values=True): f = os.path.expanduser(f) try: - config = ConfigObj(f, interpolation=False, encoding='utf8', - list_values=list_values) + config = ConfigObj(f, interpolation=False, encoding="utf8", list_values=list_values) except ConfigObjError as e: - log(logger, logging.WARNING, "Unable to parse line {0} of config file " - "'{1}'.".format(e.line_number, f)) + log(logger, logging.WARNING, "Unable to parse line {0} of config file " "'{1}'.".format(e.line_number, f)) log(logger, logging.WARNING, "Using successfully parsed config values.") return e.config except (IOError, OSError) as e: - log(logger, logging.WARNING, "You don't have permission to read " - "config file '{0}'.".format(e.filename)) + log(logger, logging.WARNING, "You don't have permission to read " "config file '{0}'.".format(e.filename)) return None return config @@ -80,15 +77,12 @@ def get_included_configs(config_file: Union[str, TextIOWrapper]) -> list: try: with open(config_file) as f: - include_directives = filter( - lambda s: s.startswith('!includedir'), - f - ) + include_directives = filter(lambda s: s.startswith("!includedir"), f) dirs = map(lambda s: s.strip().split()[-1], include_directives) dirs = filter(os.path.isdir, dirs) for dir in dirs: for filename in os.listdir(dir): - if filename.endswith('.cnf'): + if filename.endswith(".cnf"): included_configs.append(os.path.join(dir, filename)) except (PermissionError, UnicodeDecodeError): pass @@ -117,29 +111,31 @@ def read_config_files(files, list_values=True): def create_default_config(list_values=True): import mycli - default_config_file = resources.open_text(mycli, 'myclirc') + + default_config_file = resources.open_text(mycli, "myclirc") return read_config_file(default_config_file, list_values=list_values) def write_default_config(destination, overwrite=False): import mycli - default_config = resources.read_text(mycli, 'myclirc') + + default_config = resources.read_text(mycli, "myclirc") destination = os.path.expanduser(destination) if not overwrite and exists(destination): return - with open(destination, 'w') as f: + with open(destination, "w") as f: f.write(default_config) def get_mylogin_cnf_path(): """Return the path to the login path file or None if it doesn't exist.""" - mylogin_cnf_path = os.getenv('MYSQL_TEST_LOGIN_FILE') + mylogin_cnf_path = os.getenv("MYSQL_TEST_LOGIN_FILE") if mylogin_cnf_path is None: - app_data = os.getenv('APPDATA') - default_dir = os.path.join(app_data, 'MySQL') if app_data else '~' - mylogin_cnf_path = os.path.join(default_dir, '.mylogin.cnf') + app_data = os.getenv("APPDATA") + default_dir = os.path.join(app_data, "MySQL") if app_data else "~" + mylogin_cnf_path = os.path.join(default_dir, ".mylogin.cnf") mylogin_cnf_path = os.path.expanduser(mylogin_cnf_path) @@ -159,14 +155,14 @@ def open_mylogin_cnf(name): """ try: - with open(name, 'rb') as f: + with open(name, "rb") as f: plaintext = read_and_decrypt_mylogin_cnf(f) except (OSError, IOError, ValueError): - logger.error('Unable to open login path file.') + logger.error("Unable to open login path file.") return None if not isinstance(plaintext, BytesIO): - logger.error('Unable to read login path file.') + logger.error("Unable to read login path file.") return None return TextIOWrapper(plaintext) @@ -181,6 +177,7 @@ def encrypt_mylogin_cnf(plaintext: IO[str]): https://github.com/isotopp/mysql-config-coder """ + def realkey(key): """Create the AES key from the login key.""" rkey = bytearray(16) @@ -194,10 +191,7 @@ def encrypt_mylogin_cnf(plaintext: IO[str]): pad_len = buf_len - text_len pad_chr = bytes(chr(pad_len), "utf8") plaintext = plaintext.encode() + pad_chr * pad_len - encrypted_text = b''.join( - [aes.encrypt(plaintext[i: i + 16]) - for i in range(0, len(plaintext), 16)] - ) + encrypted_text = b"".join([aes.encrypt(plaintext[i : i + 16]) for i in range(0, len(plaintext), 16)]) return encrypted_text LOGIN_KEY_LENGTH = 20 @@ -248,7 +242,7 @@ def read_and_decrypt_mylogin_cnf(f): buf = f.read(4) if not buf or len(buf) != 4: - logger.error('Login path file is blank or incomplete.') + logger.error("Login path file is blank or incomplete.") return None # Read the login key. @@ -258,12 +252,12 @@ def read_and_decrypt_mylogin_cnf(f): rkey = [0] * 16 for i in range(LOGIN_KEY_LEN): try: - rkey[i % 16] ^= ord(key[i:i+1]) + rkey[i % 16] ^= ord(key[i : i + 1]) except TypeError: # ord() was unable to get the value of the byte. - logger.error('Unable to generate login path AES key.') + logger.error("Unable to generate login path AES key.") return None - rkey = struct.pack('16B', *rkey) + rkey = struct.pack("16B", *rkey) # Create a bytes buffer to hold the plaintext. plaintext = BytesIO() @@ -274,20 +268,17 @@ def read_and_decrypt_mylogin_cnf(f): len_buf = f.read(MAX_CIPHER_STORE_LEN) if len(len_buf) < MAX_CIPHER_STORE_LEN: break - cipher_len, = struct.unpack("<i", len_buf) + (cipher_len,) = struct.unpack("<i", len_buf) # Read cipher_len bytes from the file and decrypt. cipher = f.read(cipher_len) - plain = _remove_pad( - b''.join([aes.decrypt(cipher[i: i + 16]) - for i in range(0, cipher_len, 16)]) - ) + plain = _remove_pad(b"".join([aes.decrypt(cipher[i : i + 16]) for i in range(0, cipher_len, 16)])) if plain is False: continue plaintext.write(plain) if plaintext.tell() == 0: - logger.error('No data successfully decrypted from login path file.') + logger.error("No data successfully decrypted from login path file.") return None plaintext.seek(0) @@ -299,17 +290,17 @@ def str_to_bool(s): if isinstance(s, bool): return s elif not isinstance(s, basestring): - raise TypeError('argument must be a string') + raise TypeError("argument must be a string") - true_values = ('true', 'on', '1') - false_values = ('false', 'off', '0') + true_values = ("true", "on", "1") + false_values = ("false", "off", "0") if s.lower() in true_values: return True elif s.lower() in false_values: return False else: - raise ValueError('not a recognized boolean value: {0}'.format(s)) + raise ValueError("not a recognized boolean value: {0}".format(s)) def strip_matching_quotes(s): @@ -319,8 +310,7 @@ def strip_matching_quotes(s): values. """ - if (isinstance(s, basestring) and len(s) >= 2 and - s[0] == s[-1] and s[0] in ('"', "'")): + if isinstance(s, basestring) and len(s) >= 2 and s[0] == s[-1] and s[0] in ('"', "'"): s = s[1:-1] return s @@ -332,13 +322,13 @@ def _remove_pad(line): pad_length = ord(line[-1:]) except TypeError: # ord() was unable to get the value of the byte. - logger.warning('Unable to remove pad.') + logger.warning("Unable to remove pad.") return False if pad_length > len(line) or len(set(line[-pad_length:])) != 1: # Pad length should be less than or equal to the length of the # plaintext. The pad should have a single unique byte. - logger.warning('Invalid pad found in login path file.') + logger.warning("Invalid pad found in login path file.") return False return line[:-pad_length] diff --git a/mycli/key_bindings.py b/mycli/key_bindings.py index b084849..e03f728 100644 --- a/mycli/key_bindings.py +++ b/mycli/key_bindings.py @@ -12,22 +12,22 @@ def mycli_bindings(mycli): """Custom key bindings for mycli.""" kb = KeyBindings() - @kb.add('f2') + @kb.add("f2") def _(event): """Enable/Disable SmartCompletion Mode.""" - _logger.debug('Detected F2 key.') + _logger.debug("Detected F2 key.") mycli.completer.smart_completion = not mycli.completer.smart_completion - @kb.add('f3') + @kb.add("f3") def _(event): """Enable/Disable Multiline Mode.""" - _logger.debug('Detected F3 key.') + _logger.debug("Detected F3 key.") mycli.multi_line = not mycli.multi_line - @kb.add('f4') + @kb.add("f4") def _(event): """Toggle between Vi and Emacs mode.""" - _logger.debug('Detected F4 key.') + _logger.debug("Detected F4 key.") if mycli.key_bindings == "vi": event.app.editing_mode = EditingMode.EMACS mycli.key_bindings = "emacs" @@ -35,17 +35,17 @@ def mycli_bindings(mycli): event.app.editing_mode = EditingMode.VI mycli.key_bindings = "vi" - @kb.add('tab') + @kb.add("tab") def _(event): """Force autocompletion at cursor.""" - _logger.debug('Detected <Tab> key.') + _logger.debug("Detected <Tab> key.") b = event.app.current_buffer if b.complete_state: b.complete_next() else: b.start_completion(select_first=True) - @kb.add('c-space') + @kb.add("c-space") def _(event): """ Initialize autocompletion at cursor. @@ -55,7 +55,7 @@ def mycli_bindings(mycli): If the menu is showing, select the next completion. """ - _logger.debug('Detected <C-Space> key.') + _logger.debug("Detected <C-Space> key.") b = event.app.current_buffer if b.complete_state: @@ -63,14 +63,14 @@ def mycli_bindings(mycli): else: b.start_completion(select_first=False) - @kb.add('c-x', 'p', filter=emacs_mode) + @kb.add("c-x", "p", filter=emacs_mode) def _(event): """ Prettify and indent current statement, usually into multiple lines. Only accepts buffers containing single SQL statements. """ - _logger.debug('Detected <C-x p>/> key.') + _logger.debug("Detected <C-x p>/> key.") b = event.app.current_buffer cursorpos_relative = b.cursor_position / max(1, len(b.text)) @@ -78,19 +78,18 @@ def mycli_bindings(mycli): if len(pretty_text) > 0: b.text = pretty_text cursorpos_abs = int(round(cursorpos_relative * len(b.text))) - while 0 < cursorpos_abs < len(b.text) \ - and b.text[cursorpos_abs] in (' ', '\n'): + while 0 < cursorpos_abs < len(b.text) and b.text[cursorpos_abs] in (" ", "\n"): cursorpos_abs -= 1 b.cursor_position = min(cursorpos_abs, len(b.text)) - @kb.add('c-x', 'u', filter=emacs_mode) + @kb.add("c-x", "u", filter=emacs_mode) def _(event): """ Unprettify and dedent current statement, usually into one line. Only accepts buffers containing single SQL statements. """ - _logger.debug('Detected <C-x u>/< key.') + _logger.debug("Detected <C-x u>/< key.") b = event.app.current_buffer cursorpos_relative = b.cursor_position / max(1, len(b.text)) @@ -98,18 +97,17 @@ def mycli_bindings(mycli): if len(unpretty_text) > 0: b.text = unpretty_text cursorpos_abs = int(round(cursorpos_relative * len(b.text))) - while 0 < cursorpos_abs < len(b.text) \ - and b.text[cursorpos_abs] in (' ', '\n'): + while 0 < cursorpos_abs < len(b.text) and b.text[cursorpos_abs] in (" ", "\n"): cursorpos_abs -= 1 b.cursor_position = min(cursorpos_abs, len(b.text)) - @kb.add('c-r', filter=emacs_mode) + @kb.add("c-r", filter=emacs_mode) def _(event): """Search history using fzf or default reverse incremental search.""" - _logger.debug('Detected <C-r> key.') + _logger.debug("Detected <C-r> key.") search_history(event) - @kb.add('enter', filter=completion_is_selected) + @kb.add("enter", filter=completion_is_selected) def _(event): """Makes the enter key work as the tab key only when showing the menu. @@ -118,20 +116,20 @@ def mycli_bindings(mycli): (accept current selection). """ - _logger.debug('Detected enter key.') + _logger.debug("Detected enter key.") event.current_buffer.complete_state = None b = event.app.current_buffer b.complete_state = None - @kb.add('escape', 'enter') + @kb.add("escape", "enter") def _(event): """Introduces a line break in multi-line mode, or dispatches the command in single-line mode.""" - _logger.debug('Detected alt-enter key.') + _logger.debug("Detected alt-enter key.") if mycli.multi_line: event.app.current_buffer.validate_and_handle() else: - event.app.current_buffer.insert_text('\n') + event.app.current_buffer.insert_text("\n") return kb diff --git a/mycli/lexer.py b/mycli/lexer.py index 4b14d72..3350d11 100644 --- a/mycli/lexer.py +++ b/mycli/lexer.py @@ -7,6 +7,5 @@ class MyCliLexer(MySqlLexer): """Extends MySQL lexer to add keywords.""" tokens = { - 'root': [(r'\brepair\b', Keyword), - (r'\boffset\b', Keyword), inherit], + "root": [(r"\brepair\b", Keyword), (r"\boffset\b", Keyword), inherit], } diff --git a/mycli/magic.py b/mycli/magic.py index e1611bc..c237ff1 100644 --- a/mycli/magic.py +++ b/mycli/magic.py @@ -5,19 +5,20 @@ import logging _logger = logging.getLogger(__name__) -def load_ipython_extension(ipython): +def load_ipython_extension(ipython): # This is called via the ipython command '%load_ext mycli.magic'. # First, load the sql magic if it isn't already loaded. - if not ipython.find_line_magic('sql'): - ipython.run_line_magic('load_ext', 'sql') + if not ipython.find_line_magic("sql"): + ipython.run_line_magic("load_ext", "sql") # Register our own magic. - ipython.register_magic_function(mycli_line_magic, 'line', 'mycli') + ipython.register_magic_function(mycli_line_magic, "line", "mycli") + def mycli_line_magic(line): - _logger.debug('mycli magic called: %r', line) + _logger.debug("mycli magic called: %r", line) parsed = sql.parse.parse(line, {}) # "get" was renamed to "set" in ipython-sql: # https://github.com/catherinedevlin/ipython-sql/commit/f4283c65aaf68f961e84019e8b939e4a3c501d43 @@ -32,17 +33,17 @@ def mycli_line_magic(line): try: # A corresponding mycli object already exists mycli = conn._mycli - _logger.debug('Reusing existing mycli') + _logger.debug("Reusing existing mycli") except AttributeError: mycli = MyCli() u = conn.session.engine.url - _logger.debug('New mycli: %r', str(u)) + _logger.debug("New mycli: %r", str(u)) mycli.connect(host=u.host, port=u.port, passwd=u.password, database=u.database, user=u.username, init_command=None) conn._mycli = mycli # For convenience, print the connection alias - print('Connected: {}'.format(conn.name)) + print("Connected: {}".format(conn.name)) try: mycli.run_cli() @@ -54,9 +55,9 @@ def mycli_line_magic(line): q = mycli.query_history[-1] if q.mutating: - _logger.debug('Mutating query detected -- ignoring') + _logger.debug("Mutating query detected -- ignoring") return if q.successful: - ipython = get_ipython() - return ipython.run_cell_magic('sql', line, q.query) + ipython = get_ipython() # noqa: F821 + return ipython.run_cell_magic("sql", line, q.query) diff --git a/mycli/main.py b/mycli/main.py index 4c194ce..e480fea 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -8,8 +8,8 @@ import logging import threading import re import stat -import fileinput from collections import namedtuple + try: from pwd import getpwuid except ImportError: @@ -33,8 +33,7 @@ from prompt_toolkit.shortcuts import PromptSession, CompleteStyle from prompt_toolkit.document import Document from prompt_toolkit.filters import HasFocus, IsDone from prompt_toolkit.formatted_text import ANSI -from prompt_toolkit.layout.processors import (HighlightMatchingBracketProcessor, - ConditionalProcessor) +from prompt_toolkit.layout.processors import HighlightMatchingBracketProcessor, ConditionalProcessor from prompt_toolkit.lexers import PygmentsLexer from prompt_toolkit.auto_suggest import AutoSuggestFromHistory @@ -50,9 +49,7 @@ from .clistyle import style_factory, style_factory_output from .sqlexecute import FIELD_TYPES, SQLExecute, ERROR_CODE_ACCESS_DENIED from .clibuffer import cli_is_multiline from .completion_refresher import CompletionRefresher -from .config import (write_default_config, get_mylogin_cnf_path, - open_mylogin_cnf, read_config_files, str_to_bool, - strip_matching_quotes) +from .config import write_default_config, get_mylogin_cnf_path, open_mylogin_cnf, read_config_files, str_to_bool, strip_matching_quotes from .key_bindings import mycli_bindings from .lexer import MyCliLexer from . import __version__ @@ -82,27 +79,23 @@ except ImportError: from mycli.packages.paramiko_stub import paramiko # Query tuples are used for maintaining history -Query = namedtuple('Query', ['query', 'successful', 'mutating']) +Query = namedtuple("Query", ["query", "successful", "mutating"]) -SUPPORT_INFO = ( - 'Home: http://mycli.net\n' - 'Bug tracker: https://github.com/dbcli/mycli/issues' -) +SUPPORT_INFO = "Home: http://mycli.net\n" "Bug tracker: https://github.com/dbcli/mycli/issues" class MyCli(object): - - default_prompt = '\\t \\u@\\h:\\d> ' - default_prompt_splitln = '\\u@\\h\\n(\\t):\\d>' + default_prompt = "\\t \\u@\\h:\\d> " + default_prompt_splitln = "\\u@\\h\\n(\\t):\\d>" max_len_prompt = 45 defaults_suffix = None # In order of being loaded. Files lower in list override earlier ones. cnf_files = [ - '/etc/my.cnf', - '/etc/mysql/my.cnf', - '/usr/local/etc/my.cnf', - os.path.expanduser('~/.my.cnf'), + "/etc/my.cnf", + "/etc/mysql/my.cnf", + "/usr/local/etc/my.cnf", + os.path.expanduser("~/.my.cnf"), ] # check XDG_CONFIG_HOME exists and not an empty string @@ -110,17 +103,22 @@ class MyCli(object): xdg_config_home = os.environ.get("XDG_CONFIG_HOME") else: xdg_config_home = "~/.config" - system_config_files = [ - '/etc/myclirc', - os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc") - ] + system_config_files = ["/etc/myclirc", os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc")] pwd_config_file = os.path.join(os.getcwd(), ".myclirc") - def __init__(self, sqlexecute=None, prompt=None, - logfile=None, defaults_suffix=None, defaults_file=None, - login_path=None, auto_vertical_output=False, warn=None, - myclirc="~/.myclirc"): + def __init__( + self, + sqlexecute=None, + prompt=None, + logfile=None, + defaults_suffix=None, + defaults_file=None, + login_path=None, + auto_vertical_output=False, + warn=None, + myclirc="~/.myclirc", + ): self.sqlexecute = sqlexecute self.logfile = logfile self.defaults_suffix = defaults_suffix @@ -135,48 +133,41 @@ class MyCli(object): self.cnf_files = [defaults_file] # Load config. - config_files = (self.system_config_files + - [myclirc] + [self.pwd_config_file]) + config_files = self.system_config_files + [myclirc] + [self.pwd_config_file] c = self.config = read_config_files(config_files) - self.multi_line = c['main'].as_bool('multi_line') - self.key_bindings = c['main']['key_bindings'] - special.set_timing_enabled(c['main'].as_bool('timing')) - self.beep_after_seconds = float(c['main']['beep_after_seconds'] or 0) + self.multi_line = c["main"].as_bool("multi_line") + self.key_bindings = c["main"]["key_bindings"] + special.set_timing_enabled(c["main"].as_bool("timing")) + self.beep_after_seconds = float(c["main"]["beep_after_seconds"] or 0) FavoriteQueries.instance = FavoriteQueries.from_config(self.config) self.dsn_alias = None - self.formatter = TabularOutputFormatter( - format_name=c['main']['table_format']) + self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"]) sql_format.register_new_formatter(self.formatter) self.formatter.mycli = self - self.syntax_style = c['main']['syntax_style'] - self.less_chatty = c['main'].as_bool('less_chatty') - self.cli_style = c['colors'] - self.output_style = style_factory_output( - self.syntax_style, - self.cli_style - ) - self.wider_completion_menu = c['main'].as_bool('wider_completion_menu') - c_dest_warning = c['main'].as_bool('destructive_warning') + self.syntax_style = c["main"]["syntax_style"] + self.less_chatty = c["main"].as_bool("less_chatty") + self.cli_style = c["colors"] + self.output_style = style_factory_output(self.syntax_style, self.cli_style) + self.wider_completion_menu = c["main"].as_bool("wider_completion_menu") + c_dest_warning = c["main"].as_bool("destructive_warning") self.destructive_warning = c_dest_warning if warn is None else warn - self.login_path_as_host = c['main'].as_bool('login_path_as_host') + self.login_path_as_host = c["main"].as_bool("login_path_as_host") # read from cli argument or user config file - self.auto_vertical_output = auto_vertical_output or \ - c['main'].as_bool('auto_vertical_output') + self.auto_vertical_output = auto_vertical_output or c["main"].as_bool("auto_vertical_output") # Write user config if system config wasn't the last config loaded. if c.filename not in self.system_config_files and not os.path.exists(myclirc): write_default_config(myclirc) # audit log - if self.logfile is None and 'audit_log' in c['main']: + if self.logfile is None and "audit_log" in c["main"]: try: - self.logfile = open(os.path.expanduser(c['main']['audit_log']), 'a') - except (IOError, OSError) as e: - self.echo('Error: Unable to open the audit log file. Your queries will not be logged.', - err=True, fg='red') + self.logfile = open(os.path.expanduser(c["main"]["audit_log"]), "a") + except (IOError, OSError): + self.echo("Error: Unable to open the audit log file. Your queries will not be logged.", err=True, fg="red") self.logfile = False self.completion_refresher = CompletionRefresher() @@ -184,20 +175,18 @@ class MyCli(object): self.logger = logging.getLogger(__name__) self.initialize_logging() - prompt_cnf = self.read_my_cnf_files(self.cnf_files, ['prompt'])['prompt'] - self.prompt_format = prompt or prompt_cnf or c['main']['prompt'] or \ - self.default_prompt - self.multiline_continuation_char = c['main']['prompt_continuation'] - keyword_casing = c['main'].get('keyword_casing', 'auto') + prompt_cnf = self.read_my_cnf_files(self.cnf_files, ["prompt"])["prompt"] + self.prompt_format = prompt or prompt_cnf or c["main"]["prompt"] or self.default_prompt + self.multiline_continuation_char = c["main"]["prompt_continuation"] + keyword_casing = c["main"].get("keyword_casing", "auto") self.query_history = [] # Initialize completer. - self.smart_completion = c['main'].as_bool('smart_completion') + self.smart_completion = c["main"].as_bool("smart_completion") self.completer = SQLCompleter( - self.smart_completion, - supported_formats=self.formatter.supported_formats, - keyword_casing=keyword_casing) + self.smart_completion, supported_formats=self.formatter.supported_formats, keyword_casing=keyword_casing + ) self._completer_lock = threading.Lock() # Register custom special commands. @@ -212,58 +201,61 @@ class MyCli(object): self.cnf_files.append(mylogin_cnf) elif mylogin_cnf_path and not mylogin_cnf: # There was an error reading the login path file. - print('Error: Unable to read login path file.') + print("Error: Unable to read login path file.") self.prompt_app = None def register_special_commands(self): - special.register_special_command(self.change_db, 'use', - '\\u', 'Change to a new database.', aliases=('\\u',)) - special.register_special_command(self.change_db, 'connect', - '\\r', 'Reconnect to the database. Optional database argument.', - aliases=('\\r', ), case_sensitive=True) - special.register_special_command(self.refresh_completions, 'rehash', - '\\#', 'Refresh auto-completions.', arg_type=NO_QUERY, aliases=('\\#',)) + special.register_special_command(self.change_db, "use", "\\u", "Change to a new database.", aliases=("\\u",)) special.register_special_command( - self.change_table_format, 'tableformat', '\\T', - 'Change the table format used to output results.', - aliases=('\\T',), case_sensitive=True) - special.register_special_command(self.execute_from_file, 'source', '\\. filename', - 'Execute commands from file.', aliases=('\\.',)) - special.register_special_command(self.change_prompt_format, 'prompt', - '\\R', 'Change prompt format.', aliases=('\\R',), case_sensitive=True) + self.change_db, + "connect", + "\\r", + "Reconnect to the database. Optional database argument.", + aliases=("\\r",), + case_sensitive=True, + ) + special.register_special_command( + self.refresh_completions, "rehash", "\\#", "Refresh auto-completions.", arg_type=NO_QUERY, aliases=("\\#",) + ) + special.register_special_command( + self.change_table_format, + "tableformat", + "\\T", + "Change the table format used to output results.", + aliases=("\\T",), + case_sensitive=True, + ) + special.register_special_command(self.execute_from_file, "source", "\\. filename", "Execute commands from file.", aliases=("\\.",)) + special.register_special_command( + self.change_prompt_format, "prompt", "\\R", "Change prompt format.", aliases=("\\R",), case_sensitive=True + ) def change_table_format(self, arg, **_): try: self.formatter.format_name = arg - yield (None, None, None, - 'Changed table format to {}'.format(arg)) + yield (None, None, None, "Changed table format to {}".format(arg)) except ValueError: - msg = 'Table format {} not recognized. Allowed formats:'.format( - arg) + msg = "Table format {} not recognized. Allowed formats:".format(arg) for table_type in self.formatter.supported_formats: msg += "\n\t{}".format(table_type) yield (None, None, None, msg) def change_db(self, arg, **_): if not arg: - click.secho( - "No database selected", - err=True, fg="red" - ) + click.secho("No database selected", err=True, fg="red") return - if arg.startswith('`') and arg.endswith('`'): - arg = re.sub(r'^`(.*)`$', r'\1', arg) - arg = re.sub(r'``', r'`', arg) + if arg.startswith("`") and arg.endswith("`"): + arg = re.sub(r"^`(.*)`$", r"\1", arg) + arg = re.sub(r"``", r"`", arg) self.sqlexecute.change_db(arg) - yield (None, None, None, 'You are now connected to database "%s" as ' - 'user "%s"' % (self.sqlexecute.dbname, self.sqlexecute.user)) + yield (None, None, None, 'You are now connected to database "%s" as ' 'user "%s"' % (self.sqlexecute.dbname, self.sqlexecute.user)) def execute_from_file(self, arg, **_): if not arg: - message = 'Missing required argument, filename.' + message = "Missing required argument, filename." return [(None, None, None, message)] try: with open(os.path.expanduser(arg)) as f: @@ -271,9 +263,8 @@ class MyCli(object): except IOError as e: return [(None, None, None, str(e))] - if (self.destructive_warning and - confirm_destructive_query(query) is False): - message = 'Wise choice. Command execution stopped.' + if self.destructive_warning and confirm_destructive_query(query) is False: + message = "Wise choice. Command execution stopped." return [(None, None, None, message)] return self.sqlexecute.run(query) @@ -283,23 +274,23 @@ class MyCli(object): Change the prompt format. """ if not arg: - message = 'Missing required argument, format.' + message = "Missing required argument, format." return [(None, None, None, message)] self.prompt_format = self.get_prompt(arg) return [(None, None, None, "Changed prompt format to %s" % arg)] def initialize_logging(self): - - log_file = os.path.expanduser(self.config['main']['log_file']) - log_level = self.config['main']['log_level'] - - level_map = {'CRITICAL': logging.CRITICAL, - 'ERROR': logging.ERROR, - 'WARNING': logging.WARNING, - 'INFO': logging.INFO, - 'DEBUG': logging.DEBUG - } + log_file = os.path.expanduser(self.config["main"]["log_file"]) + log_level = self.config["main"]["log_level"] + + level_map = { + "CRITICAL": logging.CRITICAL, + "ERROR": logging.ERROR, + "WARNING": logging.WARNING, + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, + } # Disable logging if value is NONE by switching to a no-op handler # Set log level to a high value so it doesn't even waste cycles getting called. @@ -309,26 +300,21 @@ class MyCli(object): elif dir_path_exists(log_file): handler = logging.FileHandler(log_file) else: - self.echo( - 'Error: Unable to open the log file "{}".'.format(log_file), - err=True, fg='red') + self.echo('Error: Unable to open the log file "{}".'.format(log_file), err=True, fg="red") return - formatter = logging.Formatter( - '%(asctime)s (%(process)d/%(threadName)s) ' - '%(name)s %(levelname)s - %(message)s') + formatter = logging.Formatter("%(asctime)s (%(process)d/%(threadName)s) " "%(name)s %(levelname)s - %(message)s") handler.setFormatter(formatter) - root_logger = logging.getLogger('mycli') + root_logger = logging.getLogger("mycli") root_logger.addHandler(handler) root_logger.setLevel(level_map[log_level.upper()]) logging.captureWarnings(True) - root_logger.debug('Initializing mycli logging.') - root_logger.debug('Log file %r.', log_file) - + root_logger.debug("Initializing mycli logging.") + root_logger.debug("Log file %r.", log_file) def read_my_cnf_files(self, files, keys): """ @@ -339,16 +325,16 @@ class MyCli(object): """ cnf = read_config_files(files, list_values=False) - sections = ['client', 'mysqld'] + sections = ["client", "mysqld"] key_transformations = { - 'mysqld': { - 'socket': 'default_socket', - 'port': 'default_port', - 'user': 'default_user', + "mysqld": { + "socket": "default_socket", + "port": "default_port", + "user": "default_user", }, } - if self.login_path and self.login_path != 'client': + if self.login_path and self.login_path != "client": sections.append(self.login_path) if self.defaults_suffix: @@ -357,24 +343,19 @@ class MyCli(object): configuration = defaultdict(lambda: None) for key in keys: for section in cnf: - if ( - section not in sections or - key not in cnf[section] - ): + if section not in sections or key not in cnf[section]: continue new_key = key_transformations.get(section, {}).get(key) or key - configuration[new_key] = strip_matching_quotes( - cnf[section][key]) + configuration[new_key] = strip_matching_quotes(cnf[section][key]) return configuration - def merge_ssl_with_cnf(self, ssl, cnf): """Merge SSL configuration dict with cnf dict""" merged = {} merged.update(ssl) - prefix = 'ssl-' + prefix = "ssl-" for k, v in cnf.items(): # skip unrelated options if not k.startswith(prefix): @@ -383,64 +364,72 @@ class MyCli(object): continue # special case because PyMySQL argument is significantly different # from commandline - if k == 'ssl-verify-server-cert': - merged['check_hostname'] = v + if k == "ssl-verify-server-cert": + merged["check_hostname"] = v else: # use argument name just strip "ssl-" prefix - arg = k[len(prefix):] + arg = k[len(prefix) :] merged[arg] = v return merged - def connect(self, database='', user='', passwd='', host='', port='', - socket='', charset='', local_infile='', ssl='', - ssh_user='', ssh_host='', ssh_port='', - ssh_password='', ssh_key_filename='', init_command='', password_file=''): - - cnf = {'database': None, - 'user': None, - 'password': None, - 'host': None, - 'port': None, - 'socket': None, - 'default_socket': None, - 'default-character-set': None, - 'local-infile': None, - 'loose-local-infile': None, - 'ssl-ca': None, - 'ssl-cert': None, - 'ssl-key': None, - 'ssl-cipher': None, - 'ssl-verify-serer-cert': None, + def connect( + self, + database="", + user="", + passwd="", + host="", + port="", + socket="", + charset="", + local_infile="", + ssl="", + ssh_user="", + ssh_host="", + ssh_port="", + ssh_password="", + ssh_key_filename="", + init_command="", + password_file="", + ): + cnf = { + "database": None, + "user": None, + "password": None, + "host": None, + "port": None, + "socket": None, + "default_socket": None, + "default-character-set": None, + "local-infile": None, + "loose-local-infile": None, + "ssl-ca": None, + "ssl-cert": None, + "ssl-key": None, + "ssl-cipher": None, + "ssl-verify-serer-cert": None, } cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys()) # Fall back to config values only if user did not specify a value. - database = database or cnf['database'] - user = user or cnf['user'] or os.getenv('USER') - host = host or cnf['host'] - port = port or cnf['port'] + database = database or cnf["database"] + user = user or cnf["user"] or os.getenv("USER") + host = host or cnf["host"] + port = port or cnf["port"] ssl = ssl or {} port = port and int(port) if not port: port = 3306 - if not host or host == 'localhost': - socket = ( - socket or - cnf['socket'] or - cnf['default_socket'] or - guess_socket_location() - ) - + if not host or host == "localhost": + socket = socket or cnf["socket"] or cnf["default_socket"] or guess_socket_location() - passwd = passwd if isinstance(passwd, str) else cnf['password'] - charset = charset or cnf['default-character-set'] or 'utf8' + passwd = passwd if isinstance(passwd, str) else cnf["password"] + charset = charset or cnf["default-character-set"] or "utf8" # Favor whichever local_infile option is set. - for local_infile_option in (local_infile, cnf['local-infile'], - cnf['loose-local-infile'], False): + for local_infile_option in (local_infile, cnf["local-infile"], cnf["loose-local-infile"], False): try: local_infile = str_to_bool(local_infile_option) break @@ -461,21 +450,44 @@ class MyCli(object): def _connect(): try: self.sqlexecute = SQLExecute( - database, user, passwd, host, port, socket, charset, - local_infile, ssl, ssh_user, ssh_host, ssh_port, - ssh_password, ssh_key_filename, init_command + database, + user, + passwd, + host, + port, + socket, + charset, + local_infile, + ssl, + ssh_user, + ssh_host, + ssh_port, + ssh_password, + ssh_key_filename, + init_command, ) except OperationalError as e: if e.args[0] == ERROR_CODE_ACCESS_DENIED: if password_from_file: new_passwd = password_from_file else: - new_passwd = click.prompt('Password', hide_input=True, - show_default=False, type=str, err=True) + new_passwd = click.prompt("Password", hide_input=True, show_default=False, type=str, err=True) self.sqlexecute = SQLExecute( - database, user, new_passwd, host, port, socket, - charset, local_infile, ssl, ssh_user, ssh_host, - ssh_port, ssh_password, ssh_key_filename, init_command + database, + user, + new_passwd, + host, + port, + socket, + charset, + local_infile, + ssl, + ssh_user, + ssh_host, + ssh_port, + ssh_password, + ssh_key_filename, + init_command, ) else: raise e @@ -483,54 +495,48 @@ class MyCli(object): try: if not WIN and socket: socket_owner = getpwuid(os.stat(socket).st_uid).pw_name - self.echo( - f"Connecting to socket {socket}, owned by user {socket_owner}", err=True) + self.echo(f"Connecting to socket {socket}, owned by user {socket_owner}", err=True) try: _connect() except OperationalError as e: # These are "Can't open socket" and 2x "Can't connect" if [code for code in (2001, 2002, 2003) if code == e.args[0]]: - self.logger.debug('Database connection failed: %r.', e) - self.logger.error( - "traceback: %r", traceback.format_exc()) - self.logger.debug('Retrying over TCP/IP') - self.echo( - "Failed to connect to local MySQL server through socket '{}':".format(socket)) + self.logger.debug("Database connection failed: %r.", e) + self.logger.error("traceback: %r", traceback.format_exc()) + self.logger.debug("Retrying over TCP/IP") + self.echo("Failed to connect to local MySQL server through socket '{}':".format(socket)) self.echo(str(e), err=True) - self.echo( - 'Retrying over TCP/IP', err=True) + self.echo("Retrying over TCP/IP", err=True) # Else fall back to TCP/IP localhost socket = "" - host = 'localhost' + host = "localhost" port = 3306 _connect() else: raise e else: - host = host or 'localhost' + host = host or "localhost" port = port or 3306 # Bad ports give particularly daft error messages try: port = int(port) - except ValueError as e: - self.echo("Error: Invalid port number: '{0}'.".format(port), - err=True, fg='red') + except ValueError: + self.echo("Error: Invalid port number: '{0}'.".format(port), err=True, fg="red") exit(1) _connect() except Exception as e: # Connecting to a database could fail. - self.logger.debug('Database connection failed: %r.', e) + self.logger.debug("Database connection failed: %r.", e) self.logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg='red') + self.echo(str(e), err=True, fg="red") exit(1) def get_password_from_file(self, password_file): password_from_file = None if password_file: - if (os.path.isfile(password_file) or stat.S_ISFIFO(os.stat(password_file).st_mode)) \ - and os.access(password_file, os.R_OK): + if (os.path.isfile(password_file) or stat.S_ISFIFO(os.stat(password_file).st_mode)) and os.access(password_file, os.R_OK): with open(password_file) as fp: password_from_file = fp.readline() password_from_file = password_from_file.rstrip().lstrip() @@ -552,8 +558,7 @@ class MyCli(object): while special.editor_command(text): filename = special.get_filename(text) - query = (special.get_editor_query(text) or - self.get_last_query()) + query = special.get_editor_query(text) or self.get_last_query() sql, message = special.open_external_editor(filename, sql=query) if message: # Something went wrong. Raise an exception and bail. @@ -578,8 +583,7 @@ class MyCli(object): """ if special.clip_command(text): - query = (special.get_clip_query(text) or - self.get_last_query()) + query = special.get_clip_query(text) or self.get_last_query() message = special.copy_query_to_clipboard(sql=query) if message: raise RuntimeError(message) @@ -588,30 +592,30 @@ class MyCli(object): def handle_prettify_binding(self, text): try: - statements = sqlglot.parse(text, read='mysql') - except Exception as e: + statements = sqlglot.parse(text, read="mysql") + except Exception: statements = [] if len(statements) == 1 and statements[0]: - pretty_text = statements[0].sql(pretty=True, pad=4, dialect='mysql') + pretty_text = statements[0].sql(pretty=True, pad=4, dialect="mysql") else: - pretty_text = '' - self.toolbar_error_message = 'Prettify failed to parse statement' + pretty_text = "" + self.toolbar_error_message = "Prettify failed to parse statement" if len(pretty_text) > 0: - pretty_text = pretty_text + ';' + pretty_text = pretty_text + ";" return pretty_text def handle_unprettify_binding(self, text): try: - statements = sqlglot.parse(text, read='mysql') - except Exception as e: + statements = sqlglot.parse(text, read="mysql") + except Exception: statements = [] if len(statements) == 1 and statements[0]: - unpretty_text = statements[0].sql(pretty=False, dialect='mysql') + unpretty_text = statements[0].sql(pretty=False, dialect="mysql") else: - unpretty_text = '' - self.toolbar_error_message = 'Unprettify failed to parse statement' + unpretty_text = "" + self.toolbar_error_message = "Unprettify failed to parse statement" if len(unpretty_text) > 0: - unpretty_text = unpretty_text + ';' + unpretty_text = unpretty_text + ";" return unpretty_text def run_cli(self): @@ -623,24 +627,24 @@ class MyCli(object): if self.smart_completion: self.refresh_completions() - history_file = os.path.expanduser( - os.environ.get('MYCLI_HISTFILE', '~/.mycli-history')) + history_file = os.path.expanduser(os.environ.get("MYCLI_HISTFILE", "~/.mycli-history")) if dir_path_exists(history_file): history = FileHistoryWithTimestamp(history_file) else: history = None self.echo( - 'Error: Unable to open the history file "{}". ' - 'Your query history will not be saved.'.format(history_file), - err=True, fg='red') + 'Error: Unable to open the history file "{}". ' "Your query history will not be saved.".format(history_file), + err=True, + fg="red", + ) key_bindings = mycli_bindings(self) if not self.less_chatty: print(sqlexecute.server_info) - print('mycli', __version__) + print("mycli", __version__) print(SUPPORT_INFO) - print('Thanks to the contributor -', thanks_picker()) + print("Thanks to the contributor -", thanks_picker()) def get_message(): prompt = self.get_prompt(self.prompt_format) @@ -650,16 +654,14 @@ class MyCli(object): return ANSI(prompt) def get_continuation(width, *_): - if self.multiline_continuation_char == '': - continuation = '' + if self.multiline_continuation_char == "": + continuation = "" elif self.multiline_continuation_char: left_padding = width - len(self.multiline_continuation_char) - continuation = " " * \ - max((left_padding - 1), 0) + \ - self.multiline_continuation_char + " " + continuation = " " * max((left_padding - 1), 0) + self.multiline_continuation_char + " " else: continuation = " " - return [('class:continuation', continuation)] + return [("class:continuation", continuation)] def show_suggestion_tip(): return iterations < 2 @@ -678,7 +680,7 @@ class MyCli(object): except RuntimeError as e: logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg='red') + self.echo(str(e), err=True, fg="red") return try: @@ -687,7 +689,7 @@ class MyCli(object): except RuntimeError as e: logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg='red') + self.echo(str(e), err=True, fg="red") return if not text.strip(): @@ -698,9 +700,9 @@ class MyCli(object): if destroy is None: pass # Query was not destructive. Nothing to do here. elif destroy is True: - self.echo('Your call!') + self.echo("Your call!") else: - self.echo('Wise choice!') + self.echo("Wise choice!") return else: destroy = True @@ -711,13 +713,13 @@ class MyCli(object): mutating = False try: - logger.debug('sql: %r', text) + logger.debug("sql: %r", text) special.write_tee(self.get_prompt(self.prompt_format) + text) if self.logfile: - self.logfile.write('\n# %s\n' % datetime.now()) + self.logfile.write("\n# %s\n" % datetime.now()) self.logfile.write(text) - self.logfile.write('\n') + self.logfile.write("\n") successful = False start = time() @@ -730,12 +732,10 @@ class MyCli(object): logger.debug("rows: %r", cur) logger.debug("status: %r", status) threshold = 1000 - if (is_select(status) and - cur and cur.rowcount > threshold): - self.echo('The result set has more than {} rows.'.format( - threshold), fg='red') - if not confirm('Do you want to continue?'): - self.echo("Aborted!", err=True, fg='red') + if is_select(status) and cur and cur.rowcount > threshold: + self.echo("The result set has more than {} rows.".format(threshold), fg="red") + if not confirm("Do you want to continue?"): + self.echo("Aborted!", err=True, fg="red") break if self.auto_vertical_output: @@ -743,14 +743,12 @@ class MyCli(object): else: max_width = None - formatted = self.format_output( - title, cur, headers, special.is_expanded_output(), - max_width) + formatted = self.format_output(title, cur, headers, special.is_expanded_output(), max_width) t = time() - start try: if result_count > 0: - self.echo('') + self.echo("") try: self.output(formatted, status) except KeyboardInterrupt: @@ -758,7 +756,7 @@ class MyCli(object): if self.beep_after_seconds > 0 and t >= self.beep_after_seconds: self.bell() if special.is_timing_enabled(): - self.echo('Time: %0.03fs' % t) + self.echo("Time: %0.03fs" % t) except KeyboardInterrupt: pass @@ -778,42 +776,40 @@ class MyCli(object): # Restart connection to the database sqlexecute.connect() try: - for title, cur, headers, status in sqlexecute.run('kill %s' % connection_id_to_kill): + for title, cur, headers, status in sqlexecute.run("kill %s" % connection_id_to_kill): status_str = str(status).lower() - if status_str.find('ok') > -1: - logger.debug("cancelled query, connection id: %r, sql: %r", - connection_id_to_kill, text) - self.echo("cancelled query", err=True, fg='red') + if status_str.find("ok") > -1: + logger.debug("cancelled query, connection id: %r, sql: %r", connection_id_to_kill, text) + self.echo("cancelled query", err=True, fg="red") except Exception as e: - self.echo('Encountered error while cancelling query: {}'.format(e), - err=True, fg='red') + self.echo("Encountered error while cancelling query: {}".format(e), err=True, fg="red") else: - logger.debug("Did not get a connection id, skip cancelling query") + logger.debug("Did not get a connection id, skip cancelling query") except NotImplementedError: - self.echo('Not Yet Implemented.', fg="yellow") + self.echo("Not Yet Implemented.", fg="yellow") except OperationalError as e: logger.debug("Exception: %r", e) - if (e.args[0] in (2003, 2006, 2013)): - logger.debug('Attempting to reconnect.') - self.echo('Reconnecting...', fg='yellow') + if e.args[0] in (2003, 2006, 2013): + logger.debug("Attempting to reconnect.") + self.echo("Reconnecting...", fg="yellow") try: sqlexecute.connect() - logger.debug('Reconnected successfully.') + logger.debug("Reconnected successfully.") one_iteration(text) return # OK to just return, cuz the recursion call runs to the end. except OperationalError as e: - logger.debug('Reconnect failed. e: %r', e) - self.echo(str(e), err=True, fg='red') + logger.debug("Reconnect failed. e: %r", e) + self.echo(str(e), err=True, fg="red") # If reconnection failed, don't proceed further. return else: logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg='red') + self.echo(str(e), err=True, fg="red") except Exception as e: logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg='red') + self.echo(str(e), err=True, fg="red") else: if is_dropping_database(text, self.sqlexecute.dbname): self.sqlexecute.dbname = None @@ -821,25 +817,21 @@ class MyCli(object): # Refresh the table names and column names if necessary. if need_completion_refresh(text): - self.refresh_completions( - reset=need_completion_reset(text)) + self.refresh_completions(reset=need_completion_reset(text)) finally: if self.logfile is False: - self.echo("Warning: This query was not logged.", - err=True, fg='red') + self.echo("Warning: This query was not logged.", err=True, fg="red") query = Query(text, successful, mutating) self.query_history.append(query) - get_toolbar_tokens = create_toolbar_tokens_func( - self, show_suggestion_tip) + get_toolbar_tokens = create_toolbar_tokens_func(self, show_suggestion_tip) if self.wider_completion_menu: complete_style = CompleteStyle.MULTI_COLUMN else: complete_style = CompleteStyle.COLUMN with self._completer_lock: - - if self.key_bindings == 'vi': + if self.key_bindings == "vi": editing_mode = EditingMode.VI else: editing_mode = EditingMode.EMACS @@ -851,12 +843,12 @@ class MyCli(object): prompt_continuation=get_continuation, bottom_toolbar=get_toolbar_tokens, complete_style=complete_style, - input_processors=[ConditionalProcessor( - processor=HighlightMatchingBracketProcessor( - chars='[](){}'), - filter=HasFocus(DEFAULT_BUFFER) & ~IsDone() - )], - tempfile_suffix='.sql', + input_processors=[ + ConditionalProcessor( + processor=HighlightMatchingBracketProcessor(chars="[](){}"), filter=HasFocus(DEFAULT_BUFFER) & ~IsDone() + ) + ], + tempfile_suffix=".sql", completer=DynamicCompleter(lambda: self.completer), history=history, auto_suggest=AutoSuggestFromHistory(), @@ -869,7 +861,7 @@ class MyCli(object): enable_system_prompt=True, enable_suspend=True, editing_mode=editing_mode, - search_ignore_case=True + search_ignore_case=True, ) try: @@ -879,7 +871,7 @@ class MyCli(object): except EOFError: special.close_tee() if not self.less_chatty: - self.echo('Goodbye!') + self.echo("Goodbye!") def log_output(self, output): """Log the output in the audit log, if it's enabled.""" @@ -898,22 +890,20 @@ class MyCli(object): click.secho(s, **kwargs) def bell(self): - """Print a bell on the stderr. - """ - click.secho('\a', err=True, nl=False) + """Print a bell on the stderr.""" + click.secho("\a", err=True, nl=False) def get_output_margin(self, status=None): """Get the output margin (number of rows for the prompt, footer and timing message.""" - margin = self.get_reserved_space() + self.get_prompt(self.prompt_format).count('\n') + 1 + margin = self.get_reserved_space() + self.get_prompt(self.prompt_format).count("\n") + 1 if special.is_timing_enabled(): margin += 1 if status: - margin += 1 + status.count('\n') + margin += 1 + status.count("\n") return margin - def output(self, output, status=None): """Output text to stdout or a pager command. @@ -957,9 +947,11 @@ class MyCli(object): if buf: if output_via_pager: + def newlinewrapper(text): for line in text: yield line + "\n" + click.echo_via_pager(newlinewrapper(buf)) else: for line in buf: @@ -971,18 +963,18 @@ class MyCli(object): def configure_pager(self): # Provide sane defaults for less if they are empty. - if not os.environ.get('LESS'): - os.environ['LESS'] = '-RXF' + if not os.environ.get("LESS"): + os.environ["LESS"] = "-RXF" - cnf = self.read_my_cnf_files(self.cnf_files, ['pager', 'skip-pager']) - cnf_pager = cnf['pager'] or self.config['main']['pager'] + cnf = self.read_my_cnf_files(self.cnf_files, ["pager", "skip-pager"]) + cnf_pager = cnf["pager"] or self.config["main"]["pager"] if cnf_pager: special.set_pager(cnf_pager) self.explicit_pager = True else: self.explicit_pager = False - if cnf['skip-pager'] or not self.config['main'].as_bool('enable_pager'): + if cnf["skip-pager"] or not self.config["main"].as_bool("enable_pager"): special.disable_pager() def refresh_completions(self, reset=False): @@ -990,17 +982,19 @@ class MyCli(object): with self._completer_lock: self.completer.reset_completions() self.completion_refresher.refresh( - self.sqlexecute, self._on_completions_refreshed, - {'smart_completion': self.smart_completion, - 'supported_formats': self.formatter.supported_formats, - 'keyword_casing': self.completer.keyword_casing}) + self.sqlexecute, + self._on_completions_refreshed, + { + "smart_completion": self.smart_completion, + "supported_formats": self.formatter.supported_formats, + "keyword_casing": self.completer.keyword_casing, + }, + ) - return [(None, None, None, - 'Auto-completion refresh started in the background.')] + return [(None, None, None, "Auto-completion refresh started in the background.")] def _on_completions_refreshed(self, new_completer): - """Swap the completer object in cli with the newly created completer. - """ + """Swap the completer object in cli with the newly created completer.""" with self._completer_lock: self.completer = new_completer @@ -1011,27 +1005,26 @@ class MyCli(object): def get_completions(self, text, cursor_positition): with self._completer_lock: - return self.completer.get_completions( - Document(text=text, cursor_position=cursor_positition), None) + return self.completer.get_completions(Document(text=text, cursor_position=cursor_positition), None) def get_prompt(self, string): sqlexecute = self.sqlexecute host = self.login_path if self.login_path and self.login_path_as_host else sqlexecute.host now = datetime.now() - string = string.replace('\\u', sqlexecute.user or '(none)') - string = string.replace('\\h', host or '(none)') - string = string.replace('\\d', sqlexecute.dbname or '(none)') - string = string.replace('\\t', sqlexecute.server_info.species.name) - string = string.replace('\\n', "\n") - string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y')) - string = string.replace('\\m', now.strftime('%M')) - string = string.replace('\\P', now.strftime('%p')) - string = string.replace('\\R', now.strftime('%H')) - string = string.replace('\\r', now.strftime('%I')) - string = string.replace('\\s', now.strftime('%S')) - string = string.replace('\\p', str(sqlexecute.port)) - string = string.replace('\\A', self.dsn_alias or '(none)') - string = string.replace('\\_', ' ') + string = string.replace("\\u", sqlexecute.user or "(none)") + string = string.replace("\\h", host or "(none)") + string = string.replace("\\d", sqlexecute.dbname or "(none)") + string = string.replace("\\t", sqlexecute.server_info.species.name) + string = string.replace("\\n", "\n") + string = string.replace("\\D", now.strftime("%a %b %d %H:%M:%S %Y")) + string = string.replace("\\m", now.strftime("%M")) + string = string.replace("\\P", now.strftime("%p")) + string = string.replace("\\R", now.strftime("%H")) + string = string.replace("\\r", now.strftime("%I")) + string = string.replace("\\s", now.strftime("%S")) + string = string.replace("\\p", str(sqlexecute.port)) + string = string.replace("\\A", self.dsn_alias or "(none)") + string = string.replace("\\_", " ") return string def run_query(self, query, new_line=True): @@ -1044,49 +1037,45 @@ class MyCli(object): for line in output: click.echo(line, nl=new_line) - def format_output(self, title, cur, headers, expanded=False, - max_width=None): - expanded = expanded or self.formatter.format_name == 'vertical' + def format_output(self, title, cur, headers, expanded=False, max_width=None): + expanded = expanded or self.formatter.format_name == "vertical" output = [] - output_kwargs = { - 'dialect': 'unix', - 'disable_numparse': True, - 'preserve_whitespace': True, - 'style': self.output_style - } + output_kwargs = {"dialect": "unix", "disable_numparse": True, "preserve_whitespace": True, "style": self.output_style} - if not self.formatter.format_name in sql_format.supported_formats: - output_kwargs["preprocessors"] = (preprocessors.align_decimals, ) + if self.formatter.format_name not in sql_format.supported_formats: + output_kwargs["preprocessors"] = (preprocessors.align_decimals,) if title: # Only print the title if it's not None. output = itertools.chain(output, [title]) if cur: column_types = None - if hasattr(cur, 'description'): + if hasattr(cur, "description"): + def get_col_type(col): col_type = FIELD_TYPES.get(col[1], str) return col_type if type(col_type) is type else str + column_types = [get_col_type(col) for col in cur.description] if max_width is not None: cur = list(cur) formatted = self.formatter.format_output( - cur, headers, format_name='vertical' if expanded else None, - column_types=column_types, - **output_kwargs) + cur, headers, format_name="vertical" if expanded else None, column_types=column_types, **output_kwargs + ) if isinstance(formatted, str): formatted = formatted.splitlines() formatted = iter(formatted) - if (not expanded and max_width and headers and cur): + if not expanded and max_width and headers and cur: first_line = next(formatted) if len(strip_ansi(first_line)) > max_width: formatted = self.formatter.format_output( - cur, headers, format_name='vertical', column_types=column_types, **output_kwargs) + cur, headers, format_name="vertical", column_types=column_types, **output_kwargs + ) if isinstance(formatted, str): formatted = iter(formatted.splitlines()) else: @@ -1094,12 +1083,11 @@ class MyCli(object): output = itertools.chain(output, formatted) - return output def get_reserved_space(self): """Get the number of lines to reserve for the completion menu.""" - reserved_space_ratio = .45 + reserved_space_ratio = 0.45 max_reserved_space = 8 _, height = shutil.get_terminal_size() return min(int(round(height * reserved_space_ratio)), max_reserved_space) @@ -1110,91 +1098,108 @@ class MyCli(object): @click.command() -@click.option('-h', '--host', envvar='MYSQL_HOST', help='Host address of the database.') -@click.option('-P', '--port', envvar='MYSQL_TCP_PORT', type=int, help='Port number to use for connection. Honors ' - '$MYSQL_TCP_PORT.') -@click.option('-u', '--user', help='User name to connect to the database.') -@click.option('-S', '--socket', envvar='MYSQL_UNIX_PORT', help='The socket file to use for connection.') -@click.option('-p', '--password', 'password', envvar='MYSQL_PWD', type=str, - help='Password to connect to the database.') -@click.option('--pass', 'password', envvar='MYSQL_PWD', type=str, - help='Password to connect to the database.') -@click.option('--ssh-user', help='User name to connect to ssh server.') -@click.option('--ssh-host', help='Host name to connect to ssh server.') -@click.option('--ssh-port', default=22, help='Port to connect to ssh server.') -@click.option('--ssh-password', help='Password to connect to ssh server.') -@click.option('--ssh-key-filename', help='Private key filename (identify file) for the ssh connection.') -@click.option('--ssh-config-path', help='Path to ssh configuration.', - default=os.path.expanduser('~') + '/.ssh/config') -@click.option('--ssh-config-host', help='Host to connect to ssh server reading from ssh configuration.') -@click.option('--ssl', 'ssl_enable', is_flag=True, - help='Enable SSL for connection (automatically enabled with other flags).') -@click.option('--ssl-ca', help='CA file in PEM format.', - type=click.Path(exists=True)) -@click.option('--ssl-capath', help='CA directory.') -@click.option('--ssl-cert', help='X509 cert in PEM format.', - type=click.Path(exists=True)) -@click.option('--ssl-key', help='X509 key in PEM format.', - type=click.Path(exists=True)) -@click.option('--ssl-cipher', help='SSL cipher to use.') -@click.option('--tls-version', - type=click.Choice(['TLSv1', 'TLSv1.1', 'TLSv1.2', 'TLSv1.3'], case_sensitive=False), - help='TLS protocol version for secure connection.') -@click.option('--ssl-verify-server-cert', is_flag=True, - help=('Verify server\'s "Common Name" in its cert against ' - 'hostname used when connecting. This option is disabled ' - 'by default.')) +@click.option("-h", "--host", envvar="MYSQL_HOST", help="Host address of the database.") +@click.option("-P", "--port", envvar="MYSQL_TCP_PORT", type=int, help="Port number to use for connection. Honors " "$MYSQL_TCP_PORT.") +@click.option("-u", "--user", help="User name to connect to the database.") +@click.option("-S", "--socket", envvar="MYSQL_UNIX_PORT", help="The socket file to use for connection.") +@click.option("-p", "--password", "password", envvar="MYSQL_PWD", type=str, help="Password to connect to the database.") +@click.option("--pass", "password", envvar="MYSQL_PWD", type=str, help="Password to connect to the database.") +@click.option("--ssh-user", help="User name to connect to ssh server.") +@click.option("--ssh-host", help="Host name to connect to ssh server.") +@click.option("--ssh-port", default=22, help="Port to connect to ssh server.") +@click.option("--ssh-password", help="Password to connect to ssh server.") +@click.option("--ssh-key-filename", help="Private key filename (identify file) for the ssh connection.") +@click.option("--ssh-config-path", help="Path to ssh configuration.", default=os.path.expanduser("~") + "/.ssh/config") +@click.option("--ssh-config-host", help="Host to connect to ssh server reading from ssh configuration.") +@click.option("--ssl", "ssl_enable", is_flag=True, help="Enable SSL for connection (automatically enabled with other flags).") +@click.option("--ssl-ca", help="CA file in PEM format.", type=click.Path(exists=True)) +@click.option("--ssl-capath", help="CA directory.") +@click.option("--ssl-cert", help="X509 cert in PEM format.", type=click.Path(exists=True)) +@click.option("--ssl-key", help="X509 key in PEM format.", type=click.Path(exists=True)) +@click.option("--ssl-cipher", help="SSL cipher to use.") +@click.option( + "--tls-version", + type=click.Choice(["TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3"], case_sensitive=False), + help="TLS protocol version for secure connection.", +) +@click.option( + "--ssl-verify-server-cert", + is_flag=True, + help=('Verify server\'s "Common Name" in its cert against ' "hostname used when connecting. This option is disabled " "by default."), +) # as of 2016-02-15 revocation list is not supported by underling PyMySQL # library (--ssl-crl and --ssl-crlpath options in vanilla mysql client) -@click.option('-V', '--version', is_flag=True, help='Output mycli\'s version.') -@click.option('-v', '--verbose', is_flag=True, help='Verbose output.') -@click.option('-D', '--database', 'dbname', help='Database to use.') -@click.option('-d', '--dsn', default='', envvar='DSN', - help='Use DSN configured into the [alias_dsn] section of myclirc file.') -@click.option('--list-dsn', 'list_dsn', is_flag=True, - help='list of DSN configured into the [alias_dsn] section of myclirc file.') -@click.option('--list-ssh-config', 'list_ssh_config', is_flag=True, - help='list ssh configurations in the ssh config (requires paramiko).') -@click.option('-R', '--prompt', 'prompt', - help='Prompt format (Default: "{0}").'.format( - MyCli.default_prompt)) -@click.option('-l', '--logfile', type=click.File(mode='a', encoding='utf-8'), - help='Log every query and its results to a file.') -@click.option('--defaults-group-suffix', type=str, - help='Read MySQL config groups with the specified suffix.') -@click.option('--defaults-file', type=click.Path(), - help='Only read MySQL options from the given file.') -@click.option('--myclirc', type=click.Path(), default="~/.myclirc", - help='Location of myclirc file.') -@click.option('--auto-vertical-output', is_flag=True, - help='Automatically switch to vertical output mode if the result is wider than the terminal width.') -@click.option('-t', '--table', is_flag=True, - help='Display batch output in table format.') -@click.option('--csv', is_flag=True, - help='Display batch output in CSV format.') -@click.option('--warn/--no-warn', default=None, - help='Warn before running a destructive query.') -@click.option('--local-infile', type=bool, - help='Enable/disable LOAD DATA LOCAL INFILE.') -@click.option('-g', '--login-path', type=str, - help='Read this path from the login file.') -@click.option('-e', '--execute', type=str, - help='Execute command and quit.') -@click.option('--init-command', type=str, - help='SQL statement to execute after connecting.') -@click.option('--charset', type=str, - help='Character set for MySQL session.') -@click.option('--password-file', type=click.Path(), - help='File or FIFO path containing the password to connect to the db if not specified otherwise.') -@click.argument('database', default='', nargs=1) -def cli(database, user, host, port, socket, password, dbname, - version, verbose, prompt, logfile, defaults_group_suffix, - defaults_file, login_path, auto_vertical_output, local_infile, - ssl_enable, ssl_ca, ssl_capath, ssl_cert, ssl_key, ssl_cipher, - tls_version, ssl_verify_server_cert, table, csv, warn, execute, - myclirc, dsn, list_dsn, ssh_user, ssh_host, ssh_port, ssh_password, - ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host, - init_command, charset, password_file): +@click.version_option(__version__, "-V", "--version", help="Output mycli's version.") +@click.option("-v", "--verbose", is_flag=True, help="Verbose output.") +@click.option("-D", "--database", "dbname", help="Database to use.") +@click.option("-d", "--dsn", default="", envvar="DSN", help="Use DSN configured into the [alias_dsn] section of myclirc file.") +@click.option("--list-dsn", "list_dsn", is_flag=True, help="list of DSN configured into the [alias_dsn] section of myclirc file.") +@click.option("--list-ssh-config", "list_ssh_config", is_flag=True, help="list ssh configurations in the ssh config (requires paramiko).") +@click.option("-R", "--prompt", "prompt", help='Prompt format (Default: "{0}").'.format(MyCli.default_prompt)) +@click.option("-l", "--logfile", type=click.File(mode="a", encoding="utf-8"), help="Log every query and its results to a file.") +@click.option("--defaults-group-suffix", type=str, help="Read MySQL config groups with the specified suffix.") +@click.option("--defaults-file", type=click.Path(), help="Only read MySQL options from the given file.") +@click.option("--myclirc", type=click.Path(), default="~/.myclirc", help="Location of myclirc file.") +@click.option( + "--auto-vertical-output", + is_flag=True, + help="Automatically switch to vertical output mode if the result is wider than the terminal width.", +) +@click.option("-t", "--table", is_flag=True, help="Display batch output in table format.") +@click.option("--csv", is_flag=True, help="Display batch output in CSV format.") +@click.option("--warn/--no-warn", default=None, help="Warn before running a destructive query.") +@click.option("--local-infile", type=bool, help="Enable/disable LOAD DATA LOCAL INFILE.") +@click.option("-g", "--login-path", type=str, help="Read this path from the login file.") +@click.option("-e", "--execute", type=str, help="Execute command and quit.") +@click.option("--init-command", type=str, help="SQL statement to execute after connecting.") +@click.option("--charset", type=str, help="Character set for MySQL session.") +@click.option( + "--password-file", type=click.Path(), help="File or FIFO path containing the password to connect to the db if not specified otherwise." +) +@click.argument("database", default="", nargs=1) +def cli( + database, + user, + host, + port, + socket, + password, + dbname, + verbose, + prompt, + logfile, + defaults_group_suffix, + defaults_file, + login_path, + auto_vertical_output, + local_infile, + ssl_enable, + ssl_ca, + ssl_capath, + ssl_cert, + ssl_key, + ssl_cipher, + tls_version, + ssl_verify_server_cert, + table, + csv, + warn, + execute, + myclirc, + dsn, + list_dsn, + ssh_user, + ssh_host, + ssh_port, + ssh_password, + ssh_key_filename, + list_ssh_config, + ssh_config_path, + ssh_config_host, + init_command, + charset, + password_file, +): """A MySQL terminal client with auto-completion and syntax highlighting. \b @@ -1204,26 +1209,24 @@ def cli(database, user, host, port, socket, password, dbname, - mycli mysql://my_user@my_host.com:3306/my_database """ - - if version: - print('Version:', __version__) - sys.exit(0) - - mycli = MyCli(prompt=prompt, logfile=logfile, - defaults_suffix=defaults_group_suffix, - defaults_file=defaults_file, login_path=login_path, - auto_vertical_output=auto_vertical_output, warn=warn, - myclirc=myclirc) + mycli = MyCli( + prompt=prompt, + logfile=logfile, + defaults_suffix=defaults_group_suffix, + defaults_file=defaults_file, + login_path=login_path, + auto_vertical_output=auto_vertical_output, + warn=warn, + myclirc=myclirc, + ) if list_dsn: try: - alias_dsn = mycli.config['alias_dsn'] - except KeyError as err: - click.secho('Invalid DSNs found in the config file. '\ - 'Please check the "[alias_dsn]" section in myclirc.', - err=True, fg='red') + alias_dsn = mycli.config["alias_dsn"] + except KeyError: + click.secho("Invalid DSNs found in the config file. " 'Please check the "[alias_dsn]" section in myclirc.', err=True, fg="red") exit(1) except Exception as e: - click.secho(str(e), err=True, fg='red') + click.secho(str(e), err=True, fg="red") exit(1) for alias, value in alias_dsn.items(): if verbose: @@ -1236,8 +1239,7 @@ def cli(database, user, host, port, socket, password, dbname, for host in ssh_config.get_hostnames(): if verbose: host_config = ssh_config.lookup(host) - click.secho("{} : {}".format( - host, host_config.get('hostname'))) + click.secho("{} : {}".format(host, host_config.get("hostname"))) else: click.secho(host) sys.exit(0) @@ -1245,15 +1247,15 @@ def cli(database, user, host, port, socket, password, dbname, database = dbname or database ssl = { - 'enable': ssl_enable, - 'ca': ssl_ca and os.path.expanduser(ssl_ca), - 'cert': ssl_cert and os.path.expanduser(ssl_cert), - 'key': ssl_key and os.path.expanduser(ssl_key), - 'capath': ssl_capath, - 'cipher': ssl_cipher, - 'tls_version': tls_version, - 'check_hostname': ssl_verify_server_cert, - } + "enable": ssl_enable, + "ca": ssl_ca and os.path.expanduser(ssl_ca), + "cert": ssl_cert and os.path.expanduser(ssl_cert), + "key": ssl_key and os.path.expanduser(ssl_key), + "capath": ssl_capath, + "cipher": ssl_cipher, + "tls_version": tls_version, + "check_hostname": ssl_verify_server_cert, + } # remove empty ssl options ssl = {k: v for k, v in ssl.items() if v is not None} @@ -1262,20 +1264,21 @@ def cli(database, user, host, port, socket, password, dbname, # Treat the database argument as a DSN alias if we're missing # other connection information. - if (mycli.config['alias_dsn'] and database and '://' not in database - and not any([user, password, host, port, login_path])): - dsn, database = database, '' + if mycli.config["alias_dsn"] and database and "://" not in database and not any([user, password, host, port, login_path]): + dsn, database = database, "" - if database and '://' in database: - dsn_uri, database = database, '' + if database and "://" in database: + dsn_uri, database = database, "" if dsn: try: - dsn_uri = mycli.config['alias_dsn'][dsn] + dsn_uri = mycli.config["alias_dsn"][dsn] except KeyError: - click.secho('Could not find the specified DSN in the config file. ' - 'Please check the "[alias_dsn]" section in your ' - 'myclirc.', err=True, fg='red') + click.secho( + "Could not find the specified DSN in the config file. " 'Please check the "[alias_dsn]" section in your ' "myclirc.", + err=True, + fg="red", + ) exit(1) else: mycli.dsn_alias = dsn @@ -1294,16 +1297,13 @@ def cli(database, user, host, port, socket, password, dbname, port = uri.port if ssh_config_host: - ssh_config = read_ssh_config( - ssh_config_path - ).lookup(ssh_config_host) - ssh_host = ssh_host if ssh_host else ssh_config.get('hostname') - ssh_user = ssh_user if ssh_user else ssh_config.get('user') - if ssh_config.get('port') and ssh_port == 22: + ssh_config = read_ssh_config(ssh_config_path).lookup(ssh_config_host) + ssh_host = ssh_host if ssh_host else ssh_config.get("hostname") + ssh_user = ssh_user if ssh_user else ssh_config.get("user") + if ssh_config.get("port") and ssh_port == 22: # port has a default value, overwrite it if it's in the config - ssh_port = int(ssh_config.get('port')) - ssh_key_filename = ssh_key_filename if ssh_key_filename else ssh_config.get( - 'identityfile', [None])[0] + ssh_port = int(ssh_config.get("port")) + ssh_key_filename = ssh_key_filename if ssh_key_filename else ssh_config.get("identityfile", [None])[0] ssh_key_filename = ssh_key_filename and os.path.expanduser(ssh_key_filename) @@ -1323,52 +1323,48 @@ def cli(database, user, host, port, socket, password, dbname, ssh_key_filename=ssh_key_filename, init_command=init_command, charset=charset, - password_file=password_file + password_file=password_file, ) - mycli.logger.debug('Launch Params: \n' - '\tdatabase: %r' - '\tuser: %r' - '\thost: %r' - '\tport: %r', database, user, host, port) + mycli.logger.debug("Launch Params: \n" "\tdatabase: %r" "\tuser: %r" "\thost: %r" "\tport: %r", database, user, host, port) # --execute argument if execute: try: if csv: - mycli.formatter.format_name = 'csv' - if execute.endswith(r'\G'): + mycli.formatter.format_name = "csv" + if execute.endswith(r"\G"): execute = execute[:-2] elif table: - if execute.endswith(r'\G'): + if execute.endswith(r"\G"): execute = execute[:-2] else: - mycli.formatter.format_name = 'tsv' + mycli.formatter.format_name = "tsv" mycli.run_query(execute) exit(0) except Exception as e: - click.secho(str(e), err=True, fg='red') + click.secho(str(e), err=True, fg="red") exit(1) if sys.stdin.isatty(): mycli.run_cli() else: - stdin = click.get_text_stream('stdin') + stdin = click.get_text_stream("stdin") try: stdin_text = stdin.read() except MemoryError: - click.secho('Failed! Ran out of memory.', err=True, fg='red') - click.secho('You might want to try the official mysql client.', err=True, fg='red') - click.secho('Sorry... :(', err=True, fg='red') + click.secho("Failed! Ran out of memory.", err=True, fg="red") + click.secho("You might want to try the official mysql client.", err=True, fg="red") + click.secho("Sorry... :(", err=True, fg="red") exit(1) if mycli.destructive_warning and is_destructive(stdin_text): try: - sys.stdin = open('/dev/tty') + sys.stdin = open("/dev/tty") warn_confirmed = confirm_destructive_query(stdin_text) except (IOError, OSError): - mycli.logger.warning('Unable to open TTY as stdin.') + mycli.logger.warning("Unable to open TTY as stdin.") if not warn_confirmed: exit(0) @@ -1376,14 +1372,14 @@ def cli(database, user, host, port, socket, password, dbname, new_line = True if csv: - mycli.formatter.format_name = 'csv' + mycli.formatter.format_name = "csv" elif not table: - mycli.formatter.format_name = 'tsv' + mycli.formatter.format_name = "tsv" mycli.run_query(stdin_text, new_line=new_line) exit(0) except Exception as e: - click.secho(str(e), err=True, fg='red') + click.secho(str(e), err=True, fg="red") exit(1) @@ -1393,8 +1389,7 @@ def need_completion_refresh(queries): for query in sqlparse.split(queries): try: first_token = query.split()[0] - if first_token.lower() in ('alter', 'create', 'use', '\\r', - '\\u', 'connect', 'drop', 'rename'): + if first_token.lower() in ("alter", "create", "use", "\\r", "\\u", "connect", "drop", "rename"): return True except Exception: return False @@ -1408,7 +1403,7 @@ def need_completion_reset(queries): for query in sqlparse.split(queries): try: first_token = query.split()[0] - if first_token.lower() in ('use', '\\u'): + if first_token.lower() in ("use", "\\u"): return True except Exception: return False @@ -1419,8 +1414,7 @@ def is_mutating(status): if not status: return False - mutating = set(['insert', 'update', 'delete', 'alter', 'create', 'drop', - 'replace', 'truncate', 'load', 'rename']) + mutating = set(["insert", "update", "delete", "alter", "create", "drop", "replace", "truncate", "load", "rename"]) return status.split(None, 1)[0].lower() in mutating @@ -1428,25 +1422,23 @@ def is_select(status): """Returns true if the first word in status is 'select'.""" if not status: return False - return status.split(None, 1)[0].lower() == 'select' + return status.split(None, 1)[0].lower() == "select" def thanks_picker(): import mycli - lines = ( - resources.read_text(mycli, 'AUTHORS') + - resources.read_text(mycli, 'SPONSORS') - ).split('\n') + + lines = (resources.read_text(mycli, "AUTHORS") + resources.read_text(mycli, "SPONSORS")).split("\n") contents = [] for line in lines: - m = re.match(r'^ *\* (.*)', line) + m = re.match(r"^ *\* (.*)", line) if m: contents.append(m.group(1)) return choice(contents) -@prompt_register('edit-and-execute-command') +@prompt_register("edit-and-execute-command") def edit_and_execute(event): """Different from the prompt-toolkit default, we want to have a choice not to execute a query after editing, hence validate_and_handle=False.""" @@ -1460,16 +1452,13 @@ def read_ssh_config(ssh_config_path): with open(ssh_config_path) as f: ssh_config.parse(f) except FileNotFoundError as e: - click.secho(str(e), err=True, fg='red') + click.secho(str(e), err=True, fg="red") sys.exit(1) # Paramiko prior to version 2.7 raises Exception on parse errors. # In 2.7 it has become paramiko.ssh_exception.SSHException, # but let's catch everything for compatibility except Exception as err: - click.secho( - f'Could not parse SSH configuration file {ssh_config_path}:\n{err} ', - err=True, fg='red' - ) + click.secho(f"Could not parse SSH configuration file {ssh_config_path}:\n{err} ", err=True, fg="red") sys.exit(1) else: return ssh_config diff --git a/mycli/packages/completion_engine.py b/mycli/packages/completion_engine.py index 6d5709a..a2cd63a 100644 --- a/mycli/packages/completion_engine.py +++ b/mycli/packages/completion_engine.py @@ -12,8 +12,7 @@ def suggest_type(full_text, text_before_cursor): A scope for a column category will be a list of tables. """ - word_before_cursor = last_word(text_before_cursor, - include='many_punctuations') + word_before_cursor = last_word(text_before_cursor, include="many_punctuations") identifier = None @@ -25,12 +24,10 @@ def suggest_type(full_text, text_before_cursor): # partially typed string which renders the smart completion useless because # it will always return the list of keywords as completion. if word_before_cursor: - if word_before_cursor.endswith( - '(') or word_before_cursor.startswith('\\'): + if word_before_cursor.endswith("(") or word_before_cursor.startswith("\\"): parsed = sqlparse.parse(text_before_cursor) else: - parsed = sqlparse.parse( - text_before_cursor[:-len(word_before_cursor)]) + parsed = sqlparse.parse(text_before_cursor[: -len(word_before_cursor)]) # word_before_cursor may include a schema qualification, like # "schema_name.partial_name" or "schema_name.", so parse it @@ -42,7 +39,7 @@ def suggest_type(full_text, text_before_cursor): else: parsed = sqlparse.parse(text_before_cursor) except (TypeError, AttributeError): - return [{'type': 'keyword'}] + return [{"type": "keyword"}] if len(parsed) > 1: # Multiple statements being edited -- isolate the current one by @@ -72,13 +69,12 @@ def suggest_type(full_text, text_before_cursor): # Be careful here because trivial whitespace is parsed as a statement, # but the statement won't have a first token tok1 = statement.token_first() - if tok1 and (tok1.value == 'source' or tok1.value.startswith('\\')): + if tok1 and (tok1.value == "source" or tok1.value.startswith("\\")): return suggest_special(text_before_cursor) - last_token = statement and statement.token_prev(len(statement.tokens))[1] or '' + last_token = statement and statement.token_prev(len(statement.tokens))[1] or "" - return suggest_based_on_last_token(last_token, text_before_cursor, - full_text, identifier) + return suggest_based_on_last_token(last_token, text_before_cursor, full_text, identifier) def suggest_special(text): @@ -87,27 +83,27 @@ def suggest_special(text): if cmd == text: # Trying to complete the special command itself - return [{'type': 'special'}] + return [{"type": "special"}] - if cmd in ('\\u', '\\r'): - return [{'type': 'database'}] + if cmd in ("\\u", "\\r"): + return [{"type": "database"}] - if cmd in ('\\T'): - return [{'type': 'table_format'}] + if cmd in ("\\T"): + return [{"type": "table_format"}] - if cmd in ['\\f', '\\fs', '\\fd']: - return [{'type': 'favoritequery'}] + if cmd in ["\\f", "\\fs", "\\fd"]: + return [{"type": "favoritequery"}] - if cmd in ['\\dt', '\\dt+']: + if cmd in ["\\dt", "\\dt+"]: return [ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}, + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + {"type": "schema"}, ] - elif cmd in ['\\.', 'source']: - return[{'type': 'file_name'}] + elif cmd in ["\\.", "source"]: + return [{"type": "file_name"}] - return [{'type': 'keyword'}, {'type': 'special'}] + return [{"type": "keyword"}, {"type": "special"}] def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier): @@ -127,20 +123,19 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier # 'where foo > 5 and '. We need to look "inside" token.tokens to handle # suggestions in complicated where clauses correctly prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor) - return suggest_based_on_last_token(prev_keyword, text_before_cursor, - full_text, identifier) + return suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text, identifier) elif token is None: - return [{'type': 'keyword'}] + return [{"type": "keyword"}] else: token_v = token.value.lower() - is_operand = lambda x: x and any([x.endswith(op) for op in ['+', '-', '*', '/']]) + is_operand = lambda x: x and any([x.endswith(op) for op in ["+", "-", "*", "/"]]) # noqa: E731 if not token: - return [{'type': 'keyword'}, {'type': 'special'}] + return [{"type": "keyword"}, {"type": "special"}] elif token_v == "*": - return [{'type': 'keyword'}] - elif token_v.endswith('('): + return [{"type": "keyword"}] + elif token_v.endswith("("): p = sqlparse.parse(text_before_cursor)[0] if p.tokens and isinstance(p.tokens[-1], Where): @@ -155,8 +150,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier # Suggest columns/functions AND keywords. (If we wanted to be # really fancy, we could suggest only array-typed columns) - column_suggestions = suggest_based_on_last_token('where', - text_before_cursor, full_text, identifier) + column_suggestions = suggest_based_on_last_token("where", text_before_cursor, full_text, identifier) # Check for a subquery expression (cases 3 & 4) where = p.tokens[-1] @@ -167,130 +161,133 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier prev_tok = prev_tok.tokens[-1] prev_tok = prev_tok.value.lower() - if prev_tok == 'exists': - return [{'type': 'keyword'}] + if prev_tok == "exists": + return [{"type": "keyword"}] else: return column_suggestions # Get the token before the parens idx, prev_tok = p.token_prev(len(p.tokens) - 1) - if prev_tok and prev_tok.value and prev_tok.value.lower() == 'using': + if prev_tok and prev_tok.value and prev_tok.value.lower() == "using": # tbl1 INNER JOIN tbl2 USING (col1, col2) tables = extract_tables(full_text) # suggest columns that are present in more than one table - return [{'type': 'column', 'tables': tables, 'drop_unique': True}] - elif p.token_first().value.lower() == 'select': + return [{"type": "column", "tables": tables, "drop_unique": True}] + elif p.token_first().value.lower() == "select": # If the lparen is preceeded by a space chances are we're about to # do a sub-select. - if last_word(text_before_cursor, - 'all_punctuations').startswith('('): - return [{'type': 'keyword'}] - elif p.token_first().value.lower() == 'show': - return [{'type': 'show'}] + if last_word(text_before_cursor, "all_punctuations").startswith("("): + return [{"type": "keyword"}] + elif p.token_first().value.lower() == "show": + return [{"type": "show"}] # We're probably in a function argument list - return [{'type': 'column', 'tables': extract_tables(full_text)}] - elif token_v in ('set', 'order by', 'distinct'): - return [{'type': 'column', 'tables': extract_tables(full_text)}] - elif token_v == 'as': + return [{"type": "column", "tables": extract_tables(full_text)}] + elif token_v in ("set", "order by", "distinct"): + return [{"type": "column", "tables": extract_tables(full_text)}] + elif token_v == "as": # Don't suggest anything for an alias return [] - elif token_v in ('show'): - return [{'type': 'show'}] - elif token_v in ('to',): + elif token_v in ("show"): + return [{"type": "show"}] + elif token_v in ("to",): p = sqlparse.parse(text_before_cursor)[0] - if p.token_first().value.lower() == 'change': - return [{'type': 'change'}] + if p.token_first().value.lower() == "change": + return [{"type": "change"}] else: - return [{'type': 'user'}] - elif token_v in ('user', 'for'): - return [{'type': 'user'}] - elif token_v in ('select', 'where', 'having'): + return [{"type": "user"}] + elif token_v in ("user", "for"): + return [{"type": "user"}] + elif token_v in ("select", "where", "having"): # Check for a table alias or schema qualification parent = (identifier and identifier.get_parent_name()) or [] tables = extract_tables(full_text) if parent: tables = [t for t in tables if identifies(parent, *t)] - return [{'type': 'column', 'tables': tables}, - {'type': 'table', 'schema': parent}, - {'type': 'view', 'schema': parent}, - {'type': 'function', 'schema': parent}] + return [ + {"type": "column", "tables": tables}, + {"type": "table", "schema": parent}, + {"type": "view", "schema": parent}, + {"type": "function", "schema": parent}, + ] else: aliases = [alias or table for (schema, table, alias) in tables] - return [{'type': 'column', 'tables': tables}, - {'type': 'function', 'schema': []}, - {'type': 'alias', 'aliases': aliases}, - {'type': 'keyword'}] - elif (token_v.endswith('join') and token.is_keyword) or (token_v in - ('copy', 'from', 'update', 'into', 'describe', 'truncate', - 'desc', 'explain')): + return [ + {"type": "column", "tables": tables}, + {"type": "function", "schema": []}, + {"type": "alias", "aliases": aliases}, + {"type": "keyword"}, + ] + elif (token_v.endswith("join") and token.is_keyword) or ( + token_v in ("copy", "from", "update", "into", "describe", "truncate", "desc", "explain") + ): schema = (identifier and identifier.get_parent_name()) or [] # Suggest tables from either the currently-selected schema or the # public schema if no schema has been specified - suggest = [{'type': 'table', 'schema': schema}] + suggest = [{"type": "table", "schema": schema}] if not schema: # Suggest schemas - suggest.insert(0, {'type': 'schema'}) + suggest.insert(0, {"type": "schema"}) # Only tables can be TRUNCATED, otherwise suggest views - if token_v != 'truncate': - suggest.append({'type': 'view', 'schema': schema}) + if token_v != "truncate": + suggest.append({"type": "view", "schema": schema}) return suggest - elif token_v in ('table', 'view', 'function'): + elif token_v in ("table", "view", "function"): # E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>' rel_type = token_v schema = (identifier and identifier.get_parent_name()) or [] if schema: - return [{'type': rel_type, 'schema': schema}] + return [{"type": rel_type, "schema": schema}] else: - return [{'type': 'schema'}, {'type': rel_type, 'schema': []}] - elif token_v == 'on': + return [{"type": "schema"}, {"type": rel_type, "schema": []}] + elif token_v == "on": tables = extract_tables(full_text) # [(schema, table, alias), ...] parent = (identifier and identifier.get_parent_name()) or [] if parent: # "ON parent.<suggestion>" # parent can be either a schema name or table alias tables = [t for t in tables if identifies(parent, *t)] - return [{'type': 'column', 'tables': tables}, - {'type': 'table', 'schema': parent}, - {'type': 'view', 'schema': parent}, - {'type': 'function', 'schema': parent}] + return [ + {"type": "column", "tables": tables}, + {"type": "table", "schema": parent}, + {"type": "view", "schema": parent}, + {"type": "function", "schema": parent}, + ] else: # ON <suggestion> # Use table alias if there is one, otherwise the table name aliases = [alias or table for (schema, table, alias) in tables] - suggest = [{'type': 'alias', 'aliases': aliases}] + suggest = [{"type": "alias", "aliases": aliases}] # The lists of 'aliases' could be empty if we're trying to complete # a GRANT query. eg: GRANT SELECT, INSERT ON <tab> # In that case we just suggest all tables. if not aliases: - suggest.append({'type': 'table', 'schema': parent}) + suggest.append({"type": "table", "schema": parent}) return suggest - elif token_v in ('use', 'database', 'template', 'connect'): + elif token_v in ("use", "database", "template", "connect"): # "\c <db", "use <db>", "DROP DATABASE <db>", # "CREATE DATABASE <newdb> WITH TEMPLATE <db>" - return [{'type': 'database'}] - elif token_v == 'tableformat': - return [{'type': 'table_format'}] - elif token_v.endswith(',') or is_operand(token_v) or token_v in ['=', 'and', 'or']: + return [{"type": "database"}] + elif token_v == "tableformat": + return [{"type": "table_format"}] + elif token_v.endswith(",") or is_operand(token_v) or token_v in ["=", "and", "or"]: prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor) if prev_keyword: - return suggest_based_on_last_token( - prev_keyword, text_before_cursor, full_text, identifier) + return suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text, identifier) else: return [] else: - return [{'type': 'keyword'}] + return [{"type": "keyword"}] def identifies(id, schema, table, alias): - return id == alias or id == table or ( - schema and (id == schema + '.' + table)) + return id == alias or id == table or (schema and (id == schema + "." + table)) diff --git a/mycli/packages/filepaths.py b/mycli/packages/filepaths.py index a91055d..12d9286 100644 --- a/mycli/packages/filepaths.py +++ b/mycli/packages/filepaths.py @@ -38,7 +38,7 @@ def complete_path(curr_dir, last_dir): """ if not last_dir or curr_dir.startswith(last_dir): return curr_dir - elif last_dir == '~': + elif last_dir == "~": return os.path.join(last_dir, curr_dir) @@ -51,7 +51,7 @@ def parse_path(root_dir): :return: tuple of (string, string, int) """ - base_dir, last_dir, position = '', '', 0 + base_dir, last_dir, position = "", "", 0 if root_dir: base_dir, last_dir = os.path.split(root_dir) position = -len(last_dir) if last_dir else 0 @@ -69,9 +69,9 @@ def suggest_path(root_dir): """ if not root_dir: - return [os.path.abspath(os.sep), '~', os.curdir, os.pardir] + return [os.path.abspath(os.sep), "~", os.curdir, os.pardir] - if '~' in root_dir: + if "~" in root_dir: root_dir = os.path.expanduser(root_dir) if not os.path.exists(root_dir): @@ -100,7 +100,7 @@ def guess_socket_location(): for r, dirs, files in os.walk(directory, topdown=True): for filename in files: name, ext = os.path.splitext(filename) - if name.startswith("mysql") and name != "mysqlx" and ext in ('.socket', '.sock'): + if name.startswith("mysql") and name != "mysqlx" and ext in (".socket", ".sock"): return os.path.join(r, filename) dirs[:] = [d for d in dirs if d.startswith("mysql")] return None diff --git a/mycli/packages/paramiko_stub/__init__.py b/mycli/packages/paramiko_stub/__init__.py index 045b00e..154c72c 100644 --- a/mycli/packages/paramiko_stub/__init__.py +++ b/mycli/packages/paramiko_stub/__init__.py @@ -12,16 +12,19 @@ class Paramiko: def __getattr__(self, name): import sys from textwrap import dedent - print(dedent(""" - To enable certain SSH features you need to install paramiko: + + print( + dedent(""" + To enable certain SSH features you need to install paramiko and sshtunnel: - pip install paramiko + pip install paramiko sshtunnel It is required for the following configuration options: --list-ssh-config --ssh-config-host --ssh-host - """)) + """) + ) sys.exit(1) diff --git a/mycli/packages/parseutils.py b/mycli/packages/parseutils.py index 3090530..9acbcd5 100644 --- a/mycli/packages/parseutils.py +++ b/mycli/packages/parseutils.py @@ -4,18 +4,18 @@ from sqlparse.sql import IdentifierList, Identifier, Function from sqlparse.tokens import Keyword, DML, Punctuation cleanup_regex = { - # This matches only alphanumerics and underscores. - 'alphanum_underscore': re.compile(r'(\w+)$'), - # This matches everything except spaces, parens, colon, and comma - 'many_punctuations': re.compile(r'([^():,\s]+)$'), - # This matches everything except spaces, parens, colon, comma, and period - 'most_punctuations': re.compile(r'([^\.():,\s]+)$'), - # This matches everything except a space. - 'all_punctuations': re.compile(r'([^\s]+)$'), + # This matches only alphanumerics and underscores. + "alphanum_underscore": re.compile(r"(\w+)$"), + # This matches everything except spaces, parens, colon, and comma + "many_punctuations": re.compile(r"([^():,\s]+)$"), + # This matches everything except spaces, parens, colon, comma, and period + "most_punctuations": re.compile(r"([^\.():,\s]+)$"), + # This matches everything except a space. + "all_punctuations": re.compile(r"([^\s]+)$"), } -def last_word(text, include='alphanum_underscore'): +def last_word(text, include="alphanum_underscore"): r""" Find the last word in a sentence. @@ -47,18 +47,18 @@ def last_word(text, include='alphanum_underscore'): 'def' """ - if not text: # Empty string - return '' + if not text: # Empty string + return "" if text[-1].isspace(): - return '' + return "" else: regex = cleanup_regex[include] matches = regex.search(text) if matches: return matches.group(0) else: - return '' + return "" # This code is borrowed from sqlparse example script. @@ -67,11 +67,11 @@ def is_subselect(parsed): if not parsed.is_group: return False for item in parsed.tokens: - if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT', - 'UPDATE', 'CREATE', 'DELETE'): + if item.ttype is DML and item.value.upper() in ("SELECT", "INSERT", "UPDATE", "CREATE", "DELETE"): return True return False + def extract_from_part(parsed, stop_at_punctuation=True): tbl_prefix_seen = False for item in parsed.tokens: @@ -85,7 +85,7 @@ def extract_from_part(parsed, stop_at_punctuation=True): # "ON" is a keyword and will trigger the next elif condition. # So instead of stooping the loop when finding an "ON" skip it # eg: 'SELECT * FROM abc JOIN def ON abc.id = def.abc_id JOIN ghi' - elif item.ttype is Keyword and item.value.upper() == 'ON': + elif item.ttype is Keyword and item.value.upper() == "ON": tbl_prefix_seen = False continue # An incomplete nested select won't be recognized correctly as a @@ -96,24 +96,28 @@ def extract_from_part(parsed, stop_at_punctuation=True): # Also 'SELECT * FROM abc JOIN def' will trigger this elif # condition. So we need to ignore the keyword JOIN and its variants # INNER JOIN, FULL OUTER JOIN, etc. - elif item.ttype is Keyword and ( - not item.value.upper() == 'FROM') and ( - not item.value.upper().endswith('JOIN')): + elif item.ttype is Keyword and (not item.value.upper() == "FROM") and (not item.value.upper().endswith("JOIN")): return else: yield item - elif ((item.ttype is Keyword or item.ttype is Keyword.DML) and - item.value.upper() in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE', 'JOIN',)): + elif (item.ttype is Keyword or item.ttype is Keyword.DML) and item.value.upper() in ( + "COPY", + "FROM", + "INTO", + "UPDATE", + "TABLE", + "JOIN", + ): tbl_prefix_seen = True # 'SELECT a, FROM abc' will detect FROM as part of the column list. # So this check here is necessary. elif isinstance(item, IdentifierList): for identifier in item.get_identifiers(): - if (identifier.ttype is Keyword and - identifier.value.upper() == 'FROM'): + if identifier.ttype is Keyword and identifier.value.upper() == "FROM": tbl_prefix_seen = True break + def extract_table_identifiers(token_stream): """yields tuples of (schema_name, table_name, table_alias)""" @@ -141,6 +145,7 @@ def extract_table_identifiers(token_stream): elif isinstance(item, Function): yield (None, item.get_name(), item.get_name()) + # extract_tables is inspired from examples in the sqlparse lib. def extract_tables(sql): """Extract the table names from an SQL statement. @@ -156,27 +161,27 @@ def extract_tables(sql): # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2) # abc is the table name, but if we don't stop at the first lparen, then # we'll identify abc, col1 and col2 as table names. - insert_stmt = parsed[0].token_first().value.lower() == 'insert' + insert_stmt = parsed[0].token_first().value.lower() == "insert" stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt) return list(extract_table_identifiers(stream)) + def find_prev_keyword(sql): - """ Find the last sql keyword in an SQL statement + """Find the last sql keyword in an SQL statement Returns the value of the last keyword, and the text of the query with everything after the last keyword stripped """ if not sql.strip(): - return None, '' + return None, "" parsed = sqlparse.parse(sql)[0] flattened = list(parsed.flatten()) - logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN') + logical_operators = ("AND", "OR", "NOT", "BETWEEN") for t in reversed(flattened): - if t.value == '(' or (t.is_keyword and ( - t.value.upper() not in logical_operators)): + if t.value == "(" or (t.is_keyword and (t.value.upper() not in logical_operators)): # Find the location of token t in the original parsed statement # We can't use parsed.token_index(t) because t may be a child token # inside a TokenList, in which case token_index thows an error @@ -189,10 +194,10 @@ def find_prev_keyword(sql): # Combine the string values of all tokens in the original list # up to and including the target keyword token t, to produce a # query string with everything after the keyword token removed - text = ''.join(tok.value for tok in flattened[:idx+1]) + text = "".join(tok.value for tok in flattened[: idx + 1]) return t, text - return None, '' + return None, "" def query_starts_with(query, prefixes): @@ -212,31 +217,25 @@ def queries_start_with(queries, prefixes): def query_has_where_clause(query): """Check if the query contains a where-clause.""" - return any( - isinstance(token, sqlparse.sql.Where) - for token_list in sqlparse.parse(query) - for token in token_list - ) + return any(isinstance(token, sqlparse.sql.Where) for token_list in sqlparse.parse(query) for token in token_list) def is_destructive(queries): """Returns if any of the queries in *queries* is destructive.""" - keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter') + keywords = ("drop", "shutdown", "delete", "truncate", "alter") for query in sqlparse.split(queries): if query: if query_starts_with(query, keywords) is True: return True - elif query_starts_with( - query, ['update'] - ) is True and not query_has_where_clause(query): + elif query_starts_with(query, ["update"]) is True and not query_has_where_clause(query): return True return False -if __name__ == '__main__': - sql = 'select * from (select t. from tabl t' - print (extract_tables(sql)) +if __name__ == "__main__": + sql = "select * from (select t. from tabl t" + print(extract_tables(sql)) def is_dropping_database(queries, dbname): @@ -258,9 +257,7 @@ def is_dropping_database(queries, dbname): "database", "schema", ): - database_token = next( - (t for t in query.tokens if isinstance(t, Identifier)), None - ) + database_token = next((t for t in query.tokens if isinstance(t, Identifier)), None) if database_token is not None and normalize_db_name(database_token.get_name()) == dbname: result = keywords[0].normalized == "DROP" return result diff --git a/mycli/packages/prompt_utils.py b/mycli/packages/prompt_utils.py index fb1e431..2cbca5e 100644 --- a/mycli/packages/prompt_utils.py +++ b/mycli/packages/prompt_utils.py @@ -4,20 +4,20 @@ from .parseutils import is_destructive class ConfirmBoolParamType(click.ParamType): - name = 'confirmation' + name = "confirmation" def convert(self, value, param, ctx): if isinstance(value, bool): return bool(value) value = value.lower() - if value in ('yes', 'y'): + if value in ("yes", "y"): return True - elif value in ('no', 'n'): + elif value in ("no", "n"): return False - self.fail('%s is not a valid boolean' % value, param, ctx) + self.fail("%s is not a valid boolean" % value, param, ctx) def __repr__(self): - return 'BOOL' + return "BOOL" BOOLEAN_TYPE = ConfirmBoolParamType() @@ -32,8 +32,7 @@ def confirm_destructive_query(queries): * False if the query is destructive and the user doesn't want to proceed. """ - prompt_text = ("You're about to run a destructive command.\n" - "Do you want to proceed? (y/n)") + prompt_text = "You're about to run a destructive command.\n" "Do you want to proceed? (y/n)" if is_destructive(queries) and sys.stdin.isatty(): return prompt(prompt_text, type=BOOLEAN_TYPE) diff --git a/mycli/packages/special/__init__.py b/mycli/packages/special/__init__.py index 92bcca6..0c8c909 100644 --- a/mycli/packages/special/__init__.py +++ b/mycli/packages/special/__init__.py @@ -1,10 +1,12 @@ __all__ = [] + def export(defn): """Decorator to explicitly mark functions that are exposed in a lib.""" globals()[defn.__name__] = defn __all__.append(defn.__name__) return defn -from . import dbcommands -from . import iocommands + +from . import dbcommands # noqa: E402 F401 +from . import iocommands # noqa: E402 F401 diff --git a/mycli/packages/special/dbcommands.py b/mycli/packages/special/dbcommands.py index 5c29c55..4432a22 100644 --- a/mycli/packages/special/dbcommands.py +++ b/mycli/packages/special/dbcommands.py @@ -10,24 +10,23 @@ from pymysql import ProgrammingError log = logging.getLogger(__name__) -@special_command('\\dt', '\\dt[+] [table]', 'List or describe tables.', - arg_type=PARSED_QUERY, case_sensitive=True) +@special_command("\\dt", "\\dt[+] [table]", "List or describe tables.", arg_type=PARSED_QUERY, case_sensitive=True) def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False): if arg: - query = 'SHOW FIELDS FROM {0}'.format(arg) + query = "SHOW FIELDS FROM {0}".format(arg) else: - query = 'SHOW TABLES' + query = "SHOW TABLES" log.debug(query) cur.execute(query) tables = cur.fetchall() - status = '' + status = "" if cur.description: headers = [x[0] for x in cur.description] else: - return [(None, None, None, '')] + return [(None, None, None, "")] if verbose and arg: - query = 'SHOW CREATE TABLE {0}'.format(arg) + query = "SHOW CREATE TABLE {0}".format(arg) log.debug(query) cur.execute(query) status = cur.fetchone()[1] @@ -35,128 +34,121 @@ def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False): return [(None, tables, headers, status)] -@special_command('\\l', '\\l', 'List databases.', arg_type=RAW_QUERY, case_sensitive=True) +@special_command("\\l", "\\l", "List databases.", arg_type=RAW_QUERY, case_sensitive=True) def list_databases(cur, **_): - query = 'SHOW DATABASES' + query = "SHOW DATABASES" log.debug(query) cur.execute(query) if cur.description: headers = [x[0] for x in cur.description] - return [(None, cur, headers, '')] + return [(None, cur, headers, "")] else: - return [(None, None, None, '')] + return [(None, None, None, "")] -@special_command('status', '\\s', 'Get status information from the server.', - arg_type=RAW_QUERY, aliases=('\\s', ), case_sensitive=True) +@special_command("status", "\\s", "Get status information from the server.", arg_type=RAW_QUERY, aliases=("\\s",), case_sensitive=True) def status(cur, **_): - query = 'SHOW GLOBAL STATUS;' + query = "SHOW GLOBAL STATUS;" log.debug(query) try: cur.execute(query) except ProgrammingError: # Fallback in case query fail, as it does with Mysql 4 - query = 'SHOW STATUS;' + query = "SHOW STATUS;" log.debug(query) cur.execute(query) status = dict(cur.fetchall()) - query = 'SHOW GLOBAL VARIABLES;' + query = "SHOW GLOBAL VARIABLES;" log.debug(query) cur.execute(query) variables = dict(cur.fetchall()) # prepare in case keys are bytes, as with Python 3 and Mysql 4 - if (isinstance(list(variables)[0], bytes) and - isinstance(list(status)[0], bytes)): - variables = {k.decode('utf-8'): v.decode('utf-8') for k, v - in variables.items()} - status = {k.decode('utf-8'): v.decode('utf-8') for k, v - in status.items()} + if isinstance(list(variables)[0], bytes) and isinstance(list(status)[0], bytes): + variables = {k.decode("utf-8"): v.decode("utf-8") for k, v in variables.items()} + status = {k.decode("utf-8"): v.decode("utf-8") for k, v in status.items()} # Create output buffers. title = [] output = [] footer = [] - title.append('--------------') + title.append("--------------") # Output the mycli client information. implementation = platform.python_implementation() version = platform.python_version() client_info = [] - client_info.append('mycli {0},'.format(__version__)) - client_info.append('running on {0} {1}'.format(implementation, version)) - title.append(' '.join(client_info) + '\n') + client_info.append("mycli {0},".format(__version__)) + client_info.append("running on {0} {1}".format(implementation, version)) + title.append(" ".join(client_info) + "\n") # Build the output that will be displayed as a table. - output.append(('Connection id:', cur.connection.thread_id())) + output.append(("Connection id:", cur.connection.thread_id())) - query = 'SELECT DATABASE(), USER();' + query = "SELECT DATABASE(), USER();" log.debug(query) cur.execute(query) db, user = cur.fetchone() if db is None: - db = '' + db = "" - output.append(('Current database:', db)) - output.append(('Current user:', user)) + output.append(("Current database:", db)) + output.append(("Current user:", user)) if iocommands.is_pager_enabled(): - if 'PAGER' in os.environ: - pager = os.environ['PAGER'] + if "PAGER" in os.environ: + pager = os.environ["PAGER"] else: - pager = 'System default' + pager = "System default" else: - pager = 'stdout' - output.append(('Current pager:', pager)) + pager = "stdout" + output.append(("Current pager:", pager)) - output.append(('Server version:', '{0} {1}'.format( - variables['version'], variables['version_comment']))) - output.append(('Protocol version:', variables['protocol_version'])) + output.append(("Server version:", "{0} {1}".format(variables["version"], variables["version_comment"]))) + output.append(("Protocol version:", variables["protocol_version"])) - if 'unix' in cur.connection.host_info.lower(): + if "unix" in cur.connection.host_info.lower(): host_info = cur.connection.host_info else: - host_info = '{0} via TCP/IP'.format(cur.connection.host) + host_info = "{0} via TCP/IP".format(cur.connection.host) - output.append(('Connection:', host_info)) + output.append(("Connection:", host_info)) - query = ('SELECT @@character_set_server, @@character_set_database, ' - '@@character_set_client, @@character_set_connection LIMIT 1;') + query = "SELECT @@character_set_server, @@character_set_database, " "@@character_set_client, @@character_set_connection LIMIT 1;" log.debug(query) cur.execute(query) charset = cur.fetchone() - output.append(('Server characterset:', charset[0])) - output.append(('Db characterset:', charset[1])) - output.append(('Client characterset:', charset[2])) - output.append(('Conn. characterset:', charset[3])) + output.append(("Server characterset:", charset[0])) + output.append(("Db characterset:", charset[1])) + output.append(("Client characterset:", charset[2])) + output.append(("Conn. characterset:", charset[3])) - if 'TCP/IP' in host_info: - output.append(('TCP port:', cur.connection.port)) + if "TCP/IP" in host_info: + output.append(("TCP port:", cur.connection.port)) else: - output.append(('UNIX socket:', variables['socket'])) + output.append(("UNIX socket:", variables["socket"])) - if 'Uptime' in status: - output.append(('Uptime:', format_uptime(status['Uptime']))) + if "Uptime" in status: + output.append(("Uptime:", format_uptime(status["Uptime"]))) - if 'Threads_connected' in status: + if "Threads_connected" in status: # Print the current server statistics. stats = [] - stats.append('Connections: {0}'.format(status['Threads_connected'])) - if 'Queries' in status: - stats.append('Queries: {0}'.format(status['Queries'])) - stats.append('Slow queries: {0}'.format(status['Slow_queries'])) - stats.append('Opens: {0}'.format(status['Opened_tables'])) - if 'Flush_commands' in status: - stats.append('Flush tables: {0}'.format(status['Flush_commands'])) - stats.append('Open tables: {0}'.format(status['Open_tables'])) - if 'Queries' in status: - queries_per_second = int(status['Queries']) / int(status['Uptime']) - stats.append('Queries per second avg: {:.3f}'.format( - queries_per_second)) - stats = ' '.join(stats) - footer.append('\n' + stats) - - footer.append('--------------') - return [('\n'.join(title), output, '', '\n'.join(footer))] + stats.append("Connections: {0}".format(status["Threads_connected"])) + if "Queries" in status: + stats.append("Queries: {0}".format(status["Queries"])) + stats.append("Slow queries: {0}".format(status["Slow_queries"])) + stats.append("Opens: {0}".format(status["Opened_tables"])) + if "Flush_commands" in status: + stats.append("Flush tables: {0}".format(status["Flush_commands"])) + stats.append("Open tables: {0}".format(status["Open_tables"])) + if "Queries" in status: + queries_per_second = int(status["Queries"]) / int(status["Uptime"]) + stats.append("Queries per second avg: {:.3f}".format(queries_per_second)) + stats = " ".join(stats) + footer.append("\n" + stats) + + footer.append("--------------") + return [("\n".join(title), output, "", "\n".join(footer))] diff --git a/mycli/packages/special/delimitercommand.py b/mycli/packages/special/delimitercommand.py index 994b134..530bf1a 100644 --- a/mycli/packages/special/delimitercommand.py +++ b/mycli/packages/special/delimitercommand.py @@ -4,7 +4,7 @@ import sqlparse class DelimiterCommand(object): def __init__(self): - self._delimiter = ';' + self._delimiter = ";" def _split(self, sql): """Temporary workaround until sqlparse.split() learns about custom @@ -12,22 +12,19 @@ class DelimiterCommand(object): placeholder = "\ufffc" # unicode object replacement character - if self._delimiter == ';': + if self._delimiter == ";": return sqlparse.split(sql) # We must find a string that original sql does not contain. # Most likely, our placeholder is enough, but if not, keep looking while placeholder in sql: placeholder += placeholder[0] - sql = sql.replace(';', placeholder) - sql = sql.replace(self._delimiter, ';') + sql = sql.replace(";", placeholder) + sql = sql.replace(self._delimiter, ";") split = sqlparse.split(sql) - return [ - stmt.replace(';', self._delimiter).replace(placeholder, ';') - for stmt in split - ] + return [stmt.replace(";", self._delimiter).replace(placeholder, ";") for stmt in split] def queries_iter(self, input): """Iterate over queries in the input string.""" @@ -49,7 +46,7 @@ class DelimiterCommand(object): # re-split everything, and if we previously stripped # the delimiter, append it to the end if self._delimiter != delimiter: - combined_statement = ' '.join([sql] + queries) + combined_statement = " ".join([sql] + queries) if trailing_delimiter: combined_statement += delimiter queries = self._split(combined_statement)[1:] @@ -63,13 +60,13 @@ class DelimiterCommand(object): word of it. """ - match = arg and re.search(r'[^\s]+', arg) + match = arg and re.search(r"[^\s]+", arg) if not match: - message = 'Missing required argument, delimiter' + message = "Missing required argument, delimiter" return [(None, None, None, message)] delimiter = match.group() - if delimiter.lower() == 'delimiter': + if delimiter.lower() == "delimiter": return [(None, None, None, 'Invalid delimiter "delimiter"')] self._delimiter = delimiter diff --git a/mycli/packages/special/favoritequeries.py b/mycli/packages/special/favoritequeries.py index 0b91400..3f8648c 100644 --- a/mycli/packages/special/favoritequeries.py +++ b/mycli/packages/special/favoritequeries.py @@ -1,8 +1,7 @@ class FavoriteQueries(object): + section_name = "favorite_queries" - section_name = 'favorite_queries' - - usage = ''' + usage = """ Favorite Queries are a way to save frequently used queries with a short name. Examples: @@ -29,7 +28,7 @@ Examples: # Delete a favorite query. > \\fd simple simple: Deleted -''' +""" # Class-level variable, for convenience to use as a singleton. instance = None @@ -48,7 +47,7 @@ Examples: return self.config.get(self.section_name, {}).get(name, None) def save(self, name, query): - self.config.encoding = 'utf-8' + self.config.encoding = "utf-8" if self.section_name not in self.config: self.config[self.section_name] = {} self.config[self.section_name][name] = query @@ -58,6 +57,6 @@ Examples: try: del self.config[self.section_name][name] except KeyError: - return '%s: Not Found.' % name + return "%s: Not Found." % name self.config.write() - return '%s: Deleted' % name + return "%s: Deleted" % name diff --git a/mycli/packages/special/iocommands.py b/mycli/packages/special/iocommands.py index 01f3c7b..87b5366 100644 --- a/mycli/packages/special/iocommands.py +++ b/mycli/packages/special/iocommands.py @@ -34,6 +34,7 @@ def set_timing_enabled(val): global TIMING_ENABLED TIMING_ENABLED = val + @export def set_pager_enabled(val): global PAGER_ENABLED @@ -44,33 +45,35 @@ def set_pager_enabled(val): def is_pager_enabled(): return PAGER_ENABLED + @export -@special_command('pager', '\\P [command]', - 'Set PAGER. Print the query results via PAGER.', - arg_type=PARSED_QUERY, aliases=('\\P', ), case_sensitive=True) +@special_command( + "pager", "\\P [command]", "Set PAGER. Print the query results via PAGER.", arg_type=PARSED_QUERY, aliases=("\\P",), case_sensitive=True +) def set_pager(arg, **_): if arg: - os.environ['PAGER'] = arg - msg = 'PAGER set to %s.' % arg + os.environ["PAGER"] = arg + msg = "PAGER set to %s." % arg set_pager_enabled(True) else: - if 'PAGER' in os.environ: - msg = 'PAGER set to %s.' % os.environ['PAGER'] + if "PAGER" in os.environ: + msg = "PAGER set to %s." % os.environ["PAGER"] else: # This uses click's default per echo_via_pager. - msg = 'Pager enabled.' + msg = "Pager enabled." set_pager_enabled(True) return [(None, None, None, msg)] + @export -@special_command('nopager', '\\n', 'Disable pager, print to stdout.', - arg_type=NO_QUERY, aliases=('\\n', ), case_sensitive=True) +@special_command("nopager", "\\n", "Disable pager, print to stdout.", arg_type=NO_QUERY, aliases=("\\n",), case_sensitive=True) def disable_pager(): set_pager_enabled(False) - return [(None, None, None, 'Pager disabled.')] + return [(None, None, None, "Pager disabled.")] -@special_command('\\timing', '\\t', 'Toggle timing of commands.', arg_type=NO_QUERY, aliases=('\\t', ), case_sensitive=True) + +@special_command("\\timing", "\\t", "Toggle timing of commands.", arg_type=NO_QUERY, aliases=("\\t",), case_sensitive=True) def toggle_timing(): global TIMING_ENABLED TIMING_ENABLED = not TIMING_ENABLED @@ -78,21 +81,26 @@ def toggle_timing(): message += "on." if TIMING_ENABLED else "off." return [(None, None, None, message)] + @export def is_timing_enabled(): return TIMING_ENABLED + @export def set_expanded_output(val): global use_expanded_output use_expanded_output = val + @export def is_expanded_output(): return use_expanded_output + _logger = logging.getLogger(__name__) + @export def editor_command(command): """ @@ -101,12 +109,13 @@ def editor_command(command): """ # It is possible to have `\e filename` or `SELECT * FROM \e`. So we check # for both conditions. - return command.strip().endswith('\\e') or command.strip().startswith('\\e') + return command.strip().endswith("\\e") or command.strip().startswith("\\e") + @export def get_filename(sql): - if sql.strip().startswith('\\e'): - command, _, filename = sql.partition(' ') + if sql.strip().startswith("\\e"): + command, _, filename = sql.partition(" ") return filename.strip() or None @@ -118,9 +127,9 @@ def get_editor_query(sql): # The reason we can't simply do .strip('\e') is that it strips characters, # not a substring. So it'll strip "e" in the end of the sql also! # Ex: "select * from style\e" -> "select * from styl". - pattern = re.compile(r'(^\\e|\\e$)') + pattern = re.compile(r"(^\\e|\\e$)") while pattern.search(sql): - sql = pattern.sub('', sql) + sql = pattern.sub("", sql) return sql @@ -135,25 +144,24 @@ def open_external_editor(filename=None, sql=None): """ message = None - filename = filename.strip().split(' ', 1)[0] if filename else None + filename = filename.strip().split(" ", 1)[0] if filename else None - sql = sql or '' - MARKER = '# Type your query above this line.\n' + sql = sql or "" + MARKER = "# Type your query above this line.\n" # Populate the editor buffer with the partial sql (if available) and a # placeholder comment. - query = click.edit(u'{sql}\n\n{marker}'.format(sql=sql, marker=MARKER), - filename=filename, extension='.sql') + query = click.edit("{sql}\n\n{marker}".format(sql=sql, marker=MARKER), filename=filename, extension=".sql") if filename: try: with open(filename) as f: query = f.read() except IOError: - message = 'Error reading file: %s.' % filename + message = "Error reading file: %s." % filename if query is not None: - query = query.split(MARKER, 1)[0].rstrip('\n') + query = query.split(MARKER, 1)[0].rstrip("\n") else: # Don't return None for the caller to deal with. # Empty string is ok. @@ -171,7 +179,7 @@ def clip_command(command): """ # It is possible to have `\clip` or `SELECT * FROM \clip`. So we check # for both conditions. - return command.strip().endswith('\\clip') or command.strip().startswith('\\clip') + return command.strip().endswith("\\clip") or command.strip().startswith("\\clip") @export @@ -181,9 +189,9 @@ def get_clip_query(sql): # The reason we can't simply do .strip('\clip') is that it strips characters, # not a substring. So it'll strip "c" in the end of the sql also! - pattern = re.compile(r'(^\\clip|\\clip$)') + pattern = re.compile(r"(^\\clip|\\clip$)") while pattern.search(sql): - sql = pattern.sub('', sql) + sql = pattern.sub("", sql) return sql @@ -192,26 +200,26 @@ def get_clip_query(sql): def copy_query_to_clipboard(sql=None): """Send query to the clipboard.""" - sql = sql or '' + sql = sql or "" message = None try: - pyperclip.copy(u'{sql}'.format(sql=sql)) + pyperclip.copy("{sql}".format(sql=sql)) except RuntimeError as e: - message = 'Error clipping query: %s.' % e.strerror + message = "Error clipping query: %s." % e.strerror return message -@special_command('\\f', '\\f [name [args..]]', 'List or execute favorite queries.', arg_type=PARSED_QUERY, case_sensitive=True) +@special_command("\\f", "\\f [name [args..]]", "List or execute favorite queries.", arg_type=PARSED_QUERY, case_sensitive=True) def execute_favorite_query(cur, arg, **_): """Returns (title, rows, headers, status)""" - if arg == '': + if arg == "": for result in list_favorite_queries(): yield result """Parse out favorite name and optional substitution parameters""" - name, _, arg_str = arg.partition(' ') + name, _, arg_str = arg.partition(" ") args = shlex.split(arg_str) query = FavoriteQueries.instance.get(name) @@ -224,8 +232,8 @@ def execute_favorite_query(cur, arg, **_): yield (None, None, None, arg_error) else: for sql in sqlparse.split(query): - sql = sql.rstrip(';') - title = '> %s' % (sql) + sql = sql.rstrip(";") + title = "> %s" % (sql) cur.execute(sql) if cur.description: headers = [x[0] for x in cur.description] @@ -233,60 +241,60 @@ def execute_favorite_query(cur, arg, **_): else: yield (title, None, None, None) + def list_favorite_queries(): """List of all favorite queries. Returns (title, rows, headers, status)""" headers = ["Name", "Query"] - rows = [(r, FavoriteQueries.instance.get(r)) - for r in FavoriteQueries.instance.list()] + rows = [(r, FavoriteQueries.instance.get(r)) for r in FavoriteQueries.instance.list()] if not rows: - status = '\nNo favorite queries found.' + FavoriteQueries.instance.usage + status = "\nNo favorite queries found." + FavoriteQueries.instance.usage else: - status = '' - return [('', rows, headers, status)] + status = "" + return [("", rows, headers, status)] def subst_favorite_query_args(query, args): """replace positional parameters ($1...$N) in query.""" for idx, val in enumerate(args): - subst_var = '$' + str(idx + 1) + subst_var = "$" + str(idx + 1) if subst_var not in query: - return [None, 'query does not have substitution parameter ' + subst_var + ':\n ' + query] + return [None, "query does not have substitution parameter " + subst_var + ":\n " + query] query = query.replace(subst_var, val) - match = re.search(r'\$\d+', query) + match = re.search(r"\$\d+", query) if match: - return[None, 'missing substitution for ' + match.group(0) + ' in query:\n ' + query] + return [None, "missing substitution for " + match.group(0) + " in query:\n " + query] return [query, None] -@special_command('\\fs', '\\fs name query', 'Save a favorite query.') + +@special_command("\\fs", "\\fs name query", "Save a favorite query.") def save_favorite_query(arg, **_): """Save a new favorite query. Returns (title, rows, headers, status)""" - usage = 'Syntax: \\fs name query.\n\n' + FavoriteQueries.instance.usage + usage = "Syntax: \\fs name query.\n\n" + FavoriteQueries.instance.usage if not arg: return [(None, None, None, usage)] - name, _, query = arg.partition(' ') + name, _, query = arg.partition(" ") # If either name or query is missing then print the usage and complain. if (not name) or (not query): - return [(None, None, None, - usage + 'Err: Both name and query are required.')] + return [(None, None, None, usage + "Err: Both name and query are required.")] FavoriteQueries.instance.save(name, query) return [(None, None, None, "Saved.")] -@special_command('\\fd', '\\fd [name]', 'Delete a favorite query.') +@special_command("\\fd", "\\fd [name]", "Delete a favorite query.") def delete_favorite_query(arg, **_): """Delete an existing favorite query.""" - usage = 'Syntax: \\fd name.\n\n' + FavoriteQueries.instance.usage + usage = "Syntax: \\fd name.\n\n" + FavoriteQueries.instance.usage if not arg: return [(None, None, None, usage)] @@ -295,8 +303,7 @@ def delete_favorite_query(arg, **_): return [(None, None, None, status)] -@special_command('system', 'system [command]', - 'Execute a system shell commmand.') +@special_command("system", "system [command]", "Execute a system shell commmand.") def execute_system_command(arg, **_): """Execute a system shell command.""" usage = "Syntax: system [command].\n" @@ -306,13 +313,13 @@ def execute_system_command(arg, **_): try: command = arg.strip() - if command.startswith('cd'): + if command.startswith("cd"): ok, error_message = handle_cd_command(arg) if not ok: return [(None, None, None, error_message)] - return [(None, None, None, '')] + return [(None, None, None, "")] - args = arg.split(' ') + args = arg.split(" ") process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) output, error = process.communicate() response = output if not error else error @@ -324,25 +331,24 @@ def execute_system_command(arg, **_): return [(None, None, None, response)] except OSError as e: - return [(None, None, None, 'OSError: %s' % e.strerror)] + return [(None, None, None, "OSError: %s" % e.strerror)] def parseargfile(arg): - if arg.startswith('-o '): + if arg.startswith("-o "): mode = "w" filename = arg[3:] else: - mode = 'a' + mode = "a" filename = arg if not filename: - raise TypeError('You must provide a filename.') + raise TypeError("You must provide a filename.") - return {'file': os.path.expanduser(filename), 'mode': mode} + return {"file": os.path.expanduser(filename), "mode": mode} -@special_command('tee', 'tee [-o] filename', - 'Append all results to an output file (overwrite using -o).') +@special_command("tee", "tee [-o] filename", "Append all results to an output file (overwrite using -o).") def set_tee(arg, **_): global tee_file @@ -353,6 +359,7 @@ def set_tee(arg, **_): return [(None, None, None, "")] + @export def close_tee(): global tee_file @@ -361,31 +368,29 @@ def close_tee(): tee_file = None -@special_command('notee', 'notee', 'Stop writing results to an output file.') +@special_command("notee", "notee", "Stop writing results to an output file.") def no_tee(arg, **_): close_tee() return [(None, None, None, "")] + @export def write_tee(output): global tee_file if tee_file: click.echo(output, file=tee_file, nl=False) - click.echo(u'\n', file=tee_file, nl=False) + click.echo("\n", file=tee_file, nl=False) tee_file.flush() -@special_command('\\once', '\\o [-o] filename', - 'Append next result to an output file (overwrite using -o).', - aliases=('\\o', )) +@special_command("\\once", "\\o [-o] filename", "Append next result to an output file (overwrite using -o).", aliases=("\\o",)) def set_once(arg, **_): global once_file, written_to_once_file try: once_file = open(**parseargfile(arg)) except (IOError, OSError) as e: - raise OSError("Cannot write to file '{}': {}".format( - e.filename, e.strerror)) + raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror)) written_to_once_file = False return [(None, None, None, "")] @@ -396,7 +401,7 @@ def write_once(output): global once_file, written_to_once_file if output and once_file: click.echo(output, file=once_file, nl=False) - click.echo(u"\n", file=once_file, nl=False) + click.echo("\n", file=once_file, nl=False) once_file.flush() written_to_once_file = True @@ -410,22 +415,22 @@ def unset_once_if_written(): once_file = None -@special_command('\\pipe_once', '\\| command', - 'Send next result to a subprocess.', - aliases=('\\|', )) +@special_command("\\pipe_once", "\\| command", "Send next result to a subprocess.", aliases=("\\|",)) def set_pipe_once(arg, **_): global pipe_once_process, written_to_pipe_once_process pipe_once_cmd = shlex.split(arg) if len(pipe_once_cmd) == 0: raise OSError("pipe_once requires a command") written_to_pipe_once_process = False - pipe_once_process = subprocess.Popen(pipe_once_cmd, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - bufsize=1, - encoding='UTF-8', - universal_newlines=True) + pipe_once_process = subprocess.Popen( + pipe_once_cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + bufsize=1, + encoding="UTF-8", + universal_newlines=True, + ) return [(None, None, None, "")] @@ -435,11 +440,10 @@ def write_pipe_once(output): if output and pipe_once_process: try: click.echo(output, file=pipe_once_process.stdin, nl=False) - click.echo(u"\n", file=pipe_once_process.stdin, nl=False) + click.echo("\n", file=pipe_once_process.stdin, nl=False) except (IOError, OSError) as e: pipe_once_process.terminate() - raise OSError( - "Failed writing to pipe_once subprocess: {}".format(e.strerror)) + raise OSError("Failed writing to pipe_once subprocess: {}".format(e.strerror)) written_to_pipe_once_process = True @@ -450,18 +454,14 @@ def unset_pipe_once_if_written(): if written_to_pipe_once_process: (stdout_data, stderr_data) = pipe_once_process.communicate() if len(stdout_data) > 0: - print(stdout_data.rstrip(u"\n")) + print(stdout_data.rstrip("\n")) if len(stderr_data) > 0: - print(stderr_data.rstrip(u"\n")) + print(stderr_data.rstrip("\n")) pipe_once_process = None written_to_pipe_once_process = False -@special_command( - 'watch', - 'watch [seconds] [-c] query', - 'Executes the query every [seconds] seconds (by default 5).' -) +@special_command("watch", "watch [seconds] [-c] query", "Executes the query every [seconds] seconds (by default 5).") def watch_query(arg, **kwargs): usage = """Syntax: watch [seconds] [-c] query. * seconds: The interval at the query will be repeated, in seconds. @@ -480,27 +480,24 @@ def watch_query(arg, **kwargs): # Oops, we parsed all the arguments without finding a statement yield (None, None, None, usage) return - (current_arg, _, arg) = arg.partition(' ') + (current_arg, _, arg) = arg.partition(" ") try: seconds = float(current_arg) continue except ValueError: pass - if current_arg == '-c': + if current_arg == "-c": clear_screen = True continue - statement = '{0!s} {1!s}'.format(current_arg, arg) + statement = "{0!s} {1!s}".format(current_arg, arg) destructive_prompt = confirm_destructive_query(statement) if destructive_prompt is False: click.secho("Wise choice!") return elif destructive_prompt is True: click.secho("Your call!") - cur = kwargs['cur'] - sql_list = [ - (sql.rstrip(';'), "> {0!s}".format(sql)) - for sql in sqlparse.split(statement) - ] + cur = kwargs["cur"] + sql_list = [(sql.rstrip(";"), "> {0!s}".format(sql)) for sql in sqlparse.split(statement)] old_pager_enabled = is_pager_enabled() while True: if clear_screen: @@ -509,7 +506,7 @@ def watch_query(arg, **kwargs): # Somewhere in the code the pager its activated after every yield, # so we disable it in every iteration set_pager_enabled(False) - for (sql, title) in sql_list: + for sql, title in sql_list: cur.execute(sql) if cur.description: headers = [x[0] for x in cur.description] @@ -527,7 +524,7 @@ def watch_query(arg, **kwargs): @export -@special_command('delimiter', None, 'Change SQL delimiter.') +@special_command("delimiter", None, "Change SQL delimiter.") def set_delimiter(arg, **_): return delimiter_command.set(arg) diff --git a/mycli/packages/special/main.py b/mycli/packages/special/main.py index ab04f30..4d1c941 100644 --- a/mycli/packages/special/main.py +++ b/mycli/packages/special/main.py @@ -9,43 +9,43 @@ NO_QUERY = 0 PARSED_QUERY = 1 RAW_QUERY = 2 -SpecialCommand = namedtuple('SpecialCommand', - ['handler', 'command', 'shortcut', 'description', 'arg_type', 'hidden', - 'case_sensitive']) +SpecialCommand = namedtuple("SpecialCommand", ["handler", "command", "shortcut", "description", "arg_type", "hidden", "case_sensitive"]) COMMANDS = {} + @export class CommandNotFound(Exception): pass + @export def parse_special_command(sql): - command, _, arg = sql.partition(' ') - verbose = '+' in command - command = command.strip().replace('+', '') + command, _, arg = sql.partition(" ") + verbose = "+" in command + command = command.strip().replace("+", "") return (command, verbose, arg.strip()) + @export -def special_command(command, shortcut, description, arg_type=PARSED_QUERY, - hidden=False, case_sensitive=False, aliases=()): +def special_command(command, shortcut, description, arg_type=PARSED_QUERY, hidden=False, case_sensitive=False, aliases=()): def wrapper(wrapped): - register_special_command(wrapped, command, shortcut, description, - arg_type, hidden, case_sensitive, aliases) + register_special_command(wrapped, command, shortcut, description, arg_type, hidden, case_sensitive, aliases) return wrapped + return wrapper + @export -def register_special_command(handler, command, shortcut, description, - arg_type=PARSED_QUERY, hidden=False, case_sensitive=False, aliases=()): +def register_special_command( + handler, command, shortcut, description, arg_type=PARSED_QUERY, hidden=False, case_sensitive=False, aliases=() +): cmd = command.lower() if not case_sensitive else command - COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description, - arg_type, hidden, case_sensitive) + COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description, arg_type, hidden, case_sensitive) for alias in aliases: cmd = alias.lower() if not case_sensitive else alias - COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description, - arg_type, case_sensitive=case_sensitive, - hidden=True) + COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description, arg_type, case_sensitive=case_sensitive, hidden=True) + @export def execute(cur, sql): @@ -62,11 +62,11 @@ def execute(cur, sql): except KeyError: special_cmd = COMMANDS[command.lower()] if special_cmd.case_sensitive: - raise CommandNotFound('Command not found: %s' % command) + raise CommandNotFound("Command not found: %s" % command) # "help <SQL KEYWORD> is a special case. We want built-in help, not # mycli help here. - if command == 'help' and arg: + if command == "help" and arg: return show_keyword_help(cur=cur, arg=arg) if special_cmd.arg_type == NO_QUERY: @@ -76,9 +76,10 @@ def execute(cur, sql): elif special_cmd.arg_type == RAW_QUERY: return special_cmd.handler(cur=cur, query=sql) -@special_command('help', '\\?', 'Show this help.', arg_type=NO_QUERY, aliases=('\\?', '?')) + +@special_command("help", "\\?", "Show this help.", arg_type=NO_QUERY, aliases=("\\?", "?")) def show_help(): # All the parameters are ignored. - headers = ['Command', 'Shortcut', 'Description'] + headers = ["Command", "Shortcut", "Description"] result = [] for _, value in sorted(COMMANDS.items()): @@ -86,6 +87,7 @@ def show_help(): # All the parameters are ignored. result.append((value.command, value.shortcut, value.description)) return [(None, result, headers, None)] + def show_keyword_help(cur, arg): """ Call the built-in "show <command>", to display help for an SQL keyword. @@ -99,22 +101,19 @@ def show_keyword_help(cur, arg): cur.execute(query) if cur.description and cur.rowcount > 0: headers = [x[0] for x in cur.description] - return [(None, cur, headers, '')] + return [(None, cur, headers, "")] else: - return [(None, None, None, 'No help found for {0}.'.format(keyword))] + return [(None, None, None, "No help found for {0}.".format(keyword))] -@special_command('exit', '\\q', 'Exit.', arg_type=NO_QUERY, aliases=('\\q', )) -@special_command('quit', '\\q', 'Quit.', arg_type=NO_QUERY) +@special_command("exit", "\\q", "Exit.", arg_type=NO_QUERY, aliases=("\\q",)) +@special_command("quit", "\\q", "Quit.", arg_type=NO_QUERY) def quit(*_args): raise EOFError -@special_command('\\e', '\\e', 'Edit command with editor (uses $EDITOR).', - arg_type=NO_QUERY, case_sensitive=True) -@special_command('\\clip', '\\clip', 'Copy query to the system clipboard.', - arg_type=NO_QUERY, case_sensitive=True) -@special_command('\\G', '\\G', 'Display current query results vertically.', - arg_type=NO_QUERY, case_sensitive=True) +@special_command("\\e", "\\e", "Edit command with editor (uses $EDITOR).", arg_type=NO_QUERY, case_sensitive=True) +@special_command("\\clip", "\\clip", "Copy query to the system clipboard.", arg_type=NO_QUERY, case_sensitive=True) +@special_command("\\G", "\\G", "Display current query results vertically.", arg_type=NO_QUERY, case_sensitive=True) def stub(): raise NotImplementedError diff --git a/mycli/packages/special/utils.py b/mycli/packages/special/utils.py index ef96093..eed9306 100644 --- a/mycli/packages/special/utils.py +++ b/mycli/packages/special/utils.py @@ -1,20 +1,22 @@ import os import subprocess + def handle_cd_command(arg): """Handles a `cd` shell command by calling python's os.chdir.""" - CD_CMD = 'cd' - tokens = arg.split(CD_CMD + ' ') + CD_CMD = "cd" + tokens = arg.split(CD_CMD + " ") directory = tokens[-1] if len(tokens) > 1 else None if not directory: return False, "No folder name was provided." try: os.chdir(directory) - subprocess.call(['pwd']) + subprocess.call(["pwd"]) return True, None except OSError as e: return False, e.strerror + def format_uptime(uptime_in_seconds): """Format number of seconds into human-readable string. @@ -32,15 +34,15 @@ def format_uptime(uptime_in_seconds): uptime_values = [] - for value, unit in ((d, 'days'), (h, 'hours'), (m, 'min'), (s, 'sec')): + for value, unit in ((d, "days"), (h, "hours"), (m, "min"), (s, "sec")): if value == 0 and not uptime_values: # Don't include a value/unit if the unit isn't applicable to # the uptime. E.g. don't do 0 days 0 hours 1 min 30 sec. continue - elif value == 1 and unit.endswith('s'): + elif value == 1 and unit.endswith("s"): # Remove the "s" if the unit is singular. unit = unit[:-1] - uptime_values.append('{0} {1}'.format(value, unit)) + uptime_values.append("{0} {1}".format(value, unit)) - uptime = ' '.join(uptime_values) + uptime = " ".join(uptime_values) return uptime diff --git a/mycli/packages/tabular_output/sql_format.py b/mycli/packages/tabular_output/sql_format.py index e6587bd..828a4b3 100644 --- a/mycli/packages/tabular_output/sql_format.py +++ b/mycli/packages/tabular_output/sql_format.py @@ -2,8 +2,12 @@ from mycli.packages.parseutils import extract_tables -supported_formats = ('sql-insert', 'sql-update', 'sql-update-1', - 'sql-update-2', ) +supported_formats = ( + "sql-insert", + "sql-update", + "sql-update-1", + "sql-update-2", +) preprocessors = () @@ -25,19 +29,18 @@ def adapter(data, headers, table_format=None, **kwargs): table_name = table[1] else: table_name = "`DUAL`" - if table_format == 'sql-insert': + if table_format == "sql-insert": h = "`, `".join(headers) yield "INSERT INTO {} (`{}`) VALUES".format(table_name, h) prefix = " " for d in data: - values = ", ".join(escape_for_sql_statement(v) - for i, v in enumerate(d)) + values = ", ".join(escape_for_sql_statement(v) for i, v in enumerate(d)) yield "{}({})".format(prefix, values) if prefix == " ": prefix = ", " yield ";" - if table_format.startswith('sql-update'): - s = table_format.split('-') + if table_format.startswith("sql-update"): + s = table_format.split("-") keys = 1 if len(s) > 2: keys = int(s[-1]) @@ -49,8 +52,7 @@ def adapter(data, headers, table_format=None, **kwargs): if prefix == " ": prefix = ", " f = "`{}` = {}" - where = (f.format(headers[i], escape_for_sql_statement( - d[i])) for i in range(keys)) + where = (f.format(headers[i], escape_for_sql_statement(d[i])) for i in range(keys)) yield "WHERE {};".format(" AND ".join(where)) @@ -58,5 +60,4 @@ def register_new_formatter(TabularOutputFormatter): global formatter formatter = TabularOutputFormatter for sql_format in supported_formats: - TabularOutputFormatter.register_new_formatter( - sql_format, adapter, preprocessors, {'table_format': sql_format}) + TabularOutputFormatter.register_new_formatter(sql_format, adapter, preprocessors, {"table_format": sql_format}) diff --git a/mycli/packages/toolkit/fzf.py b/mycli/packages/toolkit/fzf.py index 36cb347..5aeebe3 100644 --- a/mycli/packages/toolkit/fzf.py +++ b/mycli/packages/toolkit/fzf.py @@ -29,7 +29,7 @@ def search_history(event: KeyPressEvent): formatted_history_items = [] original_history_items = [] for item, timestamp in history_items_with_timestamp: - formatted_item = item.replace('\n', ' ') + formatted_item = item.replace("\n", " ") timestamp = timestamp.split(".")[0] if "." in timestamp else timestamp formatted_history_items.append(f"{timestamp} {formatted_item}") original_history_items.append(item) diff --git a/mycli/packages/toolkit/history.py b/mycli/packages/toolkit/history.py index 75f4a5a..237317f 100644 --- a/mycli/packages/toolkit/history.py +++ b/mycli/packages/toolkit/history.py @@ -1,5 +1,5 @@ import os -from typing import Iterable, Union, List, Tuple +from typing import Union, List, Tuple from prompt_toolkit.history import FileHistory diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py index b0eecea..1636289 100644 --- a/mycli/sqlcompleter.py +++ b/mycli/sqlcompleter.py @@ -14,191 +14,887 @@ _logger = logging.getLogger(__name__) class SQLCompleter(Completer): keywords = [ - 'SELECT', 'FROM', 'WHERE', 'UPDATE', 'DELETE FROM', 'GROUP BY', - 'JOIN', 'INSERT INTO', 'LIKE', 'LIMIT', 'ACCESS', 'ADD', 'ALL', - 'ALTER TABLE', 'AND', 'ANY', 'AS', 'ASC', 'AUTO_INCREMENT', - 'BEFORE', 'BEGIN', 'BETWEEN', 'BIGINT', 'BINARY', 'BY', 'CASE', - 'CHANGE MASTER TO', 'CHAR', 'CHARACTER SET', 'CHECK', 'COLLATE', - 'COLUMN', 'COMMENT', 'COMMIT', 'CONSTRAINT', 'CREATE', 'CURRENT', - 'CURRENT_TIMESTAMP', 'DATABASE', 'DATE', 'DECIMAL', 'DEFAULT', - 'DESC', 'DESCRIBE', 'DROP', 'ELSE', 'END', 'ENGINE', 'ESCAPE', - 'EXISTS', 'FILE', 'FLOAT', 'FOR', 'FOREIGN KEY', 'FORMAT', 'FULL', - 'FUNCTION', 'GRANT', 'HAVING', 'HOST', 'IDENTIFIED', 'IN', - 'INCREMENT', 'INDEX', 'INT', 'INTEGER', 'INTERVAL', 'INTO', 'IS', - 'KEY', 'LEFT', 'LEVEL', 'LOCK', 'LOGS', 'LONG', 'MASTER', - 'MEDIUMINT', 'MODE', 'MODIFY', 'NOT', 'NULL', 'NUMBER', 'OFFSET', - 'ON', 'OPTION', 'OR', 'ORDER BY', 'OUTER', 'OWNER', 'PASSWORD', - 'PORT', 'PRIMARY', 'PRIVILEGES', 'PROCESSLIST', 'PURGE', - 'REFERENCES', 'REGEXP', 'RENAME', 'REPAIR', 'RESET', 'REVOKE', - 'RIGHT', 'ROLLBACK', 'ROW', 'ROWS', 'ROW_FORMAT', 'SAVEPOINT', - 'SESSION', 'SET', 'SHARE', 'SHOW', 'SLAVE', 'SMALLINT', - 'START', 'STOP', 'TABLE', 'THEN', 'TINYINT', 'TO', 'TRANSACTION', - 'TRIGGER', 'TRUNCATE', 'UNION', 'UNIQUE', 'UNSIGNED', 'USE', - 'USER', 'USING', 'VALUES', 'VARCHAR', 'VIEW', 'WHEN', 'WITH' - ] + "SELECT", + "FROM", + "WHERE", + "UPDATE", + "DELETE FROM", + "GROUP BY", + "JOIN", + "INSERT INTO", + "LIKE", + "LIMIT", + "ACCESS", + "ADD", + "ALL", + "ALTER TABLE", + "AND", + "ANY", + "AS", + "ASC", + "AUTO_INCREMENT", + "BEFORE", + "BEGIN", + "BETWEEN", + "BIGINT", + "BINARY", + "BY", + "CASE", + "CHANGE MASTER TO", + "CHAR", + "CHARACTER SET", + "CHECK", + "COLLATE", + "COLUMN", + "COMMENT", + "COMMIT", + "CONSTRAINT", + "CREATE", + "CURRENT", + "CURRENT_TIMESTAMP", + "DATABASE", + "DATE", + "DECIMAL", + "DEFAULT", + "DESC", + "DESCRIBE", + "DROP", + "ELSE", + "END", + "ENGINE", + "ESCAPE", + "EXISTS", + "FILE", + "FLOAT", + "FOR", + "FOREIGN KEY", + "FORMAT", + "FULL", + "FUNCTION", + "GRANT", + "HAVING", + "HOST", + "IDENTIFIED", + "IN", + "INCREMENT", + "INDEX", + "INT", + "INTEGER", + "INTERVAL", + "INTO", + "IS", + "KEY", + "LEFT", + "LEVEL", + "LOCK", + "LOGS", + "LONG", + "MASTER", + "MEDIUMINT", + "MODE", + "MODIFY", + "NOT", + "NULL", + "NUMBER", + "OFFSET", + "ON", + "OPTION", + "OR", + "ORDER BY", + "OUTER", + "OWNER", + "PASSWORD", + "PORT", + "PRIMARY", + "PRIVILEGES", + "PROCESSLIST", + "PURGE", + "REFERENCES", + "REGEXP", + "RENAME", + "REPAIR", + "RESET", + "REVOKE", + "RIGHT", + "ROLLBACK", + "ROW", + "ROWS", + "ROW_FORMAT", + "SAVEPOINT", + "SESSION", + "SET", + "SHARE", + "SHOW", + "SLAVE", + "SMALLINT", + "START", + "STOP", + "TABLE", + "THEN", + "TINYINT", + "TO", + "TRANSACTION", + "TRIGGER", + "TRUNCATE", + "UNION", + "UNIQUE", + "UNSIGNED", + "USE", + "USER", + "USING", + "VALUES", + "VARCHAR", + "VIEW", + "WHEN", + "WITH", + ] tidb_keywords = [ - "SELECT", "FROM", "WHERE", "DELETE FROM", "UPDATE", "GROUP BY", - "JOIN", "INSERT INTO", "LIKE", "LIMIT", "ACCOUNT", "ACTION", "ADD", - "ADDDATE", "ADMIN", "ADVISE", "AFTER", "AGAINST", "AGO", - "ALGORITHM", "ALL", "ALTER", "ALWAYS", "ANALYZE", "AND", "ANY", - "APPROX_COUNT_DISTINCT", "APPROX_PERCENTILE", "AS", "ASC", "ASCII", - "ATTRIBUTES", "AUTO_ID_CACHE", "AUTO_INCREMENT", "AUTO_RANDOM", - "AUTO_RANDOM_BASE", "AVG", "AVG_ROW_LENGTH", "BACKEND", "BACKUP", - "BACKUPS", "BATCH", "BEGIN", "BERNOULLI", "BETWEEN", "BIGINT", - "BINARY", "BINDING", "BINDINGS", "BINDING_CACHE", "BINLOG", "BIT", - "BIT_AND", "BIT_OR", "BIT_XOR", "BLOB", "BLOCK", "BOOL", "BOOLEAN", - "BOTH", "BOUND", "BRIEF", "BTREE", "BUCKETS", "BUILTINS", "BY", - "BYTE", "CACHE", "CALL", "CANCEL", "CAPTURE", "CARDINALITY", - "CASCADE", "CASCADED", "CASE", "CAST", "CAUSAL", "CHAIN", "CHANGE", - "CHAR", "CHARACTER", "CHARSET", "CHECK", "CHECKPOINT", "CHECKSUM", - "CIPHER", "CLEANUP", "CLIENT", "CLIENT_ERRORS_SUMMARY", - "CLUSTERED", "CMSKETCH", "COALESCE", "COLLATE", "COLLATION", - "COLUMN", "COLUMNS", "COLUMN_FORMAT", "COLUMN_STATS_USAGE", - "COMMENT", "COMMIT", "COMMITTED", "COMPACT", "COMPRESSED", - "COMPRESSION", "CONCURRENCY", "CONFIG", "CONNECTION", - "CONSISTENCY", "CONSISTENT", "CONSTRAINT", "CONSTRAINTS", - "CONTEXT", "CONVERT", "COPY", "CORRELATION", "CPU", "CREATE", - "CROSS", "CSV_BACKSLASH_ESCAPE", "CSV_DELIMITER", "CSV_HEADER", - "CSV_NOT_NULL", "CSV_NULL", "CSV_SEPARATOR", - "CSV_TRIM_LAST_SEPARATORS", "CUME_DIST", "CURRENT", "CURRENT_DATE", - "CURRENT_ROLE", "CURRENT_TIME", "CURRENT_TIMESTAMP", - "CURRENT_USER", "CURTIME", "CYCLE", "DATA", "DATABASE", - "DATABASES", "DATE", "DATETIME", "DATE_ADD", "DATE_SUB", "DAY", - "DAY_HOUR", "DAY_MICROSECOND", "DAY_MINUTE", "DAY_SECOND", "DDL", - "DEALLOCATE", "DECIMAL", "DEFAULT", "DEFINER", "DELAYED", - "DELAY_KEY_WRITE", "DENSE_RANK", "DEPENDENCY", "DEPTH", "DESC", - "DESCRIBE", "DIRECTORY", "DISABLE", "DISABLED", "DISCARD", "DISK", - "DISTINCT", "DISTINCTROW", "DIV", "DO", "DOT", "DOUBLE", "DRAINER", - "DROP", "DRY", "DUAL", "DUMP", "DUPLICATE", "DYNAMIC", "ELSE", - "ENABLE", "ENABLED", "ENCLOSED", "ENCRYPTION", "END", "ENFORCED", - "ENGINE", "ENGINES", "ENUM", "ERROR", "ERRORS", "ESCAPE", - "ESCAPED", "EVENT", "EVENTS", "EVOLVE", "EXACT", "EXCEPT", - "EXCHANGE", "EXCLUSIVE", "EXECUTE", "EXISTS", "EXPANSION", - "EXPIRE", "EXPLAIN", "EXPR_PUSHDOWN_BLACKLIST", "EXTENDED", - "EXTRACT", "FALSE", "FAST", "FAULTS", "FETCH", "FIELDS", "FILE", - "FIRST", "FIRST_VALUE", "FIXED", "FLASHBACK", "FLOAT", "FLUSH", - "FOLLOWER", "FOLLOWERS", "FOLLOWER_CONSTRAINTS", "FOLLOWING", - "FOR", "FORCE", "FOREIGN", "FORMAT", "FULL", "FULLTEXT", - "FUNCTION", "GENERAL", "GENERATED", "GET_FORMAT", "GLOBAL", - "GRANT", "GRANTS", "GROUPS", "GROUP_CONCAT", "HASH", "HAVING", - "HELP", "HIGH_PRIORITY", "HISTOGRAM", "HISTOGRAMS_IN_FLIGHT", - "HISTORY", "HOSTS", "HOUR", "HOUR_MICROSECOND", "HOUR_MINUTE", - "HOUR_SECOND", "IDENTIFIED", "IF", "IGNORE", "IMPORT", "IMPORTS", - "IN", "INCREMENT", "INCREMENTAL", "INDEX", "INDEXES", "INFILE", - "INNER", "INPLACE", "INSERT_METHOD", "INSTANCE", - "INSTANT", "INT", "INT1", "INT2", "INT3", "INT4", "INT8", - "INTEGER", "INTERNAL", "INTERSECT", "INTERVAL", "INTO", - "INVISIBLE", "INVOKER", "IO", "IPC", "IS", "ISOLATION", "ISSUER", - "JOB", "JOBS", "JSON", "JSON_ARRAYAGG", "JSON_OBJECTAGG", "KEY", - "KEYS", "KEY_BLOCK_SIZE", "KILL", "LABELS", "LAG", "LANGUAGE", - "LAST", "LASTVAL", "LAST_BACKUP", "LAST_VALUE", "LEAD", "LEADER", - "LEADER_CONSTRAINTS", "LEADING", "LEARNER", "LEARNERS", - "LEARNER_CONSTRAINTS", "LEFT", "LESS", "LEVEL", "LINEAR", "LINES", - "LIST", "LOAD", "LOCAL", "LOCALTIME", "LOCALTIMESTAMP", "LOCATION", - "LOCK", "LOCKED", "LOGS", "LONG", "LONGBLOB", "LONGTEXT", - "LOW_PRIORITY", "MASTER", "MATCH", "MAX", "MAXVALUE", - "MAX_CONNECTIONS_PER_HOUR", "MAX_IDXNUM", "MAX_MINUTES", - "MAX_QUERIES_PER_HOUR", "MAX_ROWS", "MAX_UPDATES_PER_HOUR", - "MAX_USER_CONNECTIONS", "MB", "MEDIUMBLOB", "MEDIUMINT", - "MEDIUMTEXT", "MEMORY", "MERGE", "MICROSECOND", "MIN", "MINUTE", - "MINUTE_MICROSECOND", "MINUTE_SECOND", "MINVALUE", "MIN_ROWS", - "MOD", "MODE", "MODIFY", "MONTH", "NAMES", "NATIONAL", "NATURAL", - "NCHAR", "NEVER", "NEXT", "NEXTVAL", "NEXT_ROW_ID", "NO", - "NOCACHE", "NOCYCLE", "NODEGROUP", "NODE_ID", "NODE_STATE", - "NOMAXVALUE", "NOMINVALUE", "NONCLUSTERED", "NONE", "NORMAL", - "NOT", "NOW", "NOWAIT", "NO_WRITE_TO_BINLOG", "NTH_VALUE", "NTILE", - "NULL", "NULLS", "NUMERIC", "NVARCHAR", "OF", "OFF", "OFFSET", - "ON", "ONLINE", "ONLY", "ON_DUPLICATE", "OPEN", "OPTIMISTIC", - "OPTIMIZE", "OPTION", "OPTIONAL", "OPTIONALLY", - "OPT_RULE_BLACKLIST", "OR", "ORDER", "OUTER", "OUTFILE", "OVER", - "PACK_KEYS", "PAGE", "PARSER", "PARTIAL", "PARTITION", - "PARTITIONING", "PARTITIONS", "PASSWORD", "PERCENT", - "PERCENT_RANK", "PER_DB", "PER_TABLE", "PESSIMISTIC", "PLACEMENT", - "PLAN", "PLAN_CACHE", "PLUGINS", "POLICY", "POSITION", "PRECEDING", - "PRECISION", "PREDICATE", "PREPARE", "PRESERVE", - "PRE_SPLIT_REGIONS", "PRIMARY", "PRIMARY_REGION", "PRIVILEGES", - "PROCEDURE", "PROCESS", "PROCESSLIST", "PROFILE", "PROFILES", - "PROXY", "PUMP", "PURGE", "QUARTER", "QUERIES", "QUERY", "QUICK", - "RANGE", "RANK", "RATE_LIMIT", "READ", "REAL", "REBUILD", "RECENT", - "RECOVER", "RECURSIVE", "REDUNDANT", "REFERENCES", "REGEXP", - "REGION", "REGIONS", "RELEASE", "RELOAD", "REMOVE", "RENAME", - "REORGANIZE", "REPAIR", "REPEAT", "REPEATABLE", "REPLACE", - "REPLAYER", "REPLICA", "REPLICAS", "REPLICATION", "REQUIRE", - "REQUIRED", "RESET", "RESPECT", "RESTART", "RESTORE", "RESTORES", - "RESTRICT", "RESUME", "REVERSE", "REVOKE", "RIGHT", "RLIKE", - "ROLE", "ROLLBACK", "ROUTINE", "ROW", "ROWS", "ROW_COUNT", - "ROW_FORMAT", "ROW_NUMBER", "RTREE", "RUN", "RUNNING", "S3", - "SAMPLERATE", "SAMPLES", "SAN", "SAVEPOINT", "SCHEDULE", "SECOND", - "SECONDARY_ENGINE", "SECONDARY_LOAD", "SECONDARY_UNLOAD", - "SECOND_MICROSECOND", "SECURITY", "SEND_CREDENTIALS_TO_TIKV", - "SEPARATOR", "SEQUENCE", "SERIAL", "SERIALIZABLE", "SESSION", - "SESSION_STATES", "SET", "SETVAL", "SHARD_ROW_ID_BITS", "SHARE", - "SHARED", "SHOW", "SHUTDOWN", "SIGNED", "SIMPLE", "SKIP", - "SKIP_SCHEMA_FILES", "SLAVE", "SLOW", "SMALLINT", "SNAPSHOT", - "SOME", "SOURCE", "SPATIAL", "SPLIT", "SQL", "SQL_BIG_RESULT", - "SQL_BUFFER_RESULT", "SQL_CACHE", "SQL_CALC_FOUND_ROWS", - "SQL_NO_CACHE", "SQL_SMALL_RESULT", "SQL_TSI_DAY", "SQL_TSI_HOUR", - "SQL_TSI_MINUTE", "SQL_TSI_MONTH", "SQL_TSI_QUARTER", - "SQL_TSI_SECOND", "SQL_TSI_WEEK", "SQL_TSI_YEAR", "SSL", - "STALENESS", "START", "STARTING", "STATISTICS", "STATS", - "STATS_AUTO_RECALC", "STATS_BUCKETS", "STATS_COL_CHOICE", - "STATS_COL_LIST", "STATS_EXTENDED", "STATS_HEALTHY", - "STATS_HISTOGRAMS", "STATS_META", "STATS_OPTIONS", - "STATS_PERSISTENT", "STATS_SAMPLE_PAGES", "STATS_SAMPLE_RATE", - "STATS_TOPN", "STATUS", "STD", "STDDEV", "STDDEV_POP", - "STDDEV_SAMP", "STOP", "STORAGE", "STORED", "STRAIGHT_JOIN", - "STRICT", "STRICT_FORMAT", "STRONG", "SUBDATE", "SUBJECT", - "SUBPARTITION", "SUBPARTITIONS", "SUBSTRING", "SUM", "SUPER", - "SWAPS", "SWITCHES", "SYSTEM", "SYSTEM_TIME", "TABLE", "TABLES", - "TABLESAMPLE", "TABLESPACE", "TABLE_CHECKSUM", "TARGET", - "TELEMETRY", "TELEMETRY_ID", "TEMPORARY", "TEMPTABLE", - "TERMINATED", "TEXT", "THAN", "THEN", "TIDB", "TIFLASH", - "TIKV_IMPORTER", "TIME", "TIMESTAMP", "TIMESTAMPADD", - "TIMESTAMPDIFF", "TINYBLOB", "TINYINT", "TINYTEXT", "TLS", "TO", - "TOKUDB_DEFAULT", "TOKUDB_FAST", "TOKUDB_LZMA", "TOKUDB_QUICKLZ", - "TOKUDB_SMALL", "TOKUDB_SNAPPY", "TOKUDB_UNCOMPRESSED", - "TOKUDB_ZLIB", "TOP", "TOPN", "TRACE", "TRADITIONAL", "TRAILING", - "TRANSACTION", "TRIGGER", "TRIGGERS", "TRIM", "TRUE", - "TRUE_CARD_COST", "TRUNCATE", "TYPE", "UNBOUNDED", "UNCOMMITTED", - "UNDEFINED", "UNICODE", "UNION", "UNIQUE", "UNKNOWN", "UNLOCK", - "UNSIGNED", "USAGE", "USE", "USER", "USING", "UTC_DATE", - "UTC_TIME", "UTC_TIMESTAMP", "VALIDATION", "VALUE", "VALUES", - "VARBINARY", "VARCHAR", "VARCHARACTER", "VARIABLES", "VARIANCE", - "VARYING", "VAR_POP", "VAR_SAMP", "VERBOSE", "VIEW", "VIRTUAL", - "VISIBLE", "VOTER", "VOTERS", "VOTER_CONSTRAINTS", "WAIT", - "WARNINGS", "WEEK", "WEIGHT_STRING", "WHEN", "WIDTH", "WINDOW", - "WITH", "WITHOUT", "WRITE", "X509", "XOR", "YEAR", "YEAR_MONTH", - "ZEROFILL" - ] - - functions = ['AVG', 'CONCAT', 'COUNT', 'DISTINCT', 'FIRST', 'FORMAT', - 'FROM_UNIXTIME', 'LAST', 'LCASE', 'LEN', 'MAX', 'MID', - 'MIN', 'NOW', 'ROUND', 'SUM', 'TOP', 'UCASE', - 'UNIX_TIMESTAMP' - ] + "SELECT", + "FROM", + "WHERE", + "DELETE FROM", + "UPDATE", + "GROUP BY", + "JOIN", + "INSERT INTO", + "LIKE", + "LIMIT", + "ACCOUNT", + "ACTION", + "ADD", + "ADDDATE", + "ADMIN", + "ADVISE", + "AFTER", + "AGAINST", + "AGO", + "ALGORITHM", + "ALL", + "ALTER", + "ALWAYS", + "ANALYZE", + "AND", + "ANY", + "APPROX_COUNT_DISTINCT", + "APPROX_PERCENTILE", + "AS", + "ASC", + "ASCII", + "ATTRIBUTES", + "AUTO_ID_CACHE", + "AUTO_INCREMENT", + "AUTO_RANDOM", + "AUTO_RANDOM_BASE", + "AVG", + "AVG_ROW_LENGTH", + "BACKEND", + "BACKUP", + "BACKUPS", + "BATCH", + "BEGIN", + "BERNOULLI", + "BETWEEN", + "BIGINT", + "BINARY", + "BINDING", + "BINDINGS", + "BINDING_CACHE", + "BINLOG", + "BIT", + "BIT_AND", + "BIT_OR", + "BIT_XOR", + "BLOB", + "BLOCK", + "BOOL", + "BOOLEAN", + "BOTH", + "BOUND", + "BRIEF", + "BTREE", + "BUCKETS", + "BUILTINS", + "BY", + "BYTE", + "CACHE", + "CALL", + "CANCEL", + "CAPTURE", + "CARDINALITY", + "CASCADE", + "CASCADED", + "CASE", + "CAST", + "CAUSAL", + "CHAIN", + "CHANGE", + "CHAR", + "CHARACTER", + "CHARSET", + "CHECK", + "CHECKPOINT", + "CHECKSUM", + "CIPHER", + "CLEANUP", + "CLIENT", + "CLIENT_ERRORS_SUMMARY", + "CLUSTERED", + "CMSKETCH", + "COALESCE", + "COLLATE", + "COLLATION", + "COLUMN", + "COLUMNS", + "COLUMN_FORMAT", + "COLUMN_STATS_USAGE", + "COMMENT", + "COMMIT", + "COMMITTED", + "COMPACT", + "COMPRESSED", + "COMPRESSION", + "CONCURRENCY", + "CONFIG", + "CONNECTION", + "CONSISTENCY", + "CONSISTENT", + "CONSTRAINT", + "CONSTRAINTS", + "CONTEXT", + "CONVERT", + "COPY", + "CORRELATION", + "CPU", + "CREATE", + "CROSS", + "CSV_BACKSLASH_ESCAPE", + "CSV_DELIMITER", + "CSV_HEADER", + "CSV_NOT_NULL", + "CSV_NULL", + "CSV_SEPARATOR", + "CSV_TRIM_LAST_SEPARATORS", + "CUME_DIST", + "CURRENT", + "CURRENT_DATE", + "CURRENT_ROLE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "CURRENT_USER", + "CURTIME", + "CYCLE", + "DATA", + "DATABASE", + "DATABASES", + "DATE", + "DATETIME", + "DATE_ADD", + "DATE_SUB", + "DAY", + "DAY_HOUR", + "DAY_MICROSECOND", + "DAY_MINUTE", + "DAY_SECOND", + "DDL", + "DEALLOCATE", + "DECIMAL", + "DEFAULT", + "DEFINER", + "DELAYED", + "DELAY_KEY_WRITE", + "DENSE_RANK", + "DEPENDENCY", + "DEPTH", + "DESC", + "DESCRIBE", + "DIRECTORY", + "DISABLE", + "DISABLED", + "DISCARD", + "DISK", + "DISTINCT", + "DISTINCTROW", + "DIV", + "DO", + "DOT", + "DOUBLE", + "DRAINER", + "DROP", + "DRY", + "DUAL", + "DUMP", + "DUPLICATE", + "DYNAMIC", + "ELSE", + "ENABLE", + "ENABLED", + "ENCLOSED", + "ENCRYPTION", + "END", + "ENFORCED", + "ENGINE", + "ENGINES", + "ENUM", + "ERROR", + "ERRORS", + "ESCAPE", + "ESCAPED", + "EVENT", + "EVENTS", + "EVOLVE", + "EXACT", + "EXCEPT", + "EXCHANGE", + "EXCLUSIVE", + "EXECUTE", + "EXISTS", + "EXPANSION", + "EXPIRE", + "EXPLAIN", + "EXPR_PUSHDOWN_BLACKLIST", + "EXTENDED", + "EXTRACT", + "FALSE", + "FAST", + "FAULTS", + "FETCH", + "FIELDS", + "FILE", + "FIRST", + "FIRST_VALUE", + "FIXED", + "FLASHBACK", + "FLOAT", + "FLUSH", + "FOLLOWER", + "FOLLOWERS", + "FOLLOWER_CONSTRAINTS", + "FOLLOWING", + "FOR", + "FORCE", + "FOREIGN", + "FORMAT", + "FULL", + "FULLTEXT", + "FUNCTION", + "GENERAL", + "GENERATED", + "GET_FORMAT", + "GLOBAL", + "GRANT", + "GRANTS", + "GROUPS", + "GROUP_CONCAT", + "HASH", + "HAVING", + "HELP", + "HIGH_PRIORITY", + "HISTOGRAM", + "HISTOGRAMS_IN_FLIGHT", + "HISTORY", + "HOSTS", + "HOUR", + "HOUR_MICROSECOND", + "HOUR_MINUTE", + "HOUR_SECOND", + "IDENTIFIED", + "IF", + "IGNORE", + "IMPORT", + "IMPORTS", + "IN", + "INCREMENT", + "INCREMENTAL", + "INDEX", + "INDEXES", + "INFILE", + "INNER", + "INPLACE", + "INSERT_METHOD", + "INSTANCE", + "INSTANT", + "INT", + "INT1", + "INT2", + "INT3", + "INT4", + "INT8", + "INTEGER", + "INTERNAL", + "INTERSECT", + "INTERVAL", + "INTO", + "INVISIBLE", + "INVOKER", + "IO", + "IPC", + "IS", + "ISOLATION", + "ISSUER", + "JOB", + "JOBS", + "JSON", + "JSON_ARRAYAGG", + "JSON_OBJECTAGG", + "KEY", + "KEYS", + "KEY_BLOCK_SIZE", + "KILL", + "LABELS", + "LAG", + "LANGUAGE", + "LAST", + "LASTVAL", + "LAST_BACKUP", + "LAST_VALUE", + "LEAD", + "LEADER", + "LEADER_CONSTRAINTS", + "LEADING", + "LEARNER", + "LEARNERS", + "LEARNER_CONSTRAINTS", + "LEFT", + "LESS", + "LEVEL", + "LINEAR", + "LINES", + "LIST", + "LOAD", + "LOCAL", + "LOCALTIME", + "LOCALTIMESTAMP", + "LOCATION", + "LOCK", + "LOCKED", + "LOGS", + "LONG", + "LONGBLOB", + "LONGTEXT", + "LOW_PRIORITY", + "MASTER", + "MATCH", + "MAX", + "MAXVALUE", + "MAX_CONNECTIONS_PER_HOUR", + "MAX_IDXNUM", + "MAX_MINUTES", + "MAX_QUERIES_PER_HOUR", + "MAX_ROWS", + "MAX_UPDATES_PER_HOUR", + "MAX_USER_CONNECTIONS", + "MB", + "MEDIUMBLOB", + "MEDIUMINT", + "MEDIUMTEXT", + "MEMORY", + "MERGE", + "MICROSECOND", + "MIN", + "MINUTE", + "MINUTE_MICROSECOND", + "MINUTE_SECOND", + "MINVALUE", + "MIN_ROWS", + "MOD", + "MODE", + "MODIFY", + "MONTH", + "NAMES", + "NATIONAL", + "NATURAL", + "NCHAR", + "NEVER", + "NEXT", + "NEXTVAL", + "NEXT_ROW_ID", + "NO", + "NOCACHE", + "NOCYCLE", + "NODEGROUP", + "NODE_ID", + "NODE_STATE", + "NOMAXVALUE", + "NOMINVALUE", + "NONCLUSTERED", + "NONE", + "NORMAL", + "NOT", + "NOW", + "NOWAIT", + "NO_WRITE_TO_BINLOG", + "NTH_VALUE", + "NTILE", + "NULL", + "NULLS", + "NUMERIC", + "NVARCHAR", + "OF", + "OFF", + "OFFSET", + "ON", + "ONLINE", + "ONLY", + "ON_DUPLICATE", + "OPEN", + "OPTIMISTIC", + "OPTIMIZE", + "OPTION", + "OPTIONAL", + "OPTIONALLY", + "OPT_RULE_BLACKLIST", + "OR", + "ORDER", + "OUTER", + "OUTFILE", + "OVER", + "PACK_KEYS", + "PAGE", + "PARSER", + "PARTIAL", + "PARTITION", + "PARTITIONING", + "PARTITIONS", + "PASSWORD", + "PERCENT", + "PERCENT_RANK", + "PER_DB", + "PER_TABLE", + "PESSIMISTIC", + "PLACEMENT", + "PLAN", + "PLAN_CACHE", + "PLUGINS", + "POLICY", + "POSITION", + "PRECEDING", + "PRECISION", + "PREDICATE", + "PREPARE", + "PRESERVE", + "PRE_SPLIT_REGIONS", + "PRIMARY", + "PRIMARY_REGION", + "PRIVILEGES", + "PROCEDURE", + "PROCESS", + "PROCESSLIST", + "PROFILE", + "PROFILES", + "PROXY", + "PUMP", + "PURGE", + "QUARTER", + "QUERIES", + "QUERY", + "QUICK", + "RANGE", + "RANK", + "RATE_LIMIT", + "READ", + "REAL", + "REBUILD", + "RECENT", + "RECOVER", + "RECURSIVE", + "REDUNDANT", + "REFERENCES", + "REGEXP", + "REGION", + "REGIONS", + "RELEASE", + "RELOAD", + "REMOVE", + "RENAME", + "REORGANIZE", + "REPAIR", + "REPEAT", + "REPEATABLE", + "REPLACE", + "REPLAYER", + "REPLICA", + "REPLICAS", + "REPLICATION", + "REQUIRE", + "REQUIRED", + "RESET", + "RESPECT", + "RESTART", + "RESTORE", + "RESTORES", + "RESTRICT", + "RESUME", + "REVERSE", + "REVOKE", + "RIGHT", + "RLIKE", + "ROLE", + "ROLLBACK", + "ROUTINE", + "ROW", + "ROWS", + "ROW_COUNT", + "ROW_FORMAT", + "ROW_NUMBER", + "RTREE", + "RUN", + "RUNNING", + "S3", + "SAMPLERATE", + "SAMPLES", + "SAN", + "SAVEPOINT", + "SCHEDULE", + "SECOND", + "SECONDARY_ENGINE", + "SECONDARY_LOAD", + "SECONDARY_UNLOAD", + "SECOND_MICROSECOND", + "SECURITY", + "SEND_CREDENTIALS_TO_TIKV", + "SEPARATOR", + "SEQUENCE", + "SERIAL", + "SERIALIZABLE", + "SESSION", + "SESSION_STATES", + "SET", + "SETVAL", + "SHARD_ROW_ID_BITS", + "SHARE", + "SHARED", + "SHOW", + "SHUTDOWN", + "SIGNED", + "SIMPLE", + "SKIP", + "SKIP_SCHEMA_FILES", + "SLAVE", + "SLOW", + "SMALLINT", + "SNAPSHOT", + "SOME", + "SOURCE", + "SPATIAL", + "SPLIT", + "SQL", + "SQL_BIG_RESULT", + "SQL_BUFFER_RESULT", + "SQL_CACHE", + "SQL_CALC_FOUND_ROWS", + "SQL_NO_CACHE", + "SQL_SMALL_RESULT", + "SQL_TSI_DAY", + "SQL_TSI_HOUR", + "SQL_TSI_MINUTE", + "SQL_TSI_MONTH", + "SQL_TSI_QUARTER", + "SQL_TSI_SECOND", + "SQL_TSI_WEEK", + "SQL_TSI_YEAR", + "SSL", + "STALENESS", + "START", + "STARTING", + "STATISTICS", + "STATS", + "STATS_AUTO_RECALC", + "STATS_BUCKETS", + "STATS_COL_CHOICE", + "STATS_COL_LIST", + "STATS_EXTENDED", + "STATS_HEALTHY", + "STATS_HISTOGRAMS", + "STATS_META", + "STATS_OPTIONS", + "STATS_PERSISTENT", + "STATS_SAMPLE_PAGES", + "STATS_SAMPLE_RATE", + "STATS_TOPN", + "STATUS", + "STD", + "STDDEV", + "STDDEV_POP", + "STDDEV_SAMP", + "STOP", + "STORAGE", + "STORED", + "STRAIGHT_JOIN", + "STRICT", + "STRICT_FORMAT", + "STRONG", + "SUBDATE", + "SUBJECT", + "SUBPARTITION", + "SUBPARTITIONS", + "SUBSTRING", + "SUM", + "SUPER", + "SWAPS", + "SWITCHES", + "SYSTEM", + "SYSTEM_TIME", + "TABLE", + "TABLES", + "TABLESAMPLE", + "TABLESPACE", + "TABLE_CHECKSUM", + "TARGET", + "TELEMETRY", + "TELEMETRY_ID", + "TEMPORARY", + "TEMPTABLE", + "TERMINATED", + "TEXT", + "THAN", + "THEN", + "TIDB", + "TIFLASH", + "TIKV_IMPORTER", + "TIME", + "TIMESTAMP", + "TIMESTAMPADD", + "TIMESTAMPDIFF", + "TINYBLOB", + "TINYINT", + "TINYTEXT", + "TLS", + "TO", + "TOKUDB_DEFAULT", + "TOKUDB_FAST", + "TOKUDB_LZMA", + "TOKUDB_QUICKLZ", + "TOKUDB_SMALL", + "TOKUDB_SNAPPY", + "TOKUDB_UNCOMPRESSED", + "TOKUDB_ZLIB", + "TOP", + "TOPN", + "TRACE", + "TRADITIONAL", + "TRAILING", + "TRANSACTION", + "TRIGGER", + "TRIGGERS", + "TRIM", + "TRUE", + "TRUE_CARD_COST", + "TRUNCATE", + "TYPE", + "UNBOUNDED", + "UNCOMMITTED", + "UNDEFINED", + "UNICODE", + "UNION", + "UNIQUE", + "UNKNOWN", + "UNLOCK", + "UNSIGNED", + "USAGE", + "USE", + "USER", + "USING", + "UTC_DATE", + "UTC_TIME", + "UTC_TIMESTAMP", + "VALIDATION", + "VALUE", + "VALUES", + "VARBINARY", + "VARCHAR", + "VARCHARACTER", + "VARIABLES", + "VARIANCE", + "VARYING", + "VAR_POP", + "VAR_SAMP", + "VERBOSE", + "VIEW", + "VIRTUAL", + "VISIBLE", + "VOTER", + "VOTERS", + "VOTER_CONSTRAINTS", + "WAIT", + "WARNINGS", + "WEEK", + "WEIGHT_STRING", + "WHEN", + "WIDTH", + "WINDOW", + "WITH", + "WITHOUT", + "WRITE", + "X509", + "XOR", + "YEAR", + "YEAR_MONTH", + "ZEROFILL", + ] + + functions = [ + "AVG", + "CONCAT", + "COUNT", + "DISTINCT", + "FIRST", + "FORMAT", + "FROM_UNIXTIME", + "LAST", + "LCASE", + "LEN", + "MAX", + "MID", + "MIN", + "NOW", + "ROUND", + "SUM", + "TOP", + "UCASE", + "UNIX_TIMESTAMP", + ] # https://docs.pingcap.com/tidb/dev/tidb-functions tidb_functions = [ - 'TIDB_BOUNDED_STALENESS', 'TIDB_DECODE_KEY', 'TIDB_DECODE_PLAN', - 'TIDB_IS_DDL_OWNER', 'TIDB_PARSE_TSO', 'TIDB_VERSION', - 'TIDB_DECODE_SQL_DIGESTS', 'VITESS_HASH', 'TIDB_SHARD' - ] - + "TIDB_BOUNDED_STALENESS", + "TIDB_DECODE_KEY", + "TIDB_DECODE_PLAN", + "TIDB_IS_DDL_OWNER", + "TIDB_PARSE_TSO", + "TIDB_VERSION", + "TIDB_DECODE_SQL_DIGESTS", + "VITESS_HASH", + "TIDB_SHARD", + ] show_items = [] - change_items = ['MASTER_BIND', 'MASTER_HOST', 'MASTER_USER', - 'MASTER_PASSWORD', 'MASTER_PORT', 'MASTER_CONNECT_RETRY', - 'MASTER_HEARTBEAT_PERIOD', 'MASTER_LOG_FILE', - 'MASTER_LOG_POS', 'RELAY_LOG_FILE', 'RELAY_LOG_POS', - 'MASTER_SSL', 'MASTER_SSL_CA', 'MASTER_SSL_CAPATH', - 'MASTER_SSL_CERT', 'MASTER_SSL_KEY', 'MASTER_SSL_CIPHER', - 'MASTER_SSL_VERIFY_SERVER_CERT', 'IGNORE_SERVER_IDS'] + change_items = [ + "MASTER_BIND", + "MASTER_HOST", + "MASTER_USER", + "MASTER_PASSWORD", + "MASTER_PORT", + "MASTER_CONNECT_RETRY", + "MASTER_HEARTBEAT_PERIOD", + "MASTER_LOG_FILE", + "MASTER_LOG_POS", + "RELAY_LOG_FILE", + "RELAY_LOG_POS", + "MASTER_SSL", + "MASTER_SSL_CA", + "MASTER_SSL_CAPATH", + "MASTER_SSL_CERT", + "MASTER_SSL_KEY", + "MASTER_SSL_CIPHER", + "MASTER_SSL_VERIFY_SERVER_CERT", + "IGNORE_SERVER_IDS", + ] users = [] - def __init__(self, smart_completion=True, supported_formats=(), keyword_casing='auto'): + def __init__(self, smart_completion=True, supported_formats=(), keyword_casing="auto"): super(self.__class__, self).__init__() self.smart_completion = smart_completion self.reserved_words = set() @@ -208,16 +904,14 @@ class SQLCompleter(Completer): self.special_commands = [] self.table_formats = supported_formats - if keyword_casing not in ('upper', 'lower', 'auto'): - keyword_casing = 'auto' + if keyword_casing not in ("upper", "lower", "auto"): + keyword_casing = "auto" self.keyword_casing = keyword_casing self.reset_completions() def escape_name(self, name): - if name and ((not self.name_pattern.match(name)) - or (name.upper() in self.reserved_words) - or (name.upper() in self.functions)): - name = '`%s`' % name + if name and ((not self.name_pattern.match(name)) or (name.upper() in self.reserved_words) or (name.upper() in self.functions)): + name = "`%s`" % name return name @@ -264,7 +958,7 @@ class SQLCompleter(Completer): def extend_schemata(self, schema): if schema is None: return - metadata = self.dbmetadata['tables'] + metadata = self.dbmetadata["tables"] metadata[schema] = {} # dbmetadata.values() are the 'tables' and 'functions' dicts @@ -293,10 +987,9 @@ class SQLCompleter(Completer): metadata = self.dbmetadata[kind] for relname in data: try: - metadata[self.dbname][relname[0]] = ['*'] + metadata[self.dbname][relname[0]] = ["*"] except KeyError: - _logger.error('%r %r listed in unrecognized schema %r', - kind, relname[0], self.dbname) + _logger.error("%r %r listed in unrecognized schema %r", kind, relname[0], self.dbname) self.all_completions.add(relname[0]) def extend_columns(self, column_data, kind): @@ -317,6 +1010,13 @@ class SQLCompleter(Completer): metadata = self.dbmetadata[kind] for relname, column in column_data: + if relname not in metadata[self.dbname]: + _logger.error("relname '%s' was not found in db '%s'", relname, self.dbname) + # this could happen back when the completer populated via two calls: + # SHOW TABLES then SELECT table_name, column_name from information_schema.columns + # it's a slight race, but much more likely on Vitess picking random shards for each. + # see discussion in https://github.com/dbcli/mycli/pull/1182 (tl;dr - let's keep it) + continue metadata[self.dbname][relname].append(column) self.all_completions.add(column) @@ -337,7 +1037,7 @@ class SQLCompleter(Completer): # dbmetadata['functions'][$schema_name][$function_name] should return # function metadata. - metadata = self.dbmetadata['functions'] + metadata = self.dbmetadata["functions"] for func in func_data: metadata[self.dbname][func[0]] = None @@ -350,8 +1050,8 @@ class SQLCompleter(Completer): self.databases = [] self.users = [] self.show_items = [] - self.dbname = '' - self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {}} + self.dbname = "" + self.dbmetadata = {"tables": {}, "views": {}, "functions": {}} self.all_completions = set(self.keywords + self.functions) @staticmethod @@ -369,14 +1069,14 @@ class SQLCompleter(Completer): yields prompt_toolkit Completion instances for any matches found in the collection of available completions. """ - last = last_word(text, include='most_punctuations') + last = last_word(text, include="most_punctuations") text = last.lower() completions = [] if fuzzy: - regex = '.*?'.join(map(escape, text)) - pat = compile('(%s)' % regex) + regex = ".*?".join(map(escape, text)) + pat = compile("(%s)" % regex) for item in collection: r = pat.search(item.lower()) if r: @@ -388,16 +1088,15 @@ class SQLCompleter(Completer): if match_point >= 0: completions.append((len(text), match_point, item)) - if casing == 'auto': - casing = 'lower' if last and last[-1].islower() else 'upper' + if casing == "auto": + casing = "lower" if last and last[-1].islower() else "upper" def apply_case(kw): - if casing == 'upper': + if casing == "upper": return kw.upper() return kw.lower() - return (Completion(z if casing is None else apply_case(z), -len(text)) - for x, y, z in completions) + return (Completion(z if casing is None else apply_case(z), -len(text)) for x, y, z in completions) def get_completions(self, document, complete_event, smart_completion=None): word_before_cursor = document.get_word_before_cursor(WORD=True) @@ -407,36 +1106,30 @@ class SQLCompleter(Completer): # If smart_completion is off then match any word that starts with # 'word_before_cursor'. if not smart_completion: - return self.find_matches(word_before_cursor, self.all_completions, - start_only=True, fuzzy=False) + return self.find_matches(word_before_cursor, self.all_completions, start_only=True, fuzzy=False) completions = [] suggestions = suggest_type(document.text, document.text_before_cursor) for suggestion in suggestions: + _logger.debug("Suggestion type: %r", suggestion["type"]) - _logger.debug('Suggestion type: %r', suggestion['type']) - - if suggestion['type'] == 'column': - tables = suggestion['tables'] + if suggestion["type"] == "column": + tables = suggestion["tables"] _logger.debug("Completion column scope: %r", tables) scoped_cols = self.populate_scoped_cols(tables) - if suggestion.get('drop_unique'): + if suggestion.get("drop_unique"): # drop_unique is used for 'tb11 JOIN tbl2 USING (...' # which should suggest only columns that appear in more than # one table - scoped_cols = [ - col for (col, count) in Counter(scoped_cols).items() - if count > 1 and col != '*' - ] + scoped_cols = [col for (col, count) in Counter(scoped_cols).items() if count > 1 and col != "*"] cols = self.find_matches(word_before_cursor, scoped_cols) completions.extend(cols) - elif suggestion['type'] == 'function': + elif suggestion["type"] == "function": # suggest user-defined functions using substring matching - funcs = self.populate_schema_objects(suggestion['schema'], - 'functions') + funcs = self.populate_schema_objects(suggestion["schema"], "functions") user_funcs = self.find_matches(word_before_cursor, funcs) completions.extend(user_funcs) @@ -444,77 +1137,59 @@ class SQLCompleter(Completer): # there is no schema qualifier. If a schema qualifier is # present it probably denotes a table. # eg: SELECT * FROM users u WHERE u. - if not suggestion['schema']: - predefined_funcs = self.find_matches(word_before_cursor, - self.functions, - start_only=True, - fuzzy=False, - casing=self.keyword_casing) + if not suggestion["schema"]: + predefined_funcs = self.find_matches( + word_before_cursor, self.functions, start_only=True, fuzzy=False, casing=self.keyword_casing + ) completions.extend(predefined_funcs) - elif suggestion['type'] == 'table': - tables = self.populate_schema_objects(suggestion['schema'], - 'tables') + elif suggestion["type"] == "table": + tables = self.populate_schema_objects(suggestion["schema"], "tables") tables = self.find_matches(word_before_cursor, tables) completions.extend(tables) - elif suggestion['type'] == 'view': - views = self.populate_schema_objects(suggestion['schema'], - 'views') + elif suggestion["type"] == "view": + views = self.populate_schema_objects(suggestion["schema"], "views") views = self.find_matches(word_before_cursor, views) completions.extend(views) - elif suggestion['type'] == 'alias': - aliases = suggestion['aliases'] + elif suggestion["type"] == "alias": + aliases = suggestion["aliases"] aliases = self.find_matches(word_before_cursor, aliases) completions.extend(aliases) - elif suggestion['type'] == 'database': + elif suggestion["type"] == "database": dbs = self.find_matches(word_before_cursor, self.databases) completions.extend(dbs) - elif suggestion['type'] == 'keyword': - keywords = self.find_matches(word_before_cursor, self.keywords, - casing=self.keyword_casing) + elif suggestion["type"] == "keyword": + keywords = self.find_matches(word_before_cursor, self.keywords, casing=self.keyword_casing) completions.extend(keywords) - elif suggestion['type'] == 'show': - show_items = self.find_matches(word_before_cursor, - self.show_items, - start_only=False, - fuzzy=True, - casing=self.keyword_casing) + elif suggestion["type"] == "show": + show_items = self.find_matches( + word_before_cursor, self.show_items, start_only=False, fuzzy=True, casing=self.keyword_casing + ) completions.extend(show_items) - elif suggestion['type'] == 'change': - change_items = self.find_matches(word_before_cursor, - self.change_items, - start_only=False, - fuzzy=True) + elif suggestion["type"] == "change": + change_items = self.find_matches(word_before_cursor, self.change_items, start_only=False, fuzzy=True) completions.extend(change_items) - elif suggestion['type'] == 'user': - users = self.find_matches(word_before_cursor, self.users, - start_only=False, - fuzzy=True) + elif suggestion["type"] == "user": + users = self.find_matches(word_before_cursor, self.users, start_only=False, fuzzy=True) completions.extend(users) - elif suggestion['type'] == 'special': - special = self.find_matches(word_before_cursor, - self.special_commands, - start_only=True, - fuzzy=False) + elif suggestion["type"] == "special": + special = self.find_matches(word_before_cursor, self.special_commands, start_only=True, fuzzy=False) completions.extend(special) - elif suggestion['type'] == 'favoritequery': - queries = self.find_matches(word_before_cursor, - FavoriteQueries.instance.list(), - start_only=False, fuzzy=True) + elif suggestion["type"] == "favoritequery": + queries = self.find_matches(word_before_cursor, FavoriteQueries.instance.list(), start_only=False, fuzzy=True) completions.extend(queries) - elif suggestion['type'] == 'table_format': - formats = self.find_matches(word_before_cursor, - self.table_formats) + elif suggestion["type"] == "table_format": + formats = self.find_matches(word_before_cursor, self.table_formats) completions.extend(formats) - elif suggestion['type'] == 'file_name': + elif suggestion["type"] == "file_name": file_names = self.find_files(word_before_cursor) completions.extend(file_names) @@ -553,20 +1228,20 @@ class SQLCompleter(Completer): # tables and views cannot share the same name, we can check one # at a time try: - columns.extend(meta['tables'][schema][relname]) + columns.extend(meta["tables"][schema][relname]) # Table exists, so don't bother checking for a view continue except KeyError: try: - columns.extend(meta['tables'][schema][escaped_relname]) + columns.extend(meta["tables"][schema][escaped_relname]) # Table exists, so don't bother checking for a view continue except KeyError: pass try: - columns.extend(meta['views'][schema][relname]) + columns.extend(meta["views"][schema][relname]) except KeyError: pass diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index bd5f5d9..d5b6db6 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -5,31 +5,29 @@ import re import pymysql from .packages import special from pymysql.constants import FIELD_TYPE -from pymysql.converters import (convert_datetime, - convert_timedelta, convert_date, conversions, - decoders) +from pymysql.converters import convert_datetime, convert_timedelta, convert_date, conversions, decoders + try: - import paramiko + import paramiko # noqa: F401 + import sshtunnel except ImportError: - from mycli.packages.paramiko_stub import paramiko + pass _logger = logging.getLogger(__name__) FIELD_TYPES = decoders.copy() -FIELD_TYPES.update({ - FIELD_TYPE.NULL: type(None) -}) +FIELD_TYPES.update({FIELD_TYPE.NULL: type(None)}) ERROR_CODE_ACCESS_DENIED = 1045 class ServerSpecies(enum.Enum): - MySQL = 'MySQL' - MariaDB = 'MariaDB' - Percona = 'Percona' - TiDB = 'TiDB' - Unknown = 'MySQL' + MySQL = "MySQL" + MariaDB = "MariaDB" + Percona = "Percona" + TiDB = "TiDB" + Unknown = "MySQL" class ServerInfo: @@ -43,7 +41,7 @@ class ServerInfo: if not version_str or not isinstance(version_str, str): return 0 try: - major, minor, patch = version_str.split('.') + major, minor, patch = version_str.split(".") except ValueError: return 0 else: @@ -52,55 +50,67 @@ class ServerInfo: @classmethod def from_version_string(cls, version_string): if not version_string: - return cls(ServerSpecies.Unknown, '') + return cls(ServerSpecies.Unknown, "") re_species = ( - (r'(?P<version>[0-9\.]+)-MariaDB', ServerSpecies.MariaDB), - (r'[0-9\.]*-TiDB-v(?P<version>[0-9\.]+)-?(?P<comment>[a-z0-9\-]*)', ServerSpecies.TiDB), - (r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[0-9]+$)', - ServerSpecies.Percona), - (r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[A-Za-z0-9_]+)', - ServerSpecies.MySQL), + (r"(?P<version>[0-9\.]+)-MariaDB", ServerSpecies.MariaDB), + (r"[0-9\.]*-TiDB-v(?P<version>[0-9\.]+)-?(?P<comment>[a-z0-9\-]*)", ServerSpecies.TiDB), + (r"(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[0-9]+$)", ServerSpecies.Percona), + (r"(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[A-Za-z0-9_]+)", ServerSpecies.MySQL), ) for regexp, species in re_species: match = re.search(regexp, version_string) if match is not None: - parsed_version = match.group('version') + parsed_version = match.group("version") detected_species = species break else: detected_species = ServerSpecies.Unknown - parsed_version = '' + parsed_version = "" return cls(detected_species, parsed_version) def __str__(self): if self.species: - return f'{self.species.value} {self.version_str}' + return f"{self.species.value} {self.version_str}" else: return self.version_str class SQLExecute(object): + databases_query = """SHOW DATABASES""" - databases_query = '''SHOW DATABASES''' - - tables_query = '''SHOW TABLES''' + tables_query = """SHOW TABLES""" show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"''' - users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user''' + users_query = """SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user""" functions_query = '''SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES WHERE ROUTINE_TYPE="FUNCTION" AND ROUTINE_SCHEMA = "%s"''' - table_columns_query = '''select TABLE_NAME, COLUMN_NAME from information_schema.columns + table_columns_query = """select TABLE_NAME, COLUMN_NAME from information_schema.columns where table_schema = '%s' - order by table_name,ordinal_position''' - - def __init__(self, database, user, password, host, port, socket, charset, - local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password, - ssh_key_filename, init_command=None): + order by table_name,ordinal_position""" + + def __init__( + self, + database, + user, + password, + host, + port, + socket, + charset, + local_infile, + ssl, + ssh_user, + ssh_host, + ssh_port, + ssh_password, + ssh_key_filename, + init_command=None, + ): self.dbname = database self.user = user self.password = password @@ -120,52 +130,79 @@ class SQLExecute(object): self.init_command = init_command self.connect() - def connect(self, database=None, user=None, password=None, host=None, - port=None, socket=None, charset=None, local_infile=None, - ssl=None, ssh_host=None, ssh_port=None, ssh_user=None, - ssh_password=None, ssh_key_filename=None, init_command=None): - db = (database or self.dbname) - user = (user or self.user) - password = (password or self.password) - host = (host or self.host) - port = (port or self.port) - socket = (socket or self.socket) - charset = (charset or self.charset) - local_infile = (local_infile or self.local_infile) - ssl = (ssl or self.ssl) - ssh_user = (ssh_user or self.ssh_user) - ssh_host = (ssh_host or self.ssh_host) - ssh_port = (ssh_port or self.ssh_port) - ssh_password = (ssh_password or self.ssh_password) - ssh_key_filename = (ssh_key_filename or self.ssh_key_filename) - init_command = (init_command or self.init_command) + def connect( + self, + database=None, + user=None, + password=None, + host=None, + port=None, + socket=None, + charset=None, + local_infile=None, + ssl=None, + ssh_host=None, + ssh_port=None, + ssh_user=None, + ssh_password=None, + ssh_key_filename=None, + init_command=None, + ): + db = database or self.dbname + user = user or self.user + password = password or self.password + host = host or self.host + port = port or self.port + socket = socket or self.socket + charset = charset or self.charset + local_infile = local_infile or self.local_infile + ssl = ssl or self.ssl + ssh_user = ssh_user or self.ssh_user + ssh_host = ssh_host or self.ssh_host + ssh_port = ssh_port or self.ssh_port + ssh_password = ssh_password or self.ssh_password + ssh_key_filename = ssh_key_filename or self.ssh_key_filename + init_command = init_command or self.init_command _logger.debug( - 'Connection DB Params: \n' - '\tdatabase: %r' - '\tuser: %r' - '\thost: %r' - '\tport: %r' - '\tsocket: %r' - '\tcharset: %r' - '\tlocal_infile: %r' - '\tssl: %r' - '\tssh_user: %r' - '\tssh_host: %r' - '\tssh_port: %r' - '\tssh_password: %r' - '\tssh_key_filename: %r' - '\tinit_command: %r', - db, user, host, port, socket, charset, local_infile, ssl, - ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename, - init_command + "Connection DB Params: \n" + "\tdatabase: %r" + "\tuser: %r" + "\thost: %r" + "\tport: %r" + "\tsocket: %r" + "\tcharset: %r" + "\tlocal_infile: %r" + "\tssl: %r" + "\tssh_user: %r" + "\tssh_host: %r" + "\tssh_port: %r" + "\tssh_password: %r" + "\tssh_key_filename: %r" + "\tinit_command: %r", + db, + user, + host, + port, + socket, + charset, + local_infile, + ssl, + ssh_user, + ssh_host, + ssh_port, + ssh_password, + ssh_key_filename, + init_command, ) conv = conversions.copy() - conv.update({ - FIELD_TYPE.TIMESTAMP: lambda obj: (convert_datetime(obj) or obj), - FIELD_TYPE.DATETIME: lambda obj: (convert_datetime(obj) or obj), - FIELD_TYPE.TIME: lambda obj: (convert_timedelta(obj) or obj), - FIELD_TYPE.DATE: lambda obj: (convert_date(obj) or obj), - }) + conv.update( + { + FIELD_TYPE.TIMESTAMP: lambda obj: (convert_datetime(obj) or obj), + FIELD_TYPE.DATETIME: lambda obj: (convert_datetime(obj) or obj), + FIELD_TYPE.TIME: lambda obj: (convert_timedelta(obj) or obj), + FIELD_TYPE.DATE: lambda obj: (convert_date(obj) or obj), + } + ) defer_connect = False @@ -181,29 +218,45 @@ class SQLExecute(object): ssl_context = self._create_ssl_ctx(ssl) conn = pymysql.connect( - database=db, user=user, password=password, host=host, port=port, - unix_socket=socket, use_unicode=True, charset=charset, - autocommit=True, client_flag=client_flag, - local_infile=local_infile, conv=conv, ssl=ssl_context, program_name="mycli", - defer_connect=defer_connect, init_command=init_command + database=db, + user=user, + password=password, + host=host, + port=port, + unix_socket=socket, + use_unicode=True, + charset=charset, + autocommit=True, + client_flag=client_flag, + local_infile=local_infile, + conv=conv, + ssl=ssl_context, + program_name="mycli", + defer_connect=defer_connect, + init_command=init_command, ) if ssh_host: - client = paramiko.SSHClient() - client.load_system_host_keys() - client.set_missing_host_key_policy(paramiko.WarningPolicy()) - client.connect( - ssh_host, ssh_port, ssh_user, ssh_password, - key_filename=ssh_key_filename - ) - chan = client.get_transport().open_channel( - 'direct-tcpip', - (host, port), - ('0.0.0.0', 0), - ) - conn.connect(chan) - - if hasattr(self, 'conn'): + ##### paramiko.Channel is a bad socket implementation overall if you want SSL through an SSH tunnel + ##### + # instead let's open a tunnel and rewrite host:port to local bind + try: + chan = sshtunnel.SSHTunnelForwarder( + (ssh_host, ssh_port), + ssh_username=ssh_user, + ssh_pkey=ssh_key_filename, + ssh_password=ssh_password, + remote_bind_address=(host, port), + ) + chan.start() + + conn.host = chan.local_bind_host + conn.port = chan.local_bind_port + conn.connect() + except Exception as e: + raise e + + if hasattr(self, "conn"): self.conn.close() self.conn = conn # Update them after the connection is made to ensure that it was a @@ -235,24 +288,24 @@ class SQLExecute(object): # Split the sql into separate queries and run each one. # Unless it's saving a favorite query, in which case we # want to save them all together. - if statement.startswith('\\fs'): + if statement.startswith("\\fs"): components = [statement] else: components = special.split_queries(statement) for sql in components: # \G is treated specially since we have to set the expanded output. - if sql.endswith('\\G'): + if sql.endswith("\\G"): special.set_expanded_output(True) sql = sql[:-2].strip() cur = self.conn.cursor() - try: # Special command - _logger.debug('Trying a dbspecial command. sql: %r', sql) + try: # Special command + _logger.debug("Trying a dbspecial command. sql: %r", sql) for result in special.execute(cur, sql): yield result except special.CommandNotFound: # Regular SQL - _logger.debug('Regular sql statement. sql: %r', sql) + _logger.debug("Regular sql statement. sql: %r", sql) cur.execute(sql) while True: yield self.get_result(cur) @@ -271,12 +324,11 @@ class SQLExecute(object): # e.g. SELECT or SHOW. if cursor.description is not None: headers = [x[0] for x in cursor.description] - status = '{0} row{1} in set' + status = "{0} row{1} in set" else: - _logger.debug('No rows in result.') - status = 'Query OK, {0} row{1} affected' - status = status.format(cursor.rowcount, - '' if cursor.rowcount == 1 else 's') + _logger.debug("No rows in result.") + status = "Query OK, {0} row{1} affected" + status = status.format(cursor.rowcount, "" if cursor.rowcount == 1 else "s") return (title, cursor if cursor.description else None, headers, status) @@ -284,7 +336,7 @@ class SQLExecute(object): """Yields table names""" with self.conn.cursor() as cur: - _logger.debug('Tables Query. sql: %r', self.tables_query) + _logger.debug("Tables Query. sql: %r", self.tables_query) cur.execute(self.tables_query) for row in cur: yield row @@ -292,14 +344,14 @@ class SQLExecute(object): def table_columns(self): """Yields (table name, column name) pairs""" with self.conn.cursor() as cur: - _logger.debug('Columns Query. sql: %r', self.table_columns_query) + _logger.debug("Columns Query. sql: %r", self.table_columns_query) cur.execute(self.table_columns_query % self.dbname) for row in cur: yield row def databases(self): with self.conn.cursor() as cur: - _logger.debug('Databases Query. sql: %r', self.databases_query) + _logger.debug("Databases Query. sql: %r", self.databases_query) cur.execute(self.databases_query) return [x[0] for x in cur.fetchall()] @@ -307,31 +359,31 @@ class SQLExecute(object): """Yields tuples of (schema_name, function_name)""" with self.conn.cursor() as cur: - _logger.debug('Functions Query. sql: %r', self.functions_query) + _logger.debug("Functions Query. sql: %r", self.functions_query) cur.execute(self.functions_query % self.dbname) for row in cur: yield row def show_candidates(self): with self.conn.cursor() as cur: - _logger.debug('Show Query. sql: %r', self.show_candidates_query) + _logger.debug("Show Query. sql: %r", self.show_candidates_query) try: cur.execute(self.show_candidates_query) except pymysql.DatabaseError as e: - _logger.error('No show completions due to %r', e) - yield '' + _logger.error("No show completions due to %r", e) + yield "" else: for row in cur: - yield (row[0].split(None, 1)[-1], ) + yield (row[0].split(None, 1)[-1],) def users(self): with self.conn.cursor() as cur: - _logger.debug('Users Query. sql: %r', self.users_query) + _logger.debug("Users Query. sql: %r", self.users_query) try: cur.execute(self.users_query) except pymysql.DatabaseError as e: - _logger.error('No user completions due to %r', e) - yield '' + _logger.error("No user completions due to %r", e) + yield "" else: for row in cur: yield row @@ -343,17 +395,17 @@ class SQLExecute(object): def reset_connection_id(self): # Remember current connection id - _logger.debug('Get current connection id') + _logger.debug("Get current connection id") try: - res = self.run('select connection_id()') + res = self.run("select connection_id()") for title, cur, headers, status in res: self.connection_id = cur.fetchone()[0] except Exception as e: # See #1054 self.connection_id = -1 - _logger.error('Failed to get connection id: %s', e) + _logger.error("Failed to get connection id: %s", e) else: - _logger.debug('Current connection id: %s', self.connection_id) + _logger.debug("Current connection id: %s", self.connection_id) def change_db(self, db): self.conn.select_db(db) @@ -392,6 +444,6 @@ class SQLExecute(object): ctx.minimum_version = ssl.TLSVersion.TLSv1_3 ctx.maximum_version = ssl.TLSVersion.TLSv1_3 else: - _logger.error('Invalid tls version: %s', tls_version) + _logger.error("Invalid tls version: %s", tls_version) return ctx diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..107e85b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,59 @@ +[project] +name = "mycli" +dynamic = ["version"] +description = "CLI for MySQL Database. With auto-completion and syntax highlighting." +readme = "README.md" +requires-python = ">=3.7" +license = { text = "BSD" } +authors = [{ name = "Mycli Core Team", email = "mycli-dev@googlegroups.com" }] +urls = { homepage = "http://mycli.net" } + +dependencies = [ + "click >= 7.0", + "cryptography >= 1.0.0", + "Pygments>=1.6", + "prompt_toolkit>=3.0.6,<4.0.0", + "PyMySQL >= 0.9.2", + "sqlparse>=0.3.0,<0.5.0", + "sqlglot>=5.1.3", + "configobj >= 5.0.5", + "cli_helpers[styles] >= 2.2.1", + "pyperclip >= 1.8.1", + "pyaes >= 1.6.1", + "pyfzf >= 0.3.1", + "importlib_resources >= 5.0.0; python_version<'3.9'", +] + +[build-system] +requires = [ + "setuptools>=64.0", + "setuptools-scm>=8;python_version>='3.8'", + "setuptools-scm<8;python_version<'3.8'", +] +build-backend = "setuptools.build_meta" + +[tool.setuptools_scm] + +[project.optional-dependencies] +ssh = ["paramiko", "sshtunnel"] +dev = [ + "behave>=1.2.6", + "coverage>=7.2.7", + "pexpect>=4.9.0", + "pytest>=7.4.4", + "pytest-cov>=4.1.0", + "tox>=4.8.0", + "pdbpp>=0.10.3", +] + +[project.scripts] +mycli = "mycli.main:cli" + +[tool.setuptools.package-data] +mycli = ["myclirc", "AUTHORS", "SPONSORS"] + +[tool.setuptools.packages.find] +include = ["mycli*"] + +[tool.ruff] +line-length = 140 diff --git a/release.py b/release.py deleted file mode 100755 index 62daa80..0000000 --- a/release.py +++ /dev/null @@ -1,119 +0,0 @@ -"""A script to publish a release of mycli to PyPI.""" - -from optparse import OptionParser -import re -import subprocess -import sys - -import click - -DEBUG = False -CONFIRM_STEPS = False -DRY_RUN = False - - -def skip_step(): - """ - Asks for user's response whether to run a step. Default is yes. - :return: boolean - """ - global CONFIRM_STEPS - - if CONFIRM_STEPS: - return not click.confirm('--- Run this step?', default=True) - return False - - -def run_step(*args): - """ - Prints out the command and asks if it should be run. - If yes (default), runs it. - :param args: list of strings (command and args) - """ - global DRY_RUN - - cmd = args - print(' '.join(cmd)) - if skip_step(): - print('--- Skipping...') - elif DRY_RUN: - print('--- Pretending to run...') - else: - subprocess.check_output(cmd) - - -def version(version_file): - _version_re = re.compile( - r'__version__\s+=\s+(?P<quote>[\'"])(?P<version>.*)(?P=quote)') - - with open(version_file) as f: - ver = _version_re.search(f.read()).group('version') - - return ver - - -def commit_for_release(version_file, ver): - run_step('git', 'reset') - run_step('git', 'add', version_file) - run_step('git', 'commit', '--message', - 'Releasing version {}'.format(ver)) - - -def create_git_tag(tag_name): - run_step('git', 'tag', tag_name) - - -def create_distribution_files(): - run_step('python', 'setup.py', 'sdist', 'bdist_wheel') - - -def upload_distribution_files(): - run_step('twine', 'upload', 'dist/*') - - -def push_to_github(): - run_step('git', 'push', 'origin', 'main') - - -def push_tags_to_github(): - run_step('git', 'push', '--tags', 'origin') - - -def checklist(questions): - for question in questions: - if not click.confirm('--- {}'.format(question), default=False): - sys.exit(1) - - -if __name__ == '__main__': - if DEBUG: - subprocess.check_output = lambda x: x - - ver = version('mycli/__init__.py') - - parser = OptionParser() - parser.add_option( - "-c", "--confirm-steps", action="store_true", dest="confirm_steps", - default=False, help=("Confirm every step. If the step is not " - "confirmed, it will be skipped.") - ) - parser.add_option( - "-d", "--dry-run", action="store_true", dest="dry_run", - default=False, help="Print out, but not actually run any steps." - ) - - popts, pargs = parser.parse_args() - CONFIRM_STEPS = popts.confirm_steps - DRY_RUN = popts.dry_run - - print('Releasing Version:', ver) - - if not click.confirm('Are you sure?', default=False): - sys.exit(1) - - commit_for_release('mycli/__init__.py', ver) - create_git_tag('v{}'.format(ver)) - create_distribution_files() - push_to_github() - push_tags_to_github() - upload_distribution_files() diff --git a/requirements-dev.txt b/requirements-dev.txt index 603efa2..abf92d3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -10,6 +10,7 @@ colorama>=0.4.1 git+https://github.com/hayd/pep8radius.git # --error-status option not released click>=7.0 paramiko==2.11.0 +sshtunnel==0.4.0 pyperclip>=1.8.1 importlib_resources>=5.0.0 pyaes>=1.6.1 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index e533c7b..0000000 --- a/setup.cfg +++ /dev/null @@ -1,18 +0,0 @@ -[bdist_wheel] -universal = 1 - -[tool:pytest] -addopts = --capture=sys - --showlocals - --doctest-modules - --doctest-ignore-import-errors - --ignore=setup.py - --ignore=mycli/magic.py - --ignore=mycli/packages/parseutils.py - --ignore=test/features - -[pep8] -rev = master -docformatter = True -diff = True -error-status = True diff --git a/setup.py b/setup.py deleted file mode 100755 index c7f9333..0000000 --- a/setup.py +++ /dev/null @@ -1,127 +0,0 @@ -#!/usr/bin/env python - -import ast -import re -import subprocess -import sys - -from setuptools import Command, find_packages, setup -from setuptools.command.test import test as TestCommand - -_version_re = re.compile(r'__version__\s+=\s+(.*)') - -with open('mycli/__init__.py') as f: - version = ast.literal_eval(_version_re.search( - f.read()).group(1)) - -description = 'CLI for MySQL Database. With auto-completion and syntax highlighting.' - -install_requirements = [ - 'click >= 7.0', - # Pinning cryptography is not needed after paramiko 2.11.0. Correct it - 'cryptography >= 1.0.0', - # 'Pygments>=1.6,<=2.11.1', - 'Pygments>=1.6', - 'prompt_toolkit>=3.0.6,<4.0.0', - 'PyMySQL >= 0.9.2', - 'sqlparse>=0.3.0,<0.5.0', - 'sqlglot>=5.1.3', - 'configobj >= 5.0.5', - 'cli_helpers[styles] >= 2.2.1', - 'pyperclip >= 1.8.1', - 'pyaes >= 1.6.1', - 'pyfzf >= 0.3.1', -] - -if sys.version_info.minor < 9: - install_requirements.append('importlib_resources >= 5.0.0') - - -class lint(Command): - description = 'check code against PEP 8 (and fix violations)' - - user_options = [ - ('branch=', 'b', 'branch/revision to compare against (e.g. main)'), - ('fix', 'f', 'fix the violations in place'), - ('error-status', 'e', 'return an error code on failed PEP check'), - ] - - def initialize_options(self): - """Set the default options.""" - self.branch = 'main' - self.fix = False - self.error_status = True - - def finalize_options(self): - pass - - def run(self): - cmd = 'pep8radius {}'.format(self.branch) - if self.fix: - cmd += ' --in-place' - if self.error_status: - cmd += ' --error-status' - sys.exit(subprocess.call(cmd, shell=True)) - - -class test(TestCommand): - - user_options = [ - ('pytest-args=', 'a', 'Arguments to pass to pytest'), - ('behave-args=', 'b', 'Arguments to pass to pytest') - ] - - def initialize_options(self): - TestCommand.initialize_options(self) - self.pytest_args = '' - self.behave_args = '--no-capture' - - def run_tests(self): - unit_test_errno = subprocess.call( - 'pytest test/ ' + self.pytest_args, - shell=True - ) - cli_errno = subprocess.call( - 'behave test/features ' + self.behave_args, - shell=True - ) - subprocess.run(['git', 'checkout', '--', 'test/myclirc'], check=False) - sys.exit(unit_test_errno or cli_errno) - - -setup( - name='mycli', - author='Mycli Core Team', - author_email='mycli-dev@googlegroups.com', - version=version, - url='http://mycli.net', - packages=find_packages(exclude=['test*']), - package_data={'mycli': ['myclirc', 'AUTHORS', 'SPONSORS']}, - description=description, - long_description=description, - install_requires=install_requirements, - entry_points={ - 'console_scripts': ['mycli = mycli.main:cli'], - }, - cmdclass={'lint': lint, 'test': test}, - python_requires=">=3.7", - classifiers=[ - 'Intended Audience :: Developers', - 'License :: OSI Approved :: BSD License', - 'Operating System :: Unix', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: SQL', - 'Topic :: Database', - 'Topic :: Database :: Front-Ends', - 'Topic :: Software Development', - 'Topic :: Software Development :: Libraries :: Python Modules', - ], - extras_require={ - 'ssh': ['paramiko'], - }, -) diff --git a/test/conftest.py b/test/conftest.py index 1325596..5575b40 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,13 +1,12 @@ import pytest -from .utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db, - db_connection, SSH_USER, SSH_HOST, SSH_PORT) +from .utils import HOST, USER, PASSWORD, PORT, CHARSET, create_db, db_connection, SSH_USER, SSH_HOST, SSH_PORT import mycli.sqlexecute @pytest.fixture(scope="function") def connection(): - create_db('mycli_test_db') - connection = db_connection('mycli_test_db') + create_db("mycli_test_db") + connection = db_connection("mycli_test_db") yield connection connection.close() @@ -22,8 +21,18 @@ def cursor(connection): @pytest.fixture def executor(connection): return mycli.sqlexecute.SQLExecute( - database='mycli_test_db', user=USER, - host=HOST, password=PASSWORD, port=PORT, socket=None, charset=CHARSET, - local_infile=False, ssl=None, ssh_user=SSH_USER, ssh_host=SSH_HOST, - ssh_port=SSH_PORT, ssh_password=None, ssh_key_filename=None + database="mycli_test_db", + user=USER, + host=HOST, + password=PASSWORD, + port=PORT, + socket=None, + charset=CHARSET, + local_infile=False, + ssl=None, + ssh_user=SSH_USER, + ssh_host=SSH_HOST, + ssh_port=SSH_PORT, + ssh_password=None, + ssh_key_filename=None, ) diff --git a/test/features/db_utils.py b/test/features/db_utils.py index be550e9..175cc1b 100644 --- a/test/features/db_utils.py +++ b/test/features/db_utils.py @@ -1,8 +1,7 @@ import pymysql -def create_db(hostname='localhost', port=3306, username=None, - password=None, dbname=None): +def create_db(hostname="localhost", port=3306, username=None, password=None, dbname=None): """Create test database. :param hostname: string @@ -14,17 +13,12 @@ def create_db(hostname='localhost', port=3306, username=None, """ cn = pymysql.connect( - host=hostname, - port=port, - user=username, - password=password, - charset='utf8mb4', - cursorclass=pymysql.cursors.DictCursor + host=hostname, port=port, user=username, password=password, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor ) with cn.cursor() as cr: - cr.execute('drop database if exists ' + dbname) - cr.execute('create database ' + dbname) + cr.execute("drop database if exists " + dbname) + cr.execute("create database " + dbname) cn.close() @@ -44,20 +38,13 @@ def create_cn(hostname, port, password, username, dbname): """ cn = pymysql.connect( - host=hostname, - port=port, - user=username, - password=password, - db=dbname, - charset='utf8mb4', - cursorclass=pymysql.cursors.DictCursor + host=hostname, port=port, user=username, password=password, db=dbname, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor ) return cn -def drop_db(hostname='localhost', port=3306, username=None, - password=None, dbname=None): +def drop_db(hostname="localhost", port=3306, username=None, password=None, dbname=None): """Drop database. :param hostname: string @@ -68,17 +55,11 @@ def drop_db(hostname='localhost', port=3306, username=None, """ cn = pymysql.connect( - host=hostname, - port=port, - user=username, - password=password, - db=dbname, - charset='utf8mb4', - cursorclass=pymysql.cursors.DictCursor + host=hostname, port=port, user=username, password=password, db=dbname, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor ) with cn.cursor() as cr: - cr.execute('drop database if exists ' + dbname) + cr.execute("drop database if exists " + dbname) close_cn(cn) diff --git a/test/features/environment.py b/test/features/environment.py index 1ea0f08..a3d3764 100644 --- a/test/features/environment.py +++ b/test/features/environment.py @@ -9,96 +9,72 @@ import pexpect from steps.wrappers import run_cli, wait_prompt -test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log') +test_log_file = os.path.join(os.environ["HOME"], ".mycli.test.log") -SELF_CONNECTING_FEATURES = ( - 'test/features/connection.feature', -) +SELF_CONNECTING_FEATURES = ("test/features/connection.feature",) -MY_CNF_PATH = os.path.expanduser('~/.my.cnf') -MY_CNF_BACKUP_PATH = f'{MY_CNF_PATH}.backup' -MYLOGIN_CNF_PATH = os.path.expanduser('~/.mylogin.cnf') -MYLOGIN_CNF_BACKUP_PATH = f'{MYLOGIN_CNF_PATH}.backup' +MY_CNF_PATH = os.path.expanduser("~/.my.cnf") +MY_CNF_BACKUP_PATH = f"{MY_CNF_PATH}.backup" +MYLOGIN_CNF_PATH = os.path.expanduser("~/.mylogin.cnf") +MYLOGIN_CNF_BACKUP_PATH = f"{MYLOGIN_CNF_PATH}.backup" def get_db_name_from_context(context): - return context.config.userdata.get( - 'my_test_db', None - ) or "mycli_behave_tests" - + return context.config.userdata.get("my_test_db", None) or "mycli_behave_tests" def before_all(context): """Set env parameters.""" - os.environ['LINES'] = "100" - os.environ['COLUMNS'] = "100" - os.environ['EDITOR'] = 'ex' - os.environ['LC_ALL'] = 'en_US.UTF-8' - os.environ['PROMPT_TOOLKIT_NO_CPR'] = '1' - os.environ['MYCLI_HISTFILE'] = os.devnull + os.environ["LINES"] = "100" + os.environ["COLUMNS"] = "100" + os.environ["EDITOR"] = "ex" + os.environ["LC_ALL"] = "en_US.UTF-8" + os.environ["PROMPT_TOOLKIT_NO_CPR"] = "1" + os.environ["MYCLI_HISTFILE"] = os.devnull - test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) - login_path_file = os.path.join(test_dir, 'mylogin.cnf') -# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file + # test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) + # login_path_file = os.path.join(test_dir, "mylogin.cnf") + # os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file - context.package_root = os.path.abspath( - os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + context.package_root = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) - os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root, - '.coveragerc') + os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root, ".coveragerc") context.exit_sent = False - vi = '_'.join([str(x) for x in sys.version_info[:3]]) + vi = "_".join([str(x) for x in sys.version_info[:3]]) db_name = get_db_name_from_context(context) - db_name_full = '{0}_{1}'.format(db_name, vi) + db_name_full = "{0}_{1}".format(db_name, vi) # Store get params from config/environment variables context.conf = { - 'host': context.config.userdata.get( - 'my_test_host', - os.getenv('PYTEST_HOST', 'localhost') - ), - 'port': context.config.userdata.get( - 'my_test_port', - int(os.getenv('PYTEST_PORT', '3306')) - ), - 'user': context.config.userdata.get( - 'my_test_user', - os.getenv('PYTEST_USER', 'root') - ), - 'pass': context.config.userdata.get( - 'my_test_pass', - os.getenv('PYTEST_PASSWORD', None) - ), - 'cli_command': context.config.userdata.get( - 'my_cli_command', None) or - sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"', - 'dbname': db_name, - 'dbname_tmp': db_name_full + '_tmp', - 'vi': vi, - 'pager_boundary': '---boundary---', + "host": context.config.userdata.get("my_test_host", os.getenv("PYTEST_HOST", "localhost")), + "port": context.config.userdata.get("my_test_port", int(os.getenv("PYTEST_PORT", "3306"))), + "user": context.config.userdata.get("my_test_user", os.getenv("PYTEST_USER", "root")), + "pass": context.config.userdata.get("my_test_pass", os.getenv("PYTEST_PASSWORD", None)), + "cli_command": context.config.userdata.get("my_cli_command", None) + or sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"', + "dbname": db_name, + "dbname_tmp": db_name_full + "_tmp", + "vi": vi, + "pager_boundary": "---boundary---", } _, my_cnf = mkstemp() - with open(my_cnf, 'w') as f: + with open(my_cnf, "w") as f: f.write( - '[client]\n' - 'pager={0} {1} {2}\n'.format( - sys.executable, os.path.join(context.package_root, - 'test/features/wrappager.py'), - context.conf['pager_boundary']) + "[client]\n" "pager={0} {1} {2}\n".format( + sys.executable, os.path.join(context.package_root, "test/features/wrappager.py"), context.conf["pager_boundary"] + ) ) - context.conf['defaults-file'] = my_cnf - context.conf['myclirc'] = os.path.join(context.package_root, 'test', - 'myclirc') + context.conf["defaults-file"] = my_cnf + context.conf["myclirc"] = os.path.join(context.package_root, "test", "myclirc") - context.cn = dbutils.create_db(context.conf['host'], context.conf['port'], - context.conf['user'], - context.conf['pass'], - context.conf['dbname']) + context.cn = dbutils.create_db( + context.conf["host"], context.conf["port"], context.conf["user"], context.conf["pass"], context.conf["dbname"] + ) context.fixture_data = fixutils.read_fixture_files() @@ -106,12 +82,10 @@ def before_all(context): def after_all(context): """Unset env parameters.""" dbutils.close_cn(context.cn) - dbutils.drop_db(context.conf['host'], context.conf['port'], - context.conf['user'], context.conf['pass'], - context.conf['dbname']) + dbutils.drop_db(context.conf["host"], context.conf["port"], context.conf["user"], context.conf["pass"], context.conf["dbname"]) # Restore env vars. - #for k, v in context.pgenv.items(): + # for k, v in context.pgenv.items(): # if k in os.environ and v is None: # del os.environ[k] # elif v: @@ -123,8 +97,8 @@ def before_step(context, _): def before_scenario(context, arg): - with open(test_log_file, 'w') as f: - f.write('') + with open(test_log_file, "w") as f: + f.write("") if arg.location.filename not in SELF_CONNECTING_FEATURES: run_cli(context) wait_prompt(context) @@ -140,23 +114,18 @@ def after_scenario(context, _): """Cleans up after each test complete.""" with open(test_log_file) as f: for line in f: - if 'error' in line.lower(): - raise RuntimeError(f'Error in log file: {line}') + if "error" in line.lower(): + raise RuntimeError(f"Error in log file: {line}") - if hasattr(context, 'cli') and not context.exit_sent: + if hasattr(context, "cli") and not context.exit_sent: # Quit nicely. if not context.atprompt: - user = context.conf['user'] - host = context.conf['host'] + user = context.conf["user"] + host = context.conf["host"] dbname = context.currentdb - context.cli.expect_exact( - '{0}@{1}:{2}>'.format( - user, host, dbname - ), - timeout=5 - ) - context.cli.sendcontrol('c') - context.cli.sendcontrol('d') + context.cli.expect_exact("{0}@{1}:{2}>".format(user, host, dbname), timeout=5) + context.cli.sendcontrol("c") + context.cli.sendcontrol("d") context.cli.expect_exact(pexpect.EOF, timeout=5) if os.path.exists(MY_CNF_BACKUP_PATH): diff --git a/test/features/fixture_utils.py b/test/features/fixture_utils.py index f85e0f6..514e41f 100644 --- a/test/features/fixture_utils.py +++ b/test/features/fixture_utils.py @@ -1,5 +1,4 @@ import os -import io def read_fixture_lines(filename): @@ -20,9 +19,9 @@ def read_fixture_files(): fixture_dict = {} current_dir = os.path.dirname(__file__) - fixture_dir = os.path.join(current_dir, 'fixture_data/') + fixture_dir = os.path.join(current_dir, "fixture_data/") for filename in os.listdir(fixture_dir): - if filename not in ['.', '..']: + if filename not in [".", ".."]: fullname = os.path.join(fixture_dir, filename) fixture_dict[filename] = read_fixture_lines(fullname) diff --git a/test/features/steps/auto_vertical.py b/test/features/steps/auto_vertical.py index e1cb26f..ad20067 100644 --- a/test/features/steps/auto_vertical.py +++ b/test/features/steps/auto_vertical.py @@ -6,41 +6,42 @@ import wrappers from utils import parse_cli_args_to_dict -@when('we run dbcli with {arg}') +@when("we run dbcli with {arg}") def step_run_cli_with_arg(context, arg): wrappers.run_cli(context, run_args=parse_cli_args_to_dict(arg)) -@when('we execute a small query') +@when("we execute a small query") def step_execute_small_query(context): - context.cli.sendline('select 1') + context.cli.sendline("select 1") -@when('we execute a large query') +@when("we execute a large query") def step_execute_large_query(context): - context.cli.sendline( - 'select {}'.format(','.join([str(n) for n in range(1, 50)]))) + context.cli.sendline("select {}".format(",".join([str(n) for n in range(1, 50)]))) -@then('we see small results in horizontal format') +@then("we see small results in horizontal format") def step_see_small_results(context): - wrappers.expect_pager(context, dedent("""\ + wrappers.expect_pager( + context, + dedent("""\ +---+\r | 1 |\r +---+\r | 1 |\r +---+\r \r - """), timeout=5) - wrappers.expect_exact(context, '1 row in set', timeout=2) + """), + timeout=5, + ) + wrappers.expect_exact(context, "1 row in set", timeout=2) -@then('we see large results in vertical format') +@then("we see large results in vertical format") def step_see_large_results(context): - rows = ['{n:3}| {n}'.format(n=str(n)) for n in range(1, 50)] - expected = ('***************************[ 1. row ]' - '***************************\r\n' + - '{}\r\n'.format('\r\n'.join(rows) + '\r\n')) + rows = ["{n:3}| {n}".format(n=str(n)) for n in range(1, 50)] + expected = "***************************[ 1. row ]" "***************************\r\n" + "{}\r\n".format("\r\n".join(rows) + "\r\n") wrappers.expect_pager(context, expected, timeout=10) - wrappers.expect_exact(context, '1 row in set', timeout=2) + wrappers.expect_exact(context, "1 row in set", timeout=2) diff --git a/test/features/steps/basic_commands.py b/test/features/steps/basic_commands.py index 425ef67..ec1e47a 100644 --- a/test/features/steps/basic_commands.py +++ b/test/features/steps/basic_commands.py @@ -5,18 +5,18 @@ to call the step in "*.feature" file. """ -from behave import when +from behave import when, then from textwrap import dedent import tempfile import wrappers -@when('we run dbcli') +@when("we run dbcli") def step_run_cli(context): wrappers.run_cli(context) -@when('we wait for prompt') +@when("we wait for prompt") def step_wait_prompt(context): wrappers.wait_prompt(context) @@ -24,77 +24,75 @@ def step_wait_prompt(context): @when('we send "ctrl + d"') def step_ctrl_d(context): """Send Ctrl + D to hopefully exit.""" - context.cli.sendcontrol('d') + context.cli.sendcontrol("d") context.exit_sent = True -@when('we send "\?" command') +@when(r'we send "\?" command') def step_send_help(context): - """Send \? + r"""Send \? to see help. """ - context.cli.sendline('\\?') - wrappers.expect_exact( - context, context.conf['pager_boundary'] + '\r\n', timeout=5) + context.cli.sendline("\\?") + wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) -@when(u'we send source command') +@when("we send source command") def step_send_source_command(context): with tempfile.NamedTemporaryFile() as f: - f.write(b'\?') + f.write(b"\\?") f.flush() - context.cli.sendline('\. {0}'.format(f.name)) - wrappers.expect_exact( - context, context.conf['pager_boundary'] + '\r\n', timeout=5) + context.cli.sendline("\\. {0}".format(f.name)) + wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) -@when(u'we run query to check application_name') +@when("we run query to check application_name") def step_check_application_name(context): context.cli.sendline( "SELECT 'found' FROM performance_schema.session_connect_attrs WHERE attr_name = 'program_name' AND attr_value = 'mycli'" ) -@then(u'we see found') +@then("we see found") def step_see_found(context): wrappers.expect_exact( context, - context.conf['pager_boundary'] + '\r' + dedent(''' + context.conf["pager_boundary"] + + "\r" + + dedent(""" +-------+\r | found |\r +-------+\r | found |\r +-------+\r \r - ''') + context.conf['pager_boundary'], - timeout=5 + """) + + context.conf["pager_boundary"], + timeout=5, ) -@then(u'we confirm the destructive warning') -def step_confirm_destructive_command(context): +@then("we confirm the destructive warning") +def step_confirm_destructive_command(context): # noqa """Confirm destructive command.""" - wrappers.expect_exact( - context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) - context.cli.sendline('y') + wrappers.expect_exact(context, "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", timeout=2) + context.cli.sendline("y") -@when(u'we answer the destructive warning with "{confirmation}"') -def step_confirm_destructive_command(context, confirmation): +@when('we answer the destructive warning with "{confirmation}"') +def step_confirm_destructive_command(context, confirmation): # noqa """Confirm destructive command.""" - wrappers.expect_exact( - context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) + wrappers.expect_exact(context, "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", timeout=2) context.cli.sendline(confirmation) -@then(u'we answer the destructive warning with invalid "{confirmation}" and see text "{text}"') -def step_confirm_destructive_command(context, confirmation, text): +@then('we answer the destructive warning with invalid "{confirmation}" and see text "{text}"') +def step_confirm_destructive_command(context, confirmation, text): # noqa """Confirm destructive command.""" - wrappers.expect_exact( - context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) + wrappers.expect_exact(context, "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", timeout=2) context.cli.sendline(confirmation) wrappers.expect_exact(context, text, timeout=2) # we must exit the Click loop, or the feature will hang - context.cli.sendline('n') + context.cli.sendline("n") diff --git a/test/features/steps/connection.py b/test/features/steps/connection.py index e16dd86..80d0653 100644 --- a/test/features/steps/connection.py +++ b/test/features/steps/connection.py @@ -1,9 +1,7 @@ import io import os -import shlex from behave import when, then -import pexpect import wrappers from test.features.steps.utils import parse_cli_args_to_dict @@ -12,60 +10,44 @@ from test.utils import HOST, PORT, USER, PASSWORD from mycli.config import encrypt_mylogin_cnf -TEST_LOGIN_PATH = 'test_login_path' +TEST_LOGIN_PATH = "test_login_path" @when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"') @when('we run mycli without arguments "{excluded_args}"') -def step_run_cli_without_args(context, excluded_args, exact_args=''): - wrappers.run_cli( - context, - run_args=parse_cli_args_to_dict(exact_args), - exclude_args=parse_cli_args_to_dict(excluded_args).keys() - ) +def step_run_cli_without_args(context, excluded_args, exact_args=""): + wrappers.run_cli(context, run_args=parse_cli_args_to_dict(exact_args), exclude_args=parse_cli_args_to_dict(excluded_args).keys()) @then('status contains "{expression}"') def status_contains(context, expression): - wrappers.expect_exact(context, f'{expression}', timeout=5) + wrappers.expect_exact(context, f"{expression}", timeout=5) # Normally, the shutdown after scenario waits for the prompt. # But we may have changed the prompt, depending on parameters, # so let's wait for its last character - context.cli.expect_exact('>') + context.cli.expect_exact(">") context.atprompt = True -@when('we create my.cnf file') +@when("we create my.cnf file") def step_create_my_cnf_file(context): - my_cnf = ( - '[client]\n' - f'host = {HOST}\n' - f'port = {PORT}\n' - f'user = {USER}\n' - f'password = {PASSWORD}\n' - ) - with open(MY_CNF_PATH, 'w') as f: + my_cnf = "[client]\n" f"host = {HOST}\n" f"port = {PORT}\n" f"user = {USER}\n" f"password = {PASSWORD}\n" + with open(MY_CNF_PATH, "w") as f: f.write(my_cnf) -@when('we create mylogin.cnf file') +@when("we create mylogin.cnf file") def step_create_mylogin_cnf_file(context): - os.environ.pop('MYSQL_TEST_LOGIN_FILE', None) - mylogin_cnf = ( - f'[{TEST_LOGIN_PATH}]\n' - f'host = {HOST}\n' - f'port = {PORT}\n' - f'user = {USER}\n' - f'password = {PASSWORD}\n' - ) - with open(MYLOGIN_CNF_PATH, 'wb') as f: + os.environ.pop("MYSQL_TEST_LOGIN_FILE", None) + mylogin_cnf = f"[{TEST_LOGIN_PATH}]\n" f"host = {HOST}\n" f"port = {PORT}\n" f"user = {USER}\n" f"password = {PASSWORD}\n" + with open(MYLOGIN_CNF_PATH, "wb") as f: input_file = io.StringIO(mylogin_cnf) f.write(encrypt_mylogin_cnf(input_file).read()) -@then('we are logged in') +@then("we are logged in") def we_are_logged_in(context): db_name = get_db_name_from_context(context) - context.cli.expect_exact(f'{db_name}>', timeout=5) + context.cli.expect_exact(f"{db_name}>", timeout=5) context.atprompt = True diff --git a/test/features/steps/crud_database.py b/test/features/steps/crud_database.py index 841f37d..56ff114 100644 --- a/test/features/steps/crud_database.py +++ b/test/features/steps/crud_database.py @@ -11,105 +11,99 @@ import wrappers from behave import when, then -@when('we create database') +@when("we create database") def step_db_create(context): """Send create database.""" - context.cli.sendline('create database {0};'.format( - context.conf['dbname_tmp'])) + context.cli.sendline("create database {0};".format(context.conf["dbname_tmp"])) - context.response = { - 'database_name': context.conf['dbname_tmp'] - } + context.response = {"database_name": context.conf["dbname_tmp"]} -@when('we drop database') +@when("we drop database") def step_db_drop(context): """Send drop database.""" - context.cli.sendline('drop database {0};'.format( - context.conf['dbname_tmp'])) + context.cli.sendline("drop database {0};".format(context.conf["dbname_tmp"])) -@when('we connect to test database') +@when("we connect to test database") def step_db_connect_test(context): """Send connect to database.""" - db_name = context.conf['dbname'] + db_name = context.conf["dbname"] context.currentdb = db_name - context.cli.sendline('use {0};'.format(db_name)) + context.cli.sendline("use {0};".format(db_name)) -@when('we connect to quoted test database') +@when("we connect to quoted test database") def step_db_connect_quoted_tmp(context): """Send connect to database.""" - db_name = context.conf['dbname'] + db_name = context.conf["dbname"] context.currentdb = db_name - context.cli.sendline('use `{0}`;'.format(db_name)) + context.cli.sendline("use `{0}`;".format(db_name)) -@when('we connect to tmp database') +@when("we connect to tmp database") def step_db_connect_tmp(context): """Send connect to database.""" - db_name = context.conf['dbname_tmp'] + db_name = context.conf["dbname_tmp"] context.currentdb = db_name - context.cli.sendline('use {0}'.format(db_name)) + context.cli.sendline("use {0}".format(db_name)) -@when('we connect to dbserver') +@when("we connect to dbserver") def step_db_connect_dbserver(context): """Send connect to database.""" - context.currentdb = 'mysql' - context.cli.sendline('use mysql') + context.currentdb = "mysql" + context.cli.sendline("use mysql") -@then('dbcli exits') +@then("dbcli exits") def step_wait_exit(context): """Make sure the cli exits.""" wrappers.expect_exact(context, pexpect.EOF, timeout=5) -@then('we see dbcli prompt') +@then("we see dbcli prompt") def step_see_prompt(context): """Wait to see the prompt.""" - user = context.conf['user'] - host = context.conf['host'] + user = context.conf["user"] + host = context.conf["host"] dbname = context.currentdb - wrappers.wait_prompt(context, '{0}@{1}:{2}> '.format(user, host, dbname)) + wrappers.wait_prompt(context, "{0}@{1}:{2}> ".format(user, host, dbname)) -@then('we see help output') +@then("we see help output") def step_see_help(context): - for expected_line in context.fixture_data['help_commands.txt']: + for expected_line in context.fixture_data["help_commands.txt"]: wrappers.expect_exact(context, expected_line, timeout=1) -@then('we see database created') +@then("we see database created") def step_see_db_created(context): """Wait to see create database output.""" - wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2) -@then('we see database dropped') +@then("we see database dropped") def step_see_db_dropped(context): """Wait to see drop database output.""" - wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) + wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2) -@then('we see database dropped and no default database') +@then("we see database dropped and no default database") def step_see_db_dropped_no_default(context): """Wait to see drop database output.""" - user = context.conf['user'] - host = context.conf['host'] - database = '(none)' + user = context.conf["user"] + host = context.conf["host"] + database = "(none)" context.currentdb = None - wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) - wrappers.wait_prompt(context, '{0}@{1}:{2}>'.format(user, host, database)) + wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2) + wrappers.wait_prompt(context, "{0}@{1}:{2}>".format(user, host, database)) -@then('we see database connected') +@then("we see database connected") def step_see_db_connected(context): """Wait to see drop database output.""" - wrappers.expect_exact( - context, 'You are now connected to database "', timeout=2) + wrappers.expect_exact(context, 'You are now connected to database "', timeout=2) wrappers.expect_exact(context, '"', timeout=2) - wrappers.expect_exact(context, ' as user "{0}"'.format( - context.conf['user']), timeout=2) + wrappers.expect_exact(context, ' as user "{0}"'.format(context.conf["user"]), timeout=2) diff --git a/test/features/steps/crud_table.py b/test/features/steps/crud_table.py index f715f0c..48a6408 100644 --- a/test/features/steps/crud_table.py +++ b/test/features/steps/crud_table.py @@ -10,103 +10,109 @@ from behave import when, then from textwrap import dedent -@when('we create table') +@when("we create table") def step_create_table(context): """Send create table.""" - context.cli.sendline('create table a(x text);') + context.cli.sendline("create table a(x text);") -@when('we insert into table') +@when("we insert into table") def step_insert_into_table(context): """Send insert into table.""" - context.cli.sendline('''insert into a(x) values('xxx');''') + context.cli.sendline("""insert into a(x) values('xxx');""") -@when('we update table') +@when("we update table") def step_update_table(context): """Send insert into table.""" - context.cli.sendline('''update a set x = 'yyy' where x = 'xxx';''') + context.cli.sendline("""update a set x = 'yyy' where x = 'xxx';""") -@when('we select from table') +@when("we select from table") def step_select_from_table(context): """Send select from table.""" - context.cli.sendline('select * from a;') + context.cli.sendline("select * from a;") -@when('we delete from table') +@when("we delete from table") def step_delete_from_table(context): """Send deete from table.""" - context.cli.sendline('''delete from a where x = 'yyy';''') + context.cli.sendline("""delete from a where x = 'yyy';""") -@when('we drop table') +@when("we drop table") def step_drop_table(context): """Send drop table.""" - context.cli.sendline('drop table a;') + context.cli.sendline("drop table a;") -@then('we see table created') +@then("we see table created") def step_see_table_created(context): """Wait to see create table output.""" - wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) + wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2) -@then('we see record inserted') +@then("we see record inserted") def step_see_record_inserted(context): """Wait to see insert output.""" - wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2) -@then('we see record updated') +@then("we see record updated") def step_see_record_updated(context): """Wait to see update output.""" - wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2) -@then('we see data selected') +@then("we see data selected") def step_see_data_selected(context): """Wait to see select output.""" wrappers.expect_pager( - context, dedent("""\ + context, + dedent("""\ +-----+\r | x |\r +-----+\r | yyy |\r +-----+\r \r - """), timeout=2) - wrappers.expect_exact(context, '1 row in set', timeout=2) + """), + timeout=2, + ) + wrappers.expect_exact(context, "1 row in set", timeout=2) -@then('we see record deleted') +@then("we see record deleted") def step_see_data_deleted(context): """Wait to see delete output.""" - wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2) -@then('we see table dropped') +@then("we see table dropped") def step_see_table_dropped(context): """Wait to see drop output.""" - wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) + wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2) -@when('we select null') +@when("we select null") def step_select_null(context): """Send select null.""" - context.cli.sendline('select null;') + context.cli.sendline("select null;") -@then('we see null selected') +@then("we see null selected") def step_see_null_selected(context): """Wait to see null output.""" wrappers.expect_pager( - context, dedent("""\ + context, + dedent("""\ +--------+\r | NULL |\r +--------+\r | <null> |\r +--------+\r \r - """), timeout=2) - wrappers.expect_exact(context, '1 row in set', timeout=2) + """), + timeout=2, + ) + wrappers.expect_exact(context, "1 row in set", timeout=2) diff --git a/test/features/steps/iocommands.py b/test/features/steps/iocommands.py index bbabf43..07d5c77 100644 --- a/test/features/steps/iocommands.py +++ b/test/features/steps/iocommands.py @@ -5,101 +5,93 @@ from behave import when, then from textwrap import dedent -@when('we start external editor providing a file name') +@when("we start external editor providing a file name") def step_edit_file(context): """Edit file with external editor.""" - context.editor_file_name = os.path.join( - context.package_root, 'test_file_{0}.sql'.format(context.conf['vi'])) + context.editor_file_name = os.path.join(context.package_root, "test_file_{0}.sql".format(context.conf["vi"])) if os.path.exists(context.editor_file_name): os.remove(context.editor_file_name) - context.cli.sendline('\e {0}'.format( - os.path.basename(context.editor_file_name))) - wrappers.expect_exact( - context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2) - wrappers.expect_exact(context, '\r\n:', timeout=2) + context.cli.sendline("\\e {0}".format(os.path.basename(context.editor_file_name))) + wrappers.expect_exact(context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2) + wrappers.expect_exact(context, "\r\n:", timeout=2) @when('we type "{query}" in the editor') def step_edit_type_sql(context, query): - context.cli.sendline('i') + context.cli.sendline("i") context.cli.sendline(query) - context.cli.sendline('.') - wrappers.expect_exact(context, '\r\n:', timeout=2) + context.cli.sendline(".") + wrappers.expect_exact(context, "\r\n:", timeout=2) -@when('we exit the editor') +@when("we exit the editor") def step_edit_quit(context): - context.cli.sendline('x') + context.cli.sendline("x") wrappers.expect_exact(context, "written", timeout=2) @then('we see "{query}" in prompt') def step_edit_done_sql(context, query): - for match in query.split(' '): + for match in query.split(" "): wrappers.expect_exact(context, match, timeout=5) # Cleanup the command line. - context.cli.sendcontrol('c') + context.cli.sendcontrol("c") # Cleanup the edited file. if context.editor_file_name and os.path.exists(context.editor_file_name): os.remove(context.editor_file_name) -@when(u'we tee output') +@when("we tee output") def step_tee_ouptut(context): - context.tee_file_name = os.path.join( - context.package_root, 'tee_file_{0}.sql'.format(context.conf['vi'])) + context.tee_file_name = os.path.join(context.package_root, "tee_file_{0}.sql".format(context.conf["vi"])) if os.path.exists(context.tee_file_name): os.remove(context.tee_file_name) - context.cli.sendline('tee {0}'.format( - os.path.basename(context.tee_file_name))) + context.cli.sendline("tee {0}".format(os.path.basename(context.tee_file_name))) -@when(u'we select "select {param}"') +@when('we select "select {param}"') def step_query_select_number(context, param): - context.cli.sendline(u'select {}'.format(param)) - wrappers.expect_pager(context, dedent(u"""\ + context.cli.sendline("select {}".format(param)) + wrappers.expect_pager( + context, + dedent( + """\ +{dashes}+\r | {param} |\r +{dashes}+\r | {param} |\r +{dashes}+\r \r - """.format(param=param, dashes='-' * (len(param) + 2)) - ), timeout=5) - wrappers.expect_exact(context, '1 row in set', timeout=2) + """.format(param=param, dashes="-" * (len(param) + 2)) + ), + timeout=5, + ) + wrappers.expect_exact(context, "1 row in set", timeout=2) -@then(u'we see result "{result}"') +@then('we see result "{result}"') def step_see_result(context, result): - wrappers.expect_exact( - context, - u"| {} |".format(result), - timeout=2 - ) + wrappers.expect_exact(context, "| {} |".format(result), timeout=2) -@when(u'we query "{query}"') +@when('we query "{query}"') def step_query(context, query): context.cli.sendline(query) -@when(u'we notee output') +@when("we notee output") def step_notee_output(context): - context.cli.sendline('notee') + context.cli.sendline("notee") -@then(u'we see 123456 in tee output') +@then("we see 123456 in tee output") def step_see_123456_in_ouput(context): with open(context.tee_file_name) as f: - assert '123456' in f.read() + assert "123456" in f.read() if os.path.exists(context.tee_file_name): os.remove(context.tee_file_name) -@then(u'delimiter is set to "{delimiter}"') +@then('delimiter is set to "{delimiter}"') def delimiter_is_set(context, delimiter): - wrappers.expect_exact( - context, - u'Changed delimiter to {}'.format(delimiter), - timeout=2 - ) + wrappers.expect_exact(context, "Changed delimiter to {}".format(delimiter), timeout=2) diff --git a/test/features/steps/named_queries.py b/test/features/steps/named_queries.py index bc1f866..93d68ba 100644 --- a/test/features/steps/named_queries.py +++ b/test/features/steps/named_queries.py @@ -9,82 +9,79 @@ import wrappers from behave import when, then -@when('we save a named query') +@when("we save a named query") def step_save_named_query(context): """Send \fs command.""" - context.cli.sendline('\\fs foo SELECT 12345') + context.cli.sendline("\\fs foo SELECT 12345") -@when('we use a named query') +@when("we use a named query") def step_use_named_query(context): """Send \f command.""" - context.cli.sendline('\\f foo') + context.cli.sendline("\\f foo") -@when('we delete a named query') +@when("we delete a named query") def step_delete_named_query(context): """Send \fd command.""" - context.cli.sendline('\\fd foo') + context.cli.sendline("\\fd foo") -@then('we see the named query saved') +@then("we see the named query saved") def step_see_named_query_saved(context): """Wait to see query saved.""" - wrappers.expect_exact(context, 'Saved.', timeout=2) + wrappers.expect_exact(context, "Saved.", timeout=2) -@then('we see the named query executed') +@then("we see the named query executed") def step_see_named_query_executed(context): """Wait to see select output.""" - wrappers.expect_exact(context, 'SELECT 12345', timeout=2) + wrappers.expect_exact(context, "SELECT 12345", timeout=2) -@then('we see the named query deleted') +@then("we see the named query deleted") def step_see_named_query_deleted(context): """Wait to see query deleted.""" - wrappers.expect_exact(context, 'foo: Deleted', timeout=2) + wrappers.expect_exact(context, "foo: Deleted", timeout=2) -@when('we save a named query with parameters') +@when("we save a named query with parameters") def step_save_named_query_with_parameters(context): """Send \fs command for query with parameters.""" context.cli.sendline('\\fs foo_args SELECT $1, "$2", "$3"') -@when('we use named query with parameters') +@when("we use named query with parameters") def step_use_named_query_with_parameters(context): """Send \f command with parameters.""" context.cli.sendline('\\f foo_args 101 second "third value"') -@then('we see the named query with parameters executed') +@then("we see the named query with parameters executed") def step_see_named_query_with_parameters_executed(context): """Wait to see select output.""" - wrappers.expect_exact( - context, 'SELECT 101, "second", "third value"', timeout=2) + wrappers.expect_exact(context, 'SELECT 101, "second", "third value"', timeout=2) -@when('we use named query with too few parameters') +@when("we use named query with too few parameters") def step_use_named_query_with_too_few_parameters(context): """Send \f command with missing parameters.""" - context.cli.sendline('\\f foo_args 101') + context.cli.sendline("\\f foo_args 101") -@then('we see the named query with parameters fail with missing parameters') +@then("we see the named query with parameters fail with missing parameters") def step_see_named_query_with_parameters_fail_with_missing_parameters(context): """Wait to see select output.""" - wrappers.expect_exact( - context, 'missing substitution for $2 in query:', timeout=2) + wrappers.expect_exact(context, "missing substitution for $2 in query:", timeout=2) -@when('we use named query with too many parameters') +@when("we use named query with too many parameters") def step_use_named_query_with_too_many_parameters(context): """Send \f command with extra parameters.""" - context.cli.sendline('\\f foo_args 101 102 103 104') + context.cli.sendline("\\f foo_args 101 102 103 104") -@then('we see the named query with parameters fail with extra parameters') +@then("we see the named query with parameters fail with extra parameters") def step_see_named_query_with_parameters_fail_with_extra_parameters(context): """Wait to see select output.""" - wrappers.expect_exact( - context, 'query does not have substitution parameter $4:', timeout=2) + wrappers.expect_exact(context, "query does not have substitution parameter $4:", timeout=2) diff --git a/test/features/steps/specials.py b/test/features/steps/specials.py index e8b99e3..1b50a00 100644 --- a/test/features/steps/specials.py +++ b/test/features/steps/specials.py @@ -9,10 +9,10 @@ import wrappers from behave import when, then -@when('we refresh completions') +@when("we refresh completions") def step_refresh_completions(context): """Send refresh command.""" - context.cli.sendline('rehash') + context.cli.sendline("rehash") @then('we see text "{text}"') @@ -20,8 +20,8 @@ def step_see_text(context, text): """Wait to see given text message.""" wrappers.expect_exact(context, text, timeout=2) -@then('we see completions refresh started') + +@then("we see completions refresh started") def step_see_refresh_started(context): """Wait to see refresh output.""" - wrappers.expect_exact( - context, 'Auto-completion refresh started in the background.', timeout=2) + wrappers.expect_exact(context, "Auto-completion refresh started in the background.", timeout=2) diff --git a/test/features/steps/utils.py b/test/features/steps/utils.py index 1ae63d2..873f9d4 100644 --- a/test/features/steps/utils.py +++ b/test/features/steps/utils.py @@ -4,8 +4,8 @@ import shlex def parse_cli_args_to_dict(cli_args: str): args_dict = {} for arg in shlex.split(cli_args): - if '=' in arg: - key, value = arg.split('=') + if "=" in arg: + key, value = arg.split("=") args_dict[key] = value else: args_dict[arg] = None diff --git a/test/features/steps/wrappers.py b/test/features/steps/wrappers.py index 6408f23..f9325c6 100644 --- a/test/features/steps/wrappers.py +++ b/test/features/steps/wrappers.py @@ -18,10 +18,9 @@ def expect_exact(context, expected, timeout): timedout = True if timedout: # Strip color codes out of the output. - actual = re.sub(r'\x1b\[([0-9A-Za-z;?])+[m|K]?', - '', context.cli.before) + actual = re.sub(r"\x1b\[([0-9A-Za-z;?])+[m|K]?", "", context.cli.before) raise Exception( - textwrap.dedent('''\ + textwrap.dedent("""\ Expected: --- {0!r} @@ -34,17 +33,12 @@ def expect_exact(context, expected, timeout): --- {2!r} --- - ''').format( - expected, - actual, - context.logfile.getvalue() - ) + """).format(expected, actual, context.logfile.getvalue()) ) def expect_pager(context, expected, timeout): - expect_exact(context, "{0}\r\n{1}{0}\r\n".format( - context.conf['pager_boundary'], expected), timeout=timeout) + expect_exact(context, "{0}\r\n{1}{0}\r\n".format(context.conf["pager_boundary"], expected), timeout=timeout) def run_cli(context, run_args=None, exclude_args=None): @@ -63,55 +57,49 @@ def run_cli(context, run_args=None, exclude_args=None): else: rendered_args.append(key) - if conf.get('host', None): - add_arg('host', '-h', conf['host']) - if conf.get('user', None): - add_arg('user', '-u', conf['user']) - if conf.get('pass', None): - add_arg('pass', '-p', conf['pass']) - if conf.get('port', None): - add_arg('port', '-P', str(conf['port'])) - if conf.get('dbname', None): - add_arg('dbname', '-D', conf['dbname']) - if conf.get('defaults-file', None): - add_arg('defaults_file', '--defaults-file', conf['defaults-file']) - if conf.get('myclirc', None): - add_arg('myclirc', '--myclirc', conf['myclirc']) - if conf.get('login_path'): - add_arg('login_path', '--login-path', conf['login_path']) + if conf.get("host", None): + add_arg("host", "-h", conf["host"]) + if conf.get("user", None): + add_arg("user", "-u", conf["user"]) + if conf.get("pass", None): + add_arg("pass", "-p", conf["pass"]) + if conf.get("port", None): + add_arg("port", "-P", str(conf["port"])) + if conf.get("dbname", None): + add_arg("dbname", "-D", conf["dbname"]) + if conf.get("defaults-file", None): + add_arg("defaults_file", "--defaults-file", conf["defaults-file"]) + if conf.get("myclirc", None): + add_arg("myclirc", "--myclirc", conf["myclirc"]) + if conf.get("login_path"): + add_arg("login_path", "--login-path", conf["login_path"]) for arg_name, arg_value in conf.items(): - if arg_name.startswith('-'): + if arg_name.startswith("-"): add_arg(arg_name, arg_name, arg_value) try: - cli_cmd = context.conf['cli_command'] + cli_cmd = context.conf["cli_command"] except KeyError: - cli_cmd = ( - '{0!s} -c "' - 'import coverage ; ' - 'coverage.process_startup(); ' - 'import mycli.main; ' - 'mycli.main.cli()' - '"' - ).format(sys.executable) + cli_cmd = ('{0!s} -c "' "import coverage ; " "coverage.process_startup(); " "import mycli.main; " "mycli.main.cli()" '"').format( + sys.executable + ) cmd_parts = [cli_cmd] + rendered_args - cmd = ' '.join(cmd_parts) + cmd = " ".join(cmd_parts) context.cli = pexpect.spawnu(cmd, cwd=context.package_root) context.logfile = StringIO() context.cli.logfile = context.logfile context.exit_sent = False - context.currentdb = context.conf['dbname'] + context.currentdb = context.conf["dbname"] def wait_prompt(context, prompt=None): """Make sure prompt is displayed.""" if prompt is None: - user = context.conf['user'] - host = context.conf['host'] + user = context.conf["user"] + host = context.conf["host"] dbname = context.currentdb - prompt = '{0}@{1}:{2}>'.format( - user, host, dbname), + prompt = ("{0}@{1}:{2}>".format(user, host, dbname),) expect_exact(context, prompt, timeout=5) context.atprompt = True diff --git a/test/myclirc b/test/myclirc index 7d96c45..58f7279 100644 --- a/test/myclirc +++ b/test/myclirc @@ -153,6 +153,7 @@ output.null = "#808080" # Favorite queries. [favorite_queries] check = 'select "✔"' +foo_args = 'SELECT $1, "$2", "$3"' # Use the -d option to reference a DSN. # Special characters in passwords and other strings can be escaped with URL encoding. diff --git a/test/test_clistyle.py b/test/test_clistyle.py index f82cdf0..ab40444 100644 --- a/test/test_clistyle.py +++ b/test/test_clistyle.py @@ -1,4 +1,5 @@ """Test the mycli.clistyle module.""" + import pytest from pygments.style import Style @@ -10,9 +11,9 @@ from mycli.clistyle import style_factory @pytest.mark.skip(reason="incompatible with new prompt toolkit") def test_style_factory(): """Test that a Pygments Style class is created.""" - header = 'bold underline #ansired' - cli_style = {'Token.Output.Header': header} - style = style_factory('default', cli_style) + header = "bold underline #ansired" + cli_style = {"Token.Output.Header": header} + style = style_factory("default", cli_style) assert isinstance(style(), Style) assert Token.Output.Header in style.styles @@ -22,6 +23,6 @@ def test_style_factory(): @pytest.mark.skip(reason="incompatible with new prompt toolkit") def test_style_factory_unknown_name(): """Test that an unrecognized name will not throw an error.""" - style = style_factory('foobar', {}) + style = style_factory("foobar", {}) assert isinstance(style(), Style) diff --git a/test/test_completion_engine.py b/test/test_completion_engine.py index 318b632..3104065 100644 --- a/test/test_completion_engine.py +++ b/test/test_completion_engine.py @@ -8,494 +8,528 @@ def sorted_dicts(dicts): def test_select_suggests_cols_with_visible_table_scope(): - suggestions = suggest_type('SELECT FROM tabl', 'SELECT ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['tabl']}, - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + suggestions = suggest_type("SELECT FROM tabl", "SELECT ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) def test_select_suggests_cols_with_qualified_table_scope(): - suggestions = suggest_type('SELECT FROM sch.tabl', 'SELECT ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['tabl']}, - {'type': 'column', 'tables': [('sch', 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) - - -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM tabl WHERE ', - 'SELECT * FROM tabl WHERE (', - 'SELECT * FROM tabl WHERE foo = ', - 'SELECT * FROM tabl WHERE bar OR ', - 'SELECT * FROM tabl WHERE foo = 1 AND ', - 'SELECT * FROM tabl WHERE (bar > 10 AND ', - 'SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (', - 'SELECT * FROM tabl WHERE 10 < ', - 'SELECT * FROM tabl WHERE foo BETWEEN ', - 'SELECT * FROM tabl WHERE foo BETWEEN foo AND ', -]) + suggestions = suggest_type("SELECT FROM sch.tabl", "SELECT ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [("sch", "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM tabl WHERE ", + "SELECT * FROM tabl WHERE (", + "SELECT * FROM tabl WHERE foo = ", + "SELECT * FROM tabl WHERE bar OR ", + "SELECT * FROM tabl WHERE foo = 1 AND ", + "SELECT * FROM tabl WHERE (bar > 10 AND ", + "SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (", + "SELECT * FROM tabl WHERE 10 < ", + "SELECT * FROM tabl WHERE foo BETWEEN ", + "SELECT * FROM tabl WHERE foo BETWEEN foo AND ", + ], +) def test_where_suggests_columns_functions(expression): suggestions = suggest_type(expression, expression) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['tabl']}, - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) - - -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM tabl WHERE foo IN (', - 'SELECT * FROM tabl WHERE foo IN (bar, ', -]) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM tabl WHERE foo IN (", + "SELECT * FROM tabl WHERE foo IN (bar, ", + ], +) def test_where_in_suggests_columns(expression): suggestions = suggest_type(expression, expression) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['tabl']}, - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) def test_where_equals_any_suggests_columns_or_keywords(): - text = 'SELECT * FROM tabl WHERE foo = ANY(' + text = "SELECT * FROM tabl WHERE foo = ANY(" suggestions = suggest_type(text, text) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['tabl']}, - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}]) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) def test_lparen_suggests_cols(): - suggestion = suggest_type('SELECT MAX( FROM tbl', 'SELECT MAX(') - assert suggestion == [ - {'type': 'column', 'tables': [(None, 'tbl', None)]}] + suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(") + assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}] def test_operand_inside_function_suggests_cols1(): - suggestion = suggest_type( - 'SELECT MAX(col1 + FROM tbl', 'SELECT MAX(col1 + ') - assert suggestion == [ - {'type': 'column', 'tables': [(None, 'tbl', None)]}] + suggestion = suggest_type("SELECT MAX(col1 + FROM tbl", "SELECT MAX(col1 + ") + assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}] def test_operand_inside_function_suggests_cols2(): - suggestion = suggest_type( - 'SELECT MAX(col1 + col2 + FROM tbl', 'SELECT MAX(col1 + col2 + ') - assert suggestion == [ - {'type': 'column', 'tables': [(None, 'tbl', None)]}] + suggestion = suggest_type("SELECT MAX(col1 + col2 + FROM tbl", "SELECT MAX(col1 + col2 + ") + assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}] def test_select_suggests_cols_and_funcs(): - suggestions = suggest_type('SELECT ', 'SELECT ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': []}, - {'type': 'column', 'tables': []}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) - - -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM ', - 'INSERT INTO ', - 'COPY ', - 'UPDATE ', - 'DESCRIBE ', - 'DESC ', - 'EXPLAIN ', - 'SELECT * FROM foo JOIN ', -]) + suggestions = suggest_type("SELECT ", "SELECT ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": []}, + {"type": "column", "tables": []}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM ", + "INSERT INTO ", + "COPY ", + "UPDATE ", + "DESCRIBE ", + "DESC ", + "EXPLAIN ", + "SELECT * FROM foo JOIN ", + ], +) def test_expression_suggests_tables_views_and_schemas(expression): suggestions = suggest_type(expression, expression) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) - - -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM sch.', - 'INSERT INTO sch.', - 'COPY sch.', - 'UPDATE sch.', - 'DESCRIBE sch.', - 'DESC sch.', - 'EXPLAIN sch.', - 'SELECT * FROM foo JOIN sch.', -]) + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM sch.", + "INSERT INTO sch.", + "COPY sch.", + "UPDATE sch.", + "DESCRIBE sch.", + "DESC sch.", + "EXPLAIN sch.", + "SELECT * FROM foo JOIN sch.", + ], +) def test_expression_suggests_qualified_tables_views_and_schemas(expression): suggestions = suggest_type(expression, expression) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': 'sch'}, - {'type': 'view', 'schema': 'sch'}]) + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": "sch"}, {"type": "view", "schema": "sch"}]) def test_truncate_suggests_tables_and_schemas(): - suggestions = suggest_type('TRUNCATE ', 'TRUNCATE ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'schema'}]) + suggestions = suggest_type("TRUNCATE ", "TRUNCATE ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "schema"}]) def test_truncate_suggests_qualified_tables(): - suggestions = suggest_type('TRUNCATE sch.', 'TRUNCATE sch.') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': 'sch'}]) + suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": "sch"}]) def test_distinct_suggests_cols(): - suggestions = suggest_type('SELECT DISTINCT ', 'SELECT DISTINCT ') - assert suggestions == [{'type': 'column', 'tables': []}] + suggestions = suggest_type("SELECT DISTINCT ", "SELECT DISTINCT ") + assert suggestions == [{"type": "column", "tables": []}] def test_col_comma_suggests_cols(): - suggestions = suggest_type('SELECT a, b, FROM tbl', 'SELECT a, b,') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['tbl']}, - {'type': 'column', 'tables': [(None, 'tbl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["tbl"]}, + {"type": "column", "tables": [(None, "tbl", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) def test_table_comma_suggests_tables_and_schemas(): - suggestions = suggest_type('SELECT a, b FROM tbl1, ', - 'SELECT a, b FROM tbl1, ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) + suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) def test_into_suggests_tables_and_schemas(): - suggestion = suggest_type('INSERT INTO ', 'INSERT INTO ') - assert sorted_dicts(suggestion) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) + suggestion = suggest_type("INSERT INTO ", "INSERT INTO ") + assert sorted_dicts(suggestion) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) def test_insert_into_lparen_suggests_cols(): - suggestions = suggest_type('INSERT INTO abc (', 'INSERT INTO abc (') - assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}] + suggestions = suggest_type("INSERT INTO abc (", "INSERT INTO abc (") + assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}] def test_insert_into_lparen_partial_text_suggests_cols(): - suggestions = suggest_type('INSERT INTO abc (i', 'INSERT INTO abc (i') - assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}] + suggestions = suggest_type("INSERT INTO abc (i", "INSERT INTO abc (i") + assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}] def test_insert_into_lparen_comma_suggests_cols(): - suggestions = suggest_type('INSERT INTO abc (id,', 'INSERT INTO abc (id,') - assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}] + suggestions = suggest_type("INSERT INTO abc (id,", "INSERT INTO abc (id,") + assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}] def test_partially_typed_col_name_suggests_col_names(): - suggestions = suggest_type('SELECT * FROM tabl WHERE col_n', - 'SELECT * FROM tabl WHERE col_n') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['tabl']}, - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + suggestions = suggest_type("SELECT * FROM tabl WHERE col_n", "SELECT * FROM tabl WHERE col_n") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) def test_dot_suggests_cols_of_a_table_or_schema_qualified_table(): - suggestions = suggest_type('SELECT tabl. FROM tabl', 'SELECT tabl.') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'table', 'schema': 'tabl'}, - {'type': 'view', 'schema': 'tabl'}, - {'type': 'function', 'schema': 'tabl'}]) + suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "table", "schema": "tabl"}, + {"type": "view", "schema": "tabl"}, + {"type": "function", "schema": "tabl"}, + ] + ) def test_dot_suggests_cols_of_an_alias(): - suggestions = suggest_type('SELECT t1. FROM tabl1 t1, tabl2 t2', - 'SELECT t1.') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': 't1'}, - {'type': 'view', 'schema': 't1'}, - {'type': 'column', 'tables': [(None, 'tabl1', 't1')]}, - {'type': 'function', 'schema': 't1'}]) + suggestions = suggest_type("SELECT t1. FROM tabl1 t1, tabl2 t2", "SELECT t1.") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "table", "schema": "t1"}, + {"type": "view", "schema": "t1"}, + {"type": "column", "tables": [(None, "tabl1", "t1")]}, + {"type": "function", "schema": "t1"}, + ] + ) def test_dot_col_comma_suggests_cols_or_schema_qualified_table(): - suggestions = suggest_type('SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2', - 'SELECT t1.a, t2.') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'tabl2', 't2')]}, - {'type': 'table', 'schema': 't2'}, - {'type': 'view', 'schema': 't2'}, - {'type': 'function', 'schema': 't2'}]) - - -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM (', - 'SELECT * FROM foo WHERE EXISTS (', - 'SELECT * FROM foo WHERE bar AND NOT EXISTS (', - 'SELECT 1 AS', -]) + suggestions = suggest_type("SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2.") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "column", "tables": [(None, "tabl2", "t2")]}, + {"type": "table", "schema": "t2"}, + {"type": "view", "schema": "t2"}, + {"type": "function", "schema": "t2"}, + ] + ) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM (", + "SELECT * FROM foo WHERE EXISTS (", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (", + "SELECT 1 AS", + ], +) def test_sub_select_suggests_keyword(expression): suggestion = suggest_type(expression, expression) - assert suggestion == [{'type': 'keyword'}] + assert suggestion == [{"type": "keyword"}] -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM (S', - 'SELECT * FROM foo WHERE EXISTS (S', - 'SELECT * FROM foo WHERE bar AND NOT EXISTS (S', -]) +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM (S", + "SELECT * FROM foo WHERE EXISTS (S", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (S", + ], +) def test_sub_select_partial_text_suggests_keyword(expression): suggestion = suggest_type(expression, expression) - assert suggestion == [{'type': 'keyword'}] + assert suggestion == [{"type": "keyword"}] def test_outer_table_reference_in_exists_subquery_suggests_columns(): - q = 'SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f.' + q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f." suggestions = suggest_type(q, q) assert suggestions == [ - {'type': 'column', 'tables': [(None, 'foo', 'f')]}, - {'type': 'table', 'schema': 'f'}, - {'type': 'view', 'schema': 'f'}, - {'type': 'function', 'schema': 'f'}] - - -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM (SELECT * FROM ', - 'SELECT * FROM foo WHERE EXISTS (SELECT * FROM ', - 'SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ', -]) + {"type": "column", "tables": [(None, "foo", "f")]}, + {"type": "table", "schema": "f"}, + {"type": "view", "schema": "f"}, + {"type": "function", "schema": "f"}, + ] + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM (SELECT * FROM ", + "SELECT * FROM foo WHERE EXISTS (SELECT * FROM ", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ", + ], +) def test_sub_select_table_name_completion(expression): suggestion = suggest_type(expression, expression) - assert sorted_dicts(suggestion) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) + assert sorted_dicts(suggestion) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) def test_sub_select_col_name_completion(): - suggestions = suggest_type('SELECT * FROM (SELECT FROM abc', - 'SELECT * FROM (SELECT ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['abc']}, - {'type': 'column', 'tables': [(None, 'abc', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + suggestions = suggest_type("SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["abc"]}, + {"type": "column", "tables": [(None, "abc", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) @pytest.mark.xfail def test_sub_select_multiple_col_name_completion(): - suggestions = suggest_type('SELECT * FROM (SELECT a, FROM abc', - 'SELECT * FROM (SELECT a, ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'abc', None)]}, - {'type': 'function', 'schema': []}]) + suggestions = suggest_type("SELECT * FROM (SELECT a, FROM abc", "SELECT * FROM (SELECT a, ") + assert sorted_dicts(suggestions) == sorted_dicts( + [{"type": "column", "tables": [(None, "abc", None)]}, {"type": "function", "schema": []}] + ) def test_sub_select_dot_col_name_completion(): - suggestions = suggest_type('SELECT * FROM (SELECT t. FROM tabl t', - 'SELECT * FROM (SELECT t.') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'tabl', 't')]}, - {'type': 'table', 'schema': 't'}, - {'type': 'view', 'schema': 't'}, - {'type': 'function', 'schema': 't'}]) - - -@pytest.mark.parametrize('join_type', ['', 'INNER', 'LEFT', 'RIGHT OUTER']) -@pytest.mark.parametrize('tbl_alias', ['', 'foo']) + suggestions = suggest_type("SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t.") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "column", "tables": [(None, "tabl", "t")]}, + {"type": "table", "schema": "t"}, + {"type": "view", "schema": "t"}, + {"type": "function", "schema": "t"}, + ] + ) + + +@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"]) +@pytest.mark.parametrize("tbl_alias", ["", "foo"]) def test_join_suggests_tables_and_schemas(tbl_alias, join_type): - text = 'SELECT * FROM abc {0} {1} JOIN '.format(tbl_alias, join_type) + text = "SELECT * FROM abc {0} {1} JOIN ".format(tbl_alias, join_type) suggestion = suggest_type(text, text) - assert sorted_dicts(suggestion) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) + assert sorted_dicts(suggestion) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) -@pytest.mark.parametrize('sql', [ - 'SELECT * FROM abc a JOIN def d ON a.', - 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.', -]) +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM abc a JOIN def d ON a.", + "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.", + ], +) def test_join_alias_dot_suggests_cols1(sql): suggestions = suggest_type(sql, sql) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'abc', 'a')]}, - {'type': 'table', 'schema': 'a'}, - {'type': 'view', 'schema': 'a'}, - {'type': 'function', 'schema': 'a'}]) - - -@pytest.mark.parametrize('sql', [ - 'SELECT * FROM abc a JOIN def d ON a.id = d.', - 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.', -]) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "column", "tables": [(None, "abc", "a")]}, + {"type": "table", "schema": "a"}, + {"type": "view", "schema": "a"}, + {"type": "function", "schema": "a"}, + ] + ) + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM abc a JOIN def d ON a.id = d.", + "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.", + ], +) def test_join_alias_dot_suggests_cols2(sql): suggestions = suggest_type(sql, sql) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'def', 'd')]}, - {'type': 'table', 'schema': 'd'}, - {'type': 'view', 'schema': 'd'}, - {'type': 'function', 'schema': 'd'}]) - - -@pytest.mark.parametrize('sql', [ - 'select a.x, b.y from abc a join bcd b on ', - 'select a.x, b.y from abc a join bcd b on a.id = b.id OR ', -]) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "column", "tables": [(None, "def", "d")]}, + {"type": "table", "schema": "d"}, + {"type": "view", "schema": "d"}, + {"type": "function", "schema": "d"}, + ] + ) + + +@pytest.mark.parametrize( + "sql", + [ + "select a.x, b.y from abc a join bcd b on ", + "select a.x, b.y from abc a join bcd b on a.id = b.id OR ", + ], +) def test_on_suggests_aliases(sql): suggestions = suggest_type(sql, sql) - assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}] + assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}] -@pytest.mark.parametrize('sql', [ - 'select abc.x, bcd.y from abc join bcd on ', - 'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ', -]) +@pytest.mark.parametrize( + "sql", + [ + "select abc.x, bcd.y from abc join bcd on ", + "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ", + ], +) def test_on_suggests_tables(sql): suggestions = suggest_type(sql, sql) - assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}] + assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}] -@pytest.mark.parametrize('sql', [ - 'select a.x, b.y from abc a join bcd b on a.id = ', - 'select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ', -]) +@pytest.mark.parametrize( + "sql", + [ + "select a.x, b.y from abc a join bcd b on a.id = ", + "select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ", + ], +) def test_on_suggests_aliases_right_side(sql): suggestions = suggest_type(sql, sql) - assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}] + assert suggestions == [{"type": "alias", "aliases": ["a", "b"]}] -@pytest.mark.parametrize('sql', [ - 'select abc.x, bcd.y from abc join bcd on ', - 'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ', -]) +@pytest.mark.parametrize( + "sql", + [ + "select abc.x, bcd.y from abc join bcd on ", + "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ", + ], +) def test_on_suggests_tables_right_side(sql): suggestions = suggest_type(sql, sql) - assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}] + assert suggestions == [{"type": "alias", "aliases": ["abc", "bcd"]}] -@pytest.mark.parametrize('col_list', ['', 'col1, ']) +@pytest.mark.parametrize("col_list", ["", "col1, "]) def test_join_using_suggests_common_columns(col_list): - text = 'select * from abc inner join def using (' + col_list - assert suggest_type(text, text) == [ - {'type': 'column', - 'tables': [(None, 'abc', None), (None, 'def', None)], - 'drop_unique': True}] - -@pytest.mark.parametrize('sql', [ - 'SELECT * FROM abc a JOIN def d ON a.id = d.id JOIN ghi g ON g.', - 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.id2 JOIN ghi g ON d.id = g.id AND g.', -]) + text = "select * from abc inner join def using (" + col_list + assert suggest_type(text, text) == [{"type": "column", "tables": [(None, "abc", None), (None, "def", None)], "drop_unique": True}] + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM abc a JOIN def d ON a.id = d.id JOIN ghi g ON g.", + "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.id2 JOIN ghi g ON d.id = g.id AND g.", + ], +) def test_two_join_alias_dot_suggests_cols1(sql): suggestions = suggest_type(sql, sql) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'ghi', 'g')]}, - {'type': 'table', 'schema': 'g'}, - {'type': 'view', 'schema': 'g'}, - {'type': 'function', 'schema': 'g'}]) + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "column", "tables": [(None, "ghi", "g")]}, + {"type": "table", "schema": "g"}, + {"type": "view", "schema": "g"}, + {"type": "function", "schema": "g"}, + ] + ) + def test_2_statements_2nd_current(): - suggestions = suggest_type('select * from a; select * from ', - 'select * from a; select * from ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) - - suggestions = suggest_type('select * from a; select from b', - 'select * from a; select ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['b']}, - {'type': 'column', 'tables': [(None, 'b', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + suggestions = suggest_type("select * from a; select * from ", "select * from a; select * from ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) + + suggestions = suggest_type("select * from a; select from b", "select * from a; select ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["b"]}, + {"type": "column", "tables": [(None, "b", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) # Should work even if first statement is invalid - suggestions = suggest_type('select * from; select * from ', - 'select * from; select * from ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) + suggestions = suggest_type("select * from; select * from ", "select * from; select * from ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) def test_2_statements_1st_current(): - suggestions = suggest_type('select * from ; select * from b', - 'select * from ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) - - suggestions = suggest_type('select from a; select * from b', - 'select ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['a']}, - {'type': 'column', 'tables': [(None, 'a', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + suggestions = suggest_type("select * from ; select * from b", "select * from ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) + + suggestions = suggest_type("select from a; select * from b", "select ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["a"]}, + {"type": "column", "tables": [(None, "a", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) def test_3_statements_2nd_current(): - suggestions = suggest_type('select * from a; select * from ; select * from c', - 'select * from a; select * from ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) - - suggestions = suggest_type('select * from a; select from b; select * from c', - 'select * from a; select ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['b']}, - {'type': 'column', 'tables': [(None, 'b', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) + suggestions = suggest_type("select * from a; select * from ; select * from c", "select * from a; select * from ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) + + suggestions = suggest_type("select * from a; select from b; select * from c", "select * from a; select ") + assert sorted_dicts(suggestions) == sorted_dicts( + [ + {"type": "alias", "aliases": ["b"]}, + {"type": "column", "tables": [(None, "b", None)]}, + {"type": "function", "schema": []}, + {"type": "keyword"}, + ] + ) def test_create_db_with_template(): - suggestions = suggest_type('create database foo with template ', - 'create database foo with template ') + suggestions = suggest_type("create database foo with template ", "create database foo with template ") - assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'database'}]) + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "database"}]) -@pytest.mark.parametrize('initial_text', ['', ' ', '\t \t']) +@pytest.mark.parametrize("initial_text", ["", " ", "\t \t"]) def test_specials_included_for_initial_completion(initial_text): suggestions = suggest_type(initial_text, initial_text) - assert sorted_dicts(suggestions) == \ - sorted_dicts([{'type': 'keyword'}, {'type': 'special'}]) + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "keyword"}, {"type": "special"}]) def test_specials_not_included_after_initial_token(): - suggestions = suggest_type('create table foo (dt d', - 'create table foo (dt d') + suggestions = suggest_type("create table foo (dt d", "create table foo (dt d") - assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'keyword'}]) + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "keyword"}]) def test_drop_schema_qualified_table_suggests_only_tables(): - text = 'DROP TABLE schema_name.table_name' + text = "DROP TABLE schema_name.table_name" suggestions = suggest_type(text, text) - assert suggestions == [{'type': 'table', 'schema': 'schema_name'}] + assert suggestions == [{"type": "table", "schema": "schema_name"}] -@pytest.mark.parametrize('text', [',', ' ,', 'sel ,']) +@pytest.mark.parametrize("text", [",", " ,", "sel ,"]) def test_handle_pre_completion_comma_gracefully(text): suggestions = suggest_type(text, text) @@ -503,53 +537,59 @@ def test_handle_pre_completion_comma_gracefully(text): def test_cross_join(): - text = 'select * from v1 cross join v2 JOIN v1.id, ' + text = "select * from v1 cross join v2 JOIN v1.id, " suggestions = suggest_type(text, text) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) -@pytest.mark.parametrize('expression', [ - 'SELECT 1 AS ', - 'SELECT 1 FROM tabl AS ', -]) +@pytest.mark.parametrize( + "expression", + [ + "SELECT 1 AS ", + "SELECT 1 FROM tabl AS ", + ], +) def test_after_as(expression): suggestions = suggest_type(expression, expression) assert set(suggestions) == set() -@pytest.mark.parametrize('expression', [ - '\\. ', - 'select 1; \\. ', - 'select 1;\\. ', - 'select 1 ; \\. ', - 'source ', - 'truncate table test; source ', - 'truncate table test ; source ', - 'truncate table test;source ', -]) +@pytest.mark.parametrize( + "expression", + [ + "\\. ", + "select 1; \\. ", + "select 1;\\. ", + "select 1 ; \\. ", + "source ", + "truncate table test; source ", + "truncate table test ; source ", + "truncate table test;source ", + ], +) def test_source_is_file(expression): suggestions = suggest_type(expression, expression) - assert suggestions == [{'type': 'file_name'}] + assert suggestions == [{"type": "file_name"}] -@pytest.mark.parametrize("expression", [ - "\\f ", -]) +@pytest.mark.parametrize( + "expression", + [ + "\\f ", + ], +) def test_favorite_name_suggestion(expression): suggestions = suggest_type(expression, expression) - assert suggestions == [{'type': 'favoritequery'}] + assert suggestions == [{"type": "favoritequery"}] def test_order_by(): - text = 'select * from foo order by ' + text = "select * from foo order by " suggestions = suggest_type(text, text) - assert suggestions == [{'tables': [(None, 'foo', None)], 'type': 'column'}] + assert suggestions == [{"tables": [(None, "foo", None)], "type": "column"}] def test_quoted_where(): text = "'where i=';" suggestions = suggest_type(text, text) - assert suggestions == [{'type': 'keyword'}] + assert suggestions == [{"type": "keyword"}] diff --git a/test/test_completion_refresher.py b/test/test_completion_refresher.py index 31359cf..6f192d0 100644 --- a/test/test_completion_refresher.py +++ b/test/test_completion_refresher.py @@ -6,6 +6,7 @@ from unittest.mock import Mock, patch @pytest.fixture def refresher(): from mycli.completion_refresher import CompletionRefresher + return CompletionRefresher() @@ -18,8 +19,7 @@ def test_ctor(refresher): """ assert len(refresher.refreshers) > 0 actual_handlers = list(refresher.refreshers.keys()) - expected_handlers = ['databases', 'schemata', 'tables', 'users', 'functions', - 'special_commands', 'show_commands', 'keywords'] + expected_handlers = ["databases", "schemata", "tables", "users", "functions", "special_commands", "show_commands", "keywords"] assert expected_handlers == actual_handlers @@ -32,12 +32,12 @@ def test_refresh_called_once(refresher): callbacks = Mock() sqlexecute = Mock() - with patch.object(refresher, '_bg_refresh') as bg_refresh: + with patch.object(refresher, "_bg_refresh") as bg_refresh: actual = refresher.refresh(sqlexecute, callbacks) time.sleep(1) # Wait for the thread to work. assert len(actual) == 1 assert len(actual[0]) == 4 - assert actual[0][3] == 'Auto-completion refresh started in the background.' + assert actual[0][3] == "Auto-completion refresh started in the background." bg_refresh.assert_called_with(sqlexecute, callbacks, {}) @@ -61,13 +61,13 @@ def test_refresh_called_twice(refresher): time.sleep(1) # Wait for the thread to work. assert len(actual1) == 1 assert len(actual1[0]) == 4 - assert actual1[0][3] == 'Auto-completion refresh started in the background.' + assert actual1[0][3] == "Auto-completion refresh started in the background." actual2 = refresher.refresh(sqlexecute, callbacks) time.sleep(1) # Wait for the thread to work. assert len(actual2) == 1 assert len(actual2[0]) == 4 - assert actual2[0][3] == 'Auto-completion refresh restarted.' + assert actual2[0][3] == "Auto-completion refresh restarted." def test_refresh_with_callbacks(refresher): @@ -80,9 +80,9 @@ def test_refresh_with_callbacks(refresher): sqlexecute_class = Mock() sqlexecute = Mock() - with patch('mycli.completion_refresher.SQLExecute', sqlexecute_class): + with patch("mycli.completion_refresher.SQLExecute", sqlexecute_class): # Set refreshers to 0: we're not testing refresh logic here refresher.refreshers = {} refresher.refresh(sqlexecute, callbacks) time.sleep(1) # Wait for the thread to work. - assert (callbacks[0].call_count == 1) + assert callbacks[0].call_count == 1 diff --git a/test/test_config.py b/test/test_config.py index 7f2b244..859ca02 100644 --- a/test/test_config.py +++ b/test/test_config.py @@ -1,4 +1,5 @@ """Unit tests for the mycli.config module.""" + from io import BytesIO, StringIO, TextIOWrapper import os import struct @@ -6,21 +7,26 @@ import sys import tempfile import pytest -from mycli.config import (get_mylogin_cnf_path, open_mylogin_cnf, - read_and_decrypt_mylogin_cnf, read_config_file, - str_to_bool, strip_matching_quotes) +from mycli.config import ( + get_mylogin_cnf_path, + open_mylogin_cnf, + read_and_decrypt_mylogin_cnf, + read_config_file, + str_to_bool, + strip_matching_quotes, +) -LOGIN_PATH_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__), - 'mylogin.cnf')) +LOGIN_PATH_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__), "mylogin.cnf")) def open_bmylogin_cnf(name): """Open contents of *name* in a BytesIO buffer.""" - with open(name, 'rb') as f: + with open(name, "rb") as f: buf = BytesIO() buf.write(f.read()) return buf + def test_read_mylogin_cnf(): """Tests that a login path file can be read and decrypted.""" mylogin_cnf = open_mylogin_cnf(LOGIN_PATH_FILE) @@ -28,7 +34,7 @@ def test_read_mylogin_cnf(): assert isinstance(mylogin_cnf, TextIOWrapper) contents = mylogin_cnf.read() - for word in ('[test]', 'user', 'password', 'host', 'port'): + for word in ("[test]", "user", "password", "host", "port"): assert word in contents @@ -46,7 +52,7 @@ def test_corrupted_login_key(): buf.seek(4) # Write null bytes over half the login key - buf.write(b'\0\0\0\0\0\0\0\0\0\0') + buf.write(b"\0\0\0\0\0\0\0\0\0\0") buf.seek(0) mylogin_cnf = read_and_decrypt_mylogin_cnf(buf) @@ -63,58 +69,58 @@ def test_corrupted_pad(): # Skip option group len_buf = buf.read(4) - cipher_len, = struct.unpack("<i", len_buf) + (cipher_len,) = struct.unpack("<i", len_buf) buf.read(cipher_len) # Corrupt the pad for the user line len_buf = buf.read(4) - cipher_len, = struct.unpack("<i", len_buf) + (cipher_len,) = struct.unpack("<i", len_buf) buf.read(cipher_len - 1) - buf.write(b'\0') + buf.write(b"\0") buf.seek(0) mylogin_cnf = TextIOWrapper(read_and_decrypt_mylogin_cnf(buf)) contents = mylogin_cnf.read() - for word in ('[test]', 'password', 'host', 'port'): + for word in ("[test]", "password", "host", "port"): assert word in contents - assert 'user' not in contents + assert "user" not in contents def test_get_mylogin_cnf_path(): """Tests that the path for .mylogin.cnf is detected.""" original_env = None - if 'MYSQL_TEST_LOGIN_FILE' in os.environ: - original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE') - is_windows = sys.platform == 'win32' + if "MYSQL_TEST_LOGIN_FILE" in os.environ: + original_env = os.environ.pop("MYSQL_TEST_LOGIN_FILE") + is_windows = sys.platform == "win32" login_cnf_path = get_mylogin_cnf_path() if original_env is not None: - os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env + os.environ["MYSQL_TEST_LOGIN_FILE"] = original_env if login_cnf_path is not None: - assert login_cnf_path.endswith('.mylogin.cnf') + assert login_cnf_path.endswith(".mylogin.cnf") if is_windows is True: - assert 'MySQL' in login_cnf_path + assert "MySQL" in login_cnf_path else: - home_dir = os.path.expanduser('~') + home_dir = os.path.expanduser("~") assert login_cnf_path.startswith(home_dir) def test_alternate_get_mylogin_cnf_path(): """Tests that the alternate path for .mylogin.cnf is detected.""" original_env = None - if 'MYSQL_TEST_LOGIN_FILE' in os.environ: - original_env = os.environ.pop('MYSQL_TEST_LOGIN_FILE') + if "MYSQL_TEST_LOGIN_FILE" in os.environ: + original_env = os.environ.pop("MYSQL_TEST_LOGIN_FILE") _, temp_path = tempfile.mkstemp() - os.environ['MYSQL_TEST_LOGIN_FILE'] = temp_path + os.environ["MYSQL_TEST_LOGIN_FILE"] = temp_path login_cnf_path = get_mylogin_cnf_path() if original_env is not None: - os.environ['MYSQL_TEST_LOGIN_FILE'] = original_env + os.environ["MYSQL_TEST_LOGIN_FILE"] = original_env assert temp_path == login_cnf_path @@ -124,17 +130,17 @@ def test_str_to_bool(): assert str_to_bool(False) is False assert str_to_bool(True) is True - assert str_to_bool('False') is False - assert str_to_bool('True') is True - assert str_to_bool('TRUE') is True - assert str_to_bool('1') is True - assert str_to_bool('0') is False - assert str_to_bool('on') is True - assert str_to_bool('off') is False - assert str_to_bool('off') is False + assert str_to_bool("False") is False + assert str_to_bool("True") is True + assert str_to_bool("TRUE") is True + assert str_to_bool("1") is True + assert str_to_bool("0") is False + assert str_to_bool("on") is True + assert str_to_bool("off") is False + assert str_to_bool("off") is False with pytest.raises(ValueError): - str_to_bool('foo') + str_to_bool("foo") with pytest.raises(TypeError): str_to_bool(None) @@ -143,19 +149,19 @@ def test_str_to_bool(): def test_read_config_file_list_values_default(): """Test that reading a config file uses list_values by default.""" - f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n") + f = StringIO("[main]\nweather='cloudy with a chance of meatballs'\n") config = read_config_file(f) - assert config['main']['weather'] == u"cloudy with a chance of meatballs" + assert config["main"]["weather"] == "cloudy with a chance of meatballs" def test_read_config_file_list_values_off(): """Test that you can disable list_values when reading a config file.""" - f = StringIO(u"[main]\nweather='cloudy with a chance of meatballs'\n") + f = StringIO("[main]\nweather='cloudy with a chance of meatballs'\n") config = read_config_file(f, list_values=False) - assert config['main']['weather'] == u"'cloudy with a chance of meatballs'" + assert config["main"]["weather"] == "'cloudy with a chance of meatballs'" def test_strip_quotes_with_matching_quotes(): @@ -177,7 +183,7 @@ def test_strip_quotes_with_unmatching_quotes(): def test_strip_quotes_with_empty_string(): """Test that an empty string is handled during unquoting.""" - assert '' == strip_matching_quotes('') + assert "" == strip_matching_quotes("") def test_strip_quotes_with_none(): diff --git a/test/test_dbspecial.py b/test/test_dbspecial.py index 21e389c..aee6e05 100644 --- a/test/test_dbspecial.py +++ b/test/test_dbspecial.py @@ -4,39 +4,32 @@ from mycli.packages.special.utils import format_uptime def test_u_suggests_databases(): - suggestions = suggest_type('\\u ', '\\u ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'database'}]) + suggestions = suggest_type("\\u ", "\\u ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "database"}]) def test_describe_table(): - suggestions = suggest_type('\\dt', '\\dt ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) + suggestions = suggest_type("\\dt", "\\dt ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) def test_list_or_show_create_tables(): - suggestions = suggest_type('\\dt+', '\\dt+ ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) + suggestions = suggest_type("\\dt+", "\\dt+ ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) def test_format_uptime(): seconds = 59 - assert '59 sec' == format_uptime(seconds) + assert "59 sec" == format_uptime(seconds) seconds = 120 - assert '2 min 0 sec' == format_uptime(seconds) + assert "2 min 0 sec" == format_uptime(seconds) seconds = 54890 - assert '15 hours 14 min 50 sec' == format_uptime(seconds) + assert "15 hours 14 min 50 sec" == format_uptime(seconds) seconds = 598244 - assert '6 days 22 hours 10 min 44 sec' == format_uptime(seconds) + assert "6 days 22 hours 10 min 44 sec" == format_uptime(seconds) seconds = 522600 - assert '6 days 1 hour 10 min 0 sec' == format_uptime(seconds) + assert "6 days 1 hour 10 min 0 sec" == format_uptime(seconds) diff --git a/test/test_main.py b/test/test_main.py index 589d6cd..b0f8d4c 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -13,52 +13,62 @@ from textwrap import dedent from collections import namedtuple from tempfile import NamedTemporaryFile -from textwrap import dedent test_dir = os.path.abspath(os.path.dirname(__file__)) project_dir = os.path.dirname(test_dir) -default_config_file = os.path.join(project_dir, 'test', 'myclirc') -login_path_file = os.path.join(test_dir, 'mylogin.cnf') - -os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file -CLI_ARGS = ['--user', USER, '--host', HOST, '--port', PORT, - '--password', PASSWORD, '--myclirc', default_config_file, - '--defaults-file', default_config_file, - 'mycli_test_db'] +default_config_file = os.path.join(project_dir, "test", "myclirc") +login_path_file = os.path.join(test_dir, "mylogin.cnf") + +os.environ["MYSQL_TEST_LOGIN_FILE"] = login_path_file +CLI_ARGS = [ + "--user", + USER, + "--host", + HOST, + "--port", + PORT, + "--password", + PASSWORD, + "--myclirc", + default_config_file, + "--defaults-file", + default_config_file, + "mycli_test_db", +] @dbtest def test_execute_arg(executor): - run(executor, 'create table test (a text)') + run(executor, "create table test (a text)") run(executor, 'insert into test values("abc")') - sql = 'select * from test;' + sql = "select * from test;" runner = CliRunner() - result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql]) + result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql]) assert result.exit_code == 0 - assert 'abc' in result.output + assert "abc" in result.output - result = runner.invoke(cli, args=CLI_ARGS + ['--execute', sql]) + result = runner.invoke(cli, args=CLI_ARGS + ["--execute", sql]) assert result.exit_code == 0 - assert 'abc' in result.output + assert "abc" in result.output - expected = 'a\nabc\n' + expected = "a\nabc\n" assert expected in result.output @dbtest def test_execute_arg_with_table(executor): - run(executor, 'create table test (a text)') + run(executor, "create table test (a text)") run(executor, 'insert into test values("abc")') - sql = 'select * from test;' + sql = "select * from test;" runner = CliRunner() - result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--table']) - expected = '+-----+\n| a |\n+-----+\n| abc |\n+-----+\n' + result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql] + ["--table"]) + expected = "+-----+\n| a |\n+-----+\n| abc |\n+-----+\n" assert result.exit_code == 0 assert expected in result.output @@ -66,12 +76,12 @@ def test_execute_arg_with_table(executor): @dbtest def test_execute_arg_with_csv(executor): - run(executor, 'create table test (a text)') + run(executor, "create table test (a text)") run(executor, 'insert into test values("abc")') - sql = 'select * from test;' + sql = "select * from test;" runner = CliRunner() - result = runner.invoke(cli, args=CLI_ARGS + ['-e', sql] + ['--csv']) + result = runner.invoke(cli, args=CLI_ARGS + ["-e", sql] + ["--csv"]) expected = '"a"\n"abc"\n' assert result.exit_code == 0 @@ -80,35 +90,29 @@ def test_execute_arg_with_csv(executor): @dbtest def test_batch_mode(executor): - run(executor, '''create table test(a text)''') - run(executor, '''insert into test values('abc'), ('def'), ('ghi')''') + run(executor, """create table test(a text)""") + run(executor, """insert into test values('abc'), ('def'), ('ghi')""") - sql = ( - 'select count(*) from test;\n' - 'select * from test limit 1;' - ) + sql = "select count(*) from test;\n" "select * from test limit 1;" runner = CliRunner() result = runner.invoke(cli, args=CLI_ARGS, input=sql) assert result.exit_code == 0 - assert 'count(*)\n3\na\nabc\n' in "".join(result.output) + assert "count(*)\n3\na\nabc\n" in "".join(result.output) @dbtest def test_batch_mode_table(executor): - run(executor, '''create table test(a text)''') - run(executor, '''insert into test values('abc'), ('def'), ('ghi')''') + run(executor, """create table test(a text)""") + run(executor, """insert into test values('abc'), ('def'), ('ghi')""") - sql = ( - 'select count(*) from test;\n' - 'select * from test limit 1;' - ) + sql = "select count(*) from test;\n" "select * from test limit 1;" runner = CliRunner() - result = runner.invoke(cli, args=CLI_ARGS + ['-t'], input=sql) + result = runner.invoke(cli, args=CLI_ARGS + ["-t"], input=sql) - expected = (dedent("""\ + expected = dedent("""\ +----------+ | count(*) | +----------+ @@ -118,7 +122,7 @@ def test_batch_mode_table(executor): | a | +-----+ | abc | - +-----+""")) + +-----+""") assert result.exit_code == 0 assert expected in result.output @@ -126,14 +130,13 @@ def test_batch_mode_table(executor): @dbtest def test_batch_mode_csv(executor): - run(executor, '''create table test(a text, b text)''') - run(executor, - '''insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')''') + run(executor, """create table test(a text, b text)""") + run(executor, """insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')""") - sql = 'select * from test;' + sql = "select * from test;" runner = CliRunner() - result = runner.invoke(cli, args=CLI_ARGS + ['--csv'], input=sql) + result = runner.invoke(cli, args=CLI_ARGS + ["--csv"], input=sql) expected = '"a","b"\n"abc","de\nf"\n"ghi","jkl"\n' @@ -150,15 +153,15 @@ def test_help_strings_end_with_periods(): """Make sure click options have help text that end with a period.""" for param in cli.params: if isinstance(param, click.core.Option): - assert hasattr(param, 'help') - assert param.help.endswith('.') + assert hasattr(param, "help") + assert param.help.endswith(".") def test_command_descriptions_end_with_periods(): """Make sure that mycli commands' descriptions end with a period.""" MyCli() for _, command in SPECIAL_COMMANDS.items(): - assert command[3].endswith('.') + assert command[3].endswith(".") def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager): @@ -166,23 +169,23 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager): clickoutput = "" m = MyCli(myclirc=default_config_file) - class TestOutput(): + class TestOutput: def get_size(self): - size = namedtuple('Size', 'rows columns') + size = namedtuple("Size", "rows columns") size.columns, size.rows = terminal_size return size - class TestExecute(): - host = 'test' - user = 'test' - dbname = 'test' - server_info = ServerInfo.from_version_string('unknown') + class TestExecute: + host = "test" + user = "test" + dbname = "test" + server_info = ServerInfo.from_version_string("unknown") port = 0 def server_type(self): - return ['test'] + return ["test"] - class PromptBuffer(): + class PromptBuffer: output = TestOutput() m.prompt_app = PromptBuffer() @@ -199,8 +202,8 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager): global clickoutput clickoutput += s + "\n" - monkeypatch.setattr(click, 'echo_via_pager', echo_via_pager) - monkeypatch.setattr(click, 'secho', secho) + monkeypatch.setattr(click, "echo_via_pager", echo_via_pager) + monkeypatch.setattr(click, "secho", secho) m.output(testdata) if clickoutput.endswith("\n"): clickoutput = clickoutput[:-1] @@ -208,59 +211,29 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager): def test_conditional_pager(monkeypatch): - testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split( - " ") + testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split(" ") # User didn't set pager, output doesn't fit screen -> pager - output( - monkeypatch, - terminal_size=(5, 10), - testdata=testdata, - explicit_pager=False, - expect_pager=True - ) + output(monkeypatch, terminal_size=(5, 10), testdata=testdata, explicit_pager=False, expect_pager=True) # User didn't set pager, output fits screen -> no pager - output( - monkeypatch, - terminal_size=(20, 20), - testdata=testdata, - explicit_pager=False, - expect_pager=False - ) + output(monkeypatch, terminal_size=(20, 20), testdata=testdata, explicit_pager=False, expect_pager=False) # User manually configured pager, output doesn't fit screen -> pager - output( - monkeypatch, - terminal_size=(5, 10), - testdata=testdata, - explicit_pager=True, - expect_pager=True - ) + output(monkeypatch, terminal_size=(5, 10), testdata=testdata, explicit_pager=True, expect_pager=True) # User manually configured pager, output fit screen -> pager - output( - monkeypatch, - terminal_size=(20, 20), - testdata=testdata, - explicit_pager=True, - expect_pager=True - ) + output(monkeypatch, terminal_size=(20, 20), testdata=testdata, explicit_pager=True, expect_pager=True) - SPECIAL_COMMANDS['nopager'].handler() - output( - monkeypatch, - terminal_size=(5, 10), - testdata=testdata, - explicit_pager=False, - expect_pager=False - ) - SPECIAL_COMMANDS['pager'].handler('') + SPECIAL_COMMANDS["nopager"].handler() + output(monkeypatch, terminal_size=(5, 10), testdata=testdata, explicit_pager=False, expect_pager=False) + SPECIAL_COMMANDS["pager"].handler("") def test_reserved_space_is_integer(monkeypatch): """Make sure that reserved space is returned as an integer.""" + def stub_terminal_size(): return (5, 5) with monkeypatch.context() as m: - m.setattr(shutil, 'get_terminal_size', stub_terminal_size) + m.setattr(shutil, "get_terminal_size", stub_terminal_size) mycli = MyCli() assert isinstance(mycli.get_reserved_space(), int) @@ -268,18 +241,20 @@ def test_reserved_space_is_integer(monkeypatch): def test_list_dsn(): runner = CliRunner() # keep Windows from locking the file with delete=False - with NamedTemporaryFile(mode="w",delete=False) as myclirc: - myclirc.write(dedent("""\ + with NamedTemporaryFile(mode="w", delete=False) as myclirc: + myclirc.write( + dedent("""\ [alias_dsn] test = mysql://test/test - """)) + """) + ) myclirc.flush() - args = ['--list-dsn', '--myclirc', myclirc.name] + args = ["--list-dsn", "--myclirc", myclirc.name] result = runner.invoke(cli, args=args) assert result.output == "test\n" - result = runner.invoke(cli, args=args + ['--verbose']) + result = runner.invoke(cli, args=args + ["--verbose"]) assert result.output == "test : mysql://test/test\n" - + # delete=False means we should try to clean up try: if os.path.exists(myclirc.name): @@ -287,41 +262,41 @@ def test_list_dsn(): except Exception as e: print(f"An error occurred while attempting to delete the file: {e}") - - def test_prettify_statement(): - statement = 'SELECT 1' + statement = "SELECT 1" m = MyCli() pretty_statement = m.handle_prettify_binding(statement) - assert pretty_statement == 'SELECT\n 1;' + assert pretty_statement == "SELECT\n 1;" def test_unprettify_statement(): - statement = 'SELECT\n 1' + statement = "SELECT\n 1" m = MyCli() unpretty_statement = m.handle_unprettify_binding(statement) - assert unpretty_statement == 'SELECT 1;' + assert unpretty_statement == "SELECT 1;" def test_list_ssh_config(): runner = CliRunner() # keep Windows from locking the file with delete=False - with NamedTemporaryFile(mode="w",delete=False) as ssh_config: - ssh_config.write(dedent("""\ + with NamedTemporaryFile(mode="w", delete=False) as ssh_config: + ssh_config.write( + dedent("""\ Host test Hostname test.example.com User joe Port 22222 IdentityFile ~/.ssh/gateway - """)) + """) + ) ssh_config.flush() - args = ['--list-ssh-config', '--ssh-config-path', ssh_config.name] + args = ["--list-ssh-config", "--ssh-config-path", ssh_config.name] result = runner.invoke(cli, args=args) assert "test\n" in result.output - result = runner.invoke(cli, args=args + ['--verbose']) + result = runner.invoke(cli, args=args + ["--verbose"]) assert "test : test.example.com\n" in result.output - + # delete=False means we should try to clean up try: if os.path.exists(ssh_config.name): @@ -343,7 +318,7 @@ def test_dsn(monkeypatch): pass class MockMyCli: - config = {'alias_dsn': {}} + config = {"alias_dsn": {}} def __init__(self, **args): self.logger = Logger() @@ -357,97 +332,109 @@ def test_dsn(monkeypatch): pass import mycli.main - monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + + monkeypatch.setattr(mycli.main, "MyCli", MockMyCli) runner = CliRunner() # When a user supplies a DSN as database argument to mycli, # use these values. - result = runner.invoke(mycli.main.cli, args=[ - "mysql://dsn_user:dsn_passwd@dsn_host:1/dsn_database"] - ) + result = runner.invoke(mycli.main.cli, args=["mysql://dsn_user:dsn_passwd@dsn_host:1/dsn_database"]) assert result.exit_code == 0, result.output + " " + str(result.exception) - assert \ - MockMyCli.connect_args["user"] == "dsn_user" and \ - MockMyCli.connect_args["passwd"] == "dsn_passwd" and \ - MockMyCli.connect_args["host"] == "dsn_host" and \ - MockMyCli.connect_args["port"] == 1 and \ - MockMyCli.connect_args["database"] == "dsn_database" + assert ( + MockMyCli.connect_args["user"] == "dsn_user" + and MockMyCli.connect_args["passwd"] == "dsn_passwd" + and MockMyCli.connect_args["host"] == "dsn_host" + and MockMyCli.connect_args["port"] == 1 + and MockMyCli.connect_args["database"] == "dsn_database" + ) MockMyCli.connect_args = None # When a use supplies a DSN as database argument to mycli, # and used command line arguments, use the command line # arguments. - result = runner.invoke(mycli.main.cli, args=[ - "mysql://dsn_user:dsn_passwd@dsn_host:2/dsn_database", - "--user", "arg_user", - "--password", "arg_password", - "--host", "arg_host", - "--port", "3", - "--database", "arg_database", - ]) + result = runner.invoke( + mycli.main.cli, + args=[ + "mysql://dsn_user:dsn_passwd@dsn_host:2/dsn_database", + "--user", + "arg_user", + "--password", + "arg_password", + "--host", + "arg_host", + "--port", + "3", + "--database", + "arg_database", + ], + ) assert result.exit_code == 0, result.output + " " + str(result.exception) - assert \ - MockMyCli.connect_args["user"] == "arg_user" and \ - MockMyCli.connect_args["passwd"] == "arg_password" and \ - MockMyCli.connect_args["host"] == "arg_host" and \ - MockMyCli.connect_args["port"] == 3 and \ - MockMyCli.connect_args["database"] == "arg_database" - - MockMyCli.config = { - 'alias_dsn': { - 'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database' - } - } + assert ( + MockMyCli.connect_args["user"] == "arg_user" + and MockMyCli.connect_args["passwd"] == "arg_password" + and MockMyCli.connect_args["host"] == "arg_host" + and MockMyCli.connect_args["port"] == 3 + and MockMyCli.connect_args["database"] == "arg_database" + ) + + MockMyCli.config = {"alias_dsn": {"test": "mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database"}} MockMyCli.connect_args = None # When a user uses a DSN from the configuration file (alias_dsn), # use these values. - result = runner.invoke(cli, args=['--dsn', 'test']) + result = runner.invoke(cli, args=["--dsn", "test"]) assert result.exit_code == 0, result.output + " " + str(result.exception) - assert \ - MockMyCli.connect_args["user"] == "alias_dsn_user" and \ - MockMyCli.connect_args["passwd"] == "alias_dsn_passwd" and \ - MockMyCli.connect_args["host"] == "alias_dsn_host" and \ - MockMyCli.connect_args["port"] == 4 and \ - MockMyCli.connect_args["database"] == "alias_dsn_database" - - MockMyCli.config = { - 'alias_dsn': { - 'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database' - } - } + assert ( + MockMyCli.connect_args["user"] == "alias_dsn_user" + and MockMyCli.connect_args["passwd"] == "alias_dsn_passwd" + and MockMyCli.connect_args["host"] == "alias_dsn_host" + and MockMyCli.connect_args["port"] == 4 + and MockMyCli.connect_args["database"] == "alias_dsn_database" + ) + + MockMyCli.config = {"alias_dsn": {"test": "mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database"}} MockMyCli.connect_args = None # When a user uses a DSN from the configuration file (alias_dsn) # and used command line arguments, use the command line arguments. - result = runner.invoke(cli, args=[ - '--dsn', 'test', '', - "--user", "arg_user", - "--password", "arg_password", - "--host", "arg_host", - "--port", "5", - "--database", "arg_database", - ]) + result = runner.invoke( + cli, + args=[ + "--dsn", + "test", + "", + "--user", + "arg_user", + "--password", + "arg_password", + "--host", + "arg_host", + "--port", + "5", + "--database", + "arg_database", + ], + ) assert result.exit_code == 0, result.output + " " + str(result.exception) - assert \ - MockMyCli.connect_args["user"] == "arg_user" and \ - MockMyCli.connect_args["passwd"] == "arg_password" and \ - MockMyCli.connect_args["host"] == "arg_host" and \ - MockMyCli.connect_args["port"] == 5 and \ - MockMyCli.connect_args["database"] == "arg_database" + assert ( + MockMyCli.connect_args["user"] == "arg_user" + and MockMyCli.connect_args["passwd"] == "arg_password" + and MockMyCli.connect_args["host"] == "arg_host" + and MockMyCli.connect_args["port"] == 5 + and MockMyCli.connect_args["database"] == "arg_database" + ) # Use a DSN without password - result = runner.invoke(mycli.main.cli, args=[ - "mysql://dsn_user@dsn_host:6/dsn_database"] - ) + result = runner.invoke(mycli.main.cli, args=["mysql://dsn_user@dsn_host:6/dsn_database"]) assert result.exit_code == 0, result.output + " " + str(result.exception) - assert \ - MockMyCli.connect_args["user"] == "dsn_user" and \ - MockMyCli.connect_args["passwd"] is None and \ - MockMyCli.connect_args["host"] == "dsn_host" and \ - MockMyCli.connect_args["port"] == 6 and \ - MockMyCli.connect_args["database"] == "dsn_database" + assert ( + MockMyCli.connect_args["user"] == "dsn_user" + and MockMyCli.connect_args["passwd"] is None + and MockMyCli.connect_args["host"] == "dsn_host" + and MockMyCli.connect_args["port"] == 6 + and MockMyCli.connect_args["database"] == "dsn_database" + ) def test_ssh_config(monkeypatch): @@ -463,7 +450,7 @@ def test_ssh_config(monkeypatch): pass class MockMyCli: - config = {'alias_dsn': {}} + config = {"alias_dsn": {}} def __init__(self, **args): self.logger = Logger() @@ -477,58 +464,62 @@ def test_ssh_config(monkeypatch): pass import mycli.main - monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + + monkeypatch.setattr(mycli.main, "MyCli", MockMyCli) runner = CliRunner() # Setup temporary configuration # keep Windows from locking the file with delete=False - with NamedTemporaryFile(mode="w",delete=False) as ssh_config: - ssh_config.write(dedent("""\ + with NamedTemporaryFile(mode="w", delete=False) as ssh_config: + ssh_config.write( + dedent("""\ Host test Hostname test.example.com User joe Port 22222 IdentityFile ~/.ssh/gateway - """)) + """) + ) ssh_config.flush() # When a user supplies a ssh config. - result = runner.invoke(mycli.main.cli, args=[ - "--ssh-config-path", - ssh_config.name, - "--ssh-config-host", - "test" - ]) - assert result.exit_code == 0, result.output + \ - " " + str(result.exception) - assert \ - MockMyCli.connect_args["ssh_user"] == "joe" and \ - MockMyCli.connect_args["ssh_host"] == "test.example.com" and \ - MockMyCli.connect_args["ssh_port"] == 22222 and \ - MockMyCli.connect_args["ssh_key_filename"] == os.path.expanduser( - "~") + "/.ssh/gateway" + result = runner.invoke(mycli.main.cli, args=["--ssh-config-path", ssh_config.name, "--ssh-config-host", "test"]) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert ( + MockMyCli.connect_args["ssh_user"] == "joe" + and MockMyCli.connect_args["ssh_host"] == "test.example.com" + and MockMyCli.connect_args["ssh_port"] == 22222 + and MockMyCli.connect_args["ssh_key_filename"] == os.path.expanduser("~") + "/.ssh/gateway" + ) # When a user supplies a ssh config host as argument to mycli, # and used command line arguments, use the command line # arguments. - result = runner.invoke(mycli.main.cli, args=[ - "--ssh-config-path", - ssh_config.name, - "--ssh-config-host", - "test", - "--ssh-user", "arg_user", - "--ssh-host", "arg_host", - "--ssh-port", "3", - "--ssh-key-filename", "/path/to/key" - ]) - assert result.exit_code == 0, result.output + \ - " " + str(result.exception) - assert \ - MockMyCli.connect_args["ssh_user"] == "arg_user" and \ - MockMyCli.connect_args["ssh_host"] == "arg_host" and \ - MockMyCli.connect_args["ssh_port"] == 3 and \ - MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key" - + result = runner.invoke( + mycli.main.cli, + args=[ + "--ssh-config-path", + ssh_config.name, + "--ssh-config-host", + "test", + "--ssh-user", + "arg_user", + "--ssh-host", + "arg_host", + "--ssh-port", + "3", + "--ssh-key-filename", + "/path/to/key", + ], + ) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert ( + MockMyCli.connect_args["ssh_user"] == "arg_user" + and MockMyCli.connect_args["ssh_host"] == "arg_host" + and MockMyCli.connect_args["ssh_port"] == 3 + and MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key" + ) + # delete=False means we should try to clean up try: if os.path.exists(ssh_config.name): @@ -542,9 +533,7 @@ def test_init_command_arg(executor): init_command = "set sql_select_limit=1000" sql = 'show variables like "sql_select_limit";' runner = CliRunner() - result = runner.invoke( - cli, args=CLI_ARGS + ["--init-command", init_command], input=sql - ) + result = runner.invoke(cli, args=CLI_ARGS + ["--init-command", init_command], input=sql) expected = "sql_select_limit\t1000\n" assert result.exit_code == 0 @@ -553,18 +542,13 @@ def test_init_command_arg(executor): @dbtest def test_init_command_multiple_arg(executor): - init_command = 'set sql_select_limit=2000; set max_join_size=20000' - sql = ( - 'show variables like "sql_select_limit";\n' - 'show variables like "max_join_size"' - ) + init_command = "set sql_select_limit=2000; set max_join_size=20000" + sql = 'show variables like "sql_select_limit";\n' 'show variables like "max_join_size"' runner = CliRunner() - result = runner.invoke( - cli, args=CLI_ARGS + ['--init-command', init_command], input=sql - ) + result = runner.invoke(cli, args=CLI_ARGS + ["--init-command", init_command], input=sql) - expected_sql_select_limit = 'sql_select_limit\t2000\n' - expected_max_join_size = 'max_join_size\t20000\n' + expected_sql_select_limit = "sql_select_limit\t2000\n" + expected_max_join_size = "max_join_size\t20000\n" assert result.exit_code == 0 assert expected_sql_select_limit in result.output diff --git a/test/test_naive_completion.py b/test/test_naive_completion.py index 0bc3bf8..31ac165 100644 --- a/test/test_naive_completion.py +++ b/test/test_naive_completion.py @@ -6,56 +6,48 @@ from prompt_toolkit.document import Document @pytest.fixture def completer(): import mycli.sqlcompleter as sqlcompleter + return sqlcompleter.SQLCompleter(smart_completion=False) @pytest.fixture def complete_event(): from unittest.mock import Mock + return Mock() def test_empty_string_completion(completer, complete_event): - text = '' + text = "" position = 0 - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list(map(Completion, completer.all_completions)) def test_select_keyword_completion(completer, complete_event): - text = 'SEL' - position = len('SEL') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == list([Completion(text='SELECT', start_position=-3)]) + text = "SEL" + position = len("SEL") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == list([Completion(text="SELECT", start_position=-3)]) def test_function_name_completion(completer, complete_event): - text = 'SELECT MA' - position = len('SELECT MA') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) + text = "SELECT MA" + position = len("SELECT MA") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert sorted(x.text for x in result) == ["MASTER", "MAX"] def test_column_name_completion(completer, complete_event): - text = 'SELECT FROM users' - position = len('SELECT ') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) + text = "SELECT FROM users" + position = len("SELECT ") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list(map(Completion, completer.all_completions)) def test_special_name_completion(completer, complete_event): - text = '\\' - position = len('\\') - result = set(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) + text = "\\" + position = len("\\") + result = set(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) # Special commands will NOT be suggested during naive completion mode. assert result == set() diff --git a/test/test_parseutils.py b/test/test_parseutils.py index 920a08d..0925299 100644 --- a/test/test_parseutils.py +++ b/test/test_parseutils.py @@ -1,67 +1,72 @@ import pytest from mycli.packages.parseutils import ( - extract_tables, query_starts_with, queries_start_with, is_destructive, query_has_where_clause, - is_dropping_database) + extract_tables, + query_starts_with, + queries_start_with, + is_destructive, + query_has_where_clause, + is_dropping_database, +) def test_empty_string(): - tables = extract_tables('') + tables = extract_tables("") assert tables == [] def test_simple_select_single_table(): - tables = extract_tables('select * from abc') - assert tables == [(None, 'abc', None)] + tables = extract_tables("select * from abc") + assert tables == [(None, "abc", None)] def test_simple_select_single_table_schema_qualified(): - tables = extract_tables('select * from abc.def') - assert tables == [('abc', 'def', None)] + tables = extract_tables("select * from abc.def") + assert tables == [("abc", "def", None)] def test_simple_select_multiple_tables(): - tables = extract_tables('select * from abc, def') - assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] + tables = extract_tables("select * from abc, def") + assert sorted(tables) == [(None, "abc", None), (None, "def", None)] def test_simple_select_multiple_tables_schema_qualified(): - tables = extract_tables('select * from abc.def, ghi.jkl') - assert sorted(tables) == [('abc', 'def', None), ('ghi', 'jkl', None)] + tables = extract_tables("select * from abc.def, ghi.jkl") + assert sorted(tables) == [("abc", "def", None), ("ghi", "jkl", None)] def test_simple_select_with_cols_single_table(): - tables = extract_tables('select a,b from abc') - assert tables == [(None, 'abc', None)] + tables = extract_tables("select a,b from abc") + assert tables == [(None, "abc", None)] def test_simple_select_with_cols_single_table_schema_qualified(): - tables = extract_tables('select a,b from abc.def') - assert tables == [('abc', 'def', None)] + tables = extract_tables("select a,b from abc.def") + assert tables == [("abc", "def", None)] def test_simple_select_with_cols_multiple_tables(): - tables = extract_tables('select a,b from abc, def') - assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] + tables = extract_tables("select a,b from abc, def") + assert sorted(tables) == [(None, "abc", None), (None, "def", None)] def test_simple_select_with_cols_multiple_tables_with_schema(): - tables = extract_tables('select a,b from abc.def, def.ghi') - assert sorted(tables) == [('abc', 'def', None), ('def', 'ghi', None)] + tables = extract_tables("select a,b from abc.def, def.ghi") + assert sorted(tables) == [("abc", "def", None), ("def", "ghi", None)] def test_select_with_hanging_comma_single_table(): - tables = extract_tables('select a, from abc') - assert tables == [(None, 'abc', None)] + tables = extract_tables("select a, from abc") + assert tables == [(None, "abc", None)] def test_select_with_hanging_comma_multiple_tables(): - tables = extract_tables('select a, from abc, def') - assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] + tables = extract_tables("select a, from abc, def") + assert sorted(tables) == [(None, "abc", None), (None, "def", None)] def test_select_with_hanging_period_multiple_tables(): - tables = extract_tables('SELECT t1. FROM tabl1 t1, tabl2 t2') - assert sorted(tables) == [(None, 'tabl1', 't1'), (None, 'tabl2', 't2')] + tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2") + assert sorted(tables) == [(None, "tabl1", "t1"), (None, "tabl2", "t2")] def test_simple_insert_single_table(): @@ -69,97 +74,80 @@ def test_simple_insert_single_table(): # sqlparse mistakenly assigns an alias to the table # assert tables == [(None, 'abc', None)] - assert tables == [(None, 'abc', 'abc')] + assert tables == [(None, "abc", "abc")] @pytest.mark.xfail def test_simple_insert_single_table_schema_qualified(): tables = extract_tables('insert into abc.def (id, name) values (1, "def")') - assert tables == [('abc', 'def', None)] + assert tables == [("abc", "def", None)] def test_simple_update_table(): - tables = extract_tables('update abc set id = 1') - assert tables == [(None, 'abc', None)] + tables = extract_tables("update abc set id = 1") + assert tables == [(None, "abc", None)] def test_simple_update_table_with_schema(): - tables = extract_tables('update abc.def set id = 1') - assert tables == [('abc', 'def', None)] + tables = extract_tables("update abc.def set id = 1") + assert tables == [("abc", "def", None)] def test_join_table(): - tables = extract_tables('SELECT * FROM abc a JOIN def d ON a.id = d.num') - assert sorted(tables) == [(None, 'abc', 'a'), (None, 'def', 'd')] + tables = extract_tables("SELECT * FROM abc a JOIN def d ON a.id = d.num") + assert sorted(tables) == [(None, "abc", "a"), (None, "def", "d")] def test_join_table_schema_qualified(): - tables = extract_tables( - 'SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num') - assert tables == [('abc', 'def', 'x'), ('ghi', 'jkl', 'y')] + tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num") + assert tables == [("abc", "def", "x"), ("ghi", "jkl", "y")] def test_join_as_table(): - tables = extract_tables('SELECT * FROM my_table AS m WHERE m.a > 5') - assert tables == [(None, 'my_table', 'm')] + tables = extract_tables("SELECT * FROM my_table AS m WHERE m.a > 5") + assert tables == [(None, "my_table", "m")] def test_query_starts_with(): - query = 'USE test;' - assert query_starts_with(query, ('use', )) is True + query = "USE test;" + assert query_starts_with(query, ("use",)) is True - query = 'DROP DATABASE test;' - assert query_starts_with(query, ('use', )) is False + query = "DROP DATABASE test;" + assert query_starts_with(query, ("use",)) is False def test_query_starts_with_comment(): - query = '# comment\nUSE test;' - assert query_starts_with(query, ('use', )) is True + query = "# comment\nUSE test;" + assert query_starts_with(query, ("use",)) is True def test_queries_start_with(): - sql = ( - '# comment\n' - 'show databases;' - 'use foo;' - ) - assert queries_start_with(sql, ('show', 'select')) is True - assert queries_start_with(sql, ('use', 'drop')) is True - assert queries_start_with(sql, ('delete', 'update')) is False + sql = "# comment\n" "show databases;" "use foo;" + assert queries_start_with(sql, ("show", "select")) is True + assert queries_start_with(sql, ("use", "drop")) is True + assert queries_start_with(sql, ("delete", "update")) is False def test_is_destructive(): - sql = ( - 'use test;\n' - 'show databases;\n' - 'drop database foo;' - ) + sql = "use test;\n" "show databases;\n" "drop database foo;" assert is_destructive(sql) is True def test_is_destructive_update_with_where_clause(): - sql = ( - 'use test;\n' - 'show databases;\n' - 'UPDATE test SET x = 1 WHERE id = 1;' - ) + sql = "use test;\n" "show databases;\n" "UPDATE test SET x = 1 WHERE id = 1;" assert is_destructive(sql) is False def test_is_destructive_update_without_where_clause(): - sql = ( - 'use test;\n' - 'show databases;\n' - 'UPDATE test SET x = 1;' - ) + sql = "use test;\n" "show databases;\n" "UPDATE test SET x = 1;" assert is_destructive(sql) is True @pytest.mark.parametrize( - ('sql', 'has_where_clause'), + ("sql", "has_where_clause"), [ - ('update test set dummy = 1;', False), - ('update test set dummy = 1 where id = 1);', True), + ("update test set dummy = 1;", False), + ("update test set dummy = 1 where id = 1);", True), ], ) def test_query_has_where_clause(sql, has_where_clause): @@ -167,24 +155,20 @@ def test_query_has_where_clause(sql, has_where_clause): @pytest.mark.parametrize( - ('sql', 'dbname', 'is_dropping'), + ("sql", "dbname", "is_dropping"), [ - ('select bar from foo', 'foo', False), - ('drop database "foo";', '`foo`', True), - ('drop schema foo', 'foo', True), - ('drop schema foo', 'bar', False), - ('drop database bar', 'foo', False), - ('drop database foo', None, False), - ('drop database foo; create database foo', 'foo', False), - ('drop database foo; create database bar', 'foo', True), - ('select bar from foo; drop database bazz', 'foo', False), - ('select bar from foo; drop database bazz', 'bazz', True), - ('-- dropping database \n ' - 'drop -- really dropping \n ' - 'schema abc -- now it is dropped', - 'abc', - True) - ] + ("select bar from foo", "foo", False), + ('drop database "foo";', "`foo`", True), + ("drop schema foo", "foo", True), + ("drop schema foo", "bar", False), + ("drop database bar", "foo", False), + ("drop database foo", None, False), + ("drop database foo; create database foo", "foo", False), + ("drop database foo; create database bar", "foo", True), + ("select bar from foo; drop database bazz", "foo", False), + ("select bar from foo; drop database bazz", "bazz", True), + ("-- dropping database \n " "drop -- really dropping \n " "schema abc -- now it is dropped", "abc", True), + ], ) def test_is_dropping_database(sql, dbname, is_dropping): assert is_dropping_database(sql, dbname) == is_dropping diff --git a/test/test_prompt_utils.py b/test/test_prompt_utils.py index 2373fac..625e022 100644 --- a/test/test_prompt_utils.py +++ b/test/test_prompt_utils.py @@ -4,8 +4,8 @@ from mycli.packages.prompt_utils import confirm_destructive_query def test_confirm_destructive_query_notty(): - stdin = click.get_text_stream('stdin') + stdin = click.get_text_stream("stdin") assert stdin.isatty() is False - sql = 'drop database foo;' + sql = "drop database foo;" assert confirm_destructive_query(sql) is None diff --git a/test/test_smart_completion_public_schema_only.py b/test/test_smart_completion_public_schema_only.py index 30b15ac..8ad40a4 100644 --- a/test/test_smart_completion_public_schema_only.py +++ b/test/test_smart_completion_public_schema_only.py @@ -43,49 +43,35 @@ def complete_event(): def test_special_name_completion(completer, complete_event): text = "\\d" position = len("\\d") - result = completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) assert result == [Completion(text="\\dt", start_position=-2)] def test_empty_string_completion(completer, complete_event): text = "" position = 0 - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) - assert ( - list(map(Completion, completer.keywords + completer.special_commands)) == result - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert list(map(Completion, completer.keywords + completer.special_commands)) == result def test_select_keyword_completion(completer, complete_event): text = "SEL" position = len("SEL") - result = completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) assert list(result) == list([Completion(text="SELECT", start_position=-3)]) def test_select_star(completer, complete_event): text = "SELECT * " position = len(text) - result = completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) assert list(result) == list(map(Completion, completer.keywords)) def test_table_completion(completer, complete_event): text = "SELECT * FROM " position = len(text) - result = completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) assert list(result) == list( [ Completion(text="users", start_position=0), @@ -99,9 +85,7 @@ def test_table_completion(completer, complete_event): def test_function_name_completion(completer, complete_event): text = "SELECT MA" position = len("SELECT MA") - result = completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) assert list(result) == list( [ Completion(text="MAX", start_position=-2), @@ -127,11 +111,7 @@ def test_suggested_column_names(completer, complete_event): """ text = "SELECT from users" position = len("SELECT ") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="*", start_position=0), @@ -157,9 +137,7 @@ def test_suggested_column_names_in_function(completer, complete_event): """ text = "SELECT MAX( from users" position = len("SELECT MAX(") - result = completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) assert list(result) == list( [ Completion(text="*", start_position=0), @@ -181,11 +159,7 @@ def test_suggested_column_names_with_table_dot(completer, complete_event): """ text = "SELECT users. from users" position = len("SELECT users.") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="*", start_position=0), @@ -207,11 +181,7 @@ def test_suggested_column_names_with_alias(completer, complete_event): """ text = "SELECT u. from users u" position = len("SELECT u.") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="*", start_position=0), @@ -234,11 +204,7 @@ def test_suggested_multiple_column_names(completer, complete_event): """ text = "SELECT id, from users u" position = len("SELECT id, ") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="*", start_position=0), @@ -264,11 +230,7 @@ def test_suggested_multiple_column_names_with_alias(completer, complete_event): """ text = "SELECT u.id, u. from users u" position = len("SELECT u.id, u.") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="*", start_position=0), @@ -291,11 +253,7 @@ def test_suggested_multiple_column_names_with_dot(completer, complete_event): """ text = "SELECT users.id, users. from users u" position = len("SELECT users.id, users.") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="*", start_position=0), @@ -310,11 +268,7 @@ def test_suggested_multiple_column_names_with_dot(completer, complete_event): def test_suggested_aliases_after_on(completer, complete_event): text = "SELECT u.name, o.id FROM users u JOIN orders o ON " position = len("SELECT u.name, o.id FROM users u JOIN orders o ON ") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="u", start_position=0), @@ -326,11 +280,7 @@ def test_suggested_aliases_after_on(completer, complete_event): def test_suggested_aliases_after_on_right_side(completer, complete_event): text = "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = " position = len("SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="u", start_position=0), @@ -342,11 +292,7 @@ def test_suggested_aliases_after_on_right_side(completer, complete_event): def test_suggested_tables_after_on(completer, complete_event): text = "SELECT users.name, orders.id FROM users JOIN orders ON " position = len("SELECT users.name, orders.id FROM users JOIN orders ON ") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="users", start_position=0), @@ -357,14 +303,8 @@ def test_suggested_tables_after_on(completer, complete_event): def test_suggested_tables_after_on_right_side(completer, complete_event): text = "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = " - position = len( - "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = " - ) - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + position = len("SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="users", start_position=0), @@ -376,11 +316,7 @@ def test_suggested_tables_after_on_right_side(completer, complete_event): def test_table_names_after_from(completer, complete_event): text = "SELECT * FROM " position = len("SELECT * FROM ") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="users", start_position=0), @@ -394,29 +330,21 @@ def test_table_names_after_from(completer, complete_event): def test_auto_escaped_col_names(completer, complete_event): text = "SELECT from `select`" position = len("SELECT ") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == [ Completion(text="*", start_position=0), Completion(text="id", start_position=0), Completion(text="`insert`", start_position=0), Completion(text="`ABC`", start_position=0), - ] + list(map(Completion, completer.functions)) + [ - Completion(text="select", start_position=0) - ] + list(map(Completion, completer.keywords)) + ] + list(map(Completion, completer.functions)) + [Completion(text="select", start_position=0)] + list( + map(Completion, completer.keywords) + ) def test_un_escaped_table_names(completer, complete_event): text = "SELECT from réveillé" position = len("SELECT ") - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) assert result == list( [ Completion(text="*", start_position=0), @@ -464,10 +392,6 @@ def dummy_list_path(dir_name): ) def test_file_name_completion(completer, complete_event, text, expected): position = len(text) - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), complete_event - ) - ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) expected = list((Completion(txt, pos) for txt, pos in expected)) assert result == expected diff --git a/test/test_special_iocommands.py b/test/test_special_iocommands.py index d0ca45f..bea5620 100644 --- a/test/test_special_iocommands.py +++ b/test/test_special_iocommands.py @@ -17,11 +17,11 @@ def test_set_get_pager(): assert mycli.packages.special.is_pager_enabled() mycli.packages.special.set_pager_enabled(False) assert not mycli.packages.special.is_pager_enabled() - mycli.packages.special.set_pager('less') - assert os.environ['PAGER'] == "less" + mycli.packages.special.set_pager("less") + assert os.environ["PAGER"] == "less" mycli.packages.special.set_pager(False) - assert os.environ['PAGER'] == "less" - del os.environ['PAGER'] + assert os.environ["PAGER"] == "less" + del os.environ["PAGER"] mycli.packages.special.set_pager(False) mycli.packages.special.disable_pager() assert not mycli.packages.special.is_pager_enabled() @@ -42,45 +42,44 @@ def test_set_get_expanded_output(): def test_editor_command(): - assert mycli.packages.special.editor_command(r'hello\e') - assert mycli.packages.special.editor_command(r'\ehello') - assert not mycli.packages.special.editor_command(r'hello') + assert mycli.packages.special.editor_command(r"hello\e") + assert mycli.packages.special.editor_command(r"\ehello") + assert not mycli.packages.special.editor_command(r"hello") - assert mycli.packages.special.get_filename(r'\e filename') == "filename" + assert mycli.packages.special.get_filename(r"\e filename") == "filename" - os.environ['EDITOR'] = 'true' - os.environ['VISUAL'] = 'true' + os.environ["EDITOR"] = "true" + os.environ["VISUAL"] = "true" # Set the editor to Notepad on Windows - if os.name != 'nt': - mycli.packages.special.open_external_editor(sql=r'select 1') == "select 1" + if os.name != "nt": + mycli.packages.special.open_external_editor(sql=r"select 1") == "select 1" else: - pytest.skip('Skipping on Windows platform.') - + pytest.skip("Skipping on Windows platform.") def test_tee_command(): - mycli.packages.special.write_tee(u"hello world") # write without file set + mycli.packages.special.write_tee("hello world") # write without file set # keep Windows from locking the file with delete=False with tempfile.NamedTemporaryFile(delete=False) as f: - mycli.packages.special.execute(None, u"tee " + f.name) - mycli.packages.special.write_tee(u"hello world") - if os.name=='nt': + mycli.packages.special.execute(None, "tee " + f.name) + mycli.packages.special.write_tee("hello world") + if os.name == "nt": assert f.read() == b"hello world\r\n" else: assert f.read() == b"hello world\n" - mycli.packages.special.execute(None, u"tee -o " + f.name) - mycli.packages.special.write_tee(u"hello world") + mycli.packages.special.execute(None, "tee -o " + f.name) + mycli.packages.special.write_tee("hello world") f.seek(0) - if os.name=='nt': + if os.name == "nt": assert f.read() == b"hello world\r\n" else: assert f.read() == b"hello world\n" - mycli.packages.special.execute(None, u"notee") - mycli.packages.special.write_tee(u"hello world") + mycli.packages.special.execute(None, "notee") + mycli.packages.special.write_tee("hello world") f.seek(0) - if os.name=='nt': + if os.name == "nt": assert f.read() == b"hello world\r\n" else: assert f.read() == b"hello world\n" @@ -92,52 +91,49 @@ def test_tee_command(): os.remove(f.name) except Exception as e: print(f"An error occurred while attempting to delete the file: {e}") - def test_tee_command_error(): with pytest.raises(TypeError): - mycli.packages.special.execute(None, 'tee') + mycli.packages.special.execute(None, "tee") with pytest.raises(OSError): with tempfile.NamedTemporaryFile() as f: os.chmod(f.name, stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH) - mycli.packages.special.execute(None, 'tee {}'.format(f.name)) + mycli.packages.special.execute(None, "tee {}".format(f.name)) @dbtest - @pytest.mark.skipif(os.name == "nt", reason="Bug: fails on Windows, needs fixing, singleton of FQ not working right") def test_favorite_query(): with db_connection().cursor() as cur: - query = u'select "✔"' - mycli.packages.special.execute(cur, u'\\fs check {0}'.format(query)) - assert next(mycli.packages.special.execute( - cur, u'\\f check'))[0] == "> " + query + query = 'select "✔"' + mycli.packages.special.execute(cur, "\\fs check {0}".format(query)) + assert next(mycli.packages.special.execute(cur, "\\f check"))[0] == "> " + query def test_once_command(): with pytest.raises(TypeError): - mycli.packages.special.execute(None, u"\\once") + mycli.packages.special.execute(None, "\\once") with pytest.raises(OSError): - mycli.packages.special.execute(None, u"\\once /proc/access-denied") + mycli.packages.special.execute(None, "\\once /proc/access-denied") - mycli.packages.special.write_once(u"hello world") # write without file set + mycli.packages.special.write_once("hello world") # write without file set # keep Windows from locking the file with delete=False with tempfile.NamedTemporaryFile(delete=False) as f: - mycli.packages.special.execute(None, u"\\once " + f.name) - mycli.packages.special.write_once(u"hello world") - if os.name=='nt': + mycli.packages.special.execute(None, "\\once " + f.name) + mycli.packages.special.write_once("hello world") + if os.name == "nt": assert f.read() == b"hello world\r\n" else: assert f.read() == b"hello world\n" - mycli.packages.special.execute(None, u"\\once -o " + f.name) - mycli.packages.special.write_once(u"hello world line 1") - mycli.packages.special.write_once(u"hello world line 2") + mycli.packages.special.execute(None, "\\once -o " + f.name) + mycli.packages.special.write_once("hello world line 1") + mycli.packages.special.write_once("hello world line 2") f.seek(0) - if os.name=='nt': + if os.name == "nt": assert f.read() == b"hello world line 1\r\nhello world line 2\r\n" else: assert f.read() == b"hello world line 1\nhello world line 2\n" @@ -151,52 +147,47 @@ def test_once_command(): def test_pipe_once_command(): with pytest.raises(IOError): - mycli.packages.special.execute(None, u"\\pipe_once") + mycli.packages.special.execute(None, "\\pipe_once") with pytest.raises(OSError): - mycli.packages.special.execute( - None, u"\\pipe_once /proc/access-denied") + mycli.packages.special.execute(None, "\\pipe_once /proc/access-denied") - if os.name == 'nt': - mycli.packages.special.execute(None, u'\\pipe_once python -c "import sys; print(len(sys.stdin.read().strip()))"') - mycli.packages.special.write_once(u"hello world") + if os.name == "nt": + mycli.packages.special.execute(None, '\\pipe_once python -c "import sys; print(len(sys.stdin.read().strip()))"') + mycli.packages.special.write_once("hello world") mycli.packages.special.unset_pipe_once_if_written() else: - mycli.packages.special.execute(None, u"\\pipe_once wc") - mycli.packages.special.write_once(u"hello world") - mycli.packages.special.unset_pipe_once_if_written() - # how to assert on wc output? + with tempfile.NamedTemporaryFile() as f: + mycli.packages.special.execute(None, "\\pipe_once tee " + f.name) + mycli.packages.special.write_pipe_once("hello world") + mycli.packages.special.unset_pipe_once_if_written() + f.seek(0) + assert f.read() == b"hello world\n" def test_parseargfile(): """Test that parseargfile expands the user directory.""" - expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'), - 'mode': 'a'} - - if os.name=='nt': - assert expected == mycli.packages.special.iocommands.parseargfile( - '~\\filename') + expected = {"file": os.path.join(os.path.expanduser("~"), "filename"), "mode": "a"} + + if os.name == "nt": + assert expected == mycli.packages.special.iocommands.parseargfile("~\\filename") else: - assert expected == mycli.packages.special.iocommands.parseargfile( - '~/filename') - - expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'), - 'mode': 'w'} - if os.name=='nt': - assert expected == mycli.packages.special.iocommands.parseargfile( - '-o ~\\filename') + assert expected == mycli.packages.special.iocommands.parseargfile("~/filename") + + expected = {"file": os.path.join(os.path.expanduser("~"), "filename"), "mode": "w"} + if os.name == "nt": + assert expected == mycli.packages.special.iocommands.parseargfile("-o ~\\filename") else: - assert expected == mycli.packages.special.iocommands.parseargfile( - '-o ~/filename') + assert expected == mycli.packages.special.iocommands.parseargfile("-o ~/filename") def test_parseargfile_no_file(): """Test that parseargfile raises a TypeError if there is no filename.""" with pytest.raises(TypeError): - mycli.packages.special.iocommands.parseargfile('') + mycli.packages.special.iocommands.parseargfile("") with pytest.raises(TypeError): - mycli.packages.special.iocommands.parseargfile('-o ') + mycli.packages.special.iocommands.parseargfile("-o ") @dbtest @@ -205,11 +196,9 @@ def test_watch_query_iteration(): the desired query and returns the given results.""" expected_value = "1" query = "SELECT {0!s}".format(expected_value) - expected_title = '> {0!s}'.format(query) + expected_title = "> {0!s}".format(query) with db_connection().cursor() as cur: - result = next(mycli.packages.special.iocommands.watch_query( - arg=query, cur=cur - )) + result = next(mycli.packages.special.iocommands.watch_query(arg=query, cur=cur)) assert result[0] == expected_title assert result[2][0] == expected_value @@ -230,14 +219,12 @@ def test_watch_query_full(): wait_interval = 1 expected_value = "1" query = "SELECT {0!s}".format(expected_value) - expected_title = '> {0!s}'.format(query) + expected_title = "> {0!s}".format(query) expected_results = 4 ctrl_c_process = send_ctrl_c(wait_interval) with db_connection().cursor() as cur: results = list( - result for result in mycli.packages.special.iocommands.watch_query( - arg='{0!s} {1!s}'.format(watch_seconds, query), cur=cur - ) + result for result in mycli.packages.special.iocommands.watch_query(arg="{0!s} {1!s}".format(watch_seconds, query), cur=cur) ) ctrl_c_process.join(1) assert len(results) == expected_results @@ -247,14 +234,12 @@ def test_watch_query_full(): @dbtest -@patch('click.clear') +@patch("click.clear") def test_watch_query_clear(clear_mock): """Test that the screen is cleared with the -c flag of `watch` command before execute the query.""" with db_connection().cursor() as cur: - watch_gen = mycli.packages.special.iocommands.watch_query( - arg='0.1 -c select 1;', cur=cur - ) + watch_gen = mycli.packages.special.iocommands.watch_query(arg="0.1 -c select 1;", cur=cur) assert not clear_mock.called next(watch_gen) assert clear_mock.called @@ -271,19 +256,20 @@ def test_watch_query_bad_arguments(): watch_query = mycli.packages.special.iocommands.watch_query with db_connection().cursor() as cur: with pytest.raises(ProgrammingError): - next(watch_query('a select 1;', cur=cur)) + next(watch_query("a select 1;", cur=cur)) with pytest.raises(ProgrammingError): - next(watch_query('-a select 1;', cur=cur)) + next(watch_query("-a select 1;", cur=cur)) with pytest.raises(ProgrammingError): - next(watch_query('1 -a select 1;', cur=cur)) + next(watch_query("1 -a select 1;", cur=cur)) with pytest.raises(ProgrammingError): - next(watch_query('-c -a select 1;', cur=cur)) + next(watch_query("-c -a select 1;", cur=cur)) @dbtest -@patch('click.clear') +@patch("click.clear") def test_watch_query_interval_clear(clear_mock): """Test `watch` command with interval and clear flag.""" + def test_asserts(gen): clear_mock.reset_mock() start = time() @@ -296,46 +282,32 @@ def test_watch_query_interval_clear(clear_mock): seconds = 1.0 watch_query = mycli.packages.special.iocommands.watch_query with db_connection().cursor() as cur: - test_asserts(watch_query('{0!s} -c select 1;'.format(seconds), - cur=cur)) - test_asserts(watch_query('-c {0!s} select 1;'.format(seconds), - cur=cur)) + test_asserts(watch_query("{0!s} -c select 1;".format(seconds), cur=cur)) + test_asserts(watch_query("-c {0!s} select 1;".format(seconds), cur=cur)) def test_split_sql_by_delimiter(): - for delimiter_str in (';', '$', '😀'): + for delimiter_str in (";", "$", "😀"): mycli.packages.special.set_delimiter(delimiter_str) sql_input = "select 1{} select \ufffc2".format(delimiter_str) - queries = ( - "select 1", - "select \ufffc2" - ) - for query, parsed_query in zip( - queries, mycli.packages.special.split_queries(sql_input)): - assert(query == parsed_query) + queries = ("select 1", "select \ufffc2") + for query, parsed_query in zip(queries, mycli.packages.special.split_queries(sql_input)): + assert query == parsed_query def test_switch_delimiter_within_query(): - mycli.packages.special.set_delimiter(';') + mycli.packages.special.set_delimiter(";") sql_input = "select 1; delimiter $$ select 2 $$ select 3 $$" - queries = ( - "select 1", - "delimiter $$ select 2 $$ select 3 $$", - "select 2", - "select 3" - ) - for query, parsed_query in zip( - queries, - mycli.packages.special.split_queries(sql_input)): - assert(query == parsed_query) + queries = ("select 1", "delimiter $$ select 2 $$ select 3 $$", "select 2", "select 3") + for query, parsed_query in zip(queries, mycli.packages.special.split_queries(sql_input)): + assert query == parsed_query def test_set_delimiter(): - - for delim in ('foo', 'bar'): + for delim in ("foo", "bar"): mycli.packages.special.set_delimiter(delim) assert mycli.packages.special.get_current_delimiter() == delim def teardown_function(): - mycli.packages.special.set_delimiter(';') + mycli.packages.special.set_delimiter(";") diff --git a/test/test_sqlexecute.py b/test/test_sqlexecute.py index ca186bc..17e082b 100644 --- a/test/test_sqlexecute.py +++ b/test/test_sqlexecute.py @@ -7,14 +7,11 @@ from mycli.sqlexecute import ServerInfo, ServerSpecies from .utils import run, dbtest, set_expanded_output, is_expanded_output -def assert_result_equal(result, title=None, rows=None, headers=None, - status=None, auto_status=True, assert_contains=False): +def assert_result_equal(result, title=None, rows=None, headers=None, status=None, auto_status=True, assert_contains=False): """Assert that an sqlexecute.run() result matches the expected values.""" if status is None and auto_status and rows: - status = '{} row{} in set'.format( - len(rows), 's' if len(rows) > 1 else '') - fields = {'title': title, 'rows': rows, 'headers': headers, - 'status': status} + status = "{} row{} in set".format(len(rows), "s" if len(rows) > 1 else "") + fields = {"title": title, "rows": rows, "headers": headers, "status": status} if assert_contains: # Do a loose match on the results using the *in* operator. @@ -28,34 +25,35 @@ def assert_result_equal(result, title=None, rows=None, headers=None, @dbtest def test_conn(executor): - run(executor, '''create table test(a text)''') - run(executor, '''insert into test values('abc')''') - results = run(executor, '''select * from test''') + run(executor, """create table test(a text)""") + run(executor, """insert into test values('abc')""") + results = run(executor, """select * from test""") - assert_result_equal(results, headers=['a'], rows=[('abc',)]) + assert_result_equal(results, headers=["a"], rows=[("abc",)]) @dbtest def test_bools(executor): - run(executor, '''create table test(a boolean)''') - run(executor, '''insert into test values(True)''') - results = run(executor, '''select * from test''') + run(executor, """create table test(a boolean)""") + run(executor, """insert into test values(True)""") + results = run(executor, """select * from test""") - assert_result_equal(results, headers=['a'], rows=[(1,)]) + assert_result_equal(results, headers=["a"], rows=[(1,)]) @dbtest def test_binary(executor): - run(executor, '''create table bt(geom linestring NOT NULL)''') - run(executor, "INSERT INTO bt VALUES " - "(ST_GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));") - results = run(executor, '''select * from bt''') - - geom = (b'\x00\x00\x00\x00\x01\x02\x00\x00\x00\x02\x00\x00\x009\x7f\x13\n' - b'\x11\x18]@4\xf4Op\xb1\xdeC@\x00\x00\x00\x00\x00\x18]@B>\xe8\xd9' - b'\xac\xdeC@') + run(executor, """create table bt(geom linestring NOT NULL)""") + run(executor, "INSERT INTO bt VALUES " "(ST_GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));") + results = run(executor, """select * from bt""") + + geom = ( + b"\x00\x00\x00\x00\x01\x02\x00\x00\x00\x02\x00\x00\x009\x7f\x13\n" + b"\x11\x18]@4\xf4Op\xb1\xdeC@\x00\x00\x00\x00\x00\x18]@B>\xe8\xd9" + b"\xac\xdeC@" + ) - assert_result_equal(results, headers=['geom'], rows=[(geom,)]) + assert_result_equal(results, headers=["geom"], rows=[(geom,)]) @dbtest @@ -63,49 +61,48 @@ def test_table_and_columns_query(executor): run(executor, "create table a(x text, y text)") run(executor, "create table b(z text)") - assert set(executor.tables()) == set([('a',), ('b',)]) - assert set(executor.table_columns()) == set( - [('a', 'x'), ('a', 'y'), ('b', 'z')]) + assert set(executor.tables()) == set([("a",), ("b",)]) + assert set(executor.table_columns()) == set([("a", "x"), ("a", "y"), ("b", "z")]) @dbtest def test_database_list(executor): databases = executor.databases() - assert 'mycli_test_db' in databases + assert "mycli_test_db" in databases @dbtest def test_invalid_syntax(executor): with pytest.raises(pymysql.ProgrammingError) as excinfo: - run(executor, 'invalid syntax!') - assert 'You have an error in your SQL syntax;' in str(excinfo.value) + run(executor, "invalid syntax!") + assert "You have an error in your SQL syntax;" in str(excinfo.value) @dbtest def test_invalid_column_name(executor): with pytest.raises(pymysql.err.OperationalError) as excinfo: - run(executor, 'select invalid command') + run(executor, "select invalid command") assert "Unknown column 'invalid' in 'field list'" in str(excinfo.value) @dbtest def test_unicode_support_in_output(executor): run(executor, "create table unicodechars(t text)") - run(executor, u"insert into unicodechars (t) values ('é')") + run(executor, "insert into unicodechars (t) values ('é')") # See issue #24, this raises an exception without proper handling - results = run(executor, u"select * from unicodechars") - assert_result_equal(results, headers=['t'], rows=[(u'é',)]) + results = run(executor, "select * from unicodechars") + assert_result_equal(results, headers=["t"], rows=[("é",)]) @dbtest def test_multiple_queries_same_line(executor): results = run(executor, "select 'foo'; select 'bar'") - expected = [{'title': None, 'headers': ['foo'], 'rows': [('foo',)], - 'status': '1 row in set'}, - {'title': None, 'headers': ['bar'], 'rows': [('bar',)], - 'status': '1 row in set'}] + expected = [ + {"title": None, "headers": ["foo"], "rows": [("foo",)], "status": "1 row in set"}, + {"title": None, "headers": ["bar"], "rows": [("bar",)], "status": "1 row in set"}, + ] assert expected == results @@ -113,7 +110,7 @@ def test_multiple_queries_same_line(executor): def test_multiple_queries_same_line_syntaxerror(executor): with pytest.raises(pymysql.ProgrammingError) as excinfo: run(executor, "select 'foo'; invalid syntax") - assert 'You have an error in your SQL syntax;' in str(excinfo.value) + assert "You have an error in your SQL syntax;" in str(excinfo.value) @dbtest @@ -125,15 +122,13 @@ def test_favorite_query(executor): run(executor, "insert into test values('def')") results = run(executor, "\\fs test-a select * from test where a like 'a%'") - assert_result_equal(results, status='Saved.') + assert_result_equal(results, status="Saved.") results = run(executor, "\\f test-a") - assert_result_equal(results, - title="> select * from test where a like 'a%'", - headers=['a'], rows=[('abc',)], auto_status=False) + assert_result_equal(results, title="> select * from test where a like 'a%'", headers=["a"], rows=[("abc",)], auto_status=False) results = run(executor, "\\fd test-a") - assert_result_equal(results, status='test-a: Deleted') + assert_result_equal(results, status="test-a: Deleted") @dbtest @@ -144,158 +139,147 @@ def test_favorite_query_multiple_statement(executor): run(executor, "insert into test values('abc')") run(executor, "insert into test values('def')") - results = run(executor, - "\\fs test-ad select * from test where a like 'a%'; " - "select * from test where a like 'd%'") - assert_result_equal(results, status='Saved.') + results = run(executor, "\\fs test-ad select * from test where a like 'a%'; " "select * from test where a like 'd%'") + assert_result_equal(results, status="Saved.") results = run(executor, "\\f test-ad") - expected = [{'title': "> select * from test where a like 'a%'", - 'headers': ['a'], 'rows': [('abc',)], 'status': None}, - {'title': "> select * from test where a like 'd%'", - 'headers': ['a'], 'rows': [('def',)], 'status': None}] + expected = [ + {"title": "> select * from test where a like 'a%'", "headers": ["a"], "rows": [("abc",)], "status": None}, + {"title": "> select * from test where a like 'd%'", "headers": ["a"], "rows": [("def",)], "status": None}, + ] assert expected == results results = run(executor, "\\fd test-ad") - assert_result_equal(results, status='test-ad: Deleted') + assert_result_equal(results, status="test-ad: Deleted") @dbtest @pytest.mark.skipif(os.name == "nt", reason="Bug: fails on Windows, needs fixing, singleton of FQ not working right") def test_favorite_query_expanded_output(executor): set_expanded_output(False) - run(executor, '''create table test(a text)''') - run(executor, '''insert into test values('abc')''') + run(executor, """create table test(a text)""") + run(executor, """insert into test values('abc')""") results = run(executor, "\\fs test-ae select * from test") - assert_result_equal(results, status='Saved.') + assert_result_equal(results, status="Saved.") results = run(executor, "\\f test-ae \\G") assert is_expanded_output() is True - assert_result_equal(results, title='> select * from test', - headers=['a'], rows=[('abc',)], auto_status=False) + assert_result_equal(results, title="> select * from test", headers=["a"], rows=[("abc",)], auto_status=False) set_expanded_output(False) results = run(executor, "\\fd test-ae") - assert_result_equal(results, status='test-ae: Deleted') + assert_result_equal(results, status="test-ae: Deleted") @dbtest def test_special_command(executor): - results = run(executor, '\\?') - assert_result_equal(results, rows=('quit', '\\q', 'Quit.'), - headers='Command', assert_contains=True, - auto_status=False) + results = run(executor, "\\?") + assert_result_equal(results, rows=("quit", "\\q", "Quit."), headers="Command", assert_contains=True, auto_status=False) @dbtest def test_cd_command_without_a_folder_name(executor): - results = run(executor, 'system cd') - assert_result_equal(results, status='No folder name was provided.') + results = run(executor, "system cd") + assert_result_equal(results, status="No folder name was provided.") @dbtest def test_system_command_not_found(executor): - results = run(executor, 'system xyz') - if os.name=='nt': - assert_result_equal(results, status='OSError: The system cannot find the file specified', - assert_contains=True) + results = run(executor, "system xyz") + if os.name == "nt": + assert_result_equal(results, status="OSError: The system cannot find the file specified", assert_contains=True) else: - assert_result_equal(results, status='OSError: No such file or directory', - assert_contains=True) + assert_result_equal(results, status="OSError: No such file or directory", assert_contains=True) @dbtest def test_system_command_output(executor): eol = os.linesep test_dir = os.path.abspath(os.path.dirname(__file__)) - test_file_path = os.path.join(test_dir, 'test.txt') - results = run(executor, 'system cat {0}'.format(test_file_path)) - assert_result_equal(results, status=f'mycli rocks!{eol}') + test_file_path = os.path.join(test_dir, "test.txt") + results = run(executor, "system cat {0}".format(test_file_path)) + assert_result_equal(results, status=f"mycli rocks!{eol}") @dbtest def test_cd_command_current_dir(executor): test_path = os.path.abspath(os.path.dirname(__file__)) - run(executor, 'system cd {0}'.format(test_path)) + run(executor, "system cd {0}".format(test_path)) assert os.getcwd() == test_path @dbtest def test_unicode_support(executor): - results = run(executor, u"SELECT '日本語' AS japanese;") - assert_result_equal(results, headers=['japanese'], rows=[(u'日本語',)]) + results = run(executor, "SELECT '日本語' AS japanese;") + assert_result_equal(results, headers=["japanese"], rows=[("日本語",)]) @dbtest def test_timestamp_null(executor): - run(executor, '''create table ts_null(a timestamp null)''') - run(executor, '''insert into ts_null values(null)''') - results = run(executor, '''select * from ts_null''') - assert_result_equal(results, headers=['a'], - rows=[(None,)]) + run(executor, """create table ts_null(a timestamp null)""") + run(executor, """insert into ts_null values(null)""") + results = run(executor, """select * from ts_null""") + assert_result_equal(results, headers=["a"], rows=[(None,)]) @dbtest def test_datetime_null(executor): - run(executor, '''create table dt_null(a datetime null)''') - run(executor, '''insert into dt_null values(null)''') - results = run(executor, '''select * from dt_null''') - assert_result_equal(results, headers=['a'], - rows=[(None,)]) + run(executor, """create table dt_null(a datetime null)""") + run(executor, """insert into dt_null values(null)""") + results = run(executor, """select * from dt_null""") + assert_result_equal(results, headers=["a"], rows=[(None,)]) @dbtest def test_date_null(executor): - run(executor, '''create table date_null(a date null)''') - run(executor, '''insert into date_null values(null)''') - results = run(executor, '''select * from date_null''') - assert_result_equal(results, headers=['a'], rows=[(None,)]) + run(executor, """create table date_null(a date null)""") + run(executor, """insert into date_null values(null)""") + results = run(executor, """select * from date_null""") + assert_result_equal(results, headers=["a"], rows=[(None,)]) @dbtest def test_time_null(executor): - run(executor, '''create table time_null(a time null)''') - run(executor, '''insert into time_null values(null)''') - results = run(executor, '''select * from time_null''') - assert_result_equal(results, headers=['a'], rows=[(None,)]) + run(executor, """create table time_null(a time null)""") + run(executor, """insert into time_null values(null)""") + results = run(executor, """select * from time_null""") + assert_result_equal(results, headers=["a"], rows=[(None,)]) @dbtest def test_multiple_results(executor): - query = '''CREATE PROCEDURE dmtest() + query = """CREATE PROCEDURE dmtest() BEGIN SELECT 1; SELECT 2; - END''' + END""" executor.conn.cursor().execute(query) - results = run(executor, 'call dmtest;') + results = run(executor, "call dmtest;") expected = [ - {'title': None, 'rows': [(1,)], 'headers': ['1'], - 'status': '1 row in set'}, - {'title': None, 'rows': [(2,)], 'headers': ['2'], - 'status': '1 row in set'} + {"title": None, "rows": [(1,)], "headers": ["1"], "status": "1 row in set"}, + {"title": None, "rows": [(2,)], "headers": ["2"], "status": "1 row in set"}, ] assert results == expected @pytest.mark.parametrize( - 'version_string, species, parsed_version_string, version', + "version_string, species, parsed_version_string, version", ( - ('5.7.25-TiDB-v6.1.0','TiDB', '6.1.0', 60100), - ('8.0.11-TiDB-v7.2.0-alpha-69-g96e9e68daa', 'TiDB', '7.2.0', 70200), - ('5.7.32-35', 'Percona', '5.7.32', 50732), - ('5.7.32-0ubuntu0.18.04.1', 'MySQL', '5.7.32', 50732), - ('10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508), - ('5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508), - ('5.0.16-pro-nt-log', 'MySQL', '5.0.16', 50016), - ('5.1.5a-alpha', 'MySQL', '5.1.5', 50105), - ('unexpected version string', None, '', 0), - ('', None, '', 0), - (None, None, '', 0), - ) + ("5.7.25-TiDB-v6.1.0", "TiDB", "6.1.0", 60100), + ("8.0.11-TiDB-v7.2.0-alpha-69-g96e9e68daa", "TiDB", "7.2.0", 70200), + ("5.7.32-35", "Percona", "5.7.32", 50732), + ("5.7.32-0ubuntu0.18.04.1", "MySQL", "5.7.32", 50732), + ("10.5.8-MariaDB-1:10.5.8+maria~focal", "MariaDB", "10.5.8", 100508), + ("5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal", "MariaDB", "10.5.8", 100508), + ("5.0.16-pro-nt-log", "MySQL", "5.0.16", 50016), + ("5.1.5a-alpha", "MySQL", "5.1.5", 50105), + ("unexpected version string", None, "", 0), + ("", None, "", 0), + (None, None, "", 0), + ), ) def test_version_parsing(version_string, species, parsed_version_string, version): server_info = ServerInfo.from_version_string(version_string) diff --git a/test/test_tabular_output.py b/test/test_tabular_output.py index bdc1dbf..45e97af 100644 --- a/test/test_tabular_output.py +++ b/test/test_tabular_output.py @@ -2,8 +2,6 @@ from textwrap import dedent -from mycli.packages.tabular_output import sql_format -from cli_helpers.tabular_output import TabularOutputFormatter from .utils import USER, PASSWORD, HOST, PORT, dbtest @@ -23,20 +21,17 @@ def mycli(): @dbtest def test_sql_output(mycli): """Test the sql output adapter.""" - headers = ['letters', 'number', 'optional', 'float', 'binary'] + headers = ["letters", "number", "optional", "float", "binary"] class FakeCursor(object): def __init__(self): - self.data = [ - ('abc', 1, None, 10.0, b'\xAA'), - ('d', 456, '1', 0.5, b'\xAA\xBB') - ] + self.data = [("abc", 1, None, 10.0, b"\xaa"), ("d", 456, "1", 0.5, b"\xaa\xbb")] self.description = [ (None, FIELD_TYPE.VARCHAR), (None, FIELD_TYPE.LONG), (None, FIELD_TYPE.LONG), (None, FIELD_TYPE.FLOAT), - (None, FIELD_TYPE.BLOB) + (None, FIELD_TYPE.BLOB), ] def __iter__(self): @@ -52,12 +47,11 @@ def test_sql_output(mycli): return self.description # Test sql-update output format - assert list(mycli.change_table_format("sql-update")) == \ - [(None, None, None, 'Changed table format to sql-update')] + assert list(mycli.change_table_format("sql-update")) == [(None, None, None, "Changed table format to sql-update")] mycli.formatter.query = "" output = mycli.format_output(None, FakeCursor(), headers) actual = "\n".join(output) - assert actual == dedent('''\ + assert actual == dedent("""\ UPDATE `DUAL` SET `number` = 1 , `optional` = NULL @@ -69,13 +63,12 @@ def test_sql_output(mycli): , `optional` = '1' , `float` = 0.5e0 , `binary` = X'aabb' - WHERE `letters` = 'd';''') + WHERE `letters` = 'd';""") # Test sql-update-2 output format - assert list(mycli.change_table_format("sql-update-2")) == \ - [(None, None, None, 'Changed table format to sql-update-2')] + assert list(mycli.change_table_format("sql-update-2")) == [(None, None, None, "Changed table format to sql-update-2")] mycli.formatter.query = "" output = mycli.format_output(None, FakeCursor(), headers) - assert "\n".join(output) == dedent('''\ + assert "\n".join(output) == dedent("""\ UPDATE `DUAL` SET `optional` = NULL , `float` = 10.0e0 @@ -85,34 +78,31 @@ def test_sql_output(mycli): `optional` = '1' , `float` = 0.5e0 , `binary` = X'aabb' - WHERE `letters` = 'd' AND `number` = 456;''') + WHERE `letters` = 'd' AND `number` = 456;""") # Test sql-insert output format (without table name) - assert list(mycli.change_table_format("sql-insert")) == \ - [(None, None, None, 'Changed table format to sql-insert')] + assert list(mycli.change_table_format("sql-insert")) == [(None, None, None, "Changed table format to sql-insert")] mycli.formatter.query = "" output = mycli.format_output(None, FakeCursor(), headers) - assert "\n".join(output) == dedent('''\ + assert "\n".join(output) == dedent("""\ INSERT INTO `DUAL` (`letters`, `number`, `optional`, `float`, `binary`) VALUES ('abc', 1, NULL, 10.0e0, X'aa') , ('d', 456, '1', 0.5e0, X'aabb') - ;''') + ;""") # Test sql-insert output format (with table name) - assert list(mycli.change_table_format("sql-insert")) == \ - [(None, None, None, 'Changed table format to sql-insert')] + assert list(mycli.change_table_format("sql-insert")) == [(None, None, None, "Changed table format to sql-insert")] mycli.formatter.query = "SELECT * FROM `table`" output = mycli.format_output(None, FakeCursor(), headers) - assert "\n".join(output) == dedent('''\ + assert "\n".join(output) == dedent("""\ INSERT INTO table (`letters`, `number`, `optional`, `float`, `binary`) VALUES ('abc', 1, NULL, 10.0e0, X'aa') , ('d', 456, '1', 0.5e0, X'aabb') - ;''') + ;""") # Test sql-insert output format (with database + table name) - assert list(mycli.change_table_format("sql-insert")) == \ - [(None, None, None, 'Changed table format to sql-insert')] + assert list(mycli.change_table_format("sql-insert")) == [(None, None, None, "Changed table format to sql-insert")] mycli.formatter.query = "SELECT * FROM `database`.`table`" output = mycli.format_output(None, FakeCursor(), headers) - assert "\n".join(output) == dedent('''\ + assert "\n".join(output) == dedent("""\ INSERT INTO database.table (`letters`, `number`, `optional`, `float`, `binary`) VALUES ('abc', 1, NULL, 10.0e0, X'aa') , ('d', 456, '1', 0.5e0, X'aabb') - ;''') + ;""") diff --git a/test/utils.py b/test/utils.py index ab12248..383f502 100644 --- a/test/utils.py +++ b/test/utils.py @@ -9,20 +9,18 @@ import pytest from mycli.main import special -PASSWORD = os.getenv('PYTEST_PASSWORD') -USER = os.getenv('PYTEST_USER', 'root') -HOST = os.getenv('PYTEST_HOST', 'localhost') -PORT = int(os.getenv('PYTEST_PORT', 3306)) -CHARSET = os.getenv('PYTEST_CHARSET', 'utf8') -SSH_USER = os.getenv('PYTEST_SSH_USER', None) -SSH_HOST = os.getenv('PYTEST_SSH_HOST', None) -SSH_PORT = os.getenv('PYTEST_SSH_PORT', 22) +PASSWORD = os.getenv("PYTEST_PASSWORD") +USER = os.getenv("PYTEST_USER", "root") +HOST = os.getenv("PYTEST_HOST", "localhost") +PORT = int(os.getenv("PYTEST_PORT", 3306)) +CHARSET = os.getenv("PYTEST_CHARSET", "utf8") +SSH_USER = os.getenv("PYTEST_SSH_USER", None) +SSH_HOST = os.getenv("PYTEST_SSH_HOST", None) +SSH_PORT = os.getenv("PYTEST_SSH_PORT", 22) def db_connection(dbname=None): - conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname, - password=PASSWORD, charset=CHARSET, - local_infile=False) + conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname, password=PASSWORD, charset=CHARSET, local_infile=False) conn.autocommit = True return conn @@ -30,20 +28,18 @@ def db_connection(dbname=None): try: db_connection() CAN_CONNECT_TO_DB = True -except: +except Exception: CAN_CONNECT_TO_DB = False -dbtest = pytest.mark.skipif( - not CAN_CONNECT_TO_DB, - reason="Need a mysql instance at localhost accessible by user 'root'") +dbtest = pytest.mark.skipif(not CAN_CONNECT_TO_DB, reason="Need a mysql instance at localhost accessible by user 'root'") def create_db(dbname): with db_connection().cursor() as cur: try: - cur.execute('''DROP DATABASE IF EXISTS mycli_test_db''') - cur.execute('''CREATE DATABASE mycli_test_db''') - except: + cur.execute("""DROP DATABASE IF EXISTS mycli_test_db""") + cur.execute("""CREATE DATABASE mycli_test_db""") + except Exception: pass @@ -53,8 +49,7 @@ def run(executor, sql, rows_as_list=True): for title, rows, headers, status in executor.run(sql): rows = list(rows) if (rows_as_list and rows) else rows - result.append({'title': title, 'rows': rows, 'headers': headers, - 'status': status}) + result.append({"title": title, "rows": rows, "headers": headers, "status": status}) return result @@ -87,8 +82,6 @@ def send_ctrl_c(wait_seconds): Returns the `multiprocessing.Process` created. """ - ctrl_c_process = multiprocessing.Process( - target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds) - ) + ctrl_c_process = multiprocessing.Process(target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds)) ctrl_c_process.start() return ctrl_c_process @@ -1,15 +1,21 @@ [tox] -envlist = py36, py37, py38 +envlist = py [testenv] -deps = pytest - mock - pexpect - behave - coverage -commands = python setup.py test +skip_install = true +deps = uv passenv = PYTEST_HOST PYTEST_USER PYTEST_PASSWORD PYTEST_PORT PYTEST_CHARSET +commands = uv pip install -e .[dev,ssh] + coverage run -m pytest -v test + coverage report -m + behave test/features + +[testenv:style] +skip_install = true +deps = ruff +commands = ruff check --fix + ruff format |