summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/ci.yml2
-rw-r--r--AUTHORS3
-rw-r--r--RELEASES.md17
-rw-r--r--changelog.rst14
-rw-r--r--pgcli/__init__.py2
-rw-r--r--pgcli/main.py89
-rw-r--r--pgcli/packages/parseutils/tables.py18
-rw-r--r--pgcli/pgcompleter.py13
-rw-r--r--release_procedure.txt13
-rw-r--r--setup.py5
-rw-r--r--tests/features/steps/basic_commands.py4
-rw-r--r--tests/test_ssh_tunnel.py188
12 files changed, 328 insertions, 40 deletions
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index ce54d6f..f0e6fd8 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -38,7 +38,7 @@ jobs:
- name: Install requirements
run: |
pip install -U pip setuptools
- pip install --no-cache-dir .
+ pip install --no-cache-dir ".[sshtunnel]"
pip install -r requirements-dev.txt
pip install keyrings.alt>=3.1
diff --git a/AUTHORS b/AUTHORS
index bcfba6a..a037334 100644
--- a/AUTHORS
+++ b/AUTHORS
@@ -116,8 +116,9 @@ Contributors:
* Kevin Marsh (kevinmarsh)
* Eero Ruohola (ruohola)
* Miroslav Šedivý (eumiro)
- * Eric R Young (ERYoung11)
+ * Eric R Young (ERYoung11)
* Paweł Sacawa (psacawa)
+ * Bruno Inec (sweenu)
Creator:
--------
diff --git a/RELEASES.md b/RELEASES.md
new file mode 100644
index 0000000..37cf4d2
--- /dev/null
+++ b/RELEASES.md
@@ -0,0 +1,17 @@
+Releasing pgcli
+---------------
+
+We have a script called `release.py` to automate the process.
+
+The script can be run with `-c` to confirm or skip steps. There's also a `--dry-run` option that only prints out the steps.
+
+```
+> python release.py --help
+Usage: release.py [options]
+
+Options:
+ -h, --help show this help message and exit
+ -c, --confirm-steps Confirm every step. If the step is not confirmed, it
+ will be skipped.
+ -d, --dry-run Print out, but not actually run any steps.
+```
diff --git a/changelog.rst b/changelog.rst
index d4bbd39..1f0bc59 100644
--- a/changelog.rst
+++ b/changelog.rst
@@ -1,3 +1,16 @@
+TBD
+===
+
+* [List new changes here].
+
+3.4.0 (2022/02/21)
+==================
+
+Features:
+---------
+
+* Add optional support for automatically creating an SSH tunnel to a machine with access to the remote database ([related issue](https://github.com/dbcli/pgcli/issues/459)).
+
3.3.1 (2022/01/18)
==================
@@ -7,7 +20,6 @@ Bug fixes:
* Prompt for password when -W is provided even if there is a password in keychain. Fixes #1307.
* Upgrade cli_helpers to 2.2.1
-
3.3.0 (2022/01/11)
==================
diff --git a/pgcli/__init__.py b/pgcli/__init__.py
index ff04168..903a158 100644
--- a/pgcli/__init__.py
+++ b/pgcli/__init__.py
@@ -1 +1 @@
-__version__ = "3.3.1"
+__version__ = "3.4.0"
diff --git a/pgcli/main.py b/pgcli/main.py
index e4a2ee3..a72f708 100644
--- a/pgcli/main.py
+++ b/pgcli/main.py
@@ -1,6 +1,5 @@
import platform
import warnings
-from os.path import expanduser
from configobj import ConfigObj, ParseError
from pgspecial.namedqueries import NamedQueries
@@ -8,6 +7,7 @@ from .config import skip_initial_comment
warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2")
+import atexit
import os
import re
import sys
@@ -21,6 +21,8 @@ import datetime as dt
import itertools
import platform
from time import time, sleep
+from typing import Optional
+from urllib.parse import urlparse
keyring = None # keyring will be loaded later
@@ -78,12 +80,21 @@ except ImportError:
from getpass import getuser
from psycopg2 import OperationalError, InterfaceError
+from psycopg2.extensions import make_dsn, parse_dsn
import psycopg2
from collections import namedtuple
from textwrap import dedent
+try:
+ import sshtunnel
+
+ SSH_TUNNEL_SUPPORT = True
+except ImportError:
+ SSH_TUNNEL_SUPPORT = False
+
+
# Ref: https://stackoverflow.com/questions/30425105/filter-special-chars-such-as-color-codes-from-shell-output
COLOR_CODE_REGEX = re.compile(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))")
DEFAULT_MAX_FIELD_WIDTH = 500
@@ -168,8 +179,8 @@ class PGCli:
prompt_dsn=None,
auto_vertical_output=False,
warn=None,
+ ssh_tunnel_url: Optional[str] = None,
):
-
self.force_passwd_prompt = force_passwd_prompt
self.never_passwd_prompt = never_passwd_prompt
self.pgexecute = pgexecute
@@ -275,6 +286,10 @@ class PGCli:
self.prompt_app = None
+ self.ssh_tunnel_config = c.get("ssh tunnels")
+ self.ssh_tunnel_url = ssh_tunnel_url
+ self.ssh_tunnel = None
+
def quit(self):
raise PgCliQuitError
@@ -585,6 +600,56 @@ class PGCli:
return True
return False
+ if dsn:
+ parsed_dsn = parse_dsn(dsn)
+ if "host" in parsed_dsn:
+ host = parsed_dsn["host"]
+ if "port" in parsed_dsn:
+ port = parsed_dsn["port"]
+
+ if self.ssh_tunnel_config and not self.ssh_tunnel_url:
+ for db_host_regex, tunnel_url in self.ssh_tunnel_config.items():
+ if re.search(db_host_regex, host):
+ self.ssh_tunnel_url = tunnel_url
+ break
+
+ if self.ssh_tunnel_url:
+ # We add the protocol as urlparse doesn't find it by itself
+ if "://" not in self.ssh_tunnel_url:
+ self.ssh_tunnel_url = f"ssh://{self.ssh_tunnel_url}"
+
+ tunnel_info = urlparse(self.ssh_tunnel_url)
+ params = {
+ "local_bind_address": ("127.0.0.1",),
+ "remote_bind_address": (host, int(port or 5432)),
+ "ssh_address_or_host": (tunnel_info.hostname, tunnel_info.port or 22),
+ "logger": self.logger,
+ }
+ if tunnel_info.username:
+ params["ssh_username"] = tunnel_info.username
+ if tunnel_info.password:
+ params["ssh_password"] = tunnel_info.password
+
+ # Hack: sshtunnel adds a console handler to the logger, so we revert handlers.
+ # We can remove this when https://github.com/pahaz/sshtunnel/pull/250 is merged.
+ logger_handlers = self.logger.handlers.copy()
+ try:
+ self.ssh_tunnel = sshtunnel.SSHTunnelForwarder(**params)
+ self.ssh_tunnel.start()
+ except Exception as e:
+ self.logger.handlers = logger_handlers
+ self.logger.error("traceback: %r", traceback.format_exc())
+ click.secho(str(e), err=True, fg="red")
+ exit(1)
+ self.logger.handlers = logger_handlers
+
+ atexit.register(self.ssh_tunnel.stop)
+ host = "127.0.0.1"
+ port = self.ssh_tunnel.local_bind_ports[0]
+
+ if dsn:
+ dsn = make_dsn(dsn, host=host, port=port)
+
# Attempt to connect to the database.
# Note that passwd may be empty on the first attempt. If connection
# fails because of a missing or incorrect password, but we're allowed to
@@ -1222,7 +1287,7 @@ class PGCli:
"--list",
"list_databases",
is_flag=True,
- help="list " "available databases, then exit.",
+ help="list available databases, then exit.",
)
@click.option(
"--auto-vertical-output",
@@ -1235,6 +1300,11 @@ class PGCli:
type=click.Choice(["all", "moderate", "off"]),
help="Warn before running a destructive query.",
)
+@click.option(
+ "--ssh-tunnel",
+ default=None,
+ help="Open an SSH tunnel to the given address and connect to the database from it.",
+)
@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1)
@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1)
def cli(
@@ -1258,6 +1328,7 @@ def cli(
auto_vertical_output,
list_dsn,
warn,
+ ssh_tunnel: str,
):
if version:
print("Version:", __version__)
@@ -1294,6 +1365,15 @@ def cli(
)
exit(1)
+ if ssh_tunnel and not SSH_TUNNEL_SUPPORT:
+ click.secho(
+ 'Cannot open SSH tunnel, "sshtunnel" package was not found. '
+ "Please install pgcli with `pip install pgcli[sshtunnel]` if you want SSH tunnel support.",
+ err=True,
+ fg="red",
+ )
+ exit(1)
+
pgcli = PGCli(
prompt_passwd,
never_prompt,
@@ -1305,6 +1385,7 @@ def cli(
prompt_dsn=prompt_dsn,
auto_vertical_output=auto_vertical_output,
warn=warn,
+ ssh_tunnel_url=ssh_tunnel,
)
# Choose which ever one has a valid value.
@@ -1548,7 +1629,7 @@ def parse_service_info(service):
elif os.getenv("PGSYSCONFDIR"):
service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf")
else:
- service_file = expanduser("~/.pg_service.conf")
+ service_file = os.path.expanduser("~/.pg_service.conf")
if not service or not os.path.exists(service_file):
# nothing to do
return None, service_file
diff --git a/pgcli/packages/parseutils/tables.py b/pgcli/packages/parseutils/tables.py
index aaa676c..f2e1e42 100644
--- a/pgcli/packages/parseutils/tables.py
+++ b/pgcli/packages/parseutils/tables.py
@@ -63,17 +63,13 @@ def extract_from_part(parsed, stop_at_punctuation=True):
yield item
elif item.ttype is Keyword or item.ttype is Keyword.DML:
item_val = item.value.upper()
- if (
- item_val
- in (
- "COPY",
- "FROM",
- "INTO",
- "UPDATE",
- "TABLE",
- )
- or item_val.endswith("JOIN")
- ):
+ if item_val in (
+ "COPY",
+ "FROM",
+ "INTO",
+ "UPDATE",
+ "TABLE",
+ ) or item_val.endswith("JOIN"):
tbl_prefix_seen = True
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
# So this check here is necessary.
diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py
index 227e25c..e66c3dc 100644
--- a/pgcli/pgcompleter.py
+++ b/pgcli/pgcompleter.py
@@ -491,11 +491,14 @@ class PGCompleter(Completer):
def get_column_matches(self, suggestion, word_before_cursor):
tables = suggestion.table_refs
- do_qualify = suggestion.qualifiable and {
- "always": True,
- "never": False,
- "if_more_than_one_table": len(tables) > 1,
- }[self.qualify_columns]
+ do_qualify = (
+ suggestion.qualifiable
+ and {
+ "always": True,
+ "never": False,
+ "if_more_than_one_table": len(tables) > 1,
+ }[self.qualify_columns]
+ )
qualify = lambda col, tbl: (
(tbl + "." + self.case(col)) if do_qualify else self.case(col)
)
diff --git a/release_procedure.txt b/release_procedure.txt
deleted file mode 100644
index 9f3bff0..0000000
--- a/release_procedure.txt
+++ /dev/null
@@ -1,13 +0,0 @@
-# vi: ft=vimwiki
-
-* Bump the version number in pgcli/__init__.py
-* Commit with message: 'Releasing version X.X.X.'
-* Create a tag: git tag vX.X.X
-* Fix the image url in PyPI to point to github raw content. https://raw.githubusercontent.com/dbcli/pgcli/master/screenshots/image01.png
-* Create source dist tar ball: python setup.py sdist
-* Test this by installing it in a fresh new virtualenv. Run SanityChecks [./sanity_checks.txt].
-* Upload the source dist to PyPI: https://pypi.python.org/pypi/pgcli
-* pip install pgcli
-* Run SanityChecks.
-* Push the version back to github: git push --tags origin master
-* Done!
diff --git a/setup.py b/setup.py
index 0cbd192..ed6fb3c 100644
--- a/setup.py
+++ b/setup.py
@@ -39,7 +39,10 @@ setup(
description=description,
long_description=open("README.rst").read(),
install_requires=install_requirements,
- extras_require={"keyring": ["keyring >= 12.2.0"]},
+ extras_require={
+ "keyring": ["keyring >= 12.2.0"],
+ "sshtunnel": ["sshtunnel >= 0.4.0"],
+ },
python_requires=">=3.6",
entry_points="""
[console_scripts]
diff --git a/tests/features/steps/basic_commands.py b/tests/features/steps/basic_commands.py
index 7ca20f0..a7c99ee 100644
--- a/tests/features/steps/basic_commands.py
+++ b/tests/features/steps/basic_commands.py
@@ -97,9 +97,9 @@ def step_see_error_message(context):
@when("we send source command")
def step_send_source_command(context):
context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_")
- context.tmpfile_sql_help.write(br"\?")
+ context.tmpfile_sql_help.write(rb"\?")
context.tmpfile_sql_help.flush()
- context.cli.sendline(fr"\i {context.tmpfile_sql_help.name}")
+ context.cli.sendline(rf"\i {context.tmpfile_sql_help.name}")
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
diff --git a/tests/test_ssh_tunnel.py b/tests/test_ssh_tunnel.py
new file mode 100644
index 0000000..ae865f4
--- /dev/null
+++ b/tests/test_ssh_tunnel.py
@@ -0,0 +1,188 @@
+import os
+from unittest.mock import patch, MagicMock, ANY
+
+import pytest
+from configobj import ConfigObj
+from click.testing import CliRunner
+from sshtunnel import SSHTunnelForwarder
+
+from pgcli.main import cli, PGCli
+from pgcli.pgexecute import PGExecute
+
+
+@pytest.fixture
+def mock_ssh_tunnel_forwarder() -> MagicMock:
+ mock_ssh_tunnel_forwarder = MagicMock(
+ SSHTunnelForwarder, local_bind_ports=[1111], autospec=True
+ )
+ with patch(
+ "pgcli.main.sshtunnel.SSHTunnelForwarder",
+ return_value=mock_ssh_tunnel_forwarder,
+ ) as mock:
+ yield mock
+
+
+@pytest.fixture
+def mock_pgexecute() -> MagicMock:
+ with patch.object(PGExecute, "__init__", return_value=None) as mock_pgexecute:
+ yield mock_pgexecute
+
+
+def test_ssh_tunnel(
+ mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock
+) -> None:
+ # Test with just a host
+ tunnel_url = "some.host"
+ db_params = {
+ "database": "dbname",
+ "host": "db.host",
+ "user": "db_user",
+ "passwd": "db_passwd",
+ }
+ expected_tunnel_params = {
+ "local_bind_address": ("127.0.0.1",),
+ "remote_bind_address": (db_params["host"], 5432),
+ "ssh_address_or_host": (tunnel_url, 22),
+ "logger": ANY,
+ }
+
+ pgcli = PGCli(ssh_tunnel_url=tunnel_url)
+ pgcli.connect(**db_params)
+
+ mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params)
+ mock_ssh_tunnel_forwarder.return_value.start.assert_called_once()
+ mock_pgexecute.assert_called_once()
+
+ call_args, call_kwargs = mock_pgexecute.call_args
+ assert call_args == (
+ db_params["database"],
+ db_params["user"],
+ db_params["passwd"],
+ "127.0.0.1",
+ pgcli.ssh_tunnel.local_bind_ports[0],
+ "",
+ )
+ mock_ssh_tunnel_forwarder.reset_mock()
+ mock_pgexecute.reset_mock()
+
+ # Test with a full url and with a specific db port
+ tunnel_user = "tunnel_user"
+ tunnel_passwd = "tunnel_pass"
+ tunnel_host = "some.other.host"
+ tunnel_port = 1022
+ tunnel_url = f"ssh://{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}"
+ db_params["port"] = 1234
+
+ expected_tunnel_params["remote_bind_address"] = (
+ db_params["host"],
+ db_params["port"],
+ )
+ expected_tunnel_params["ssh_address_or_host"] = (tunnel_host, tunnel_port)
+ expected_tunnel_params["ssh_username"] = tunnel_user
+ expected_tunnel_params["ssh_password"] = tunnel_passwd
+
+ pgcli = PGCli(ssh_tunnel_url=tunnel_url)
+ pgcli.connect(**db_params)
+
+ mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params)
+ mock_ssh_tunnel_forwarder.return_value.start.assert_called_once()
+ mock_pgexecute.assert_called_once()
+
+ call_args, call_kwargs = mock_pgexecute.call_args
+ assert call_args == (
+ db_params["database"],
+ db_params["user"],
+ db_params["passwd"],
+ "127.0.0.1",
+ pgcli.ssh_tunnel.local_bind_ports[0],
+ "",
+ )
+ mock_ssh_tunnel_forwarder.reset_mock()
+ mock_pgexecute.reset_mock()
+
+ # Test with DSN
+ dsn = (
+ f"user={db_params['user']} password={db_params['passwd']} "
+ f"host={db_params['host']} port={db_params['port']}"
+ )
+
+ pgcli = PGCli(ssh_tunnel_url=tunnel_url)
+ pgcli.connect(dsn=dsn)
+
+ expected_dsn = (
+ f"user={db_params['user']} password={db_params['passwd']} "
+ f"host=127.0.0.1 port={pgcli.ssh_tunnel.local_bind_ports[0]}"
+ )
+
+ mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params)
+ mock_pgexecute.assert_called_once()
+
+ call_args, call_kwargs = mock_pgexecute.call_args
+ assert expected_dsn in call_args
+
+
+def test_cli_with_tunnel() -> None:
+ runner = CliRunner()
+ tunnel_url = "mytunnel"
+ with patch.object(
+ PGCli, "__init__", autospec=True, return_value=None
+ ) as mock_pgcli:
+ runner.invoke(cli, ["--ssh-tunnel", tunnel_url])
+ mock_pgcli.assert_called_once()
+ call_args, call_kwargs = mock_pgcli.call_args
+ assert call_kwargs["ssh_tunnel_url"] == tunnel_url
+
+
+def test_config(
+ tmpdir: os.PathLike, mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock
+) -> None:
+ pgclirc = str(tmpdir.join("rcfile"))
+
+ tunnel_user = "tunnel_user"
+ tunnel_passwd = "tunnel_pass"
+ tunnel_host = "tunnel.host"
+ tunnel_port = 1022
+ tunnel_url = f"{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}"
+
+ tunnel2_url = "tunnel2.host"
+
+ config = ConfigObj()
+ config.filename = pgclirc
+ config["ssh tunnels"] = {}
+ config["ssh tunnels"][r"\.com$"] = tunnel_url
+ config["ssh tunnels"][r"^hello-"] = tunnel2_url
+ config.write()
+
+ # Unmatched host
+ pgcli = PGCli(pgclirc_file=pgclirc)
+ pgcli.connect(host="unmatched.host")
+ mock_ssh_tunnel_forwarder.assert_not_called()
+
+ # Host matching first tunnel
+ pgcli = PGCli(pgclirc_file=pgclirc)
+ pgcli.connect(host="matched.host.com")
+ mock_ssh_tunnel_forwarder.assert_called_once()
+ call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args
+ assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port)
+ assert call_kwargs["ssh_username"] == tunnel_user
+ assert call_kwargs["ssh_password"] == tunnel_passwd
+ mock_ssh_tunnel_forwarder.reset_mock()
+
+ # Host matching second tunnel
+ pgcli = PGCli(pgclirc_file=pgclirc)
+ pgcli.connect(host="hello-i-am-matched")
+ mock_ssh_tunnel_forwarder.assert_called_once()
+
+ call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args
+ assert call_kwargs["ssh_address_or_host"] == (tunnel2_url, 22)
+ mock_ssh_tunnel_forwarder.reset_mock()
+
+ # Host matching both tunnels (will use the first one matched)
+ pgcli = PGCli(pgclirc_file=pgclirc)
+ pgcli.connect(host="hello-i-am-matched.com")
+ mock_ssh_tunnel_forwarder.assert_called_once()
+
+ call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args
+ assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port)
+ assert call_kwargs["ssh_username"] == tunnel_user
+ assert call_kwargs["ssh_password"] == tunnel_passwd