summaryrefslogtreecommitdiffstats
path: root/pgcli/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'pgcli/main.py')
-rw-r--r--pgcli/main.py89
1 files changed, 85 insertions, 4 deletions
diff --git a/pgcli/main.py b/pgcli/main.py
index e4a2ee3..a72f708 100644
--- a/pgcli/main.py
+++ b/pgcli/main.py
@@ -1,6 +1,5 @@
import platform
import warnings
-from os.path import expanduser
from configobj import ConfigObj, ParseError
from pgspecial.namedqueries import NamedQueries
@@ -8,6 +7,7 @@ from .config import skip_initial_comment
warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2")
+import atexit
import os
import re
import sys
@@ -21,6 +21,8 @@ import datetime as dt
import itertools
import platform
from time import time, sleep
+from typing import Optional
+from urllib.parse import urlparse
keyring = None # keyring will be loaded later
@@ -78,12 +80,21 @@ except ImportError:
from getpass import getuser
from psycopg2 import OperationalError, InterfaceError
+from psycopg2.extensions import make_dsn, parse_dsn
import psycopg2
from collections import namedtuple
from textwrap import dedent
+try:
+ import sshtunnel
+
+ SSH_TUNNEL_SUPPORT = True
+except ImportError:
+ SSH_TUNNEL_SUPPORT = False
+
+
# Ref: https://stackoverflow.com/questions/30425105/filter-special-chars-such-as-color-codes-from-shell-output
COLOR_CODE_REGEX = re.compile(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))")
DEFAULT_MAX_FIELD_WIDTH = 500
@@ -168,8 +179,8 @@ class PGCli:
prompt_dsn=None,
auto_vertical_output=False,
warn=None,
+ ssh_tunnel_url: Optional[str] = None,
):
-
self.force_passwd_prompt = force_passwd_prompt
self.never_passwd_prompt = never_passwd_prompt
self.pgexecute = pgexecute
@@ -275,6 +286,10 @@ class PGCli:
self.prompt_app = None
+ self.ssh_tunnel_config = c.get("ssh tunnels")
+ self.ssh_tunnel_url = ssh_tunnel_url
+ self.ssh_tunnel = None
+
def quit(self):
raise PgCliQuitError
@@ -585,6 +600,56 @@ class PGCli:
return True
return False
+ if dsn:
+ parsed_dsn = parse_dsn(dsn)
+ if "host" in parsed_dsn:
+ host = parsed_dsn["host"]
+ if "port" in parsed_dsn:
+ port = parsed_dsn["port"]
+
+ if self.ssh_tunnel_config and not self.ssh_tunnel_url:
+ for db_host_regex, tunnel_url in self.ssh_tunnel_config.items():
+ if re.search(db_host_regex, host):
+ self.ssh_tunnel_url = tunnel_url
+ break
+
+ if self.ssh_tunnel_url:
+ # We add the protocol as urlparse doesn't find it by itself
+ if "://" not in self.ssh_tunnel_url:
+ self.ssh_tunnel_url = f"ssh://{self.ssh_tunnel_url}"
+
+ tunnel_info = urlparse(self.ssh_tunnel_url)
+ params = {
+ "local_bind_address": ("127.0.0.1",),
+ "remote_bind_address": (host, int(port or 5432)),
+ "ssh_address_or_host": (tunnel_info.hostname, tunnel_info.port or 22),
+ "logger": self.logger,
+ }
+ if tunnel_info.username:
+ params["ssh_username"] = tunnel_info.username
+ if tunnel_info.password:
+ params["ssh_password"] = tunnel_info.password
+
+ # Hack: sshtunnel adds a console handler to the logger, so we revert handlers.
+ # We can remove this when https://github.com/pahaz/sshtunnel/pull/250 is merged.
+ logger_handlers = self.logger.handlers.copy()
+ try:
+ self.ssh_tunnel = sshtunnel.SSHTunnelForwarder(**params)
+ self.ssh_tunnel.start()
+ except Exception as e:
+ self.logger.handlers = logger_handlers
+ self.logger.error("traceback: %r", traceback.format_exc())
+ click.secho(str(e), err=True, fg="red")
+ exit(1)
+ self.logger.handlers = logger_handlers
+
+ atexit.register(self.ssh_tunnel.stop)
+ host = "127.0.0.1"
+ port = self.ssh_tunnel.local_bind_ports[0]
+
+ if dsn:
+ dsn = make_dsn(dsn, host=host, port=port)
+
# Attempt to connect to the database.
# Note that passwd may be empty on the first attempt. If connection
# fails because of a missing or incorrect password, but we're allowed to
@@ -1222,7 +1287,7 @@ class PGCli:
"--list",
"list_databases",
is_flag=True,
- help="list " "available databases, then exit.",
+ help="list available databases, then exit.",
)
@click.option(
"--auto-vertical-output",
@@ -1235,6 +1300,11 @@ class PGCli:
type=click.Choice(["all", "moderate", "off"]),
help="Warn before running a destructive query.",
)
+@click.option(
+ "--ssh-tunnel",
+ default=None,
+ help="Open an SSH tunnel to the given address and connect to the database from it.",
+)
@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1)
@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1)
def cli(
@@ -1258,6 +1328,7 @@ def cli(
auto_vertical_output,
list_dsn,
warn,
+ ssh_tunnel: str,
):
if version:
print("Version:", __version__)
@@ -1294,6 +1365,15 @@ def cli(
)
exit(1)
+ if ssh_tunnel and not SSH_TUNNEL_SUPPORT:
+ click.secho(
+ 'Cannot open SSH tunnel, "sshtunnel" package was not found. '
+ "Please install pgcli with `pip install pgcli[sshtunnel]` if you want SSH tunnel support.",
+ err=True,
+ fg="red",
+ )
+ exit(1)
+
pgcli = PGCli(
prompt_passwd,
never_prompt,
@@ -1305,6 +1385,7 @@ def cli(
prompt_dsn=prompt_dsn,
auto_vertical_output=auto_vertical_output,
warn=warn,
+ ssh_tunnel_url=ssh_tunnel,
)
# Choose which ever one has a valid value.
@@ -1548,7 +1629,7 @@ def parse_service_info(service):
elif os.getenv("PGSYSCONFDIR"):
service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf")
else:
- service_file = expanduser("~/.pg_service.conf")
+ service_file = os.path.expanduser("~/.pg_service.conf")
if not service or not os.path.exists(service_file):
# nothing to do
return None, service_file