From e8e0960f454f47142162c94a083fa9efd19d4fd9 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 2 Mar 2022 13:22:04 +0100 Subject: Merging upstream version 3.4.0. Signed-off-by: Daniel Baumann --- .github/workflows/ci.yml | 2 +- AUTHORS | 3 +- RELEASES.md | 17 +++ changelog.rst | 14 ++- pgcli/__init__.py | 2 +- pgcli/main.py | 89 +++++++++++++++- pgcli/packages/parseutils/tables.py | 18 ++-- pgcli/pgcompleter.py | 13 ++- release_procedure.txt | 13 --- setup.py | 5 +- tests/features/steps/basic_commands.py | 4 +- tests/test_ssh_tunnel.py | 188 +++++++++++++++++++++++++++++++++ 12 files changed, 328 insertions(+), 40 deletions(-) create mode 100644 RELEASES.md delete mode 100644 release_procedure.txt create mode 100644 tests/test_ssh_tunnel.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ce54d6f..f0e6fd8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -38,7 +38,7 @@ jobs: - name: Install requirements run: | pip install -U pip setuptools - pip install --no-cache-dir . + pip install --no-cache-dir ".[sshtunnel]" pip install -r requirements-dev.txt pip install keyrings.alt>=3.1 diff --git a/AUTHORS b/AUTHORS index bcfba6a..a037334 100644 --- a/AUTHORS +++ b/AUTHORS @@ -116,8 +116,9 @@ Contributors: * Kevin Marsh (kevinmarsh) * Eero Ruohola (ruohola) * Miroslav Šedivý (eumiro) - * Eric R Young (ERYoung11) + * Eric R Young (ERYoung11) * Paweł Sacawa (psacawa) + * Bruno Inec (sweenu) Creator: -------- diff --git a/RELEASES.md b/RELEASES.md new file mode 100644 index 0000000..37cf4d2 --- /dev/null +++ b/RELEASES.md @@ -0,0 +1,17 @@ +Releasing pgcli +--------------- + +We have a script called `release.py` to automate the process. + +The script can be run with `-c` to confirm or skip steps. There's also a `--dry-run` option that only prints out the steps. + +``` +> python release.py --help +Usage: release.py [options] + +Options: + -h, --help show this help message and exit + -c, --confirm-steps Confirm every step. If the step is not confirmed, it + will be skipped. + -d, --dry-run Print out, but not actually run any steps. +``` diff --git a/changelog.rst b/changelog.rst index d4bbd39..1f0bc59 100644 --- a/changelog.rst +++ b/changelog.rst @@ -1,3 +1,16 @@ +TBD +=== + +* [List new changes here]. + +3.4.0 (2022/02/21) +================== + +Features: +--------- + +* Add optional support for automatically creating an SSH tunnel to a machine with access to the remote database ([related issue](https://github.com/dbcli/pgcli/issues/459)). + 3.3.1 (2022/01/18) ================== @@ -7,7 +20,6 @@ Bug fixes: * Prompt for password when -W is provided even if there is a password in keychain. Fixes #1307. * Upgrade cli_helpers to 2.2.1 - 3.3.0 (2022/01/11) ================== diff --git a/pgcli/__init__.py b/pgcli/__init__.py index ff04168..903a158 100644 --- a/pgcli/__init__.py +++ b/pgcli/__init__.py @@ -1 +1 @@ -__version__ = "3.3.1" +__version__ = "3.4.0" diff --git a/pgcli/main.py b/pgcli/main.py index e4a2ee3..a72f708 100644 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -1,6 +1,5 @@ import platform import warnings -from os.path import expanduser from configobj import ConfigObj, ParseError from pgspecial.namedqueries import NamedQueries @@ -8,6 +7,7 @@ from .config import skip_initial_comment warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2") +import atexit import os import re import sys @@ -21,6 +21,8 @@ import datetime as dt import itertools import platform from time import time, sleep +from typing import Optional +from urllib.parse import urlparse keyring = None # keyring will be loaded later @@ -78,12 +80,21 @@ except ImportError: from getpass import getuser from psycopg2 import OperationalError, InterfaceError +from psycopg2.extensions import make_dsn, parse_dsn import psycopg2 from collections import namedtuple from textwrap import dedent +try: + import sshtunnel + + SSH_TUNNEL_SUPPORT = True +except ImportError: + SSH_TUNNEL_SUPPORT = False + + # Ref: https://stackoverflow.com/questions/30425105/filter-special-chars-such-as-color-codes-from-shell-output COLOR_CODE_REGEX = re.compile(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))") DEFAULT_MAX_FIELD_WIDTH = 500 @@ -168,8 +179,8 @@ class PGCli: prompt_dsn=None, auto_vertical_output=False, warn=None, + ssh_tunnel_url: Optional[str] = None, ): - self.force_passwd_prompt = force_passwd_prompt self.never_passwd_prompt = never_passwd_prompt self.pgexecute = pgexecute @@ -275,6 +286,10 @@ class PGCli: self.prompt_app = None + self.ssh_tunnel_config = c.get("ssh tunnels") + self.ssh_tunnel_url = ssh_tunnel_url + self.ssh_tunnel = None + def quit(self): raise PgCliQuitError @@ -585,6 +600,56 @@ class PGCli: return True return False + if dsn: + parsed_dsn = parse_dsn(dsn) + if "host" in parsed_dsn: + host = parsed_dsn["host"] + if "port" in parsed_dsn: + port = parsed_dsn["port"] + + if self.ssh_tunnel_config and not self.ssh_tunnel_url: + for db_host_regex, tunnel_url in self.ssh_tunnel_config.items(): + if re.search(db_host_regex, host): + self.ssh_tunnel_url = tunnel_url + break + + if self.ssh_tunnel_url: + # We add the protocol as urlparse doesn't find it by itself + if "://" not in self.ssh_tunnel_url: + self.ssh_tunnel_url = f"ssh://{self.ssh_tunnel_url}" + + tunnel_info = urlparse(self.ssh_tunnel_url) + params = { + "local_bind_address": ("127.0.0.1",), + "remote_bind_address": (host, int(port or 5432)), + "ssh_address_or_host": (tunnel_info.hostname, tunnel_info.port or 22), + "logger": self.logger, + } + if tunnel_info.username: + params["ssh_username"] = tunnel_info.username + if tunnel_info.password: + params["ssh_password"] = tunnel_info.password + + # Hack: sshtunnel adds a console handler to the logger, so we revert handlers. + # We can remove this when https://github.com/pahaz/sshtunnel/pull/250 is merged. + logger_handlers = self.logger.handlers.copy() + try: + self.ssh_tunnel = sshtunnel.SSHTunnelForwarder(**params) + self.ssh_tunnel.start() + except Exception as e: + self.logger.handlers = logger_handlers + self.logger.error("traceback: %r", traceback.format_exc()) + click.secho(str(e), err=True, fg="red") + exit(1) + self.logger.handlers = logger_handlers + + atexit.register(self.ssh_tunnel.stop) + host = "127.0.0.1" + port = self.ssh_tunnel.local_bind_ports[0] + + if dsn: + dsn = make_dsn(dsn, host=host, port=port) + # Attempt to connect to the database. # Note that passwd may be empty on the first attempt. If connection # fails because of a missing or incorrect password, but we're allowed to @@ -1222,7 +1287,7 @@ class PGCli: "--list", "list_databases", is_flag=True, - help="list " "available databases, then exit.", + help="list available databases, then exit.", ) @click.option( "--auto-vertical-output", @@ -1235,6 +1300,11 @@ class PGCli: type=click.Choice(["all", "moderate", "off"]), help="Warn before running a destructive query.", ) +@click.option( + "--ssh-tunnel", + default=None, + help="Open an SSH tunnel to the given address and connect to the database from it.", +) @click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1) @click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1) def cli( @@ -1258,6 +1328,7 @@ def cli( auto_vertical_output, list_dsn, warn, + ssh_tunnel: str, ): if version: print("Version:", __version__) @@ -1294,6 +1365,15 @@ def cli( ) exit(1) + if ssh_tunnel and not SSH_TUNNEL_SUPPORT: + click.secho( + 'Cannot open SSH tunnel, "sshtunnel" package was not found. ' + "Please install pgcli with `pip install pgcli[sshtunnel]` if you want SSH tunnel support.", + err=True, + fg="red", + ) + exit(1) + pgcli = PGCli( prompt_passwd, never_prompt, @@ -1305,6 +1385,7 @@ def cli( prompt_dsn=prompt_dsn, auto_vertical_output=auto_vertical_output, warn=warn, + ssh_tunnel_url=ssh_tunnel, ) # Choose which ever one has a valid value. @@ -1548,7 +1629,7 @@ def parse_service_info(service): elif os.getenv("PGSYSCONFDIR"): service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf") else: - service_file = expanduser("~/.pg_service.conf") + service_file = os.path.expanduser("~/.pg_service.conf") if not service or not os.path.exists(service_file): # nothing to do return None, service_file diff --git a/pgcli/packages/parseutils/tables.py b/pgcli/packages/parseutils/tables.py index aaa676c..f2e1e42 100644 --- a/pgcli/packages/parseutils/tables.py +++ b/pgcli/packages/parseutils/tables.py @@ -63,17 +63,13 @@ def extract_from_part(parsed, stop_at_punctuation=True): yield item elif item.ttype is Keyword or item.ttype is Keyword.DML: item_val = item.value.upper() - if ( - item_val - in ( - "COPY", - "FROM", - "INTO", - "UPDATE", - "TABLE", - ) - or item_val.endswith("JOIN") - ): + if item_val in ( + "COPY", + "FROM", + "INTO", + "UPDATE", + "TABLE", + ) or item_val.endswith("JOIN"): tbl_prefix_seen = True # 'SELECT a, FROM abc' will detect FROM as part of the column list. # So this check here is necessary. diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py index 227e25c..e66c3dc 100644 --- a/pgcli/pgcompleter.py +++ b/pgcli/pgcompleter.py @@ -491,11 +491,14 @@ class PGCompleter(Completer): def get_column_matches(self, suggestion, word_before_cursor): tables = suggestion.table_refs - do_qualify = suggestion.qualifiable and { - "always": True, - "never": False, - "if_more_than_one_table": len(tables) > 1, - }[self.qualify_columns] + do_qualify = ( + suggestion.qualifiable + and { + "always": True, + "never": False, + "if_more_than_one_table": len(tables) > 1, + }[self.qualify_columns] + ) qualify = lambda col, tbl: ( (tbl + "." + self.case(col)) if do_qualify else self.case(col) ) diff --git a/release_procedure.txt b/release_procedure.txt deleted file mode 100644 index 9f3bff0..0000000 --- a/release_procedure.txt +++ /dev/null @@ -1,13 +0,0 @@ -# vi: ft=vimwiki - -* Bump the version number in pgcli/__init__.py -* Commit with message: 'Releasing version X.X.X.' -* Create a tag: git tag vX.X.X -* Fix the image url in PyPI to point to github raw content. https://raw.githubusercontent.com/dbcli/pgcli/master/screenshots/image01.png -* Create source dist tar ball: python setup.py sdist -* Test this by installing it in a fresh new virtualenv. Run SanityChecks [./sanity_checks.txt]. -* Upload the source dist to PyPI: https://pypi.python.org/pypi/pgcli -* pip install pgcli -* Run SanityChecks. -* Push the version back to github: git push --tags origin master -* Done! diff --git a/setup.py b/setup.py index 0cbd192..ed6fb3c 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,10 @@ setup( description=description, long_description=open("README.rst").read(), install_requires=install_requirements, - extras_require={"keyring": ["keyring >= 12.2.0"]}, + extras_require={ + "keyring": ["keyring >= 12.2.0"], + "sshtunnel": ["sshtunnel >= 0.4.0"], + }, python_requires=">=3.6", entry_points=""" [console_scripts] diff --git a/tests/features/steps/basic_commands.py b/tests/features/steps/basic_commands.py index 7ca20f0..a7c99ee 100644 --- a/tests/features/steps/basic_commands.py +++ b/tests/features/steps/basic_commands.py @@ -97,9 +97,9 @@ def step_see_error_message(context): @when("we send source command") def step_send_source_command(context): context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_") - context.tmpfile_sql_help.write(br"\?") + context.tmpfile_sql_help.write(rb"\?") context.tmpfile_sql_help.flush() - context.cli.sendline(fr"\i {context.tmpfile_sql_help.name}") + context.cli.sendline(rf"\i {context.tmpfile_sql_help.name}") wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) diff --git a/tests/test_ssh_tunnel.py b/tests/test_ssh_tunnel.py new file mode 100644 index 0000000..ae865f4 --- /dev/null +++ b/tests/test_ssh_tunnel.py @@ -0,0 +1,188 @@ +import os +from unittest.mock import patch, MagicMock, ANY + +import pytest +from configobj import ConfigObj +from click.testing import CliRunner +from sshtunnel import SSHTunnelForwarder + +from pgcli.main import cli, PGCli +from pgcli.pgexecute import PGExecute + + +@pytest.fixture +def mock_ssh_tunnel_forwarder() -> MagicMock: + mock_ssh_tunnel_forwarder = MagicMock( + SSHTunnelForwarder, local_bind_ports=[1111], autospec=True + ) + with patch( + "pgcli.main.sshtunnel.SSHTunnelForwarder", + return_value=mock_ssh_tunnel_forwarder, + ) as mock: + yield mock + + +@pytest.fixture +def mock_pgexecute() -> MagicMock: + with patch.object(PGExecute, "__init__", return_value=None) as mock_pgexecute: + yield mock_pgexecute + + +def test_ssh_tunnel( + mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock +) -> None: + # Test with just a host + tunnel_url = "some.host" + db_params = { + "database": "dbname", + "host": "db.host", + "user": "db_user", + "passwd": "db_passwd", + } + expected_tunnel_params = { + "local_bind_address": ("127.0.0.1",), + "remote_bind_address": (db_params["host"], 5432), + "ssh_address_or_host": (tunnel_url, 22), + "logger": ANY, + } + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect(**db_params) + + mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) + mock_ssh_tunnel_forwarder.return_value.start.assert_called_once() + mock_pgexecute.assert_called_once() + + call_args, call_kwargs = mock_pgexecute.call_args + assert call_args == ( + db_params["database"], + db_params["user"], + db_params["passwd"], + "127.0.0.1", + pgcli.ssh_tunnel.local_bind_ports[0], + "", + ) + mock_ssh_tunnel_forwarder.reset_mock() + mock_pgexecute.reset_mock() + + # Test with a full url and with a specific db port + tunnel_user = "tunnel_user" + tunnel_passwd = "tunnel_pass" + tunnel_host = "some.other.host" + tunnel_port = 1022 + tunnel_url = f"ssh://{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}" + db_params["port"] = 1234 + + expected_tunnel_params["remote_bind_address"] = ( + db_params["host"], + db_params["port"], + ) + expected_tunnel_params["ssh_address_or_host"] = (tunnel_host, tunnel_port) + expected_tunnel_params["ssh_username"] = tunnel_user + expected_tunnel_params["ssh_password"] = tunnel_passwd + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect(**db_params) + + mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) + mock_ssh_tunnel_forwarder.return_value.start.assert_called_once() + mock_pgexecute.assert_called_once() + + call_args, call_kwargs = mock_pgexecute.call_args + assert call_args == ( + db_params["database"], + db_params["user"], + db_params["passwd"], + "127.0.0.1", + pgcli.ssh_tunnel.local_bind_ports[0], + "", + ) + mock_ssh_tunnel_forwarder.reset_mock() + mock_pgexecute.reset_mock() + + # Test with DSN + dsn = ( + f"user={db_params['user']} password={db_params['passwd']} " + f"host={db_params['host']} port={db_params['port']}" + ) + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect(dsn=dsn) + + expected_dsn = ( + f"user={db_params['user']} password={db_params['passwd']} " + f"host=127.0.0.1 port={pgcli.ssh_tunnel.local_bind_ports[0]}" + ) + + mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) + mock_pgexecute.assert_called_once() + + call_args, call_kwargs = mock_pgexecute.call_args + assert expected_dsn in call_args + + +def test_cli_with_tunnel() -> None: + runner = CliRunner() + tunnel_url = "mytunnel" + with patch.object( + PGCli, "__init__", autospec=True, return_value=None + ) as mock_pgcli: + runner.invoke(cli, ["--ssh-tunnel", tunnel_url]) + mock_pgcli.assert_called_once() + call_args, call_kwargs = mock_pgcli.call_args + assert call_kwargs["ssh_tunnel_url"] == tunnel_url + + +def test_config( + tmpdir: os.PathLike, mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock +) -> None: + pgclirc = str(tmpdir.join("rcfile")) + + tunnel_user = "tunnel_user" + tunnel_passwd = "tunnel_pass" + tunnel_host = "tunnel.host" + tunnel_port = 1022 + tunnel_url = f"{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}" + + tunnel2_url = "tunnel2.host" + + config = ConfigObj() + config.filename = pgclirc + config["ssh tunnels"] = {} + config["ssh tunnels"][r"\.com$"] = tunnel_url + config["ssh tunnels"][r"^hello-"] = tunnel2_url + config.write() + + # Unmatched host + pgcli = PGCli(pgclirc_file=pgclirc) + pgcli.connect(host="unmatched.host") + mock_ssh_tunnel_forwarder.assert_not_called() + + # Host matching first tunnel + pgcli = PGCli(pgclirc_file=pgclirc) + pgcli.connect(host="matched.host.com") + mock_ssh_tunnel_forwarder.assert_called_once() + call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args + assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port) + assert call_kwargs["ssh_username"] == tunnel_user + assert call_kwargs["ssh_password"] == tunnel_passwd + mock_ssh_tunnel_forwarder.reset_mock() + + # Host matching second tunnel + pgcli = PGCli(pgclirc_file=pgclirc) + pgcli.connect(host="hello-i-am-matched") + mock_ssh_tunnel_forwarder.assert_called_once() + + call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args + assert call_kwargs["ssh_address_or_host"] == (tunnel2_url, 22) + mock_ssh_tunnel_forwarder.reset_mock() + + # Host matching both tunnels (will use the first one matched) + pgcli = PGCli(pgclirc_file=pgclirc) + pgcli.connect(host="hello-i-am-matched.com") + mock_ssh_tunnel_forwarder.assert_called_once() + + call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args + assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port) + assert call_kwargs["ssh_username"] == tunnel_user + assert call_kwargs["ssh_password"] == tunnel_passwd -- cgit v1.2.3