diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 03:06:41 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 03:06:41 +0000 |
commit | 708c091a8b4db6a55be1c96ae33ee0da632b269f (patch) | |
tree | aac9e87c59cb8bc7e3cd429e9200c3ca017cb591 /tests | |
parent | Initial commit. (diff) | |
download | pgcli-708c091a8b4db6a55be1c96ae33ee0da632b269f.tar.xz pgcli-708c091a8b4db6a55be1c96ae33ee0da632b269f.zip |
Adding upstream version 4.0.1.upstream/4.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
53 files changed, 7550 insertions, 0 deletions
diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..33cddf2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,52 @@ +import os +import pytest +from utils import ( + POSTGRES_HOST, + POSTGRES_PORT, + POSTGRES_USER, + POSTGRES_PASSWORD, + create_db, + db_connection, + drop_tables, +) +import pgcli.pgexecute + + +@pytest.fixture(scope="function") +def connection(): + create_db("_test_db") + connection = db_connection("_test_db") + yield connection + + drop_tables(connection) + connection.close() + + +@pytest.fixture +def cursor(connection): + with connection.cursor() as cur: + return cur + + +@pytest.fixture +def executor(connection): + return pgcli.pgexecute.PGExecute( + database="_test_db", + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + port=POSTGRES_PORT, + dsn=None, + ) + + +@pytest.fixture +def exception_formatter(): + return lambda e: str(e) + + +@pytest.fixture(scope="session", autouse=True) +def temp_config(tmpdir_factory): + # this function runs on start of test session. + # use temporary directory for config home so user config will not be used + os.environ["XDG_CONFIG_HOME"] = str(tmpdir_factory.mktemp("data")) diff --git a/tests/features/__init__.py b/tests/features/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/features/__init__.py diff --git a/tests/features/auto_vertical.feature b/tests/features/auto_vertical.feature new file mode 100644 index 0000000..aa95718 --- /dev/null +++ b/tests/features/auto_vertical.feature @@ -0,0 +1,12 @@ +Feature: auto_vertical mode: + on, off + + Scenario: auto_vertical on with small query + When we run dbcli with --auto-vertical-output + and we execute a small query + then we see small results in horizontal format + + Scenario: auto_vertical on with large query + When we run dbcli with --auto-vertical-output + and we execute a large query + then we see large results in vertical format diff --git a/tests/features/basic_commands.feature b/tests/features/basic_commands.feature new file mode 100644 index 0000000..ee497b9 --- /dev/null +++ b/tests/features/basic_commands.feature @@ -0,0 +1,81 @@ +Feature: run the cli, + call the help command, + exit the cli + + Scenario: run "\?" command + When we send "\?" command + then we see help output + + Scenario: run source command + When we send source command + then we see help output + + Scenario: run partial select command + When we send partial select command + then we see error message + then we see dbcli prompt + + Scenario: check our application_name + When we run query to check application_name + then we see found + + Scenario: run the cli and exit + When we send "ctrl + d" + then dbcli exits + + Scenario: confirm exit when a transaction is ongoing + When we begin transaction + and we try to send "ctrl + d" + then we see ongoing transaction message + when we send "c" + then dbcli exits + + Scenario: cancel exit when a transaction is ongoing + When we begin transaction + and we try to send "ctrl + d" + then we see ongoing transaction message + when we send "a" + then we see dbcli prompt + when we rollback transaction + when we send "ctrl + d" + then dbcli exits + + Scenario: interrupt current query via "ctrl + c" + When we send sleep query + and we send "ctrl + c" + then we see cancelled query warning + when we check for any non-idle sleep queries + then we don't see any non-idle sleep queries + + Scenario: list databases + When we list databases + then we see list of databases + + Scenario: run the cli with --username + When we launch dbcli using --username + and we send "\?" command + then we see help output + + Scenario: run the cli with --user + When we launch dbcli using --user + and we send "\?" command + then we see help output + + Scenario: run the cli with --port + When we launch dbcli using --port + and we send "\?" command + then we see help output + + Scenario: run the cli with --password + When we launch dbcli using --password + then we send password + and we see dbcli prompt + when we send "\?" command + then we see help output + + Scenario: run the cli with dsn and password + When we launch dbcli using dsn_password + then we send password + and we see dbcli prompt + when we send "\?" command + then we see help output diff --git a/tests/features/crud_database.feature b/tests/features/crud_database.feature new file mode 100644 index 0000000..87da4e3 --- /dev/null +++ b/tests/features/crud_database.feature @@ -0,0 +1,17 @@ +Feature: manipulate databases: + create, drop, connect, disconnect + + Scenario: create and drop temporary database + When we create database + then we see database created + when we drop database + then we respond to the destructive warning: y + then we see database dropped + when we connect to dbserver + then we see database connected + + Scenario: connect and disconnect from test database + When we connect to test database + then we see database connected + when we connect to dbserver + then we see database connected diff --git a/tests/features/crud_table.feature b/tests/features/crud_table.feature new file mode 100644 index 0000000..8a43c5c --- /dev/null +++ b/tests/features/crud_table.feature @@ -0,0 +1,45 @@ +Feature: manipulate tables: + create, insert, update, select, delete from, drop + + Scenario: create, insert, select from, update, drop table + When we connect to test database + then we see database connected + when we create table + then we see table created + when we insert into table + then we see record inserted + when we select from table + then we see data selected: initial + when we update table + then we see record updated + when we select from table + then we see data selected: updated + when we delete from table + then we respond to the destructive warning: y + then we see record deleted + when we drop table + then we respond to the destructive warning: y + then we see table dropped + when we connect to dbserver + then we see database connected + + Scenario: transaction handling, with cancelling on a destructive warning. + When we connect to test database + then we see database connected + when we create table + then we see table created + when we begin transaction + then we see transaction began + when we insert into table + then we see record inserted + when we delete from table + then we respond to the destructive warning: n + when we select from table + then we see data selected: initial + when we rollback transaction + then we see transaction rolled back + when we select from table + then we see select output without data + when we drop table + then we respond to the destructive warning: y + then we see table dropped diff --git a/tests/features/db_utils.py b/tests/features/db_utils.py new file mode 100644 index 0000000..595c6c2 --- /dev/null +++ b/tests/features/db_utils.py @@ -0,0 +1,87 @@ +from psycopg import connect + + +def create_db( + hostname="localhost", username=None, password=None, dbname=None, port=None +): + """Create test database. + + :param hostname: string + :param username: string + :param password: string + :param dbname: string + :param port: int + :return: + + """ + cn = create_cn(hostname, password, username, "postgres", port) + + cn.autocommit = True + with cn.cursor() as cr: + cr.execute(f"drop database if exists {dbname}") + cr.execute(f"create database {dbname}") + + cn.close() + + cn = create_cn(hostname, password, username, dbname, port) + return cn + + +def create_cn(hostname, password, username, dbname, port): + """ + Open connection to database. + :param hostname: + :param password: + :param username: + :param dbname: string + :return: psycopg2.connection + """ + cn = connect( + host=hostname, user=username, dbname=dbname, password=password, port=port + ) + + print(f"Created connection: {cn.info.get_parameters()}.") + return cn + + +def pgbouncer_available(hostname="localhost", password=None, username="postgres"): + cn = None + try: + cn = create_cn(hostname, password, username, "pgbouncer", 6432) + return True + except: + print("Pgbouncer is not available.") + finally: + if cn: + cn.close() + return False + + +def drop_db(hostname="localhost", username=None, password=None, dbname=None, port=None): + """ + Drop database. + :param hostname: string + :param username: string + :param password: string + :param dbname: string + """ + cn = create_cn(hostname, password, username, "postgres", port) + + # Needed for DB drop. + cn.autocommit = True + + with cn.cursor() as cr: + cr.execute(f"drop database if exists {dbname}") + + close_cn(cn) + + +def close_cn(cn=None): + """ + Close connection. + :param connection: psycopg2.connection + """ + if cn: + cn_params = cn.info.get_parameters() + cn.close() + print(f"Closed connection: {cn_params}.") diff --git a/tests/features/environment.py b/tests/features/environment.py new file mode 100644 index 0000000..50ac5fa --- /dev/null +++ b/tests/features/environment.py @@ -0,0 +1,227 @@ +import copy +import os +import sys +import db_utils as dbutils +import fixture_utils as fixutils +import pexpect +import tempfile +import shutil +import signal + + +from steps import wrappers + + +def before_all(context): + """Set env parameters.""" + env_old = copy.deepcopy(dict(os.environ)) + os.environ["LINES"] = "100" + os.environ["COLUMNS"] = "100" + os.environ["PAGER"] = "cat" + os.environ["EDITOR"] = "ex" + os.environ["VISUAL"] = "ex" + os.environ["PROMPT_TOOLKIT_NO_CPR"] = "1" + + context.package_root = os.path.abspath( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + ) + fixture_dir = os.path.join(context.package_root, "tests/features/fixture_data") + + print("package root:", context.package_root) + print("fixture dir:", fixture_dir) + + os.environ["COVERAGE_PROCESS_START"] = os.path.join( + context.package_root, ".coveragerc" + ) + + context.exit_sent = False + + vi = "_".join([str(x) for x in sys.version_info[:3]]) + db_name = context.config.userdata.get("pg_test_db", "pgcli_behave_tests") + db_name_full = f"{db_name}_{vi}" + + # Store get params from config. + context.conf = { + "host": context.config.userdata.get( + "pg_test_host", os.getenv("PGHOST", "localhost") + ), + "user": context.config.userdata.get( + "pg_test_user", os.getenv("PGUSER", "postgres") + ), + "pass": context.config.userdata.get( + "pg_test_pass", os.getenv("PGPASSWORD", None) + ), + "port": context.config.userdata.get( + "pg_test_port", os.getenv("PGPORT", "5432") + ), + "cli_command": ( + context.config.userdata.get("pg_cli_command", None) + or '{python} -c "{startup}"'.format( + python=sys.executable, + startup="; ".join( + [ + "import coverage", + "coverage.process_startup()", + "import pgcli.main", + "pgcli.main.cli(auto_envvar_prefix='BEHAVE')", + ] + ), + ) + ), + "dbname": db_name_full, + "dbname_tmp": db_name_full + "_tmp", + "vi": vi, + "pager_boundary": "---boundary---", + } + os.environ["PAGER"] = "{0} {1} {2}".format( + sys.executable, + os.path.join(context.package_root, "tests/features/wrappager.py"), + context.conf["pager_boundary"], + ) + + # Store old env vars. + context.pgenv = { + "PGDATABASE": os.environ.get("PGDATABASE", None), + "PGUSER": os.environ.get("PGUSER", None), + "PGHOST": os.environ.get("PGHOST", None), + "PGPASSWORD": os.environ.get("PGPASSWORD", None), + "PGPORT": os.environ.get("PGPORT", None), + "XDG_CONFIG_HOME": os.environ.get("XDG_CONFIG_HOME", None), + "PGSERVICEFILE": os.environ.get("PGSERVICEFILE", None), + } + + # Set new env vars. + os.environ["PGDATABASE"] = context.conf["dbname"] + os.environ["PGUSER"] = context.conf["user"] + os.environ["PGHOST"] = context.conf["host"] + os.environ["PGPORT"] = context.conf["port"] + os.environ["PGSERVICEFILE"] = os.path.join(fixture_dir, "mock_pg_service.conf") + + if context.conf["pass"]: + os.environ["PGPASSWORD"] = context.conf["pass"] + else: + if "PGPASSWORD" in os.environ: + del os.environ["PGPASSWORD"] + os.environ["BEHAVE_WARN"] = "moderate" + + context.cn = dbutils.create_db( + context.conf["host"], + context.conf["user"], + context.conf["pass"], + context.conf["dbname"], + context.conf["port"], + ) + context.pgbouncer_available = dbutils.pgbouncer_available( + hostname=context.conf["host"], + password=context.conf["pass"], + username=context.conf["user"], + ) + context.fixture_data = fixutils.read_fixture_files() + + # use temporary directory as config home + context.env_config_home = tempfile.mkdtemp(prefix="pgcli_home_") + os.environ["XDG_CONFIG_HOME"] = context.env_config_home + show_env_changes(env_old, dict(os.environ)) + + +def show_env_changes(env_old, env_new): + """Print out all test-specific env values.""" + print("--- os.environ changed values: ---") + all_keys = env_old.keys() | env_new.keys() + for k in sorted(all_keys): + old_value = env_old.get(k, "") + new_value = env_new.get(k, "") + if new_value and old_value != new_value: + print(f'{k}="{new_value}"') + print("-" * 20) + + +def after_all(context): + """ + Unset env parameters. + """ + dbutils.close_cn(context.cn) + dbutils.drop_db( + context.conf["host"], + context.conf["user"], + context.conf["pass"], + context.conf["dbname"], + context.conf["port"], + ) + + # Remove temp config directory + shutil.rmtree(context.env_config_home) + + # Restore env vars. + for k, v in context.pgenv.items(): + if k in os.environ and v is None: + del os.environ[k] + elif v: + os.environ[k] = v + + +def before_step(context, _): + context.atprompt = False + + +def is_known_problem(scenario): + """TODO: why is this not working in 3.12?""" + if sys.version_info >= (3, 12): + return scenario.name in ( + 'interrupt current query via "ctrl + c"', + "run the cli with --username", + "run the cli with --user", + "run the cli with --port", + ) + return False + + +def before_scenario(context, scenario): + if scenario.name == "list databases": + # not using the cli for that + return + if is_known_problem(scenario): + scenario.skip() + currentdb = None + if "pgbouncer" in scenario.feature.tags: + if context.pgbouncer_available: + os.environ["PGDATABASE"] = "pgbouncer" + os.environ["PGPORT"] = "6432" + currentdb = "pgbouncer" + else: + scenario.skip() + else: + # set env vars back to normal test database + os.environ["PGDATABASE"] = context.conf["dbname"] + os.environ["PGPORT"] = context.conf["port"] + wrappers.run_cli(context, currentdb=currentdb) + wrappers.wait_prompt(context) + + +def after_scenario(context, scenario): + """Cleans up after each scenario completes.""" + if hasattr(context, "cli") and context.cli and not context.exit_sent: + # Quit nicely. + if not getattr(context, "atprompt", False): + dbname = context.currentdb + context.cli.expect_exact(f"{dbname}>", timeout=5) + try: + context.cli.sendcontrol("c") + context.cli.sendcontrol("d") + except Exception as x: + print("Failed cleanup after scenario:") + print(x) + try: + context.cli.expect_exact(pexpect.EOF, timeout=5) + except pexpect.TIMEOUT: + print(f"--- after_scenario {scenario.name}: kill cli") + context.cli.kill(signal.SIGKILL) + if hasattr(context, "tmpfile_sql_help") and context.tmpfile_sql_help: + context.tmpfile_sql_help.close() + context.tmpfile_sql_help = None + + +# # TODO: uncomment to debug a failure +# def after_step(context, step): +# if step.status == "failed": +# import pdb; pdb.set_trace() diff --git a/tests/features/expanded.feature b/tests/features/expanded.feature new file mode 100644 index 0000000..e486048 --- /dev/null +++ b/tests/features/expanded.feature @@ -0,0 +1,29 @@ +Feature: expanded mode: + on, off, auto + + Scenario: expanded on + When we prepare the test data + and we set expanded on + and we select from table + then we see expanded data selected + when we drop table + then we respond to the destructive warning: y + then we see table dropped + + Scenario: expanded off + When we prepare the test data + and we set expanded off + and we select from table + then we see nonexpanded data selected + when we drop table + then we respond to the destructive warning: y + then we see table dropped + + Scenario: expanded auto + When we prepare the test data + and we set expanded auto + and we select from table + then we see auto data selected + when we drop table + then we respond to the destructive warning: y + then we see table dropped diff --git a/tests/features/fixture_data/help.txt b/tests/features/fixture_data/help.txt new file mode 100644 index 0000000..bebb976 --- /dev/null +++ b/tests/features/fixture_data/help.txt @@ -0,0 +1,25 @@ ++--------------------------+------------------------------------------------+ +| Command | Description | +|--------------------------+------------------------------------------------| +| \# | Refresh auto-completions. | +| \? | Show Help. | +| \T [format] | Change the table format used to output results | +| \c[onnect] database_name | Change to a new database. | +| \d [pattern] | List or describe tables, views and sequences. | +| \dT[S+] [pattern] | List data types | +| \df[+] [pattern] | List functions. | +| \di[+] [pattern] | List indexes. | +| \dn[+] [pattern] | List schemas. | +| \ds[+] [pattern] | List sequences. | +| \dt[+] [pattern] | List tables. | +| \du[+] [pattern] | List roles. | +| \dv[+] [pattern] | List views. | +| \e [file] | Edit the query with external editor. | +| \l | List databases. | +| \n[+] [name] | List or execute named queries. | +| \nd [name [query]] | Delete a named query. | +| \ns name query | Save a named query. | +| \refresh | Refresh auto-completions. | +| \timing | Toggle timing of commands. | +| \x | Toggle expanded output. | ++--------------------------+------------------------------------------------+ diff --git a/tests/features/fixture_data/help_commands.txt b/tests/features/fixture_data/help_commands.txt new file mode 100644 index 0000000..e076661 --- /dev/null +++ b/tests/features/fixture_data/help_commands.txt @@ -0,0 +1,64 @@ +Command +Description +\# +Refresh auto-completions. +\? +Show Commands. +\T [format] +Change the table format used to output results +\c[onnect] database_name +Change to a new database. +\copy [tablename] to/from [filename] +Copy data between a file and a table. +\d[+] [pattern] +List or describe tables, views and sequences. +\dT[S+] [pattern] +List data types +\db[+] [pattern] +List tablespaces. +\df[+] [pattern] +List functions. +\di[+] [pattern] +List indexes. +\dm[+] [pattern] +List materialized views. +\dn[+] [pattern] +List schemas. +\ds[+] [pattern] +List sequences. +\dt[+] [pattern] +List tables. +\du[+] [pattern] +List roles. +\dv[+] [pattern] +List views. +\dx[+] [pattern] +List extensions. +\e [file] +Edit the query with external editor. +\h +Show SQL syntax and help. +\i filename +Execute commands from file. +\l +List databases. +\n[+] [name] [param1 param2 ...] +List or execute named queries. +\nd [name] +Delete a named query. +\ns name query +Save a named query. +\o [filename] +Send all query results to file. +\pager [command] +Set PAGER. Print the query results via PAGER. +\pset [key] [value] +A limited version of traditional \pset +\refresh +Refresh auto-completions. +\sf[+] FUNCNAME +Show a function's definition. +\timing +Toggle timing of commands. +\x +Toggle expanded output. diff --git a/tests/features/fixture_data/mock_pg_service.conf b/tests/features/fixture_data/mock_pg_service.conf new file mode 100644 index 0000000..15f9811 --- /dev/null +++ b/tests/features/fixture_data/mock_pg_service.conf @@ -0,0 +1,4 @@ +[mock_postgres] +dbname=postgres +host=localhost +user=postgres diff --git a/tests/features/fixture_utils.py b/tests/features/fixture_utils.py new file mode 100644 index 0000000..70b603d --- /dev/null +++ b/tests/features/fixture_utils.py @@ -0,0 +1,28 @@ +import os +import codecs + + +def read_fixture_lines(filename): + """ + Read lines of text from file. + :param filename: string name + :return: list of strings + """ + lines = [] + for line in codecs.open(filename, "rb", encoding="utf-8"): + lines.append(line.strip()) + return lines + + +def read_fixture_files(): + """Read all files inside fixture_data directory.""" + current_dir = os.path.dirname(__file__) + fixture_dir = os.path.join(current_dir, "fixture_data/") + print(f"reading fixture data: {fixture_dir}") + fixture_dict = {} + for filename in os.listdir(fixture_dir): + if filename not in [".", ".."]: + fullname = os.path.join(fixture_dir, filename) + fixture_dict[filename] = read_fixture_lines(fullname) + + return fixture_dict diff --git a/tests/features/iocommands.feature b/tests/features/iocommands.feature new file mode 100644 index 0000000..dad7d10 --- /dev/null +++ b/tests/features/iocommands.feature @@ -0,0 +1,17 @@ +Feature: I/O commands + + Scenario: edit sql in file with external editor + When we start external editor providing a file name + and we type sql in the editor + and we exit the editor + then we see dbcli prompt + and we see the sql in prompt + + Scenario: tee output from query + When we tee output + and we wait for prompt + and we query "select 123456" + and we wait for prompt + and we stop teeing output + and we wait for prompt + then we see 123456 in tee output diff --git a/tests/features/named_queries.feature b/tests/features/named_queries.feature new file mode 100644 index 0000000..74201b9 --- /dev/null +++ b/tests/features/named_queries.feature @@ -0,0 +1,10 @@ +Feature: named queries: + save, use and delete named queries + + Scenario: save, use and delete named queries + When we connect to test database + then we see database connected + when we save a named query + then we see the named query saved + when we delete a named query + then we see the named query deleted diff --git a/tests/features/pgbouncer.feature b/tests/features/pgbouncer.feature new file mode 100644 index 0000000..14cc5ad --- /dev/null +++ b/tests/features/pgbouncer.feature @@ -0,0 +1,12 @@ +@pgbouncer +Feature: run pgbouncer, + call the help command, + exit the cli + + Scenario: run "show help" command + When we send "show help" command + then we see the pgbouncer help output + + Scenario: run the cli and exit + When we send "ctrl + d" + then dbcli exits diff --git a/tests/features/specials.feature b/tests/features/specials.feature new file mode 100644 index 0000000..63c5cdc --- /dev/null +++ b/tests/features/specials.feature @@ -0,0 +1,6 @@ +Feature: Special commands + + Scenario: run refresh command + When we refresh completions + and we wait for prompt + then we see completions refresh started diff --git a/tests/features/steps/__init__.py b/tests/features/steps/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/features/steps/__init__.py diff --git a/tests/features/steps/auto_vertical.py b/tests/features/steps/auto_vertical.py new file mode 100644 index 0000000..d7cdccd --- /dev/null +++ b/tests/features/steps/auto_vertical.py @@ -0,0 +1,99 @@ +from textwrap import dedent +from behave import then, when +import wrappers + + +@when("we run dbcli with {arg}") +def step_run_cli_with_arg(context, arg): + wrappers.run_cli(context, run_args=arg.split("=")) + + +@when("we execute a small query") +def step_execute_small_query(context): + context.cli.sendline("select 1") + + +@when("we execute a large query") +def step_execute_large_query(context): + context.cli.sendline("select {}".format(",".join([str(n) for n in range(1, 50)]))) + + +@then("we see small results in horizontal format") +def step_see_small_results(context): + wrappers.expect_pager( + context, + dedent( + """\ + +----------+\r + | ?column? |\r + |----------|\r + | 1 |\r + +----------+\r + SELECT 1\r + """ + ), + timeout=5, + ) + + +@then("we see large results in vertical format") +def step_see_large_results(context): + wrappers.expect_pager( + context, + dedent( + """\ + -[ RECORD 1 ]-------------------------\r + ?column? | 1\r + ?column? | 2\r + ?column? | 3\r + ?column? | 4\r + ?column? | 5\r + ?column? | 6\r + ?column? | 7\r + ?column? | 8\r + ?column? | 9\r + ?column? | 10\r + ?column? | 11\r + ?column? | 12\r + ?column? | 13\r + ?column? | 14\r + ?column? | 15\r + ?column? | 16\r + ?column? | 17\r + ?column? | 18\r + ?column? | 19\r + ?column? | 20\r + ?column? | 21\r + ?column? | 22\r + ?column? | 23\r + ?column? | 24\r + ?column? | 25\r + ?column? | 26\r + ?column? | 27\r + ?column? | 28\r + ?column? | 29\r + ?column? | 30\r + ?column? | 31\r + ?column? | 32\r + ?column? | 33\r + ?column? | 34\r + ?column? | 35\r + ?column? | 36\r + ?column? | 37\r + ?column? | 38\r + ?column? | 39\r + ?column? | 40\r + ?column? | 41\r + ?column? | 42\r + ?column? | 43\r + ?column? | 44\r + ?column? | 45\r + ?column? | 46\r + ?column? | 47\r + ?column? | 48\r + ?column? | 49\r + SELECT 1\r + """ + ), + timeout=5, + ) diff --git a/tests/features/steps/basic_commands.py b/tests/features/steps/basic_commands.py new file mode 100644 index 0000000..687bdc0 --- /dev/null +++ b/tests/features/steps/basic_commands.py @@ -0,0 +1,231 @@ +""" +Steps for behavioral style tests are defined in this module. +Each step is defined by the string decorating it. +This string is used to call the step in "*.feature" file. +""" + +import pexpect +import subprocess +import tempfile + +from behave import when, then +from textwrap import dedent +import wrappers + + +@when("we list databases") +def step_list_databases(context): + cmd = ["pgcli", "--list"] + context.cmd_output = subprocess.check_output(cmd, cwd=context.package_root) + + +@then("we see list of databases") +def step_see_list_databases(context): + assert b"List of databases" in context.cmd_output + assert b"postgres" in context.cmd_output + context.cmd_output = None + + +@when("we run dbcli") +def step_run_cli(context): + wrappers.run_cli(context) + + +@when("we launch dbcli using {arg}") +def step_run_cli_using_arg(context, arg): + prompt_check = False + currentdb = None + if arg == "--username": + arg = "--username={}".format(context.conf["user"]) + if arg == "--user": + arg = "--user={}".format(context.conf["user"]) + if arg == "--port": + arg = "--port={}".format(context.conf["port"]) + if arg == "--password": + arg = "--password" + prompt_check = False + # This uses the mock_pg_service.conf file in fixtures folder. + if arg == "dsn_password": + arg = "service=mock_postgres --password" + prompt_check = False + currentdb = "postgres" + wrappers.run_cli( + context, run_args=[arg], prompt_check=prompt_check, currentdb=currentdb + ) + + +@when("we wait for prompt") +def step_wait_prompt(context): + wrappers.wait_prompt(context) + + +@when('we send "ctrl + d"') +def step_ctrl_d(context): + """ + Send Ctrl + D to hopefully exit. + """ + step_try_to_ctrl_d(context) + context.cli.expect(pexpect.EOF, timeout=5) + context.exit_sent = True + + +@when('we try to send "ctrl + d"') +def step_try_to_ctrl_d(context): + """ + Send Ctrl + D, perhaps exiting, perhaps not (if a transaction is + ongoing). + """ + # turn off pager before exiting + context.cli.sendcontrol("c") + context.cli.sendline(r"\pset pager off") + wrappers.wait_prompt(context) + context.cli.sendcontrol("d") + + +@when('we send "ctrl + c"') +def step_ctrl_c(context): + """Send Ctrl + c to hopefully interrupt.""" + context.cli.sendcontrol("c") + + +@then("we see cancelled query warning") +def step_see_cancelled_query_warning(context): + """ + Make sure we receive the warning that the current query was cancelled. + """ + wrappers.expect_exact(context, "cancelled query", timeout=2) + + +@then("we see ongoing transaction message") +def step_see_ongoing_transaction_error(context): + """ + Make sure we receive the warning that a transaction is ongoing. + """ + context.cli.expect("A transaction is ongoing.", timeout=2) + + +@when("we send sleep query") +def step_send_sleep_15_seconds(context): + """ + Send query to sleep for 15 seconds. + """ + context.cli.sendline("select pg_sleep(15)") + + +@when("we check for any non-idle sleep queries") +def step_check_for_active_sleep_queries(context): + """ + Send query to check for any non-idle pg_sleep queries. + """ + context.cli.sendline( + "select state from pg_stat_activity where query not like '%pg_stat_activity%' and query like '%pg_sleep%' and state != 'idle';" + ) + + +@then("we don't see any non-idle sleep queries") +def step_no_active_sleep_queries(context): + """Confirm that any pg_sleep queries are either idle or not active.""" + wrappers.expect_exact( + context, + context.conf["pager_boundary"] + + "\r" + + dedent( + """ + +-------+\r + | state |\r + |-------|\r + +-------+\r + SELECT 0\r + """ + ) + + context.conf["pager_boundary"], + timeout=5, + ) + + +@when(r'we send "\?" command') +def step_send_help(context): + r""" + Send \? to see help. + """ + context.cli.sendline(r"\?") + + +@when("we send partial select command") +def step_send_partial_select_command(context): + """ + Send `SELECT a` to see completion. + """ + context.cli.sendline("SELECT a") + + +@then("we see error message") +def step_see_error_message(context): + wrappers.expect_exact(context, 'column "a" does not exist', timeout=2) + + +@when("we send source command") +def step_send_source_command(context): + context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_") + context.tmpfile_sql_help.write(rb"\?") + context.tmpfile_sql_help.flush() + context.cli.sendline(rf"\i {context.tmpfile_sql_help.name}") + wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) + + +@when("we run query to check application_name") +def step_check_application_name(context): + context.cli.sendline( + "SELECT 'found' FROM pg_stat_activity WHERE application_name = 'pgcli' HAVING COUNT(*) > 0;" + ) + + +@then("we see found") +def step_see_found(context): + wrappers.expect_exact( + context, + context.conf["pager_boundary"] + + "\r" + + dedent( + """ + +----------+\r + | ?column? |\r + |----------|\r + | found |\r + +----------+\r + SELECT 1\r + """ + ) + + context.conf["pager_boundary"], + timeout=5, + ) + + +@then("we respond to the destructive warning: {response}") +def step_resppond_to_destructive_command(context, response): + """Respond to destructive command.""" + wrappers.expect_exact( + context, + "You're about to run a destructive command.\r\nDo you want to proceed? [y/N]:", + timeout=2, + ) + context.cli.sendline(response.strip()) + + +@then("we send password") +def step_send_password(context): + wrappers.expect_exact(context, "Password for", timeout=5) + context.cli.sendline(context.conf["pass"] or "DOES NOT MATTER") + + +@when('we send "{text}"') +def step_send_text(context, text): + context.cli.sendline(text) + # Try to detect whether we are exiting. If so, set `exit_sent` + # so that `after_scenario` correctly cleans up. + try: + context.cli.expect(pexpect.EOF, timeout=0.2) + except pexpect.TIMEOUT: + pass + else: + context.exit_sent = True diff --git a/tests/features/steps/crud_database.py b/tests/features/steps/crud_database.py new file mode 100644 index 0000000..87cdc85 --- /dev/null +++ b/tests/features/steps/crud_database.py @@ -0,0 +1,93 @@ +""" +Steps for behavioral style tests are defined in this module. +Each step is defined by the string decorating it. +This string is used to call the step in "*.feature" file. +""" +import pexpect + +from behave import when, then +import wrappers + + +@when("we create database") +def step_db_create(context): + """ + Send create database. + """ + context.cli.sendline("create database {};".format(context.conf["dbname_tmp"])) + + context.response = {"database_name": context.conf["dbname_tmp"]} + + +@when("we drop database") +def step_db_drop(context): + """ + Send drop database. + """ + context.cli.sendline("drop database {};".format(context.conf["dbname_tmp"])) + + +@when("we connect to test database") +def step_db_connect_test(context): + """ + Send connect to database. + """ + db_name = context.conf["dbname"] + context.cli.sendline(f"\\connect {db_name}") + + +@when("we connect to dbserver") +def step_db_connect_dbserver(context): + """ + Send connect to database. + """ + context.cli.sendline("\\connect postgres") + context.currentdb = "postgres" + + +@then("dbcli exits") +def step_wait_exit(context): + """ + Make sure the cli exits. + """ + wrappers.expect_exact(context, pexpect.EOF, timeout=5) + + +@then("we see dbcli prompt") +def step_see_prompt(context): + """ + Wait to see the prompt. + """ + db_name = getattr(context, "currentdb", context.conf["dbname"]) + wrappers.expect_exact(context, f"{db_name}>", timeout=5) + context.atprompt = True + + +@then("we see help output") +def step_see_help(context): + for expected_line in context.fixture_data["help_commands.txt"]: + wrappers.expect_exact(context, expected_line, timeout=2) + + +@then("we see database created") +def step_see_db_created(context): + """ + Wait to see create database output. + """ + wrappers.expect_pager(context, "CREATE DATABASE\r\n", timeout=5) + + +@then("we see database dropped") +def step_see_db_dropped(context): + """ + Wait to see drop database output. + """ + wrappers.expect_pager(context, "DROP DATABASE\r\n", timeout=2) + + +@then("we see database connected") +def step_see_db_connected(context): + """ + Wait to see drop database output. + """ + wrappers.expect_exact(context, "You are now connected to database", timeout=2) diff --git a/tests/features/steps/crud_table.py b/tests/features/steps/crud_table.py new file mode 100644 index 0000000..27d543e --- /dev/null +++ b/tests/features/steps/crud_table.py @@ -0,0 +1,185 @@ +""" +Steps for behavioral style tests are defined in this module. +Each step is defined by the string decorating it. +This string is used to call the step in "*.feature" file. +""" + +from behave import when, then +from textwrap import dedent +import wrappers + + +INITIAL_DATA = "xxx" +UPDATED_DATA = "yyy" + + +@when("we create table") +def step_create_table(context): + """ + Send create table. + """ + context.cli.sendline("create table a(x text);") + + +@when("we insert into table") +def step_insert_into_table(context): + """ + Send insert into table. + """ + context.cli.sendline(f"""insert into a(x) values('{INITIAL_DATA}');""") + + +@when("we update table") +def step_update_table(context): + """ + Send insert into table. + """ + context.cli.sendline( + f"""update a set x = '{UPDATED_DATA}' where x = '{INITIAL_DATA}';""" + ) + + +@when("we select from table") +def step_select_from_table(context): + """ + Send select from table. + """ + context.cli.sendline("select * from a;") + + +@when("we delete from table") +def step_delete_from_table(context): + """ + Send deete from table. + """ + context.cli.sendline(f"""delete from a where x = '{UPDATED_DATA}';""") + + +@when("we drop table") +def step_drop_table(context): + """ + Send drop table. + """ + context.cli.sendline("drop table a;") + + +@when("we alter the table") +def step_alter_table(context): + """ + Alter the table by adding a column. + """ + context.cli.sendline("""alter table a add column y varchar;""") + + +@when("we begin transaction") +def step_begin_transaction(context): + """ + Begin transaction + """ + context.cli.sendline("begin;") + + +@when("we rollback transaction") +def step_rollback_transaction(context): + """ + Rollback transaction + """ + context.cli.sendline("rollback;") + + +@then("we see table created") +def step_see_table_created(context): + """ + Wait to see create table output. + """ + wrappers.expect_pager(context, "CREATE TABLE\r\n", timeout=2) + + +@then("we see record inserted") +def step_see_record_inserted(context): + """ + Wait to see insert output. + """ + wrappers.expect_pager(context, "INSERT 0 1\r\n", timeout=2) + + +@then("we see record updated") +def step_see_record_updated(context): + """ + Wait to see update output. + """ + wrappers.expect_pager(context, "UPDATE 1\r\n", timeout=2) + + +@then("we see data selected: {data}") +def step_see_data_selected(context, data): + """ + Wait to see select output with initial or updated data. + """ + x = UPDATED_DATA if data == "updated" else INITIAL_DATA + wrappers.expect_pager( + context, + dedent( + f"""\ + +-----+\r + | x |\r + |-----|\r + | {x} |\r + +-----+\r + SELECT 1\r + """ + ), + timeout=1, + ) + + +@then("we see select output without data") +def step_see_no_data_selected(context): + """ + Wait to see select output without data. + """ + wrappers.expect_pager( + context, + dedent( + """\ + +---+\r + | x |\r + |---|\r + +---+\r + SELECT 0\r + """ + ), + timeout=1, + ) + + +@then("we see record deleted") +def step_see_data_deleted(context): + """ + Wait to see delete output. + """ + wrappers.expect_pager(context, "DELETE 1\r\n", timeout=2) + + +@then("we see table dropped") +def step_see_table_dropped(context): + """ + Wait to see drop output. + """ + wrappers.expect_pager(context, "DROP TABLE\r\n", timeout=2) + + +@then("we see transaction began") +def step_see_transaction_began(context): + """ + Wait to see transaction began. + """ + wrappers.expect_pager(context, "BEGIN\r\n", timeout=2) + + +@then("we see transaction rolled back") +def step_see_transaction_rolled_back(context): + """ + Wait to see transaction rollback. + """ + wrappers.expect_pager(context, "ROLLBACK\r\n", timeout=2) diff --git a/tests/features/steps/expanded.py b/tests/features/steps/expanded.py new file mode 100644 index 0000000..302cab9 --- /dev/null +++ b/tests/features/steps/expanded.py @@ -0,0 +1,70 @@ +"""Steps for behavioral style tests are defined in this module. + +Each step is defined by the string decorating it. This string is used +to call the step in "*.feature" file. + +""" + +from behave import when, then +from textwrap import dedent +import wrappers + + +@when("we prepare the test data") +def step_prepare_data(context): + """Create table, insert a record.""" + context.cli.sendline("drop table if exists a;") + wrappers.expect_exact( + context, + "You're about to run a destructive command.\r\nDo you want to proceed? [y/N]:", + timeout=2, + ) + context.cli.sendline("y") + + wrappers.wait_prompt(context) + context.cli.sendline("create table a(x integer, y real, z numeric(10, 4));") + wrappers.expect_pager(context, "CREATE TABLE\r\n", timeout=2) + context.cli.sendline("""insert into a(x, y, z) values(1, 1.0, 1.0);""") + wrappers.expect_pager(context, "INSERT 0 1\r\n", timeout=2) + + +@when("we set expanded {mode}") +def step_set_expanded(context, mode): + """Set expanded to mode.""" + context.cli.sendline("\\" + f"x {mode}") + wrappers.expect_exact(context, "Expanded display is", timeout=2) + wrappers.wait_prompt(context) + + +@then("we see {which} data selected") +def step_see_data(context, which): + """Select data from expanded test table.""" + if which == "expanded": + wrappers.expect_pager( + context, + dedent( + """\ + -[ RECORD 1 ]-------------------------\r + x | 1\r + y | 1.0\r + z | 1.0000\r + SELECT 1\r + """ + ), + timeout=1, + ) + else: + wrappers.expect_pager( + context, + dedent( + """\ + +---+-----+--------+\r + | x | y | z |\r + |---+-----+--------|\r + | 1 | 1.0 | 1.0000 |\r + +---+-----+--------+\r + SELECT 1\r + """ + ), + timeout=1, + ) diff --git a/tests/features/steps/iocommands.py b/tests/features/steps/iocommands.py new file mode 100644 index 0000000..a614490 --- /dev/null +++ b/tests/features/steps/iocommands.py @@ -0,0 +1,80 @@ +import os +import os.path + +from behave import when, then +import wrappers + + +@when("we start external editor providing a file name") +def step_edit_file(context): + """Edit file with external editor.""" + context.editor_file_name = os.path.join( + context.package_root, "test_file_{0}.sql".format(context.conf["vi"]) + ) + if os.path.exists(context.editor_file_name): + os.remove(context.editor_file_name) + context.cli.sendline(r"\e {}".format(os.path.basename(context.editor_file_name))) + wrappers.expect_exact( + context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2 + ) + wrappers.expect_exact(context, ":", timeout=2) + + +@when("we type sql in the editor") +def step_edit_type_sql(context): + context.cli.sendline("i") + context.cli.sendline("select * from abc") + context.cli.sendline(".") + wrappers.expect_exact(context, ":", timeout=2) + + +@when("we exit the editor") +def step_edit_quit(context): + context.cli.sendline("x") + wrappers.expect_exact(context, "written", timeout=2) + + +@then("we see the sql in prompt") +def step_edit_done_sql(context): + for match in "select * from abc".split(" "): + wrappers.expect_exact(context, match, timeout=1) + # Cleanup the command line. + context.cli.sendcontrol("c") + # Cleanup the edited file. + if context.editor_file_name and os.path.exists(context.editor_file_name): + os.remove(context.editor_file_name) + context.atprompt = True + + +@when("we tee output") +def step_tee_ouptut(context): + context.tee_file_name = os.path.join( + context.package_root, "tee_file_{0}.sql".format(context.conf["vi"]) + ) + if os.path.exists(context.tee_file_name): + os.remove(context.tee_file_name) + context.cli.sendline(r"\o {}".format(os.path.basename(context.tee_file_name))) + wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) + wrappers.expect_exact(context, "Writing to file", timeout=5) + wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) + wrappers.expect_exact(context, "Time", timeout=5) + + +@when('we query "select 123456"') +def step_query_select_123456(context): + context.cli.sendline("select 123456") + + +@when("we stop teeing output") +def step_notee_output(context): + context.cli.sendline(r"\o") + wrappers.expect_exact(context, "Time", timeout=5) + + +@then("we see 123456 in tee output") +def step_see_123456_in_ouput(context): + with open(context.tee_file_name) as f: + assert "123456" in f.read() + if os.path.exists(context.tee_file_name): + os.remove(context.tee_file_name) + context.atprompt = True diff --git a/tests/features/steps/named_queries.py b/tests/features/steps/named_queries.py new file mode 100644 index 0000000..3f52859 --- /dev/null +++ b/tests/features/steps/named_queries.py @@ -0,0 +1,57 @@ +""" +Steps for behavioral style tests are defined in this module. +Each step is defined by the string decorating it. +This string is used to call the step in "*.feature" file. +""" + +from behave import when, then +import wrappers + + +@when("we save a named query") +def step_save_named_query(context): + """ + Send \ns command + """ + context.cli.sendline("\\ns foo SELECT 12345") + + +@when("we use a named query") +def step_use_named_query(context): + """ + Send \n command + """ + context.cli.sendline("\\n foo") + + +@when("we delete a named query") +def step_delete_named_query(context): + """ + Send \nd command + """ + context.cli.sendline("\\nd foo") + + +@then("we see the named query saved") +def step_see_named_query_saved(context): + """ + Wait to see query saved. + """ + wrappers.expect_exact(context, "Saved.", timeout=2) + + +@then("we see the named query executed") +def step_see_named_query_executed(context): + """ + Wait to see select output. + """ + wrappers.expect_exact(context, "12345", timeout=1) + wrappers.expect_exact(context, "SELECT 1", timeout=1) + + +@then("we see the named query deleted") +def step_see_named_query_deleted(context): + """ + Wait to see query deleted. + """ + wrappers.expect_pager(context, "foo: Deleted\r\n", timeout=1) diff --git a/tests/features/steps/pgbouncer.py b/tests/features/steps/pgbouncer.py new file mode 100644 index 0000000..f156982 --- /dev/null +++ b/tests/features/steps/pgbouncer.py @@ -0,0 +1,22 @@ +""" +Steps for behavioral style tests are defined in this module. +Each step is defined by the string decorating it. +This string is used to call the step in "*.feature" file. +""" + +from behave import when, then +import wrappers + + +@when('we send "show help" command') +def step_send_help_command(context): + context.cli.sendline("show help") + + +@then("we see the pgbouncer help output") +def see_pgbouncer_help(context): + wrappers.expect_exact( + context, + "SHOW HELP|CONFIG|DATABASES|POOLS|CLIENTS|SERVERS|USERS|VERSION", + timeout=3, + ) diff --git a/tests/features/steps/specials.py b/tests/features/steps/specials.py new file mode 100644 index 0000000..a85f371 --- /dev/null +++ b/tests/features/steps/specials.py @@ -0,0 +1,31 @@ +""" +Steps for behavioral style tests are defined in this module. +Each step is defined by the string decorating it. +This string is used to call the step in "*.feature" file. +""" + +from behave import when, then +import wrappers + + +@when("we refresh completions") +def step_refresh_completions(context): + """ + Send refresh command. + """ + context.cli.sendline("\\refresh") + + +@then("we see completions refresh started") +def step_see_refresh_started(context): + """ + Wait to see refresh output. + """ + wrappers.expect_pager( + context, + [ + "Auto-completion refresh started in the background.\r\n", + "Auto-completion refresh restarted.\r\n", + ], + timeout=2, + ) diff --git a/tests/features/steps/wrappers.py b/tests/features/steps/wrappers.py new file mode 100644 index 0000000..3ebcc92 --- /dev/null +++ b/tests/features/steps/wrappers.py @@ -0,0 +1,71 @@ +import re +import pexpect +from pgcli.main import COLOR_CODE_REGEX +import textwrap + +from io import StringIO + + +def expect_exact(context, expected, timeout): + timedout = False + try: + context.cli.expect_exact(expected, timeout=timeout) + except pexpect.TIMEOUT: + timedout = True + if timedout: + # Strip color codes out of the output. + actual = re.sub(r"\x1b\[([0-9A-Za-z;?])+[m|K]?", "", context.cli.before) + raise Exception( + textwrap.dedent( + """\ + Expected: + --- + {0!r} + --- + Actual: + --- + {1!r} + --- + Full log: + --- + {2!r} + --- + """ + ).format(expected, actual, context.logfile.getvalue()) + ) + + +def expect_pager(context, expected, timeout): + formatted = expected if isinstance(expected, list) else [expected] + formatted = [ + f"{context.conf['pager_boundary']}\r\n{t}{context.conf['pager_boundary']}\r\n" + for t in formatted + ] + + expect_exact( + context, + formatted, + timeout=timeout, + ) + + +def run_cli(context, run_args=None, prompt_check=True, currentdb=None): + """Run the process using pexpect.""" + run_args = run_args or [] + cli_cmd = context.conf.get("cli_command") + cmd_parts = [cli_cmd] + run_args + cmd = " ".join(cmd_parts) + context.cli = pexpect.spawnu(cmd, cwd=context.package_root) + context.logfile = StringIO() + context.cli.logfile = context.logfile + context.exit_sent = False + context.currentdb = currentdb or context.conf["dbname"] + context.cli.sendline(r"\pset pager always") + if prompt_check: + wait_prompt(context) + + +def wait_prompt(context): + """Make sure prompt is displayed.""" + prompt_str = "{0}>".format(context.currentdb) + expect_exact(context, [prompt_str + " ", prompt_str, pexpect.EOF], timeout=3) diff --git a/tests/features/wrappager.py b/tests/features/wrappager.py new file mode 100644 index 0000000..51d4909 --- /dev/null +++ b/tests/features/wrappager.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +import sys + + +def wrappager(boundary): + print(boundary) + while 1: + buf = sys.stdin.read(2048) + if not buf: + break + sys.stdout.write(buf) + print(boundary) + + +if __name__ == "__main__": + wrappager(sys.argv[1]) diff --git a/tests/formatter/__init__.py b/tests/formatter/__init__.py new file mode 100644 index 0000000..9bad579 --- /dev/null +++ b/tests/formatter/__init__.py @@ -0,0 +1 @@ +# coding=utf-8 diff --git a/tests/formatter/test_sqlformatter.py b/tests/formatter/test_sqlformatter.py new file mode 100644 index 0000000..016ed95 --- /dev/null +++ b/tests/formatter/test_sqlformatter.py @@ -0,0 +1,111 @@ +# coding=utf-8 + +from pgcli.packages.formatter.sqlformatter import escape_for_sql_statement + +from cli_helpers.tabular_output import TabularOutputFormatter +from pgcli.packages.formatter.sqlformatter import adapter, register_new_formatter + + +def test_escape_for_sql_statement_bytes(): + bts = b"837124ab3e8dc0f" + escaped_bytes = escape_for_sql_statement(bts) + assert escaped_bytes == "X'383337313234616233653864633066'" + + +def test_escape_for_sql_statement_number(): + num = 2981 + escaped_bytes = escape_for_sql_statement(num) + assert escaped_bytes == "'2981'" + + +def test_escape_for_sql_statement_str(): + example_str = "example str" + escaped_bytes = escape_for_sql_statement(example_str) + assert escaped_bytes == "'example str'" + + +def test_output_sql_insert(): + global formatter + formatter = TabularOutputFormatter + register_new_formatter(formatter) + data = [ + [ + 1, + "Jackson", + "jackson_test@gmail.com", + "132454789", + None, + "2022-09-09 19:44:32.712343+08", + "2022-09-09 19:44:32.712343+08", + ] + ] + header = ["id", "name", "email", "phone", "description", "created_at", "updated_at"] + table_format = "sql-insert" + kwargs = { + "column_types": [int, str, str, str, str, str, str], + "sep_title": "RECORD {n}", + "sep_character": "-", + "sep_length": (1, 25), + "missing_value": "<null>", + "integer_format": "", + "float_format": "", + "disable_numparse": True, + "preserve_whitespace": True, + "max_field_width": 500, + } + formatter.query = 'SELECT * FROM "user";' + output = adapter(data, header, table_format=table_format, **kwargs) + output_list = [l for l in output] + expected = [ + 'INSERT INTO "user" ("id", "name", "email", "phone", "description", "created_at", "updated_at") VALUES', + " ('1', 'Jackson', 'jackson_test@gmail.com', '132454789', NULL, " + + "'2022-09-09 19:44:32.712343+08', '2022-09-09 19:44:32.712343+08')", + ";", + ] + assert expected == output_list + + +def test_output_sql_update(): + global formatter + formatter = TabularOutputFormatter + register_new_formatter(formatter) + data = [ + [ + 1, + "Jackson", + "jackson_test@gmail.com", + "132454789", + "", + "2022-09-09 19:44:32.712343+08", + "2022-09-09 19:44:32.712343+08", + ] + ] + header = ["id", "name", "email", "phone", "description", "created_at", "updated_at"] + table_format = "sql-update" + kwargs = { + "column_types": [int, str, str, str, str, str, str], + "sep_title": "RECORD {n}", + "sep_character": "-", + "sep_length": (1, 25), + "missing_value": "<null>", + "integer_format": "", + "float_format": "", + "disable_numparse": True, + "preserve_whitespace": True, + "max_field_width": 500, + } + formatter.query = 'SELECT * FROM "user";' + output = adapter(data, header, table_format=table_format, **kwargs) + output_list = [l for l in output] + print(output_list) + expected = [ + 'UPDATE "user" SET', + " \"name\" = 'Jackson'", + ", \"email\" = 'jackson_test@gmail.com'", + ", \"phone\" = '132454789'", + ", \"description\" = ''", + ", \"created_at\" = '2022-09-09 19:44:32.712343+08'", + ", \"updated_at\" = '2022-09-09 19:44:32.712343+08'", + "WHERE \"id\" = '1';", + ] + assert expected == output_list diff --git a/tests/metadata.py b/tests/metadata.py new file mode 100644 index 0000000..4ebcccd --- /dev/null +++ b/tests/metadata.py @@ -0,0 +1,255 @@ +from functools import partial +from itertools import product +from pgcli.packages.parseutils.meta import FunctionMetadata, ForeignKey +from prompt_toolkit.completion import Completion +from prompt_toolkit.document import Document +from unittest.mock import Mock +import pytest + +parametrize = pytest.mark.parametrize + +qual = ["if_more_than_one_table", "always"] +no_qual = ["if_more_than_one_table", "never"] + + +def escape(name): + if not name.islower() or name in ("select", "localtimestamp"): + return '"' + name + '"' + return name + + +def completion(display_meta, text, pos=0): + return Completion(text, start_position=pos, display_meta=display_meta) + + +def function(text, pos=0, display=None): + return Completion( + text, display=display or text, start_position=pos, display_meta="function" + ) + + +def get_result(completer, text, position=None): + position = len(text) if position is None else position + return completer.get_completions( + Document(text=text, cursor_position=position), Mock() + ) + + +def result_set(completer, text, position=None): + return set(get_result(completer, text, position)) + + +# The code below is quivalent to +# def schema(text, pos=0): +# return completion('schema', text, pos) +# and so on +schema = partial(completion, "schema") +table = partial(completion, "table") +view = partial(completion, "view") +column = partial(completion, "column") +keyword = partial(completion, "keyword") +datatype = partial(completion, "datatype") +alias = partial(completion, "table alias") +name_join = partial(completion, "name join") +fk_join = partial(completion, "fk join") +join = partial(completion, "join") + + +def wildcard_expansion(cols, pos=-1): + return Completion(cols, start_position=pos, display_meta="columns", display="*") + + +class MetaData: + def __init__(self, metadata): + self.metadata = metadata + + def builtin_functions(self, pos=0): + return [function(f, pos) for f in self.completer.functions] + + def builtin_datatypes(self, pos=0): + return [datatype(dt, pos) for dt in self.completer.datatypes] + + def keywords(self, pos=0): + return [keyword(kw, pos) for kw in self.completer.keywords_tree.keys()] + + def specials(self, pos=0): + return [ + Completion(text=k, start_position=pos, display_meta=v.description) + for k, v in self.completer.pgspecial.commands.items() + ] + + def columns(self, tbl, parent="public", typ="tables", pos=0): + if typ == "functions": + fun = [x for x in self.metadata[typ][parent] if x[0] == tbl][0] + cols = fun[1] + else: + cols = self.metadata[typ][parent][tbl] + return [column(escape(col), pos) for col in cols] + + def datatypes(self, parent="public", pos=0): + return [ + datatype(escape(x), pos) + for x in self.metadata.get("datatypes", {}).get(parent, []) + ] + + def tables(self, parent="public", pos=0): + return [ + table(escape(x), pos) + for x in self.metadata.get("tables", {}).get(parent, []) + ] + + def views(self, parent="public", pos=0): + return [ + view(escape(x), pos) for x in self.metadata.get("views", {}).get(parent, []) + ] + + def functions(self, parent="public", pos=0): + return [ + function( + escape(x[0]) + + "(" + + ", ".join( + arg_name + " := " + for (arg_name, arg_mode) in zip(x[1], x[3]) + if arg_mode in ("b", "i") + ) + + ")", + pos, + escape(x[0]) + + "(" + + ", ".join( + arg_name + for (arg_name, arg_mode) in zip(x[1], x[3]) + if arg_mode in ("b", "i") + ) + + ")", + ) + for x in self.metadata.get("functions", {}).get(parent, []) + ] + + def schemas(self, pos=0): + schemas = {sch for schs in self.metadata.values() for sch in schs} + return [schema(escape(s), pos=pos) for s in schemas] + + def functions_and_keywords(self, parent="public", pos=0): + return ( + self.functions(parent, pos) + + self.builtin_functions(pos) + + self.keywords(pos) + ) + + # Note that the filtering parameters here only apply to the columns + def columns_functions_and_keywords(self, tbl, parent="public", typ="tables", pos=0): + return self.functions_and_keywords(pos=pos) + self.columns( + tbl, parent, typ, pos + ) + + def from_clause_items(self, parent="public", pos=0): + return ( + self.functions(parent, pos) + + self.views(parent, pos) + + self.tables(parent, pos) + ) + + def schemas_and_from_clause_items(self, parent="public", pos=0): + return self.from_clause_items(parent, pos) + self.schemas(pos) + + def types(self, parent="public", pos=0): + return self.datatypes(parent, pos) + self.tables(parent, pos) + + @property + def completer(self): + return self.get_completer() + + def get_completers(self, casing): + """ + Returns a function taking three bools `casing`, `filtr`, `aliasing` and + the list `qualify`, all defaulting to None. + Returns a list of completers. + These parameters specify the allowed values for the corresponding + completer parameters, `None` meaning any, i.e. (None, None, None, None) + results in all 24 possible completers, whereas e.g. + (True, False, True, ['never']) results in the one completer with + casing, without `search_path` filtering of objects, with table + aliasing, and without column qualification. + """ + + def _cfg(_casing, filtr, aliasing, qualify): + cfg = {"settings": {}} + if _casing: + cfg["casing"] = casing + cfg["settings"]["search_path_filter"] = filtr + cfg["settings"]["generate_aliases"] = aliasing + cfg["settings"]["qualify_columns"] = qualify + return cfg + + def _cfgs(casing, filtr, aliasing, qualify): + casings = [True, False] if casing is None else [casing] + filtrs = [True, False] if filtr is None else [filtr] + aliases = [True, False] if aliasing is None else [aliasing] + qualifys = qualify or ["always", "if_more_than_one_table", "never"] + return [_cfg(*p) for p in product(casings, filtrs, aliases, qualifys)] + + def completers(casing=None, filtr=None, aliasing=None, qualify=None): + get_comp = self.get_completer + return [get_comp(**c) for c in _cfgs(casing, filtr, aliasing, qualify)] + + return completers + + def _make_col(self, sch, tbl, col): + defaults = self.metadata.get("defaults", {}).get(sch, {}) + return (sch, tbl, col, "text", (tbl, col) in defaults, defaults.get((tbl, col))) + + def get_completer(self, settings=None, casing=None): + metadata = self.metadata + from pgcli.pgcompleter import PGCompleter + from pgspecial import PGSpecial + + comp = PGCompleter( + smart_completion=True, settings=settings, pgspecial=PGSpecial() + ) + + schemata, tables, tbl_cols, views, view_cols = [], [], [], [], [] + + for sch, tbls in metadata["tables"].items(): + schemata.append(sch) + + for tbl, cols in tbls.items(): + tables.append((sch, tbl)) + # Let all columns be text columns + tbl_cols.extend([self._make_col(sch, tbl, col) for col in cols]) + + for sch, tbls in metadata.get("views", {}).items(): + for tbl, cols in tbls.items(): + views.append((sch, tbl)) + # Let all columns be text columns + view_cols.extend([self._make_col(sch, tbl, col) for col in cols]) + + functions = [ + FunctionMetadata(sch, *func_meta, arg_defaults=None) + for sch, funcs in metadata["functions"].items() + for func_meta in funcs + ] + + datatypes = [ + (sch, typ) + for sch, datatypes in metadata["datatypes"].items() + for typ in datatypes + ] + + foreignkeys = [ + ForeignKey(*fk) for fks in metadata["foreignkeys"].values() for fk in fks + ] + + comp.extend_schemata(schemata) + comp.extend_relations(tables, kind="tables") + comp.extend_relations(views, kind="views") + comp.extend_columns(tbl_cols, kind="tables") + comp.extend_columns(view_cols, kind="views") + comp.extend_functions(functions) + comp.extend_datatypes(datatypes) + comp.extend_foreignkeys(foreignkeys) + comp.set_search_path(["public"]) + comp.extend_casing(casing or []) + + return comp diff --git a/tests/parseutils/test_ctes.py b/tests/parseutils/test_ctes.py new file mode 100644 index 0000000..3e89cca --- /dev/null +++ b/tests/parseutils/test_ctes.py @@ -0,0 +1,137 @@ +import pytest +from sqlparse import parse +from pgcli.packages.parseutils.ctes import ( + token_start_pos, + extract_ctes, + extract_column_names as _extract_column_names, +) + + +def extract_column_names(sql): + p = parse(sql)[0] + return _extract_column_names(p) + + +def test_token_str_pos(): + sql = "SELECT * FROM xxx" + p = parse(sql)[0] + idx = p.token_index(p.tokens[-1]) + assert token_start_pos(p.tokens, idx) == len("SELECT * FROM ") + + sql = "SELECT * FROM \nxxx" + p = parse(sql)[0] + idx = p.token_index(p.tokens[-1]) + assert token_start_pos(p.tokens, idx) == len("SELECT * FROM \n") + + +def test_single_column_name_extraction(): + sql = "SELECT abc FROM xxx" + assert extract_column_names(sql) == ("abc",) + + +def test_aliased_single_column_name_extraction(): + sql = "SELECT abc def FROM xxx" + assert extract_column_names(sql) == ("def",) + + +def test_aliased_expression_name_extraction(): + sql = "SELECT 99 abc FROM xxx" + assert extract_column_names(sql) == ("abc",) + + +def test_multiple_column_name_extraction(): + sql = "SELECT abc, def FROM xxx" + assert extract_column_names(sql) == ("abc", "def") + + +def test_missing_column_name_handled_gracefully(): + sql = "SELECT abc, 99 FROM xxx" + assert extract_column_names(sql) == ("abc",) + + sql = "SELECT abc, 99, def FROM xxx" + assert extract_column_names(sql) == ("abc", "def") + + +def test_aliased_multiple_column_name_extraction(): + sql = "SELECT abc def, ghi jkl FROM xxx" + assert extract_column_names(sql) == ("def", "jkl") + + +def test_table_qualified_column_name_extraction(): + sql = "SELECT abc.def, ghi.jkl FROM xxx" + assert extract_column_names(sql) == ("def", "jkl") + + +@pytest.mark.parametrize( + "sql", + [ + "INSERT INTO foo (x, y, z) VALUES (5, 6, 7) RETURNING x, y", + "DELETE FROM foo WHERE x > y RETURNING x, y", + "UPDATE foo SET x = 9 RETURNING x, y", + ], +) +def test_extract_column_names_from_returning_clause(sql): + assert extract_column_names(sql) == ("x", "y") + + +def test_simple_cte_extraction(): + sql = "WITH a AS (SELECT abc FROM xxx) SELECT * FROM a" + start_pos = len("WITH a AS ") + stop_pos = len("WITH a AS (SELECT abc FROM xxx)") + ctes, remainder = extract_ctes(sql) + + assert tuple(ctes) == (("a", ("abc",), start_pos, stop_pos),) + assert remainder.strip() == "SELECT * FROM a" + + +def test_cte_extraction_around_comments(): + sql = """--blah blah blah + WITH a AS (SELECT abc def FROM x) + SELECT * FROM a""" + start_pos = len( + """--blah blah blah + WITH a AS """ + ) + stop_pos = len( + """--blah blah blah + WITH a AS (SELECT abc def FROM x)""" + ) + + ctes, remainder = extract_ctes(sql) + assert tuple(ctes) == (("a", ("def",), start_pos, stop_pos),) + assert remainder.strip() == "SELECT * FROM a" + + +def test_multiple_cte_extraction(): + sql = """WITH + x AS (SELECT abc, def FROM x), + y AS (SELECT ghi, jkl FROM y) + SELECT * FROM a, b""" + + start1 = len( + """WITH + x AS """ + ) + + stop1 = len( + """WITH + x AS (SELECT abc, def FROM x)""" + ) + + start2 = len( + """WITH + x AS (SELECT abc, def FROM x), + y AS """ + ) + + stop2 = len( + """WITH + x AS (SELECT abc, def FROM x), + y AS (SELECT ghi, jkl FROM y)""" + ) + + ctes, remainder = extract_ctes(sql) + assert tuple(ctes) == ( + ("x", ("abc", "def"), start1, stop1), + ("y", ("ghi", "jkl"), start2, stop2), + ) diff --git a/tests/parseutils/test_function_metadata.py b/tests/parseutils/test_function_metadata.py new file mode 100644 index 0000000..0350e2a --- /dev/null +++ b/tests/parseutils/test_function_metadata.py @@ -0,0 +1,19 @@ +from pgcli.packages.parseutils.meta import FunctionMetadata + + +def test_function_metadata_eq(): + f1 = FunctionMetadata( + "s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None + ) + f2 = FunctionMetadata( + "s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None + ) + f3 = FunctionMetadata( + "s", "g", ["x"], ["integer"], [], "int", False, False, False, False, None + ) + assert f1 == f2 + assert f1 != f3 + assert not (f1 != f2) + assert not (f1 == f3) + assert hash(f1) == hash(f2) + assert hash(f1) != hash(f3) diff --git a/tests/parseutils/test_parseutils.py b/tests/parseutils/test_parseutils.py new file mode 100644 index 0000000..349cbd0 --- /dev/null +++ b/tests/parseutils/test_parseutils.py @@ -0,0 +1,310 @@ +import pytest +from pgcli.packages.parseutils import ( + is_destructive, + parse_destructive_warning, + BASE_KEYWORDS, + ALL_KEYWORDS, +) +from pgcli.packages.parseutils.tables import extract_tables +from pgcli.packages.parseutils.utils import find_prev_keyword, is_open_quote + + +def test_empty_string(): + tables = extract_tables("") + assert tables == () + + +def test_simple_select_single_table(): + tables = extract_tables("select * from abc") + assert tables == ((None, "abc", None, False),) + + +@pytest.mark.parametrize( + "sql", ['select * from "abc"."def"', 'select * from abc."def"'] +) +def test_simple_select_single_table_schema_qualified_quoted_table(sql): + tables = extract_tables(sql) + assert tables == (("abc", "def", '"def"', False),) + + +@pytest.mark.parametrize("sql", ["select * from abc.def", 'select * from "abc".def']) +def test_simple_select_single_table_schema_qualified(sql): + tables = extract_tables(sql) + assert tables == (("abc", "def", None, False),) + + +def test_simple_select_single_table_double_quoted(): + tables = extract_tables('select * from "Abc"') + assert tables == ((None, "Abc", None, False),) + + +def test_simple_select_multiple_tables(): + tables = extract_tables("select * from abc, def") + assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)} + + +def test_simple_select_multiple_tables_double_quoted(): + tables = extract_tables('select * from "Abc", "Def"') + assert set(tables) == {(None, "Abc", None, False), (None, "Def", None, False)} + + +def test_simple_select_single_table_deouble_quoted_aliased(): + tables = extract_tables('select * from "Abc" a') + assert tables == ((None, "Abc", "a", False),) + + +def test_simple_select_multiple_tables_deouble_quoted_aliased(): + tables = extract_tables('select * from "Abc" a, "Def" d') + assert set(tables) == {(None, "Abc", "a", False), (None, "Def", "d", False)} + + +def test_simple_select_multiple_tables_schema_qualified(): + tables = extract_tables("select * from abc.def, ghi.jkl") + assert set(tables) == {("abc", "def", None, False), ("ghi", "jkl", None, False)} + + +def test_simple_select_with_cols_single_table(): + tables = extract_tables("select a,b from abc") + assert tables == ((None, "abc", None, False),) + + +def test_simple_select_with_cols_single_table_schema_qualified(): + tables = extract_tables("select a,b from abc.def") + assert tables == (("abc", "def", None, False),) + + +def test_simple_select_with_cols_multiple_tables(): + tables = extract_tables("select a,b from abc, def") + assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)} + + +def test_simple_select_with_cols_multiple_qualified_tables(): + tables = extract_tables("select a,b from abc.def, def.ghi") + assert set(tables) == {("abc", "def", None, False), ("def", "ghi", None, False)} + + +def test_select_with_hanging_comma_single_table(): + tables = extract_tables("select a, from abc") + assert tables == ((None, "abc", None, False),) + + +def test_select_with_hanging_comma_multiple_tables(): + tables = extract_tables("select a, from abc, def") + assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)} + + +def test_select_with_hanging_period_multiple_tables(): + tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2") + assert set(tables) == {(None, "tabl1", "t1", False), (None, "tabl2", "t2", False)} + + +def test_simple_insert_single_table(): + tables = extract_tables('insert into abc (id, name) values (1, "def")') + + # sqlparse mistakenly assigns an alias to the table + # AND mistakenly identifies the field list as + # assert tables == ((None, 'abc', 'abc', False),) + + assert tables == ((None, "abc", "abc", False),) + + +@pytest.mark.xfail +def test_simple_insert_single_table_schema_qualified(): + tables = extract_tables('insert into abc.def (id, name) values (1, "def")') + assert tables == (("abc", "def", None, False),) + + +def test_simple_update_table_no_schema(): + tables = extract_tables("update abc set id = 1") + assert tables == ((None, "abc", None, False),) + + +def test_simple_update_table_with_schema(): + tables = extract_tables("update abc.def set id = 1") + assert tables == (("abc", "def", None, False),) + + +@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"]) +def test_join_table(join_type): + sql = f"SELECT * FROM abc a {join_type} JOIN def d ON a.id = d.num" + tables = extract_tables(sql) + assert set(tables) == {(None, "abc", "a", False), (None, "def", "d", False)} + + +def test_join_table_schema_qualified(): + tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num") + assert set(tables) == {("abc", "def", "x", False), ("ghi", "jkl", "y", False)} + + +def test_incomplete_join_clause(): + sql = """select a.x, b.y + from abc a join bcd b + on a.id = """ + tables = extract_tables(sql) + assert tables == ((None, "abc", "a", False), (None, "bcd", "b", False)) + + +def test_join_as_table(): + tables = extract_tables("SELECT * FROM my_table AS m WHERE m.a > 5") + assert tables == ((None, "my_table", "m", False),) + + +def test_multiple_joins(): + sql = """select * from t1 + inner join t2 ON + t1.id = t2.t1_id + inner join t3 ON + t2.id = t3.""" + tables = extract_tables(sql) + assert tables == ( + (None, "t1", None, False), + (None, "t2", None, False), + (None, "t3", None, False), + ) + + +def test_subselect_tables(): + sql = "SELECT * FROM (SELECT FROM abc" + tables = extract_tables(sql) + assert tables == ((None, "abc", None, False),) + + +@pytest.mark.parametrize("text", ["SELECT * FROM foo.", "SELECT 123 AS foo"]) +def test_extract_no_tables(text): + tables = extract_tables(text) + assert tables == tuple() + + +@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"]) +def test_simple_function_as_table(arg_list): + tables = extract_tables(f"SELECT * FROM foo({arg_list})") + assert tables == ((None, "foo", None, True),) + + +@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"]) +def test_simple_schema_qualified_function_as_table(arg_list): + tables = extract_tables(f"SELECT * FROM foo.bar({arg_list})") + assert tables == (("foo", "bar", None, True),) + + +@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"]) +def test_simple_aliased_function_as_table(arg_list): + tables = extract_tables(f"SELECT * FROM foo({arg_list}) bar") + assert tables == ((None, "foo", "bar", True),) + + +def test_simple_table_and_function(): + tables = extract_tables("SELECT * FROM foo JOIN bar()") + assert set(tables) == {(None, "foo", None, False), (None, "bar", None, True)} + + +def test_complex_table_and_function(): + tables = extract_tables( + """SELECT * FROM foo.bar baz + JOIN bar.qux(x, y, z) quux""" + ) + assert set(tables) == {("foo", "bar", "baz", False), ("bar", "qux", "quux", True)} + + +def test_find_prev_keyword_using(): + q = "select * from tbl1 inner join tbl2 using (col1, " + kw, q2 = find_prev_keyword(q) + assert kw.value == "(" and q2 == "select * from tbl1 inner join tbl2 using (" + + +@pytest.mark.parametrize( + "sql", + [ + "select * from foo where bar", + "select * from foo where bar = 1 and baz or ", + "select * from foo where bar = 1 and baz between qux and ", + ], +) +def test_find_prev_keyword_where(sql): + kw, stripped = find_prev_keyword(sql) + assert kw.value == "where" and stripped == "select * from foo where" + + +@pytest.mark.parametrize( + "sql", ["create table foo (bar int, baz ", "select * from foo() as bar (baz "] +) +def test_find_prev_keyword_open_parens(sql): + kw, _ = find_prev_keyword(sql) + assert kw.value == "(" + + +@pytest.mark.parametrize( + "sql", + [ + "", + "$$ foo $$", + "$$ 'foo' $$", + '$$ "foo" $$', + "$$ $a$ $$", + "$a$ $$ $a$", + "foo bar $$ baz $$", + ], +) +def test_is_open_quote__closed(sql): + assert not is_open_quote(sql) + + +@pytest.mark.parametrize( + "sql", + [ + "$$", + ";;;$$", + "foo $$ bar $$; foo $$", + "$$ foo $a$", + "foo 'bar baz", + "$a$ foo ", + '$$ "foo" ', + "$$ $a$ ", + "foo bar $$ baz", + ], +) +def test_is_open_quote__open(sql): + assert is_open_quote(sql) + + +@pytest.mark.parametrize( + ("sql", "keywords", "expected"), + [ + ("update abc set x = 1", ALL_KEYWORDS, True), + ("update abc set x = 1 where y = 2", ALL_KEYWORDS, True), + ("update abc set x = 1", BASE_KEYWORDS, True), + ("update abc set x = 1 where y = 2", BASE_KEYWORDS, False), + ("select x, y, z from abc", ALL_KEYWORDS, False), + ("drop abc", ALL_KEYWORDS, True), + ("alter abc", ALL_KEYWORDS, True), + ("delete abc", ALL_KEYWORDS, True), + ("truncate abc", ALL_KEYWORDS, True), + ("insert into abc values (1, 2, 3)", ALL_KEYWORDS, False), + ("insert into abc values (1, 2, 3)", BASE_KEYWORDS, False), + ("insert into abc values (1, 2, 3)", ["insert"], True), + ("insert into abc values (1, 2, 3)", ["insert"], True), + ], +) +def test_is_destructive(sql, keywords, expected): + assert is_destructive(sql, keywords) == expected + + +@pytest.mark.parametrize( + ("warning_level", "expected"), + [ + ("true", ALL_KEYWORDS), + ("false", []), + ("all", ALL_KEYWORDS), + ("moderate", BASE_KEYWORDS), + ("off", []), + ("", []), + (None, []), + (ALL_KEYWORDS, ALL_KEYWORDS), + (BASE_KEYWORDS, BASE_KEYWORDS), + ("insert", ["insert"]), + ("drop,alter,delete", ["drop", "alter", "delete"]), + (["drop", "alter", "delete"], ["drop", "alter", "delete"]), + ], +) +def test_parse_destructive_warning(warning_level, expected): + assert parse_destructive_warning(warning_level) == expected diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 0000000..f787740 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts=--capture=sys --showlocals
\ No newline at end of file diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..a517a89 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,40 @@ +import pytest +from unittest import mock +from pgcli import auth + + +@pytest.mark.parametrize("enabled,call_count", [(True, 1), (False, 0)]) +def test_keyring_initialize(enabled, call_count): + logger = mock.MagicMock() + + with mock.patch("importlib.import_module", return_value=True) as import_method: + auth.keyring_initialize(enabled, logger=logger) + assert import_method.call_count == call_count + + +def test_keyring_get_password_ok(): + with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()): + with mock.patch("pgcli.auth.keyring.get_password", return_value="abc123"): + assert auth.keyring_get_password("test") == "abc123" + + +def test_keyring_get_password_exception(): + with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()): + with mock.patch( + "pgcli.auth.keyring.get_password", side_effect=Exception("Boom!") + ): + assert auth.keyring_get_password("test") == "" + + +def test_keyring_set_password_ok(): + with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()): + with mock.patch("pgcli.auth.keyring.set_password"): + auth.keyring_set_password("test", "abc123") + + +def test_keyring_set_password_exception(): + with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()): + with mock.patch( + "pgcli.auth.keyring.set_password", side_effect=Exception("Boom!") + ): + auth.keyring_set_password("test", "abc123") diff --git a/tests/test_completion_refresher.py b/tests/test_completion_refresher.py new file mode 100644 index 0000000..a5529d6 --- /dev/null +++ b/tests/test_completion_refresher.py @@ -0,0 +1,95 @@ +import time +import pytest +from unittest.mock import Mock, patch + + +@pytest.fixture +def refresher(): + from pgcli.completion_refresher import CompletionRefresher + + return CompletionRefresher() + + +def test_ctor(refresher): + """ + Refresher object should contain a few handlers + :param refresher: + :return: + """ + assert len(refresher.refreshers) > 0 + actual_handlers = list(refresher.refreshers.keys()) + expected_handlers = [ + "schemata", + "tables", + "views", + "types", + "databases", + "casing", + "functions", + ] + assert expected_handlers == actual_handlers + + +def test_refresh_called_once(refresher): + """ + + :param refresher: + :return: + """ + callbacks = Mock() + pgexecute = Mock(**{"is_virtual_database.return_value": False}) + special = Mock() + + with patch.object(refresher, "_bg_refresh") as bg_refresh: + actual = refresher.refresh(pgexecute, special, callbacks) + time.sleep(1) # Wait for the thread to work. + assert len(actual) == 1 + assert len(actual[0]) == 4 + assert actual[0][3] == "Auto-completion refresh started in the background." + bg_refresh.assert_called_with(pgexecute, special, callbacks, None, None) + + +def test_refresh_called_twice(refresher): + """ + If refresh is called a second time, it should be restarted + :param refresher: + :return: + """ + callbacks = Mock() + + pgexecute = Mock(**{"is_virtual_database.return_value": False}) + special = Mock() + + def dummy_bg_refresh(*args): + time.sleep(3) # seconds + + refresher._bg_refresh = dummy_bg_refresh + + actual1 = refresher.refresh(pgexecute, special, callbacks) + time.sleep(1) # Wait for the thread to work. + assert len(actual1) == 1 + assert len(actual1[0]) == 4 + assert actual1[0][3] == "Auto-completion refresh started in the background." + + actual2 = refresher.refresh(pgexecute, special, callbacks) + time.sleep(1) # Wait for the thread to work. + assert len(actual2) == 1 + assert len(actual2[0]) == 4 + assert actual2[0][3] == "Auto-completion refresh restarted." + + +def test_refresh_with_callbacks(refresher): + """ + Callbacks must be called + :param refresher: + """ + callbacks = [Mock()] + pgexecute = Mock(**{"is_virtual_database.return_value": False}) + pgexecute.extra_args = {} + special = Mock() + + # Set refreshers to 0: we're not testing refresh logic here + refresher.refreshers = {} + refresher.refresh(pgexecute, special, callbacks) + time.sleep(1) # Wait for the thread to work. + assert callbacks[0].call_count == 1 diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..08fe74e --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,43 @@ +import io +import os +import stat + +import pytest + +from pgcli.config import ensure_dir_exists, skip_initial_comment + + +def test_ensure_file_parent(tmpdir): + subdir = tmpdir.join("subdir") + rcfile = subdir.join("rcfile") + ensure_dir_exists(str(rcfile)) + + +def test_ensure_existing_dir(tmpdir): + rcfile = str(tmpdir.mkdir("subdir").join("rcfile")) + + # should just not raise + ensure_dir_exists(rcfile) + + +def test_ensure_other_create_error(tmpdir): + subdir = tmpdir.join('subdir"') + rcfile = subdir.join("rcfile") + + # trigger an oserror that isn't "directory already exists" + os.chmod(str(tmpdir), stat.S_IREAD) + + with pytest.raises(OSError): + ensure_dir_exists(str(rcfile)) + + +@pytest.mark.parametrize( + "text, skipped_lines", + ( + ("abc\n", 1), + ("#[section]\ndef\n[section]", 2), + ("[section]", 0), + ), +) +def test_skip_initial_comment(text, skipped_lines): + assert skip_initial_comment(io.StringIO(text)) == skipped_lines diff --git a/tests/test_fuzzy_completion.py b/tests/test_fuzzy_completion.py new file mode 100644 index 0000000..8f8f2cd --- /dev/null +++ b/tests/test_fuzzy_completion.py @@ -0,0 +1,87 @@ +import pytest + + +@pytest.fixture +def completer(): + import pgcli.pgcompleter as pgcompleter + + return pgcompleter.PGCompleter() + + +def test_ranking_ignores_identifier_quotes(completer): + """When calculating result rank, identifier quotes should be ignored. + + The result ranking algorithm ignores identifier quotes. Without this + correction, the match "user", which Postgres requires to be quoted + since it is also a reserved word, would incorrectly fall below the + match user_action because the literal quotation marks in "user" + alter the position of the match. + + This test checks that the fuzzy ranking algorithm correctly ignores + quotation marks when computing match ranks. + + """ + + text = "user" + collection = ["user_action", '"user"'] + matches = completer.find_matches(text, collection) + assert len(matches) == 2 + + +def test_ranking_based_on_shortest_match(completer): + """Fuzzy result rank should be based on shortest match. + + Result ranking in fuzzy searching is partially based on the length + of matches: shorter matches are considered more relevant than + longer ones. When searching for the text 'user', the length + component of the match 'user_group' could be either 4 ('user') or + 7 ('user_gr'). + + This test checks that the fuzzy ranking algorithm uses the shorter + match when calculating result rank. + + """ + + text = "user" + collection = ["api_user", "user_group"] + matches = completer.find_matches(text, collection) + + assert matches[1].priority > matches[0].priority + + +@pytest.mark.parametrize( + "collection", + [["user_action", "user"], ["user_group", "user"], ["user_group", "user_action"]], +) +def test_should_break_ties_using_lexical_order(completer, collection): + """Fuzzy result rank should use lexical order to break ties. + + When fuzzy matching, if multiple matches have the same match length and + start position, present them in lexical (rather than arbitrary) order. For + example, if we have tables 'user', 'user_action', and 'user_group', a + search for the text 'user' should present these tables in this order. + + The input collections to this test are out of order; each run checks that + the search text 'user' results in the input tables being reordered + lexically. + + """ + + text = "user" + matches = completer.find_matches(text, collection) + + assert matches[1].priority > matches[0].priority + + +def test_matching_should_be_case_insensitive(completer): + """Fuzzy matching should keep matches even if letter casing doesn't match. + + This test checks that variations of the text which have different casing + are still matched. + """ + + text = "foo" + collection = ["Foo", "FOO", "fOO"] + matches = completer.find_matches(text, collection) + + assert len(matches) == 3 diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..cbf20a6 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,490 @@ +import os +import platform +from unittest import mock + +import pytest + +try: + import setproctitle +except ImportError: + setproctitle = None + +from pgcli.main import ( + obfuscate_process_password, + format_output, + PGCli, + OutputSettings, + COLOR_CODE_REGEX, +) +from pgcli.pgexecute import PGExecute +from pgspecial.main import PAGER_OFF, PAGER_LONG_OUTPUT, PAGER_ALWAYS +from utils import dbtest, run +from collections import namedtuple + + +@pytest.mark.skipif(platform.system() == "Windows", reason="Not applicable in windows") +@pytest.mark.skipif(not setproctitle, reason="setproctitle not available") +def test_obfuscate_process_password(): + original_title = setproctitle.getproctitle() + + setproctitle.setproctitle("pgcli user=root password=secret host=localhost") + obfuscate_process_password() + title = setproctitle.getproctitle() + expected = "pgcli user=root password=xxxx host=localhost" + assert title == expected + + setproctitle.setproctitle("pgcli user=root password=top secret host=localhost") + obfuscate_process_password() + title = setproctitle.getproctitle() + expected = "pgcli user=root password=xxxx host=localhost" + assert title == expected + + setproctitle.setproctitle("pgcli user=root password=top secret") + obfuscate_process_password() + title = setproctitle.getproctitle() + expected = "pgcli user=root password=xxxx" + assert title == expected + + setproctitle.setproctitle("pgcli postgres://root:secret@localhost/db") + obfuscate_process_password() + title = setproctitle.getproctitle() + expected = "pgcli postgres://root:xxxx@localhost/db" + assert title == expected + + setproctitle.setproctitle(original_title) + + +def test_format_output(): + settings = OutputSettings(table_format="psql", dcmlfmt="d", floatfmt="g") + results = format_output( + "Title", [("abc", "def")], ["head1", "head2"], "test status", settings + ) + expected = [ + "Title", + "+-------+-------+", + "| head1 | head2 |", + "|-------+-------|", + "| abc | def |", + "+-------+-------+", + "test status", + ] + assert list(results) == expected + + +def test_format_output_truncate_on(): + settings = OutputSettings( + table_format="psql", dcmlfmt="d", floatfmt="g", max_field_width=10 + ) + results = format_output( + None, + [("first field value", "second field value")], + ["head1", "head2"], + None, + settings, + ) + expected = [ + "+------------+------------+", + "| head1 | head2 |", + "|------------+------------|", + "| first f... | second ... |", + "+------------+------------+", + ] + assert list(results) == expected + + +def test_format_output_truncate_off(): + settings = OutputSettings( + table_format="psql", dcmlfmt="d", floatfmt="g", max_field_width=None + ) + long_field_value = ("first field " * 100).strip() + results = format_output(None, [(long_field_value,)], ["head1"], None, settings) + lines = list(results) + assert lines[3] == f"| {long_field_value} |" + + +@dbtest +def test_format_array_output(executor): + statement = """ + SELECT + array[1, 2, 3]::bigint[] as bigint_array, + '{{1,2},{3,4}}'::numeric[] as nested_numeric_array, + '{å,魚,текст}'::text[] as 配列 + UNION ALL + SELECT '{}', NULL, array[NULL] + """ + results = run(executor, statement) + expected = [ + "+--------------+----------------------+--------------+", + "| bigint_array | nested_numeric_array | 配列 |", + "|--------------+----------------------+--------------|", + "| {1,2,3} | {{1,2},{3,4}} | {å,魚,текст} |", + "| {} | <null> | {<null>} |", + "+--------------+----------------------+--------------+", + "SELECT 2", + ] + assert list(results) == expected + + +@dbtest +def test_format_array_output_expanded(executor): + statement = """ + SELECT + array[1, 2, 3]::bigint[] as bigint_array, + '{{1,2},{3,4}}'::numeric[] as nested_numeric_array, + '{å,魚,текст}'::text[] as 配列 + UNION ALL + SELECT '{}', NULL, array[NULL] + """ + results = run(executor, statement, expanded=True) + expected = [ + "-[ RECORD 1 ]-------------------------", + "bigint_array | {1,2,3}", + "nested_numeric_array | {{1,2},{3,4}}", + "配列 | {å,魚,текст}", + "-[ RECORD 2 ]-------------------------", + "bigint_array | {}", + "nested_numeric_array | <null>", + "配列 | {<null>}", + "SELECT 2", + ] + assert "\n".join(results) == "\n".join(expected) + + +def test_format_output_auto_expand(): + settings = OutputSettings( + table_format="psql", dcmlfmt="d", floatfmt="g", max_width=100 + ) + table_results = format_output( + "Title", [("abc", "def")], ["head1", "head2"], "test status", settings + ) + table = [ + "Title", + "+-------+-------+", + "| head1 | head2 |", + "|-------+-------|", + "| abc | def |", + "+-------+-------+", + "test status", + ] + assert list(table_results) == table + expanded_results = format_output( + "Title", + [("abc", "def")], + ["head1", "head2"], + "test status", + settings._replace(max_width=1), + ) + expanded = [ + "Title", + "-[ RECORD 1 ]-------------------------", + "head1 | abc", + "head2 | def", + "test status", + ] + assert "\n".join(expanded_results) == "\n".join(expanded) + + +termsize = namedtuple("termsize", ["rows", "columns"]) +test_line = "-" * 10 +test_data = [ + (10, 10, "\n".join([test_line] * 7)), + (10, 10, "\n".join([test_line] * 6)), + (10, 10, "\n".join([test_line] * 5)), + (10, 10, "-" * 11), + (10, 10, "-" * 10), + (10, 10, "-" * 9), +] + +# 4 lines are reserved at the bottom of the terminal for pgcli's prompt +use_pager_when_on = [True, True, False, True, False, False] + +# Can be replaced with pytest.param once we can upgrade pytest after Python 3.4 goes EOL +test_ids = [ + "Output longer than terminal height", + "Output equal to terminal height", + "Output shorter than terminal height", + "Output longer than terminal width", + "Output equal to terminal width", + "Output shorter than terminal width", +] + + +@pytest.fixture +def pset_pager_mocks(): + cli = PGCli() + cli.watch_command = None + with mock.patch("pgcli.main.click.echo") as mock_echo, mock.patch( + "pgcli.main.click.echo_via_pager" + ) as mock_echo_via_pager, mock.patch.object(cli, "prompt_app") as mock_app: + yield cli, mock_echo, mock_echo_via_pager, mock_app + + +@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids) +def test_pset_pager_off(term_height, term_width, text, pset_pager_mocks): + cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks + mock_cli.output.get_size.return_value = termsize( + rows=term_height, columns=term_width + ) + + with mock.patch.object(cli.pgspecial, "pager_config", PAGER_OFF): + cli.echo_via_pager(text) + + mock_echo.assert_called() + mock_echo_via_pager.assert_not_called() + + +@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids) +def test_pset_pager_always(term_height, term_width, text, pset_pager_mocks): + cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks + mock_cli.output.get_size.return_value = termsize( + rows=term_height, columns=term_width + ) + + with mock.patch.object(cli.pgspecial, "pager_config", PAGER_ALWAYS): + cli.echo_via_pager(text) + + mock_echo.assert_not_called() + mock_echo_via_pager.assert_called() + + +pager_on_test_data = [l + (r,) for l, r in zip(test_data, use_pager_when_on)] + + +@pytest.mark.parametrize( + "term_height,term_width,text,use_pager", pager_on_test_data, ids=test_ids +) +def test_pset_pager_on(term_height, term_width, text, use_pager, pset_pager_mocks): + cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks + mock_cli.output.get_size.return_value = termsize( + rows=term_height, columns=term_width + ) + + with mock.patch.object(cli.pgspecial, "pager_config", PAGER_LONG_OUTPUT): + cli.echo_via_pager(text) + + if use_pager: + mock_echo.assert_not_called() + mock_echo_via_pager.assert_called() + else: + mock_echo_via_pager.assert_not_called() + mock_echo.assert_called() + + +@pytest.mark.parametrize( + "text,expected_length", + [ + ( + "22200K .......\u001b[0m\u001b[91m... .......... ...\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... .........\u001b[0m\u001b[91m.\u001b[0m\u001b[91m \u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... 50% 28.6K 12m55s", + 78, + ), + ("=\u001b[m=", 2), + ("-\u001b]23\u0007-", 2), + ], +) +def test_color_pattern(text, expected_length, pset_pager_mocks): + cli = pset_pager_mocks[0] + assert len(COLOR_CODE_REGEX.sub("", text)) == expected_length + + +@dbtest +def test_i_works(tmpdir, executor): + sqlfile = tmpdir.join("test.sql") + sqlfile.write("SELECT NOW()") + rcfile = str(tmpdir.join("rcfile")) + cli = PGCli(pgexecute=executor, pgclirc_file=rcfile) + statement = r"\i {0}".format(sqlfile) + run(executor, statement, pgspecial=cli.pgspecial) + + +@dbtest +def test_echo_works(executor): + cli = PGCli(pgexecute=executor) + statement = r"\echo asdf" + result = run(executor, statement, pgspecial=cli.pgspecial) + assert result == ["asdf"] + + +@dbtest +def test_qecho_works(executor): + cli = PGCli(pgexecute=executor) + statement = r"\qecho asdf" + result = run(executor, statement, pgspecial=cli.pgspecial) + assert result == ["asdf"] + + +@dbtest +def test_watch_works(executor): + cli = PGCli(pgexecute=executor) + + def run_with_watch( + query, target_call_count=1, expected_output="", expected_timing=None + ): + """ + :param query: Input to the CLI + :param target_call_count: Number of times the user lets the command run before Ctrl-C + :param expected_output: Substring expected to be found for each executed query + :param expected_timing: value `time.sleep` expected to be called with on every invocation + """ + with mock.patch.object(cli, "echo_via_pager") as mock_echo, mock.patch( + "pgcli.main.sleep" + ) as mock_sleep: + mock_sleep.side_effect = [None] * (target_call_count - 1) + [ + KeyboardInterrupt + ] + cli.handle_watch_command(query) + # Validate that sleep was called with the right timing + for i in range(target_call_count - 1): + assert mock_sleep.call_args_list[i][0][0] == expected_timing + # Validate that the output of the query was expected + assert mock_echo.call_count == target_call_count + for i in range(target_call_count): + assert expected_output in mock_echo.call_args_list[i][0][0] + + # With no history, it errors. + with mock.patch("pgcli.main.click.secho") as mock_secho: + cli.handle_watch_command(r"\watch 2") + mock_secho.assert_called() + assert ( + r"\watch cannot be used with an empty query" + in mock_secho.call_args_list[0][0][0] + ) + + # Usage 1: Run a query and then re-run it with \watch across two prompts. + run_with_watch("SELECT 111", expected_output="111") + run_with_watch( + "\\watch 10", target_call_count=2, expected_output="111", expected_timing=10 + ) + + # Usage 2: Run a query and \watch via the same prompt. + run_with_watch( + "SELECT 222; \\watch 4", + target_call_count=3, + expected_output="222", + expected_timing=4, + ) + + # Usage 3: Re-run the last watched command with a new timing + run_with_watch( + "\\watch 5", target_call_count=4, expected_output="222", expected_timing=5 + ) + + +def test_missing_rc_dir(tmpdir): + rcfile = str(tmpdir.join("subdir").join("rcfile")) + + PGCli(pgclirc_file=rcfile) + assert os.path.exists(rcfile) + + +def test_quoted_db_uri(tmpdir): + with mock.patch.object(PGCli, "connect") as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + cli.connect_uri("postgres://bar%5E:%5Dfoo@baz.com/testdb%5B") + mock_connect.assert_called_with( + database="testdb[", host="baz.com", user="bar^", passwd="]foo" + ) + + +def test_pg_service_file(tmpdir): + with mock.patch.object(PGCli, "connect") as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + with open(tmpdir.join(".pg_service.conf").strpath, "w") as service_conf: + service_conf.write( + """File begins with a comment + that is not a comment + # or maybe a comment after all + because psql is crazy + + [myservice] + host=a_host + user=a_user + port=5433 + password=much_secure + dbname=a_dbname + + [my_other_service] + host=b_host + user=b_user + port=5435 + dbname=b_dbname + """ + ) + os.environ["PGSERVICEFILE"] = tmpdir.join(".pg_service.conf").strpath + cli.connect_service("myservice", "another_user") + mock_connect.assert_called_with( + database="a_dbname", + host="a_host", + user="another_user", + port="5433", + passwd="much_secure", + ) + + with mock.patch.object(PGExecute, "__init__") as mock_pgexecute: + mock_pgexecute.return_value = None + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + os.environ["PGPASSWORD"] = "very_secure" + cli.connect_service("my_other_service", None) + mock_pgexecute.assert_called_with( + "b_dbname", + "b_user", + "very_secure", + "b_host", + "5435", + "", + application_name="pgcli", + ) + del os.environ["PGPASSWORD"] + del os.environ["PGSERVICEFILE"] + + +def test_ssl_db_uri(tmpdir): + with mock.patch.object(PGCli, "connect") as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + cli.connect_uri( + "postgres://bar%5E:%5Dfoo@baz.com/testdb%5B?" + "sslmode=verify-full&sslcert=m%79.pem&sslkey=my-key.pem&sslrootcert=c%61.pem" + ) + mock_connect.assert_called_with( + database="testdb[", + host="baz.com", + user="bar^", + passwd="]foo", + sslmode="verify-full", + sslcert="my.pem", + sslkey="my-key.pem", + sslrootcert="ca.pem", + ) + + +def test_port_db_uri(tmpdir): + with mock.patch.object(PGCli, "connect") as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + cli.connect_uri("postgres://bar:foo@baz.com:2543/testdb") + mock_connect.assert_called_with( + database="testdb", host="baz.com", user="bar", passwd="foo", port="2543" + ) + + +def test_multihost_db_uri(tmpdir): + with mock.patch.object(PGCli, "connect") as mock_connect: + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + cli.connect_uri( + "postgres://bar:foo@baz1.com:2543,baz2.com:2543,baz3.com:2543/testdb" + ) + mock_connect.assert_called_with( + database="testdb", + host="baz1.com,baz2.com,baz3.com", + user="bar", + passwd="foo", + port="2543,2543,2543", + ) + + +def test_application_name_db_uri(tmpdir): + with mock.patch.object(PGExecute, "__init__") as mock_pgexecute: + mock_pgexecute.return_value = None + cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) + cli.connect_uri("postgres://bar@baz.com/?application_name=cow") + mock_pgexecute.assert_called_with( + "bar", "bar", "", "baz.com", "", "", application_name="cow" + ) diff --git a/tests/test_naive_completion.py b/tests/test_naive_completion.py new file mode 100644 index 0000000..5b93661 --- /dev/null +++ b/tests/test_naive_completion.py @@ -0,0 +1,133 @@ +import pytest +from prompt_toolkit.completion import Completion +from prompt_toolkit.document import Document +from utils import completions_to_set + + +@pytest.fixture +def completer(): + import pgcli.pgcompleter as pgcompleter + + return pgcompleter.PGCompleter(smart_completion=False) + + +@pytest.fixture +def complete_event(): + from unittest.mock import Mock + + return Mock() + + +def test_empty_string_completion(completer, complete_event): + text = "" + position = 0 + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == completions_to_set(map(Completion, completer.all_completions)) + + +def test_select_keyword_completion(completer, complete_event): + text = "SEL" + position = len("SEL") + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == completions_to_set([Completion(text="SELECT", start_position=-3)]) + + +def test_function_name_completion(completer, complete_event): + text = "SELECT MA" + position = len("SELECT MA") + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == completions_to_set( + [ + Completion(text="MATERIALIZED VIEW", start_position=-2), + Completion(text="MAX", start_position=-2), + Completion(text="MAXEXTENTS", start_position=-2), + Completion(text="MAKE_DATE", start_position=-2), + Completion(text="MAKE_TIME", start_position=-2), + Completion(text="MAKE_TIMESTAMPTZ", start_position=-2), + Completion(text="MAKE_INTERVAL", start_position=-2), + Completion(text="MASKLEN", start_position=-2), + Completion(text="MAKE_TIMESTAMP", start_position=-2), + ] + ) + + +def test_column_name_completion(completer, complete_event): + text = "SELECT FROM users" + position = len("SELECT ") + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + assert result == completions_to_set(map(Completion, completer.all_completions)) + + +def test_alter_well_known_keywords_completion(completer, complete_event): + text = "ALTER " + position = len(text) + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), + complete_event, + smart_completion=True, + ) + ) + assert result > completions_to_set( + [ + Completion(text="DATABASE", display_meta="keyword"), + Completion(text="TABLE", display_meta="keyword"), + Completion(text="SYSTEM", display_meta="keyword"), + ] + ) + assert ( + completions_to_set([Completion(text="CREATE", display_meta="keyword")]) + not in result + ) + + +def test_special_name_completion(completer, complete_event): + text = "\\" + position = len("\\") + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), complete_event + ) + ) + # Special commands will NOT be suggested during naive completion mode. + assert result == completions_to_set([]) + + +def test_datatype_name_completion(completer, complete_event): + text = "SELECT price::IN" + position = len("SELECT price::IN") + result = completions_to_set( + completer.get_completions( + Document(text=text, cursor_position=position), + complete_event, + smart_completion=True, + ) + ) + assert result == completions_to_set( + [ + Completion(text="INET", display_meta="datatype"), + Completion(text="INT", display_meta="datatype"), + Completion(text="INT2", display_meta="datatype"), + Completion(text="INT4", display_meta="datatype"), + Completion(text="INT8", display_meta="datatype"), + Completion(text="INTEGER", display_meta="datatype"), + Completion(text="INTERNAL", display_meta="datatype"), + Completion(text="INTERVAL", display_meta="datatype"), + ] + ) diff --git a/tests/test_pgcompleter.py b/tests/test_pgcompleter.py new file mode 100644 index 0000000..909fa0b --- /dev/null +++ b/tests/test_pgcompleter.py @@ -0,0 +1,76 @@ +import pytest +from pgcli import pgcompleter + + +def test_load_alias_map_file_missing_file(): + with pytest.raises( + pgcompleter.InvalidMapFile, + match=r"Cannot read alias_map_file - /path/to/non-existent/file.json does not exist$", + ): + pgcompleter.load_alias_map_file("/path/to/non-existent/file.json") + + +def test_load_alias_map_file_invalid_json(tmp_path): + fpath = tmp_path / "foo.json" + fpath.write_text("this is not valid json") + with pytest.raises(pgcompleter.InvalidMapFile, match=r".*is not valid json$"): + pgcompleter.load_alias_map_file(str(fpath)) + + +@pytest.mark.parametrize( + "table_name, alias", + [ + ("SomE_Table", "SET"), + ("SOmeTabLe", "SOTL"), + ("someTable", "T"), + ], +) +def test_generate_alias_uses_upper_case_letters_from_name(table_name, alias): + assert pgcompleter.generate_alias(table_name) == alias + + +@pytest.mark.parametrize( + "table_name, alias", + [ + ("some_tab_le", "stl"), + ("s_ome_table", "sot"), + ("sometable", "s"), + ], +) +def test_generate_alias_uses_first_char_and_every_preceded_by_underscore( + table_name, alias +): + assert pgcompleter.generate_alias(table_name) == alias + + +@pytest.mark.parametrize( + "table_name, alias_map, alias", + [ + ("some_table", {"some_table": "my_alias"}, "my_alias"), + ], +) +def test_generate_alias_can_use_alias_map(table_name, alias_map, alias): + assert pgcompleter.generate_alias(table_name, alias_map) == alias + + +@pytest.mark.parametrize( + "table_name, alias_map, alias", + [ + ("SomeTable", {"SomeTable": "my_alias"}, "my_alias"), + ], +) +def test_generate_alias_prefers_alias_over_upper_case_name( + table_name, alias_map, alias +): + assert pgcompleter.generate_alias(table_name, alias_map) == alias + + +@pytest.mark.parametrize( + "table_name, alias", + [ + ("Some_tablE", "SE"), + ("SomeTab_le", "ST"), + ], +) +def test_generate_alias_prefers_upper_case_name_over_underscore_name(table_name, alias): + assert pgcompleter.generate_alias(table_name) == alias diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py new file mode 100644 index 0000000..636795b --- /dev/null +++ b/tests/test_pgexecute.py @@ -0,0 +1,773 @@ +from textwrap import dedent + +import psycopg +import pytest +from unittest.mock import patch, MagicMock +from pgspecial.main import PGSpecial, NO_QUERY +from utils import run, dbtest, requires_json, requires_jsonb + +from pgcli.main import PGCli +from pgcli.packages.parseutils.meta import FunctionMetadata + + +def function_meta_data( + func_name, + schema_name="public", + arg_names=None, + arg_types=None, + arg_modes=None, + return_type=None, + is_aggregate=False, + is_window=False, + is_set_returning=False, + is_extension=False, + arg_defaults=None, +): + return FunctionMetadata( + schema_name, + func_name, + arg_names, + arg_types, + arg_modes, + return_type, + is_aggregate, + is_window, + is_set_returning, + is_extension, + arg_defaults, + ) + + +@dbtest +def test_conn(executor): + run(executor, """create table test(a text)""") + run(executor, """insert into test values('abc')""") + assert run(executor, """select * from test""", join=True) == dedent( + """\ + +-----+ + | a | + |-----| + | abc | + +-----+ + SELECT 1""" + ) + + +@dbtest +def test_copy(executor): + executor_copy = executor.copy() + run(executor_copy, """create table test(a text)""") + run(executor_copy, """insert into test values('abc')""") + assert run(executor_copy, """select * from test""", join=True) == dedent( + """\ + +-----+ + | a | + |-----| + | abc | + +-----+ + SELECT 1""" + ) + + +@dbtest +def test_bools_are_treated_as_strings(executor): + run(executor, """create table test(a boolean)""") + run(executor, """insert into test values(True)""") + assert run(executor, """select * from test""", join=True) == dedent( + """\ + +------+ + | a | + |------| + | True | + +------+ + SELECT 1""" + ) + + +@dbtest +def test_expanded_slash_G(executor, pgspecial): + # Tests whether we reset the expanded output after a \G. + run(executor, """create table test(a boolean)""") + run(executor, """insert into test values(True)""") + results = run(executor, r"""select * from test \G""", pgspecial=pgspecial) + assert pgspecial.expanded_output == False + + +@dbtest +def test_schemata_table_views_and_columns_query(executor): + run(executor, "create table a(x text, y text)") + run(executor, "create table b(z text)") + run(executor, "create view d as select 1 as e") + run(executor, "create schema schema1") + run(executor, "create table schema1.c (w text DEFAULT 'meow')") + run(executor, "create schema schema2") + + # schemata + # don't enforce all members of the schemas since they may include postgres + # temporary schemas + assert set(executor.schemata()) >= { + "public", + "pg_catalog", + "information_schema", + "schema1", + "schema2", + } + assert executor.search_path() == ["pg_catalog", "public"] + + # tables + assert set(executor.tables()) >= { + ("public", "a"), + ("public", "b"), + ("schema1", "c"), + } + + assert set(executor.table_columns()) >= { + ("public", "a", "x", "text", False, None), + ("public", "a", "y", "text", False, None), + ("public", "b", "z", "text", False, None), + ("schema1", "c", "w", "text", True, "'meow'::text"), + } + + # views + assert set(executor.views()) >= {("public", "d")} + + assert set(executor.view_columns()) >= { + ("public", "d", "e", "integer", False, None) + } + + +@dbtest +def test_foreign_key_query(executor): + run(executor, "create schema schema1") + run(executor, "create schema schema2") + run(executor, "create table schema1.parent(parentid int PRIMARY KEY)") + run( + executor, + "create table schema2.child(childid int PRIMARY KEY, motherid int REFERENCES schema1.parent)", + ) + + assert set(executor.foreignkeys()) >= { + ("schema1", "parent", "parentid", "schema2", "child", "motherid") + } + + +@dbtest +def test_functions_query(executor): + run( + executor, + """create function func1() returns int + language sql as $$select 1$$""", + ) + run(executor, "create schema schema1") + run( + executor, + """create function schema1.func2() returns int + language sql as $$select 2$$""", + ) + + run( + executor, + """create function func3() + returns table(x int, y int) language sql + as $$select 1, 2 from generate_series(1,5)$$;""", + ) + + run( + executor, + """create function func4(x int) returns setof int language sql + as $$select generate_series(1,5)$$;""", + ) + + funcs = set(executor.functions()) + assert funcs >= { + function_meta_data(func_name="func1", return_type="integer"), + function_meta_data( + func_name="func3", + arg_names=["x", "y"], + arg_types=["integer", "integer"], + arg_modes=["t", "t"], + return_type="record", + is_set_returning=True, + ), + function_meta_data( + schema_name="public", + func_name="func4", + arg_names=("x",), + arg_types=("integer",), + return_type="integer", + is_set_returning=True, + ), + function_meta_data( + schema_name="schema1", func_name="func2", return_type="integer" + ), + } + + +@dbtest +def test_datatypes_query(executor): + run(executor, "create type foo AS (a int, b text)") + + types = list(executor.datatypes()) + assert types == [("public", "foo")] + + +@dbtest +def test_database_list(executor): + databases = executor.databases() + assert "_test_db" in databases + + +@dbtest +def test_invalid_syntax(executor, exception_formatter): + result = run(executor, "invalid syntax!", exception_formatter=exception_formatter) + assert 'syntax error at or near "invalid"' in result[0] + + +@dbtest +def test_invalid_column_name(executor, exception_formatter): + result = run( + executor, "select invalid command", exception_formatter=exception_formatter + ) + assert 'column "invalid" does not exist' in result[0] + + +@pytest.fixture(params=[True, False]) +def expanded(request): + return request.param + + +@dbtest +def test_unicode_support_in_output(executor, expanded): + run(executor, "create table unicodechars(t text)") + run(executor, "insert into unicodechars (t) values ('é')") + + # See issue #24, this raises an exception without proper handling + assert "é" in run( + executor, "select * from unicodechars", join=True, expanded=expanded + ) + + +@dbtest +def test_not_is_special(executor, pgspecial): + """is_special is set to false for database queries.""" + query = "select 1" + result = list(executor.run(query, pgspecial=pgspecial)) + success, is_special = result[0][5:] + assert success == True + assert is_special == False + + +@dbtest +def test_execute_from_file_no_arg(executor, pgspecial): + r"""\i without a filename returns an error.""" + result = list(executor.run(r"\i", pgspecial=pgspecial)) + status, sql, success, is_special = result[0][3:] + assert "missing required argument" in status + assert success == False + assert is_special == True + + +@dbtest +@patch("pgcli.main.os") +def test_execute_from_file_io_error(os, executor, pgspecial): + r"""\i with an os_error returns an error.""" + # Inject an OSError. + os.path.expanduser.side_effect = OSError("test") + + # Check the result. + result = list(executor.run(r"\i test", pgspecial=pgspecial)) + status, sql, success, is_special = result[0][3:] + assert status == "test" + assert success == False + assert is_special == True + + +@dbtest +def test_execute_from_commented_file_that_executes_another_file( + executor, pgspecial, tmpdir +): + # https://github.com/dbcli/pgcli/issues/1336 + sqlfile1 = tmpdir.join("test01.sql") + sqlfile1.write("-- asdf \n\\h") + sqlfile2 = tmpdir.join("test00.sql") + sqlfile2.write("--An useless comment;\nselect now();\n-- another useless comment") + + rcfile = str(tmpdir.join("rcfile")) + print(rcfile) + cli = PGCli(pgexecute=executor, pgclirc_file=rcfile) + assert cli != None + statement = "--comment\n\\h" + result = run(executor, statement, pgspecial=cli.pgspecial) + assert result != None + assert result[0].find("ALTER TABLE") + + +@dbtest +def test_execute_commented_first_line_and_special(executor, pgspecial, tmpdir): + # just some base cases that should work also + statement = "--comment\nselect now();" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("now") >= 0 + + statement = "/*comment*/\nselect now();" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("now") >= 0 + + # https://github.com/dbcli/pgcli/issues/1362 + statement = "--comment\n\\h" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = "--comment1\n--comment2\n\\h" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = "/*comment*/\n\h;" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = """/*comment1 + comment2*/ + \h""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = """/*comment1 + comment2*/ + /*comment 3 + comment4*/ + \\h""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = " /*comment*/\n\h;" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = "/*comment\ncomment line2*/\n\h;" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = " /*comment\ncomment line2*/\n\h;" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("ALTER") >= 0 + assert result[1].find("ABORT") >= 0 + + statement = """\\h /*comment4 */""" + result = run(executor, statement, pgspecial=pgspecial) + print(result) + assert result != None + assert result[0].find("No help") >= 0 + + # TODO: we probably don't want to do this but sqlparse is not parsing things well + # we relly want it to find help but right now, sqlparse isn't dropping the /*comment*/ + # style comments after command + + statement = """/*comment1*/ + \h + /*comment4 */""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[0].find("No help") >= 0 + + # TODO: same for this one + statement = """/*comment1 + comment3 + comment2*/ + \\h + /*comment4 + comment5 + comment6*/""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[0].find("No help") >= 0 + + +@dbtest +def test_execute_commented_first_line_and_normal(executor, pgspecial, tmpdir): + # https://github.com/dbcli/pgcli/issues/1403 + + # just some base cases that should work also + statement = "--comment\nselect now();" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("now") >= 0 + + statement = "/*comment*/\nselect now();" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[1].find("now") >= 0 + + # this simulates the original error (1403) without having to add/drop tables + # since it was just an error on reading input files and not the actual + # command itself + + # test that the statement works + statement = """VALUES (1, 'one'), (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # test the statement with a \n in the middle + statement = """VALUES (1, 'one'),\n (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # test the statement with a newline in the middle + statement = """VALUES (1, 'one'), + (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # now add a single comment line + statement = """--comment\nVALUES (1, 'one'), (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # doing without special char \n + statement = """--comment + VALUES (1,'one'), + (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # two comment lines + statement = """--comment\n--comment2\nVALUES (1,'one'), (2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # doing without special char \n + statement = """--comment + --comment2 + VALUES (1,'one'), (2, 'two'), (3, 'three'); + """ + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # multiline comment + newline in middle of the statement + statement = """/*comment +comment2 +comment3*/ +VALUES (1,'one'), +(2, 'two'), (3, 'three');""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + # multiline comment + newline in middle of the statement + # + comments after the statement + statement = """/*comment +comment2 +comment3*/ +VALUES (1,'one'), +(2, 'two'), (3, 'three'); +--comment4 +--comment5""" + result = run(executor, statement, pgspecial=pgspecial) + assert result != None + assert result[5].find("three") >= 0 + + +@dbtest +def test_multiple_queries_same_line(executor): + result = run(executor, "select 'foo'; select 'bar'") + assert len(result) == 12 # 2 * (output+status) * 3 lines + assert "foo" in result[3] + assert "bar" in result[9] + + +@dbtest +def test_multiple_queries_with_special_command_same_line(executor, pgspecial): + result = run(executor, r"select 'foo'; \d", pgspecial=pgspecial) + assert len(result) == 11 # 2 * (output+status) * 3 lines + assert "foo" in result[3] + # This is a lame check. :( + assert "Schema" in result[7] + + +@dbtest +def test_multiple_queries_same_line_syntaxerror(executor, exception_formatter): + result = run( + executor, + "select 'fooé'; invalid syntax é", + exception_formatter=exception_formatter, + ) + assert "fooé" in result[3] + assert 'syntax error at or near "invalid"' in result[-1] + + +@pytest.fixture +def pgspecial(): + return PGCli().pgspecial + + +@dbtest +def test_special_command_help(executor, pgspecial): + result = run(executor, "\\?", pgspecial=pgspecial)[1].split("|") + assert "Command" in result[1] + assert "Description" in result[2] + + +@dbtest +def test_bytea_field_support_in_output(executor): + run(executor, "create table binarydata(c bytea)") + run(executor, "insert into binarydata (c) values (decode('DEADBEEF', 'hex'))") + + assert "\\xdeadbeef" in run(executor, "select * from binarydata", join=True) + + +@dbtest +def test_unicode_support_in_unknown_type(executor): + assert "日本語" in run(executor, "SELECT '日本語' AS japanese;", join=True) + + +@dbtest +def test_unicode_support_in_enum_type(executor): + run(executor, "CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy', '日本語')") + run(executor, "CREATE TABLE person (name TEXT, current_mood mood)") + run(executor, "INSERT INTO person VALUES ('Moe', '日本語')") + assert "日本語" in run(executor, "SELECT * FROM person", join=True) + + +@requires_json +def test_json_renders_without_u_prefix(executor, expanded): + run(executor, "create table jsontest(d json)") + run(executor, """insert into jsontest (d) values ('{"name": "Éowyn"}')""") + result = run( + executor, "SELECT d FROM jsontest LIMIT 1", join=True, expanded=expanded + ) + + assert '{"name": "Éowyn"}' in result + + +@requires_jsonb +def test_jsonb_renders_without_u_prefix(executor, expanded): + run(executor, "create table jsonbtest(d jsonb)") + run(executor, """insert into jsonbtest (d) values ('{"name": "Éowyn"}')""") + result = run( + executor, "SELECT d FROM jsonbtest LIMIT 1", join=True, expanded=expanded + ) + + assert '{"name": "Éowyn"}' in result + + +@dbtest +def test_date_time_types(executor): + run(executor, "SET TIME ZONE UTC") + assert ( + run(executor, "SELECT (CAST('00:00:00' AS time))", join=True).split("\n")[3] + == "| 00:00:00 |" + ) + assert ( + run(executor, "SELECT (CAST('00:00:00+14:59' AS timetz))", join=True).split( + "\n" + )[3] + == "| 00:00:00+14:59 |" + ) + assert ( + run(executor, "SELECT (CAST('4713-01-01 BC' AS date))", join=True).split("\n")[ + 3 + ] + == "| 4713-01-01 BC |" + ) + assert ( + run( + executor, "SELECT (CAST('4713-01-01 00:00:00 BC' AS timestamp))", join=True + ).split("\n")[3] + == "| 4713-01-01 00:00:00 BC |" + ) + assert ( + run( + executor, + "SELECT (CAST('4713-01-01 00:00:00+00 BC' AS timestamptz))", + join=True, + ).split("\n")[3] + == "| 4713-01-01 00:00:00+00 BC |" + ) + assert ( + run( + executor, "SELECT (CAST('-123456789 days 12:23:56' AS interval))", join=True + ).split("\n")[3] + == "| -123456789 days, 12:23:56 |" + ) + + +@dbtest +@pytest.mark.parametrize("value", ["10000000", "10000000.0", "10000000000000"]) +def test_large_numbers_render_directly(executor, value): + run(executor, "create table numbertest(a numeric)") + run(executor, f"insert into numbertest (a) values ({value})") + assert value in run(executor, "select * from numbertest", join=True) + + +@dbtest +@pytest.mark.parametrize("command", ["di", "dv", "ds", "df", "dT"]) +@pytest.mark.parametrize("verbose", ["", "+"]) +@pytest.mark.parametrize("pattern", ["", "x", "*.*", "x.y", "x.*", "*.y"]) +def test_describe_special(executor, command, verbose, pattern, pgspecial): + # We don't have any tests for the output of any of the special commands, + # but we can at least make sure they run without error + sql = r"\{command}{verbose} {pattern}".format(**locals()) + list(executor.run(sql, pgspecial=pgspecial)) + + +@dbtest +@pytest.mark.parametrize("sql", ["invalid sql", "SELECT 1; select error;"]) +def test_raises_with_no_formatter(executor, sql): + with pytest.raises(psycopg.ProgrammingError): + list(executor.run(sql)) + + +@dbtest +def test_on_error_resume(executor, exception_formatter): + sql = "select 1; error; select 1;" + result = list( + executor.run(sql, on_error_resume=True, exception_formatter=exception_formatter) + ) + assert len(result) == 3 + + +@dbtest +def test_on_error_stop(executor, exception_formatter): + sql = "select 1; error; select 1;" + result = list( + executor.run( + sql, on_error_resume=False, exception_formatter=exception_formatter + ) + ) + assert len(result) == 2 + + +# @dbtest +# def test_unicode_notices(executor): +# sql = "DO language plpgsql $$ BEGIN RAISE NOTICE '有人更改'; END $$;" +# result = list(executor.run(sql)) +# assert result[0][0] == u'NOTICE: 有人更改\n' + + +@dbtest +def test_nonexistent_function_definition(executor): + with pytest.raises(RuntimeError): + result = executor.view_definition("there_is_no_such_function") + + +@dbtest +def test_function_definition(executor): + run( + executor, + """ + CREATE OR REPLACE FUNCTION public.the_number_three() + RETURNS int + LANGUAGE sql + AS $function$ + select 3; + $function$ + """, + ) + result = executor.function_definition("the_number_three") + + +@dbtest +def test_view_definition(executor): + run(executor, "create table tbl1 (a text, b numeric)") + run(executor, "create view vw1 AS SELECT * FROM tbl1") + run(executor, "create materialized view mvw1 AS SELECT * FROM tbl1") + result = executor.view_definition("vw1") + assert 'VIEW "public"."vw1" AS' in result + assert "FROM tbl1" in result + # import pytest; pytest.set_trace() + result = executor.view_definition("mvw1") + assert "MATERIALIZED VIEW" in result + + +@dbtest +def test_nonexistent_view_definition(executor): + with pytest.raises(RuntimeError): + result = executor.view_definition("there_is_no_such_view") + with pytest.raises(RuntimeError): + result = executor.view_definition("mvw1") + + +@dbtest +def test_short_host(executor): + with patch.object(executor, "host", "localhost"): + assert executor.short_host == "localhost" + with patch.object(executor, "host", "localhost.example.org"): + assert executor.short_host == "localhost" + with patch.object( + executor, "host", "localhost1.example.org,localhost2.example.org" + ): + assert executor.short_host == "localhost1" + + +class VirtualCursor: + """Mock a cursor to virtual database like pgbouncer.""" + + def __init__(self): + self.protocol_error = False + self.protocol_message = "" + self.description = None + self.status = None + self.statusmessage = "Error" + + def execute(self, *args, **kwargs): + self.protocol_error = True + self.protocol_message = "Command not supported" + + +@dbtest +def test_exit_without_active_connection(executor): + quit_handler = MagicMock() + pgspecial = PGSpecial() + pgspecial.register( + quit_handler, + "\\q", + "\\q", + "Quit pgcli.", + arg_type=NO_QUERY, + case_sensitive=True, + aliases=(":q",), + ) + + with patch.object( + executor.conn, "cursor", side_effect=psycopg.InterfaceError("I'm broken!") + ): + # we should be able to quit the app, even without active connection + run(executor, "\\q", pgspecial=pgspecial) + quit_handler.assert_called_once() + + # an exception should be raised when running a query without active connection + with pytest.raises(psycopg.InterfaceError): + run(executor, "select 1", pgspecial=pgspecial) + + +@dbtest +def test_virtual_database(executor): + virtual_connection = MagicMock() + virtual_connection.cursor.return_value = VirtualCursor() + with patch.object(executor, "conn", virtual_connection): + result = run(executor, "select 1") + assert "Command not supported" in result diff --git a/tests/test_pgspecial.py b/tests/test_pgspecial.py new file mode 100644 index 0000000..cd99e32 --- /dev/null +++ b/tests/test_pgspecial.py @@ -0,0 +1,78 @@ +import pytest +from pgcli.packages.sqlcompletion import ( + suggest_type, + Special, + Database, + Schema, + Table, + View, + Function, + Datatype, +) + + +def test_slash_suggests_special(): + suggestions = suggest_type("\\", "\\") + assert set(suggestions) == {Special()} + + +def test_slash_d_suggests_special(): + suggestions = suggest_type("\\d", "\\d") + assert set(suggestions) == {Special()} + + +def test_dn_suggests_schemata(): + suggestions = suggest_type("\\dn ", "\\dn ") + assert suggestions == (Schema(),) + + suggestions = suggest_type("\\dn xxx", "\\dn xxx") + assert suggestions == (Schema(),) + + +def test_d_suggests_tables_views_and_schemas(): + suggestions = suggest_type(r"\d ", r"\d ") + assert set(suggestions) == {Schema(), Table(schema=None), View(schema=None)} + + suggestions = suggest_type(r"\d xxx", r"\d xxx") + assert set(suggestions) == {Schema(), Table(schema=None), View(schema=None)} + + +def test_d_dot_suggests_schema_qualified_tables_or_views(): + suggestions = suggest_type(r"\d myschema.", r"\d myschema.") + assert set(suggestions) == {Table(schema="myschema"), View(schema="myschema")} + + suggestions = suggest_type(r"\d myschema.xxx", r"\d myschema.xxx") + assert set(suggestions) == {Table(schema="myschema"), View(schema="myschema")} + + +def test_df_suggests_schema_or_function(): + suggestions = suggest_type("\\df xxx", "\\df xxx") + assert set(suggestions) == {Function(schema=None, usage="special"), Schema()} + + suggestions = suggest_type("\\df myschema.xxx", "\\df myschema.xxx") + assert suggestions == (Function(schema="myschema", usage="special"),) + + +def test_leading_whitespace_ok(): + cmd = "\\dn " + whitespace = " " + suggestions = suggest_type(whitespace + cmd, whitespace + cmd) + assert suggestions == suggest_type(cmd, cmd) + + +def test_dT_suggests_schema_or_datatypes(): + text = "\\dT " + suggestions = suggest_type(text, text) + assert set(suggestions) == {Schema(), Datatype(schema=None)} + + +def test_schema_qualified_dT_suggests_datatypes(): + text = "\\dT foo." + suggestions = suggest_type(text, text) + assert suggestions == (Datatype(schema="foo"),) + + +@pytest.mark.parametrize("command", ["\\c ", "\\connect "]) +def test_c_suggests_databases(command): + suggestions = suggest_type(command, command) + assert suggestions == (Database(),) diff --git a/tests/test_prioritization.py b/tests/test_prioritization.py new file mode 100644 index 0000000..f5b6700 --- /dev/null +++ b/tests/test_prioritization.py @@ -0,0 +1,20 @@ +from pgcli.packages.prioritization import PrevalenceCounter + + +def test_prevalence_counter(): + counter = PrevalenceCounter() + sql = """SELECT * FROM foo WHERE bar GROUP BY baz; + select * from foo; + SELECT * FROM foo WHERE bar GROUP + BY baz""" + counter.update(sql) + + keywords = ["SELECT", "FROM", "GROUP BY"] + expected = [3, 3, 2] + kw_counts = [counter.keyword_count(x) for x in keywords] + assert kw_counts == expected + assert counter.keyword_count("NOSUCHKEYWORD") == 0 + + names = ["foo", "bar", "baz"] + name_counts = [counter.name_count(x) for x in names] + assert name_counts == [3, 2, 2] diff --git a/tests/test_prompt_utils.py b/tests/test_prompt_utils.py new file mode 100644 index 0000000..91abe37 --- /dev/null +++ b/tests/test_prompt_utils.py @@ -0,0 +1,17 @@ +import click + +from pgcli.packages.prompt_utils import confirm_destructive_query + + +def test_confirm_destructive_query_notty(): + stdin = click.get_text_stream("stdin") + if not stdin.isatty(): + sql = "drop database foo;" + assert confirm_destructive_query(sql, [], None) is None + + +def test_confirm_destructive_query_with_alias(): + stdin = click.get_text_stream("stdin") + if not stdin.isatty(): + sql = "drop database foo;" + assert confirm_destructive_query(sql, ["drop"], "test") is None diff --git a/tests/test_rowlimit.py b/tests/test_rowlimit.py new file mode 100644 index 0000000..da916b4 --- /dev/null +++ b/tests/test_rowlimit.py @@ -0,0 +1,79 @@ +import pytest +from unittest.mock import Mock + +from pgcli.main import PGCli + + +# We need this fixtures because we need PGCli object to be created +# after test collection so it has config loaded from temp directory + + +@pytest.fixture(scope="module") +def default_pgcli_obj(): + return PGCli() + + +@pytest.fixture(scope="module") +def DEFAULT(default_pgcli_obj): + return default_pgcli_obj.row_limit + + +@pytest.fixture(scope="module") +def LIMIT(DEFAULT): + return DEFAULT + 1000 + + +@pytest.fixture(scope="module") +def over_default(DEFAULT): + over_default_cursor = Mock() + over_default_cursor.configure_mock(rowcount=DEFAULT + 10) + return over_default_cursor + + +@pytest.fixture(scope="module") +def over_limit(LIMIT): + over_limit_cursor = Mock() + over_limit_cursor.configure_mock(rowcount=LIMIT + 10) + return over_limit_cursor + + +@pytest.fixture(scope="module") +def low_count(): + low_count_cursor = Mock() + low_count_cursor.configure_mock(rowcount=1) + return low_count_cursor + + +def test_row_limit_with_LIMIT_clause(LIMIT, over_limit): + cli = PGCli(row_limit=LIMIT) + stmt = "SELECT * FROM students LIMIT 1000" + + result = cli._should_limit_output(stmt, over_limit) + assert result is False + + cli = PGCli(row_limit=0) + result = cli._should_limit_output(stmt, over_limit) + assert result is False + + +def test_row_limit_without_LIMIT_clause(LIMIT, over_limit): + cli = PGCli(row_limit=LIMIT) + stmt = "SELECT * FROM students" + + result = cli._should_limit_output(stmt, over_limit) + assert result is True + + cli = PGCli(row_limit=0) + result = cli._should_limit_output(stmt, over_limit) + assert result is False + + +def test_row_limit_on_non_select(over_limit): + cli = PGCli() + stmt = "UPDATE students SET name='Boby'" + result = cli._should_limit_output(stmt, over_limit) + assert result is False + + cli = PGCli(row_limit=0) + result = cli._should_limit_output(stmt, over_limit) + assert result is False diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py new file mode 100644 index 0000000..5c9c9af --- /dev/null +++ b/tests/test_smart_completion_multiple_schemata.py @@ -0,0 +1,757 @@ +import itertools +from metadata import ( + MetaData, + alias, + name_join, + fk_join, + join, + schema, + table, + function, + wildcard_expansion, + column, + get_result, + result_set, + qual, + no_qual, + parametrize, +) +from utils import completions_to_set + +metadata = { + "tables": { + "public": { + "users": ["id", "email", "first_name", "last_name"], + "orders": ["id", "ordered_date", "status", "datestamp"], + "select": ["id", "localtime", "ABC"], + }, + "custom": { + "users": ["id", "phone_number"], + "Users": ["userid", "username"], + "products": ["id", "product_name", "price"], + "shipments": ["id", "address", "user_id"], + }, + "Custom": {"projects": ["projectid", "name"]}, + "blog": { + "entries": ["entryid", "entrytitle", "entrytext"], + "tags": ["tagid", "name"], + "entrytags": ["entryid", "tagid"], + "entacclog": ["entryid", "username", "datestamp"], + }, + }, + "functions": { + "public": [ + ["func1", [], [], [], "", False, False, False, False], + ["func2", [], [], [], "", False, False, False, False], + ], + "custom": [ + ["func3", [], [], [], "", False, False, False, False], + [ + "set_returning_func", + ["x"], + ["integer"], + ["o"], + "integer", + False, + False, + True, + False, + ], + ], + "Custom": [["func4", [], [], [], "", False, False, False, False]], + "blog": [ + [ + "extract_entry_symbols", + ["_entryid", "symbol"], + ["integer", "text"], + ["i", "o"], + "", + False, + False, + True, + False, + ], + [ + "enter_entry", + ["_title", "_text", "entryid"], + ["text", "text", "integer"], + ["i", "i", "o"], + "", + False, + False, + False, + False, + ], + ], + }, + "datatypes": {"public": ["typ1", "typ2"], "custom": ["typ3", "typ4"]}, + "foreignkeys": { + "custom": [("public", "users", "id", "custom", "shipments", "user_id")], + "blog": [ + ("blog", "entries", "entryid", "blog", "entacclog", "entryid"), + ("blog", "entries", "entryid", "blog", "entrytags", "entryid"), + ("blog", "tags", "tagid", "blog", "entrytags", "tagid"), + ], + }, + "defaults": { + "public": { + ("orders", "id"): "nextval('orders_id_seq'::regclass)", + ("orders", "datestamp"): "now()", + ("orders", "status"): "'PENDING'::text", + } + }, +} + +testdata = MetaData(metadata) +cased_schemas = [schema(x) for x in ("public", "blog", "CUSTOM", '"Custom"')] +casing = ( + "SELECT", + "Orders", + "User_Emails", + "CUSTOM", + "Func1", + "Entries", + "Tags", + "EntryTags", + "EntAccLog", + "EntryID", + "EntryTitle", + "EntryText", +) +completers = testdata.get_completers(casing) + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +@parametrize("table", ["users", '"users"']) +def test_suggested_column_names_from_shadowed_visible_table(completer, table): + result = get_result(completer, "SELECT FROM " + table, len("SELECT ")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("users") + ) + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +@parametrize( + "text", + [ + "SELECT from custom.users", + "WITH users as (SELECT 1 AS foo) SELECT from custom.users", + ], +) +def test_suggested_column_names_from_qualified_shadowed_table(completer, text): + result = get_result(completer, text, position=text.find(" ") + 1) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("users", "custom") + ) + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +@parametrize("text", ["WITH users as (SELECT 1 AS foo) SELECT from users"]) +def test_suggested_column_names_from_cte(completer, text): + result = completions_to_set(get_result(completer, text, text.find(" ") + 1)) + assert result == completions_to_set( + [column("foo")] + testdata.functions_and_keywords() + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT * FROM users JOIN custom.shipments ON ", + """SELECT * + FROM public.users + JOIN custom.shipments ON """, + ], +) +def test_suggested_join_conditions(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [ + alias("users"), + alias("shipments"), + name_join("shipments.id = users.id"), + fk_join("shipments.user_id = users.id"), + ] + ) + + +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +@parametrize( + ("query", "tbl"), + itertools.product( + ( + "SELECT * FROM public.{0} RIGHT OUTER JOIN ", + """SELECT * + FROM {0} + JOIN """, + ), + ("users", '"users"', "Users"), + ), +) +def test_suggested_joins(completer, query, tbl): + result = get_result(completer, query.format(tbl)) + assert completions_to_set(result) == completions_to_set( + testdata.schemas_and_from_clause_items() + + [join(f"custom.shipments ON shipments.user_id = {tbl}.id")] + ) + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +def test_suggested_column_names_from_schema_qualifed_table(completer): + result = get_result(completer, "SELECT from custom.products", len("SELECT ")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("products", "custom") + ) + + +@parametrize( + "text", + [ + "INSERT INTO orders(", + "INSERT INTO orders (", + "INSERT INTO public.orders(", + "INSERT INTO public.orders (", + ], +) +@parametrize("completer", completers(filtr=True, casing=False)) +def test_suggested_columns_with_insert(completer, text): + assert completions_to_set(get_result(completer, text)) == completions_to_set( + testdata.columns("orders") + ) + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +def test_suggested_column_names_in_function(completer): + result = get_result( + completer, "SELECT MAX( from custom.products", len("SELECT MAX(") + ) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("products", "custom") + ) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize( + "text", + ["SELECT * FROM Custom.", "SELECT * FROM custom.", 'SELECT * FROM "custom".'], +) +@parametrize("use_leading_double_quote", [False, True]) +def test_suggested_table_names_with_schema_dot( + completer, text, use_leading_double_quote +): + if use_leading_double_quote: + text += '"' + start_position = -1 + else: + start_position = 0 + + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.from_clause_items("custom", start_position) + ) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize("text", ['SELECT * FROM "Custom".']) +@parametrize("use_leading_double_quote", [False, True]) +def test_suggested_table_names_with_schema_dot2( + completer, text, use_leading_double_quote +): + if use_leading_double_quote: + text += '"' + start_position = -1 + else: + start_position = 0 + + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.from_clause_items("Custom", start_position) + ) + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_suggested_column_names_with_qualified_alias(completer): + result = get_result(completer, "SELECT p. from custom.products p", len("SELECT p.")) + assert completions_to_set(result) == completions_to_set( + testdata.columns("products", "custom") + ) + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +def test_suggested_multiple_column_names(completer): + result = get_result( + completer, "SELECT id, from custom.products", len("SELECT id, ") + ) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("products", "custom") + ) + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_suggested_multiple_column_names_with_alias(completer): + result = get_result( + completer, "SELECT p.id, p. from custom.products p", len("SELECT u.id, u.") + ) + assert completions_to_set(result) == completions_to_set( + testdata.columns("products", "custom") + ) + + +@parametrize("completer", completers(filtr=True, casing=False)) +@parametrize( + "text", + [ + "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON ", + "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON JOIN public.orders z ON z.id > y.id", + ], +) +def test_suggestions_after_on(completer, text): + position = len( + "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON " + ) + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + [ + alias("x"), + alias("y"), + name_join("y.price = x.price"), + name_join("y.product_name = x.product_name"), + name_join("y.id = x.id"), + ] + ) + + +@parametrize("completer", completers()) +def test_suggested_aliases_after_on_right_side(completer): + text = "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON x.id = " + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set([alias("x"), alias("y")]) + + +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +def test_table_names_after_from(completer): + text = "SELECT * FROM " + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas_and_from_clause_items() + ) + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_schema_qualified_function_name(completer): + text = "SELECT custom.func" + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [ + function("func3()", -len("func")), + function("set_returning_func()", -len("func")), + ] + ) + + +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +def test_schema_qualified_function_name_after_from(completer): + text = "SELECT * FROM custom.set_r" + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [ + function("set_returning_func()", -len("func")), + ] + ) + + +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +def test_unqualified_function_name_not_returned(completer): + text = "SELECT * FROM set_r" + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set([]) + + +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +def test_unqualified_function_name_in_search_path(completer): + completer.search_path = ["public", "custom"] + text = "SELECT * FROM set_r" + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [ + function("set_returning_func()", -len("func")), + ] + ) + + +@parametrize("completer", completers(filtr=True, casing=False)) +@parametrize( + "text", + [ + "SELECT 1::custom.", + "CREATE TABLE foo (bar custom.", + "CREATE FUNCTION foo (bar INT, baz custom.", + "ALTER TABLE foo ALTER COLUMN bar TYPE custom.", + ], +) +def test_schema_qualified_type_name(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set(testdata.types("custom")) + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_suggest_columns_from_aliased_set_returning_function(completer): + result = get_result( + completer, "select f. from custom.set_returning_func() f", len("select f.") + ) + assert completions_to_set(result) == completions_to_set( + testdata.columns("set_returning_func", "custom", "functions") + ) + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual)) +@parametrize( + "text", + [ + "SELECT * FROM custom.set_returning_func()", + "SELECT * FROM Custom.set_returning_func()", + "SELECT * FROM Custom.Set_Returning_Func()", + ], +) +def test_wildcard_column_expansion_with_function(completer, text): + position = len("SELECT *") + + completions = get_result(completer, text, position) + + col_list = "x" + expected = [wildcard_expansion(col_list)] + + assert expected == completions + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_wildcard_column_expansion_with_alias_qualifier(completer): + text = "SELECT p.* FROM custom.products p" + position = len("SELECT p.*") + + completions = get_result(completer, text, position) + + col_list = "id, p.product_name, p.price" + expected = [wildcard_expansion(col_list)] + + assert expected == completions + + +@parametrize("completer", completers(filtr=True, casing=False)) +@parametrize( + "text", + [ + """ + SELECT count(1) FROM users; + CREATE FUNCTION foo(custom.products _products) returns custom.shipments + LANGUAGE SQL + AS $foo$ + SELECT 1 FROM custom.shipments; + INSERT INTO public.orders(*) values(-1, now(), 'preliminary'); + SELECT 2 FROM custom.users; + $foo$; + SELECT count(1) FROM custom.shipments; + """, + "INSERT INTO public.orders(*", + "INSERT INTO public.Orders(*", + "INSERT INTO public.orders (*", + "INSERT INTO public.Orders (*", + "INSERT INTO orders(*", + "INSERT INTO Orders(*", + "INSERT INTO orders (*", + "INSERT INTO Orders (*", + "INSERT INTO public.orders(*)", + "INSERT INTO public.Orders(*)", + "INSERT INTO public.orders (*)", + "INSERT INTO public.Orders (*)", + "INSERT INTO orders(*)", + "INSERT INTO Orders(*)", + "INSERT INTO orders (*)", + "INSERT INTO Orders (*)", + ], +) +def test_wildcard_column_expansion_with_insert(completer, text): + position = text.index("*") + 1 + completions = get_result(completer, text, position) + + expected = [wildcard_expansion("ordered_date, status")] + assert expected == completions + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_wildcard_column_expansion_with_table_qualifier(completer): + text = 'SELECT "select".* FROM public."select"' + position = len('SELECT "select".*') + + completions = get_result(completer, text, position) + + col_list = 'id, "select"."localtime", "select"."ABC"' + expected = [wildcard_expansion(col_list)] + + assert expected == completions + + +@parametrize("completer", completers(filtr=True, casing=False, qualify=qual)) +def test_wildcard_column_expansion_with_two_tables(completer): + text = 'SELECT * FROM public."select" JOIN custom.users ON true' + position = len("SELECT *") + + completions = get_result(completer, text, position) + + cols = ( + '"select".id, "select"."localtime", "select"."ABC", ' + "users.id, users.phone_number" + ) + expected = [wildcard_expansion(cols)] + assert completions == expected + + +@parametrize("completer", completers(filtr=True, casing=False)) +def test_wildcard_column_expansion_with_two_tables_and_parent(completer): + text = 'SELECT "select".* FROM public."select" JOIN custom.users u ON true' + position = len('SELECT "select".*') + + completions = get_result(completer, text, position) + + col_list = 'id, "select"."localtime", "select"."ABC"' + expected = [wildcard_expansion(col_list)] + + assert expected == completions + + +@parametrize("completer", completers(filtr=True, casing=False)) +@parametrize( + "text", + [ + "SELECT U. FROM custom.Users U", + "SELECT U. FROM custom.USERS U", + "SELECT U. FROM custom.users U", + 'SELECT U. FROM "custom".Users U', + 'SELECT U. FROM "custom".USERS U', + 'SELECT U. FROM "custom".users U', + ], +) +def test_suggest_columns_from_unquoted_table(completer, text): + position = len("SELECT U.") + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + testdata.columns("users", "custom") + ) + + +@parametrize("completer", completers(filtr=True, casing=False)) +@parametrize( + "text", ['SELECT U. FROM custom."Users" U', 'SELECT U. FROM "custom"."Users" U'] +) +def test_suggest_columns_from_quoted_table(completer, text): + position = len("SELECT U.") + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + testdata.columns("Users", "custom") + ) + + +texts = ["SELECT * FROM ", "SELECT * FROM public.Orders O CROSS JOIN "] + + +@parametrize("completer", completers(filtr=True, casing=False, aliasing=False)) +@parametrize("text", texts) +def test_schema_or_visible_table_completion(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas_and_from_clause_items() + ) + + +@parametrize("completer", completers(aliasing=True, casing=False, filtr=True)) +@parametrize("text", texts) +def test_table_aliases(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas() + + [ + table("users u"), + table("orders o" if text == "SELECT * FROM " else "orders o2"), + table('"select" s'), + function("func1() f"), + function("func2() f"), + ] + ) + + +@parametrize("completer", completers(aliasing=True, casing=True, filtr=True)) +@parametrize("text", texts) +def test_aliases_with_casing(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + cased_schemas + + [ + table("users u"), + table("Orders O" if text == "SELECT * FROM " else "Orders O2"), + table('"select" s'), + function("Func1() F"), + function("func2() f"), + ] + ) + + +@parametrize("completer", completers(aliasing=False, casing=True, filtr=True)) +@parametrize("text", texts) +def test_table_casing(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + cased_schemas + + [ + table("users"), + table("Orders"), + table('"select"'), + function("Func1()"), + function("func2()"), + ] + ) + + +@parametrize("completer", completers(aliasing=False, casing=True)) +def test_alias_search_without_aliases2(completer): + text = "SELECT * FROM blog.et" + result = get_result(completer, text) + assert result[0] == table("EntryTags", -2) + + +@parametrize("completer", completers(aliasing=False, casing=True)) +def test_alias_search_without_aliases1(completer): + text = "SELECT * FROM blog.e" + result = get_result(completer, text) + assert result[0] == table("Entries", -1) + + +@parametrize("completer", completers(aliasing=True, casing=True)) +def test_alias_search_with_aliases2(completer): + text = "SELECT * FROM blog.et" + result = get_result(completer, text) + assert result[0] == table("EntryTags ET", -2) + + +@parametrize("completer", completers(aliasing=True, casing=True)) +def test_alias_search_with_aliases1(completer): + text = "SELECT * FROM blog.e" + result = get_result(completer, text) + assert result[0] == table("Entries E", -1) + + +@parametrize("completer", completers(aliasing=True, casing=True)) +def test_join_alias_search_with_aliases1(completer): + text = "SELECT * FROM blog.Entries E JOIN blog.e" + result = get_result(completer, text) + assert result[:2] == [ + table("Entries E2", -1), + join("EntAccLog EAL ON EAL.EntryID = E.EntryID", -1), + ] + + +@parametrize("completer", completers(aliasing=False, casing=True)) +def test_join_alias_search_without_aliases1(completer): + text = "SELECT * FROM blog.Entries JOIN blog.e" + result = get_result(completer, text) + assert result[:2] == [ + table("Entries", -1), + join("EntAccLog ON EntAccLog.EntryID = Entries.EntryID", -1), + ] + + +@parametrize("completer", completers(aliasing=True, casing=True)) +def test_join_alias_search_with_aliases2(completer): + text = "SELECT * FROM blog.Entries E JOIN blog.et" + result = get_result(completer, text) + assert result[0] == join("EntryTags ET ON ET.EntryID = E.EntryID", -2) + + +@parametrize("completer", completers(aliasing=False, casing=True)) +def test_join_alias_search_without_aliases2(completer): + text = "SELECT * FROM blog.Entries JOIN blog.et" + result = get_result(completer, text) + assert result[0] == join("EntryTags ON EntryTags.EntryID = Entries.EntryID", -2) + + +@parametrize("completer", completers()) +def test_function_alias_search_without_aliases(completer): + text = "SELECT blog.ees" + result = get_result(completer, text) + first = result[0] + assert first.start_position == -3 + assert first.text == "extract_entry_symbols()" + assert first.display_text == "extract_entry_symbols(_entryid)" + + +@parametrize("completer", completers()) +def test_function_alias_search_with_aliases(completer): + text = "SELECT blog.ee" + result = get_result(completer, text) + first = result[0] + assert first.start_position == -2 + assert first.text == "enter_entry(_title := , _text := )" + assert first.display_text == "enter_entry(_title, _text)" + + +@parametrize("completer", completers(filtr=True, casing=True, qualify=no_qual)) +def test_column_alias_search(completer): + result = get_result(completer, "SELECT et FROM blog.Entries E", len("SELECT et")) + cols = ("EntryText", "EntryTitle", "EntryID") + assert result[:3] == [column(c, -2) for c in cols] + + +@parametrize("completer", completers(casing=True)) +def test_column_alias_search_qualified(completer): + result = get_result( + completer, "SELECT E.ei FROM blog.Entries E", len("SELECT E.ei") + ) + cols = ("EntryID", "EntryTitle") + assert result[:3] == [column(c, -2) for c in cols] + + +@parametrize("completer", completers(casing=False, filtr=False, aliasing=False)) +def test_schema_object_order(completer): + result = get_result(completer, "SELECT * FROM u") + assert result[:3] == [ + table(t, pos=-1) for t in ("users", 'custom."Users"', "custom.users") + ] + + +@parametrize("completer", completers(casing=False, filtr=False, aliasing=False)) +def test_all_schema_objects(completer): + text = "SELECT * FROM " + result = get_result(completer, text) + assert completions_to_set(result) >= completions_to_set( + [table(x) for x in ("orders", '"select"', "custom.shipments")] + + [function(x + "()") for x in ("func2",)] + ) + + +@parametrize("completer", completers(filtr=False, aliasing=False, casing=True)) +def test_all_schema_objects_with_casing(completer): + text = "SELECT * FROM " + result = get_result(completer, text) + assert completions_to_set(result) >= completions_to_set( + [table(x) for x in ("Orders", '"select"', "CUSTOM.shipments")] + + [function(x + "()") for x in ("func2",)] + ) + + +@parametrize("completer", completers(casing=False, filtr=False, aliasing=True)) +def test_all_schema_objects_with_aliases(completer): + text = "SELECT * FROM " + result = get_result(completer, text) + assert completions_to_set(result) >= completions_to_set( + [table(x) for x in ("orders o", '"select" s', "custom.shipments s")] + + [function(x) for x in ("func2() f",)] + ) + + +@parametrize("completer", completers(casing=False, filtr=False, aliasing=True)) +def test_set_schema(completer): + text = "SET SCHEMA " + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [schema("'blog'"), schema("'Custom'"), schema("'custom'"), schema("'public'")] + ) diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py new file mode 100644 index 0000000..db1fe0a --- /dev/null +++ b/tests/test_smart_completion_public_schema_only.py @@ -0,0 +1,1112 @@ +from metadata import ( + MetaData, + alias, + name_join, + fk_join, + join, + keyword, + schema, + table, + view, + function, + column, + wildcard_expansion, + get_result, + result_set, + qual, + no_qual, + parametrize, +) +from prompt_toolkit.completion import Completion +from utils import completions_to_set + + +metadata = { + "tables": { + "users": ["id", "parentid", "email", "first_name", "last_name"], + "Users": ["userid", "username"], + "orders": ["id", "ordered_date", "status", "email"], + "select": ["id", "insert", "ABC"], + }, + "views": {"user_emails": ["id", "email"], "functions": ["function"]}, + "functions": [ + ["custom_fun", [], [], [], "", False, False, False, False], + ["_custom_fun", [], [], [], "", False, False, False, False], + ["custom_func1", [], [], [], "", False, False, False, False], + ["custom_func2", [], [], [], "", False, False, False, False], + [ + "set_returning_func", + ["x", "y"], + ["integer", "integer"], + ["b", "b"], + "", + False, + False, + True, + False, + ], + ], + "datatypes": ["custom_type1", "custom_type2"], + "foreignkeys": [ + ("public", "users", "id", "public", "users", "parentid"), + ("public", "users", "id", "public", "Users", "userid"), + ], +} + +metadata = {k: {"public": v} for k, v in metadata.items()} + +testdata = MetaData(metadata) + +cased_users_col_names = ["ID", "PARENTID", "Email", "First_Name", "last_name"] +cased_users2_col_names = ["UserID", "UserName"] +cased_func_names = [ + "Custom_Fun", + "_custom_fun", + "Custom_Func1", + "custom_func2", + "set_returning_func", +] +cased_tbls = ["Users", "Orders"] +cased_views = ["User_Emails", "Functions"] +casing = ( + ["SELECT", "PUBLIC"] + + cased_func_names + + cased_tbls + + cased_views + + cased_users_col_names + + cased_users2_col_names +) +# Lists for use in assertions +cased_funcs = [ + function(f) + for f in ("Custom_Fun()", "_custom_fun()", "Custom_Func1()", "custom_func2()") +] + [function("set_returning_func(x := , y := )", display="set_returning_func(x, y)")] +cased_tbls = [table(t) for t in (cased_tbls + ['"Users"', '"select"'])] +cased_rels = [view(t) for t in cased_views] + cased_funcs + cased_tbls +cased_users_cols = [column(c) for c in cased_users_col_names] +aliased_rels = ( + [table(t) for t in ("users u", '"Users" U', "orders o", '"select" s')] + + [view("user_emails ue"), view("functions f")] + + [ + function(f) + for f in ( + "_custom_fun() cf", + "custom_fun() cf", + "custom_func1() cf", + "custom_func2() cf", + ) + ] + + [ + function( + "set_returning_func(x := , y := ) srf", + display="set_returning_func(x, y) srf", + ) + ] +) +cased_aliased_rels = ( + [table(t) for t in ("Users U", '"Users" U', "Orders O", '"select" s')] + + [view("User_Emails UE"), view("Functions F")] + + [ + function(f) + for f in ( + "_custom_fun() cf", + "Custom_Fun() CF", + "Custom_Func1() CF", + "custom_func2() cf", + ) + ] + + [ + function( + "set_returning_func(x := , y := ) srf", + display="set_returning_func(x, y) srf", + ) + ] +) +completers = testdata.get_completers(casing) + + +# Just to make sure that this doesn't crash +@parametrize("completer", completers()) +def test_function_column_name(completer): + for l in range( + len("SELECT * FROM Functions WHERE function:"), + len("SELECT * FROM Functions WHERE function:text") + 1, + ): + assert [] == get_result( + completer, "SELECT * FROM Functions WHERE function:text"[:l] + ) + + +@parametrize("action", ["ALTER", "DROP", "CREATE", "CREATE OR REPLACE"]) +@parametrize("completer", completers()) +def test_drop_alter_function(completer, action): + assert get_result(completer, action + " FUNCTION set_ret") == [ + function("set_returning_func(x integer, y integer)", -len("set_ret")) + ] + + +@parametrize("completer", completers()) +def test_empty_string_completion(completer): + result = get_result(completer, "") + assert completions_to_set( + testdata.keywords() + testdata.specials() + ) == completions_to_set(result) + + +@parametrize("completer", completers()) +def test_select_keyword_completion(completer): + result = get_result(completer, "SEL") + assert completions_to_set(result) == completions_to_set([keyword("SELECT", -3)]) + + +@parametrize("completer", completers()) +def test_builtin_function_name_completion(completer): + result = get_result(completer, "SELECT MA") + assert completions_to_set(result) == completions_to_set( + [ + function("MAKE_DATE", -2), + function("MAKE_INTERVAL", -2), + function("MAKE_TIME", -2), + function("MAKE_TIMESTAMP", -2), + function("MAKE_TIMESTAMPTZ", -2), + function("MASKLEN", -2), + function("MAX", -2), + keyword("MAXEXTENTS", -2), + keyword("MATERIALIZED VIEW", -2), + ] + ) + + +@parametrize("completer", completers()) +def test_builtin_function_matches_only_at_start(completer): + text = "SELECT IN" + + result = [c.text for c in get_result(completer, text)] + + assert "MIN" not in result + + +@parametrize("completer", completers(casing=False, aliasing=False)) +def test_user_function_name_completion(completer): + result = get_result(completer, "SELECT cu") + assert completions_to_set(result) == completions_to_set( + [ + function("custom_fun()", -2), + function("_custom_fun()", -2), + function("custom_func1()", -2), + function("custom_func2()", -2), + function("CURRENT_DATE", -2), + function("CURRENT_TIMESTAMP", -2), + function("CUME_DIST", -2), + function("CURRENT_TIME", -2), + keyword("CURRENT", -2), + ] + ) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +def test_user_function_name_completion_matches_anywhere(completer): + result = get_result(completer, "SELECT om") + assert completions_to_set(result) == completions_to_set( + [ + function("custom_fun()", -2), + function("_custom_fun()", -2), + function("custom_func1()", -2), + function("custom_func2()", -2), + ] + ) + + +@parametrize("completer", completers(casing=True)) +def test_list_functions_for_special(completer): + result = get_result(completer, r"\df ") + assert completions_to_set(result) == completions_to_set( + [schema("PUBLIC")] + [function(f) for f in cased_func_names] + ) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_suggested_column_names_from_visible_table(completer): + result = get_result(completer, "SELECT from users", len("SELECT ")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("users") + ) + + +@parametrize("completer", completers(casing=True, qualify=no_qual)) +def test_suggested_cased_column_names(completer): + result = get_result(completer, "SELECT from users", len("SELECT ")) + assert completions_to_set(result) == completions_to_set( + cased_funcs + + cased_users_cols + + testdata.builtin_functions() + + testdata.keywords() + ) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +@parametrize("text", ["SELECT from users", "INSERT INTO Orders SELECT from users"]) +def test_suggested_auto_qualified_column_names(text, completer): + position = text.index(" ") + 1 + cols = [column(c.lower()) for c in cased_users_col_names] + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + cols + testdata.functions_and_keywords() + ) + + +@parametrize("completer", completers(casing=False, qualify=qual)) +@parametrize( + "text", + [ + 'SELECT from users U NATURAL JOIN "Users"', + 'INSERT INTO Orders SELECT from users U NATURAL JOIN "Users"', + ], +) +def test_suggested_auto_qualified_column_names_two_tables(text, completer): + position = text.index(" ") + 1 + cols = [column("U." + c.lower()) for c in cased_users_col_names] + cols += [column('"Users".' + c.lower()) for c in cased_users2_col_names] + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + cols + testdata.functions_and_keywords() + ) + + +@parametrize("completer", completers(casing=True, qualify=["always"])) +@parametrize("text", ["UPDATE users SET ", "INSERT INTO users("]) +def test_no_column_qualification(text, completer): + cols = [column(c) for c in cased_users_col_names] + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set(cols) + + +@parametrize("completer", completers(casing=True, qualify=["always"])) +def test_suggested_cased_always_qualified_column_names(completer): + text = "SELECT from users" + position = len("SELECT ") + cols = [column("users." + c) for c in cased_users_col_names] + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + cased_funcs + cols + testdata.builtin_functions() + testdata.keywords() + ) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_suggested_column_names_in_function(completer): + result = get_result(completer, "SELECT MAX( from users", len("SELECT MAX(")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("users") + ) + + +@parametrize("completer", completers(casing=False)) +def test_suggested_column_names_with_table_dot(completer): + result = get_result(completer, "SELECT users. from users", len("SELECT users.")) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) + + +@parametrize("completer", completers(casing=False)) +def test_suggested_column_names_with_alias(completer): + result = get_result(completer, "SELECT u. from users u", len("SELECT u.")) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_suggested_multiple_column_names(completer): + result = get_result(completer, "SELECT id, from users u", len("SELECT id, ")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("users") + ) + + +@parametrize("completer", completers(casing=False)) +def test_suggested_multiple_column_names_with_alias(completer): + result = get_result( + completer, "SELECT u.id, u. from users u", len("SELECT u.id, u.") + ) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) + + +@parametrize("completer", completers(casing=True)) +def test_suggested_cased_column_names_with_alias(completer): + result = get_result( + completer, "SELECT u.id, u. from users u", len("SELECT u.id, u.") + ) + assert completions_to_set(result) == completions_to_set(cased_users_cols) + + +@parametrize("completer", completers(casing=False)) +def test_suggested_multiple_column_names_with_dot(completer): + result = get_result( + completer, + "SELECT users.id, users. from users u", + len("SELECT users.id, users."), + ) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) + + +@parametrize("completer", completers(casing=False)) +def test_suggest_columns_after_three_way_join(completer): + text = """SELECT * FROM users u1 + INNER JOIN users u2 ON u1.id = u2.id + INNER JOIN users u3 ON u2.id = u3.""" + result = get_result(completer, text) + assert column("id") in result + + +join_condition_texts = [ + 'INSERT INTO orders SELECT * FROM users U JOIN "Users" U2 ON ', + """INSERT INTO public.orders(orderid) + SELECT * FROM users U JOIN "Users" U2 ON """, + 'SELECT * FROM users U JOIN "Users" U2 ON ', + 'SELECT * FROM users U INNER join "Users" U2 ON ', + 'SELECT * FROM USERS U right JOIN "Users" U2 ON ', + 'SELECT * FROM users U LEFT JOIN "Users" U2 ON ', + 'SELECT * FROM Users U FULL JOIN "Users" U2 ON ', + 'SELECT * FROM users U right outer join "Users" U2 ON ', + 'SELECT * FROM Users U LEFT OUTER JOIN "Users" U2 ON ', + 'SELECT * FROM users U FULL OUTER JOIN "Users" U2 ON ', + """SELECT * + FROM users U + FULL OUTER JOIN "Users" U2 ON + """, +] + + +@parametrize("completer", completers(casing=False)) +@parametrize("text", join_condition_texts) +def test_suggested_join_conditions(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [alias("U"), alias("U2"), fk_join("U2.userid = U.id")] + ) + + +@parametrize("completer", completers(casing=True)) +@parametrize("text", join_condition_texts) +def test_cased_join_conditions(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [alias("U"), alias("U2"), fk_join("U2.UserID = U.ID")] + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + """SELECT * + FROM users + CROSS JOIN "Users" + NATURAL JOIN users u + JOIN "Users" u2 ON + """ + ], +) +def test_suggested_join_conditions_with_same_table_twice(completer, text): + result = get_result(completer, text) + assert result == [ + fk_join("u2.userid = u.id"), + fk_join("u2.userid = users.id"), + name_join('u2.userid = "Users".userid'), + name_join('u2.username = "Users".username'), + alias("u"), + alias("u2"), + alias("users"), + alias('"Users"'), + ] + + +@parametrize("completer", completers()) +@parametrize("text", ["SELECT * FROM users JOIN users u2 on foo."]) +def test_suggested_join_conditions_with_invalid_qualifier(completer, text): + result = get_result(completer, text) + assert result == [] + + +@parametrize("completer", completers(casing=False)) +@parametrize( + ("text", "ref"), + [ + ("SELECT * FROM users JOIN NonTable on ", "NonTable"), + ("SELECT * FROM users JOIN nontable nt on ", "nt"), + ], +) +def test_suggested_join_conditions_with_invalid_table(completer, text, ref): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [alias("users"), alias(ref)] + ) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize( + "text", + [ + 'SELECT * FROM "Users" u JOIN u', + 'SELECT * FROM "Users" u JOIN uid', + 'SELECT * FROM "Users" u JOIN userid', + 'SELECT * FROM "Users" u JOIN id', + ], +) +def test_suggested_joins_fuzzy(completer, text): + result = get_result(completer, text) + last_word = text.split()[-1] + expected = join("users ON users.id = u.userid", -len(last_word)) + assert expected in result + + +join_texts = [ + "SELECT * FROM Users JOIN ", + """INSERT INTO "Users" + SELECT * + FROM Users + INNER JOIN """, + """INSERT INTO public."Users"(username) + SELECT * + FROM Users + INNER JOIN """, + """SELECT * + FROM Users + INNER JOIN """, +] + + +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize("text", join_texts) +def test_suggested_joins(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas_and_from_clause_items() + + [ + join('"Users" ON "Users".userid = Users.id'), + join("users users2 ON users2.id = Users.parentid"), + join("users users2 ON users2.parentid = Users.id"), + ] + ) + + +@parametrize("completer", completers(casing=True, aliasing=False)) +@parametrize("text", join_texts) +def test_cased_joins(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [schema("PUBLIC")] + + cased_rels + + [ + join('"Users" ON "Users".UserID = Users.ID'), + join("Users Users2 ON Users2.ID = Users.PARENTID"), + join("Users Users2 ON Users2.PARENTID = Users.ID"), + ] + ) + + +@parametrize("completer", completers(casing=False, aliasing=True)) +@parametrize("text", join_texts) +def test_aliased_joins(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas() + + aliased_rels + + [ + join('"Users" U ON U.userid = Users.id'), + join("users u ON u.id = Users.parentid"), + join("users u ON u.parentid = Users.id"), + ] + ) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize( + "text", + [ + 'SELECT * FROM public."Users" JOIN ', + 'SELECT * FROM public."Users" RIGHT OUTER JOIN ', + """SELECT * + FROM public."Users" + LEFT JOIN """, + ], +) +def test_suggested_joins_quoted_schema_qualified_table(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas_and_from_clause_items() + + [join('public.users ON users.id = "Users".userid')] + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT u.name, o.id FROM users u JOIN orders o ON ", + "SELECT u.name, o.id FROM users u JOIN orders o ON JOIN orders o2 ON", + ], +) +def test_suggested_aliases_after_on(completer, text): + position = len("SELECT u.name, o.id FROM users u JOIN orders o ON ") + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + [ + alias("u"), + name_join("o.id = u.id"), + name_join("o.email = u.email"), + alias("o"), + ] + ) + + +@parametrize("completer", completers()) +@parametrize( + "text", + [ + "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ", + "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = JOIN orders o2 ON", + ], +) +def test_suggested_aliases_after_on_right_side(completer, text): + position = len("SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ") + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set([alias("u"), alias("o")]) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT users.name, orders.id FROM users JOIN orders ON ", + "SELECT users.name, orders.id FROM users JOIN orders ON JOIN orders orders2 ON", + ], +) +def test_suggested_tables_after_on(completer, text): + position = len("SELECT users.name, orders.id FROM users JOIN orders ON ") + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + [ + name_join("orders.id = users.id"), + name_join("orders.email = users.email"), + alias("users"), + alias("orders"), + ] + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = JOIN orders orders2 ON", + "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ", + ], +) +def test_suggested_tables_after_on_right_side(completer, text): + position = len( + "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = " + ) + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + [alias("users"), alias("orders")] + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT * FROM users INNER JOIN orders USING (", + "SELECT * FROM users INNER JOIN orders USING(", + ], +) +def test_join_using_suggests_common_columns(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [column("id"), column("email")] + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT * FROM users u1 JOIN users u2 USING (email) JOIN user_emails ue USING()", + "SELECT * FROM users u1 JOIN users u2 USING(email) JOIN user_emails ue USING ()", + "SELECT * FROM users u1 JOIN user_emails ue USING () JOIN users u2 ue USING(first_name, last_name)", + "SELECT * FROM users u1 JOIN user_emails ue USING() JOIN users u2 ue USING (first_name, last_name)", + ], +) +def test_join_using_suggests_from_last_table(completer, text): + position = text.index("()") + 1 + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set( + [column("id"), column("email")] + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT * FROM users INNER JOIN orders USING (id,", + "SELECT * FROM users INNER JOIN orders USING(id,", + ], +) +def test_join_using_suggests_columns_after_first_column(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [column("id"), column("email")] + ) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize( + "text", + [ + "SELECT * FROM ", + "SELECT * FROM users CROSS JOIN ", + "SELECT * FROM users natural join ", + ], +) +def test_table_names_after_from(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas_and_from_clause_items() + ) + assert [c.text for c in result] == [ + "public", + "orders", + '"select"', + "users", + '"Users"', + "functions", + "user_emails", + "_custom_fun()", + "custom_fun()", + "custom_func1()", + "custom_func2()", + "set_returning_func(x := , y := )", + ] + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_auto_escaped_col_names(completer): + result = get_result(completer, 'SELECT from "select"', len("SELECT ")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("select") + ) + + +@parametrize("completer", completers(aliasing=False)) +def test_allow_leading_double_quote_in_last_word(completer): + result = get_result(completer, 'SELECT * from "sele') + + expected = table('"select"', -5) + + assert expected in result + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT 1::", + "CREATE TABLE foo (bar ", + "CREATE FUNCTION foo (bar INT, baz ", + "ALTER TABLE foo ALTER COLUMN bar TYPE ", + ], +) +def test_suggest_datatype(text, completer): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas() + testdata.types() + testdata.builtin_datatypes() + ) + + +@parametrize("completer", completers(casing=False)) +def test_suggest_columns_from_escaped_table_alias(completer): + result = get_result(completer, 'select * from "select" s where s.') + assert completions_to_set(result) == completions_to_set(testdata.columns("select")) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_suggest_columns_from_set_returning_function(completer): + result = get_result(completer, "select from set_returning_func()", len("select ")) + assert completions_to_set(result) == completions_to_set( + testdata.columns_functions_and_keywords("set_returning_func", typ="functions") + ) + + +@parametrize("completer", completers(casing=False)) +def test_suggest_columns_from_aliased_set_returning_function(completer): + result = get_result( + completer, "select f. from set_returning_func() f", len("select f.") + ) + assert completions_to_set(result) == completions_to_set( + testdata.columns("set_returning_func", typ="functions") + ) + + +@parametrize("completer", completers(casing=False)) +def test_join_functions_using_suggests_common_columns(completer): + text = """SELECT * FROM set_returning_func() f1 + INNER JOIN set_returning_func() f2 USING (""" + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.columns("set_returning_func", typ="functions") + ) + + +@parametrize("completer", completers(casing=False)) +def test_join_functions_on_suggests_columns_and_join_conditions(completer): + text = """SELECT * FROM set_returning_func() f1 + INNER JOIN set_returning_func() f2 ON f1.""" + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [name_join("y = f2.y"), name_join("x = f2.x")] + + testdata.columns("set_returning_func", typ="functions") + ) + + +@parametrize("completer", completers()) +def test_learn_keywords(completer): + history = "CREATE VIEW v AS SELECT 1" + completer.extend_query_history(history) + + # Now that we've used `VIEW` once, it should be suggested ahead of other + # keywords starting with v. + text = "create v" + completions = get_result(completer, text) + assert completions[0].text == "VIEW" + + +@parametrize("completer", completers(casing=False, aliasing=False)) +def test_learn_table_names(completer): + history = "SELECT * FROM users; SELECT * FROM orders; SELECT * FROM users" + completer.extend_query_history(history) + + text = "SELECT * FROM " + completions = get_result(completer, text) + + # `users` should be higher priority than `orders` (used more often) + users = table("users") + orders = table("orders") + + assert completions.index(users) < completions.index(orders) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_columns_before_keywords(completer): + text = "SELECT * FROM orders WHERE s" + completions = get_result(completer, text) + + col = column("status", -1) + kw = keyword("SELECT", -1) + + assert completions.index(col) < completions.index(kw) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +@parametrize( + "text", + [ + "SELECT * FROM users", + "INSERT INTO users SELECT * FROM users u", + """INSERT INTO users(id, parentid, email, first_name, last_name) + SELECT * + FROM users u""", + ], +) +def test_wildcard_column_expansion(completer, text): + position = text.find("*") + 1 + + completions = get_result(completer, text, position) + + col_list = "id, parentid, email, first_name, last_name" + expected = [wildcard_expansion(col_list)] + + assert expected == completions + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "SELECT u.* FROM users u", + "INSERT INTO public.users SELECT u.* FROM users u", + """INSERT INTO users(id, parentid, email, first_name, last_name) + SELECT u.* + FROM users u""", + ], +) +def test_wildcard_column_expansion_with_alias(completer, text): + position = text.find("*") + 1 + + completions = get_result(completer, text, position) + + col_list = "id, u.parentid, u.email, u.first_name, u.last_name" + expected = [wildcard_expansion(col_list)] + + assert expected == completions + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text,expected", + [ + ( + "SELECT users.* FROM users", + "id, users.parentid, users.email, users.first_name, users.last_name", + ), + ( + "SELECT Users.* FROM Users", + "id, Users.parentid, Users.email, Users.first_name, Users.last_name", + ), + ], +) +def test_wildcard_column_expansion_with_table_qualifier(completer, text, expected): + position = len("SELECT users.*") + + completions = get_result(completer, text, position) + + expected = [wildcard_expansion(expected)] + + assert expected == completions + + +@parametrize("completer", completers(casing=False, qualify=qual)) +def test_wildcard_column_expansion_with_two_tables(completer): + text = 'SELECT * FROM "select" JOIN users u ON true' + position = len("SELECT *") + + completions = get_result(completer, text, position) + + cols = ( + '"select".id, "select".insert, "select"."ABC", ' + "u.id, u.parentid, u.email, u.first_name, u.last_name" + ) + expected = [wildcard_expansion(cols)] + assert completions == expected + + +@parametrize("completer", completers(casing=False)) +def test_wildcard_column_expansion_with_two_tables_and_parent(completer): + text = 'SELECT "select".* FROM "select" JOIN users u ON true' + position = len('SELECT "select".*') + + completions = get_result(completer, text, position) + + col_list = 'id, "select".insert, "select"."ABC"' + expected = [wildcard_expansion(col_list)] + + assert expected == completions + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + ["SELECT U. FROM Users U", "SELECT U. FROM USERS U", "SELECT U. FROM users U"], +) +def test_suggest_columns_from_unquoted_table(completer, text): + position = len("SELECT U.") + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) + + +@parametrize("completer", completers(casing=False)) +def test_suggest_columns_from_quoted_table(completer): + result = get_result(completer, 'SELECT U. FROM "Users" U', len("SELECT U.")) + assert completions_to_set(result) == completions_to_set(testdata.columns("Users")) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +@parametrize("text", ["SELECT * FROM ", "SELECT * FROM Orders o CROSS JOIN "]) +def test_schema_or_visible_table_completion(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas_and_from_clause_items() + ) + + +@parametrize("completer", completers(casing=False, aliasing=True)) +@parametrize("text", ["SELECT * FROM "]) +def test_table_aliases(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas() + aliased_rels + ) + + +@parametrize("completer", completers(casing=False, aliasing=True)) +@parametrize("text", ["SELECT * FROM Orders o CROSS JOIN "]) +def test_duplicate_table_aliases(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + testdata.schemas() + + [ + table("orders o2"), + table("users u"), + table('"Users" U'), + table('"select" s'), + view("user_emails ue"), + view("functions f"), + function("_custom_fun() cf"), + function("custom_fun() cf"), + function("custom_func1() cf"), + function("custom_func2() cf"), + function( + "set_returning_func(x := , y := ) srf", + display="set_returning_func(x, y) srf", + ), + ] + ) + + +@parametrize("completer", completers(casing=True, aliasing=True)) +@parametrize("text", ["SELECT * FROM Orders o CROSS JOIN "]) +def test_duplicate_aliases_with_casing(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [ + schema("PUBLIC"), + table("Orders O2"), + table("Users U"), + table('"Users" U'), + table('"select" s'), + view("User_Emails UE"), + view("Functions F"), + function("_custom_fun() cf"), + function("Custom_Fun() CF"), + function("Custom_Func1() CF"), + function("custom_func2() cf"), + function( + "set_returning_func(x := , y := ) srf", + display="set_returning_func(x, y) srf", + ), + ] + ) + + +@parametrize("completer", completers(casing=True, aliasing=True)) +@parametrize("text", ["SELECT * FROM "]) +def test_aliases_with_casing(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [schema("PUBLIC")] + cased_aliased_rels + ) + + +@parametrize("completer", completers(casing=True, aliasing=False)) +@parametrize("text", ["SELECT * FROM "]) +def test_table_casing(completer, text): + result = get_result(completer, text) + assert completions_to_set(result) == completions_to_set( + [schema("PUBLIC")] + cased_rels + ) + + +@parametrize("completer", completers(casing=False)) +@parametrize( + "text", + [ + "INSERT INTO users ()", + "INSERT INTO users()", + "INSERT INTO users () SELECT * FROM orders;", + "INSERT INTO users() SELECT * FROM users u cross join orders o", + ], +) +def test_insert(completer, text): + position = text.find("(") + 1 + result = get_result(completer, text, position) + assert completions_to_set(result) == completions_to_set(testdata.columns("users")) + + +@parametrize("completer", completers(casing=False, aliasing=False)) +def test_suggest_cte_names(completer): + text = """ + WITH cte1 AS (SELECT a, b, c FROM foo), + cte2 AS (SELECT d, e, f FROM bar) + SELECT * FROM + """ + result = get_result(completer, text) + expected = completions_to_set( + [ + Completion("cte1", 0, display_meta="table"), + Completion("cte2", 0, display_meta="table"), + ] + ) + assert expected <= completions_to_set(result) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +def test_suggest_columns_from_cte(completer): + result = get_result( + completer, + "WITH cte AS (SELECT foo, bar FROM baz) SELECT FROM cte", + len("WITH cte AS (SELECT foo, bar FROM baz) SELECT "), + ) + expected = [ + Completion("foo", 0, display_meta="column"), + Completion("bar", 0, display_meta="column"), + ] + testdata.functions_and_keywords() + + assert completions_to_set(expected) == completions_to_set(result) + + +@parametrize("completer", completers(casing=False, qualify=no_qual)) +@parametrize( + "text", + [ + "WITH cte AS (SELECT foo FROM bar) SELECT * FROM cte WHERE cte.", + "WITH cte AS (SELECT foo FROM bar) SELECT * FROM cte c WHERE c.", + ], +) +def test_cte_qualified_columns(completer, text): + result = get_result(completer, text) + expected = [Completion("foo", 0, display_meta="column")] + assert completions_to_set(expected) == completions_to_set(result) + + +@parametrize( + "keyword_casing,expected,texts", + [ + ("upper", "SELECT", ("", "s", "S", "Sel")), + ("lower", "select", ("", "s", "S", "Sel")), + ("auto", "SELECT", ("", "S", "SEL", "seL")), + ("auto", "select", ("s", "sel", "SEl")), + ], +) +def test_keyword_casing_upper(keyword_casing, expected, texts): + for text in texts: + completer = testdata.get_completer({"keyword_casing": keyword_casing}) + completions = get_result(completer, text) + assert expected in [cpl.text for cpl in completions] + + +@parametrize("completer", completers()) +def test_keyword_after_alter(completer): + text = "ALTER TABLE users ALTER " + expected = Completion("COLUMN", start_position=0, display_meta="keyword") + completions = get_result(completer, text) + assert expected in completions + + +@parametrize("completer", completers()) +def test_set_schema(completer): + text = "SET SCHEMA " + result = get_result(completer, text) + expected = completions_to_set([schema("'public'")]) + assert completions_to_set(result) == expected + + +@parametrize("completer", completers()) +def test_special_name_completion(completer): + result = get_result(completer, "\\t") + assert completions_to_set(result) == completions_to_set( + [ + Completion( + text="\\timing", + start_position=-2, + display_meta="Toggle timing of commands.", + ) + ] + ) diff --git a/tests/test_sqlcompletion.py b/tests/test_sqlcompletion.py new file mode 100644 index 0000000..1034bbe --- /dev/null +++ b/tests/test_sqlcompletion.py @@ -0,0 +1,964 @@ +from pgcli.packages.sqlcompletion import ( + suggest_type, + Special, + Database, + Schema, + Table, + Column, + View, + Keyword, + FromClauseItem, + Function, + Datatype, + Alias, + JoinCondition, + Join, +) +from pgcli.packages.parseutils.tables import TableReference +import pytest + + +def cols_etc( + table, schema=None, alias=None, is_function=False, parent=None, last_keyword=None +): + """Returns the expected select-clause suggestions for a single-table + select.""" + return { + Column( + table_refs=(TableReference(schema, table, alias, is_function),), + qualifiable=True, + ), + Function(schema=parent), + Keyword(last_keyword), + } + + +def test_select_suggests_cols_with_visible_table_scope(): + suggestions = suggest_type("SELECT FROM tabl", "SELECT ") + assert set(suggestions) == cols_etc("tabl", last_keyword="SELECT") + + +def test_select_suggests_cols_with_qualified_table_scope(): + suggestions = suggest_type("SELECT FROM sch.tabl", "SELECT ") + assert set(suggestions) == cols_etc("tabl", "sch", last_keyword="SELECT") + + +def test_cte_does_not_crash(): + sql = "WITH CTE AS (SELECT F.* FROM Foo F WHERE F.Bar > 23) SELECT C.* FROM CTE C WHERE C.FooID BETWEEN 123 AND 234;" + for i in range(len(sql)): + suggestions = suggest_type(sql[: i + 1], sql[: i + 1]) + + +@pytest.mark.parametrize("expression", ['SELECT * FROM "tabl" WHERE ']) +def test_where_suggests_columns_functions_quoted_table(expression): + expected = cols_etc("tabl", alias='"tabl"', last_keyword="WHERE") + suggestions = suggest_type(expression, expression) + assert expected == set(suggestions) + + +@pytest.mark.parametrize( + "expression", + [ + "INSERT INTO OtherTabl(ID, Name) SELECT * FROM tabl WHERE ", + "INSERT INTO OtherTabl SELECT * FROM tabl WHERE ", + "SELECT * FROM tabl WHERE ", + "SELECT * FROM tabl WHERE (", + "SELECT * FROM tabl WHERE foo = ", + "SELECT * FROM tabl WHERE bar OR ", + "SELECT * FROM tabl WHERE foo = 1 AND ", + "SELECT * FROM tabl WHERE (bar > 10 AND ", + "SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (", + "SELECT * FROM tabl WHERE 10 < ", + "SELECT * FROM tabl WHERE foo BETWEEN ", + "SELECT * FROM tabl WHERE foo BETWEEN foo AND ", + ], +) +def test_where_suggests_columns_functions(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE") + + +@pytest.mark.parametrize( + "expression", + ["SELECT * FROM tabl WHERE foo IN (", "SELECT * FROM tabl WHERE foo IN (bar, "], +) +def test_where_in_suggests_columns(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE") + + +@pytest.mark.parametrize("expression", ["SELECT 1 AS ", "SELECT 1 FROM tabl AS "]) +def test_after_as(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == set() + + +def test_where_equals_any_suggests_columns_or_keywords(): + text = "SELECT * FROM tabl WHERE foo = ANY(" + suggestions = suggest_type(text, text) + assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE") + + +def test_lparen_suggests_cols_and_funcs(): + suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(") + assert set(suggestion) == { + Column(table_refs=((None, "tbl", None, False),), qualifiable=True), + Function(schema=None), + Keyword("("), + } + + +def test_select_suggests_cols_and_funcs(): + suggestions = suggest_type("SELECT ", "SELECT ") + assert set(suggestions) == { + Column(table_refs=(), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + } + + +@pytest.mark.parametrize( + "expression", ["INSERT INTO ", "COPY ", "UPDATE ", "DESCRIBE "] +) +def test_suggests_tables_views_and_schemas(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == {Table(schema=None), View(schema=None), Schema()} + + +@pytest.mark.parametrize("expression", ["SELECT * FROM "]) +def test_suggest_tables_views_schemas_and_functions(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM foo JOIN bar on bar.barid = foo.barid JOIN ", + "SELECT * FROM foo JOIN bar USING (barid) JOIN ", + ], +) +def test_suggest_after_join_with_two_tables(expression): + suggestions = suggest_type(expression, expression) + tables = tuple([(None, "foo", None, False), (None, "bar", None, False)]) + assert set(suggestions) == { + FromClauseItem(schema=None, table_refs=tables), + Join(tables, None), + Schema(), + } + + +@pytest.mark.parametrize( + "expression", ["SELECT * FROM foo JOIN ", "SELECT * FROM foo JOIN bar"] +) +def test_suggest_after_join_with_one_table(expression): + suggestions = suggest_type(expression, expression) + tables = ((None, "foo", None, False),) + assert set(suggestions) == { + FromClauseItem(schema=None, table_refs=tables), + Join(((None, "foo", None, False),), None), + Schema(), + } + + +@pytest.mark.parametrize( + "expression", ["INSERT INTO sch.", "COPY sch.", "DESCRIBE sch."] +) +def test_suggest_qualified_tables_and_views(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == {Table(schema="sch"), View(schema="sch")} + + +@pytest.mark.parametrize("expression", ["UPDATE sch."]) +def test_suggest_qualified_aliasable_tables_and_views(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == {Table(schema="sch"), View(schema="sch")} + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM sch.", + 'SELECT * FROM sch."', + 'SELECT * FROM sch."foo', + 'SELECT * FROM "sch".', + 'SELECT * FROM "sch"."', + ], +) +def test_suggest_qualified_tables_views_and_functions(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == {FromClauseItem(schema="sch")} + + +@pytest.mark.parametrize("expression", ["SELECT * FROM foo JOIN sch."]) +def test_suggest_qualified_tables_views_functions_and_joins(expression): + suggestions = suggest_type(expression, expression) + tbls = tuple([(None, "foo", None, False)]) + assert set(suggestions) == { + FromClauseItem(schema="sch", table_refs=tbls), + Join(tbls, "sch"), + } + + +def test_truncate_suggests_tables_and_schemas(): + suggestions = suggest_type("TRUNCATE ", "TRUNCATE ") + assert set(suggestions) == {Table(schema=None), Schema()} + + +def test_truncate_suggests_qualified_tables(): + suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.") + assert set(suggestions) == {Table(schema="sch")} + + +@pytest.mark.parametrize( + "text", ["SELECT DISTINCT ", "INSERT INTO foo SELECT DISTINCT "] +) +def test_distinct_suggests_cols(text): + suggestions = suggest_type(text, text) + assert set(suggestions) == { + Column(table_refs=(), local_tables=(), qualifiable=True), + Function(schema=None), + Keyword("DISTINCT"), + } + + +@pytest.mark.parametrize( + "text, text_before, last_keyword", + [ + ("SELECT DISTINCT FROM tbl x JOIN tbl1 y", "SELECT DISTINCT", "SELECT"), + ( + "SELECT * FROM tbl x JOIN tbl1 y ORDER BY ", + "SELECT * FROM tbl x JOIN tbl1 y ORDER BY ", + "ORDER BY", + ), + ], +) +def test_distinct_and_order_by_suggestions_with_aliases( + text, text_before, last_keyword +): + suggestions = suggest_type(text, text_before) + assert set(suggestions) == { + Column( + table_refs=( + TableReference(None, "tbl", "x", False), + TableReference(None, "tbl1", "y", False), + ), + local_tables=(), + qualifiable=True, + ), + Function(schema=None), + Keyword(last_keyword), + } + + +@pytest.mark.parametrize( + "text, text_before", + [ + ("SELECT DISTINCT x. FROM tbl x JOIN tbl1 y", "SELECT DISTINCT x."), + ( + "SELECT * FROM tbl x JOIN tbl1 y ORDER BY x.", + "SELECT * FROM tbl x JOIN tbl1 y ORDER BY x.", + ), + ], +) +def test_distinct_and_order_by_suggestions_with_alias_given(text, text_before): + suggestions = suggest_type(text, text_before) + assert set(suggestions) == { + Column( + table_refs=(TableReference(None, "tbl", "x", False),), + local_tables=(), + qualifiable=False, + ), + Table(schema="x"), + View(schema="x"), + Function(schema="x"), + } + + +def test_function_arguments_with_alias_given(): + suggestions = suggest_type("SELECT avg(x. FROM tbl x, tbl2 y", "SELECT avg(x.") + + assert set(suggestions) == { + Column( + table_refs=(TableReference(None, "tbl", "x", False),), + local_tables=(), + qualifiable=False, + ), + Table(schema="x"), + View(schema="x"), + Function(schema="x"), + } + + +def test_col_comma_suggests_cols(): + suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,") + assert set(suggestions) == { + Column(table_refs=((None, "tbl", None, False),), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + } + + +def test_table_comma_suggests_tables_and_schemas(): + suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ") + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} + + +def test_into_suggests_tables_and_schemas(): + suggestion = suggest_type("INSERT INTO ", "INSERT INTO ") + assert set(suggestion) == {Table(schema=None), View(schema=None), Schema()} + + +@pytest.mark.parametrize( + "text", ["INSERT INTO abc (", "INSERT INTO abc () SELECT * FROM hij;"] +) +def test_insert_into_lparen_suggests_cols(text): + suggestions = suggest_type(text, "INSERT INTO abc (") + assert suggestions == ( + Column(table_refs=((None, "abc", None, False),), context="insert"), + ) + + +def test_insert_into_lparen_partial_text_suggests_cols(): + suggestions = suggest_type("INSERT INTO abc (i", "INSERT INTO abc (i") + assert suggestions == ( + Column(table_refs=((None, "abc", None, False),), context="insert"), + ) + + +def test_insert_into_lparen_comma_suggests_cols(): + suggestions = suggest_type("INSERT INTO abc (id,", "INSERT INTO abc (id,") + assert suggestions == ( + Column(table_refs=((None, "abc", None, False),), context="insert"), + ) + + +def test_partially_typed_col_name_suggests_col_names(): + suggestions = suggest_type( + "SELECT * FROM tabl WHERE col_n", "SELECT * FROM tabl WHERE col_n" + ) + assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE") + + +def test_dot_suggests_cols_of_a_table_or_schema_qualified_table(): + suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.") + assert set(suggestions) == { + Column(table_refs=((None, "tabl", None, False),)), + Table(schema="tabl"), + View(schema="tabl"), + Function(schema="tabl"), + } + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT t1. FROM tabl1 t1", + "SELECT t1. FROM tabl1 t1, tabl2 t2", + 'SELECT t1. FROM "tabl1" t1', + 'SELECT t1. FROM "tabl1" t1, "tabl2" t2', + ], +) +def test_dot_suggests_cols_of_an_alias(sql): + suggestions = suggest_type(sql, "SELECT t1.") + assert set(suggestions) == { + Table(schema="t1"), + View(schema="t1"), + Column(table_refs=((None, "tabl1", "t1", False),)), + Function(schema="t1"), + } + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM tabl1 t1 WHERE t1.", + "SELECT * FROM tabl1 t1, tabl2 t2 WHERE t1.", + 'SELECT * FROM "tabl1" t1 WHERE t1.', + 'SELECT * FROM "tabl1" t1, tabl2 t2 WHERE t1.', + ], +) +def test_dot_suggests_cols_of_an_alias_where(sql): + suggestions = suggest_type(sql, sql) + assert set(suggestions) == { + Table(schema="t1"), + View(schema="t1"), + Column(table_refs=((None, "tabl1", "t1", False),)), + Function(schema="t1"), + } + + +def test_dot_col_comma_suggests_cols_or_schema_qualified_table(): + suggestions = suggest_type( + "SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2." + ) + assert set(suggestions) == { + Column(table_refs=((None, "tabl2", "t2", False),)), + Table(schema="t2"), + View(schema="t2"), + Function(schema="t2"), + } + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM (", + "SELECT * FROM foo WHERE EXISTS (", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (", + ], +) +def test_sub_select_suggests_keyword(expression): + suggestion = suggest_type(expression, expression) + assert suggestion == (Keyword(),) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM (S", + "SELECT * FROM foo WHERE EXISTS (S", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (S", + ], +) +def test_sub_select_partial_text_suggests_keyword(expression): + suggestion = suggest_type(expression, expression) + assert suggestion == (Keyword(),) + + +def test_outer_table_reference_in_exists_subquery_suggests_columns(): + q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f." + suggestions = suggest_type(q, q) + assert set(suggestions) == { + Column(table_refs=((None, "foo", "f", False),)), + Table(schema="f"), + View(schema="f"), + Function(schema="f"), + } + + +@pytest.mark.parametrize("expression", ["SELECT * FROM (SELECT * FROM "]) +def test_sub_select_table_name_completion(expression): + suggestion = suggest_type(expression, expression) + assert set(suggestion) == {FromClauseItem(schema=None), Schema()} + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM foo WHERE EXISTS (SELECT * FROM ", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ", + ], +) +def test_sub_select_table_name_completion_with_outer_table(expression): + suggestion = suggest_type(expression, expression) + tbls = tuple([(None, "foo", None, False)]) + assert set(suggestion) == {FromClauseItem(schema=None, table_refs=tbls), Schema()} + + +def test_sub_select_col_name_completion(): + suggestions = suggest_type( + "SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT " + ) + assert set(suggestions) == { + Column(table_refs=((None, "abc", None, False),), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + } + + +@pytest.mark.xfail +def test_sub_select_multiple_col_name_completion(): + suggestions = suggest_type( + "SELECT * FROM (SELECT a, FROM abc", "SELECT * FROM (SELECT a, " + ) + assert set(suggestions) == cols_etc("abc") + + +def test_sub_select_dot_col_name_completion(): + suggestions = suggest_type( + "SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t." + ) + assert set(suggestions) == { + Column(table_refs=((None, "tabl", "t", False),)), + Table(schema="t"), + View(schema="t"), + Function(schema="t"), + } + + +@pytest.mark.parametrize("join_type", ("", "INNER", "LEFT", "RIGHT OUTER")) +@pytest.mark.parametrize("tbl_alias", ("", "foo")) +def test_join_suggests_tables_and_schemas(tbl_alias, join_type): + text = f"SELECT * FROM abc {tbl_alias} {join_type} JOIN " + suggestion = suggest_type(text, text) + tbls = tuple([(None, "abc", tbl_alias or None, False)]) + assert set(suggestion) == { + FromClauseItem(schema=None, table_refs=tbls), + Schema(), + Join(tbls, None), + } + + +def test_left_join_with_comma(): + text = "select * from foo f left join bar b," + suggestions = suggest_type(text, text) + # tbls should also include (None, 'bar', 'b', False) + # but there's a bug with commas + tbls = tuple([(None, "foo", "f", False)]) + assert set(suggestions) == {FromClauseItem(schema=None, table_refs=tbls), Schema()} + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM abc a JOIN def d ON a.", + "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.", + ], +) +def test_join_alias_dot_suggests_cols1(sql): + suggestions = suggest_type(sql, sql) + tables = ((None, "abc", "a", False), (None, "def", "d", False)) + assert set(suggestions) == { + Column(table_refs=((None, "abc", "a", False),)), + Table(schema="a"), + View(schema="a"), + Function(schema="a"), + JoinCondition(table_refs=tables, parent=(None, "abc", "a", False)), + } + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM abc a JOIN def d ON a.id = d.", + "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.", + ], +) +def test_join_alias_dot_suggests_cols2(sql): + suggestion = suggest_type(sql, sql) + assert set(suggestion) == { + Column(table_refs=((None, "def", "d", False),)), + Table(schema="d"), + View(schema="d"), + Function(schema="d"), + } + + +@pytest.mark.parametrize( + "sql", + [ + "select a.x, b.y from abc a join bcd b on ", + """select a.x, b.y +from abc a +join bcd b on +""", + """select a.x, b.y +from abc a +join bcd b +on """, + "select a.x, b.y from abc a join bcd b on a.id = b.id OR ", + ], +) +def test_on_suggests_aliases_and_join_conditions(sql): + suggestions = suggest_type(sql, sql) + tables = ((None, "abc", "a", False), (None, "bcd", "b", False)) + assert set(suggestions) == { + JoinCondition(table_refs=tables, parent=None), + Alias(aliases=("a", "b")), + } + + +@pytest.mark.parametrize( + "sql", + [ + "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ", + "select abc.x, bcd.y from abc join bcd on ", + ], +) +def test_on_suggests_tables_and_join_conditions(sql): + suggestions = suggest_type(sql, sql) + tables = ((None, "abc", None, False), (None, "bcd", None, False)) + assert set(suggestions) == { + JoinCondition(table_refs=tables, parent=None), + Alias(aliases=("abc", "bcd")), + } + + +@pytest.mark.parametrize( + "sql", + [ + "select a.x, b.y from abc a join bcd b on a.id = ", + "select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ", + ], +) +def test_on_suggests_aliases_right_side(sql): + suggestions = suggest_type(sql, sql) + assert suggestions == (Alias(aliases=("a", "b")),) + + +@pytest.mark.parametrize( + "sql", + [ + "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ", + "select abc.x, bcd.y from abc join bcd on ", + ], +) +def test_on_suggests_tables_and_join_conditions_right_side(sql): + suggestions = suggest_type(sql, sql) + tables = ((None, "abc", None, False), (None, "bcd", None, False)) + assert set(suggestions) == { + JoinCondition(table_refs=tables, parent=None), + Alias(aliases=("abc", "bcd")), + } + + +@pytest.mark.parametrize( + "text", + ( + "select * from abc inner join def using (", + "select * from abc inner join def using (col1, ", + "insert into hij select * from abc inner join def using (", + """insert into hij(x, y, z) + select * from abc inner join def using (col1, """, + """insert into hij (a,b,c) + select * from abc inner join def using (col1, """, + ), +) +def test_join_using_suggests_common_columns(text): + tables = ((None, "abc", None, False), (None, "def", None, False)) + assert set(suggest_type(text, text)) == { + Column(table_refs=tables, require_last_table=True) + } + + +def test_suggest_columns_after_multiple_joins(): + sql = """select * from t1 + inner join t2 ON + t1.id = t2.t1_id + inner join t3 ON + t2.id = t3.""" + suggestions = suggest_type(sql, sql) + assert Column(table_refs=((None, "t3", None, False),)) in set(suggestions) + + +def test_2_statements_2nd_current(): + suggestions = suggest_type( + "select * from a; select * from ", "select * from a; select * from " + ) + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} + + suggestions = suggest_type( + "select * from a; select from b", "select * from a; select " + ) + assert set(suggestions) == { + Column(table_refs=((None, "b", None, False),), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + } + + # Should work even if first statement is invalid + suggestions = suggest_type( + "select * from; select * from ", "select * from; select * from " + ) + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} + + +def test_2_statements_1st_current(): + suggestions = suggest_type("select * from ; select * from b", "select * from ") + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} + + suggestions = suggest_type("select from a; select * from b", "select ") + assert set(suggestions) == cols_etc("a", last_keyword="SELECT") + + +def test_3_statements_2nd_current(): + suggestions = suggest_type( + "select * from a; select * from ; select * from c", + "select * from a; select * from ", + ) + assert set(suggestions) == {FromClauseItem(schema=None), Schema()} + + suggestions = suggest_type( + "select * from a; select from b; select * from c", "select * from a; select " + ) + assert set(suggestions) == cols_etc("b", last_keyword="SELECT") + + +@pytest.mark.parametrize( + "text", + [ + """ +CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $$ +SELECT FROM foo; +SELECT 2 FROM bar; +$$ language sql; + """, + """create function func2(int, varchar) +RETURNS text +language sql AS +$func$ +SELECT 2 FROM bar; +SELECT FROM foo; +$func$ + """, + """ +CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $func$ +SELECT 3 FROM foo; +SELECT 2 FROM bar; +$$ language sql; +create function func2(int, varchar) +RETURNS text +language sql AS +$func$ +SELECT 2 FROM bar; +SELECT FROM foo; +$func$ + """, + """ +SELECT * FROM baz; +CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $func$ +SELECT FROM foo; +SELECT 2 FROM bar; +$$ language sql; +create function func2(int, varchar) +RETURNS text +language sql AS +$func$ +SELECT 3 FROM bar; +SELECT FROM foo; +$func$ +SELECT * FROM qux; + """, + ], +) +def test_statements_in_function_body(text): + suggestions = suggest_type(text, text[: text.find(" ") + 1]) + assert set(suggestions) == { + Column(table_refs=((None, "foo", None, False),), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + } + + +functions = [ + """ +CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $$ +SELECT 1 FROM foo; +SELECT 2 FROM bar; +$$ language sql; + """, + """ +create function func2(int, varchar) +RETURNS text +language sql AS +' +SELECT 2 FROM bar; +SELECT 1 FROM foo; +'; + """, +] + + +@pytest.mark.parametrize("text", functions) +def test_statements_with_cursor_after_function_body(text): + suggestions = suggest_type(text, text[: text.find("; ") + 1]) + assert set(suggestions) == {Keyword(), Special()} + + +@pytest.mark.parametrize("text", functions) +def test_statements_with_cursor_before_function_body(text): + suggestions = suggest_type(text, "") + assert set(suggestions) == {Keyword(), Special()} + + +def test_create_db_with_template(): + suggestions = suggest_type( + "create database foo with template ", "create database foo with template " + ) + + assert set(suggestions) == {Database()} + + +@pytest.mark.parametrize("initial_text", ("", " ", "\t \t", "\n")) +def test_specials_included_for_initial_completion(initial_text): + suggestions = suggest_type(initial_text, initial_text) + + assert set(suggestions) == {Keyword(), Special()} + + +def test_drop_schema_qualified_table_suggests_only_tables(): + text = "DROP TABLE schema_name.table_name" + suggestions = suggest_type(text, text) + assert suggestions == (Table(schema="schema_name"),) + + +@pytest.mark.parametrize("text", (",", " ,", "sel ,")) +def test_handle_pre_completion_comma_gracefully(text): + suggestions = suggest_type(text, text) + + assert iter(suggestions) + + +def test_drop_schema_suggests_schemas(): + sql = "DROP SCHEMA " + assert suggest_type(sql, sql) == (Schema(),) + + +@pytest.mark.parametrize("text", ["SELECT x::", "SELECT x::y", "SELECT (x + y)::"]) +def test_cast_operator_suggests_types(text): + assert set(suggest_type(text, text)) == { + Datatype(schema=None), + Table(schema=None), + Schema(), + } + + +@pytest.mark.parametrize( + "text", ["SELECT foo::bar.", "SELECT foo::bar.baz", "SELECT (x + y)::bar."] +) +def test_cast_operator_suggests_schema_qualified_types(text): + assert set(suggest_type(text, text)) == { + Datatype(schema="bar"), + Table(schema="bar"), + } + + +def test_alter_column_type_suggests_types(): + q = "ALTER TABLE foo ALTER COLUMN bar TYPE " + assert set(suggest_type(q, q)) == { + Datatype(schema=None), + Table(schema=None), + Schema(), + } + + +@pytest.mark.parametrize( + "text", + [ + "CREATE TABLE foo (bar ", + "CREATE TABLE foo (bar DOU", + "CREATE TABLE foo (bar INT, baz ", + "CREATE TABLE foo (bar INT, baz TEXT, qux ", + "CREATE FUNCTION foo (bar ", + "CREATE FUNCTION foo (bar INT, baz ", + "SELECT * FROM foo() AS bar (baz ", + "SELECT * FROM foo() AS bar (baz INT, qux ", + # make sure this doesn't trigger special completion + "CREATE TABLE foo (dt d", + ], +) +def test_identifier_suggests_types_in_parentheses(text): + assert set(suggest_type(text, text)) == { + Datatype(schema=None), + Table(schema=None), + Schema(), + } + + +@pytest.mark.parametrize( + "text", + [ + "SELECT foo ", + "SELECT foo FROM bar ", + "SELECT foo AS bar ", + "SELECT foo bar ", + "SELECT * FROM foo AS bar ", + "SELECT * FROM foo bar ", + "SELECT foo FROM (SELECT bar ", + ], +) +def test_alias_suggests_keywords(text): + suggestions = suggest_type(text, text) + assert suggestions == (Keyword(),) + + +def test_invalid_sql(): + # issue 317 + text = "selt *" + suggestions = suggest_type(text, text) + assert suggestions == (Keyword(),) + + +@pytest.mark.parametrize( + "text", + ["SELECT * FROM foo where created > now() - ", "select * from foo where bar "], +) +def test_suggest_where_keyword(text): + # https://github.com/dbcli/mycli/issues/135 + suggestions = suggest_type(text, text) + assert set(suggestions) == cols_etc("foo", last_keyword="WHERE") + + +@pytest.mark.parametrize( + "text, before, expected", + [ + ( + "\\ns abc SELECT ", + "SELECT ", + [ + Column(table_refs=(), qualifiable=True), + Function(schema=None), + Keyword("SELECT"), + ], + ), + ("\\ns abc SELECT foo ", "SELECT foo ", (Keyword(),)), + ( + "\\ns abc SELECT t1. FROM tabl1 t1", + "SELECT t1.", + [ + Table(schema="t1"), + View(schema="t1"), + Column(table_refs=((None, "tabl1", "t1", False),)), + Function(schema="t1"), + ], + ), + ], +) +def test_named_query_completion(text, before, expected): + suggestions = suggest_type(text, before) + assert set(expected) == set(suggestions) + + +def test_select_suggests_fields_from_function(): + suggestions = suggest_type("SELECT FROM func()", "SELECT ") + assert set(suggestions) == cols_etc("func", is_function=True, last_keyword="SELECT") + + +@pytest.mark.parametrize("sql", ["("]) +def test_leading_parenthesis(sql): + # No assertion for now; just make sure it doesn't crash + suggest_type(sql, sql) + + +@pytest.mark.parametrize("sql", ['select * from "', 'select * from "foo']) +def test_ignore_leading_double_quotes(sql): + suggestions = suggest_type(sql, sql) + assert FromClauseItem(schema=None) in set(suggestions) + + +@pytest.mark.parametrize( + "sql", + [ + "ALTER TABLE foo ALTER COLUMN ", + "ALTER TABLE foo ALTER COLUMN bar", + "ALTER TABLE foo DROP COLUMN ", + "ALTER TABLE foo DROP COLUMN bar", + ], +) +def test_column_keyword_suggests_columns(sql): + suggestions = suggest_type(sql, sql) + assert set(suggestions) == {Column(table_refs=((None, "foo", None, False),))} + + +def test_handle_unrecognized_kw_generously(): + sql = "SELECT * FROM sessions WHERE session = 1 AND " + suggestions = suggest_type(sql, sql) + expected = Column(table_refs=((None, "sessions", None, False),), qualifiable=True) + + assert expected in set(suggestions) + + +@pytest.mark.parametrize("sql", ["ALTER ", "ALTER TABLE foo ALTER "]) +def test_keyword_after_alter(sql): + assert Keyword("ALTER") in set(suggest_type(sql, sql)) diff --git a/tests/test_ssh_tunnel.py b/tests/test_ssh_tunnel.py new file mode 100644 index 0000000..ae865f4 --- /dev/null +++ b/tests/test_ssh_tunnel.py @@ -0,0 +1,188 @@ +import os +from unittest.mock import patch, MagicMock, ANY + +import pytest +from configobj import ConfigObj +from click.testing import CliRunner +from sshtunnel import SSHTunnelForwarder + +from pgcli.main import cli, PGCli +from pgcli.pgexecute import PGExecute + + +@pytest.fixture +def mock_ssh_tunnel_forwarder() -> MagicMock: + mock_ssh_tunnel_forwarder = MagicMock( + SSHTunnelForwarder, local_bind_ports=[1111], autospec=True + ) + with patch( + "pgcli.main.sshtunnel.SSHTunnelForwarder", + return_value=mock_ssh_tunnel_forwarder, + ) as mock: + yield mock + + +@pytest.fixture +def mock_pgexecute() -> MagicMock: + with patch.object(PGExecute, "__init__", return_value=None) as mock_pgexecute: + yield mock_pgexecute + + +def test_ssh_tunnel( + mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock +) -> None: + # Test with just a host + tunnel_url = "some.host" + db_params = { + "database": "dbname", + "host": "db.host", + "user": "db_user", + "passwd": "db_passwd", + } + expected_tunnel_params = { + "local_bind_address": ("127.0.0.1",), + "remote_bind_address": (db_params["host"], 5432), + "ssh_address_or_host": (tunnel_url, 22), + "logger": ANY, + } + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect(**db_params) + + mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) + mock_ssh_tunnel_forwarder.return_value.start.assert_called_once() + mock_pgexecute.assert_called_once() + + call_args, call_kwargs = mock_pgexecute.call_args + assert call_args == ( + db_params["database"], + db_params["user"], + db_params["passwd"], + "127.0.0.1", + pgcli.ssh_tunnel.local_bind_ports[0], + "", + ) + mock_ssh_tunnel_forwarder.reset_mock() + mock_pgexecute.reset_mock() + + # Test with a full url and with a specific db port + tunnel_user = "tunnel_user" + tunnel_passwd = "tunnel_pass" + tunnel_host = "some.other.host" + tunnel_port = 1022 + tunnel_url = f"ssh://{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}" + db_params["port"] = 1234 + + expected_tunnel_params["remote_bind_address"] = ( + db_params["host"], + db_params["port"], + ) + expected_tunnel_params["ssh_address_or_host"] = (tunnel_host, tunnel_port) + expected_tunnel_params["ssh_username"] = tunnel_user + expected_tunnel_params["ssh_password"] = tunnel_passwd + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect(**db_params) + + mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) + mock_ssh_tunnel_forwarder.return_value.start.assert_called_once() + mock_pgexecute.assert_called_once() + + call_args, call_kwargs = mock_pgexecute.call_args + assert call_args == ( + db_params["database"], + db_params["user"], + db_params["passwd"], + "127.0.0.1", + pgcli.ssh_tunnel.local_bind_ports[0], + "", + ) + mock_ssh_tunnel_forwarder.reset_mock() + mock_pgexecute.reset_mock() + + # Test with DSN + dsn = ( + f"user={db_params['user']} password={db_params['passwd']} " + f"host={db_params['host']} port={db_params['port']}" + ) + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect(dsn=dsn) + + expected_dsn = ( + f"user={db_params['user']} password={db_params['passwd']} " + f"host=127.0.0.1 port={pgcli.ssh_tunnel.local_bind_ports[0]}" + ) + + mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) + mock_pgexecute.assert_called_once() + + call_args, call_kwargs = mock_pgexecute.call_args + assert expected_dsn in call_args + + +def test_cli_with_tunnel() -> None: + runner = CliRunner() + tunnel_url = "mytunnel" + with patch.object( + PGCli, "__init__", autospec=True, return_value=None + ) as mock_pgcli: + runner.invoke(cli, ["--ssh-tunnel", tunnel_url]) + mock_pgcli.assert_called_once() + call_args, call_kwargs = mock_pgcli.call_args + assert call_kwargs["ssh_tunnel_url"] == tunnel_url + + +def test_config( + tmpdir: os.PathLike, mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock +) -> None: + pgclirc = str(tmpdir.join("rcfile")) + + tunnel_user = "tunnel_user" + tunnel_passwd = "tunnel_pass" + tunnel_host = "tunnel.host" + tunnel_port = 1022 + tunnel_url = f"{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}" + + tunnel2_url = "tunnel2.host" + + config = ConfigObj() + config.filename = pgclirc + config["ssh tunnels"] = {} + config["ssh tunnels"][r"\.com$"] = tunnel_url + config["ssh tunnels"][r"^hello-"] = tunnel2_url + config.write() + + # Unmatched host + pgcli = PGCli(pgclirc_file=pgclirc) + pgcli.connect(host="unmatched.host") + mock_ssh_tunnel_forwarder.assert_not_called() + + # Host matching first tunnel + pgcli = PGCli(pgclirc_file=pgclirc) + pgcli.connect(host="matched.host.com") + mock_ssh_tunnel_forwarder.assert_called_once() + call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args + assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port) + assert call_kwargs["ssh_username"] == tunnel_user + assert call_kwargs["ssh_password"] == tunnel_passwd + mock_ssh_tunnel_forwarder.reset_mock() + + # Host matching second tunnel + pgcli = PGCli(pgclirc_file=pgclirc) + pgcli.connect(host="hello-i-am-matched") + mock_ssh_tunnel_forwarder.assert_called_once() + + call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args + assert call_kwargs["ssh_address_or_host"] == (tunnel2_url, 22) + mock_ssh_tunnel_forwarder.reset_mock() + + # Host matching both tunnels (will use the first one matched) + pgcli = PGCli(pgclirc_file=pgclirc) + pgcli.connect(host="hello-i-am-matched.com") + mock_ssh_tunnel_forwarder.assert_called_once() + + call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args + assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port) + assert call_kwargs["ssh_username"] == tunnel_user + assert call_kwargs["ssh_password"] == tunnel_passwd diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..67d769f --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,92 @@ +import pytest +import psycopg +from pgcli.main import format_output, OutputSettings +from os import getenv + +POSTGRES_USER = getenv("PGUSER", "postgres") +POSTGRES_HOST = getenv("PGHOST", "localhost") +POSTGRES_PORT = getenv("PGPORT", 5432) +POSTGRES_PASSWORD = getenv("PGPASSWORD", "postgres") + + +def db_connection(dbname=None): + conn = psycopg.connect( + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + port=POSTGRES_PORT, + dbname=dbname, + ) + conn.autocommit = True + return conn + + +try: + conn = db_connection() + CAN_CONNECT_TO_DB = True + SERVER_VERSION = conn.info.parameter_status("server_version") + JSON_AVAILABLE = True + JSONB_AVAILABLE = True +except Exception as x: + CAN_CONNECT_TO_DB = JSON_AVAILABLE = JSONB_AVAILABLE = False + SERVER_VERSION = 0 + + +dbtest = pytest.mark.skipif( + not CAN_CONNECT_TO_DB, + reason="Need a postgres instance at localhost accessible by user 'postgres'", +) + + +requires_json = pytest.mark.skipif( + not JSON_AVAILABLE, reason="Postgres server unavailable or json type not defined" +) + + +requires_jsonb = pytest.mark.skipif( + not JSONB_AVAILABLE, reason="Postgres server unavailable or jsonb type not defined" +) + + +def create_db(dbname): + with db_connection().cursor() as cur: + try: + cur.execute("""CREATE DATABASE _test_db""") + except: + pass + + +def drop_tables(conn): + with conn.cursor() as cur: + cur.execute( + """ + DROP SCHEMA public CASCADE; + CREATE SCHEMA public; + DROP SCHEMA IF EXISTS schema1 CASCADE; + DROP SCHEMA IF EXISTS schema2 CASCADE""" + ) + + +def run( + executor, sql, join=False, expanded=False, pgspecial=None, exception_formatter=None +): + "Return string output for the sql to be run" + + results = executor.run(sql, pgspecial, exception_formatter) + formatted = [] + settings = OutputSettings( + table_format="psql", dcmlfmt="d", floatfmt="g", expanded=expanded + ) + for title, rows, headers, status, sql, success, is_special in results: + formatted.extend(format_output(title, rows, headers, status, settings)) + if join: + formatted = "\n".join(formatted) + + return formatted + + +def completions_to_set(completions): + return { + (completion.display_text, completion.display_meta_text) + for completion in completions + } |