diff options
42 files changed, 949 insertions, 178 deletions
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d5b33bd..68a69ac 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,6 +1,9 @@ name: pgcli on: + push: + branches: + - main pull_request: paths-ignore: - '**.rst' @@ -11,7 +14,7 @@ jobs: strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] services: postgres: @@ -28,10 +31,10 @@ jobs: --health-retries 5 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} @@ -64,6 +67,10 @@ jobs: psql -h localhost -U postgres -p 6432 pgbouncer -c 'show help' + - name: Install beta version of pendulum + run: pip install pendulum==3.0.0b1 + if: matrix.python-version == '3.12' + - name: Install requirements run: | pip install -U pip setuptools @@ -72,7 +79,7 @@ jobs: pip install keyrings.alt>=3.1 - name: Run unit tests - run: coverage run --source pgcli -m py.test + run: coverage run --source pgcli -m pytest - name: Run integration tests env: @@ -86,7 +93,7 @@ jobs: - name: Run Black run: black --check . - if: matrix.python-version == '3.7' + if: matrix.python-version == '3.8' - name: Coverage run: | diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 0000000..c9232c7 --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,41 @@ +name: "CodeQL" + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + schedule: + - cron: "29 13 * * 1" + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ python ] + + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Initialize CodeQL + uses: github/codeql-action/init@v2 + with: + languages: ${{ matrix.language }} + queries: +security-and-quality + + - name: Autobuild + uses: github/codeql-action/autobuild@v2 + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v2 + with: + category: "/language:${{ matrix.language }}" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 67ba03d..8462cc2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ repos: - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 23.3.0 hooks: - id: black @@ -123,6 +123,13 @@ Contributors: * Daniel Kukula (dkuku) * Kian-Meng Ang (kianmeng) * Liu Zhao (astroshot) + * Rigo Neri (rigoneri) + * Anna Glasgall (annathyst) + * Andy Schoenberger (andyscho) + * Damien Baty (dbaty) + * blag + * Rob Berry (rob-b) + * Sharon Yogev (sharonyogev) Creator: -------- diff --git a/DEVELOP.rst b/DEVELOP.rst index 4cde694..aed2cf8 100644 --- a/DEVELOP.rst +++ b/DEVELOP.rst @@ -165,8 +165,9 @@ in the ``tests`` directory. An example:: First, install the requirements for testing: :: - - $ pip install -r requirements-dev.txt + $ pip install -U pip setuptools + $ pip install --no-cache-dir ".[sshtunnel]" + $ pip install -r requirements-dev.txt Ensure that the database user has permissions to create and drop test databases by checking your ``pg_hba.conf`` file. The default user should be ``postgres`` @@ -157,8 +157,9 @@ get this running in a development setup. https://github.com/dbcli/pgcli/blob/master/DEVELOP.rst -Please feel free to reach out to me if you need help. -My email: amjith.r@gmail.com, Twitter: `@amjithr <http://twitter.com/amjithr>`_ +Please feel free to reach out to us if you need help. +* Amjith, pgcli author: amjith.r@gmail.com, Twitter: `@amjithr <http://twitter.com/amjithr>`_ +* Irina, pgcli maintainer: i.chernyavska@gmail.com, Twitter: `@irinatruong <http://twitter.com/irinatruong>`_ Detailed Installation Instructions: ----------------------------------- @@ -351,8 +352,7 @@ choice: In [3]: my_result = _ -Pgcli only runs on Python3.7+ since 4.0.0, if you use an old version of Python, -you should use install ``pgcli <= 4.0.0``. +Pgcli dropped support for Python<3.8 as of 4.0.0. If you need it, install ``pgcli <= 4.0.0``. Thanks: ------- @@ -372,8 +372,8 @@ interface to Postgres database. Thanks to all the beta testers and contributors for your time and patience. :) -.. |Build Status| image:: https://github.com/dbcli/pgcli/workflows/pgcli/badge.svg - :target: https://github.com/dbcli/pgcli/actions?query=workflow%3Apgcli +.. |Build Status| image:: https://github.com/dbcli/pgcli/actions/workflows/ci.yml/badge.svg?branch=main + :target: https://github.com/dbcli/pgcli/actions/workflows/ci.yml .. |CodeCov| image:: https://codecov.io/gh/dbcli/pgcli/branch/master/graph/badge.svg :target: https://codecov.io/gh/dbcli/pgcli diff --git a/changelog.rst b/changelog.rst index a9c8217..7d08839 100644 --- a/changelog.rst +++ b/changelog.rst @@ -1,3 +1,52 @@ +================== +4.0.1 (2023-11-30) +================== + +Internal: +--------- +* Allow stable version of pendulum. + +================== +4.0.0 (2023-11-27) +================== + +Features: +--------- + +* Ask for confirmation when quitting cli while a transaction is ongoing. +* New `destructive_statements_require_transaction` config option to refuse to execute a + destructive SQL statement if outside a transaction. This option is off by default. +* Changed the `destructive_warning` config to be a list of commands that are considered + destructive. This would allow you to be warned on `create`, `grant`, or `insert` queries. +* Destructive warnings will now include the alias dsn connection string name if provided (-D option). +* pgcli.magic will now work with connection URLs that use TLS client certificates for authentication +* Have config option to retry queries on operational errors like connections being lost. + Also prevents getting stuck in a retry loop. +* Config option to not restart connection when cancelling a `destructive_warning` query. By default, + it will now not restart. +* Config option to always run with a single connection. +* Add comment explaining default LESS environment variable behavior and change example pager setting. +* Added `\echo` & `\qecho` special commands. ([issue 1335](https://github.com/dbcli/pgcli/issues/1335)). + +Bug fixes: +---------- + +* Fix `\ev` not producing a correctly quoted "schema"."view" +* Fix 'invalid connection option "dsn"' ([issue 1373](https://github.com/dbcli/pgcli/issues/1373)). +* Fix explain mode when used with `expand`, `auto_expand`, or `--explain-vertical-output` ([issue 1393](https://github.com/dbcli/pgcli/issues/1393)). +* Fix sql-insert format emits NULL as 'None' ([issue 1408](https://github.com/dbcli/pgcli/issues/1408)). +* Improve check for prompt-toolkit 3.0.6 ([issue 1416](https://github.com/dbcli/pgcli/issues/1416)). +* Allow specifying an `alias_map_file` in the config that will use + predetermined table aliases instead of generating aliases programmatically on + the fly +* Fixed SQL error when there is a comment on the first line: ([issue 1403](https://github.com/dbcli/pgcli/issues/1403)) +* Fix wrong usage of prompt instead of confirm when confirm execution of destructive query + +Internal: +--------- + +* Drop support for Python 3.7 and add 3.12. + 3.5.0 (2022/09/15): =================== diff --git a/pgcli/__init__.py b/pgcli/__init__.py index dcbfb52..76ad18b 100644 --- a/pgcli/__init__.py +++ b/pgcli/__init__.py @@ -1 +1 @@ -__version__ = "3.5.0" +__version__ = "4.0.1" diff --git a/pgcli/auth.py b/pgcli/auth.py index 342c412..2f1e552 100644 --- a/pgcli/auth.py +++ b/pgcli/auth.py @@ -26,7 +26,9 @@ def keyring_initialize(keyring_enabled, *, logger): try: keyring = importlib.import_module("keyring") - except Exception as e: # ImportError for Python 2, ModuleNotFoundError for Python 3 + except ( + ModuleNotFoundError + ) as e: # ImportError for Python 2, ModuleNotFoundError for Python 3 logger.warning("import keyring failed: %r.", e) diff --git a/pgcli/completion_refresher.py b/pgcli/completion_refresher.py index 1039d51..c887cb6 100644 --- a/pgcli/completion_refresher.py +++ b/pgcli/completion_refresher.py @@ -6,7 +6,6 @@ from .pgcompleter import PGCompleter class CompletionRefresher: - refreshers = OrderedDict() def __init__(self): @@ -39,7 +38,7 @@ class CompletionRefresher: args=(executor, special, callbacks, history, settings), name="completion_refresh", ) - self._completer_thread.setDaemon(True) + self._completer_thread.daemon = True self._completer_thread.start() return [ (None, None, None, "Auto-completion refresh started in the background.") diff --git a/pgcli/explain_output_formatter.py b/pgcli/explain_output_formatter.py index b14cf44..ce45b4f 100644 --- a/pgcli/explain_output_formatter.py +++ b/pgcli/explain_output_formatter.py @@ -10,7 +10,8 @@ class ExplainOutputFormatter: self.max_width = max_width def format_output(self, cur, headers, **output_kwargs): - (data,) = cur.fetchone() + # explain query results should always contain 1 row each + [(data,)] = list(cur) explain_list = json.loads(data) visualizer = Visualizer(self.max_width) for explain in explain_list: diff --git a/pgcli/magic.py b/pgcli/magic.py index 6e58f28..09902a2 100644 --- a/pgcli/magic.py +++ b/pgcli/magic.py @@ -43,7 +43,7 @@ def pgcli_line_magic(line): u = conn.session.engine.url _logger.debug("New pgcli: %r", str(u)) - pgcli.connect(u.database, u.host, u.username, u.port, u.password) + pgcli.connect_uri(str(u._replace(drivername="postgres"))) conn._pgcli = pgcli # For convenience, print the connection alias diff --git a/pgcli/main.py b/pgcli/main.py index 0fa264f..f95c800 100644 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -63,15 +63,14 @@ from .config import ( ) from .key_bindings import pgcli_bindings from .packages.formatter.sqlformatter import register_new_formatter -from .packages.prompt_utils import confirm_destructive_query +from .packages.prompt_utils import confirm, confirm_destructive_query +from .packages.parseutils import is_destructive +from .packages.parseutils import parse_destructive_warning from .__init__ import __version__ click.disable_unicode_literals_warning = True -try: - from urlparse import urlparse, unquote, parse_qs -except ImportError: - from urllib.parse import urlparse, unquote, parse_qs +from urllib.parse import urlparse from getpass import getuser @@ -201,6 +200,9 @@ class PGCli: self.multiline_mode = c["main"].get("multi_line_mode", "psql") self.vi_mode = c["main"].as_bool("vi") self.auto_expand = auto_vertical_output or c["main"].as_bool("auto_expand") + self.auto_retry_closed_connection = c["main"].as_bool( + "auto_retry_closed_connection" + ) self.expanded_output = c["main"].as_bool("expand") self.pgspecial.timing_enabled = c["main"].as_bool("timing") if row_limit is not None: @@ -224,11 +226,16 @@ class PGCli: self.syntax_style = c["main"]["syntax_style"] self.cli_style = c["colors"] self.wider_completion_menu = c["main"].as_bool("wider_completion_menu") - self.destructive_warning = warn or c["main"]["destructive_warning"] - # also handle boolean format of destructive warning - self.destructive_warning = {"true": "all", "false": "off"}.get( - self.destructive_warning.lower(), self.destructive_warning + self.destructive_warning = parse_destructive_warning( + warn or c["main"].as_list("destructive_warning") + ) + self.destructive_warning_restarts_connection = c["main"].as_bool( + "destructive_warning_restarts_connection" + ) + self.destructive_statements_require_transaction = c["main"].as_bool( + "destructive_statements_require_transaction" ) + self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty") self.null_string = c["main"].get("null_string", "<null>") self.prompt_format = ( @@ -258,6 +265,9 @@ class PGCli: # Initialize completer smart_completion = c["main"].as_bool("smart_completion") keyword_casing = c["main"]["keyword_casing"] + single_connection = single_connection or c["main"].as_bool( + "always_use_single_connection" + ) self.settings = { "casing_file": get_casing_file(c), "generate_casing_file": c["main"].as_bool("generate_casing_file"), @@ -269,6 +279,7 @@ class PGCli: "single_connection": single_connection, "less_chatty": less_chatty, "keyword_casing": keyword_casing, + "alias_map_file": c["main"]["alias_map_file"] or None, } completer = PGCompleter( @@ -292,7 +303,6 @@ class PGCli: raise PgCliQuitError def register_special_commands(self): - self.pgspecial.register( self.change_db, "\\c", @@ -354,6 +364,23 @@ class PGCli: "Change the table format used to output results", ) + self.pgspecial.register( + self.echo, + "\\echo", + "\\echo [string]", + "Echo a string to stdout", + ) + + self.pgspecial.register( + self.echo, + "\\qecho", + "\\qecho [string]", + "Echo a string to the query output channel.", + ) + + def echo(self, pattern, **_): + return [(None, None, None, pattern)] + def change_table_format(self, pattern, **_): try: if pattern not in TabularOutputFormatter().supported_formats: @@ -423,12 +450,20 @@ class PGCli: except OSError as e: return [(None, None, None, str(e), "", False, True)] - if ( - self.destructive_warning != "off" - and confirm_destructive_query(query, self.destructive_warning) is False - ): - message = "Wise choice. Command execution stopped." - return [(None, None, None, message)] + if self.destructive_warning: + if ( + self.destructive_statements_require_transaction + and not self.pgexecute.valid_transaction() + and is_destructive(query, self.destructive_warning) + ): + message = "Destructive statements must be run within a transaction. Command execution stopped." + return [(None, None, None, message)] + destroy = confirm_destructive_query( + query, self.destructive_warning, self.dsn_alias + ) + if destroy is False: + message = "Wise choice. Command execution stopped." + return [(None, None, None, message)] on_error_resume = self.on_error == "RESUME" return self.pgexecute.run( @@ -456,7 +491,6 @@ class PGCli: return [(None, None, None, message, "", True, True)] def initialize_logging(self): - log_file = self.config["main"]["log_file"] if log_file == "default": log_file = config_location() + "log" @@ -687,34 +721,52 @@ class PGCli: editor_command = special.editor_command(text) return text - def execute_command(self, text): + def execute_command(self, text, handle_closed_connection=True): logger = self.logger query = MetaQuery(query=text, successful=False) try: - if self.destructive_warning != "off": - destroy = confirm = confirm_destructive_query( - text, self.destructive_warning + if self.destructive_warning: + if ( + self.destructive_statements_require_transaction + and not self.pgexecute.valid_transaction() + and is_destructive(text, self.destructive_warning) + ): + click.secho( + "Destructive statements must be run within a transaction." + ) + raise KeyboardInterrupt + destroy = confirm_destructive_query( + text, self.destructive_warning, self.dsn_alias ) if destroy is False: click.secho("Wise choice!") raise KeyboardInterrupt elif destroy: click.secho("Your call!") + output, query = self._evaluate_command(text) except KeyboardInterrupt: - # Restart connection to the database - self.pgexecute.connect() - logger.debug("cancelled query, sql: %r", text) - click.secho("cancelled query", err=True, fg="red") + if self.destructive_warning_restarts_connection: + # Restart connection to the database + self.pgexecute.connect() + logger.debug("cancelled query and restarted connection, sql: %r", text) + click.secho( + "cancelled query and restarted connection", err=True, fg="red" + ) + else: + logger.debug("cancelled query, sql: %r", text) + click.secho("cancelled query", err=True, fg="red") except NotImplementedError: click.secho("Not Yet Implemented.", fg="yellow") except OperationalError as e: logger.error("sql: %r, error: %r", text, e) logger.error("traceback: %r", traceback.format_exc()) - self._handle_server_closed_connection(text) - except (PgCliQuitError, EOFError) as e: + click.secho(str(e), err=True, fg="red") + if handle_closed_connection: + self._handle_server_closed_connection(text) + except (PgCliQuitError, EOFError): raise except Exception as e: logger.error("sql: %r, error: %r", text, e) @@ -722,7 +774,9 @@ class PGCli: click.secho(str(e), err=True, fg="red") else: try: - if self.output_file and not text.startswith(("\\o ", "\\? ")): + if self.output_file and not text.startswith( + ("\\o ", "\\? ", "\\echo ") + ): try: with open(self.output_file, "a", encoding="utf-8") as f: click.echo(text, file=f) @@ -766,6 +820,34 @@ class PGCli: logger.debug("Search path: %r", self.completer.search_path) return query + def _check_ongoing_transaction_and_allow_quitting(self): + """Return whether we can really quit, possibly by asking the + user to confirm so if there is an ongoing transaction. + """ + if not self.pgexecute.valid_transaction(): + return True + while 1: + try: + choice = click.prompt( + "A transaction is ongoing. Choose `c` to COMMIT, `r` to ROLLBACK, `a` to abort exit.", + default="a", + ) + except click.Abort: + # Print newline if user aborts with `^C`, otherwise + # pgcli's prompt will be printed on the same line + # (just after the confirmation prompt). + click.echo(None, err=False) + choice = "a" + choice = choice.lower() + if choice == "a": + return False # do not quit + if choice == "c": + query = self.execute_command("commit") + return query.successful # quit only if query is successful + if choice == "r": + query = self.execute_command("rollback") + return query.successful # quit only if query is successful + def run_cli(self): logger = self.logger @@ -788,6 +870,10 @@ class PGCli: text = self.prompt_app.prompt() except KeyboardInterrupt: continue + except EOFError: + if not self._check_ongoing_transaction_and_allow_quitting(): + continue + raise try: text = self.handle_editor_command(text) @@ -797,7 +883,12 @@ class PGCli: click.secho(str(e), err=True, fg="red") continue - self.handle_watch_command(text) + try: + self.handle_watch_command(text) + except PgCliQuitError: + if not self._check_ongoing_transaction_and_allow_quitting(): + continue + raise self.now = dt.datetime.today() @@ -1036,10 +1127,17 @@ class PGCli: click.secho("Reconnecting...", fg="green") self.pgexecute.connect() click.secho("Reconnected!", fg="green") - self.execute_command(text) except OperationalError as e: click.secho("Reconnect Failed", fg="red") click.secho(str(e), err=True, fg="red") + else: + retry = self.auto_retry_closed_connection or confirm( + "Run the query from before reconnecting?" + ) + if retry: + click.secho("Running query...", fg="green") + # Don't get stuck in a retry loop + self.execute_command(text, handle_closed_connection=False) def refresh_completions(self, history=None, persist_priorities="all"): """Refresh outdated completions @@ -1266,7 +1364,6 @@ class PGCli: @click.option( "--warn", default=None, - type=click.Choice(["all", "moderate", "off"]), help="Warn before running a destructive query.", ) @click.option( @@ -1575,7 +1672,8 @@ def format_output(title, cur, headers, status, settings, explain_mode=False): first_line = next(formatted) formatted = itertools.chain([first_line], formatted) if ( - not expanded + not explain_mode + and not expanded and max_width and len(strip_ansi(first_line)) > max_width and headers diff --git a/pgcli/packages/formatter/sqlformatter.py b/pgcli/packages/formatter/sqlformatter.py index 5bf25fe..5224eff 100644 --- a/pgcli/packages/formatter/sqlformatter.py +++ b/pgcli/packages/formatter/sqlformatter.py @@ -14,10 +14,13 @@ preprocessors = () def escape_for_sql_statement(value): + if value is None: + return "NULL" + if isinstance(value, bytes): return f"X'{value.hex()}'" - else: - return "'{}'".format(value) + + return "'{}'".format(value) def adapter(data, headers, table_format=None, **kwargs): @@ -29,7 +32,7 @@ def adapter(data, headers, table_format=None, **kwargs): else: table_name = table[1] else: - table_name = '"DUAL"' + table_name = "DUAL" if table_format == "sql-insert": h = '", "'.join(headers) yield 'INSERT INTO "{}" ("{}") VALUES'.format(table_name, h) diff --git a/pgcli/packages/parseutils/__init__.py b/pgcli/packages/parseutils/__init__.py index 1acc008..023e13b 100644 --- a/pgcli/packages/parseutils/__init__.py +++ b/pgcli/packages/parseutils/__init__.py @@ -1,6 +1,17 @@ import sqlparse +BASE_KEYWORDS = [ + "drop", + "shutdown", + "delete", + "truncate", + "alter", + "unconditional_update", +] +ALL_KEYWORDS = BASE_KEYWORDS + ["update"] + + def query_starts_with(formatted_sql, prefixes): """Check if the query starts with any item from *prefixes*.""" prefixes = [prefix.lower() for prefix in prefixes] @@ -13,22 +24,35 @@ def query_is_unconditional_update(formatted_sql): return bool(tokens) and tokens[0] == "update" and "where" not in tokens -def query_is_simple_update(formatted_sql): - """Check if the query starts with UPDATE.""" - tokens = formatted_sql.split() - return bool(tokens) and tokens[0] == "update" - - -def is_destructive(queries, warning_level="all"): +def is_destructive(queries, keywords): """Returns if any of the queries in *queries* is destructive.""" - keywords = ("drop", "shutdown", "delete", "truncate", "alter") for query in sqlparse.split(queries): if query: formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip() - if query_starts_with(formatted_sql, keywords): - return True - if query_is_unconditional_update(formatted_sql): + if "unconditional_update" in keywords and query_is_unconditional_update( + formatted_sql + ): return True - if warning_level == "all" and query_is_simple_update(formatted_sql): + if query_starts_with(formatted_sql, keywords): return True return False + + +def parse_destructive_warning(warning_level): + """Converts a deprecated destructive warning option to a list of command keywords.""" + if not warning_level: + return [] + + if not isinstance(warning_level, list): + if "," in warning_level: + return warning_level.split(",") + warning_level = [warning_level] + + return { + "true": ALL_KEYWORDS, + "false": [], + "all": ALL_KEYWORDS, + "moderate": BASE_KEYWORDS, + "off": [], + "": [], + }.get(warning_level[0], warning_level) diff --git a/pgcli/packages/prompt_utils.py b/pgcli/packages/prompt_utils.py index e8589de..997b86e 100644 --- a/pgcli/packages/prompt_utils.py +++ b/pgcli/packages/prompt_utils.py @@ -3,7 +3,7 @@ import click from .parseutils import is_destructive -def confirm_destructive_query(queries, warning_level): +def confirm_destructive_query(queries, keywords, alias): """Check if the query is destructive and prompts the user to confirm. Returns: @@ -12,11 +12,13 @@ def confirm_destructive_query(queries, warning_level): * False if the query is destructive and the user doesn't want to proceed. """ - prompt_text = ( - "You're about to run a destructive command.\n" "Do you want to proceed? (y/n)" - ) - if is_destructive(queries, warning_level) and sys.stdin.isatty(): - return prompt(prompt_text, type=bool) + info = "You're about to run a destructive command" + if alias: + info += f" in {click.style(alias, fg='red')}" + + prompt_text = f"{info}.\nDo you want to proceed?" + if is_destructive(queries, keywords) and sys.stdin.isatty(): + return confirm(prompt_text) def confirm(*args, **kwargs): diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py index be4933a..b78edd6 100644 --- a/pgcli/packages/sqlcompletion.py +++ b/pgcli/packages/sqlcompletion.py @@ -290,7 +290,6 @@ def suggest_special(text): def suggest_based_on_last_token(token, stmt): - if isinstance(token, str): token_v = token.lower() elif isinstance(token, Comparison): @@ -399,7 +398,6 @@ def suggest_based_on_last_token(token, stmt): elif (token_v.endswith("join") and token.is_keyword) or ( token_v in ("copy", "from", "update", "into", "describe", "truncate") ): - schema = stmt.get_identifier_schema() tables = extract_tables(stmt.text_before_cursor) is_join = token_v.endswith("join") and token.is_keyword @@ -436,7 +434,6 @@ def suggest_based_on_last_token(token, stmt): try: prev = stmt.get_previous_token(token).value.lower() if prev in ("drop", "alter", "create", "create or replace"): - # Suggest functions from either the currently-selected schema or the # public schema if no schema has been specified suggest = [] diff --git a/pgcli/pgclirc b/pgcli/pgclirc index dcff63d..51f7eae 100644 --- a/pgcli/pgclirc +++ b/pgcli/pgclirc @@ -9,6 +9,10 @@ smart_completion = True # visible.) wider_completion_menu = False +# Do not create new connections for refreshing completions; Equivalent to +# always running with the --single-connection flag. +always_use_single_connection = False + # Multi-line mode allows breaking up the sql statements into multiple lines. If # this is set to True, then the end of the statements must have a semi-colon. # If this is set to False then sql statements can't be split into multiple @@ -22,14 +26,22 @@ multi_line = False # a command. multi_line_mode = psql -# Destructive warning mode will alert you before executing a sql statement +# Destructive warning will alert you before executing a sql statement # that may cause harm to the database such as "drop table", "drop database", # "shutdown", "delete", or "update". -# Possible values: -# "all" - warn on data definition statements, server actions such as SHUTDOWN, DELETE or UPDATE -# "moderate" - skip warning on UPDATE statements, except for unconditional updates -# "off" - skip all warnings -destructive_warning = all +# You can pass a list of destructive commands or leave it empty if you want to skip all warnings. +# "unconditional_update" will warn you of update statements that don't have a where clause +destructive_warning = drop, shutdown, delete, truncate, alter, update, unconditional_update + +# Destructive warning can restart the connection if this is enabled and the +# user declines. This means that any current uncommitted transaction can be +# aborted if the user doesn't want to proceed with a destructive_warning +# statement. +destructive_warning_restarts_connection = False + +# When this option is on (and if `destructive_warning` is not empty), +# destructive statements are not executed when outside of a transaction. +destructive_statements_require_transaction = False # Enables expand mode, which is similar to `\x` in psql. expand = False @@ -37,9 +49,21 @@ expand = False # Enables auto expand mode, which is similar to `\x auto` in psql. auto_expand = False +# Auto-retry queries on connection failures and other operational errors. If +# False, will prompt to rerun the failed query instead of auto-retrying. +auto_retry_closed_connection = True + # If set to True, table suggestions will include a table alias generate_aliases = False +# Path to a json file that specifies specific table aliases to use when generate_aliases is set to True +# the format for this file should be: +# { +# "some_table_name": "desired_alias", +# "some_other_table_name": "another_alias" +# } +alias_map_file = + # log_file location. # In Unix/Linux: ~/.config/pgcli/log # In Windows: %USERPROFILE%\AppData\Local\dbcli\pgcli\log @@ -83,9 +107,10 @@ qualify_columns = if_more_than_one_table # When no schema is entered, only suggest objects in search_path search_path_filter = False -# Default pager. -# By default 'PAGER' environment variable is used -# pager = less -SRXF +# Default pager. See https://www.pgcli.com/pager for more information on settings. +# By default 'PAGER' environment variable is used. If the pager is less, and the 'LESS' +# environment variable is not set, then LESS='-SRXF' will be automatically set. +# pager = less # Timing of sql statements and table rendering. timing = True @@ -140,7 +165,7 @@ less_chatty = False # \i - Postgres PID # \# - "@" sign if logged in as superuser, '>' in other case # \n - Newline -# \dsn_alias - name of dsn alias if -D option is used (empty otherwise) +# \dsn_alias - name of dsn connection string alias if -D option is used (empty otherwise) # \x1b[...m - insert ANSI escape sequence # eg: prompt = '\x1b[35m\u@\x1b[32m\h:\x1b[36m\d>' prompt = '\u@\h:\d> ' @@ -198,7 +223,8 @@ output.null = "#808080" # Named queries are queries you can execute by name. [named queries] -# DSN to call by -D option +# Here's where you can provide a list of connection string aliases. +# You can use it by passing the -D option. `pgcli -D example_dsn` [alias_dsn] # example_dsn = postgresql://[user[:password]@][netloc][:port][/dbname] diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py index e66c3dc..17fc540 100644 --- a/pgcli/pgcompleter.py +++ b/pgcli/pgcompleter.py @@ -1,3 +1,4 @@ +import json import logging import re from itertools import count, repeat, chain @@ -61,18 +62,38 @@ arg_default_type_strip_regex = re.compile(r"::[\w\.]+(\[\])?$") normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"' -def generate_alias(tbl): +def generate_alias(tbl, alias_map=None): """Generate a table alias, consisting of all upper-case letters in the table name, or, if there are no upper-case letters, the first letter + all letters preceded by _ param tbl - unescaped name of the table to alias """ + if alias_map and tbl in alias_map: + return alias_map[tbl] return "".join( [l for l in tbl if l.isupper()] or [l for l, prev in zip(tbl, "_" + tbl) if prev == "_" and l != "_"] ) +class InvalidMapFile(ValueError): + pass + + +def load_alias_map_file(path): + try: + with open(path) as fo: + alias_map = json.load(fo) + except FileNotFoundError as err: + raise InvalidMapFile( + f"Cannot read alias_map_file - {err.filename} does not exist" + ) + except json.JSONDecodeError: + raise InvalidMapFile(f"Cannot read alias_map_file - {path} is not valid json") + else: + return alias_map + + class PGCompleter(Completer): # keywords_tree: A dict mapping keywords to well known following keywords. # e.g. 'CREATE': ['TABLE', 'USER', ...], @@ -100,6 +121,11 @@ class PGCompleter(Completer): self.call_arg_oneliner_max = settings.get("call_arg_oneliner_max", 2) self.search_path_filter = settings.get("search_path_filter") self.generate_aliases = settings.get("generate_aliases") + alias_map_file = settings.get("alias_map_file") + if alias_map_file is not None: + self.alias_map = load_alias_map_file(alias_map_file) + else: + self.alias_map = None self.casing_file = settings.get("casing_file") self.insert_col_skip_patterns = [ re.compile(pattern) @@ -157,7 +183,6 @@ class PGCompleter(Completer): self.all_completions.update(additional_keywords) def extend_schemata(self, schemata): - # schemata is a list of schema names schemata = self.escaped_names(schemata) metadata = self.dbmetadata["tables"] @@ -226,7 +251,6 @@ class PGCompleter(Completer): self.all_completions.add(colname) def extend_functions(self, func_data): - # func_data is a list of function metadata namedtuples # dbmetadata['schema_name']['functions']['function_name'] should return @@ -260,7 +284,6 @@ class PGCompleter(Completer): } def extend_foreignkeys(self, fk_data): - # fk_data is a list of ForeignKey namedtuples, with fields # parentschema, childschema, parenttable, childtable, # parentcolumns, childcolumns @@ -283,7 +306,6 @@ class PGCompleter(Completer): parcolmeta.foreignkeys.append(fk) def extend_datatypes(self, type_data): - # dbmetadata['datatypes'][schema_name][type_name] should store type # metadata, such as composite type field names. Currently, we're not # storing any metadata beyond typename, so just store None @@ -697,7 +719,6 @@ class PGCompleter(Completer): return self.find_matches(word_before_cursor, conds, meta="join") def get_function_matches(self, suggestion, word_before_cursor, alias=False): - if suggestion.usage == "from": # Only suggest functions allowed in FROM clause diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index 8f2968d..497d681 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -1,7 +1,7 @@ import logging import traceback from collections import namedtuple - +import re import pgspecial as special import psycopg import psycopg.sql @@ -17,6 +17,27 @@ ViewDef = namedtuple( ) +# we added this funcion to strip beginning comments +# because sqlparse didn't handle tem well. It won't be needed if sqlparse +# does parsing of this situation better + + +def remove_beginning_comments(command): + # Regular expression pattern to match comments + pattern = r"^(/\*.*?\*/|--.*?)(?:\n|$)" + + # Find and remove all comments from the beginning + cleaned_command = command + comments = [] + match = re.match(pattern, cleaned_command, re.DOTALL) + while match: + comments.append(match.group()) + cleaned_command = cleaned_command[len(match.group()) :].lstrip() + match = re.match(pattern, cleaned_command, re.DOTALL) + + return [cleaned_command, comments] + + def register_typecasters(connection): """Casts date and timestamp values to string, resolves issues with out-of-range dates (e.g. BC) which psycopg can't handle""" @@ -76,7 +97,6 @@ class ProtocolSafeCursor(psycopg.Cursor): class PGExecute: - # The boolean argument to the current_schemas function indicates whether # implicit schemas, e.g. pg_catalog search_path_query = """ @@ -180,7 +200,6 @@ class PGExecute: dsn=None, **kwargs, ): - conn_params = self._conn_params.copy() new_params = { @@ -203,7 +222,11 @@ class PGExecute: conn_params.update({k: v for k, v in new_params.items() if v}) - conn_info = make_conninfo(**conn_params) + if "dsn" in conn_params: + other_params = {k: v for k, v in conn_params.items() if k != "dsn"} + conn_info = make_conninfo(conn_params["dsn"], **other_params) + else: + conn_info = make_conninfo(**conn_params) conn = psycopg.connect(conn_info) conn.cursor_factory = ProtocolSafeCursor @@ -309,21 +332,20 @@ class PGExecute: # sql parse doesn't split on a comment first + special # so we're going to do it - sqltemp = [] + removed_comments = [] sqlarr = [] + cleaned_command = "" - if statement.startswith("--"): - sqltemp = statement.split("\n") - sqlarr.append(sqltemp[0]) - for i in sqlparse.split(sqltemp[1]): - sqlarr.append(i) - elif statement.startswith("/*"): - sqltemp = statement.split("*/") - sqltemp[0] = sqltemp[0] + "*/" - for i in sqlparse.split(sqltemp[1]): - sqlarr.append(i) - else: - sqlarr = sqlparse.split(statement) + # could skip if statement doesn't match ^-- or ^/* + cleaned_command, removed_comments = remove_beginning_comments(statement) + + sqlarr = sqlparse.split(cleaned_command) + + # now re-add the beginning comments if there are any, so that they show up in + # log files etc when running these commands + + if len(removed_comments) > 0: + sqlarr = removed_comments + sqlarr # run each sql query for sql in sqlarr: @@ -470,7 +492,7 @@ class PGExecute: return ( psycopg.sql.SQL(template) .format( - name=psycopg.sql.Identifier(f"{result.nspname}.{result.relname}"), + name=psycopg.sql.Identifier(result.nspname, result.relname), stmt=psycopg.sql.SQL(result.viewdef), ) .as_string(self.conn) diff --git a/pgcli/pgtoolbar.py b/pgcli/pgtoolbar.py index 7b5883e..4a12ff4 100644 --- a/pgcli/pgtoolbar.py +++ b/pgcli/pgtoolbar.py @@ -1,18 +1,14 @@ -from pkg_resources import packaging - -import prompt_toolkit from prompt_toolkit.key_binding.vi_state import InputMode from prompt_toolkit.application import get_app -parse_version = packaging.version.parse - vi_modes = { InputMode.INSERT: "I", InputMode.NAVIGATION: "N", InputMode.REPLACE: "R", InputMode.INSERT_MULTIPLE: "M", } -if parse_version(prompt_toolkit.__version__) >= parse_version("3.0.6"): +# REPLACE_SINGLE is available in prompt_toolkit >= 3.0.6 +if "REPLACE_SINGLE" in {e.name for e in InputMode}: vi_modes[InputMode.REPLACE_SINGLE] = "R" diff --git a/pgcli/pyev.py b/pgcli/pyev.py index 202947f..2886c9c 100644 --- a/pgcli/pyev.py +++ b/pgcli/pyev.py @@ -146,7 +146,7 @@ class Visualizer: elif self.explain.get("Max Rows") < plan["Actual Rows"]: self.explain["Max Rows"] = plan["Actual Rows"] - if not self.explain.get("MaxCost"): + if not self.explain.get("Max Cost"): self.explain["Max Cost"] = plan["Actual Cost"] elif self.explain.get("Max Cost") < plan["Actual Cost"]: self.explain["Max Cost"] = plan["Actual Cost"] @@ -171,7 +171,7 @@ class Visualizer: return self.warning_format("%.2f ms" % value) elif value < 60000: return self.critical_format( - "%.2f s" % (value / 2000.0), + "%.2f s" % (value / 1000.0), ) else: return self.critical_format( diff --git a/pyproject.toml b/pyproject.toml index c9bf518..8477d72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.black] line-length = 88 -target-version = ['py36'] +target-version = ['py38'] include = '\.pyi?$' exclude = ''' /( @@ -19,4 +19,3 @@ exclude = ''' | tests/data )/ ''' - @@ -57,7 +57,7 @@ def version(version_file): def commit_for_release(version_file, ver): run_step("git", "reset") - run_step("git", "add", version_file) + run_step("git", "add", "-u") run_step("git", "commit", "--message", "Releasing version {}".format(ver)) diff --git a/requirements-dev.txt b/requirements-dev.txt index 9bf1117..15505a7 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,7 @@ pytest>=2.7.0 tox>=1.9.2 behave>=1.2.4 -black>=22.3.0 +black>=23.3.0 pexpect==3.3; platform_system != "Windows" pre-commit>=1.16.0 coverage>=5.0.4 @@ -10,4 +10,4 @@ docutils>=0.13.1 autopep8>=1.3.3 twine>=1.11.0 wheel>=0.33.6 -sshtunnel>=0.4.0
\ No newline at end of file +sshtunnel>=0.4.0 @@ -51,7 +51,7 @@ setup( "keyring": ["keyring >= 12.2.0"], "sshtunnel": ["sshtunnel >= 0.4.0"], }, - python_requires=">=3.7", + python_requires=">=3.8", entry_points=""" [console_scripts] pgcli=pgcli.main:cli @@ -62,10 +62,11 @@ setup( "Operating System :: Unix", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Programming Language :: SQL", "Topic :: Database", "Topic :: Database :: Front-Ends", diff --git a/tests/features/basic_commands.feature b/tests/features/basic_commands.feature index cd15306..ee497b9 100644 --- a/tests/features/basic_commands.feature +++ b/tests/features/basic_commands.feature @@ -23,6 +23,30 @@ Feature: run the cli, When we send "ctrl + d" then dbcli exits + Scenario: confirm exit when a transaction is ongoing + When we begin transaction + and we try to send "ctrl + d" + then we see ongoing transaction message + when we send "c" + then dbcli exits + + Scenario: cancel exit when a transaction is ongoing + When we begin transaction + and we try to send "ctrl + d" + then we see ongoing transaction message + when we send "a" + then we see dbcli prompt + when we rollback transaction + when we send "ctrl + d" + then dbcli exits + + Scenario: interrupt current query via "ctrl + c" + When we send sleep query + and we send "ctrl + c" + then we see cancelled query warning + when we check for any non-idle sleep queries + then we don't see any non-idle sleep queries + Scenario: list databases When we list databases then we see list of databases diff --git a/tests/features/crud_database.feature b/tests/features/crud_database.feature index ed13bbe..87da4e3 100644 --- a/tests/features/crud_database.feature +++ b/tests/features/crud_database.feature @@ -5,7 +5,7 @@ Feature: manipulate databases: When we create database then we see database created when we drop database - then we confirm the destructive warning + then we respond to the destructive warning: y then we see database dropped when we connect to dbserver then we see database connected diff --git a/tests/features/crud_table.feature b/tests/features/crud_table.feature index 1f9db4a..8a43c5c 100644 --- a/tests/features/crud_table.feature +++ b/tests/features/crud_table.feature @@ -8,15 +8,38 @@ Feature: manipulate tables: then we see table created when we insert into table then we see record inserted + when we select from table + then we see data selected: initial when we update table then we see record updated when we select from table - then we see data selected + then we see data selected: updated when we delete from table - then we confirm the destructive warning + then we respond to the destructive warning: y then we see record deleted when we drop table - then we confirm the destructive warning + then we respond to the destructive warning: y then we see table dropped when we connect to dbserver then we see database connected + + Scenario: transaction handling, with cancelling on a destructive warning. + When we connect to test database + then we see database connected + when we create table + then we see table created + when we begin transaction + then we see transaction began + when we insert into table + then we see record inserted + when we delete from table + then we respond to the destructive warning: n + when we select from table + then we see data selected: initial + when we rollback transaction + then we see transaction rolled back + when we select from table + then we see select output without data + when we drop table + then we respond to the destructive warning: y + then we see table dropped diff --git a/tests/features/environment.py b/tests/features/environment.py index 6cc8e14..50ac5fa 100644 --- a/tests/features/environment.py +++ b/tests/features/environment.py @@ -164,10 +164,24 @@ def before_step(context, _): context.atprompt = False +def is_known_problem(scenario): + """TODO: why is this not working in 3.12?""" + if sys.version_info >= (3, 12): + return scenario.name in ( + 'interrupt current query via "ctrl + c"', + "run the cli with --username", + "run the cli with --user", + "run the cli with --port", + ) + return False + + def before_scenario(context, scenario): if scenario.name == "list databases": # not using the cli for that return + if is_known_problem(scenario): + scenario.skip() currentdb = None if "pgbouncer" in scenario.feature.tags: if context.pgbouncer_available: diff --git a/tests/features/expanded.feature b/tests/features/expanded.feature index 4f381f8..e486048 100644 --- a/tests/features/expanded.feature +++ b/tests/features/expanded.feature @@ -7,7 +7,7 @@ Feature: expanded mode: and we select from table then we see expanded data selected when we drop table - then we confirm the destructive warning + then we respond to the destructive warning: y then we see table dropped Scenario: expanded off @@ -16,7 +16,7 @@ Feature: expanded mode: and we select from table then we see nonexpanded data selected when we drop table - then we confirm the destructive warning + then we respond to the destructive warning: y then we see table dropped Scenario: expanded auto @@ -25,5 +25,5 @@ Feature: expanded mode: and we select from table then we see auto data selected when we drop table - then we confirm the destructive warning + then we respond to the destructive warning: y then we see table dropped diff --git a/tests/features/steps/basic_commands.py b/tests/features/steps/basic_commands.py index 7c87814..687bdc0 100644 --- a/tests/features/steps/basic_commands.py +++ b/tests/features/steps/basic_commands.py @@ -64,13 +64,83 @@ def step_ctrl_d(context): """ Send Ctrl + D to hopefully exit. """ + step_try_to_ctrl_d(context) + context.cli.expect(pexpect.EOF, timeout=5) + context.exit_sent = True + + +@when('we try to send "ctrl + d"') +def step_try_to_ctrl_d(context): + """ + Send Ctrl + D, perhaps exiting, perhaps not (if a transaction is + ongoing). + """ # turn off pager before exiting context.cli.sendcontrol("c") context.cli.sendline(r"\pset pager off") wrappers.wait_prompt(context) context.cli.sendcontrol("d") - context.cli.expect(pexpect.EOF, timeout=5) - context.exit_sent = True + + +@when('we send "ctrl + c"') +def step_ctrl_c(context): + """Send Ctrl + c to hopefully interrupt.""" + context.cli.sendcontrol("c") + + +@then("we see cancelled query warning") +def step_see_cancelled_query_warning(context): + """ + Make sure we receive the warning that the current query was cancelled. + """ + wrappers.expect_exact(context, "cancelled query", timeout=2) + + +@then("we see ongoing transaction message") +def step_see_ongoing_transaction_error(context): + """ + Make sure we receive the warning that a transaction is ongoing. + """ + context.cli.expect("A transaction is ongoing.", timeout=2) + + +@when("we send sleep query") +def step_send_sleep_15_seconds(context): + """ + Send query to sleep for 15 seconds. + """ + context.cli.sendline("select pg_sleep(15)") + + +@when("we check for any non-idle sleep queries") +def step_check_for_active_sleep_queries(context): + """ + Send query to check for any non-idle pg_sleep queries. + """ + context.cli.sendline( + "select state from pg_stat_activity where query not like '%pg_stat_activity%' and query like '%pg_sleep%' and state != 'idle';" + ) + + +@then("we don't see any non-idle sleep queries") +def step_no_active_sleep_queries(context): + """Confirm that any pg_sleep queries are either idle or not active.""" + wrappers.expect_exact( + context, + context.conf["pager_boundary"] + + "\r" + + dedent( + """ + +-------+\r + | state |\r + |-------|\r + +-------+\r + SELECT 0\r + """ + ) + + context.conf["pager_boundary"], + timeout=5, + ) @when(r'we send "\?" command') @@ -131,18 +201,31 @@ def step_see_found(context): ) -@then("we confirm the destructive warning") -def step_confirm_destructive_command(context): - """Confirm destructive command.""" +@then("we respond to the destructive warning: {response}") +def step_resppond_to_destructive_command(context, response): + """Respond to destructive command.""" wrappers.expect_exact( context, - "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", + "You're about to run a destructive command.\r\nDo you want to proceed? [y/N]:", timeout=2, ) - context.cli.sendline("y") + context.cli.sendline(response.strip()) @then("we send password") def step_send_password(context): wrappers.expect_exact(context, "Password for", timeout=5) context.cli.sendline(context.conf["pass"] or "DOES NOT MATTER") + + +@when('we send "{text}"') +def step_send_text(context, text): + context.cli.sendline(text) + # Try to detect whether we are exiting. If so, set `exit_sent` + # so that `after_scenario` correctly cleans up. + try: + context.cli.expect(pexpect.EOF, timeout=0.2) + except pexpect.TIMEOUT: + pass + else: + context.exit_sent = True diff --git a/tests/features/steps/crud_table.py b/tests/features/steps/crud_table.py index 0375883..27d543e 100644 --- a/tests/features/steps/crud_table.py +++ b/tests/features/steps/crud_table.py @@ -9,6 +9,10 @@ from textwrap import dedent import wrappers +INITIAL_DATA = "xxx" +UPDATED_DATA = "yyy" + + @when("we create table") def step_create_table(context): """ @@ -22,7 +26,7 @@ def step_insert_into_table(context): """ Send insert into table. """ - context.cli.sendline("""insert into a(x) values('xxx');""") + context.cli.sendline(f"""insert into a(x) values('{INITIAL_DATA}');""") @when("we update table") @@ -30,7 +34,9 @@ def step_update_table(context): """ Send insert into table. """ - context.cli.sendline("""update a set x = 'yyy' where x = 'xxx';""") + context.cli.sendline( + f"""update a set x = '{UPDATED_DATA}' where x = '{INITIAL_DATA}';""" + ) @when("we select from table") @@ -46,7 +52,7 @@ def step_delete_from_table(context): """ Send deete from table. """ - context.cli.sendline("""delete from a where x = 'yyy';""") + context.cli.sendline(f"""delete from a where x = '{UPDATED_DATA}';""") @when("we drop table") @@ -57,6 +63,30 @@ def step_drop_table(context): context.cli.sendline("drop table a;") +@when("we alter the table") +def step_alter_table(context): + """ + Alter the table by adding a column. + """ + context.cli.sendline("""alter table a add column y varchar;""") + + +@when("we begin transaction") +def step_begin_transaction(context): + """ + Begin transaction + """ + context.cli.sendline("begin;") + + +@when("we rollback transaction") +def step_rollback_transaction(context): + """ + Rollback transaction + """ + context.cli.sendline("rollback;") + + @then("we see table created") def step_see_table_created(context): """ @@ -81,19 +111,20 @@ def step_see_record_updated(context): wrappers.expect_pager(context, "UPDATE 1\r\n", timeout=2) -@then("we see data selected") -def step_see_data_selected(context): +@then("we see data selected: {data}") +def step_see_data_selected(context, data): """ - Wait to see select output. + Wait to see select output with initial or updated data. """ + x = UPDATED_DATA if data == "updated" else INITIAL_DATA wrappers.expect_pager( context, dedent( - """\ + f"""\ +-----+\r | x |\r |-----|\r - | yyy |\r + | {x} |\r +-----+\r SELECT 1\r """ @@ -102,6 +133,26 @@ def step_see_data_selected(context): ) +@then("we see select output without data") +def step_see_no_data_selected(context): + """ + Wait to see select output without data. + """ + wrappers.expect_pager( + context, + dedent( + """\ + +---+\r + | x |\r + |---|\r + +---+\r + SELECT 0\r + """ + ), + timeout=1, + ) + + @then("we see record deleted") def step_see_data_deleted(context): """ @@ -116,3 +167,19 @@ def step_see_table_dropped(context): Wait to see drop output. """ wrappers.expect_pager(context, "DROP TABLE\r\n", timeout=2) + + +@then("we see transaction began") +def step_see_transaction_began(context): + """ + Wait to see transaction began. + """ + wrappers.expect_pager(context, "BEGIN\r\n", timeout=2) + + +@then("we see transaction rolled back") +def step_see_transaction_rolled_back(context): + """ + Wait to see transaction rollback. + """ + wrappers.expect_pager(context, "ROLLBACK\r\n", timeout=2) diff --git a/tests/features/steps/expanded.py b/tests/features/steps/expanded.py index ac84c41..302cab9 100644 --- a/tests/features/steps/expanded.py +++ b/tests/features/steps/expanded.py @@ -16,7 +16,7 @@ def step_prepare_data(context): context.cli.sendline("drop table if exists a;") wrappers.expect_exact( context, - "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", + "You're about to run a destructive command.\r\nDo you want to proceed? [y/N]:", timeout=2, ) context.cli.sendline("y") diff --git a/tests/features/steps/wrappers.py b/tests/features/steps/wrappers.py index 6180517..3ebcc92 100644 --- a/tests/features/steps/wrappers.py +++ b/tests/features/steps/wrappers.py @@ -3,10 +3,7 @@ import pexpect from pgcli.main import COLOR_CODE_REGEX import textwrap -try: - from StringIO import StringIO -except ImportError: - from io import StringIO +from io import StringIO def expect_exact(context, expected, timeout): diff --git a/tests/formatter/test_sqlformatter.py b/tests/formatter/test_sqlformatter.py index b8cd9c2..016ed95 100644 --- a/tests/formatter/test_sqlformatter.py +++ b/tests/formatter/test_sqlformatter.py @@ -34,7 +34,7 @@ def test_output_sql_insert(): "Jackson", "jackson_test@gmail.com", "132454789", - "", + None, "2022-09-09 19:44:32.712343+08", "2022-09-09 19:44:32.712343+08", ] @@ -58,7 +58,7 @@ def test_output_sql_insert(): output_list = [l for l in output] expected = [ 'INSERT INTO "user" ("id", "name", "email", "phone", "description", "created_at", "updated_at") VALUES', - " ('1', 'Jackson', 'jackson_test@gmail.com', '132454789', '', " + " ('1', 'Jackson', 'jackson_test@gmail.com', '132454789', NULL, " + "'2022-09-09 19:44:32.712343+08', '2022-09-09 19:44:32.712343+08')", ";", ] diff --git a/tests/parseutils/test_parseutils.py b/tests/parseutils/test_parseutils.py index 5a375d7..349cbd0 100644 --- a/tests/parseutils/test_parseutils.py +++ b/tests/parseutils/test_parseutils.py @@ -1,5 +1,10 @@ import pytest -from pgcli.packages.parseutils import is_destructive +from pgcli.packages.parseutils import ( + is_destructive, + parse_destructive_warning, + BASE_KEYWORDS, + ALL_KEYWORDS, +) from pgcli.packages.parseutils.tables import extract_tables from pgcli.packages.parseutils.utils import find_prev_keyword, is_open_quote @@ -263,18 +268,43 @@ def test_is_open_quote__open(sql): @pytest.mark.parametrize( - ("sql", "warning_level", "expected"), + ("sql", "keywords", "expected"), + [ + ("update abc set x = 1", ALL_KEYWORDS, True), + ("update abc set x = 1 where y = 2", ALL_KEYWORDS, True), + ("update abc set x = 1", BASE_KEYWORDS, True), + ("update abc set x = 1 where y = 2", BASE_KEYWORDS, False), + ("select x, y, z from abc", ALL_KEYWORDS, False), + ("drop abc", ALL_KEYWORDS, True), + ("alter abc", ALL_KEYWORDS, True), + ("delete abc", ALL_KEYWORDS, True), + ("truncate abc", ALL_KEYWORDS, True), + ("insert into abc values (1, 2, 3)", ALL_KEYWORDS, False), + ("insert into abc values (1, 2, 3)", BASE_KEYWORDS, False), + ("insert into abc values (1, 2, 3)", ["insert"], True), + ("insert into abc values (1, 2, 3)", ["insert"], True), + ], +) +def test_is_destructive(sql, keywords, expected): + assert is_destructive(sql, keywords) == expected + + +@pytest.mark.parametrize( + ("warning_level", "expected"), [ - ("update abc set x = 1", "all", True), - ("update abc set x = 1 where y = 2", "all", True), - ("update abc set x = 1", "moderate", True), - ("update abc set x = 1 where y = 2", "moderate", False), - ("select x, y, z from abc", "all", False), - ("drop abc", "all", True), - ("alter abc", "all", True), - ("delete abc", "all", True), - ("truncate abc", "all", True), + ("true", ALL_KEYWORDS), + ("false", []), + ("all", ALL_KEYWORDS), + ("moderate", BASE_KEYWORDS), + ("off", []), + ("", []), + (None, []), + (ALL_KEYWORDS, ALL_KEYWORDS), + (BASE_KEYWORDS, BASE_KEYWORDS), + ("insert", ["insert"]), + ("drop,alter,delete", ["drop", "alter", "delete"]), + (["drop", "alter", "delete"], ["drop", "alter", "delete"]), ], ) -def test_is_destructive(sql, warning_level, expected): - assert is_destructive(sql, warning_level=warning_level) == expected +def test_parse_destructive_warning(warning_level, expected): + assert parse_destructive_warning(warning_level) == expected diff --git a/tests/test_main.py b/tests/test_main.py index 9b3a84b..cbf20a6 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -216,7 +216,6 @@ def pset_pager_mocks(): with mock.patch("pgcli.main.click.echo") as mock_echo, mock.patch( "pgcli.main.click.echo_via_pager" ) as mock_echo_via_pager, mock.patch.object(cli, "prompt_app") as mock_app: - yield cli, mock_echo, mock_echo_via_pager, mock_app @@ -298,6 +297,22 @@ def test_i_works(tmpdir, executor): @dbtest +def test_echo_works(executor): + cli = PGCli(pgexecute=executor) + statement = r"\echo asdf" + result = run(executor, statement, pgspecial=cli.pgspecial) + assert result == ["asdf"] + + +@dbtest +def test_qecho_works(executor): + cli = PGCli(pgexecute=executor) + statement = r"\qecho asdf" + result = run(executor, statement, pgspecial=cli.pgspecial) + assert result == ["asdf"] + + +@dbtest def test_watch_works(executor): cli = PGCli(pgexecute=executor) @@ -371,7 +386,6 @@ def test_quoted_db_uri(tmpdir): def test_pg_service_file(tmpdir): - with mock.patch.object(PGCli, "connect") as mock_connect: cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) with open(tmpdir.join(".pg_service.conf").strpath, "w") as service_conf: diff --git a/tests/test_pgcompleter.py b/tests/test_pgcompleter.py new file mode 100644 index 0000000..909fa0b --- /dev/null +++ b/tests/test_pgcompleter.py @@ -0,0 +1,76 @@ +import pytest +from pgcli import pgcompleter + + +def test_load_alias_map_file_missing_file(): + with pytest.raises( + pgcompleter.InvalidMapFile, + match=r"Cannot read alias_map_file - /path/to/non-existent/file.json does not exist$", + ): + pgcompleter.load_alias_map_file("/path/to/non-existent/file.json") + + +def test_load_alias_map_file_invalid_json(tmp_path): + fpath = tmp_path / "foo.json" + fpath.write_text("this is not valid json") + with pytest.raises(pgcompleter.InvalidMapFile, match=r".*is not valid json$"): + pgcompleter.load_alias_map_file(str(fpath)) + + +@pytest.mark.parametrize( + "table_name, alias", + [ + ("SomE_Table", "SET"), + ("SOmeTabLe", "SOTL"), + ("someTable", "T"), + ], +) +def test_generate_alias_uses_upper_case_letters_from_name(table_name, alias): + assert pgcompleter.generate_alias(table_name) == alias + + +@pytest.mark.parametrize( + "table_name, alias", + [ + ("some_tab_le", "stl"), + ("s_ome_table", "sot"), + ("sometable", "s"), + ], +) +def test_generate_alias_uses_first_char_and_every_preceded_by_underscore( + table_name, alias +): + assert pgcompleter.generate_alias(table_name) == alias + + +@pytest.mark.parametrize( + "table_name, alias_map, alias", + [ + ("some_table", {"some_table": "my_alias"}, "my_alias"), + ], +) +def test_generate_alias_can_use_alias_map(table_name, alias_map, alias): + assert pgcompleter.generate_alias(table_name, alias_map) == alias + + +@pytest.mark.parametrize( + "table_name, alias_map, alias", + [ + ("SomeTable", {"SomeTable": "my_alias"}, "my_alias"), + ], +) +def test_generate_alias_prefers_alias_over_upper_case_name( + table_name, alias_map, alias +): + assert pgcompleter.generate_alias(table_name, alias_map) == alias + + +@pytest.mark.parametrize( + "table_name, alias", + [ + ("Some_tablE", "SE"), + ("SomeTab_le", "ST"), + ], +) +def test_generate_alias_prefers_upper_case_name_over_underscore_name(table_name, alias): + assert pgcompleter.generate_alias(table_name) == alias diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py index d6d2f93..636795b 100644 --- a/tests/test_pgexecute.py +++ b/tests/test_pgexecute.py @@ -304,9 +304,7 @@ def test_execute_from_commented_file_that_executes_another_file( @dbtest def test_execute_commented_first_line_and_special(executor, pgspecial, tmpdir): - # https://github.com/dbcli/pgcli/issues/1362 - - # just some base caes that should work also + # just some base cases that should work also statement = "--comment\nselect now();" result = run(executor, statement, pgspecial=pgspecial) assert result != None @@ -317,12 +315,14 @@ def test_execute_commented_first_line_and_special(executor, pgspecial, tmpdir): assert result != None assert result[1].find("now") >= 0 - statement = "/*comment\ncomment line2*/\nselect now();" + # https://github.com/dbcli/pgcli/issues/1362 + statement = "--comment\n\\h" result = run(executor, statement, pgspecial=pgspecial) assert result != None - assert result[1].find("now") >= 0 + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 - statement = "--comment\n\\h" + statement = "--comment1\n--comment2\n\\h" result = run(executor, statement, pgspecial=pgspecial) assert result != None assert result[1].find("ALTER") >= 0 @@ -334,6 +334,24 @@ def test_execute_commented_first_line_and_special(executor, pgspecial, tmpdir): assert result[1].find("ALTER") >= 0 assert result[1].find("ABORT") >= 0 + statement = """/*comment1 + comment2*/ + \h""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = """/*comment1 + comment2*/ + /*comment 3 + comment4*/ + \\h""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + statement = " /*comment*/\n\h;" result = run(executor, statement, pgspecial=pgspecial) assert result != None @@ -352,6 +370,126 @@ def test_execute_commented_first_line_and_special(executor, pgspecial, tmpdir): assert result[1].find("ALTER") >= 0 assert result[1].find("ABORT") >= 0 + statement = """\\h /*comment4 */""" + result = run(executor, statement, pgspecial=pgspecial) + print(result) + assert result != None + assert result[0].find("No help") >= 0 + + # TODO: we probably don't want to do this but sqlparse is not parsing things well + # we relly want it to find help but right now, sqlparse isn't dropping the /*comment*/ + # style comments after command + + statement = """/*comment1*/ + \h + /*comment4 */""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[0].find("No help") >= 0 + + # TODO: same for this one + statement = """/*comment1 + comment3 + comment2*/ + \\h + /*comment4 + comment5 + comment6*/""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[0].find("No help") >= 0 + + +@dbtest +def test_execute_commented_first_line_and_normal(executor, pgspecial, tmpdir): + # https://github.com/dbcli/pgcli/issues/1403 + + # just some base cases that should work also + statement = "--comment\nselect now();" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("now") >= 0 + + statement = "/*comment*/\nselect now();" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("now") >= 0 + + # this simulates the original error (1403) without having to add/drop tables + # since it was just an error on reading input files and not the actual + # command itself + + # test that the statement works + statement = """VALUES (1, 'one'), (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # test the statement with a \n in the middle + statement = """VALUES (1, 'one'),\n (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # test the statement with a newline in the middle + statement = """VALUES (1, 'one'), + (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # now add a single comment line + statement = """--comment\nVALUES (1, 'one'), (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # doing without special char \n + statement = """--comment + VALUES (1,'one'), + (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # two comment lines + statement = """--comment\n--comment2\nVALUES (1,'one'), (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # doing without special char \n + statement = """--comment + --comment2 + VALUES (1,'one'), (2, 'two'), (3, 'three'); + """ + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # multiline comment + newline in middle of the statement + statement = """/*comment +comment2 +comment3*/ +VALUES (1,'one'), +(2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # multiline comment + newline in middle of the statement + # + comments after the statement + statement = """/*comment +comment2 +comment3*/ +VALUES (1,'one'), +(2, 'two'), (3, 'three'); +--comment4 +--comment5""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + @dbtest def test_multiple_queries_same_line(executor): @@ -558,6 +696,7 @@ def test_view_definition(executor): run(executor, "create view vw1 AS SELECT * FROM tbl1") run(executor, "create materialized view mvw1 AS SELECT * FROM tbl1") result = executor.view_definition("vw1") + assert 'VIEW "public"."vw1" AS' in result assert "FROM tbl1" in result # import pytest; pytest.set_trace() result = executor.view_definition("mvw1") diff --git a/tests/test_prompt_utils.py b/tests/test_prompt_utils.py index a8a3a1e..91abe37 100644 --- a/tests/test_prompt_utils.py +++ b/tests/test_prompt_utils.py @@ -7,4 +7,11 @@ def test_confirm_destructive_query_notty(): stdin = click.get_text_stream("stdin") if not stdin.isatty(): sql = "drop database foo;" - assert confirm_destructive_query(sql, "all") is None + assert confirm_destructive_query(sql, [], None) is None + + +def test_confirm_destructive_query_with_alias(): + stdin = click.get_text_stream("stdin") + if not stdin.isatty(): + sql = "drop database foo;" + assert confirm_destructive_query(sql, ["drop"], "test") is None @@ -1,10 +1,11 @@ [tox] -envlist = py37, py38, py39, py310 +envlist = py38, py39, py310, py311, py312 [testenv] deps = pytest>=2.7.0,<=3.0.7 mock>=1.0.1 behave>=1.2.4 pexpect==3.3 + sshtunnel>=0.4.0 commands = py.test behave tests/features passenv = PGHOST |