summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/features/steps/basic_commands.py4
-rw-r--r--tests/test_ssh_tunnel.py188
2 files changed, 190 insertions, 2 deletions
diff --git a/tests/features/steps/basic_commands.py b/tests/features/steps/basic_commands.py
index 7ca20f0..a7c99ee 100644
--- a/tests/features/steps/basic_commands.py
+++ b/tests/features/steps/basic_commands.py
@@ -97,9 +97,9 @@ def step_see_error_message(context):
@when("we send source command")
def step_send_source_command(context):
context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_")
- context.tmpfile_sql_help.write(br"\?")
+ context.tmpfile_sql_help.write(rb"\?")
context.tmpfile_sql_help.flush()
- context.cli.sendline(fr"\i {context.tmpfile_sql_help.name}")
+ context.cli.sendline(rf"\i {context.tmpfile_sql_help.name}")
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
diff --git a/tests/test_ssh_tunnel.py b/tests/test_ssh_tunnel.py
new file mode 100644
index 0000000..ae865f4
--- /dev/null
+++ b/tests/test_ssh_tunnel.py
@@ -0,0 +1,188 @@
+import os
+from unittest.mock import patch, MagicMock, ANY
+
+import pytest
+from configobj import ConfigObj
+from click.testing import CliRunner
+from sshtunnel import SSHTunnelForwarder
+
+from pgcli.main import cli, PGCli
+from pgcli.pgexecute import PGExecute
+
+
+@pytest.fixture
+def mock_ssh_tunnel_forwarder() -> MagicMock:
+ mock_ssh_tunnel_forwarder = MagicMock(
+ SSHTunnelForwarder, local_bind_ports=[1111], autospec=True
+ )
+ with patch(
+ "pgcli.main.sshtunnel.SSHTunnelForwarder",
+ return_value=mock_ssh_tunnel_forwarder,
+ ) as mock:
+ yield mock
+
+
+@pytest.fixture
+def mock_pgexecute() -> MagicMock:
+ with patch.object(PGExecute, "__init__", return_value=None) as mock_pgexecute:
+ yield mock_pgexecute
+
+
+def test_ssh_tunnel(
+ mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock
+) -> None:
+ # Test with just a host
+ tunnel_url = "some.host"
+ db_params = {
+ "database": "dbname",
+ "host": "db.host",
+ "user": "db_user",
+ "passwd": "db_passwd",
+ }
+ expected_tunnel_params = {
+ "local_bind_address": ("127.0.0.1",),
+ "remote_bind_address": (db_params["host"], 5432),
+ "ssh_address_or_host": (tunnel_url, 22),
+ "logger": ANY,
+ }
+
+ pgcli = PGCli(ssh_tunnel_url=tunnel_url)
+ pgcli.connect(**db_params)
+
+ mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params)
+ mock_ssh_tunnel_forwarder.return_value.start.assert_called_once()
+ mock_pgexecute.assert_called_once()
+
+ call_args, call_kwargs = mock_pgexecute.call_args
+ assert call_args == (
+ db_params["database"],
+ db_params["user"],
+ db_params["passwd"],
+ "127.0.0.1",
+ pgcli.ssh_tunnel.local_bind_ports[0],
+ "",
+ )
+ mock_ssh_tunnel_forwarder.reset_mock()
+ mock_pgexecute.reset_mock()
+
+ # Test with a full url and with a specific db port
+ tunnel_user = "tunnel_user"
+ tunnel_passwd = "tunnel_pass"
+ tunnel_host = "some.other.host"
+ tunnel_port = 1022
+ tunnel_url = f"ssh://{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}"
+ db_params["port"] = 1234
+
+ expected_tunnel_params["remote_bind_address"] = (
+ db_params["host"],
+ db_params["port"],
+ )
+ expected_tunnel_params["ssh_address_or_host"] = (tunnel_host, tunnel_port)
+ expected_tunnel_params["ssh_username"] = tunnel_user
+ expected_tunnel_params["ssh_password"] = tunnel_passwd
+
+ pgcli = PGCli(ssh_tunnel_url=tunnel_url)
+ pgcli.connect(**db_params)
+
+ mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params)
+ mock_ssh_tunnel_forwarder.return_value.start.assert_called_once()
+ mock_pgexecute.assert_called_once()
+
+ call_args, call_kwargs = mock_pgexecute.call_args
+ assert call_args == (
+ db_params["database"],
+ db_params["user"],
+ db_params["passwd"],
+ "127.0.0.1",
+ pgcli.ssh_tunnel.local_bind_ports[0],
+ "",
+ )
+ mock_ssh_tunnel_forwarder.reset_mock()
+ mock_pgexecute.reset_mock()
+
+ # Test with DSN
+ dsn = (
+ f"user={db_params['user']} password={db_params['passwd']} "
+ f"host={db_params['host']} port={db_params['port']}"
+ )
+
+ pgcli = PGCli(ssh_tunnel_url=tunnel_url)
+ pgcli.connect(dsn=dsn)
+
+ expected_dsn = (
+ f"user={db_params['user']} password={db_params['passwd']} "
+ f"host=127.0.0.1 port={pgcli.ssh_tunnel.local_bind_ports[0]}"
+ )
+
+ mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params)
+ mock_pgexecute.assert_called_once()
+
+ call_args, call_kwargs = mock_pgexecute.call_args
+ assert expected_dsn in call_args
+
+
+def test_cli_with_tunnel() -> None:
+ runner = CliRunner()
+ tunnel_url = "mytunnel"
+ with patch.object(
+ PGCli, "__init__", autospec=True, return_value=None
+ ) as mock_pgcli:
+ runner.invoke(cli, ["--ssh-tunnel", tunnel_url])
+ mock_pgcli.assert_called_once()
+ call_args, call_kwargs = mock_pgcli.call_args
+ assert call_kwargs["ssh_tunnel_url"] == tunnel_url
+
+
+def test_config(
+ tmpdir: os.PathLike, mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock
+) -> None:
+ pgclirc = str(tmpdir.join("rcfile"))
+
+ tunnel_user = "tunnel_user"
+ tunnel_passwd = "tunnel_pass"
+ tunnel_host = "tunnel.host"
+ tunnel_port = 1022
+ tunnel_url = f"{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}"
+
+ tunnel2_url = "tunnel2.host"
+
+ config = ConfigObj()
+ config.filename = pgclirc
+ config["ssh tunnels"] = {}
+ config["ssh tunnels"][r"\.com$"] = tunnel_url
+ config["ssh tunnels"][r"^hello-"] = tunnel2_url
+ config.write()
+
+ # Unmatched host
+ pgcli = PGCli(pgclirc_file=pgclirc)
+ pgcli.connect(host="unmatched.host")
+ mock_ssh_tunnel_forwarder.assert_not_called()
+
+ # Host matching first tunnel
+ pgcli = PGCli(pgclirc_file=pgclirc)
+ pgcli.connect(host="matched.host.com")
+ mock_ssh_tunnel_forwarder.assert_called_once()
+ call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args
+ assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port)
+ assert call_kwargs["ssh_username"] == tunnel_user
+ assert call_kwargs["ssh_password"] == tunnel_passwd
+ mock_ssh_tunnel_forwarder.reset_mock()
+
+ # Host matching second tunnel
+ pgcli = PGCli(pgclirc_file=pgclirc)
+ pgcli.connect(host="hello-i-am-matched")
+ mock_ssh_tunnel_forwarder.assert_called_once()
+
+ call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args
+ assert call_kwargs["ssh_address_or_host"] == (tunnel2_url, 22)
+ mock_ssh_tunnel_forwarder.reset_mock()
+
+ # Host matching both tunnels (will use the first one matched)
+ pgcli = PGCli(pgclirc_file=pgclirc)
+ pgcli.connect(host="hello-i-am-matched.com")
+ mock_ssh_tunnel_forwarder.assert_called_once()
+
+ call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args
+ assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port)
+ assert call_kwargs["ssh_username"] == tunnel_user
+ assert call_kwargs["ssh_password"] == tunnel_passwd