From 98038631ba9672eafb28f518ff30db76d5954d81 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 18 May 2024 15:21:43 +0200 Subject: Adding upstream version 4.1.0. Signed-off-by: Daniel Baumann --- .github/workflows/ci.yml | 6 +- AUTHORS | 4 + changelog.rst | 25 +++- pgcli/__init__.py | 2 +- pgcli/main.py | 224 ++++++++++++++++++++++++++++++++-- pgcli/pgclirc | 14 ++- pgcli/pgexecute.py | 17 ++- setup.py | 11 +- tests/conftest.py | 2 + tests/features/steps/crud_database.py | 1 + tests/test_application_name.py | 17 +++ tests/test_main.py | 99 ++++++++++++++- tests/test_pgexecute.py | 66 +++++++++- tests/test_ssh_tunnel.py | 4 +- 14 files changed, 456 insertions(+), 36 deletions(-) create mode 100644 tests/test_application_name.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 68a69ac..007178f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -67,10 +67,6 @@ jobs: psql -h localhost -U postgres -p 6432 pgbouncer -c 'show help' - - name: Install beta version of pendulum - run: pip install pendulum==3.0.0b1 - if: matrix.python-version == '3.12' - - name: Install requirements run: | pip install -U pip setuptools @@ -89,7 +85,7 @@ jobs: run: behave tests/features --no-capture - name: Check changelog for ReST compliance - run: rst2html.py --halt=warning changelog.rst >/dev/null + run: docutils --halt=warning changelog.rst >/dev/null - name: Run Black run: black --check . diff --git a/AUTHORS b/AUTHORS index 5eff7db..9f33ff5 100644 --- a/AUTHORS +++ b/AUTHORS @@ -130,6 +130,10 @@ Contributors: * blag * Rob Berry (rob-b) * Sharon Yogev (sharonyogev) + * Hollis Wu (holi0317) + * Antonio Aguilar (crazybolillo) + * Andrew M. MacFie (amacfie) + * saucoide Creator: -------- diff --git a/changelog.rst b/changelog.rst index 7d08839..744e903 100644 --- a/changelog.rst +++ b/changelog.rst @@ -1,5 +1,26 @@ +4.1.0 (2024-03-09) ================== -4.0.1 (2023-11-30) + +Features: +--------- +* Support `PGAPPNAME` as an environment variable and `--application-name` as a command line argument. +* Add `verbose_errors` config and `\v` special command which enable the + displaying of all Postgres error fields received. +* Show Postgres notifications. +* Support sqlparse 0.5.x +* Add `--log-file [filename]` cli argument and `\log-file [filename]` special commands to + log to an external file in addition to the normal output + +Bug fixes: +---------- + +* Fix display of "short host" in prompt (with `\h`) for IPv4 addresses ([issue 964](https://github.com/dbcli/pgcli/issues/964)). +* Fix backwards display of NOTICEs from a Function ([issue 1443](https://github.com/dbcli/pgcli/issues/1443)) +* Fix psycopg errors when installing on Windows. ([issue 1413](https://https://github.com/dbcli/pgcli/issues/1413)) +* Use a home-made function to display query duration instead of relying on a third-party library (the general behaviour does not change), which fixes the installation of `pgcli` on 32-bit architectures ([issue 1451](https://github.com/dbcli/pgcli/issues/1451)) + +================== +4.0.1 (2023-10-30) ================== Internal: @@ -7,7 +28,7 @@ Internal: * Allow stable version of pendulum. ================== -4.0.0 (2023-11-27) +4.0.0 (2023-10-27) ================== Features: diff --git a/pgcli/__init__.py b/pgcli/__init__.py index 76ad18b..7039708 100644 --- a/pgcli/__init__.py +++ b/pgcli/__init__.py @@ -1 +1 @@ -__version__ = "4.0.1" +__version__ = "4.1.0" diff --git a/pgcli/main.py b/pgcli/main.py index f95c800..056a940 100644 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -11,9 +11,9 @@ import logging import threading import shutil import functools -import pendulum import datetime as dt import itertools +import pathlib import platform from time import time, sleep from typing import Optional @@ -74,8 +74,9 @@ from urllib.parse import urlparse from getpass import getuser -from psycopg import OperationalError, InterfaceError +from psycopg import OperationalError, InterfaceError, Notify from psycopg.conninfo import make_conninfo, conninfo_to_dict +from psycopg.errors import Diagnostic from collections import namedtuple @@ -129,6 +130,15 @@ class PgCliQuitError(Exception): pass +def notify_callback(notify: Notify): + click.secho( + 'Notification received on channel "{}" (PID {}):\n{}'.format( + notify.channel, notify.pid, notify.payload + ), + fg="green", + ) + + class PGCli: default_prompt = "\\u@\\h:\\d> " max_len_prompt = 30 @@ -165,6 +175,7 @@ class PGCli: pgexecute=None, pgclirc_file=None, row_limit=None, + application_name="pgcli", single_connection=False, less_chatty=None, prompt=None, @@ -172,6 +183,7 @@ class PGCli: auto_vertical_output=False, warn=None, ssh_tunnel_url: Optional[str] = None, + log_file: Optional[str] = None, ): self.force_passwd_prompt = force_passwd_prompt self.never_passwd_prompt = never_passwd_prompt @@ -210,6 +222,8 @@ class PGCli: else: self.row_limit = c["main"].as_int("row_limit") + self.application_name = application_name + # if not specified, set to DEFAULT_MAX_FIELD_WIDTH # if specified but empty, set to None to disable truncation # ellipsis will take at least 3 symbols, so this can't be less than 3 if specified and > 0 @@ -237,6 +251,9 @@ class PGCli: ) self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty") + self.verbose_errors = "verbose_errors" in c["main"] and c["main"].as_bool( + "verbose_errors" + ) self.null_string = c["main"].get("null_string", "") self.prompt_format = ( prompt @@ -295,6 +312,11 @@ class PGCli: self.ssh_tunnel_url = ssh_tunnel_url self.ssh_tunnel = None + if log_file: + with open(log_file, "a+"): + pass # ensure writeable + self.log_file = log_file + # formatter setup self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"]) register_new_formatter(self.formatter) @@ -354,6 +376,12 @@ class PGCli: "\\o [filename]", "Send all query results to file.", ) + self.pgspecial.register( + self.write_to_logfile, + "\\log-file", + "\\log-file [filename]", + "Log all query results to a logfile, in addition to the normal output destination.", + ) self.pgspecial.register( self.info_connection, "\\conninfo", "\\conninfo", "Get connection details" ) @@ -378,6 +406,26 @@ class PGCli: "Echo a string to the query output channel.", ) + self.pgspecial.register( + self.toggle_verbose_errors, + "\\v", + "\\v [on|off]", + "Toggle verbose errors.", + ) + + def toggle_verbose_errors(self, pattern, **_): + flag = pattern.strip() + + if flag == "on": + self.verbose_errors = True + elif flag == "off": + self.verbose_errors = False + else: + self.verbose_errors = not self.verbose_errors + + message = "Verbose errors " + "on." if self.verbose_errors else "off." + return [(None, None, None, message)] + def echo(self, pattern, **_): return [(None, None, None, pattern)] @@ -473,6 +521,26 @@ class PGCli: explain_mode=self.explain_mode, ) + def write_to_logfile(self, pattern, **_): + if not pattern: + self.log_file = None + message = "Logfile capture disabled" + return [(None, None, None, message, "", True, True)] + + log_file = pathlib.Path(pattern).expanduser().absolute() + + try: + with open(log_file, "a+"): + pass # ensure writeable + except OSError as e: + self.log_file = None + message = str(e) + "\nLogfile capture disabled" + return [(None, None, None, message, "", False, True)] + + self.log_file = str(log_file) + message = 'Writing to file "%s"' % self.log_file + return [(None, None, None, message, "", True, True)] + def write_to_file(self, pattern, **_): if not pattern: self.output_file = None @@ -568,7 +636,7 @@ class PGCli: if not database: database = user - kwargs.setdefault("application_name", "pgcli") + kwargs.setdefault("application_name", self.application_name) # If password prompt is not forced but no password is provided, try # getting it from environment variable. @@ -658,7 +726,16 @@ class PGCli: # prompt for a password (no -w flag), prompt for a passwd and try again. try: try: - pgexecute = PGExecute(database, user, passwd, host, port, dsn, **kwargs) + pgexecute = PGExecute( + database, + user, + passwd, + host, + port, + dsn, + notify_callback, + **kwargs, + ) except (OperationalError, InterfaceError) as e: if should_ask_for_password(e): passwd = click.prompt( @@ -668,7 +745,14 @@ class PGCli: type=str, ) pgexecute = PGExecute( - database, user, passwd, host, port, dsn, **kwargs + database, + user, + passwd, + host, + port, + dsn, + notify_callback, + **kwargs, ) else: raise e @@ -775,7 +859,7 @@ class PGCli: else: try: if self.output_file and not text.startswith( - ("\\o ", "\\? ", "\\echo ") + ("\\o ", "\\log-file", "\\? ", "\\echo ") ): try: with open(self.output_file, "a", encoding="utf-8") as f: @@ -787,6 +871,23 @@ class PGCli: else: if output: self.echo_via_pager("\n".join(output)) + + # Log to file in addition to normal output + if ( + self.log_file + and not text.startswith(("\\o ", "\\log-file", "\\? ", "\\echo ")) + and not text.strip() == "" + ): + try: + with open(self.log_file, "a", encoding="utf-8") as f: + click.echo( + dt.datetime.now().isoformat(), file=f + ) # timestamp log + click.echo(text, file=f) + click.echo("\n".join(output), file=f) + click.echo("", file=f) # extra newline + except OSError as e: + click.secho(str(e), err=True, fg="red") except KeyboardInterrupt: pass @@ -797,9 +898,9 @@ class PGCli: "Time: %0.03fs (%s), executed in: %0.03fs (%s)" % ( query.total_time, - pendulum.Duration(seconds=query.total_time).in_words(), + duration_in_words(query.total_time), query.execution_time, - pendulum.Duration(seconds=query.execution_time).in_words(), + duration_in_words(query.execution_time), ) ) else: @@ -1053,7 +1154,7 @@ class PGCli: res = self.pgexecute.run( text, self.pgspecial, - exception_formatter, + lambda x: exception_formatter(x, self.verbose_errors), on_error_resume, explain_mode=self.explain_mode, ) @@ -1337,6 +1438,12 @@ class PGCli: type=click.INT, help="Set threshold for row limit prompt. Use 0 to disable prompt.", ) +@click.option( + "--application-name", + default="pgcli", + envvar="PGAPPNAME", + help="Application name for the connection.", +) @click.option( "--less-chatty", "less_chatty", @@ -1371,6 +1478,11 @@ class PGCli: default=None, help="Open an SSH tunnel to the given address and connect to the database from it.", ) +@click.option( + "--log-file", + default=None, + help="Write all queries & output into a file, in addition to the normal output destination.", +) @click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1) @click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1) def cli( @@ -1387,6 +1499,7 @@ def cli( pgclirc, dsn, row_limit, + application_name, less_chatty, prompt, prompt_dsn, @@ -1395,6 +1508,7 @@ def cli( list_dsn, warn, ssh_tunnel: str, + log_file: str, ): if version: print("Version:", __version__) @@ -1445,6 +1559,7 @@ def cli( never_prompt, pgclirc_file=pgclirc, row_limit=row_limit, + application_name=application_name, single_connection=single_connection, less_chatty=less_chatty, prompt=prompt, @@ -1452,6 +1567,7 @@ def cli( auto_vertical_output=auto_vertical_output, warn=warn, ssh_tunnel_url=ssh_tunnel, + log_file=log_file, ) # Choose which ever one has a valid value. @@ -1583,8 +1699,71 @@ def is_select(status): return status.split(None, 1)[0].lower() == "select" -def exception_formatter(e): - return click.style(str(e), fg="red") +def diagnostic_output(diagnostic: Diagnostic) -> str: + fields = [] + + if diagnostic.severity is not None: + fields.append("Severity: " + diagnostic.severity) + + if diagnostic.severity_nonlocalized is not None: + fields.append("Severity (non-localized): " + diagnostic.severity_nonlocalized) + + if diagnostic.sqlstate is not None: + fields.append("SQLSTATE code: " + diagnostic.sqlstate) + + if diagnostic.message_primary is not None: + fields.append("Message: " + diagnostic.message_primary) + + if diagnostic.message_detail is not None: + fields.append("Detail: " + diagnostic.message_detail) + + if diagnostic.message_hint is not None: + fields.append("Hint: " + diagnostic.message_hint) + + if diagnostic.statement_position is not None: + fields.append("Position: " + diagnostic.statement_position) + + if diagnostic.internal_position is not None: + fields.append("Internal position: " + diagnostic.internal_position) + + if diagnostic.internal_query is not None: + fields.append("Internal query: " + diagnostic.internal_query) + + if diagnostic.context is not None: + fields.append("Where: " + diagnostic.context) + + if diagnostic.schema_name is not None: + fields.append("Schema name: " + diagnostic.schema_name) + + if diagnostic.table_name is not None: + fields.append("Table name: " + diagnostic.table_name) + + if diagnostic.column_name is not None: + fields.append("Column name: " + diagnostic.column_name) + + if diagnostic.datatype_name is not None: + fields.append("Data type name: " + diagnostic.datatype_name) + + if diagnostic.constraint_name is not None: + fields.append("Constraint name: " + diagnostic.constraint_name) + + if diagnostic.source_file is not None: + fields.append("File: " + diagnostic.source_file) + + if diagnostic.source_line is not None: + fields.append("Line: " + diagnostic.source_line) + + if diagnostic.source_function is not None: + fields.append("Routine: " + diagnostic.source_function) + + return "\n".join(fields) + + +def exception_formatter(e, verbose_errors: bool = False): + s = str(e) + if verbose_errors: + s += "\n" + diagnostic_output(e.diag) + return click.style(s, fg="red") def format_output(title, cur, headers, status, settings, explain_mode=False): @@ -1724,5 +1903,28 @@ def parse_service_info(service): return service_conf, service_file +def duration_in_words(duration_in_seconds: float) -> str: + if not duration_in_seconds: + return "0 seconds" + components = [] + hours, remainder = divmod(duration_in_seconds, 3600) + if hours > 1: + components.append(f"{hours} hours") + elif hours == 1: + components.append("1 hour") + minutes, seconds = divmod(remainder, 60) + if minutes > 1: + components.append(f"{minutes} minutes") + elif minutes == 1: + components.append("1 minute") + if seconds >= 2: + components.append(f"{int(seconds)} seconds") + elif seconds >= 1: + components.append("1 second") + elif seconds: + components.append(f"{round(seconds, 3)} second") + return " ".join(components) + + if __name__ == "__main__": cli() diff --git a/pgcli/pgclirc b/pgcli/pgclirc index 51f7eae..dd8b15f 100644 --- a/pgcli/pgclirc +++ b/pgcli/pgclirc @@ -33,10 +33,11 @@ multi_line_mode = psql # "unconditional_update" will warn you of update statements that don't have a where clause destructive_warning = drop, shutdown, delete, truncate, alter, update, unconditional_update -# Destructive warning can restart the connection if this is enabled and the -# user declines. This means that any current uncommitted transaction can be -# aborted if the user doesn't want to proceed with a destructive_warning -# statement. +# When `destructive_warning` is on and the user declines to proceed with a +# destructive statement, the current transaction (if any) is left untouched, +# by default. When setting `destructive_warning_restarts_connection` to +# "True", the connection to the server is restarted. In that case, the +# transaction (if any) is rolled back. destructive_warning_restarts_connection = False # When this option is on (and if `destructive_warning` is not empty), @@ -155,6 +156,11 @@ max_field_width = 500 # Skip intro on startup and goodbye on exit less_chatty = False +# Show all Postgres error fields (as listed in +# https://www.postgresql.org/docs/current/protocol-error-fields.html). +# Can be toggled with \v. +verbose_errors = False + # Postgres prompt # \t - Current date and time # \u - Username diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index 497d681..e091757 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -1,3 +1,4 @@ +import ipaddress import logging import traceback from collections import namedtuple @@ -166,6 +167,7 @@ class PGExecute: host=None, port=None, dsn=None, + notify_callback=None, **kwargs, ): self._conn_params = {} @@ -178,6 +180,7 @@ class PGExecute: self.port = None self.server_version = None self.extra_args = None + self.notify_callback = notify_callback self.connect(database, user, password, host, port, dsn, **kwargs) self.reset_expanded = None @@ -236,6 +239,9 @@ class PGExecute: self.conn = conn self.conn.autocommit = True + if self.notify_callback is not None: + self.conn.add_notify_handler(self.notify_callback) + # When we connect using a DSN, we don't really know what db, # user, etc. we connected to. Let's read it. # Note: moved this after setting autocommit because of #664. @@ -273,6 +279,11 @@ class PGExecute: @property def short_host(self): + try: + ipaddress.ip_address(self.host) + return self.host + except ValueError: + pass if "," in self.host: host, _, _ = self.host.partition(",") else: @@ -431,7 +442,11 @@ class PGExecute: def handle_notices(n): nonlocal title - title = f"{n.message_primary}\n{n.message_detail}\n{title}" + title = f"{title}" + if n.message_primary is not None: + title = f"{title}\n{n.message_primary}" + if n.message_detail is not None: + title = f"{title}\n{n.message_detail}" self.conn.add_notice_handler(handle_notices) diff --git a/setup.py b/setup.py index 9a398a4..f4606c2 100644 --- a/setup.py +++ b/setup.py @@ -12,10 +12,10 @@ install_requirements = [ # We still need to use pt-2 unless pt-3 released on Fedora32 # see: https://github.com/dbcli/pgcli/pull/1197 "prompt_toolkit>=2.0.6,<4.0.0", - "psycopg >= 3.0.14", - "sqlparse >=0.3.0,<0.5", + "psycopg >= 3.0.14; sys_platform != 'win32'", + "psycopg-binary >= 3.0.14; sys_platform == 'win32'", + "sqlparse >=0.3.0,<0.6", "configobj >= 5.0.6", - "pendulum>=2.1.0", "cli_helpers[styles] >= 2.2.1", ] @@ -27,11 +27,6 @@ install_requirements = [ if platform.system() != "Windows" and not platform.system().startswith("CYGWIN"): install_requirements.append("setproctitle >= 1.1.9") -# Windows will require the binary psycopg to run pgcli -if platform.system() == "Windows": - install_requirements.append("psycopg-binary >= 3.0.14") - - setup( name="pgcli", author="Pgcli Core Team", diff --git a/tests/conftest.py b/tests/conftest.py index 33cddf2..e50f1fe 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ from utils import ( db_connection, drop_tables, ) +import pgcli.main import pgcli.pgexecute @@ -37,6 +38,7 @@ def executor(connection): password=POSTGRES_PASSWORD, port=POSTGRES_PORT, dsn=None, + notify_callback=pgcli.main.notify_callback, ) diff --git a/tests/features/steps/crud_database.py b/tests/features/steps/crud_database.py index 87cdc85..9507d46 100644 --- a/tests/features/steps/crud_database.py +++ b/tests/features/steps/crud_database.py @@ -3,6 +3,7 @@ Steps for behavioral style tests are defined in this module. Each step is defined by the string decorating it. This string is used to call the step in "*.feature" file. """ + import pexpect from behave import when, then diff --git a/tests/test_application_name.py b/tests/test_application_name.py new file mode 100644 index 0000000..5fac5b2 --- /dev/null +++ b/tests/test_application_name.py @@ -0,0 +1,17 @@ +from unittest.mock import patch + +from click.testing import CliRunner + +from pgcli.main import cli +from pgcli.pgexecute import PGExecute + + +def test_application_name_in_env(): + runner = CliRunner() + app_name = "wonderful_app" + with patch.object(PGExecute, "__init__") as mock_pgxecute: + runner.invoke( + cli, ["127.0.0.1:5432/hello", "user"], env={"PGAPPNAME": app_name} + ) + kwargs = mock_pgxecute.call_args.kwargs + assert kwargs.get("application_name") == app_name diff --git a/tests/test_main.py b/tests/test_main.py index cbf20a6..3683d49 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,5 +1,8 @@ import os import platform +import re +import tempfile +import datetime from unittest import mock import pytest @@ -11,7 +14,9 @@ except ImportError: from pgcli.main import ( obfuscate_process_password, + duration_in_words, format_output, + notify_callback, PGCli, OutputSettings, COLOR_CODE_REGEX, @@ -296,6 +301,24 @@ def test_i_works(tmpdir, executor): run(executor, statement, pgspecial=cli.pgspecial) +@dbtest +def test_toggle_verbose_errors(executor): + cli = PGCli(pgexecute=executor) + + cli._evaluate_command("\\v on") + assert cli.verbose_errors + output, _ = cli._evaluate_command("SELECT 1/0") + assert "SQLSTATE" in output[0] + + cli._evaluate_command("\\v off") + assert not cli.verbose_errors + output, _ = cli._evaluate_command("SELECT 1/0") + assert "SQLSTATE" not in output[0] + + cli._evaluate_command("\\v") + assert cli.verbose_errors + + @dbtest def test_echo_works(executor): cli = PGCli(pgexecute=executor) @@ -312,6 +335,34 @@ def test_qecho_works(executor): assert result == ["asdf"] +@dbtest +def test_logfile_works(executor): + with tempfile.TemporaryDirectory() as tmpdir: + log_file = f"{tmpdir}/tempfile.log" + cli = PGCli(pgexecute=executor, log_file=log_file) + statement = r"\qecho hello!" + cli.execute_command(statement) + with open(log_file, "r") as f: + log_contents = f.readlines() + assert datetime.datetime.fromisoformat(log_contents[0].strip()) + assert log_contents[1].strip() == r"\qecho hello!" + assert log_contents[2].strip() == "hello!" + + +@dbtest +def test_logfile_unwriteable_file(executor): + cli = PGCli(pgexecute=executor) + statement = r"\log-file forbidden.log" + with mock.patch("builtins.open") as mock_open: + mock_open.side_effect = PermissionError( + "[Errno 13] Permission denied: 'forbidden.log'" + ) + result = run(executor, statement, pgspecial=cli.pgspecial) + assert result == [ + "[Errno 13] Permission denied: 'forbidden.log'\nLogfile capture disabled" + ] + + @dbtest def test_watch_works(executor): cli = PGCli(pgexecute=executor) @@ -431,6 +482,7 @@ def test_pg_service_file(tmpdir): "b_host", "5435", "", + notify_callback, application_name="pgcli", ) del os.environ["PGPASSWORD"] @@ -486,5 +538,50 @@ def test_application_name_db_uri(tmpdir): cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) cli.connect_uri("postgres://bar@baz.com/?application_name=cow") mock_pgexecute.assert_called_with( - "bar", "bar", "", "baz.com", "", "", application_name="cow" + "bar", "bar", "", "baz.com", "", "", notify_callback, application_name="cow" ) + + +@pytest.mark.parametrize( + "duration_in_seconds,words", + [ + (0, "0 seconds"), + (0.0009, "0.001 second"), + (0.0005, "0.001 second"), + (0.0004, "0.0 second"), # not perfect, but will do + (0.2, "0.2 second"), + (1, "1 second"), + (1.4, "1 second"), + (2, "2 seconds"), + (3.4, "3 seconds"), + (60, "1 minute"), + (61, "1 minute 1 second"), + (123, "2 minutes 3 seconds"), + (3600, "1 hour"), + (7235, "2 hours 35 seconds"), + (9005, "2 hours 30 minutes 5 seconds"), + (86401, "24 hours 1 second"), + ], +) +def test_duration_in_words(duration_in_seconds, words): + assert duration_in_words(duration_in_seconds) == words + + +@dbtest +def test_notifications(executor): + run(executor, "listen chan1") + + with mock.patch("pgcli.main.click.secho") as mock_secho: + run(executor, "notify chan1, 'testing1'") + mock_secho.assert_called() + arg = mock_secho.call_args_list[0].args[0] + assert re.match( + r'Notification received on channel "chan1" \(PID \d+\):\ntesting1', + arg, + ) + + run(executor, "unlisten chan1") + + with mock.patch("pgcli.main.click.secho") as mock_secho: + run(executor, "notify chan1, 'testing2'") + mock_secho.assert_not_called() diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py index 636795b..f1cadfd 100644 --- a/tests/test_pgexecute.py +++ b/tests/test_pgexecute.py @@ -1,3 +1,4 @@ +import re from textwrap import dedent import psycopg @@ -6,7 +7,7 @@ from unittest.mock import patch, MagicMock from pgspecial.main import PGSpecial, NO_QUERY from utils import run, dbtest, requires_json, requires_jsonb -from pgcli.main import PGCli +from pgcli.main import PGCli, exception_formatter as main_exception_formatter from pgcli.packages.parseutils.meta import FunctionMetadata @@ -219,8 +220,33 @@ def test_database_list(executor): @dbtest def test_invalid_syntax(executor, exception_formatter): - result = run(executor, "invalid syntax!", exception_formatter=exception_formatter) + result = run( + executor, + "invalid syntax!", + exception_formatter=lambda x: main_exception_formatter(x, verbose_errors=False), + ) assert 'syntax error at or near "invalid"' in result[0] + assert "SQLSTATE" not in result[0] + + +@dbtest +def test_invalid_syntax_verbose(executor): + result = run( + executor, + "invalid syntax!", + exception_formatter=lambda x: main_exception_formatter(x, verbose_errors=True), + ) + fields = r""" +Severity: ERROR +Severity \(non-localized\): ERROR +SQLSTATE code: 42601 +Message: syntax error at or near "invalid" +Position: 1 +File: scan\.l +Line: \d+ +Routine: scanner_yyerror + """.strip() + assert re.search(fields, result[0]) @dbtest @@ -690,6 +716,38 @@ def test_function_definition(executor): result = executor.function_definition("the_number_three") +@dbtest +def test_function_notice_order(executor): + run( + executor, + """ + CREATE OR REPLACE FUNCTION demo_order() RETURNS VOID AS + $$ + BEGIN + RAISE NOTICE 'first'; + RAISE NOTICE 'second'; + RAISE NOTICE 'third'; + RAISE NOTICE 'fourth'; + RAISE NOTICE 'fifth'; + RAISE NOTICE 'sixth'; + END; + $$ + LANGUAGE plpgsql; + """, + ) + + executor.function_definition("demo_order") + + result = run(executor, "select demo_order()") + assert "first\nsecond\nthird\nfourth\nfifth\nsixth" in result[0] + assert "+------------+" in result[1] + assert "| demo_order |" in result[2] + assert "|------------|" in result[3] + assert "| |" in result[4] + assert "+------------+" in result[5] + assert "SELECT 1" in result[6] + + @dbtest def test_view_definition(executor): run(executor, "create table tbl1 (a text, b numeric)") @@ -721,6 +779,10 @@ def test_short_host(executor): executor, "host", "localhost1.example.org,localhost2.example.org" ): assert executor.short_host == "localhost1" + with patch.object(executor, "host", "ec2-11-222-333-444.compute-1.amazonaws.com"): + assert executor.short_host == "ec2-11-222-333-444" + with patch.object(executor, "host", "1.2.3.4"): + assert executor.short_host == "1.2.3.4" class VirtualCursor: diff --git a/tests/test_ssh_tunnel.py b/tests/test_ssh_tunnel.py index ae865f4..983212b 100644 --- a/tests/test_ssh_tunnel.py +++ b/tests/test_ssh_tunnel.py @@ -6,7 +6,7 @@ from configobj import ConfigObj from click.testing import CliRunner from sshtunnel import SSHTunnelForwarder -from pgcli.main import cli, PGCli +from pgcli.main import cli, notify_callback, PGCli from pgcli.pgexecute import PGExecute @@ -61,6 +61,7 @@ def test_ssh_tunnel( "127.0.0.1", pgcli.ssh_tunnel.local_bind_ports[0], "", + notify_callback, ) mock_ssh_tunnel_forwarder.reset_mock() mock_pgexecute.reset_mock() @@ -96,6 +97,7 @@ def test_ssh_tunnel( "127.0.0.1", pgcli.ssh_tunnel.local_bind_ports[0], "", + notify_callback, ) mock_ssh_tunnel_forwarder.reset_mock() mock_pgexecute.reset_mock() -- cgit v1.2.3