import os from unittest.mock import patch, MagicMock, ANY import pytest from configobj import ConfigObj from click.testing import CliRunner from sshtunnel import SSHTunnelForwarder from pgcli.main import cli, PGCli from pgcli.pgexecute import PGExecute @pytest.fixture def mock_ssh_tunnel_forwarder() -> MagicMock: mock_ssh_tunnel_forwarder = MagicMock( SSHTunnelForwarder, local_bind_ports=[1111], autospec=True ) with patch( "pgcli.main.sshtunnel.SSHTunnelForwarder", return_value=mock_ssh_tunnel_forwarder, ) as mock: yield mock @pytest.fixture def mock_pgexecute() -> MagicMock: with patch.object(PGExecute, "__init__", return_value=None) as mock_pgexecute: yield mock_pgexecute def test_ssh_tunnel( mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock ) -> None: # Test with just a host tunnel_url = "some.host" db_params = { "database": "dbname", "host": "db.host", "user": "db_user", "passwd": "db_passwd", } expected_tunnel_params = { "local_bind_address": ("127.0.0.1",), "remote_bind_address": (db_params["host"], 5432), "ssh_address_or_host": (tunnel_url, 22), "logger": ANY, } pgcli = PGCli(ssh_tunnel_url=tunnel_url) pgcli.connect(**db_params) mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) mock_ssh_tunnel_forwarder.return_value.start.assert_called_once() mock_pgexecute.assert_called_once() call_args, call_kwargs = mock_pgexecute.call_args assert call_args == ( db_params["database"], db_params["user"], db_params["passwd"], "127.0.0.1", pgcli.ssh_tunnel.local_bind_ports[0], "", ) mock_ssh_tunnel_forwarder.reset_mock() mock_pgexecute.reset_mock() # Test with a full url and with a specific db port tunnel_user = "tunnel_user" tunnel_passwd = "tunnel_pass" tunnel_host = "some.other.host" tunnel_port = 1022 tunnel_url = f"ssh://{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}" db_params["port"] = 1234 expected_tunnel_params["remote_bind_address"] = ( db_params["host"], db_params["port"], ) expected_tunnel_params["ssh_address_or_host"] = (tunnel_host, tunnel_port) expected_tunnel_params["ssh_username"] = tunnel_user expected_tunnel_params["ssh_password"] = tunnel_passwd pgcli = PGCli(ssh_tunnel_url=tunnel_url) pgcli.connect(**db_params) mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) mock_ssh_tunnel_forwarder.return_value.start.assert_called_once() mock_pgexecute.assert_called_once() call_args, call_kwargs = mock_pgexecute.call_args assert call_args == ( db_params["database"], db_params["user"], db_params["passwd"], "127.0.0.1", pgcli.ssh_tunnel.local_bind_ports[0], "", ) mock_ssh_tunnel_forwarder.reset_mock() mock_pgexecute.reset_mock() # Test with DSN dsn = ( f"user={db_params['user']} password={db_params['passwd']} " f"host={db_params['host']} port={db_params['port']}" ) pgcli = PGCli(ssh_tunnel_url=tunnel_url) pgcli.connect(dsn=dsn) expected_dsn = ( f"user={db_params['user']} password={db_params['passwd']} " f"host=127.0.0.1 port={pgcli.ssh_tunnel.local_bind_ports[0]}" ) mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) mock_pgexecute.assert_called_once() call_args, call_kwargs = mock_pgexecute.call_args assert expected_dsn in call_args def test_cli_with_tunnel() -> None: runner = CliRunner() tunnel_url = "mytunnel" with patch.object( PGCli, "__init__", autospec=True, return_value=None ) as mock_pgcli: runner.invoke(cli, ["--ssh-tunnel", tunnel_url]) mock_pgcli.assert_called_once() call_args, call_kwargs = mock_pgcli.call_args assert call_kwargs["ssh_tunnel_url"] == tunnel_url def test_config( tmpdir: os.PathLike, mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock ) -> None: pgclirc = str(tmpdir.join("rcfile")) tunnel_user = "tunnel_user" tunnel_passwd = "tunnel_pass" tunnel_host = "tunnel.host" tunnel_port = 1022 tunnel_url = f"{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}" tunnel2_url = "tunnel2.host" config = ConfigObj() config.filename = pgclirc config["ssh tunnels"] = {} config["ssh tunnels"][r"\.com$"] = tunnel_url config["ssh tunnels"][r"^hello-"] = tunnel2_url config.write() # Unmatched host pgcli = PGCli(pgclirc_file=pgclirc) pgcli.connect(host="unmatched.host") mock_ssh_tunnel_forwarder.assert_not_called() # Host matching first tunnel pgcli = PGCli(pgclirc_file=pgclirc) pgcli.connect(host="matched.host.com") mock_ssh_tunnel_forwarder.assert_called_once() call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port) assert call_kwargs["ssh_username"] == tunnel_user assert call_kwargs["ssh_password"] == tunnel_passwd mock_ssh_tunnel_forwarder.reset_mock() # Host matching second tunnel pgcli = PGCli(pgclirc_file=pgclirc) pgcli.connect(host="hello-i-am-matched") mock_ssh_tunnel_forwarder.assert_called_once() call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args assert call_kwargs["ssh_address_or_host"] == (tunnel2_url, 22) mock_ssh_tunnel_forwarder.reset_mock() # Host matching both tunnels (will use the first one matched) pgcli = PGCli(pgclirc_file=pgclirc) pgcli.connect(host="hello-i-am-matched.com") mock_ssh_tunnel_forwarder.assert_called_once() call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port) assert call_kwargs["ssh_username"] == tunnel_user assert call_kwargs["ssh_password"] == tunnel_passwd