From e8e0960f454f47142162c94a083fa9efd19d4fd9 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 2 Mar 2022 13:22:04 +0100 Subject: Merging upstream version 3.4.0. Signed-off-by: Daniel Baumann --- pgcli/main.py | 89 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 85 insertions(+), 4 deletions(-) (limited to 'pgcli/main.py') diff --git a/pgcli/main.py b/pgcli/main.py index e4a2ee3..a72f708 100644 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -1,6 +1,5 @@ import platform import warnings -from os.path import expanduser from configobj import ConfigObj, ParseError from pgspecial.namedqueries import NamedQueries @@ -8,6 +7,7 @@ from .config import skip_initial_comment warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2") +import atexit import os import re import sys @@ -21,6 +21,8 @@ import datetime as dt import itertools import platform from time import time, sleep +from typing import Optional +from urllib.parse import urlparse keyring = None # keyring will be loaded later @@ -78,12 +80,21 @@ except ImportError: from getpass import getuser from psycopg2 import OperationalError, InterfaceError +from psycopg2.extensions import make_dsn, parse_dsn import psycopg2 from collections import namedtuple from textwrap import dedent +try: + import sshtunnel + + SSH_TUNNEL_SUPPORT = True +except ImportError: + SSH_TUNNEL_SUPPORT = False + + # Ref: https://stackoverflow.com/questions/30425105/filter-special-chars-such-as-color-codes-from-shell-output COLOR_CODE_REGEX = re.compile(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))") DEFAULT_MAX_FIELD_WIDTH = 500 @@ -168,8 +179,8 @@ class PGCli: prompt_dsn=None, auto_vertical_output=False, warn=None, + ssh_tunnel_url: Optional[str] = None, ): - self.force_passwd_prompt = force_passwd_prompt self.never_passwd_prompt = never_passwd_prompt self.pgexecute = pgexecute @@ -275,6 +286,10 @@ class PGCli: self.prompt_app = None + self.ssh_tunnel_config = c.get("ssh tunnels") + self.ssh_tunnel_url = ssh_tunnel_url + self.ssh_tunnel = None + def quit(self): raise PgCliQuitError @@ -585,6 +600,56 @@ class PGCli: return True return False + if dsn: + parsed_dsn = parse_dsn(dsn) + if "host" in parsed_dsn: + host = parsed_dsn["host"] + if "port" in parsed_dsn: + port = parsed_dsn["port"] + + if self.ssh_tunnel_config and not self.ssh_tunnel_url: + for db_host_regex, tunnel_url in self.ssh_tunnel_config.items(): + if re.search(db_host_regex, host): + self.ssh_tunnel_url = tunnel_url + break + + if self.ssh_tunnel_url: + # We add the protocol as urlparse doesn't find it by itself + if "://" not in self.ssh_tunnel_url: + self.ssh_tunnel_url = f"ssh://{self.ssh_tunnel_url}" + + tunnel_info = urlparse(self.ssh_tunnel_url) + params = { + "local_bind_address": ("127.0.0.1",), + "remote_bind_address": (host, int(port or 5432)), + "ssh_address_or_host": (tunnel_info.hostname, tunnel_info.port or 22), + "logger": self.logger, + } + if tunnel_info.username: + params["ssh_username"] = tunnel_info.username + if tunnel_info.password: + params["ssh_password"] = tunnel_info.password + + # Hack: sshtunnel adds a console handler to the logger, so we revert handlers. + # We can remove this when https://github.com/pahaz/sshtunnel/pull/250 is merged. + logger_handlers = self.logger.handlers.copy() + try: + self.ssh_tunnel = sshtunnel.SSHTunnelForwarder(**params) + self.ssh_tunnel.start() + except Exception as e: + self.logger.handlers = logger_handlers + self.logger.error("traceback: %r", traceback.format_exc()) + click.secho(str(e), err=True, fg="red") + exit(1) + self.logger.handlers = logger_handlers + + atexit.register(self.ssh_tunnel.stop) + host = "127.0.0.1" + port = self.ssh_tunnel.local_bind_ports[0] + + if dsn: + dsn = make_dsn(dsn, host=host, port=port) + # Attempt to connect to the database. # Note that passwd may be empty on the first attempt. If connection # fails because of a missing or incorrect password, but we're allowed to @@ -1222,7 +1287,7 @@ class PGCli: "--list", "list_databases", is_flag=True, - help="list " "available databases, then exit.", + help="list available databases, then exit.", ) @click.option( "--auto-vertical-output", @@ -1235,6 +1300,11 @@ class PGCli: type=click.Choice(["all", "moderate", "off"]), help="Warn before running a destructive query.", ) +@click.option( + "--ssh-tunnel", + default=None, + help="Open an SSH tunnel to the given address and connect to the database from it.", +) @click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1) @click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1) def cli( @@ -1258,6 +1328,7 @@ def cli( auto_vertical_output, list_dsn, warn, + ssh_tunnel: str, ): if version: print("Version:", __version__) @@ -1294,6 +1365,15 @@ def cli( ) exit(1) + if ssh_tunnel and not SSH_TUNNEL_SUPPORT: + click.secho( + 'Cannot open SSH tunnel, "sshtunnel" package was not found. ' + "Please install pgcli with `pip install pgcli[sshtunnel]` if you want SSH tunnel support.", + err=True, + fg="red", + ) + exit(1) + pgcli = PGCli( prompt_passwd, never_prompt, @@ -1305,6 +1385,7 @@ def cli( prompt_dsn=prompt_dsn, auto_vertical_output=auto_vertical_output, warn=warn, + ssh_tunnel_url=ssh_tunnel, ) # Choose which ever one has a valid value. @@ -1548,7 +1629,7 @@ def parse_service_info(service): elif os.getenv("PGSYSCONFDIR"): service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf") else: - service_file = expanduser("~/.pg_service.conf") + service_file = os.path.expanduser("~/.pg_service.conf") if not service or not os.path.exists(service_file): # nothing to do return None, service_file -- cgit v1.2.3