diff options
Diffstat (limited to '')
-rw-r--r-- | tests/test_main.py | 476 |
1 files changed, 476 insertions, 0 deletions
diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..9b3a84b --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,476 @@ +import os +import platform +from unittest import mock + +import pytest + +try: + import setproctitle +except ImportError: + setproctitle = None + +from pgcli.main import ( + obfuscate_process_password, + format_output, + PGCli, + OutputSettings, + COLOR_CODE_REGEX, +) +from pgcli.pgexecute import PGExecute +from pgspecial.main import PAGER_OFF, PAGER_LONG_OUTPUT, PAGER_ALWAYS +from utils import dbtest, run +from collections import namedtuple + + +@pytest.mark.skipif(platform.system() == "Windows", reason="Not applicable in windows") +@pytest.mark.skipif(not setproctitle, reason="setproctitle not available") +def test_obfuscate_process_password(): + original_title = setproctitle.getproctitle() + + setproctitle.setproctitle("pgcli user=root password=secret host=localhost") + obfuscate_process_password() + title = setproctitle.getproctitle() + expected = "pgcli user=root password=xxxx host=localhost" + assert title == expected + + setproctitle.setproctitle("pgcli user=root password=top secret host=localhost") + obfuscate_process_password() + title = setproctitle.getproctitle() + expected = "pgcli user=root password=xxxx host=localhost" + assert title == expected + + setproctitle.setproctitle("pgcli user=root password=top secret") + obfuscate_process_password() + title = setproctitle.getproctitle() + expected = "pgcli user=root password=xxxx" + assert title == expected + + setproctitle.setproctitle("pgcli postgres://root:secret@localhost/db") + obfuscate_process_password() + title = setproctitle.getproctitle() + expected = "pgcli postgres://root:xxxx@localhost/db" + assert title == expected + + setproctitle.setproctitle(original_title) + + +def test_format_output(): + settings = OutputSettings(table_format="psql", dcmlfmt="d", floatfmt="g") + results = format_output( + "Title", [("abc", "def")], ["head1", "head2"], "test status", settings + ) + expected = [ + "Title", + "+-------+-------+", + "| head1 | head2 |", + "|-------+-------|", + "| abc | def |", + "+-------+-------+", + "test status", + ] + assert list(results) == expected + + +def test_format_output_truncate_on(): + settings = OutputSettings( + table_format="psql", dcmlfmt="d", floatfmt="g", max_field_width=10 + ) + results = format_output( + None, + [("first field value", "second field value")], + ["head1", "head2"], + None, + settings, + ) + expected = [ + "+------------+------------+", + "| head1 | head2 |", + "|------------+------------|", + "| first f... | second ... |", + "+------------+------------+", + ] + assert list(results) == expected + + +def test_format_output_truncate_off(): + settings = OutputSettings( + table_format="psql", dcmlfmt="d", floatfmt="g", max_field_width=None + ) + long_field_value = ("first field " * 100).strip() + results = format_output(None, [(long_field_value,)], ["head1"], None, settings) + lines = list(results) + assert lines[3] == f"| {long_field_value} |" + + +@dbtest +def test_format_array_output(executor): + statement = """ + SELECT + array[1, 2, 3]::bigint[] as bigint_array, + '{{1,2},{3,4}}'::numeric[] as nested_numeric_array, + '{å,魚,текст}'::text[] as 配列 + UNION ALL + SELECT '{}', NULL, array[NULL] + """ + results = run(executor, statement) + expected = [ + "+--------------+----------------------+--------------+", + "| bigint_array | nested_numeric_array | 配列 |", + "|--------------+----------------------+--------------|", + "| {1,2,3} | {{1,2},{3,4}} | {å,魚,текст} |", + "| {} | <null> | {<null>} |", + "+--------------+----------------------+--------------+", + "SELECT 2", + ] + assert list(results) == expected + + +@dbtest +def test_format_array_output_expanded(executor): + statement = """ + SELECT + array[1, 2, 3]::bigint[] as bigint_array, + '{{1,2},{3,4}}'::numeric[] as nested_numeric_array, + '{å,魚,текст}'::text[] as 配列 + UNION ALL + SELECT '{}', NULL, array[NULL] + """ + results = run(executor, statement, expanded=True) + expected = [ + "-[ RECORD 1 ]-------------------------", + "bigint_array | {1,2,3}", + "nested_numeric_array | {{1,2},{3,4}}", + "配列 | {å,魚,текст}", + "-[ RECORD 2 ]-------------------------", + "bigint_array | {}", + "nested_numeric_array | <null>", + "配列 | {<null>}", + "SELECT 2", + ] + assert "\n".join(results) == "\n".join(expected) + + +def test_format_output_auto_expand(): + settings = OutputSettings( + table_format="psql", dcmlfmt="d", floatfmt="g", max_width=100 + ) + table_results = format_output( + "Title", [("abc", "def")], ["head1", "head2"], "test status", settings + ) + table = [ + "Title", + "+-------+-------+", + "| head1 | head2 |", + "|-------+-------|", + "| abc | def |", + "+-------+-------+", + "test status", + ] + assert list(table_results) == table + expanded_results = format_output( + "Title", + [("abc", "def")], + ["head1", "head2"], + "test status", + settings._replace(max_width=1), + ) + expanded = [ + "Title", + "-[ RECORD 1 ]-------------------------", + "head1 | abc", + "head2 | def", + "test status", + ] + assert "\n".join(expanded_results) == "\n".join(expanded) + + +termsize = namedtuple("termsize", ["rows", "columns"]) +test_line = "-" * 10 +test_data = [ + (10, 10, "\n".join([test_line] * 7)), + (10, 10, "\n".join([test_line] * 6)), + (10, 10, "\n".join([test_line] * 5)), + (10, 10, "-" * 11), + (10, 10, "-" * 10), + (10, 10, "-" * 9), +] + +# 4 lines are reserved at the bottom of the terminal for pgcli's prompt +use_pager_when_on = [True, True, False, True, False, False] + +# Can be replaced with pytest.param once we can upgrade pytest after Python 3.4 goes EOL +test_ids = [ + "Output longer than terminal height", + "Output equal to terminal height", + "Output shorter than terminal height", + "Output longer than terminal width", + "Output equal to terminal width", + "Output shorter than terminal width", +] + + +@pytest.fixture +def pset_pager_mocks(): + cli = PGCli() + cli.watch_command = None + with mock.patch("pgcli.main.click.echo") as mock_echo, mock.patch( + "pgcli.main.click.echo_via_pager" + ) as mock_echo_via_pager, mock.patch.object(cli, "prompt_app") as mock_app: + + yield cli, mock_echo, mock_echo_via_pager, mock_app + + +@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids) +def test_pset_pager_off(term_height, term_width, text, pset_pager_mocks): + cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks + mock_cli.output.get_size.return_value = termsize( + rows=term_height, columns=term_width + ) + + with mock.patch.object(cli.pgspecial, "pager_config", PAGER_OFF): + cli.echo_via_pager(text) + + mock_echo.assert_called() + mock_echo_via_pager.assert_not_called() + + +@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids) +def test_pset_pager_always(term_height, term_width, text, pset_pager_mocks): + cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks + mock_cli.output.get_size.return_value = termsize( + rows=term_height, columns=term_width + ) + + with mock.patch.object(cli.pgspecial, "pager_config", PAGER_ALWAYS): + cli.echo_via_pager(text) + + mock_echo.assert_not_called() + mock_echo_via_pager.assert_called() + + +pager_on_test_data = [l + (r,) for l, r in zip(test_data, use_pager_when_on)] + + +@pytest.mark.parametrize( + "term_height,term_width,text,use_pager", pager_on_test_data, ids=test_ids +) +def test_pset_pager_on(term_height, term_width, text, use_pager, pset_pager_mocks): + cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks + mock_cli.output.get_size.return_value = termsize( + rows=term_height, columns=term_width + ) + + with mock.patch.object(cli.pgspecial, "pager_config", PAGER_LONG_OUTPUT): + cli.echo_via_pager(text) + + if use_pager: + mock_echo.assert_not_called() + mock_echo_via_pager.assert_called() + else: + mock_echo_via_pager.assert_not_called() + mock_echo.assert_called() + + +@pytest.mark.parametrize( + "text,expected_length", + [ + ( + "22200K .......\u001b[0m\u001b[91m... .......... ...\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... .........\u001b[0m\u001b[91m.\u001b[0m\u001b[91m \u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... 50% 28.6K 12m55s", + 78, + ), + ("=\u001b[m=", 2), + ("-\u001b]23\u0007-", 2), + ], +) +def test_color_pattern(text, expected_length, pset_pager_mocks): + cli = pset_pager_mocks[0] + assert len(COLOR_CODE_REGEX.sub("", text)) == expected_length + + +@dbtest +def test_i_works(tmpdir, executor): + sqlfile = tmpdir.join("test.sql") + sqlfile.write("SELECT NOW()") + rcfile = str(tmpdir.join("rcfile")) + cli = PGCli(pgexecute=executor, pgclirc_file=rcfile) + statement = r"\i {0}".format(sqlfile) + run(executor, statement, pgspecial=cli.pgspecial) + + +@dbtest +def test_watch_works(executor): + cli = PGCli(pgexecute=executor) + + def run_with_watch( + query, target_call_count=1, expected_output="", expected_timing=None + ): + """ + :param query: Input to the CLI + :param target_call_count: Number of times the user lets the command run before Ctrl-C + :param expected_output: Substring expected to be found for each executed query + :param expected_timing: value `time.sleep` expected to be called with on every invocation + """ + with mock.patch.object(cli, "echo_via_pager") as mock_echo, mock.patch( + "pgcli.main.sleep" + ) as mock_sleep: + mock_sleep.side_effect = [None] * (target_call_count - 1) + [ + KeyboardInterrupt + ] + cli.handle_watch_command(query) + # Validate that sleep was called with the right timing + for i in range(target_call_count - 1): + assert mock_sleep.call_args_list[i][0][0] == expected_timing + # Validate that the output of the query was expected + assert mock_echo.call_count == target_call_count + for i in range(target_call_count): + assert expected_output in mock_echo.call_args_list[i][0][0] + + # With no history, it errors. + with mock.patch("pgcli.main.click.secho") as mock_secho: + cli.handle_watch_command(r"\watch 2") + mock_secho.assert_called() + assert ( + r"\watch cannot be used with an empty query" + in mock_secho.call_args_list[0][0][0] + ) + + # Usage 1: Run a query and then re-run it with \watch across two prompts. + run_with_watch("SELECT 111", expected_output="111") + run_with_watch( + "\\watch 10", target_call_count=2, expected_output="111", expected_timing=10 + ) + + # Usage 2: Run a query and \watch via the same prompt. + run_with_watch( + "SELECT 222; \\watch 4", + target_call_count=3, + expected_output="222", + expected_timing=4, + ) + + # Usage 3: Re-run the last watched command with a new timing + run_with_watch( + "\\watch 5", target_call_count=4, expected_output="222", expected_timing=5 + ) + + +def test_missing_rc_dir(tmpdir): + rcfile = str(tmpdir.join("subdir").join("rcfile")) + + PGCli(pgclirc_file=rcfile) + assert os.path.exists(rcfile) + + +def test_quoted_db_uri(tmpdir): + with mock.patch.object(PGCli, "connect") as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + cli.connect_uri("postgres://bar%5E:%5Dfoo@baz.com/testdb%5B") + mock_connect.assert_called_with( + database="testdb[", host="baz.com", user="bar^", passwd="]foo" + ) + + +def test_pg_service_file(tmpdir): + + with mock.patch.object(PGCli, "connect") as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + with open(tmpdir.join(".pg_service.conf").strpath, "w") as service_conf: + service_conf.write( + """File begins with a comment + that is not a comment + # or maybe a comment after all + because psql is crazy + + [myservice] + host=a_host + user=a_user + port=5433 + password=much_secure + dbname=a_dbname + + [my_other_service] + host=b_host + user=b_user + port=5435 + dbname=b_dbname + """ + ) + os.environ["PGSERVICEFILE"] = tmpdir.join(".pg_service.conf").strpath + cli.connect_service("myservice", "another_user") + mock_connect.assert_called_with( + database="a_dbname", + host="a_host", + user="another_user", + port="5433", + passwd="much_secure", + ) + + with mock.patch.object(PGExecute, "__init__") as mock_pgexecute: + mock_pgexecute.return_value = None + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + os.environ["PGPASSWORD"] = "very_secure" + cli.connect_service("my_other_service", None) + mock_pgexecute.assert_called_with( + "b_dbname", + "b_user", + "very_secure", + "b_host", + "5435", + "", + application_name="pgcli", + ) + del os.environ["PGPASSWORD"] + del os.environ["PGSERVICEFILE"] + + +def test_ssl_db_uri(tmpdir): + with mock.patch.object(PGCli, "connect") as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + cli.connect_uri( + "postgres://bar%5E:%5Dfoo@baz.com/testdb%5B?" + "sslmode=verify-full&sslcert=m%79.pem&sslkey=my-key.pem&sslrootcert=c%61.pem" + ) + mock_connect.assert_called_with( + database="testdb[", + host="baz.com", + user="bar^", + passwd="]foo", + sslmode="verify-full", + sslcert="my.pem", + sslkey="my-key.pem", + sslrootcert="ca.pem", + ) + + +def test_port_db_uri(tmpdir): + with mock.patch.object(PGCli, "connect") as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + cli.connect_uri("postgres://bar:foo@baz.com:2543/testdb") + mock_connect.assert_called_with( + database="testdb", host="baz.com", user="bar", passwd="foo", port="2543" + ) + + +def test_multihost_db_uri(tmpdir): + with mock.patch.object(PGCli, "connect") as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + cli.connect_uri( + "postgres://bar:foo@baz1.com:2543,baz2.com:2543,baz3.com:2543/testdb" + ) + mock_connect.assert_called_with( + database="testdb", + host="baz1.com,baz2.com,baz3.com", + user="bar", + passwd="foo", + port="2543,2543,2543", + ) + + +def test_application_name_db_uri(tmpdir): + with mock.patch.object(PGExecute, "__init__") as mock_pgexecute: + mock_pgexecute.return_value = None + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + cli.connect_uri("postgres://bar@baz.com/?application_name=cow") + mock_pgexecute.assert_called_with( + "bar", "bar", "", "baz.com", "", "", application_name="cow" + ) |