summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 03:06:41 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 03:06:41 +0000
commit708c091a8b4db6a55be1c96ae33ee0da632b269f (patch)
treeaac9e87c59cb8bc7e3cd429e9200c3ca017cb591 /tests
parentInitial commit. (diff)
downloadpgcli-708c091a8b4db6a55be1c96ae33ee0da632b269f.tar.xz
pgcli-708c091a8b4db6a55be1c96ae33ee0da632b269f.zip
Adding upstream version 4.0.1.upstream/4.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests')
-rw-r--r--tests/conftest.py52
-rw-r--r--tests/features/__init__.py0
-rw-r--r--tests/features/auto_vertical.feature12
-rw-r--r--tests/features/basic_commands.feature81
-rw-r--r--tests/features/crud_database.feature17
-rw-r--r--tests/features/crud_table.feature45
-rw-r--r--tests/features/db_utils.py87
-rw-r--r--tests/features/environment.py227
-rw-r--r--tests/features/expanded.feature29
-rw-r--r--tests/features/fixture_data/help.txt25
-rw-r--r--tests/features/fixture_data/help_commands.txt64
-rw-r--r--tests/features/fixture_data/mock_pg_service.conf4
-rw-r--r--tests/features/fixture_utils.py28
-rw-r--r--tests/features/iocommands.feature17
-rw-r--r--tests/features/named_queries.feature10
-rw-r--r--tests/features/pgbouncer.feature12
-rw-r--r--tests/features/specials.feature6
-rw-r--r--tests/features/steps/__init__.py0
-rw-r--r--tests/features/steps/auto_vertical.py99
-rw-r--r--tests/features/steps/basic_commands.py231
-rw-r--r--tests/features/steps/crud_database.py93
-rw-r--r--tests/features/steps/crud_table.py185
-rw-r--r--tests/features/steps/expanded.py70
-rw-r--r--tests/features/steps/iocommands.py80
-rw-r--r--tests/features/steps/named_queries.py57
-rw-r--r--tests/features/steps/pgbouncer.py22
-rw-r--r--tests/features/steps/specials.py31
-rw-r--r--tests/features/steps/wrappers.py71
-rw-r--r--tests/features/wrappager.py16
-rw-r--r--tests/formatter/__init__.py1
-rw-r--r--tests/formatter/test_sqlformatter.py111
-rw-r--r--tests/metadata.py255
-rw-r--r--tests/parseutils/test_ctes.py137
-rw-r--r--tests/parseutils/test_function_metadata.py19
-rw-r--r--tests/parseutils/test_parseutils.py310
-rw-r--r--tests/pytest.ini2
-rw-r--r--tests/test_auth.py40
-rw-r--r--tests/test_completion_refresher.py95
-rw-r--r--tests/test_config.py43
-rw-r--r--tests/test_fuzzy_completion.py87
-rw-r--r--tests/test_main.py490
-rw-r--r--tests/test_naive_completion.py133
-rw-r--r--tests/test_pgcompleter.py76
-rw-r--r--tests/test_pgexecute.py773
-rw-r--r--tests/test_pgspecial.py78
-rw-r--r--tests/test_prioritization.py20
-rw-r--r--tests/test_prompt_utils.py17
-rw-r--r--tests/test_rowlimit.py79
-rw-r--r--tests/test_smart_completion_multiple_schemata.py757
-rw-r--r--tests/test_smart_completion_public_schema_only.py1112
-rw-r--r--tests/test_sqlcompletion.py964
-rw-r--r--tests/test_ssh_tunnel.py188
-rw-r--r--tests/utils.py92
53 files changed, 7550 insertions, 0 deletions
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..33cddf2
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,52 @@
+import os
+import pytest
+from utils import (
+ POSTGRES_HOST,
+ POSTGRES_PORT,
+ POSTGRES_USER,
+ POSTGRES_PASSWORD,
+ create_db,
+ db_connection,
+ drop_tables,
+)
+import pgcli.pgexecute
+
+
+@pytest.fixture(scope="function")
+def connection():
+ create_db("_test_db")
+ connection = db_connection("_test_db")
+ yield connection
+
+ drop_tables(connection)
+ connection.close()
+
+
+@pytest.fixture
+def cursor(connection):
+ with connection.cursor() as cur:
+ return cur
+
+
+@pytest.fixture
+def executor(connection):
+ return pgcli.pgexecute.PGExecute(
+ database="_test_db",
+ user=POSTGRES_USER,
+ host=POSTGRES_HOST,
+ password=POSTGRES_PASSWORD,
+ port=POSTGRES_PORT,
+ dsn=None,
+ )
+
+
+@pytest.fixture
+def exception_formatter():
+ return lambda e: str(e)
+
+
+@pytest.fixture(scope="session", autouse=True)
+def temp_config(tmpdir_factory):
+ # this function runs on start of test session.
+ # use temporary directory for config home so user config will not be used
+ os.environ["XDG_CONFIG_HOME"] = str(tmpdir_factory.mktemp("data"))
diff --git a/tests/features/__init__.py b/tests/features/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/features/__init__.py
diff --git a/tests/features/auto_vertical.feature b/tests/features/auto_vertical.feature
new file mode 100644
index 0000000..aa95718
--- /dev/null
+++ b/tests/features/auto_vertical.feature
@@ -0,0 +1,12 @@
+Feature: auto_vertical mode:
+ on, off
+
+ Scenario: auto_vertical on with small query
+ When we run dbcli with --auto-vertical-output
+ and we execute a small query
+ then we see small results in horizontal format
+
+ Scenario: auto_vertical on with large query
+ When we run dbcli with --auto-vertical-output
+ and we execute a large query
+ then we see large results in vertical format
diff --git a/tests/features/basic_commands.feature b/tests/features/basic_commands.feature
new file mode 100644
index 0000000..ee497b9
--- /dev/null
+++ b/tests/features/basic_commands.feature
@@ -0,0 +1,81 @@
+Feature: run the cli,
+ call the help command,
+ exit the cli
+
+ Scenario: run "\?" command
+ When we send "\?" command
+ then we see help output
+
+ Scenario: run source command
+ When we send source command
+ then we see help output
+
+ Scenario: run partial select command
+ When we send partial select command
+ then we see error message
+ then we see dbcli prompt
+
+ Scenario: check our application_name
+ When we run query to check application_name
+ then we see found
+
+ Scenario: run the cli and exit
+ When we send "ctrl + d"
+ then dbcli exits
+
+ Scenario: confirm exit when a transaction is ongoing
+ When we begin transaction
+ and we try to send "ctrl + d"
+ then we see ongoing transaction message
+ when we send "c"
+ then dbcli exits
+
+ Scenario: cancel exit when a transaction is ongoing
+ When we begin transaction
+ and we try to send "ctrl + d"
+ then we see ongoing transaction message
+ when we send "a"
+ then we see dbcli prompt
+ when we rollback transaction
+ when we send "ctrl + d"
+ then dbcli exits
+
+ Scenario: interrupt current query via "ctrl + c"
+ When we send sleep query
+ and we send "ctrl + c"
+ then we see cancelled query warning
+ when we check for any non-idle sleep queries
+ then we don't see any non-idle sleep queries
+
+ Scenario: list databases
+ When we list databases
+ then we see list of databases
+
+ Scenario: run the cli with --username
+ When we launch dbcli using --username
+ and we send "\?" command
+ then we see help output
+
+ Scenario: run the cli with --user
+ When we launch dbcli using --user
+ and we send "\?" command
+ then we see help output
+
+ Scenario: run the cli with --port
+ When we launch dbcli using --port
+ and we send "\?" command
+ then we see help output
+
+ Scenario: run the cli with --password
+ When we launch dbcli using --password
+ then we send password
+ and we see dbcli prompt
+ when we send "\?" command
+ then we see help output
+
+ Scenario: run the cli with dsn and password
+ When we launch dbcli using dsn_password
+ then we send password
+ and we see dbcli prompt
+ when we send "\?" command
+ then we see help output
diff --git a/tests/features/crud_database.feature b/tests/features/crud_database.feature
new file mode 100644
index 0000000..87da4e3
--- /dev/null
+++ b/tests/features/crud_database.feature
@@ -0,0 +1,17 @@
+Feature: manipulate databases:
+ create, drop, connect, disconnect
+
+ Scenario: create and drop temporary database
+ When we create database
+ then we see database created
+ when we drop database
+ then we respond to the destructive warning: y
+ then we see database dropped
+ when we connect to dbserver
+ then we see database connected
+
+ Scenario: connect and disconnect from test database
+ When we connect to test database
+ then we see database connected
+ when we connect to dbserver
+ then we see database connected
diff --git a/tests/features/crud_table.feature b/tests/features/crud_table.feature
new file mode 100644
index 0000000..8a43c5c
--- /dev/null
+++ b/tests/features/crud_table.feature
@@ -0,0 +1,45 @@
+Feature: manipulate tables:
+ create, insert, update, select, delete from, drop
+
+ Scenario: create, insert, select from, update, drop table
+ When we connect to test database
+ then we see database connected
+ when we create table
+ then we see table created
+ when we insert into table
+ then we see record inserted
+ when we select from table
+ then we see data selected: initial
+ when we update table
+ then we see record updated
+ when we select from table
+ then we see data selected: updated
+ when we delete from table
+ then we respond to the destructive warning: y
+ then we see record deleted
+ when we drop table
+ then we respond to the destructive warning: y
+ then we see table dropped
+ when we connect to dbserver
+ then we see database connected
+
+ Scenario: transaction handling, with cancelling on a destructive warning.
+ When we connect to test database
+ then we see database connected
+ when we create table
+ then we see table created
+ when we begin transaction
+ then we see transaction began
+ when we insert into table
+ then we see record inserted
+ when we delete from table
+ then we respond to the destructive warning: n
+ when we select from table
+ then we see data selected: initial
+ when we rollback transaction
+ then we see transaction rolled back
+ when we select from table
+ then we see select output without data
+ when we drop table
+ then we respond to the destructive warning: y
+ then we see table dropped
diff --git a/tests/features/db_utils.py b/tests/features/db_utils.py
new file mode 100644
index 0000000..595c6c2
--- /dev/null
+++ b/tests/features/db_utils.py
@@ -0,0 +1,87 @@
+from psycopg import connect
+
+
+def create_db(
+ hostname="localhost", username=None, password=None, dbname=None, port=None
+):
+ """Create test database.
+
+ :param hostname: string
+ :param username: string
+ :param password: string
+ :param dbname: string
+ :param port: int
+ :return:
+
+ """
+ cn = create_cn(hostname, password, username, "postgres", port)
+
+ cn.autocommit = True
+ with cn.cursor() as cr:
+ cr.execute(f"drop database if exists {dbname}")
+ cr.execute(f"create database {dbname}")
+
+ cn.close()
+
+ cn = create_cn(hostname, password, username, dbname, port)
+ return cn
+
+
+def create_cn(hostname, password, username, dbname, port):
+ """
+ Open connection to database.
+ :param hostname:
+ :param password:
+ :param username:
+ :param dbname: string
+ :return: psycopg2.connection
+ """
+ cn = connect(
+ host=hostname, user=username, dbname=dbname, password=password, port=port
+ )
+
+ print(f"Created connection: {cn.info.get_parameters()}.")
+ return cn
+
+
+def pgbouncer_available(hostname="localhost", password=None, username="postgres"):
+ cn = None
+ try:
+ cn = create_cn(hostname, password, username, "pgbouncer", 6432)
+ return True
+ except:
+ print("Pgbouncer is not available.")
+ finally:
+ if cn:
+ cn.close()
+ return False
+
+
+def drop_db(hostname="localhost", username=None, password=None, dbname=None, port=None):
+ """
+ Drop database.
+ :param hostname: string
+ :param username: string
+ :param password: string
+ :param dbname: string
+ """
+ cn = create_cn(hostname, password, username, "postgres", port)
+
+ # Needed for DB drop.
+ cn.autocommit = True
+
+ with cn.cursor() as cr:
+ cr.execute(f"drop database if exists {dbname}")
+
+ close_cn(cn)
+
+
+def close_cn(cn=None):
+ """
+ Close connection.
+ :param connection: psycopg2.connection
+ """
+ if cn:
+ cn_params = cn.info.get_parameters()
+ cn.close()
+ print(f"Closed connection: {cn_params}.")
diff --git a/tests/features/environment.py b/tests/features/environment.py
new file mode 100644
index 0000000..50ac5fa
--- /dev/null
+++ b/tests/features/environment.py
@@ -0,0 +1,227 @@
+import copy
+import os
+import sys
+import db_utils as dbutils
+import fixture_utils as fixutils
+import pexpect
+import tempfile
+import shutil
+import signal
+
+
+from steps import wrappers
+
+
+def before_all(context):
+ """Set env parameters."""
+ env_old = copy.deepcopy(dict(os.environ))
+ os.environ["LINES"] = "100"
+ os.environ["COLUMNS"] = "100"
+ os.environ["PAGER"] = "cat"
+ os.environ["EDITOR"] = "ex"
+ os.environ["VISUAL"] = "ex"
+ os.environ["PROMPT_TOOLKIT_NO_CPR"] = "1"
+
+ context.package_root = os.path.abspath(
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
+ )
+ fixture_dir = os.path.join(context.package_root, "tests/features/fixture_data")
+
+ print("package root:", context.package_root)
+ print("fixture dir:", fixture_dir)
+
+ os.environ["COVERAGE_PROCESS_START"] = os.path.join(
+ context.package_root, ".coveragerc"
+ )
+
+ context.exit_sent = False
+
+ vi = "_".join([str(x) for x in sys.version_info[:3]])
+ db_name = context.config.userdata.get("pg_test_db", "pgcli_behave_tests")
+ db_name_full = f"{db_name}_{vi}"
+
+ # Store get params from config.
+ context.conf = {
+ "host": context.config.userdata.get(
+ "pg_test_host", os.getenv("PGHOST", "localhost")
+ ),
+ "user": context.config.userdata.get(
+ "pg_test_user", os.getenv("PGUSER", "postgres")
+ ),
+ "pass": context.config.userdata.get(
+ "pg_test_pass", os.getenv("PGPASSWORD", None)
+ ),
+ "port": context.config.userdata.get(
+ "pg_test_port", os.getenv("PGPORT", "5432")
+ ),
+ "cli_command": (
+ context.config.userdata.get("pg_cli_command", None)
+ or '{python} -c "{startup}"'.format(
+ python=sys.executable,
+ startup="; ".join(
+ [
+ "import coverage",
+ "coverage.process_startup()",
+ "import pgcli.main",
+ "pgcli.main.cli(auto_envvar_prefix='BEHAVE')",
+ ]
+ ),
+ )
+ ),
+ "dbname": db_name_full,
+ "dbname_tmp": db_name_full + "_tmp",
+ "vi": vi,
+ "pager_boundary": "---boundary---",
+ }
+ os.environ["PAGER"] = "{0} {1} {2}".format(
+ sys.executable,
+ os.path.join(context.package_root, "tests/features/wrappager.py"),
+ context.conf["pager_boundary"],
+ )
+
+ # Store old env vars.
+ context.pgenv = {
+ "PGDATABASE": os.environ.get("PGDATABASE", None),
+ "PGUSER": os.environ.get("PGUSER", None),
+ "PGHOST": os.environ.get("PGHOST", None),
+ "PGPASSWORD": os.environ.get("PGPASSWORD", None),
+ "PGPORT": os.environ.get("PGPORT", None),
+ "XDG_CONFIG_HOME": os.environ.get("XDG_CONFIG_HOME", None),
+ "PGSERVICEFILE": os.environ.get("PGSERVICEFILE", None),
+ }
+
+ # Set new env vars.
+ os.environ["PGDATABASE"] = context.conf["dbname"]
+ os.environ["PGUSER"] = context.conf["user"]
+ os.environ["PGHOST"] = context.conf["host"]
+ os.environ["PGPORT"] = context.conf["port"]
+ os.environ["PGSERVICEFILE"] = os.path.join(fixture_dir, "mock_pg_service.conf")
+
+ if context.conf["pass"]:
+ os.environ["PGPASSWORD"] = context.conf["pass"]
+ else:
+ if "PGPASSWORD" in os.environ:
+ del os.environ["PGPASSWORD"]
+ os.environ["BEHAVE_WARN"] = "moderate"
+
+ context.cn = dbutils.create_db(
+ context.conf["host"],
+ context.conf["user"],
+ context.conf["pass"],
+ context.conf["dbname"],
+ context.conf["port"],
+ )
+ context.pgbouncer_available = dbutils.pgbouncer_available(
+ hostname=context.conf["host"],
+ password=context.conf["pass"],
+ username=context.conf["user"],
+ )
+ context.fixture_data = fixutils.read_fixture_files()
+
+ # use temporary directory as config home
+ context.env_config_home = tempfile.mkdtemp(prefix="pgcli_home_")
+ os.environ["XDG_CONFIG_HOME"] = context.env_config_home
+ show_env_changes(env_old, dict(os.environ))
+
+
+def show_env_changes(env_old, env_new):
+ """Print out all test-specific env values."""
+ print("--- os.environ changed values: ---")
+ all_keys = env_old.keys() | env_new.keys()
+ for k in sorted(all_keys):
+ old_value = env_old.get(k, "")
+ new_value = env_new.get(k, "")
+ if new_value and old_value != new_value:
+ print(f'{k}="{new_value}"')
+ print("-" * 20)
+
+
+def after_all(context):
+ """
+ Unset env parameters.
+ """
+ dbutils.close_cn(context.cn)
+ dbutils.drop_db(
+ context.conf["host"],
+ context.conf["user"],
+ context.conf["pass"],
+ context.conf["dbname"],
+ context.conf["port"],
+ )
+
+ # Remove temp config directory
+ shutil.rmtree(context.env_config_home)
+
+ # Restore env vars.
+ for k, v in context.pgenv.items():
+ if k in os.environ and v is None:
+ del os.environ[k]
+ elif v:
+ os.environ[k] = v
+
+
+def before_step(context, _):
+ context.atprompt = False
+
+
+def is_known_problem(scenario):
+ """TODO: why is this not working in 3.12?"""
+ if sys.version_info >= (3, 12):
+ return scenario.name in (
+ 'interrupt current query via "ctrl + c"',
+ "run the cli with --username",
+ "run the cli with --user",
+ "run the cli with --port",
+ )
+ return False
+
+
+def before_scenario(context, scenario):
+ if scenario.name == "list databases":
+ # not using the cli for that
+ return
+ if is_known_problem(scenario):
+ scenario.skip()
+ currentdb = None
+ if "pgbouncer" in scenario.feature.tags:
+ if context.pgbouncer_available:
+ os.environ["PGDATABASE"] = "pgbouncer"
+ os.environ["PGPORT"] = "6432"
+ currentdb = "pgbouncer"
+ else:
+ scenario.skip()
+ else:
+ # set env vars back to normal test database
+ os.environ["PGDATABASE"] = context.conf["dbname"]
+ os.environ["PGPORT"] = context.conf["port"]
+ wrappers.run_cli(context, currentdb=currentdb)
+ wrappers.wait_prompt(context)
+
+
+def after_scenario(context, scenario):
+ """Cleans up after each scenario completes."""
+ if hasattr(context, "cli") and context.cli and not context.exit_sent:
+ # Quit nicely.
+ if not getattr(context, "atprompt", False):
+ dbname = context.currentdb
+ context.cli.expect_exact(f"{dbname}>", timeout=5)
+ try:
+ context.cli.sendcontrol("c")
+ context.cli.sendcontrol("d")
+ except Exception as x:
+ print("Failed cleanup after scenario:")
+ print(x)
+ try:
+ context.cli.expect_exact(pexpect.EOF, timeout=5)
+ except pexpect.TIMEOUT:
+ print(f"--- after_scenario {scenario.name}: kill cli")
+ context.cli.kill(signal.SIGKILL)
+ if hasattr(context, "tmpfile_sql_help") and context.tmpfile_sql_help:
+ context.tmpfile_sql_help.close()
+ context.tmpfile_sql_help = None
+
+
+# # TODO: uncomment to debug a failure
+# def after_step(context, step):
+# if step.status == "failed":
+# import pdb; pdb.set_trace()
diff --git a/tests/features/expanded.feature b/tests/features/expanded.feature
new file mode 100644
index 0000000..e486048
--- /dev/null
+++ b/tests/features/expanded.feature
@@ -0,0 +1,29 @@
+Feature: expanded mode:
+ on, off, auto
+
+ Scenario: expanded on
+ When we prepare the test data
+ and we set expanded on
+ and we select from table
+ then we see expanded data selected
+ when we drop table
+ then we respond to the destructive warning: y
+ then we see table dropped
+
+ Scenario: expanded off
+ When we prepare the test data
+ and we set expanded off
+ and we select from table
+ then we see nonexpanded data selected
+ when we drop table
+ then we respond to the destructive warning: y
+ then we see table dropped
+
+ Scenario: expanded auto
+ When we prepare the test data
+ and we set expanded auto
+ and we select from table
+ then we see auto data selected
+ when we drop table
+ then we respond to the destructive warning: y
+ then we see table dropped
diff --git a/tests/features/fixture_data/help.txt b/tests/features/fixture_data/help.txt
new file mode 100644
index 0000000..bebb976
--- /dev/null
+++ b/tests/features/fixture_data/help.txt
@@ -0,0 +1,25 @@
++--------------------------+------------------------------------------------+
+| Command | Description |
+|--------------------------+------------------------------------------------|
+| \# | Refresh auto-completions. |
+| \? | Show Help. |
+| \T [format] | Change the table format used to output results |
+| \c[onnect] database_name | Change to a new database. |
+| \d [pattern] | List or describe tables, views and sequences. |
+| \dT[S+] [pattern] | List data types |
+| \df[+] [pattern] | List functions. |
+| \di[+] [pattern] | List indexes. |
+| \dn[+] [pattern] | List schemas. |
+| \ds[+] [pattern] | List sequences. |
+| \dt[+] [pattern] | List tables. |
+| \du[+] [pattern] | List roles. |
+| \dv[+] [pattern] | List views. |
+| \e [file] | Edit the query with external editor. |
+| \l | List databases. |
+| \n[+] [name] | List or execute named queries. |
+| \nd [name [query]] | Delete a named query. |
+| \ns name query | Save a named query. |
+| \refresh | Refresh auto-completions. |
+| \timing | Toggle timing of commands. |
+| \x | Toggle expanded output. |
++--------------------------+------------------------------------------------+
diff --git a/tests/features/fixture_data/help_commands.txt b/tests/features/fixture_data/help_commands.txt
new file mode 100644
index 0000000..e076661
--- /dev/null
+++ b/tests/features/fixture_data/help_commands.txt
@@ -0,0 +1,64 @@
+Command
+Description
+\#
+Refresh auto-completions.
+\?
+Show Commands.
+\T [format]
+Change the table format used to output results
+\c[onnect] database_name
+Change to a new database.
+\copy [tablename] to/from [filename]
+Copy data between a file and a table.
+\d[+] [pattern]
+List or describe tables, views and sequences.
+\dT[S+] [pattern]
+List data types
+\db[+] [pattern]
+List tablespaces.
+\df[+] [pattern]
+List functions.
+\di[+] [pattern]
+List indexes.
+\dm[+] [pattern]
+List materialized views.
+\dn[+] [pattern]
+List schemas.
+\ds[+] [pattern]
+List sequences.
+\dt[+] [pattern]
+List tables.
+\du[+] [pattern]
+List roles.
+\dv[+] [pattern]
+List views.
+\dx[+] [pattern]
+List extensions.
+\e [file]
+Edit the query with external editor.
+\h
+Show SQL syntax and help.
+\i filename
+Execute commands from file.
+\l
+List databases.
+\n[+] [name] [param1 param2 ...]
+List or execute named queries.
+\nd [name]
+Delete a named query.
+\ns name query
+Save a named query.
+\o [filename]
+Send all query results to file.
+\pager [command]
+Set PAGER. Print the query results via PAGER.
+\pset [key] [value]
+A limited version of traditional \pset
+\refresh
+Refresh auto-completions.
+\sf[+] FUNCNAME
+Show a function's definition.
+\timing
+Toggle timing of commands.
+\x
+Toggle expanded output.
diff --git a/tests/features/fixture_data/mock_pg_service.conf b/tests/features/fixture_data/mock_pg_service.conf
new file mode 100644
index 0000000..15f9811
--- /dev/null
+++ b/tests/features/fixture_data/mock_pg_service.conf
@@ -0,0 +1,4 @@
+[mock_postgres]
+dbname=postgres
+host=localhost
+user=postgres
diff --git a/tests/features/fixture_utils.py b/tests/features/fixture_utils.py
new file mode 100644
index 0000000..70b603d
--- /dev/null
+++ b/tests/features/fixture_utils.py
@@ -0,0 +1,28 @@
+import os
+import codecs
+
+
+def read_fixture_lines(filename):
+ """
+ Read lines of text from file.
+ :param filename: string name
+ :return: list of strings
+ """
+ lines = []
+ for line in codecs.open(filename, "rb", encoding="utf-8"):
+ lines.append(line.strip())
+ return lines
+
+
+def read_fixture_files():
+ """Read all files inside fixture_data directory."""
+ current_dir = os.path.dirname(__file__)
+ fixture_dir = os.path.join(current_dir, "fixture_data/")
+ print(f"reading fixture data: {fixture_dir}")
+ fixture_dict = {}
+ for filename in os.listdir(fixture_dir):
+ if filename not in [".", ".."]:
+ fullname = os.path.join(fixture_dir, filename)
+ fixture_dict[filename] = read_fixture_lines(fullname)
+
+ return fixture_dict
diff --git a/tests/features/iocommands.feature b/tests/features/iocommands.feature
new file mode 100644
index 0000000..dad7d10
--- /dev/null
+++ b/tests/features/iocommands.feature
@@ -0,0 +1,17 @@
+Feature: I/O commands
+
+ Scenario: edit sql in file with external editor
+ When we start external editor providing a file name
+ and we type sql in the editor
+ and we exit the editor
+ then we see dbcli prompt
+ and we see the sql in prompt
+
+ Scenario: tee output from query
+ When we tee output
+ and we wait for prompt
+ and we query "select 123456"
+ and we wait for prompt
+ and we stop teeing output
+ and we wait for prompt
+ then we see 123456 in tee output
diff --git a/tests/features/named_queries.feature b/tests/features/named_queries.feature
new file mode 100644
index 0000000..74201b9
--- /dev/null
+++ b/tests/features/named_queries.feature
@@ -0,0 +1,10 @@
+Feature: named queries:
+ save, use and delete named queries
+
+ Scenario: save, use and delete named queries
+ When we connect to test database
+ then we see database connected
+ when we save a named query
+ then we see the named query saved
+ when we delete a named query
+ then we see the named query deleted
diff --git a/tests/features/pgbouncer.feature b/tests/features/pgbouncer.feature
new file mode 100644
index 0000000..14cc5ad
--- /dev/null
+++ b/tests/features/pgbouncer.feature
@@ -0,0 +1,12 @@
+@pgbouncer
+Feature: run pgbouncer,
+ call the help command,
+ exit the cli
+
+ Scenario: run "show help" command
+ When we send "show help" command
+ then we see the pgbouncer help output
+
+ Scenario: run the cli and exit
+ When we send "ctrl + d"
+ then dbcli exits
diff --git a/tests/features/specials.feature b/tests/features/specials.feature
new file mode 100644
index 0000000..63c5cdc
--- /dev/null
+++ b/tests/features/specials.feature
@@ -0,0 +1,6 @@
+Feature: Special commands
+
+ Scenario: run refresh command
+ When we refresh completions
+ and we wait for prompt
+ then we see completions refresh started
diff --git a/tests/features/steps/__init__.py b/tests/features/steps/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/features/steps/__init__.py
diff --git a/tests/features/steps/auto_vertical.py b/tests/features/steps/auto_vertical.py
new file mode 100644
index 0000000..d7cdccd
--- /dev/null
+++ b/tests/features/steps/auto_vertical.py
@@ -0,0 +1,99 @@
+from textwrap import dedent
+from behave import then, when
+import wrappers
+
+
+@when("we run dbcli with {arg}")
+def step_run_cli_with_arg(context, arg):
+ wrappers.run_cli(context, run_args=arg.split("="))
+
+
+@when("we execute a small query")
+def step_execute_small_query(context):
+ context.cli.sendline("select 1")
+
+
+@when("we execute a large query")
+def step_execute_large_query(context):
+ context.cli.sendline("select {}".format(",".join([str(n) for n in range(1, 50)])))
+
+
+@then("we see small results in horizontal format")
+def step_see_small_results(context):
+ wrappers.expect_pager(
+ context,
+ dedent(
+ """\
+ +----------+\r
+ | ?column? |\r
+ |----------|\r
+ | 1 |\r
+ +----------+\r
+ SELECT 1\r
+ """
+ ),
+ timeout=5,
+ )
+
+
+@then("we see large results in vertical format")
+def step_see_large_results(context):
+ wrappers.expect_pager(
+ context,
+ dedent(
+ """\
+ -[ RECORD 1 ]-------------------------\r
+ ?column? | 1\r
+ ?column? | 2\r
+ ?column? | 3\r
+ ?column? | 4\r
+ ?column? | 5\r
+ ?column? | 6\r
+ ?column? | 7\r
+ ?column? | 8\r
+ ?column? | 9\r
+ ?column? | 10\r
+ ?column? | 11\r
+ ?column? | 12\r
+ ?column? | 13\r
+ ?column? | 14\r
+ ?column? | 15\r
+ ?column? | 16\r
+ ?column? | 17\r
+ ?column? | 18\r
+ ?column? | 19\r
+ ?column? | 20\r
+ ?column? | 21\r
+ ?column? | 22\r
+ ?column? | 23\r
+ ?column? | 24\r
+ ?column? | 25\r
+ ?column? | 26\r
+ ?column? | 27\r
+ ?column? | 28\r
+ ?column? | 29\r
+ ?column? | 30\r
+ ?column? | 31\r
+ ?column? | 32\r
+ ?column? | 33\r
+ ?column? | 34\r
+ ?column? | 35\r
+ ?column? | 36\r
+ ?column? | 37\r
+ ?column? | 38\r
+ ?column? | 39\r
+ ?column? | 40\r
+ ?column? | 41\r
+ ?column? | 42\r
+ ?column? | 43\r
+ ?column? | 44\r
+ ?column? | 45\r
+ ?column? | 46\r
+ ?column? | 47\r
+ ?column? | 48\r
+ ?column? | 49\r
+ SELECT 1\r
+ """
+ ),
+ timeout=5,
+ )
diff --git a/tests/features/steps/basic_commands.py b/tests/features/steps/basic_commands.py
new file mode 100644
index 0000000..687bdc0
--- /dev/null
+++ b/tests/features/steps/basic_commands.py
@@ -0,0 +1,231 @@
+"""
+Steps for behavioral style tests are defined in this module.
+Each step is defined by the string decorating it.
+This string is used to call the step in "*.feature" file.
+"""
+
+import pexpect
+import subprocess
+import tempfile
+
+from behave import when, then
+from textwrap import dedent
+import wrappers
+
+
+@when("we list databases")
+def step_list_databases(context):
+ cmd = ["pgcli", "--list"]
+ context.cmd_output = subprocess.check_output(cmd, cwd=context.package_root)
+
+
+@then("we see list of databases")
+def step_see_list_databases(context):
+ assert b"List of databases" in context.cmd_output
+ assert b"postgres" in context.cmd_output
+ context.cmd_output = None
+
+
+@when("we run dbcli")
+def step_run_cli(context):
+ wrappers.run_cli(context)
+
+
+@when("we launch dbcli using {arg}")
+def step_run_cli_using_arg(context, arg):
+ prompt_check = False
+ currentdb = None
+ if arg == "--username":
+ arg = "--username={}".format(context.conf["user"])
+ if arg == "--user":
+ arg = "--user={}".format(context.conf["user"])
+ if arg == "--port":
+ arg = "--port={}".format(context.conf["port"])
+ if arg == "--password":
+ arg = "--password"
+ prompt_check = False
+ # This uses the mock_pg_service.conf file in fixtures folder.
+ if arg == "dsn_password":
+ arg = "service=mock_postgres --password"
+ prompt_check = False
+ currentdb = "postgres"
+ wrappers.run_cli(
+ context, run_args=[arg], prompt_check=prompt_check, currentdb=currentdb
+ )
+
+
+@when("we wait for prompt")
+def step_wait_prompt(context):
+ wrappers.wait_prompt(context)
+
+
+@when('we send "ctrl + d"')
+def step_ctrl_d(context):
+ """
+ Send Ctrl + D to hopefully exit.
+ """
+ step_try_to_ctrl_d(context)
+ context.cli.expect(pexpect.EOF, timeout=5)
+ context.exit_sent = True
+
+
+@when('we try to send "ctrl + d"')
+def step_try_to_ctrl_d(context):
+ """
+ Send Ctrl + D, perhaps exiting, perhaps not (if a transaction is
+ ongoing).
+ """
+ # turn off pager before exiting
+ context.cli.sendcontrol("c")
+ context.cli.sendline(r"\pset pager off")
+ wrappers.wait_prompt(context)
+ context.cli.sendcontrol("d")
+
+
+@when('we send "ctrl + c"')
+def step_ctrl_c(context):
+ """Send Ctrl + c to hopefully interrupt."""
+ context.cli.sendcontrol("c")
+
+
+@then("we see cancelled query warning")
+def step_see_cancelled_query_warning(context):
+ """
+ Make sure we receive the warning that the current query was cancelled.
+ """
+ wrappers.expect_exact(context, "cancelled query", timeout=2)
+
+
+@then("we see ongoing transaction message")
+def step_see_ongoing_transaction_error(context):
+ """
+ Make sure we receive the warning that a transaction is ongoing.
+ """
+ context.cli.expect("A transaction is ongoing.", timeout=2)
+
+
+@when("we send sleep query")
+def step_send_sleep_15_seconds(context):
+ """
+ Send query to sleep for 15 seconds.
+ """
+ context.cli.sendline("select pg_sleep(15)")
+
+
+@when("we check for any non-idle sleep queries")
+def step_check_for_active_sleep_queries(context):
+ """
+ Send query to check for any non-idle pg_sleep queries.
+ """
+ context.cli.sendline(
+ "select state from pg_stat_activity where query not like '%pg_stat_activity%' and query like '%pg_sleep%' and state != 'idle';"
+ )
+
+
+@then("we don't see any non-idle sleep queries")
+def step_no_active_sleep_queries(context):
+ """Confirm that any pg_sleep queries are either idle or not active."""
+ wrappers.expect_exact(
+ context,
+ context.conf["pager_boundary"]
+ + "\r"
+ + dedent(
+ """
+ +-------+\r
+ | state |\r
+ |-------|\r
+ +-------+\r
+ SELECT 0\r
+ """
+ )
+ + context.conf["pager_boundary"],
+ timeout=5,
+ )
+
+
+@when(r'we send "\?" command')
+def step_send_help(context):
+ r"""
+ Send \? to see help.
+ """
+ context.cli.sendline(r"\?")
+
+
+@when("we send partial select command")
+def step_send_partial_select_command(context):
+ """
+ Send `SELECT a` to see completion.
+ """
+ context.cli.sendline("SELECT a")
+
+
+@then("we see error message")
+def step_see_error_message(context):
+ wrappers.expect_exact(context, 'column "a" does not exist', timeout=2)
+
+
+@when("we send source command")
+def step_send_source_command(context):
+ context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_")
+ context.tmpfile_sql_help.write(rb"\?")
+ context.tmpfile_sql_help.flush()
+ context.cli.sendline(rf"\i {context.tmpfile_sql_help.name}")
+ wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
+
+
+@when("we run query to check application_name")
+def step_check_application_name(context):
+ context.cli.sendline(
+ "SELECT 'found' FROM pg_stat_activity WHERE application_name = 'pgcli' HAVING COUNT(*) > 0;"
+ )
+
+
+@then("we see found")
+def step_see_found(context):
+ wrappers.expect_exact(
+ context,
+ context.conf["pager_boundary"]
+ + "\r"
+ + dedent(
+ """
+ +----------+\r
+ | ?column? |\r
+ |----------|\r
+ | found |\r
+ +----------+\r
+ SELECT 1\r
+ """
+ )
+ + context.conf["pager_boundary"],
+ timeout=5,
+ )
+
+
+@then("we respond to the destructive warning: {response}")
+def step_resppond_to_destructive_command(context, response):
+ """Respond to destructive command."""
+ wrappers.expect_exact(
+ context,
+ "You're about to run a destructive command.\r\nDo you want to proceed? [y/N]:",
+ timeout=2,
+ )
+ context.cli.sendline(response.strip())
+
+
+@then("we send password")
+def step_send_password(context):
+ wrappers.expect_exact(context, "Password for", timeout=5)
+ context.cli.sendline(context.conf["pass"] or "DOES NOT MATTER")
+
+
+@when('we send "{text}"')
+def step_send_text(context, text):
+ context.cli.sendline(text)
+ # Try to detect whether we are exiting. If so, set `exit_sent`
+ # so that `after_scenario` correctly cleans up.
+ try:
+ context.cli.expect(pexpect.EOF, timeout=0.2)
+ except pexpect.TIMEOUT:
+ pass
+ else:
+ context.exit_sent = True
diff --git a/tests/features/steps/crud_database.py b/tests/features/steps/crud_database.py
new file mode 100644
index 0000000..87cdc85
--- /dev/null
+++ b/tests/features/steps/crud_database.py
@@ -0,0 +1,93 @@
+"""
+Steps for behavioral style tests are defined in this module.
+Each step is defined by the string decorating it.
+This string is used to call the step in "*.feature" file.
+"""
+import pexpect
+
+from behave import when, then
+import wrappers
+
+
+@when("we create database")
+def step_db_create(context):
+ """
+ Send create database.
+ """
+ context.cli.sendline("create database {};".format(context.conf["dbname_tmp"]))
+
+ context.response = {"database_name": context.conf["dbname_tmp"]}
+
+
+@when("we drop database")
+def step_db_drop(context):
+ """
+ Send drop database.
+ """
+ context.cli.sendline("drop database {};".format(context.conf["dbname_tmp"]))
+
+
+@when("we connect to test database")
+def step_db_connect_test(context):
+ """
+ Send connect to database.
+ """
+ db_name = context.conf["dbname"]
+ context.cli.sendline(f"\\connect {db_name}")
+
+
+@when("we connect to dbserver")
+def step_db_connect_dbserver(context):
+ """
+ Send connect to database.
+ """
+ context.cli.sendline("\\connect postgres")
+ context.currentdb = "postgres"
+
+
+@then("dbcli exits")
+def step_wait_exit(context):
+ """
+ Make sure the cli exits.
+ """
+ wrappers.expect_exact(context, pexpect.EOF, timeout=5)
+
+
+@then("we see dbcli prompt")
+def step_see_prompt(context):
+ """
+ Wait to see the prompt.
+ """
+ db_name = getattr(context, "currentdb", context.conf["dbname"])
+ wrappers.expect_exact(context, f"{db_name}>", timeout=5)
+ context.atprompt = True
+
+
+@then("we see help output")
+def step_see_help(context):
+ for expected_line in context.fixture_data["help_commands.txt"]:
+ wrappers.expect_exact(context, expected_line, timeout=2)
+
+
+@then("we see database created")
+def step_see_db_created(context):
+ """
+ Wait to see create database output.
+ """
+ wrappers.expect_pager(context, "CREATE DATABASE\r\n", timeout=5)
+
+
+@then("we see database dropped")
+def step_see_db_dropped(context):
+ """
+ Wait to see drop database output.
+ """
+ wrappers.expect_pager(context, "DROP DATABASE\r\n", timeout=2)
+
+
+@then("we see database connected")
+def step_see_db_connected(context):
+ """
+ Wait to see drop database output.
+ """
+ wrappers.expect_exact(context, "You are now connected to database", timeout=2)
diff --git a/tests/features/steps/crud_table.py b/tests/features/steps/crud_table.py
new file mode 100644
index 0000000..27d543e
--- /dev/null
+++ b/tests/features/steps/crud_table.py
@@ -0,0 +1,185 @@
+"""
+Steps for behavioral style tests are defined in this module.
+Each step is defined by the string decorating it.
+This string is used to call the step in "*.feature" file.
+"""
+
+from behave import when, then
+from textwrap import dedent
+import wrappers
+
+
+INITIAL_DATA = "xxx"
+UPDATED_DATA = "yyy"
+
+
+@when("we create table")
+def step_create_table(context):
+ """
+ Send create table.
+ """
+ context.cli.sendline("create table a(x text);")
+
+
+@when("we insert into table")
+def step_insert_into_table(context):
+ """
+ Send insert into table.
+ """
+ context.cli.sendline(f"""insert into a(x) values('{INITIAL_DATA}');""")
+
+
+@when("we update table")
+def step_update_table(context):
+ """
+ Send insert into table.
+ """
+ context.cli.sendline(
+ f"""update a set x = '{UPDATED_DATA}' where x = '{INITIAL_DATA}';"""
+ )
+
+
+@when("we select from table")
+def step_select_from_table(context):
+ """
+ Send select from table.
+ """
+ context.cli.sendline("select * from a;")
+
+
+@when("we delete from table")
+def step_delete_from_table(context):
+ """
+ Send deete from table.
+ """
+ context.cli.sendline(f"""delete from a where x = '{UPDATED_DATA}';""")
+
+
+@when("we drop table")
+def step_drop_table(context):
+ """
+ Send drop table.
+ """
+ context.cli.sendline("drop table a;")
+
+
+@when("we alter the table")
+def step_alter_table(context):
+ """
+ Alter the table by adding a column.
+ """
+ context.cli.sendline("""alter table a add column y varchar;""")
+
+
+@when("we begin transaction")
+def step_begin_transaction(context):
+ """
+ Begin transaction
+ """
+ context.cli.sendline("begin;")
+
+
+@when("we rollback transaction")
+def step_rollback_transaction(context):
+ """
+ Rollback transaction
+ """
+ context.cli.sendline("rollback;")
+
+
+@then("we see table created")
+def step_see_table_created(context):
+ """
+ Wait to see create table output.
+ """
+ wrappers.expect_pager(context, "CREATE TABLE\r\n", timeout=2)
+
+
+@then("we see record inserted")
+def step_see_record_inserted(context):
+ """
+ Wait to see insert output.
+ """
+ wrappers.expect_pager(context, "INSERT 0 1\r\n", timeout=2)
+
+
+@then("we see record updated")
+def step_see_record_updated(context):
+ """
+ Wait to see update output.
+ """
+ wrappers.expect_pager(context, "UPDATE 1\r\n", timeout=2)
+
+
+@then("we see data selected: {data}")
+def step_see_data_selected(context, data):
+ """
+ Wait to see select output with initial or updated data.
+ """
+ x = UPDATED_DATA if data == "updated" else INITIAL_DATA
+ wrappers.expect_pager(
+ context,
+ dedent(
+ f"""\
+ +-----+\r
+ | x |\r
+ |-----|\r
+ | {x} |\r
+ +-----+\r
+ SELECT 1\r
+ """
+ ),
+ timeout=1,
+ )
+
+
+@then("we see select output without data")
+def step_see_no_data_selected(context):
+ """
+ Wait to see select output without data.
+ """
+ wrappers.expect_pager(
+ context,
+ dedent(
+ """\
+ +---+\r
+ | x |\r
+ |---|\r
+ +---+\r
+ SELECT 0\r
+ """
+ ),
+ timeout=1,
+ )
+
+
+@then("we see record deleted")
+def step_see_data_deleted(context):
+ """
+ Wait to see delete output.
+ """
+ wrappers.expect_pager(context, "DELETE 1\r\n", timeout=2)
+
+
+@then("we see table dropped")
+def step_see_table_dropped(context):
+ """
+ Wait to see drop output.
+ """
+ wrappers.expect_pager(context, "DROP TABLE\r\n", timeout=2)
+
+
+@then("we see transaction began")
+def step_see_transaction_began(context):
+ """
+ Wait to see transaction began.
+ """
+ wrappers.expect_pager(context, "BEGIN\r\n", timeout=2)
+
+
+@then("we see transaction rolled back")
+def step_see_transaction_rolled_back(context):
+ """
+ Wait to see transaction rollback.
+ """
+ wrappers.expect_pager(context, "ROLLBACK\r\n", timeout=2)
diff --git a/tests/features/steps/expanded.py b/tests/features/steps/expanded.py
new file mode 100644
index 0000000..302cab9
--- /dev/null
+++ b/tests/features/steps/expanded.py
@@ -0,0 +1,70 @@
+"""Steps for behavioral style tests are defined in this module.
+
+Each step is defined by the string decorating it. This string is used
+to call the step in "*.feature" file.
+
+"""
+
+from behave import when, then
+from textwrap import dedent
+import wrappers
+
+
+@when("we prepare the test data")
+def step_prepare_data(context):
+ """Create table, insert a record."""
+ context.cli.sendline("drop table if exists a;")
+ wrappers.expect_exact(
+ context,
+ "You're about to run a destructive command.\r\nDo you want to proceed? [y/N]:",
+ timeout=2,
+ )
+ context.cli.sendline("y")
+
+ wrappers.wait_prompt(context)
+ context.cli.sendline("create table a(x integer, y real, z numeric(10, 4));")
+ wrappers.expect_pager(context, "CREATE TABLE\r\n", timeout=2)
+ context.cli.sendline("""insert into a(x, y, z) values(1, 1.0, 1.0);""")
+ wrappers.expect_pager(context, "INSERT 0 1\r\n", timeout=2)
+
+
+@when("we set expanded {mode}")
+def step_set_expanded(context, mode):
+ """Set expanded to mode."""
+ context.cli.sendline("\\" + f"x {mode}")
+ wrappers.expect_exact(context, "Expanded display is", timeout=2)
+ wrappers.wait_prompt(context)
+
+
+@then("we see {which} data selected")
+def step_see_data(context, which):
+ """Select data from expanded test table."""
+ if which == "expanded":
+ wrappers.expect_pager(
+ context,
+ dedent(
+ """\
+ -[ RECORD 1 ]-------------------------\r
+ x | 1\r
+ y | 1.0\r
+ z | 1.0000\r
+ SELECT 1\r
+ """
+ ),
+ timeout=1,
+ )
+ else:
+ wrappers.expect_pager(
+ context,
+ dedent(
+ """\
+ +---+-----+--------+\r
+ | x | y | z |\r
+ |---+-----+--------|\r
+ | 1 | 1.0 | 1.0000 |\r
+ +---+-----+--------+\r
+ SELECT 1\r
+ """
+ ),
+ timeout=1,
+ )
diff --git a/tests/features/steps/iocommands.py b/tests/features/steps/iocommands.py
new file mode 100644
index 0000000..a614490
--- /dev/null
+++ b/tests/features/steps/iocommands.py
@@ -0,0 +1,80 @@
+import os
+import os.path
+
+from behave import when, then
+import wrappers
+
+
+@when("we start external editor providing a file name")
+def step_edit_file(context):
+ """Edit file with external editor."""
+ context.editor_file_name = os.path.join(
+ context.package_root, "test_file_{0}.sql".format(context.conf["vi"])
+ )
+ if os.path.exists(context.editor_file_name):
+ os.remove(context.editor_file_name)
+ context.cli.sendline(r"\e {}".format(os.path.basename(context.editor_file_name)))
+ wrappers.expect_exact(
+ context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2
+ )
+ wrappers.expect_exact(context, ":", timeout=2)
+
+
+@when("we type sql in the editor")
+def step_edit_type_sql(context):
+ context.cli.sendline("i")
+ context.cli.sendline("select * from abc")
+ context.cli.sendline(".")
+ wrappers.expect_exact(context, ":", timeout=2)
+
+
+@when("we exit the editor")
+def step_edit_quit(context):
+ context.cli.sendline("x")
+ wrappers.expect_exact(context, "written", timeout=2)
+
+
+@then("we see the sql in prompt")
+def step_edit_done_sql(context):
+ for match in "select * from abc".split(" "):
+ wrappers.expect_exact(context, match, timeout=1)
+ # Cleanup the command line.
+ context.cli.sendcontrol("c")
+ # Cleanup the edited file.
+ if context.editor_file_name and os.path.exists(context.editor_file_name):
+ os.remove(context.editor_file_name)
+ context.atprompt = True
+
+
+@when("we tee output")
+def step_tee_ouptut(context):
+ context.tee_file_name = os.path.join(
+ context.package_root, "tee_file_{0}.sql".format(context.conf["vi"])
+ )
+ if os.path.exists(context.tee_file_name):
+ os.remove(context.tee_file_name)
+ context.cli.sendline(r"\o {}".format(os.path.basename(context.tee_file_name)))
+ wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
+ wrappers.expect_exact(context, "Writing to file", timeout=5)
+ wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
+ wrappers.expect_exact(context, "Time", timeout=5)
+
+
+@when('we query "select 123456"')
+def step_query_select_123456(context):
+ context.cli.sendline("select 123456")
+
+
+@when("we stop teeing output")
+def step_notee_output(context):
+ context.cli.sendline(r"\o")
+ wrappers.expect_exact(context, "Time", timeout=5)
+
+
+@then("we see 123456 in tee output")
+def step_see_123456_in_ouput(context):
+ with open(context.tee_file_name) as f:
+ assert "123456" in f.read()
+ if os.path.exists(context.tee_file_name):
+ os.remove(context.tee_file_name)
+ context.atprompt = True
diff --git a/tests/features/steps/named_queries.py b/tests/features/steps/named_queries.py
new file mode 100644
index 0000000..3f52859
--- /dev/null
+++ b/tests/features/steps/named_queries.py
@@ -0,0 +1,57 @@
+"""
+Steps for behavioral style tests are defined in this module.
+Each step is defined by the string decorating it.
+This string is used to call the step in "*.feature" file.
+"""
+
+from behave import when, then
+import wrappers
+
+
+@when("we save a named query")
+def step_save_named_query(context):
+ """
+ Send \ns command
+ """
+ context.cli.sendline("\\ns foo SELECT 12345")
+
+
+@when("we use a named query")
+def step_use_named_query(context):
+ """
+ Send \n command
+ """
+ context.cli.sendline("\\n foo")
+
+
+@when("we delete a named query")
+def step_delete_named_query(context):
+ """
+ Send \nd command
+ """
+ context.cli.sendline("\\nd foo")
+
+
+@then("we see the named query saved")
+def step_see_named_query_saved(context):
+ """
+ Wait to see query saved.
+ """
+ wrappers.expect_exact(context, "Saved.", timeout=2)
+
+
+@then("we see the named query executed")
+def step_see_named_query_executed(context):
+ """
+ Wait to see select output.
+ """
+ wrappers.expect_exact(context, "12345", timeout=1)
+ wrappers.expect_exact(context, "SELECT 1", timeout=1)
+
+
+@then("we see the named query deleted")
+def step_see_named_query_deleted(context):
+ """
+ Wait to see query deleted.
+ """
+ wrappers.expect_pager(context, "foo: Deleted\r\n", timeout=1)
diff --git a/tests/features/steps/pgbouncer.py b/tests/features/steps/pgbouncer.py
new file mode 100644
index 0000000..f156982
--- /dev/null
+++ b/tests/features/steps/pgbouncer.py
@@ -0,0 +1,22 @@
+"""
+Steps for behavioral style tests are defined in this module.
+Each step is defined by the string decorating it.
+This string is used to call the step in "*.feature" file.
+"""
+
+from behave import when, then
+import wrappers
+
+
+@when('we send "show help" command')
+def step_send_help_command(context):
+ context.cli.sendline("show help")
+
+
+@then("we see the pgbouncer help output")
+def see_pgbouncer_help(context):
+ wrappers.expect_exact(
+ context,
+ "SHOW HELP|CONFIG|DATABASES|POOLS|CLIENTS|SERVERS|USERS|VERSION",
+ timeout=3,
+ )
diff --git a/tests/features/steps/specials.py b/tests/features/steps/specials.py
new file mode 100644
index 0000000..a85f371
--- /dev/null
+++ b/tests/features/steps/specials.py
@@ -0,0 +1,31 @@
+"""
+Steps for behavioral style tests are defined in this module.
+Each step is defined by the string decorating it.
+This string is used to call the step in "*.feature" file.
+"""
+
+from behave import when, then
+import wrappers
+
+
+@when("we refresh completions")
+def step_refresh_completions(context):
+ """
+ Send refresh command.
+ """
+ context.cli.sendline("\\refresh")
+
+
+@then("we see completions refresh started")
+def step_see_refresh_started(context):
+ """
+ Wait to see refresh output.
+ """
+ wrappers.expect_pager(
+ context,
+ [
+ "Auto-completion refresh started in the background.\r\n",
+ "Auto-completion refresh restarted.\r\n",
+ ],
+ timeout=2,
+ )
diff --git a/tests/features/steps/wrappers.py b/tests/features/steps/wrappers.py
new file mode 100644
index 0000000..3ebcc92
--- /dev/null
+++ b/tests/features/steps/wrappers.py
@@ -0,0 +1,71 @@
+import re
+import pexpect
+from pgcli.main import COLOR_CODE_REGEX
+import textwrap
+
+from io import StringIO
+
+
+def expect_exact(context, expected, timeout):
+ timedout = False
+ try:
+ context.cli.expect_exact(expected, timeout=timeout)
+ except pexpect.TIMEOUT:
+ timedout = True
+ if timedout:
+ # Strip color codes out of the output.
+ actual = re.sub(r"\x1b\[([0-9A-Za-z;?])+[m|K]?", "", context.cli.before)
+ raise Exception(
+ textwrap.dedent(
+ """\
+ Expected:
+ ---
+ {0!r}
+ ---
+ Actual:
+ ---
+ {1!r}
+ ---
+ Full log:
+ ---
+ {2!r}
+ ---
+ """
+ ).format(expected, actual, context.logfile.getvalue())
+ )
+
+
+def expect_pager(context, expected, timeout):
+ formatted = expected if isinstance(expected, list) else [expected]
+ formatted = [
+ f"{context.conf['pager_boundary']}\r\n{t}{context.conf['pager_boundary']}\r\n"
+ for t in formatted
+ ]
+
+ expect_exact(
+ context,
+ formatted,
+ timeout=timeout,
+ )
+
+
+def run_cli(context, run_args=None, prompt_check=True, currentdb=None):
+ """Run the process using pexpect."""
+ run_args = run_args or []
+ cli_cmd = context.conf.get("cli_command")
+ cmd_parts = [cli_cmd] + run_args
+ cmd = " ".join(cmd_parts)
+ context.cli = pexpect.spawnu(cmd, cwd=context.package_root)
+ context.logfile = StringIO()
+ context.cli.logfile = context.logfile
+ context.exit_sent = False
+ context.currentdb = currentdb or context.conf["dbname"]
+ context.cli.sendline(r"\pset pager always")
+ if prompt_check:
+ wait_prompt(context)
+
+
+def wait_prompt(context):
+ """Make sure prompt is displayed."""
+ prompt_str = "{0}>".format(context.currentdb)
+ expect_exact(context, [prompt_str + " ", prompt_str, pexpect.EOF], timeout=3)
diff --git a/tests/features/wrappager.py b/tests/features/wrappager.py
new file mode 100644
index 0000000..51d4909
--- /dev/null
+++ b/tests/features/wrappager.py
@@ -0,0 +1,16 @@
+#!/usr/bin/env python
+import sys
+
+
+def wrappager(boundary):
+ print(boundary)
+ while 1:
+ buf = sys.stdin.read(2048)
+ if not buf:
+ break
+ sys.stdout.write(buf)
+ print(boundary)
+
+
+if __name__ == "__main__":
+ wrappager(sys.argv[1])
diff --git a/tests/formatter/__init__.py b/tests/formatter/__init__.py
new file mode 100644
index 0000000..9bad579
--- /dev/null
+++ b/tests/formatter/__init__.py
@@ -0,0 +1 @@
+# coding=utf-8
diff --git a/tests/formatter/test_sqlformatter.py b/tests/formatter/test_sqlformatter.py
new file mode 100644
index 0000000..016ed95
--- /dev/null
+++ b/tests/formatter/test_sqlformatter.py
@@ -0,0 +1,111 @@
+# coding=utf-8
+
+from pgcli.packages.formatter.sqlformatter import escape_for_sql_statement
+
+from cli_helpers.tabular_output import TabularOutputFormatter
+from pgcli.packages.formatter.sqlformatter import adapter, register_new_formatter
+
+
+def test_escape_for_sql_statement_bytes():
+ bts = b"837124ab3e8dc0f"
+ escaped_bytes = escape_for_sql_statement(bts)
+ assert escaped_bytes == "X'383337313234616233653864633066'"
+
+
+def test_escape_for_sql_statement_number():
+ num = 2981
+ escaped_bytes = escape_for_sql_statement(num)
+ assert escaped_bytes == "'2981'"
+
+
+def test_escape_for_sql_statement_str():
+ example_str = "example str"
+ escaped_bytes = escape_for_sql_statement(example_str)
+ assert escaped_bytes == "'example str'"
+
+
+def test_output_sql_insert():
+ global formatter
+ formatter = TabularOutputFormatter
+ register_new_formatter(formatter)
+ data = [
+ [
+ 1,
+ "Jackson",
+ "jackson_test@gmail.com",
+ "132454789",
+ None,
+ "2022-09-09 19:44:32.712343+08",
+ "2022-09-09 19:44:32.712343+08",
+ ]
+ ]
+ header = ["id", "name", "email", "phone", "description", "created_at", "updated_at"]
+ table_format = "sql-insert"
+ kwargs = {
+ "column_types": [int, str, str, str, str, str, str],
+ "sep_title": "RECORD {n}",
+ "sep_character": "-",
+ "sep_length": (1, 25),
+ "missing_value": "<null>",
+ "integer_format": "",
+ "float_format": "",
+ "disable_numparse": True,
+ "preserve_whitespace": True,
+ "max_field_width": 500,
+ }
+ formatter.query = 'SELECT * FROM "user";'
+ output = adapter(data, header, table_format=table_format, **kwargs)
+ output_list = [l for l in output]
+ expected = [
+ 'INSERT INTO "user" ("id", "name", "email", "phone", "description", "created_at", "updated_at") VALUES',
+ " ('1', 'Jackson', 'jackson_test@gmail.com', '132454789', NULL, "
+ + "'2022-09-09 19:44:32.712343+08', '2022-09-09 19:44:32.712343+08')",
+ ";",
+ ]
+ assert expected == output_list
+
+
+def test_output_sql_update():
+ global formatter
+ formatter = TabularOutputFormatter
+ register_new_formatter(formatter)
+ data = [
+ [
+ 1,
+ "Jackson",
+ "jackson_test@gmail.com",
+ "132454789",
+ "",
+ "2022-09-09 19:44:32.712343+08",
+ "2022-09-09 19:44:32.712343+08",
+ ]
+ ]
+ header = ["id", "name", "email", "phone", "description", "created_at", "updated_at"]
+ table_format = "sql-update"
+ kwargs = {
+ "column_types": [int, str, str, str, str, str, str],
+ "sep_title": "RECORD {n}",
+ "sep_character": "-",
+ "sep_length": (1, 25),
+ "missing_value": "<null>",
+ "integer_format": "",
+ "float_format": "",
+ "disable_numparse": True,
+ "preserve_whitespace": True,
+ "max_field_width": 500,
+ }
+ formatter.query = 'SELECT * FROM "user";'
+ output = adapter(data, header, table_format=table_format, **kwargs)
+ output_list = [l for l in output]
+ print(output_list)
+ expected = [
+ 'UPDATE "user" SET',
+ " \"name\" = 'Jackson'",
+ ", \"email\" = 'jackson_test@gmail.com'",
+ ", \"phone\" = '132454789'",
+ ", \"description\" = ''",
+ ", \"created_at\" = '2022-09-09 19:44:32.712343+08'",
+ ", \"updated_at\" = '2022-09-09 19:44:32.712343+08'",
+ "WHERE \"id\" = '1';",
+ ]
+ assert expected == output_list
diff --git a/tests/metadata.py b/tests/metadata.py
new file mode 100644
index 0000000..4ebcccd
--- /dev/null
+++ b/tests/metadata.py
@@ -0,0 +1,255 @@
+from functools import partial
+from itertools import product
+from pgcli.packages.parseutils.meta import FunctionMetadata, ForeignKey
+from prompt_toolkit.completion import Completion
+from prompt_toolkit.document import Document
+from unittest.mock import Mock
+import pytest
+
+parametrize = pytest.mark.parametrize
+
+qual = ["if_more_than_one_table", "always"]
+no_qual = ["if_more_than_one_table", "never"]
+
+
+def escape(name):
+ if not name.islower() or name in ("select", "localtimestamp"):
+ return '"' + name + '"'
+ return name
+
+
+def completion(display_meta, text, pos=0):
+ return Completion(text, start_position=pos, display_meta=display_meta)
+
+
+def function(text, pos=0, display=None):
+ return Completion(
+ text, display=display or text, start_position=pos, display_meta="function"
+ )
+
+
+def get_result(completer, text, position=None):
+ position = len(text) if position is None else position
+ return completer.get_completions(
+ Document(text=text, cursor_position=position), Mock()
+ )
+
+
+def result_set(completer, text, position=None):
+ return set(get_result(completer, text, position))
+
+
+# The code below is quivalent to
+# def schema(text, pos=0):
+# return completion('schema', text, pos)
+# and so on
+schema = partial(completion, "schema")
+table = partial(completion, "table")
+view = partial(completion, "view")
+column = partial(completion, "column")
+keyword = partial(completion, "keyword")
+datatype = partial(completion, "datatype")
+alias = partial(completion, "table alias")
+name_join = partial(completion, "name join")
+fk_join = partial(completion, "fk join")
+join = partial(completion, "join")
+
+
+def wildcard_expansion(cols, pos=-1):
+ return Completion(cols, start_position=pos, display_meta="columns", display="*")
+
+
+class MetaData:
+ def __init__(self, metadata):
+ self.metadata = metadata
+
+ def builtin_functions(self, pos=0):
+ return [function(f, pos) for f in self.completer.functions]
+
+ def builtin_datatypes(self, pos=0):
+ return [datatype(dt, pos) for dt in self.completer.datatypes]
+
+ def keywords(self, pos=0):
+ return [keyword(kw, pos) for kw in self.completer.keywords_tree.keys()]
+
+ def specials(self, pos=0):
+ return [
+ Completion(text=k, start_position=pos, display_meta=v.description)
+ for k, v in self.completer.pgspecial.commands.items()
+ ]
+
+ def columns(self, tbl, parent="public", typ="tables", pos=0):
+ if typ == "functions":
+ fun = [x for x in self.metadata[typ][parent] if x[0] == tbl][0]
+ cols = fun[1]
+ else:
+ cols = self.metadata[typ][parent][tbl]
+ return [column(escape(col), pos) for col in cols]
+
+ def datatypes(self, parent="public", pos=0):
+ return [
+ datatype(escape(x), pos)
+ for x in self.metadata.get("datatypes", {}).get(parent, [])
+ ]
+
+ def tables(self, parent="public", pos=0):
+ return [
+ table(escape(x), pos)
+ for x in self.metadata.get("tables", {}).get(parent, [])
+ ]
+
+ def views(self, parent="public", pos=0):
+ return [
+ view(escape(x), pos) for x in self.metadata.get("views", {}).get(parent, [])
+ ]
+
+ def functions(self, parent="public", pos=0):
+ return [
+ function(
+ escape(x[0])
+ + "("
+ + ", ".join(
+ arg_name + " := "
+ for (arg_name, arg_mode) in zip(x[1], x[3])
+ if arg_mode in ("b", "i")
+ )
+ + ")",
+ pos,
+ escape(x[0])
+ + "("
+ + ", ".join(
+ arg_name
+ for (arg_name, arg_mode) in zip(x[1], x[3])
+ if arg_mode in ("b", "i")
+ )
+ + ")",
+ )
+ for x in self.metadata.get("functions", {}).get(parent, [])
+ ]
+
+ def schemas(self, pos=0):
+ schemas = {sch for schs in self.metadata.values() for sch in schs}
+ return [schema(escape(s), pos=pos) for s in schemas]
+
+ def functions_and_keywords(self, parent="public", pos=0):
+ return (
+ self.functions(parent, pos)
+ + self.builtin_functions(pos)
+ + self.keywords(pos)
+ )
+
+ # Note that the filtering parameters here only apply to the columns
+ def columns_functions_and_keywords(self, tbl, parent="public", typ="tables", pos=0):
+ return self.functions_and_keywords(pos=pos) + self.columns(
+ tbl, parent, typ, pos
+ )
+
+ def from_clause_items(self, parent="public", pos=0):
+ return (
+ self.functions(parent, pos)
+ + self.views(parent, pos)
+ + self.tables(parent, pos)
+ )
+
+ def schemas_and_from_clause_items(self, parent="public", pos=0):
+ return self.from_clause_items(parent, pos) + self.schemas(pos)
+
+ def types(self, parent="public", pos=0):
+ return self.datatypes(parent, pos) + self.tables(parent, pos)
+
+ @property
+ def completer(self):
+ return self.get_completer()
+
+ def get_completers(self, casing):
+ """
+ Returns a function taking three bools `casing`, `filtr`, `aliasing` and
+ the list `qualify`, all defaulting to None.
+ Returns a list of completers.
+ These parameters specify the allowed values for the corresponding
+ completer parameters, `None` meaning any, i.e. (None, None, None, None)
+ results in all 24 possible completers, whereas e.g.
+ (True, False, True, ['never']) results in the one completer with
+ casing, without `search_path` filtering of objects, with table
+ aliasing, and without column qualification.
+ """
+
+ def _cfg(_casing, filtr, aliasing, qualify):
+ cfg = {"settings": {}}
+ if _casing:
+ cfg["casing"] = casing
+ cfg["settings"]["search_path_filter"] = filtr
+ cfg["settings"]["generate_aliases"] = aliasing
+ cfg["settings"]["qualify_columns"] = qualify
+ return cfg
+
+ def _cfgs(casing, filtr, aliasing, qualify):
+ casings = [True, False] if casing is None else [casing]
+ filtrs = [True, False] if filtr is None else [filtr]
+ aliases = [True, False] if aliasing is None else [aliasing]
+ qualifys = qualify or ["always", "if_more_than_one_table", "never"]
+ return [_cfg(*p) for p in product(casings, filtrs, aliases, qualifys)]
+
+ def completers(casing=None, filtr=None, aliasing=None, qualify=None):
+ get_comp = self.get_completer
+ return [get_comp(**c) for c in _cfgs(casing, filtr, aliasing, qualify)]
+
+ return completers
+
+ def _make_col(self, sch, tbl, col):
+ defaults = self.metadata.get("defaults", {}).get(sch, {})
+ return (sch, tbl, col, "text", (tbl, col) in defaults, defaults.get((tbl, col)))
+
+ def get_completer(self, settings=None, casing=None):
+ metadata = self.metadata
+ from pgcli.pgcompleter import PGCompleter
+ from pgspecial import PGSpecial
+
+ comp = PGCompleter(
+ smart_completion=True, settings=settings, pgspecial=PGSpecial()
+ )
+
+ schemata, tables, tbl_cols, views, view_cols = [], [], [], [], []
+
+ for sch, tbls in metadata["tables"].items():
+ schemata.append(sch)
+
+ for tbl, cols in tbls.items():
+ tables.append((sch, tbl))
+ # Let all columns be text columns
+ tbl_cols.extend([self._make_col(sch, tbl, col) for col in cols])
+
+ for sch, tbls in metadata.get("views", {}).items():
+ for tbl, cols in tbls.items():
+ views.append((sch, tbl))
+ # Let all columns be text columns
+ view_cols.extend([self._make_col(sch, tbl, col) for col in cols])
+
+ functions = [
+ FunctionMetadata(sch, *func_meta, arg_defaults=None)
+ for sch, funcs in metadata["functions"].items()
+ for func_meta in funcs
+ ]
+
+ datatypes = [
+ (sch, typ)
+ for sch, datatypes in metadata["datatypes"].items()
+ for typ in datatypes
+ ]
+
+ foreignkeys = [
+ ForeignKey(*fk) for fks in metadata["foreignkeys"].values() for fk in fks
+ ]
+
+ comp.extend_schemata(schemata)
+ comp.extend_relations(tables, kind="tables")
+ comp.extend_relations(views, kind="views")
+ comp.extend_columns(tbl_cols, kind="tables")
+ comp.extend_columns(view_cols, kind="views")
+ comp.extend_functions(functions)
+ comp.extend_datatypes(datatypes)
+ comp.extend_foreignkeys(foreignkeys)
+ comp.set_search_path(["public"])
+ comp.extend_casing(casing or [])
+
+ return comp
diff --git a/tests/parseutils/test_ctes.py b/tests/parseutils/test_ctes.py
new file mode 100644
index 0000000..3e89cca
--- /dev/null
+++ b/tests/parseutils/test_ctes.py
@@ -0,0 +1,137 @@
+import pytest
+from sqlparse import parse
+from pgcli.packages.parseutils.ctes import (
+ token_start_pos,
+ extract_ctes,
+ extract_column_names as _extract_column_names,
+)
+
+
+def extract_column_names(sql):
+ p = parse(sql)[0]
+ return _extract_column_names(p)
+
+
+def test_token_str_pos():
+ sql = "SELECT * FROM xxx"
+ p = parse(sql)[0]
+ idx = p.token_index(p.tokens[-1])
+ assert token_start_pos(p.tokens, idx) == len("SELECT * FROM ")
+
+ sql = "SELECT * FROM \nxxx"
+ p = parse(sql)[0]
+ idx = p.token_index(p.tokens[-1])
+ assert token_start_pos(p.tokens, idx) == len("SELECT * FROM \n")
+
+
+def test_single_column_name_extraction():
+ sql = "SELECT abc FROM xxx"
+ assert extract_column_names(sql) == ("abc",)
+
+
+def test_aliased_single_column_name_extraction():
+ sql = "SELECT abc def FROM xxx"
+ assert extract_column_names(sql) == ("def",)
+
+
+def test_aliased_expression_name_extraction():
+ sql = "SELECT 99 abc FROM xxx"
+ assert extract_column_names(sql) == ("abc",)
+
+
+def test_multiple_column_name_extraction():
+ sql = "SELECT abc, def FROM xxx"
+ assert extract_column_names(sql) == ("abc", "def")
+
+
+def test_missing_column_name_handled_gracefully():
+ sql = "SELECT abc, 99 FROM xxx"
+ assert extract_column_names(sql) == ("abc",)
+
+ sql = "SELECT abc, 99, def FROM xxx"
+ assert extract_column_names(sql) == ("abc", "def")
+
+
+def test_aliased_multiple_column_name_extraction():
+ sql = "SELECT abc def, ghi jkl FROM xxx"
+ assert extract_column_names(sql) == ("def", "jkl")
+
+
+def test_table_qualified_column_name_extraction():
+ sql = "SELECT abc.def, ghi.jkl FROM xxx"
+ assert extract_column_names(sql) == ("def", "jkl")
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "INSERT INTO foo (x, y, z) VALUES (5, 6, 7) RETURNING x, y",
+ "DELETE FROM foo WHERE x > y RETURNING x, y",
+ "UPDATE foo SET x = 9 RETURNING x, y",
+ ],
+)
+def test_extract_column_names_from_returning_clause(sql):
+ assert extract_column_names(sql) == ("x", "y")
+
+
+def test_simple_cte_extraction():
+ sql = "WITH a AS (SELECT abc FROM xxx) SELECT * FROM a"
+ start_pos = len("WITH a AS ")
+ stop_pos = len("WITH a AS (SELECT abc FROM xxx)")
+ ctes, remainder = extract_ctes(sql)
+
+ assert tuple(ctes) == (("a", ("abc",), start_pos, stop_pos),)
+ assert remainder.strip() == "SELECT * FROM a"
+
+
+def test_cte_extraction_around_comments():
+ sql = """--blah blah blah
+ WITH a AS (SELECT abc def FROM x)
+ SELECT * FROM a"""
+ start_pos = len(
+ """--blah blah blah
+ WITH a AS """
+ )
+ stop_pos = len(
+ """--blah blah blah
+ WITH a AS (SELECT abc def FROM x)"""
+ )
+
+ ctes, remainder = extract_ctes(sql)
+ assert tuple(ctes) == (("a", ("def",), start_pos, stop_pos),)
+ assert remainder.strip() == "SELECT * FROM a"
+
+
+def test_multiple_cte_extraction():
+ sql = """WITH
+ x AS (SELECT abc, def FROM x),
+ y AS (SELECT ghi, jkl FROM y)
+ SELECT * FROM a, b"""
+
+ start1 = len(
+ """WITH
+ x AS """
+ )
+
+ stop1 = len(
+ """WITH
+ x AS (SELECT abc, def FROM x)"""
+ )
+
+ start2 = len(
+ """WITH
+ x AS (SELECT abc, def FROM x),
+ y AS """
+ )
+
+ stop2 = len(
+ """WITH
+ x AS (SELECT abc, def FROM x),
+ y AS (SELECT ghi, jkl FROM y)"""
+ )
+
+ ctes, remainder = extract_ctes(sql)
+ assert tuple(ctes) == (
+ ("x", ("abc", "def"), start1, stop1),
+ ("y", ("ghi", "jkl"), start2, stop2),
+ )
diff --git a/tests/parseutils/test_function_metadata.py b/tests/parseutils/test_function_metadata.py
new file mode 100644
index 0000000..0350e2a
--- /dev/null
+++ b/tests/parseutils/test_function_metadata.py
@@ -0,0 +1,19 @@
+from pgcli.packages.parseutils.meta import FunctionMetadata
+
+
+def test_function_metadata_eq():
+ f1 = FunctionMetadata(
+ "s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None
+ )
+ f2 = FunctionMetadata(
+ "s", "f", ["x"], ["integer"], [], "int", False, False, False, False, None
+ )
+ f3 = FunctionMetadata(
+ "s", "g", ["x"], ["integer"], [], "int", False, False, False, False, None
+ )
+ assert f1 == f2
+ assert f1 != f3
+ assert not (f1 != f2)
+ assert not (f1 == f3)
+ assert hash(f1) == hash(f2)
+ assert hash(f1) != hash(f3)
diff --git a/tests/parseutils/test_parseutils.py b/tests/parseutils/test_parseutils.py
new file mode 100644
index 0000000..349cbd0
--- /dev/null
+++ b/tests/parseutils/test_parseutils.py
@@ -0,0 +1,310 @@
+import pytest
+from pgcli.packages.parseutils import (
+ is_destructive,
+ parse_destructive_warning,
+ BASE_KEYWORDS,
+ ALL_KEYWORDS,
+)
+from pgcli.packages.parseutils.tables import extract_tables
+from pgcli.packages.parseutils.utils import find_prev_keyword, is_open_quote
+
+
+def test_empty_string():
+ tables = extract_tables("")
+ assert tables == ()
+
+
+def test_simple_select_single_table():
+ tables = extract_tables("select * from abc")
+ assert tables == ((None, "abc", None, False),)
+
+
+@pytest.mark.parametrize(
+ "sql", ['select * from "abc"."def"', 'select * from abc."def"']
+)
+def test_simple_select_single_table_schema_qualified_quoted_table(sql):
+ tables = extract_tables(sql)
+ assert tables == (("abc", "def", '"def"', False),)
+
+
+@pytest.mark.parametrize("sql", ["select * from abc.def", 'select * from "abc".def'])
+def test_simple_select_single_table_schema_qualified(sql):
+ tables = extract_tables(sql)
+ assert tables == (("abc", "def", None, False),)
+
+
+def test_simple_select_single_table_double_quoted():
+ tables = extract_tables('select * from "Abc"')
+ assert tables == ((None, "Abc", None, False),)
+
+
+def test_simple_select_multiple_tables():
+ tables = extract_tables("select * from abc, def")
+ assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)}
+
+
+def test_simple_select_multiple_tables_double_quoted():
+ tables = extract_tables('select * from "Abc", "Def"')
+ assert set(tables) == {(None, "Abc", None, False), (None, "Def", None, False)}
+
+
+def test_simple_select_single_table_deouble_quoted_aliased():
+ tables = extract_tables('select * from "Abc" a')
+ assert tables == ((None, "Abc", "a", False),)
+
+
+def test_simple_select_multiple_tables_deouble_quoted_aliased():
+ tables = extract_tables('select * from "Abc" a, "Def" d')
+ assert set(tables) == {(None, "Abc", "a", False), (None, "Def", "d", False)}
+
+
+def test_simple_select_multiple_tables_schema_qualified():
+ tables = extract_tables("select * from abc.def, ghi.jkl")
+ assert set(tables) == {("abc", "def", None, False), ("ghi", "jkl", None, False)}
+
+
+def test_simple_select_with_cols_single_table():
+ tables = extract_tables("select a,b from abc")
+ assert tables == ((None, "abc", None, False),)
+
+
+def test_simple_select_with_cols_single_table_schema_qualified():
+ tables = extract_tables("select a,b from abc.def")
+ assert tables == (("abc", "def", None, False),)
+
+
+def test_simple_select_with_cols_multiple_tables():
+ tables = extract_tables("select a,b from abc, def")
+ assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)}
+
+
+def test_simple_select_with_cols_multiple_qualified_tables():
+ tables = extract_tables("select a,b from abc.def, def.ghi")
+ assert set(tables) == {("abc", "def", None, False), ("def", "ghi", None, False)}
+
+
+def test_select_with_hanging_comma_single_table():
+ tables = extract_tables("select a, from abc")
+ assert tables == ((None, "abc", None, False),)
+
+
+def test_select_with_hanging_comma_multiple_tables():
+ tables = extract_tables("select a, from abc, def")
+ assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)}
+
+
+def test_select_with_hanging_period_multiple_tables():
+ tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2")
+ assert set(tables) == {(None, "tabl1", "t1", False), (None, "tabl2", "t2", False)}
+
+
+def test_simple_insert_single_table():
+ tables = extract_tables('insert into abc (id, name) values (1, "def")')
+
+ # sqlparse mistakenly assigns an alias to the table
+ # AND mistakenly identifies the field list as
+ # assert tables == ((None, 'abc', 'abc', False),)
+
+ assert tables == ((None, "abc", "abc", False),)
+
+
+@pytest.mark.xfail
+def test_simple_insert_single_table_schema_qualified():
+ tables = extract_tables('insert into abc.def (id, name) values (1, "def")')
+ assert tables == (("abc", "def", None, False),)
+
+
+def test_simple_update_table_no_schema():
+ tables = extract_tables("update abc set id = 1")
+ assert tables == ((None, "abc", None, False),)
+
+
+def test_simple_update_table_with_schema():
+ tables = extract_tables("update abc.def set id = 1")
+ assert tables == (("abc", "def", None, False),)
+
+
+@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"])
+def test_join_table(join_type):
+ sql = f"SELECT * FROM abc a {join_type} JOIN def d ON a.id = d.num"
+ tables = extract_tables(sql)
+ assert set(tables) == {(None, "abc", "a", False), (None, "def", "d", False)}
+
+
+def test_join_table_schema_qualified():
+ tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num")
+ assert set(tables) == {("abc", "def", "x", False), ("ghi", "jkl", "y", False)}
+
+
+def test_incomplete_join_clause():
+ sql = """select a.x, b.y
+ from abc a join bcd b
+ on a.id = """
+ tables = extract_tables(sql)
+ assert tables == ((None, "abc", "a", False), (None, "bcd", "b", False))
+
+
+def test_join_as_table():
+ tables = extract_tables("SELECT * FROM my_table AS m WHERE m.a > 5")
+ assert tables == ((None, "my_table", "m", False),)
+
+
+def test_multiple_joins():
+ sql = """select * from t1
+ inner join t2 ON
+ t1.id = t2.t1_id
+ inner join t3 ON
+ t2.id = t3."""
+ tables = extract_tables(sql)
+ assert tables == (
+ (None, "t1", None, False),
+ (None, "t2", None, False),
+ (None, "t3", None, False),
+ )
+
+
+def test_subselect_tables():
+ sql = "SELECT * FROM (SELECT FROM abc"
+ tables = extract_tables(sql)
+ assert tables == ((None, "abc", None, False),)
+
+
+@pytest.mark.parametrize("text", ["SELECT * FROM foo.", "SELECT 123 AS foo"])
+def test_extract_no_tables(text):
+ tables = extract_tables(text)
+ assert tables == tuple()
+
+
+@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
+def test_simple_function_as_table(arg_list):
+ tables = extract_tables(f"SELECT * FROM foo({arg_list})")
+ assert tables == ((None, "foo", None, True),)
+
+
+@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
+def test_simple_schema_qualified_function_as_table(arg_list):
+ tables = extract_tables(f"SELECT * FROM foo.bar({arg_list})")
+ assert tables == (("foo", "bar", None, True),)
+
+
+@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
+def test_simple_aliased_function_as_table(arg_list):
+ tables = extract_tables(f"SELECT * FROM foo({arg_list}) bar")
+ assert tables == ((None, "foo", "bar", True),)
+
+
+def test_simple_table_and_function():
+ tables = extract_tables("SELECT * FROM foo JOIN bar()")
+ assert set(tables) == {(None, "foo", None, False), (None, "bar", None, True)}
+
+
+def test_complex_table_and_function():
+ tables = extract_tables(
+ """SELECT * FROM foo.bar baz
+ JOIN bar.qux(x, y, z) quux"""
+ )
+ assert set(tables) == {("foo", "bar", "baz", False), ("bar", "qux", "quux", True)}
+
+
+def test_find_prev_keyword_using():
+ q = "select * from tbl1 inner join tbl2 using (col1, "
+ kw, q2 = find_prev_keyword(q)
+ assert kw.value == "(" and q2 == "select * from tbl1 inner join tbl2 using ("
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select * from foo where bar",
+ "select * from foo where bar = 1 and baz or ",
+ "select * from foo where bar = 1 and baz between qux and ",
+ ],
+)
+def test_find_prev_keyword_where(sql):
+ kw, stripped = find_prev_keyword(sql)
+ assert kw.value == "where" and stripped == "select * from foo where"
+
+
+@pytest.mark.parametrize(
+ "sql", ["create table foo (bar int, baz ", "select * from foo() as bar (baz "]
+)
+def test_find_prev_keyword_open_parens(sql):
+ kw, _ = find_prev_keyword(sql)
+ assert kw.value == "("
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "",
+ "$$ foo $$",
+ "$$ 'foo' $$",
+ '$$ "foo" $$',
+ "$$ $a$ $$",
+ "$a$ $$ $a$",
+ "foo bar $$ baz $$",
+ ],
+)
+def test_is_open_quote__closed(sql):
+ assert not is_open_quote(sql)
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "$$",
+ ";;;$$",
+ "foo $$ bar $$; foo $$",
+ "$$ foo $a$",
+ "foo 'bar baz",
+ "$a$ foo ",
+ '$$ "foo" ',
+ "$$ $a$ ",
+ "foo bar $$ baz",
+ ],
+)
+def test_is_open_quote__open(sql):
+ assert is_open_quote(sql)
+
+
+@pytest.mark.parametrize(
+ ("sql", "keywords", "expected"),
+ [
+ ("update abc set x = 1", ALL_KEYWORDS, True),
+ ("update abc set x = 1 where y = 2", ALL_KEYWORDS, True),
+ ("update abc set x = 1", BASE_KEYWORDS, True),
+ ("update abc set x = 1 where y = 2", BASE_KEYWORDS, False),
+ ("select x, y, z from abc", ALL_KEYWORDS, False),
+ ("drop abc", ALL_KEYWORDS, True),
+ ("alter abc", ALL_KEYWORDS, True),
+ ("delete abc", ALL_KEYWORDS, True),
+ ("truncate abc", ALL_KEYWORDS, True),
+ ("insert into abc values (1, 2, 3)", ALL_KEYWORDS, False),
+ ("insert into abc values (1, 2, 3)", BASE_KEYWORDS, False),
+ ("insert into abc values (1, 2, 3)", ["insert"], True),
+ ("insert into abc values (1, 2, 3)", ["insert"], True),
+ ],
+)
+def test_is_destructive(sql, keywords, expected):
+ assert is_destructive(sql, keywords) == expected
+
+
+@pytest.mark.parametrize(
+ ("warning_level", "expected"),
+ [
+ ("true", ALL_KEYWORDS),
+ ("false", []),
+ ("all", ALL_KEYWORDS),
+ ("moderate", BASE_KEYWORDS),
+ ("off", []),
+ ("", []),
+ (None, []),
+ (ALL_KEYWORDS, ALL_KEYWORDS),
+ (BASE_KEYWORDS, BASE_KEYWORDS),
+ ("insert", ["insert"]),
+ ("drop,alter,delete", ["drop", "alter", "delete"]),
+ (["drop", "alter", "delete"], ["drop", "alter", "delete"]),
+ ],
+)
+def test_parse_destructive_warning(warning_level, expected):
+ assert parse_destructive_warning(warning_level) == expected
diff --git a/tests/pytest.ini b/tests/pytest.ini
new file mode 100644
index 0000000..f787740
--- /dev/null
+++ b/tests/pytest.ini
@@ -0,0 +1,2 @@
+[pytest]
+addopts=--capture=sys --showlocals \ No newline at end of file
diff --git a/tests/test_auth.py b/tests/test_auth.py
new file mode 100644
index 0000000..a517a89
--- /dev/null
+++ b/tests/test_auth.py
@@ -0,0 +1,40 @@
+import pytest
+from unittest import mock
+from pgcli import auth
+
+
+@pytest.mark.parametrize("enabled,call_count", [(True, 1), (False, 0)])
+def test_keyring_initialize(enabled, call_count):
+ logger = mock.MagicMock()
+
+ with mock.patch("importlib.import_module", return_value=True) as import_method:
+ auth.keyring_initialize(enabled, logger=logger)
+ assert import_method.call_count == call_count
+
+
+def test_keyring_get_password_ok():
+ with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()):
+ with mock.patch("pgcli.auth.keyring.get_password", return_value="abc123"):
+ assert auth.keyring_get_password("test") == "abc123"
+
+
+def test_keyring_get_password_exception():
+ with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()):
+ with mock.patch(
+ "pgcli.auth.keyring.get_password", side_effect=Exception("Boom!")
+ ):
+ assert auth.keyring_get_password("test") == ""
+
+
+def test_keyring_set_password_ok():
+ with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()):
+ with mock.patch("pgcli.auth.keyring.set_password"):
+ auth.keyring_set_password("test", "abc123")
+
+
+def test_keyring_set_password_exception():
+ with mock.patch("pgcli.auth.keyring", return_value=mock.MagicMock()):
+ with mock.patch(
+ "pgcli.auth.keyring.set_password", side_effect=Exception("Boom!")
+ ):
+ auth.keyring_set_password("test", "abc123")
diff --git a/tests/test_completion_refresher.py b/tests/test_completion_refresher.py
new file mode 100644
index 0000000..a5529d6
--- /dev/null
+++ b/tests/test_completion_refresher.py
@@ -0,0 +1,95 @@
+import time
+import pytest
+from unittest.mock import Mock, patch
+
+
+@pytest.fixture
+def refresher():
+ from pgcli.completion_refresher import CompletionRefresher
+
+ return CompletionRefresher()
+
+
+def test_ctor(refresher):
+ """
+ Refresher object should contain a few handlers
+ :param refresher:
+ :return:
+ """
+ assert len(refresher.refreshers) > 0
+ actual_handlers = list(refresher.refreshers.keys())
+ expected_handlers = [
+ "schemata",
+ "tables",
+ "views",
+ "types",
+ "databases",
+ "casing",
+ "functions",
+ ]
+ assert expected_handlers == actual_handlers
+
+
+def test_refresh_called_once(refresher):
+ """
+
+ :param refresher:
+ :return:
+ """
+ callbacks = Mock()
+ pgexecute = Mock(**{"is_virtual_database.return_value": False})
+ special = Mock()
+
+ with patch.object(refresher, "_bg_refresh") as bg_refresh:
+ actual = refresher.refresh(pgexecute, special, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert len(actual) == 1
+ assert len(actual[0]) == 4
+ assert actual[0][3] == "Auto-completion refresh started in the background."
+ bg_refresh.assert_called_with(pgexecute, special, callbacks, None, None)
+
+
+def test_refresh_called_twice(refresher):
+ """
+ If refresh is called a second time, it should be restarted
+ :param refresher:
+ :return:
+ """
+ callbacks = Mock()
+
+ pgexecute = Mock(**{"is_virtual_database.return_value": False})
+ special = Mock()
+
+ def dummy_bg_refresh(*args):
+ time.sleep(3) # seconds
+
+ refresher._bg_refresh = dummy_bg_refresh
+
+ actual1 = refresher.refresh(pgexecute, special, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert len(actual1) == 1
+ assert len(actual1[0]) == 4
+ assert actual1[0][3] == "Auto-completion refresh started in the background."
+
+ actual2 = refresher.refresh(pgexecute, special, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert len(actual2) == 1
+ assert len(actual2[0]) == 4
+ assert actual2[0][3] == "Auto-completion refresh restarted."
+
+
+def test_refresh_with_callbacks(refresher):
+ """
+ Callbacks must be called
+ :param refresher:
+ """
+ callbacks = [Mock()]
+ pgexecute = Mock(**{"is_virtual_database.return_value": False})
+ pgexecute.extra_args = {}
+ special = Mock()
+
+ # Set refreshers to 0: we're not testing refresh logic here
+ refresher.refreshers = {}
+ refresher.refresh(pgexecute, special, callbacks)
+ time.sleep(1) # Wait for the thread to work.
+ assert callbacks[0].call_count == 1
diff --git a/tests/test_config.py b/tests/test_config.py
new file mode 100644
index 0000000..08fe74e
--- /dev/null
+++ b/tests/test_config.py
@@ -0,0 +1,43 @@
+import io
+import os
+import stat
+
+import pytest
+
+from pgcli.config import ensure_dir_exists, skip_initial_comment
+
+
+def test_ensure_file_parent(tmpdir):
+ subdir = tmpdir.join("subdir")
+ rcfile = subdir.join("rcfile")
+ ensure_dir_exists(str(rcfile))
+
+
+def test_ensure_existing_dir(tmpdir):
+ rcfile = str(tmpdir.mkdir("subdir").join("rcfile"))
+
+ # should just not raise
+ ensure_dir_exists(rcfile)
+
+
+def test_ensure_other_create_error(tmpdir):
+ subdir = tmpdir.join('subdir"')
+ rcfile = subdir.join("rcfile")
+
+ # trigger an oserror that isn't "directory already exists"
+ os.chmod(str(tmpdir), stat.S_IREAD)
+
+ with pytest.raises(OSError):
+ ensure_dir_exists(str(rcfile))
+
+
+@pytest.mark.parametrize(
+ "text, skipped_lines",
+ (
+ ("abc\n", 1),
+ ("#[section]\ndef\n[section]", 2),
+ ("[section]", 0),
+ ),
+)
+def test_skip_initial_comment(text, skipped_lines):
+ assert skip_initial_comment(io.StringIO(text)) == skipped_lines
diff --git a/tests/test_fuzzy_completion.py b/tests/test_fuzzy_completion.py
new file mode 100644
index 0000000..8f8f2cd
--- /dev/null
+++ b/tests/test_fuzzy_completion.py
@@ -0,0 +1,87 @@
+import pytest
+
+
+@pytest.fixture
+def completer():
+ import pgcli.pgcompleter as pgcompleter
+
+ return pgcompleter.PGCompleter()
+
+
+def test_ranking_ignores_identifier_quotes(completer):
+ """When calculating result rank, identifier quotes should be ignored.
+
+ The result ranking algorithm ignores identifier quotes. Without this
+ correction, the match "user", which Postgres requires to be quoted
+ since it is also a reserved word, would incorrectly fall below the
+ match user_action because the literal quotation marks in "user"
+ alter the position of the match.
+
+ This test checks that the fuzzy ranking algorithm correctly ignores
+ quotation marks when computing match ranks.
+
+ """
+
+ text = "user"
+ collection = ["user_action", '"user"']
+ matches = completer.find_matches(text, collection)
+ assert len(matches) == 2
+
+
+def test_ranking_based_on_shortest_match(completer):
+ """Fuzzy result rank should be based on shortest match.
+
+ Result ranking in fuzzy searching is partially based on the length
+ of matches: shorter matches are considered more relevant than
+ longer ones. When searching for the text 'user', the length
+ component of the match 'user_group' could be either 4 ('user') or
+ 7 ('user_gr').
+
+ This test checks that the fuzzy ranking algorithm uses the shorter
+ match when calculating result rank.
+
+ """
+
+ text = "user"
+ collection = ["api_user", "user_group"]
+ matches = completer.find_matches(text, collection)
+
+ assert matches[1].priority > matches[0].priority
+
+
+@pytest.mark.parametrize(
+ "collection",
+ [["user_action", "user"], ["user_group", "user"], ["user_group", "user_action"]],
+)
+def test_should_break_ties_using_lexical_order(completer, collection):
+ """Fuzzy result rank should use lexical order to break ties.
+
+ When fuzzy matching, if multiple matches have the same match length and
+ start position, present them in lexical (rather than arbitrary) order. For
+ example, if we have tables 'user', 'user_action', and 'user_group', a
+ search for the text 'user' should present these tables in this order.
+
+ The input collections to this test are out of order; each run checks that
+ the search text 'user' results in the input tables being reordered
+ lexically.
+
+ """
+
+ text = "user"
+ matches = completer.find_matches(text, collection)
+
+ assert matches[1].priority > matches[0].priority
+
+
+def test_matching_should_be_case_insensitive(completer):
+ """Fuzzy matching should keep matches even if letter casing doesn't match.
+
+ This test checks that variations of the text which have different casing
+ are still matched.
+ """
+
+ text = "foo"
+ collection = ["Foo", "FOO", "fOO"]
+ matches = completer.find_matches(text, collection)
+
+ assert len(matches) == 3
diff --git a/tests/test_main.py b/tests/test_main.py
new file mode 100644
index 0000000..cbf20a6
--- /dev/null
+++ b/tests/test_main.py
@@ -0,0 +1,490 @@
+import os
+import platform
+from unittest import mock
+
+import pytest
+
+try:
+ import setproctitle
+except ImportError:
+ setproctitle = None
+
+from pgcli.main import (
+ obfuscate_process_password,
+ format_output,
+ PGCli,
+ OutputSettings,
+ COLOR_CODE_REGEX,
+)
+from pgcli.pgexecute import PGExecute
+from pgspecial.main import PAGER_OFF, PAGER_LONG_OUTPUT, PAGER_ALWAYS
+from utils import dbtest, run
+from collections import namedtuple
+
+
+@pytest.mark.skipif(platform.system() == "Windows", reason="Not applicable in windows")
+@pytest.mark.skipif(not setproctitle, reason="setproctitle not available")
+def test_obfuscate_process_password():
+ original_title = setproctitle.getproctitle()
+
+ setproctitle.setproctitle("pgcli user=root password=secret host=localhost")
+ obfuscate_process_password()
+ title = setproctitle.getproctitle()
+ expected = "pgcli user=root password=xxxx host=localhost"
+ assert title == expected
+
+ setproctitle.setproctitle("pgcli user=root password=top secret host=localhost")
+ obfuscate_process_password()
+ title = setproctitle.getproctitle()
+ expected = "pgcli user=root password=xxxx host=localhost"
+ assert title == expected
+
+ setproctitle.setproctitle("pgcli user=root password=top secret")
+ obfuscate_process_password()
+ title = setproctitle.getproctitle()
+ expected = "pgcli user=root password=xxxx"
+ assert title == expected
+
+ setproctitle.setproctitle("pgcli postgres://root:secret@localhost/db")
+ obfuscate_process_password()
+ title = setproctitle.getproctitle()
+ expected = "pgcli postgres://root:xxxx@localhost/db"
+ assert title == expected
+
+ setproctitle.setproctitle(original_title)
+
+
+def test_format_output():
+ settings = OutputSettings(table_format="psql", dcmlfmt="d", floatfmt="g")
+ results = format_output(
+ "Title", [("abc", "def")], ["head1", "head2"], "test status", settings
+ )
+ expected = [
+ "Title",
+ "+-------+-------+",
+ "| head1 | head2 |",
+ "|-------+-------|",
+ "| abc | def |",
+ "+-------+-------+",
+ "test status",
+ ]
+ assert list(results) == expected
+
+
+def test_format_output_truncate_on():
+ settings = OutputSettings(
+ table_format="psql", dcmlfmt="d", floatfmt="g", max_field_width=10
+ )
+ results = format_output(
+ None,
+ [("first field value", "second field value")],
+ ["head1", "head2"],
+ None,
+ settings,
+ )
+ expected = [
+ "+------------+------------+",
+ "| head1 | head2 |",
+ "|------------+------------|",
+ "| first f... | second ... |",
+ "+------------+------------+",
+ ]
+ assert list(results) == expected
+
+
+def test_format_output_truncate_off():
+ settings = OutputSettings(
+ table_format="psql", dcmlfmt="d", floatfmt="g", max_field_width=None
+ )
+ long_field_value = ("first field " * 100).strip()
+ results = format_output(None, [(long_field_value,)], ["head1"], None, settings)
+ lines = list(results)
+ assert lines[3] == f"| {long_field_value} |"
+
+
+@dbtest
+def test_format_array_output(executor):
+ statement = """
+ SELECT
+ array[1, 2, 3]::bigint[] as bigint_array,
+ '{{1,2},{3,4}}'::numeric[] as nested_numeric_array,
+ '{å,魚,текст}'::text[] as 配列
+ UNION ALL
+ SELECT '{}', NULL, array[NULL]
+ """
+ results = run(executor, statement)
+ expected = [
+ "+--------------+----------------------+--------------+",
+ "| bigint_array | nested_numeric_array | 配列 |",
+ "|--------------+----------------------+--------------|",
+ "| {1,2,3} | {{1,2},{3,4}} | {å,魚,текст} |",
+ "| {} | <null> | {<null>} |",
+ "+--------------+----------------------+--------------+",
+ "SELECT 2",
+ ]
+ assert list(results) == expected
+
+
+@dbtest
+def test_format_array_output_expanded(executor):
+ statement = """
+ SELECT
+ array[1, 2, 3]::bigint[] as bigint_array,
+ '{{1,2},{3,4}}'::numeric[] as nested_numeric_array,
+ '{å,魚,текст}'::text[] as 配列
+ UNION ALL
+ SELECT '{}', NULL, array[NULL]
+ """
+ results = run(executor, statement, expanded=True)
+ expected = [
+ "-[ RECORD 1 ]-------------------------",
+ "bigint_array | {1,2,3}",
+ "nested_numeric_array | {{1,2},{3,4}}",
+ "配列 | {å,魚,текст}",
+ "-[ RECORD 2 ]-------------------------",
+ "bigint_array | {}",
+ "nested_numeric_array | <null>",
+ "配列 | {<null>}",
+ "SELECT 2",
+ ]
+ assert "\n".join(results) == "\n".join(expected)
+
+
+def test_format_output_auto_expand():
+ settings = OutputSettings(
+ table_format="psql", dcmlfmt="d", floatfmt="g", max_width=100
+ )
+ table_results = format_output(
+ "Title", [("abc", "def")], ["head1", "head2"], "test status", settings
+ )
+ table = [
+ "Title",
+ "+-------+-------+",
+ "| head1 | head2 |",
+ "|-------+-------|",
+ "| abc | def |",
+ "+-------+-------+",
+ "test status",
+ ]
+ assert list(table_results) == table
+ expanded_results = format_output(
+ "Title",
+ [("abc", "def")],
+ ["head1", "head2"],
+ "test status",
+ settings._replace(max_width=1),
+ )
+ expanded = [
+ "Title",
+ "-[ RECORD 1 ]-------------------------",
+ "head1 | abc",
+ "head2 | def",
+ "test status",
+ ]
+ assert "\n".join(expanded_results) == "\n".join(expanded)
+
+
+termsize = namedtuple("termsize", ["rows", "columns"])
+test_line = "-" * 10
+test_data = [
+ (10, 10, "\n".join([test_line] * 7)),
+ (10, 10, "\n".join([test_line] * 6)),
+ (10, 10, "\n".join([test_line] * 5)),
+ (10, 10, "-" * 11),
+ (10, 10, "-" * 10),
+ (10, 10, "-" * 9),
+]
+
+# 4 lines are reserved at the bottom of the terminal for pgcli's prompt
+use_pager_when_on = [True, True, False, True, False, False]
+
+# Can be replaced with pytest.param once we can upgrade pytest after Python 3.4 goes EOL
+test_ids = [
+ "Output longer than terminal height",
+ "Output equal to terminal height",
+ "Output shorter than terminal height",
+ "Output longer than terminal width",
+ "Output equal to terminal width",
+ "Output shorter than terminal width",
+]
+
+
+@pytest.fixture
+def pset_pager_mocks():
+ cli = PGCli()
+ cli.watch_command = None
+ with mock.patch("pgcli.main.click.echo") as mock_echo, mock.patch(
+ "pgcli.main.click.echo_via_pager"
+ ) as mock_echo_via_pager, mock.patch.object(cli, "prompt_app") as mock_app:
+ yield cli, mock_echo, mock_echo_via_pager, mock_app
+
+
+@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids)
+def test_pset_pager_off(term_height, term_width, text, pset_pager_mocks):
+ cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks
+ mock_cli.output.get_size.return_value = termsize(
+ rows=term_height, columns=term_width
+ )
+
+ with mock.patch.object(cli.pgspecial, "pager_config", PAGER_OFF):
+ cli.echo_via_pager(text)
+
+ mock_echo.assert_called()
+ mock_echo_via_pager.assert_not_called()
+
+
+@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids)
+def test_pset_pager_always(term_height, term_width, text, pset_pager_mocks):
+ cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks
+ mock_cli.output.get_size.return_value = termsize(
+ rows=term_height, columns=term_width
+ )
+
+ with mock.patch.object(cli.pgspecial, "pager_config", PAGER_ALWAYS):
+ cli.echo_via_pager(text)
+
+ mock_echo.assert_not_called()
+ mock_echo_via_pager.assert_called()
+
+
+pager_on_test_data = [l + (r,) for l, r in zip(test_data, use_pager_when_on)]
+
+
+@pytest.mark.parametrize(
+ "term_height,term_width,text,use_pager", pager_on_test_data, ids=test_ids
+)
+def test_pset_pager_on(term_height, term_width, text, use_pager, pset_pager_mocks):
+ cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks
+ mock_cli.output.get_size.return_value = termsize(
+ rows=term_height, columns=term_width
+ )
+
+ with mock.patch.object(cli.pgspecial, "pager_config", PAGER_LONG_OUTPUT):
+ cli.echo_via_pager(text)
+
+ if use_pager:
+ mock_echo.assert_not_called()
+ mock_echo_via_pager.assert_called()
+ else:
+ mock_echo_via_pager.assert_not_called()
+ mock_echo.assert_called()
+
+
+@pytest.mark.parametrize(
+ "text,expected_length",
+ [
+ (
+ "22200K .......\u001b[0m\u001b[91m... .......... ...\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... .........\u001b[0m\u001b[91m.\u001b[0m\u001b[91m \u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... 50% 28.6K 12m55s",
+ 78,
+ ),
+ ("=\u001b[m=", 2),
+ ("-\u001b]23\u0007-", 2),
+ ],
+)
+def test_color_pattern(text, expected_length, pset_pager_mocks):
+ cli = pset_pager_mocks[0]
+ assert len(COLOR_CODE_REGEX.sub("", text)) == expected_length
+
+
+@dbtest
+def test_i_works(tmpdir, executor):
+ sqlfile = tmpdir.join("test.sql")
+ sqlfile.write("SELECT NOW()")
+ rcfile = str(tmpdir.join("rcfile"))
+ cli = PGCli(pgexecute=executor, pgclirc_file=rcfile)
+ statement = r"\i {0}".format(sqlfile)
+ run(executor, statement, pgspecial=cli.pgspecial)
+
+
+@dbtest
+def test_echo_works(executor):
+ cli = PGCli(pgexecute=executor)
+ statement = r"\echo asdf"
+ result = run(executor, statement, pgspecial=cli.pgspecial)
+ assert result == ["asdf"]
+
+
+@dbtest
+def test_qecho_works(executor):
+ cli = PGCli(pgexecute=executor)
+ statement = r"\qecho asdf"
+ result = run(executor, statement, pgspecial=cli.pgspecial)
+ assert result == ["asdf"]
+
+
+@dbtest
+def test_watch_works(executor):
+ cli = PGCli(pgexecute=executor)
+
+ def run_with_watch(
+ query, target_call_count=1, expected_output="", expected_timing=None
+ ):
+ """
+ :param query: Input to the CLI
+ :param target_call_count: Number of times the user lets the command run before Ctrl-C
+ :param expected_output: Substring expected to be found for each executed query
+ :param expected_timing: value `time.sleep` expected to be called with on every invocation
+ """
+ with mock.patch.object(cli, "echo_via_pager") as mock_echo, mock.patch(
+ "pgcli.main.sleep"
+ ) as mock_sleep:
+ mock_sleep.side_effect = [None] * (target_call_count - 1) + [
+ KeyboardInterrupt
+ ]
+ cli.handle_watch_command(query)
+ # Validate that sleep was called with the right timing
+ for i in range(target_call_count - 1):
+ assert mock_sleep.call_args_list[i][0][0] == expected_timing
+ # Validate that the output of the query was expected
+ assert mock_echo.call_count == target_call_count
+ for i in range(target_call_count):
+ assert expected_output in mock_echo.call_args_list[i][0][0]
+
+ # With no history, it errors.
+ with mock.patch("pgcli.main.click.secho") as mock_secho:
+ cli.handle_watch_command(r"\watch 2")
+ mock_secho.assert_called()
+ assert (
+ r"\watch cannot be used with an empty query"
+ in mock_secho.call_args_list[0][0][0]
+ )
+
+ # Usage 1: Run a query and then re-run it with \watch across two prompts.
+ run_with_watch("SELECT 111", expected_output="111")
+ run_with_watch(
+ "\\watch 10", target_call_count=2, expected_output="111", expected_timing=10
+ )
+
+ # Usage 2: Run a query and \watch via the same prompt.
+ run_with_watch(
+ "SELECT 222; \\watch 4",
+ target_call_count=3,
+ expected_output="222",
+ expected_timing=4,
+ )
+
+ # Usage 3: Re-run the last watched command with a new timing
+ run_with_watch(
+ "\\watch 5", target_call_count=4, expected_output="222", expected_timing=5
+ )
+
+
+def test_missing_rc_dir(tmpdir):
+ rcfile = str(tmpdir.join("subdir").join("rcfile"))
+
+ PGCli(pgclirc_file=rcfile)
+ assert os.path.exists(rcfile)
+
+
+def test_quoted_db_uri(tmpdir):
+ with mock.patch.object(PGCli, "connect") as mock_connect:
+ cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
+ cli.connect_uri("postgres://bar%5E:%5Dfoo@baz.com/testdb%5B")
+ mock_connect.assert_called_with(
+ database="testdb[", host="baz.com", user="bar^", passwd="]foo"
+ )
+
+
+def test_pg_service_file(tmpdir):
+ with mock.patch.object(PGCli, "connect") as mock_connect:
+ cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
+ with open(tmpdir.join(".pg_service.conf").strpath, "w") as service_conf:
+ service_conf.write(
+ """File begins with a comment
+ that is not a comment
+ # or maybe a comment after all
+ because psql is crazy
+
+ [myservice]
+ host=a_host
+ user=a_user
+ port=5433
+ password=much_secure
+ dbname=a_dbname
+
+ [my_other_service]
+ host=b_host
+ user=b_user
+ port=5435
+ dbname=b_dbname
+ """
+ )
+ os.environ["PGSERVICEFILE"] = tmpdir.join(".pg_service.conf").strpath
+ cli.connect_service("myservice", "another_user")
+ mock_connect.assert_called_with(
+ database="a_dbname",
+ host="a_host",
+ user="another_user",
+ port="5433",
+ passwd="much_secure",
+ )
+
+ with mock.patch.object(PGExecute, "__init__") as mock_pgexecute:
+ mock_pgexecute.return_value = None
+ cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
+ os.environ["PGPASSWORD"] = "very_secure"
+ cli.connect_service("my_other_service", None)
+ mock_pgexecute.assert_called_with(
+ "b_dbname",
+ "b_user",
+ "very_secure",
+ "b_host",
+ "5435",
+ "",
+ application_name="pgcli",
+ )
+ del os.environ["PGPASSWORD"]
+ del os.environ["PGSERVICEFILE"]
+
+
+def test_ssl_db_uri(tmpdir):
+ with mock.patch.object(PGCli, "connect") as mock_connect:
+ cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
+ cli.connect_uri(
+ "postgres://bar%5E:%5Dfoo@baz.com/testdb%5B?"
+ "sslmode=verify-full&sslcert=m%79.pem&sslkey=my-key.pem&sslrootcert=c%61.pem"
+ )
+ mock_connect.assert_called_with(
+ database="testdb[",
+ host="baz.com",
+ user="bar^",
+ passwd="]foo",
+ sslmode="verify-full",
+ sslcert="my.pem",
+ sslkey="my-key.pem",
+ sslrootcert="ca.pem",
+ )
+
+
+def test_port_db_uri(tmpdir):
+ with mock.patch.object(PGCli, "connect") as mock_connect:
+ cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
+ cli.connect_uri("postgres://bar:foo@baz.com:2543/testdb")
+ mock_connect.assert_called_with(
+ database="testdb", host="baz.com", user="bar", passwd="foo", port="2543"
+ )
+
+
+def test_multihost_db_uri(tmpdir):
+ with mock.patch.object(PGCli, "connect") as mock_connect:
+ cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
+ cli.connect_uri(
+ "postgres://bar:foo@baz1.com:2543,baz2.com:2543,baz3.com:2543/testdb"
+ )
+ mock_connect.assert_called_with(
+ database="testdb",
+ host="baz1.com,baz2.com,baz3.com",
+ user="bar",
+ passwd="foo",
+ port="2543,2543,2543",
+ )
+
+
+def test_application_name_db_uri(tmpdir):
+ with mock.patch.object(PGExecute, "__init__") as mock_pgexecute:
+ mock_pgexecute.return_value = None
+ cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
+ cli.connect_uri("postgres://bar@baz.com/?application_name=cow")
+ mock_pgexecute.assert_called_with(
+ "bar", "bar", "", "baz.com", "", "", application_name="cow"
+ )
diff --git a/tests/test_naive_completion.py b/tests/test_naive_completion.py
new file mode 100644
index 0000000..5b93661
--- /dev/null
+++ b/tests/test_naive_completion.py
@@ -0,0 +1,133 @@
+import pytest
+from prompt_toolkit.completion import Completion
+from prompt_toolkit.document import Document
+from utils import completions_to_set
+
+
+@pytest.fixture
+def completer():
+ import pgcli.pgcompleter as pgcompleter
+
+ return pgcompleter.PGCompleter(smart_completion=False)
+
+
+@pytest.fixture
+def complete_event():
+ from unittest.mock import Mock
+
+ return Mock()
+
+
+def test_empty_string_completion(completer, complete_event):
+ text = ""
+ position = 0
+ result = completions_to_set(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == completions_to_set(map(Completion, completer.all_completions))
+
+
+def test_select_keyword_completion(completer, complete_event):
+ text = "SEL"
+ position = len("SEL")
+ result = completions_to_set(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == completions_to_set([Completion(text="SELECT", start_position=-3)])
+
+
+def test_function_name_completion(completer, complete_event):
+ text = "SELECT MA"
+ position = len("SELECT MA")
+ result = completions_to_set(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == completions_to_set(
+ [
+ Completion(text="MATERIALIZED VIEW", start_position=-2),
+ Completion(text="MAX", start_position=-2),
+ Completion(text="MAXEXTENTS", start_position=-2),
+ Completion(text="MAKE_DATE", start_position=-2),
+ Completion(text="MAKE_TIME", start_position=-2),
+ Completion(text="MAKE_TIMESTAMPTZ", start_position=-2),
+ Completion(text="MAKE_INTERVAL", start_position=-2),
+ Completion(text="MASKLEN", start_position=-2),
+ Completion(text="MAKE_TIMESTAMP", start_position=-2),
+ ]
+ )
+
+
+def test_column_name_completion(completer, complete_event):
+ text = "SELECT FROM users"
+ position = len("SELECT ")
+ result = completions_to_set(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ assert result == completions_to_set(map(Completion, completer.all_completions))
+
+
+def test_alter_well_known_keywords_completion(completer, complete_event):
+ text = "ALTER "
+ position = len(text)
+ result = completions_to_set(
+ completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event,
+ smart_completion=True,
+ )
+ )
+ assert result > completions_to_set(
+ [
+ Completion(text="DATABASE", display_meta="keyword"),
+ Completion(text="TABLE", display_meta="keyword"),
+ Completion(text="SYSTEM", display_meta="keyword"),
+ ]
+ )
+ assert (
+ completions_to_set([Completion(text="CREATE", display_meta="keyword")])
+ not in result
+ )
+
+
+def test_special_name_completion(completer, complete_event):
+ text = "\\"
+ position = len("\\")
+ result = completions_to_set(
+ completer.get_completions(
+ Document(text=text, cursor_position=position), complete_event
+ )
+ )
+ # Special commands will NOT be suggested during naive completion mode.
+ assert result == completions_to_set([])
+
+
+def test_datatype_name_completion(completer, complete_event):
+ text = "SELECT price::IN"
+ position = len("SELECT price::IN")
+ result = completions_to_set(
+ completer.get_completions(
+ Document(text=text, cursor_position=position),
+ complete_event,
+ smart_completion=True,
+ )
+ )
+ assert result == completions_to_set(
+ [
+ Completion(text="INET", display_meta="datatype"),
+ Completion(text="INT", display_meta="datatype"),
+ Completion(text="INT2", display_meta="datatype"),
+ Completion(text="INT4", display_meta="datatype"),
+ Completion(text="INT8", display_meta="datatype"),
+ Completion(text="INTEGER", display_meta="datatype"),
+ Completion(text="INTERNAL", display_meta="datatype"),
+ Completion(text="INTERVAL", display_meta="datatype"),
+ ]
+ )
diff --git a/tests/test_pgcompleter.py b/tests/test_pgcompleter.py
new file mode 100644
index 0000000..909fa0b
--- /dev/null
+++ b/tests/test_pgcompleter.py
@@ -0,0 +1,76 @@
+import pytest
+from pgcli import pgcompleter
+
+
+def test_load_alias_map_file_missing_file():
+ with pytest.raises(
+ pgcompleter.InvalidMapFile,
+ match=r"Cannot read alias_map_file - /path/to/non-existent/file.json does not exist$",
+ ):
+ pgcompleter.load_alias_map_file("/path/to/non-existent/file.json")
+
+
+def test_load_alias_map_file_invalid_json(tmp_path):
+ fpath = tmp_path / "foo.json"
+ fpath.write_text("this is not valid json")
+ with pytest.raises(pgcompleter.InvalidMapFile, match=r".*is not valid json$"):
+ pgcompleter.load_alias_map_file(str(fpath))
+
+
+@pytest.mark.parametrize(
+ "table_name, alias",
+ [
+ ("SomE_Table", "SET"),
+ ("SOmeTabLe", "SOTL"),
+ ("someTable", "T"),
+ ],
+)
+def test_generate_alias_uses_upper_case_letters_from_name(table_name, alias):
+ assert pgcompleter.generate_alias(table_name) == alias
+
+
+@pytest.mark.parametrize(
+ "table_name, alias",
+ [
+ ("some_tab_le", "stl"),
+ ("s_ome_table", "sot"),
+ ("sometable", "s"),
+ ],
+)
+def test_generate_alias_uses_first_char_and_every_preceded_by_underscore(
+ table_name, alias
+):
+ assert pgcompleter.generate_alias(table_name) == alias
+
+
+@pytest.mark.parametrize(
+ "table_name, alias_map, alias",
+ [
+ ("some_table", {"some_table": "my_alias"}, "my_alias"),
+ ],
+)
+def test_generate_alias_can_use_alias_map(table_name, alias_map, alias):
+ assert pgcompleter.generate_alias(table_name, alias_map) == alias
+
+
+@pytest.mark.parametrize(
+ "table_name, alias_map, alias",
+ [
+ ("SomeTable", {"SomeTable": "my_alias"}, "my_alias"),
+ ],
+)
+def test_generate_alias_prefers_alias_over_upper_case_name(
+ table_name, alias_map, alias
+):
+ assert pgcompleter.generate_alias(table_name, alias_map) == alias
+
+
+@pytest.mark.parametrize(
+ "table_name, alias",
+ [
+ ("Some_tablE", "SE"),
+ ("SomeTab_le", "ST"),
+ ],
+)
+def test_generate_alias_prefers_upper_case_name_over_underscore_name(table_name, alias):
+ assert pgcompleter.generate_alias(table_name) == alias
diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py
new file mode 100644
index 0000000..636795b
--- /dev/null
+++ b/tests/test_pgexecute.py
@@ -0,0 +1,773 @@
+from textwrap import dedent
+
+import psycopg
+import pytest
+from unittest.mock import patch, MagicMock
+from pgspecial.main import PGSpecial, NO_QUERY
+from utils import run, dbtest, requires_json, requires_jsonb
+
+from pgcli.main import PGCli
+from pgcli.packages.parseutils.meta import FunctionMetadata
+
+
+def function_meta_data(
+ func_name,
+ schema_name="public",
+ arg_names=None,
+ arg_types=None,
+ arg_modes=None,
+ return_type=None,
+ is_aggregate=False,
+ is_window=False,
+ is_set_returning=False,
+ is_extension=False,
+ arg_defaults=None,
+):
+ return FunctionMetadata(
+ schema_name,
+ func_name,
+ arg_names,
+ arg_types,
+ arg_modes,
+ return_type,
+ is_aggregate,
+ is_window,
+ is_set_returning,
+ is_extension,
+ arg_defaults,
+ )
+
+
+@dbtest
+def test_conn(executor):
+ run(executor, """create table test(a text)""")
+ run(executor, """insert into test values('abc')""")
+ assert run(executor, """select * from test""", join=True) == dedent(
+ """\
+ +-----+
+ | a |
+ |-----|
+ | abc |
+ +-----+
+ SELECT 1"""
+ )
+
+
+@dbtest
+def test_copy(executor):
+ executor_copy = executor.copy()
+ run(executor_copy, """create table test(a text)""")
+ run(executor_copy, """insert into test values('abc')""")
+ assert run(executor_copy, """select * from test""", join=True) == dedent(
+ """\
+ +-----+
+ | a |
+ |-----|
+ | abc |
+ +-----+
+ SELECT 1"""
+ )
+
+
+@dbtest
+def test_bools_are_treated_as_strings(executor):
+ run(executor, """create table test(a boolean)""")
+ run(executor, """insert into test values(True)""")
+ assert run(executor, """select * from test""", join=True) == dedent(
+ """\
+ +------+
+ | a |
+ |------|
+ | True |
+ +------+
+ SELECT 1"""
+ )
+
+
+@dbtest
+def test_expanded_slash_G(executor, pgspecial):
+ # Tests whether we reset the expanded output after a \G.
+ run(executor, """create table test(a boolean)""")
+ run(executor, """insert into test values(True)""")
+ results = run(executor, r"""select * from test \G""", pgspecial=pgspecial)
+ assert pgspecial.expanded_output == False
+
+
+@dbtest
+def test_schemata_table_views_and_columns_query(executor):
+ run(executor, "create table a(x text, y text)")
+ run(executor, "create table b(z text)")
+ run(executor, "create view d as select 1 as e")
+ run(executor, "create schema schema1")
+ run(executor, "create table schema1.c (w text DEFAULT 'meow')")
+ run(executor, "create schema schema2")
+
+ # schemata
+ # don't enforce all members of the schemas since they may include postgres
+ # temporary schemas
+ assert set(executor.schemata()) >= {
+ "public",
+ "pg_catalog",
+ "information_schema",
+ "schema1",
+ "schema2",
+ }
+ assert executor.search_path() == ["pg_catalog", "public"]
+
+ # tables
+ assert set(executor.tables()) >= {
+ ("public", "a"),
+ ("public", "b"),
+ ("schema1", "c"),
+ }
+
+ assert set(executor.table_columns()) >= {
+ ("public", "a", "x", "text", False, None),
+ ("public", "a", "y", "text", False, None),
+ ("public", "b", "z", "text", False, None),
+ ("schema1", "c", "w", "text", True, "'meow'::text"),
+ }
+
+ # views
+ assert set(executor.views()) >= {("public", "d")}
+
+ assert set(executor.view_columns()) >= {
+ ("public", "d", "e", "integer", False, None)
+ }
+
+
+@dbtest
+def test_foreign_key_query(executor):
+ run(executor, "create schema schema1")
+ run(executor, "create schema schema2")
+ run(executor, "create table schema1.parent(parentid int PRIMARY KEY)")
+ run(
+ executor,
+ "create table schema2.child(childid int PRIMARY KEY, motherid int REFERENCES schema1.parent)",
+ )
+
+ assert set(executor.foreignkeys()) >= {
+ ("schema1", "parent", "parentid", "schema2", "child", "motherid")
+ }
+
+
+@dbtest
+def test_functions_query(executor):
+ run(
+ executor,
+ """create function func1() returns int
+ language sql as $$select 1$$""",
+ )
+ run(executor, "create schema schema1")
+ run(
+ executor,
+ """create function schema1.func2() returns int
+ language sql as $$select 2$$""",
+ )
+
+ run(
+ executor,
+ """create function func3()
+ returns table(x int, y int) language sql
+ as $$select 1, 2 from generate_series(1,5)$$;""",
+ )
+
+ run(
+ executor,
+ """create function func4(x int) returns setof int language sql
+ as $$select generate_series(1,5)$$;""",
+ )
+
+ funcs = set(executor.functions())
+ assert funcs >= {
+ function_meta_data(func_name="func1", return_type="integer"),
+ function_meta_data(
+ func_name="func3",
+ arg_names=["x", "y"],
+ arg_types=["integer", "integer"],
+ arg_modes=["t", "t"],
+ return_type="record",
+ is_set_returning=True,
+ ),
+ function_meta_data(
+ schema_name="public",
+ func_name="func4",
+ arg_names=("x",),
+ arg_types=("integer",),
+ return_type="integer",
+ is_set_returning=True,
+ ),
+ function_meta_data(
+ schema_name="schema1", func_name="func2", return_type="integer"
+ ),
+ }
+
+
+@dbtest
+def test_datatypes_query(executor):
+ run(executor, "create type foo AS (a int, b text)")
+
+ types = list(executor.datatypes())
+ assert types == [("public", "foo")]
+
+
+@dbtest
+def test_database_list(executor):
+ databases = executor.databases()
+ assert "_test_db" in databases
+
+
+@dbtest
+def test_invalid_syntax(executor, exception_formatter):
+ result = run(executor, "invalid syntax!", exception_formatter=exception_formatter)
+ assert 'syntax error at or near "invalid"' in result[0]
+
+
+@dbtest
+def test_invalid_column_name(executor, exception_formatter):
+ result = run(
+ executor, "select invalid command", exception_formatter=exception_formatter
+ )
+ assert 'column "invalid" does not exist' in result[0]
+
+
+@pytest.fixture(params=[True, False])
+def expanded(request):
+ return request.param
+
+
+@dbtest
+def test_unicode_support_in_output(executor, expanded):
+ run(executor, "create table unicodechars(t text)")
+ run(executor, "insert into unicodechars (t) values ('é')")
+
+ # See issue #24, this raises an exception without proper handling
+ assert "é" in run(
+ executor, "select * from unicodechars", join=True, expanded=expanded
+ )
+
+
+@dbtest
+def test_not_is_special(executor, pgspecial):
+ """is_special is set to false for database queries."""
+ query = "select 1"
+ result = list(executor.run(query, pgspecial=pgspecial))
+ success, is_special = result[0][5:]
+ assert success == True
+ assert is_special == False
+
+
+@dbtest
+def test_execute_from_file_no_arg(executor, pgspecial):
+ r"""\i without a filename returns an error."""
+ result = list(executor.run(r"\i", pgspecial=pgspecial))
+ status, sql, success, is_special = result[0][3:]
+ assert "missing required argument" in status
+ assert success == False
+ assert is_special == True
+
+
+@dbtest
+@patch("pgcli.main.os")
+def test_execute_from_file_io_error(os, executor, pgspecial):
+ r"""\i with an os_error returns an error."""
+ # Inject an OSError.
+ os.path.expanduser.side_effect = OSError("test")
+
+ # Check the result.
+ result = list(executor.run(r"\i test", pgspecial=pgspecial))
+ status, sql, success, is_special = result[0][3:]
+ assert status == "test"
+ assert success == False
+ assert is_special == True
+
+
+@dbtest
+def test_execute_from_commented_file_that_executes_another_file(
+ executor, pgspecial, tmpdir
+):
+ # https://github.com/dbcli/pgcli/issues/1336
+ sqlfile1 = tmpdir.join("test01.sql")
+ sqlfile1.write("-- asdf \n\\h")
+ sqlfile2 = tmpdir.join("test00.sql")
+ sqlfile2.write("--An useless comment;\nselect now();\n-- another useless comment")
+
+ rcfile = str(tmpdir.join("rcfile"))
+ print(rcfile)
+ cli = PGCli(pgexecute=executor, pgclirc_file=rcfile)
+ assert cli != None
+ statement = "--comment\n\\h"
+ result = run(executor, statement, pgspecial=cli.pgspecial)
+ assert result != None
+ assert result[0].find("ALTER TABLE")
+
+
+@dbtest
+def test_execute_commented_first_line_and_special(executor, pgspecial, tmpdir):
+ # just some base cases that should work also
+ statement = "--comment\nselect now();"
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("now") >= 0
+
+ statement = "/*comment*/\nselect now();"
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("now") >= 0
+
+ # https://github.com/dbcli/pgcli/issues/1362
+ statement = "--comment\n\\h"
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("ALTER") >= 0
+ assert result[1].find("ABORT") >= 0
+
+ statement = "--comment1\n--comment2\n\\h"
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("ALTER") >= 0
+ assert result[1].find("ABORT") >= 0
+
+ statement = "/*comment*/\n\h;"
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("ALTER") >= 0
+ assert result[1].find("ABORT") >= 0
+
+ statement = """/*comment1
+ comment2*/
+ \h"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("ALTER") >= 0
+ assert result[1].find("ABORT") >= 0
+
+ statement = """/*comment1
+ comment2*/
+ /*comment 3
+ comment4*/
+ \\h"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("ALTER") >= 0
+ assert result[1].find("ABORT") >= 0
+
+ statement = " /*comment*/\n\h;"
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("ALTER") >= 0
+ assert result[1].find("ABORT") >= 0
+
+ statement = "/*comment\ncomment line2*/\n\h;"
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("ALTER") >= 0
+ assert result[1].find("ABORT") >= 0
+
+ statement = " /*comment\ncomment line2*/\n\h;"
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("ALTER") >= 0
+ assert result[1].find("ABORT") >= 0
+
+ statement = """\\h /*comment4 */"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ print(result)
+ assert result != None
+ assert result[0].find("No help") >= 0
+
+ # TODO: we probably don't want to do this but sqlparse is not parsing things well
+ # we relly want it to find help but right now, sqlparse isn't dropping the /*comment*/
+ # style comments after command
+
+ statement = """/*comment1*/
+ \h
+ /*comment4 */"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[0].find("No help") >= 0
+
+ # TODO: same for this one
+ statement = """/*comment1
+ comment3
+ comment2*/
+ \\h
+ /*comment4
+ comment5
+ comment6*/"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[0].find("No help") >= 0
+
+
+@dbtest
+def test_execute_commented_first_line_and_normal(executor, pgspecial, tmpdir):
+ # https://github.com/dbcli/pgcli/issues/1403
+
+ # just some base cases that should work also
+ statement = "--comment\nselect now();"
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("now") >= 0
+
+ statement = "/*comment*/\nselect now();"
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[1].find("now") >= 0
+
+ # this simulates the original error (1403) without having to add/drop tables
+ # since it was just an error on reading input files and not the actual
+ # command itself
+
+ # test that the statement works
+ statement = """VALUES (1, 'one'), (2, 'two'), (3, 'three');"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[5].find("three") >= 0
+
+ # test the statement with a \n in the middle
+ statement = """VALUES (1, 'one'),\n (2, 'two'), (3, 'three');"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[5].find("three") >= 0
+
+ # test the statement with a newline in the middle
+ statement = """VALUES (1, 'one'),
+ (2, 'two'), (3, 'three');"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[5].find("three") >= 0
+
+ # now add a single comment line
+ statement = """--comment\nVALUES (1, 'one'), (2, 'two'), (3, 'three');"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[5].find("three") >= 0
+
+ # doing without special char \n
+ statement = """--comment
+ VALUES (1,'one'),
+ (2, 'two'), (3, 'three');"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[5].find("three") >= 0
+
+ # two comment lines
+ statement = """--comment\n--comment2\nVALUES (1,'one'), (2, 'two'), (3, 'three');"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[5].find("three") >= 0
+
+ # doing without special char \n
+ statement = """--comment
+ --comment2
+ VALUES (1,'one'), (2, 'two'), (3, 'three');
+ """
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[5].find("three") >= 0
+
+ # multiline comment + newline in middle of the statement
+ statement = """/*comment
+comment2
+comment3*/
+VALUES (1,'one'),
+(2, 'two'), (3, 'three');"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[5].find("three") >= 0
+
+ # multiline comment + newline in middle of the statement
+ # + comments after the statement
+ statement = """/*comment
+comment2
+comment3*/
+VALUES (1,'one'),
+(2, 'two'), (3, 'three');
+--comment4
+--comment5"""
+ result = run(executor, statement, pgspecial=pgspecial)
+ assert result != None
+ assert result[5].find("three") >= 0
+
+
+@dbtest
+def test_multiple_queries_same_line(executor):
+ result = run(executor, "select 'foo'; select 'bar'")
+ assert len(result) == 12 # 2 * (output+status) * 3 lines
+ assert "foo" in result[3]
+ assert "bar" in result[9]
+
+
+@dbtest
+def test_multiple_queries_with_special_command_same_line(executor, pgspecial):
+ result = run(executor, r"select 'foo'; \d", pgspecial=pgspecial)
+ assert len(result) == 11 # 2 * (output+status) * 3 lines
+ assert "foo" in result[3]
+ # This is a lame check. :(
+ assert "Schema" in result[7]
+
+
+@dbtest
+def test_multiple_queries_same_line_syntaxerror(executor, exception_formatter):
+ result = run(
+ executor,
+ "select 'fooé'; invalid syntax é",
+ exception_formatter=exception_formatter,
+ )
+ assert "fooé" in result[3]
+ assert 'syntax error at or near "invalid"' in result[-1]
+
+
+@pytest.fixture
+def pgspecial():
+ return PGCli().pgspecial
+
+
+@dbtest
+def test_special_command_help(executor, pgspecial):
+ result = run(executor, "\\?", pgspecial=pgspecial)[1].split("|")
+ assert "Command" in result[1]
+ assert "Description" in result[2]
+
+
+@dbtest
+def test_bytea_field_support_in_output(executor):
+ run(executor, "create table binarydata(c bytea)")
+ run(executor, "insert into binarydata (c) values (decode('DEADBEEF', 'hex'))")
+
+ assert "\\xdeadbeef" in run(executor, "select * from binarydata", join=True)
+
+
+@dbtest
+def test_unicode_support_in_unknown_type(executor):
+ assert "日本語" in run(executor, "SELECT '日本語' AS japanese;", join=True)
+
+
+@dbtest
+def test_unicode_support_in_enum_type(executor):
+ run(executor, "CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy', '日本語')")
+ run(executor, "CREATE TABLE person (name TEXT, current_mood mood)")
+ run(executor, "INSERT INTO person VALUES ('Moe', '日本語')")
+ assert "日本語" in run(executor, "SELECT * FROM person", join=True)
+
+
+@requires_json
+def test_json_renders_without_u_prefix(executor, expanded):
+ run(executor, "create table jsontest(d json)")
+ run(executor, """insert into jsontest (d) values ('{"name": "Éowyn"}')""")
+ result = run(
+ executor, "SELECT d FROM jsontest LIMIT 1", join=True, expanded=expanded
+ )
+
+ assert '{"name": "Éowyn"}' in result
+
+
+@requires_jsonb
+def test_jsonb_renders_without_u_prefix(executor, expanded):
+ run(executor, "create table jsonbtest(d jsonb)")
+ run(executor, """insert into jsonbtest (d) values ('{"name": "Éowyn"}')""")
+ result = run(
+ executor, "SELECT d FROM jsonbtest LIMIT 1", join=True, expanded=expanded
+ )
+
+ assert '{"name": "Éowyn"}' in result
+
+
+@dbtest
+def test_date_time_types(executor):
+ run(executor, "SET TIME ZONE UTC")
+ assert (
+ run(executor, "SELECT (CAST('00:00:00' AS time))", join=True).split("\n")[3]
+ == "| 00:00:00 |"
+ )
+ assert (
+ run(executor, "SELECT (CAST('00:00:00+14:59' AS timetz))", join=True).split(
+ "\n"
+ )[3]
+ == "| 00:00:00+14:59 |"
+ )
+ assert (
+ run(executor, "SELECT (CAST('4713-01-01 BC' AS date))", join=True).split("\n")[
+ 3
+ ]
+ == "| 4713-01-01 BC |"
+ )
+ assert (
+ run(
+ executor, "SELECT (CAST('4713-01-01 00:00:00 BC' AS timestamp))", join=True
+ ).split("\n")[3]
+ == "| 4713-01-01 00:00:00 BC |"
+ )
+ assert (
+ run(
+ executor,
+ "SELECT (CAST('4713-01-01 00:00:00+00 BC' AS timestamptz))",
+ join=True,
+ ).split("\n")[3]
+ == "| 4713-01-01 00:00:00+00 BC |"
+ )
+ assert (
+ run(
+ executor, "SELECT (CAST('-123456789 days 12:23:56' AS interval))", join=True
+ ).split("\n")[3]
+ == "| -123456789 days, 12:23:56 |"
+ )
+
+
+@dbtest
+@pytest.mark.parametrize("value", ["10000000", "10000000.0", "10000000000000"])
+def test_large_numbers_render_directly(executor, value):
+ run(executor, "create table numbertest(a numeric)")
+ run(executor, f"insert into numbertest (a) values ({value})")
+ assert value in run(executor, "select * from numbertest", join=True)
+
+
+@dbtest
+@pytest.mark.parametrize("command", ["di", "dv", "ds", "df", "dT"])
+@pytest.mark.parametrize("verbose", ["", "+"])
+@pytest.mark.parametrize("pattern", ["", "x", "*.*", "x.y", "x.*", "*.y"])
+def test_describe_special(executor, command, verbose, pattern, pgspecial):
+ # We don't have any tests for the output of any of the special commands,
+ # but we can at least make sure they run without error
+ sql = r"\{command}{verbose} {pattern}".format(**locals())
+ list(executor.run(sql, pgspecial=pgspecial))
+
+
+@dbtest
+@pytest.mark.parametrize("sql", ["invalid sql", "SELECT 1; select error;"])
+def test_raises_with_no_formatter(executor, sql):
+ with pytest.raises(psycopg.ProgrammingError):
+ list(executor.run(sql))
+
+
+@dbtest
+def test_on_error_resume(executor, exception_formatter):
+ sql = "select 1; error; select 1;"
+ result = list(
+ executor.run(sql, on_error_resume=True, exception_formatter=exception_formatter)
+ )
+ assert len(result) == 3
+
+
+@dbtest
+def test_on_error_stop(executor, exception_formatter):
+ sql = "select 1; error; select 1;"
+ result = list(
+ executor.run(
+ sql, on_error_resume=False, exception_formatter=exception_formatter
+ )
+ )
+ assert len(result) == 2
+
+
+# @dbtest
+# def test_unicode_notices(executor):
+# sql = "DO language plpgsql $$ BEGIN RAISE NOTICE '有人更改'; END $$;"
+# result = list(executor.run(sql))
+# assert result[0][0] == u'NOTICE: 有人更改\n'
+
+
+@dbtest
+def test_nonexistent_function_definition(executor):
+ with pytest.raises(RuntimeError):
+ result = executor.view_definition("there_is_no_such_function")
+
+
+@dbtest
+def test_function_definition(executor):
+ run(
+ executor,
+ """
+ CREATE OR REPLACE FUNCTION public.the_number_three()
+ RETURNS int
+ LANGUAGE sql
+ AS $function$
+ select 3;
+ $function$
+ """,
+ )
+ result = executor.function_definition("the_number_three")
+
+
+@dbtest
+def test_view_definition(executor):
+ run(executor, "create table tbl1 (a text, b numeric)")
+ run(executor, "create view vw1 AS SELECT * FROM tbl1")
+ run(executor, "create materialized view mvw1 AS SELECT * FROM tbl1")
+ result = executor.view_definition("vw1")
+ assert 'VIEW "public"."vw1" AS' in result
+ assert "FROM tbl1" in result
+ # import pytest; pytest.set_trace()
+ result = executor.view_definition("mvw1")
+ assert "MATERIALIZED VIEW" in result
+
+
+@dbtest
+def test_nonexistent_view_definition(executor):
+ with pytest.raises(RuntimeError):
+ result = executor.view_definition("there_is_no_such_view")
+ with pytest.raises(RuntimeError):
+ result = executor.view_definition("mvw1")
+
+
+@dbtest
+def test_short_host(executor):
+ with patch.object(executor, "host", "localhost"):
+ assert executor.short_host == "localhost"
+ with patch.object(executor, "host", "localhost.example.org"):
+ assert executor.short_host == "localhost"
+ with patch.object(
+ executor, "host", "localhost1.example.org,localhost2.example.org"
+ ):
+ assert executor.short_host == "localhost1"
+
+
+class VirtualCursor:
+ """Mock a cursor to virtual database like pgbouncer."""
+
+ def __init__(self):
+ self.protocol_error = False
+ self.protocol_message = ""
+ self.description = None
+ self.status = None
+ self.statusmessage = "Error"
+
+ def execute(self, *args, **kwargs):
+ self.protocol_error = True
+ self.protocol_message = "Command not supported"
+
+
+@dbtest
+def test_exit_without_active_connection(executor):
+ quit_handler = MagicMock()
+ pgspecial = PGSpecial()
+ pgspecial.register(
+ quit_handler,
+ "\\q",
+ "\\q",
+ "Quit pgcli.",
+ arg_type=NO_QUERY,
+ case_sensitive=True,
+ aliases=(":q",),
+ )
+
+ with patch.object(
+ executor.conn, "cursor", side_effect=psycopg.InterfaceError("I'm broken!")
+ ):
+ # we should be able to quit the app, even without active connection
+ run(executor, "\\q", pgspecial=pgspecial)
+ quit_handler.assert_called_once()
+
+ # an exception should be raised when running a query without active connection
+ with pytest.raises(psycopg.InterfaceError):
+ run(executor, "select 1", pgspecial=pgspecial)
+
+
+@dbtest
+def test_virtual_database(executor):
+ virtual_connection = MagicMock()
+ virtual_connection.cursor.return_value = VirtualCursor()
+ with patch.object(executor, "conn", virtual_connection):
+ result = run(executor, "select 1")
+ assert "Command not supported" in result
diff --git a/tests/test_pgspecial.py b/tests/test_pgspecial.py
new file mode 100644
index 0000000..cd99e32
--- /dev/null
+++ b/tests/test_pgspecial.py
@@ -0,0 +1,78 @@
+import pytest
+from pgcli.packages.sqlcompletion import (
+ suggest_type,
+ Special,
+ Database,
+ Schema,
+ Table,
+ View,
+ Function,
+ Datatype,
+)
+
+
+def test_slash_suggests_special():
+ suggestions = suggest_type("\\", "\\")
+ assert set(suggestions) == {Special()}
+
+
+def test_slash_d_suggests_special():
+ suggestions = suggest_type("\\d", "\\d")
+ assert set(suggestions) == {Special()}
+
+
+def test_dn_suggests_schemata():
+ suggestions = suggest_type("\\dn ", "\\dn ")
+ assert suggestions == (Schema(),)
+
+ suggestions = suggest_type("\\dn xxx", "\\dn xxx")
+ assert suggestions == (Schema(),)
+
+
+def test_d_suggests_tables_views_and_schemas():
+ suggestions = suggest_type(r"\d ", r"\d ")
+ assert set(suggestions) == {Schema(), Table(schema=None), View(schema=None)}
+
+ suggestions = suggest_type(r"\d xxx", r"\d xxx")
+ assert set(suggestions) == {Schema(), Table(schema=None), View(schema=None)}
+
+
+def test_d_dot_suggests_schema_qualified_tables_or_views():
+ suggestions = suggest_type(r"\d myschema.", r"\d myschema.")
+ assert set(suggestions) == {Table(schema="myschema"), View(schema="myschema")}
+
+ suggestions = suggest_type(r"\d myschema.xxx", r"\d myschema.xxx")
+ assert set(suggestions) == {Table(schema="myschema"), View(schema="myschema")}
+
+
+def test_df_suggests_schema_or_function():
+ suggestions = suggest_type("\\df xxx", "\\df xxx")
+ assert set(suggestions) == {Function(schema=None, usage="special"), Schema()}
+
+ suggestions = suggest_type("\\df myschema.xxx", "\\df myschema.xxx")
+ assert suggestions == (Function(schema="myschema", usage="special"),)
+
+
+def test_leading_whitespace_ok():
+ cmd = "\\dn "
+ whitespace = " "
+ suggestions = suggest_type(whitespace + cmd, whitespace + cmd)
+ assert suggestions == suggest_type(cmd, cmd)
+
+
+def test_dT_suggests_schema_or_datatypes():
+ text = "\\dT "
+ suggestions = suggest_type(text, text)
+ assert set(suggestions) == {Schema(), Datatype(schema=None)}
+
+
+def test_schema_qualified_dT_suggests_datatypes():
+ text = "\\dT foo."
+ suggestions = suggest_type(text, text)
+ assert suggestions == (Datatype(schema="foo"),)
+
+
+@pytest.mark.parametrize("command", ["\\c ", "\\connect "])
+def test_c_suggests_databases(command):
+ suggestions = suggest_type(command, command)
+ assert suggestions == (Database(),)
diff --git a/tests/test_prioritization.py b/tests/test_prioritization.py
new file mode 100644
index 0000000..f5b6700
--- /dev/null
+++ b/tests/test_prioritization.py
@@ -0,0 +1,20 @@
+from pgcli.packages.prioritization import PrevalenceCounter
+
+
+def test_prevalence_counter():
+ counter = PrevalenceCounter()
+ sql = """SELECT * FROM foo WHERE bar GROUP BY baz;
+ select * from foo;
+ SELECT * FROM foo WHERE bar GROUP
+ BY baz"""
+ counter.update(sql)
+
+ keywords = ["SELECT", "FROM", "GROUP BY"]
+ expected = [3, 3, 2]
+ kw_counts = [counter.keyword_count(x) for x in keywords]
+ assert kw_counts == expected
+ assert counter.keyword_count("NOSUCHKEYWORD") == 0
+
+ names = ["foo", "bar", "baz"]
+ name_counts = [counter.name_count(x) for x in names]
+ assert name_counts == [3, 2, 2]
diff --git a/tests/test_prompt_utils.py b/tests/test_prompt_utils.py
new file mode 100644
index 0000000..91abe37
--- /dev/null
+++ b/tests/test_prompt_utils.py
@@ -0,0 +1,17 @@
+import click
+
+from pgcli.packages.prompt_utils import confirm_destructive_query
+
+
+def test_confirm_destructive_query_notty():
+ stdin = click.get_text_stream("stdin")
+ if not stdin.isatty():
+ sql = "drop database foo;"
+ assert confirm_destructive_query(sql, [], None) is None
+
+
+def test_confirm_destructive_query_with_alias():
+ stdin = click.get_text_stream("stdin")
+ if not stdin.isatty():
+ sql = "drop database foo;"
+ assert confirm_destructive_query(sql, ["drop"], "test") is None
diff --git a/tests/test_rowlimit.py b/tests/test_rowlimit.py
new file mode 100644
index 0000000..da916b4
--- /dev/null
+++ b/tests/test_rowlimit.py
@@ -0,0 +1,79 @@
+import pytest
+from unittest.mock import Mock
+
+from pgcli.main import PGCli
+
+
+# We need this fixtures because we need PGCli object to be created
+# after test collection so it has config loaded from temp directory
+
+
+@pytest.fixture(scope="module")
+def default_pgcli_obj():
+ return PGCli()
+
+
+@pytest.fixture(scope="module")
+def DEFAULT(default_pgcli_obj):
+ return default_pgcli_obj.row_limit
+
+
+@pytest.fixture(scope="module")
+def LIMIT(DEFAULT):
+ return DEFAULT + 1000
+
+
+@pytest.fixture(scope="module")
+def over_default(DEFAULT):
+ over_default_cursor = Mock()
+ over_default_cursor.configure_mock(rowcount=DEFAULT + 10)
+ return over_default_cursor
+
+
+@pytest.fixture(scope="module")
+def over_limit(LIMIT):
+ over_limit_cursor = Mock()
+ over_limit_cursor.configure_mock(rowcount=LIMIT + 10)
+ return over_limit_cursor
+
+
+@pytest.fixture(scope="module")
+def low_count():
+ low_count_cursor = Mock()
+ low_count_cursor.configure_mock(rowcount=1)
+ return low_count_cursor
+
+
+def test_row_limit_with_LIMIT_clause(LIMIT, over_limit):
+ cli = PGCli(row_limit=LIMIT)
+ stmt = "SELECT * FROM students LIMIT 1000"
+
+ result = cli._should_limit_output(stmt, over_limit)
+ assert result is False
+
+ cli = PGCli(row_limit=0)
+ result = cli._should_limit_output(stmt, over_limit)
+ assert result is False
+
+
+def test_row_limit_without_LIMIT_clause(LIMIT, over_limit):
+ cli = PGCli(row_limit=LIMIT)
+ stmt = "SELECT * FROM students"
+
+ result = cli._should_limit_output(stmt, over_limit)
+ assert result is True
+
+ cli = PGCli(row_limit=0)
+ result = cli._should_limit_output(stmt, over_limit)
+ assert result is False
+
+
+def test_row_limit_on_non_select(over_limit):
+ cli = PGCli()
+ stmt = "UPDATE students SET name='Boby'"
+ result = cli._should_limit_output(stmt, over_limit)
+ assert result is False
+
+ cli = PGCli(row_limit=0)
+ result = cli._should_limit_output(stmt, over_limit)
+ assert result is False
diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py
new file mode 100644
index 0000000..5c9c9af
--- /dev/null
+++ b/tests/test_smart_completion_multiple_schemata.py
@@ -0,0 +1,757 @@
+import itertools
+from metadata import (
+ MetaData,
+ alias,
+ name_join,
+ fk_join,
+ join,
+ schema,
+ table,
+ function,
+ wildcard_expansion,
+ column,
+ get_result,
+ result_set,
+ qual,
+ no_qual,
+ parametrize,
+)
+from utils import completions_to_set
+
+metadata = {
+ "tables": {
+ "public": {
+ "users": ["id", "email", "first_name", "last_name"],
+ "orders": ["id", "ordered_date", "status", "datestamp"],
+ "select": ["id", "localtime", "ABC"],
+ },
+ "custom": {
+ "users": ["id", "phone_number"],
+ "Users": ["userid", "username"],
+ "products": ["id", "product_name", "price"],
+ "shipments": ["id", "address", "user_id"],
+ },
+ "Custom": {"projects": ["projectid", "name"]},
+ "blog": {
+ "entries": ["entryid", "entrytitle", "entrytext"],
+ "tags": ["tagid", "name"],
+ "entrytags": ["entryid", "tagid"],
+ "entacclog": ["entryid", "username", "datestamp"],
+ },
+ },
+ "functions": {
+ "public": [
+ ["func1", [], [], [], "", False, False, False, False],
+ ["func2", [], [], [], "", False, False, False, False],
+ ],
+ "custom": [
+ ["func3", [], [], [], "", False, False, False, False],
+ [
+ "set_returning_func",
+ ["x"],
+ ["integer"],
+ ["o"],
+ "integer",
+ False,
+ False,
+ True,
+ False,
+ ],
+ ],
+ "Custom": [["func4", [], [], [], "", False, False, False, False]],
+ "blog": [
+ [
+ "extract_entry_symbols",
+ ["_entryid", "symbol"],
+ ["integer", "text"],
+ ["i", "o"],
+ "",
+ False,
+ False,
+ True,
+ False,
+ ],
+ [
+ "enter_entry",
+ ["_title", "_text", "entryid"],
+ ["text", "text", "integer"],
+ ["i", "i", "o"],
+ "",
+ False,
+ False,
+ False,
+ False,
+ ],
+ ],
+ },
+ "datatypes": {"public": ["typ1", "typ2"], "custom": ["typ3", "typ4"]},
+ "foreignkeys": {
+ "custom": [("public", "users", "id", "custom", "shipments", "user_id")],
+ "blog": [
+ ("blog", "entries", "entryid", "blog", "entacclog", "entryid"),
+ ("blog", "entries", "entryid", "blog", "entrytags", "entryid"),
+ ("blog", "tags", "tagid", "blog", "entrytags", "tagid"),
+ ],
+ },
+ "defaults": {
+ "public": {
+ ("orders", "id"): "nextval('orders_id_seq'::regclass)",
+ ("orders", "datestamp"): "now()",
+ ("orders", "status"): "'PENDING'::text",
+ }
+ },
+}
+
+testdata = MetaData(metadata)
+cased_schemas = [schema(x) for x in ("public", "blog", "CUSTOM", '"Custom"')]
+casing = (
+ "SELECT",
+ "Orders",
+ "User_Emails",
+ "CUSTOM",
+ "Func1",
+ "Entries",
+ "Tags",
+ "EntryTags",
+ "EntAccLog",
+ "EntryID",
+ "EntryTitle",
+ "EntryText",
+)
+completers = testdata.get_completers(casing)
+
+
+@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
+@parametrize("table", ["users", '"users"'])
+def test_suggested_column_names_from_shadowed_visible_table(completer, table):
+ result = get_result(completer, "SELECT FROM " + table, len("SELECT "))
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns_functions_and_keywords("users")
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
+@parametrize(
+ "text",
+ [
+ "SELECT from custom.users",
+ "WITH users as (SELECT 1 AS foo) SELECT from custom.users",
+ ],
+)
+def test_suggested_column_names_from_qualified_shadowed_table(completer, text):
+ result = get_result(completer, text, position=text.find(" ") + 1)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns_functions_and_keywords("users", "custom")
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
+@parametrize("text", ["WITH users as (SELECT 1 AS foo) SELECT from users"])
+def test_suggested_column_names_from_cte(completer, text):
+ result = completions_to_set(get_result(completer, text, text.find(" ") + 1))
+ assert result == completions_to_set(
+ [column("foo")] + testdata.functions_and_keywords()
+ )
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT * FROM users JOIN custom.shipments ON ",
+ """SELECT *
+ FROM public.users
+ JOIN custom.shipments ON """,
+ ],
+)
+def test_suggested_join_conditions(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [
+ alias("users"),
+ alias("shipments"),
+ name_join("shipments.id = users.id"),
+ fk_join("shipments.user_id = users.id"),
+ ]
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
+@parametrize(
+ ("query", "tbl"),
+ itertools.product(
+ (
+ "SELECT * FROM public.{0} RIGHT OUTER JOIN ",
+ """SELECT *
+ FROM {0}
+ JOIN """,
+ ),
+ ("users", '"users"', "Users"),
+ ),
+)
+def test_suggested_joins(completer, query, tbl):
+ result = get_result(completer, query.format(tbl))
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas_and_from_clause_items()
+ + [join(f"custom.shipments ON shipments.user_id = {tbl}.id")]
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
+def test_suggested_column_names_from_schema_qualifed_table(completer):
+ result = get_result(completer, "SELECT from custom.products", len("SELECT "))
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns_functions_and_keywords("products", "custom")
+ )
+
+
+@parametrize(
+ "text",
+ [
+ "INSERT INTO orders(",
+ "INSERT INTO orders (",
+ "INSERT INTO public.orders(",
+ "INSERT INTO public.orders (",
+ ],
+)
+@parametrize("completer", completers(filtr=True, casing=False))
+def test_suggested_columns_with_insert(completer, text):
+ assert completions_to_set(get_result(completer, text)) == completions_to_set(
+ testdata.columns("orders")
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
+def test_suggested_column_names_in_function(completer):
+ result = get_result(
+ completer, "SELECT MAX( from custom.products", len("SELECT MAX(")
+ )
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns_functions_and_keywords("products", "custom")
+ )
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+@parametrize(
+ "text",
+ ["SELECT * FROM Custom.", "SELECT * FROM custom.", 'SELECT * FROM "custom".'],
+)
+@parametrize("use_leading_double_quote", [False, True])
+def test_suggested_table_names_with_schema_dot(
+ completer, text, use_leading_double_quote
+):
+ if use_leading_double_quote:
+ text += '"'
+ start_position = -1
+ else:
+ start_position = 0
+
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.from_clause_items("custom", start_position)
+ )
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+@parametrize("text", ['SELECT * FROM "Custom".'])
+@parametrize("use_leading_double_quote", [False, True])
+def test_suggested_table_names_with_schema_dot2(
+ completer, text, use_leading_double_quote
+):
+ if use_leading_double_quote:
+ text += '"'
+ start_position = -1
+ else:
+ start_position = 0
+
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.from_clause_items("Custom", start_position)
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+def test_suggested_column_names_with_qualified_alias(completer):
+ result = get_result(completer, "SELECT p. from custom.products p", len("SELECT p."))
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns("products", "custom")
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
+def test_suggested_multiple_column_names(completer):
+ result = get_result(
+ completer, "SELECT id, from custom.products", len("SELECT id, ")
+ )
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns_functions_and_keywords("products", "custom")
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+def test_suggested_multiple_column_names_with_alias(completer):
+ result = get_result(
+ completer, "SELECT p.id, p. from custom.products p", len("SELECT u.id, u.")
+ )
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns("products", "custom")
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON ",
+ "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON JOIN public.orders z ON z.id > y.id",
+ ],
+)
+def test_suggestions_after_on(completer, text):
+ position = len(
+ "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON "
+ )
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(
+ [
+ alias("x"),
+ alias("y"),
+ name_join("y.price = x.price"),
+ name_join("y.product_name = x.product_name"),
+ name_join("y.id = x.id"),
+ ]
+ )
+
+
+@parametrize("completer", completers())
+def test_suggested_aliases_after_on_right_side(completer):
+ text = "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON x.id = "
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set([alias("x"), alias("y")])
+
+
+@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
+def test_table_names_after_from(completer):
+ text = "SELECT * FROM "
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas_and_from_clause_items()
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+def test_schema_qualified_function_name(completer):
+ text = "SELECT custom.func"
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [
+ function("func3()", -len("func")),
+ function("set_returning_func()", -len("func")),
+ ]
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
+def test_schema_qualified_function_name_after_from(completer):
+ text = "SELECT * FROM custom.set_r"
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [
+ function("set_returning_func()", -len("func")),
+ ]
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
+def test_unqualified_function_name_not_returned(completer):
+ text = "SELECT * FROM set_r"
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set([])
+
+
+@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
+def test_unqualified_function_name_in_search_path(completer):
+ completer.search_path = ["public", "custom"]
+ text = "SELECT * FROM set_r"
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [
+ function("set_returning_func()", -len("func")),
+ ]
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT 1::custom.",
+ "CREATE TABLE foo (bar custom.",
+ "CREATE FUNCTION foo (bar INT, baz custom.",
+ "ALTER TABLE foo ALTER COLUMN bar TYPE custom.",
+ ],
+)
+def test_schema_qualified_type_name(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(testdata.types("custom"))
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+def test_suggest_columns_from_aliased_set_returning_function(completer):
+ result = get_result(
+ completer, "select f. from custom.set_returning_func() f", len("select f.")
+ )
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns("set_returning_func", "custom", "functions")
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
+@parametrize(
+ "text",
+ [
+ "SELECT * FROM custom.set_returning_func()",
+ "SELECT * FROM Custom.set_returning_func()",
+ "SELECT * FROM Custom.Set_Returning_Func()",
+ ],
+)
+def test_wildcard_column_expansion_with_function(completer, text):
+ position = len("SELECT *")
+
+ completions = get_result(completer, text, position)
+
+ col_list = "x"
+ expected = [wildcard_expansion(col_list)]
+
+ assert expected == completions
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+def test_wildcard_column_expansion_with_alias_qualifier(completer):
+ text = "SELECT p.* FROM custom.products p"
+ position = len("SELECT p.*")
+
+ completions = get_result(completer, text, position)
+
+ col_list = "id, p.product_name, p.price"
+ expected = [wildcard_expansion(col_list)]
+
+ assert expected == completions
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+@parametrize(
+ "text",
+ [
+ """
+ SELECT count(1) FROM users;
+ CREATE FUNCTION foo(custom.products _products) returns custom.shipments
+ LANGUAGE SQL
+ AS $foo$
+ SELECT 1 FROM custom.shipments;
+ INSERT INTO public.orders(*) values(-1, now(), 'preliminary');
+ SELECT 2 FROM custom.users;
+ $foo$;
+ SELECT count(1) FROM custom.shipments;
+ """,
+ "INSERT INTO public.orders(*",
+ "INSERT INTO public.Orders(*",
+ "INSERT INTO public.orders (*",
+ "INSERT INTO public.Orders (*",
+ "INSERT INTO orders(*",
+ "INSERT INTO Orders(*",
+ "INSERT INTO orders (*",
+ "INSERT INTO Orders (*",
+ "INSERT INTO public.orders(*)",
+ "INSERT INTO public.Orders(*)",
+ "INSERT INTO public.orders (*)",
+ "INSERT INTO public.Orders (*)",
+ "INSERT INTO orders(*)",
+ "INSERT INTO Orders(*)",
+ "INSERT INTO orders (*)",
+ "INSERT INTO Orders (*)",
+ ],
+)
+def test_wildcard_column_expansion_with_insert(completer, text):
+ position = text.index("*") + 1
+ completions = get_result(completer, text, position)
+
+ expected = [wildcard_expansion("ordered_date, status")]
+ assert expected == completions
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+def test_wildcard_column_expansion_with_table_qualifier(completer):
+ text = 'SELECT "select".* FROM public."select"'
+ position = len('SELECT "select".*')
+
+ completions = get_result(completer, text, position)
+
+ col_list = 'id, "select"."localtime", "select"."ABC"'
+ expected = [wildcard_expansion(col_list)]
+
+ assert expected == completions
+
+
+@parametrize("completer", completers(filtr=True, casing=False, qualify=qual))
+def test_wildcard_column_expansion_with_two_tables(completer):
+ text = 'SELECT * FROM public."select" JOIN custom.users ON true'
+ position = len("SELECT *")
+
+ completions = get_result(completer, text, position)
+
+ cols = (
+ '"select".id, "select"."localtime", "select"."ABC", '
+ "users.id, users.phone_number"
+ )
+ expected = [wildcard_expansion(cols)]
+ assert completions == expected
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+def test_wildcard_column_expansion_with_two_tables_and_parent(completer):
+ text = 'SELECT "select".* FROM public."select" JOIN custom.users u ON true'
+ position = len('SELECT "select".*')
+
+ completions = get_result(completer, text, position)
+
+ col_list = 'id, "select"."localtime", "select"."ABC"'
+ expected = [wildcard_expansion(col_list)]
+
+ assert expected == completions
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT U. FROM custom.Users U",
+ "SELECT U. FROM custom.USERS U",
+ "SELECT U. FROM custom.users U",
+ 'SELECT U. FROM "custom".Users U',
+ 'SELECT U. FROM "custom".USERS U',
+ 'SELECT U. FROM "custom".users U',
+ ],
+)
+def test_suggest_columns_from_unquoted_table(completer, text):
+ position = len("SELECT U.")
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns("users", "custom")
+ )
+
+
+@parametrize("completer", completers(filtr=True, casing=False))
+@parametrize(
+ "text", ['SELECT U. FROM custom."Users" U', 'SELECT U. FROM "custom"."Users" U']
+)
+def test_suggest_columns_from_quoted_table(completer, text):
+ position = len("SELECT U.")
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns("Users", "custom")
+ )
+
+
+texts = ["SELECT * FROM ", "SELECT * FROM public.Orders O CROSS JOIN "]
+
+
+@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
+@parametrize("text", texts)
+def test_schema_or_visible_table_completion(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas_and_from_clause_items()
+ )
+
+
+@parametrize("completer", completers(aliasing=True, casing=False, filtr=True))
+@parametrize("text", texts)
+def test_table_aliases(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas()
+ + [
+ table("users u"),
+ table("orders o" if text == "SELECT * FROM " else "orders o2"),
+ table('"select" s'),
+ function("func1() f"),
+ function("func2() f"),
+ ]
+ )
+
+
+@parametrize("completer", completers(aliasing=True, casing=True, filtr=True))
+@parametrize("text", texts)
+def test_aliases_with_casing(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ cased_schemas
+ + [
+ table("users u"),
+ table("Orders O" if text == "SELECT * FROM " else "Orders O2"),
+ table('"select" s'),
+ function("Func1() F"),
+ function("func2() f"),
+ ]
+ )
+
+
+@parametrize("completer", completers(aliasing=False, casing=True, filtr=True))
+@parametrize("text", texts)
+def test_table_casing(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ cased_schemas
+ + [
+ table("users"),
+ table("Orders"),
+ table('"select"'),
+ function("Func1()"),
+ function("func2()"),
+ ]
+ )
+
+
+@parametrize("completer", completers(aliasing=False, casing=True))
+def test_alias_search_without_aliases2(completer):
+ text = "SELECT * FROM blog.et"
+ result = get_result(completer, text)
+ assert result[0] == table("EntryTags", -2)
+
+
+@parametrize("completer", completers(aliasing=False, casing=True))
+def test_alias_search_without_aliases1(completer):
+ text = "SELECT * FROM blog.e"
+ result = get_result(completer, text)
+ assert result[0] == table("Entries", -1)
+
+
+@parametrize("completer", completers(aliasing=True, casing=True))
+def test_alias_search_with_aliases2(completer):
+ text = "SELECT * FROM blog.et"
+ result = get_result(completer, text)
+ assert result[0] == table("EntryTags ET", -2)
+
+
+@parametrize("completer", completers(aliasing=True, casing=True))
+def test_alias_search_with_aliases1(completer):
+ text = "SELECT * FROM blog.e"
+ result = get_result(completer, text)
+ assert result[0] == table("Entries E", -1)
+
+
+@parametrize("completer", completers(aliasing=True, casing=True))
+def test_join_alias_search_with_aliases1(completer):
+ text = "SELECT * FROM blog.Entries E JOIN blog.e"
+ result = get_result(completer, text)
+ assert result[:2] == [
+ table("Entries E2", -1),
+ join("EntAccLog EAL ON EAL.EntryID = E.EntryID", -1),
+ ]
+
+
+@parametrize("completer", completers(aliasing=False, casing=True))
+def test_join_alias_search_without_aliases1(completer):
+ text = "SELECT * FROM blog.Entries JOIN blog.e"
+ result = get_result(completer, text)
+ assert result[:2] == [
+ table("Entries", -1),
+ join("EntAccLog ON EntAccLog.EntryID = Entries.EntryID", -1),
+ ]
+
+
+@parametrize("completer", completers(aliasing=True, casing=True))
+def test_join_alias_search_with_aliases2(completer):
+ text = "SELECT * FROM blog.Entries E JOIN blog.et"
+ result = get_result(completer, text)
+ assert result[0] == join("EntryTags ET ON ET.EntryID = E.EntryID", -2)
+
+
+@parametrize("completer", completers(aliasing=False, casing=True))
+def test_join_alias_search_without_aliases2(completer):
+ text = "SELECT * FROM blog.Entries JOIN blog.et"
+ result = get_result(completer, text)
+ assert result[0] == join("EntryTags ON EntryTags.EntryID = Entries.EntryID", -2)
+
+
+@parametrize("completer", completers())
+def test_function_alias_search_without_aliases(completer):
+ text = "SELECT blog.ees"
+ result = get_result(completer, text)
+ first = result[0]
+ assert first.start_position == -3
+ assert first.text == "extract_entry_symbols()"
+ assert first.display_text == "extract_entry_symbols(_entryid)"
+
+
+@parametrize("completer", completers())
+def test_function_alias_search_with_aliases(completer):
+ text = "SELECT blog.ee"
+ result = get_result(completer, text)
+ first = result[0]
+ assert first.start_position == -2
+ assert first.text == "enter_entry(_title := , _text := )"
+ assert first.display_text == "enter_entry(_title, _text)"
+
+
+@parametrize("completer", completers(filtr=True, casing=True, qualify=no_qual))
+def test_column_alias_search(completer):
+ result = get_result(completer, "SELECT et FROM blog.Entries E", len("SELECT et"))
+ cols = ("EntryText", "EntryTitle", "EntryID")
+ assert result[:3] == [column(c, -2) for c in cols]
+
+
+@parametrize("completer", completers(casing=True))
+def test_column_alias_search_qualified(completer):
+ result = get_result(
+ completer, "SELECT E.ei FROM blog.Entries E", len("SELECT E.ei")
+ )
+ cols = ("EntryID", "EntryTitle")
+ assert result[:3] == [column(c, -2) for c in cols]
+
+
+@parametrize("completer", completers(casing=False, filtr=False, aliasing=False))
+def test_schema_object_order(completer):
+ result = get_result(completer, "SELECT * FROM u")
+ assert result[:3] == [
+ table(t, pos=-1) for t in ("users", 'custom."Users"', "custom.users")
+ ]
+
+
+@parametrize("completer", completers(casing=False, filtr=False, aliasing=False))
+def test_all_schema_objects(completer):
+ text = "SELECT * FROM "
+ result = get_result(completer, text)
+ assert completions_to_set(result) >= completions_to_set(
+ [table(x) for x in ("orders", '"select"', "custom.shipments")]
+ + [function(x + "()") for x in ("func2",)]
+ )
+
+
+@parametrize("completer", completers(filtr=False, aliasing=False, casing=True))
+def test_all_schema_objects_with_casing(completer):
+ text = "SELECT * FROM "
+ result = get_result(completer, text)
+ assert completions_to_set(result) >= completions_to_set(
+ [table(x) for x in ("Orders", '"select"', "CUSTOM.shipments")]
+ + [function(x + "()") for x in ("func2",)]
+ )
+
+
+@parametrize("completer", completers(casing=False, filtr=False, aliasing=True))
+def test_all_schema_objects_with_aliases(completer):
+ text = "SELECT * FROM "
+ result = get_result(completer, text)
+ assert completions_to_set(result) >= completions_to_set(
+ [table(x) for x in ("orders o", '"select" s', "custom.shipments s")]
+ + [function(x) for x in ("func2() f",)]
+ )
+
+
+@parametrize("completer", completers(casing=False, filtr=False, aliasing=True))
+def test_set_schema(completer):
+ text = "SET SCHEMA "
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [schema("'blog'"), schema("'Custom'"), schema("'custom'"), schema("'public'")]
+ )
diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py
new file mode 100644
index 0000000..db1fe0a
--- /dev/null
+++ b/tests/test_smart_completion_public_schema_only.py
@@ -0,0 +1,1112 @@
+from metadata import (
+ MetaData,
+ alias,
+ name_join,
+ fk_join,
+ join,
+ keyword,
+ schema,
+ table,
+ view,
+ function,
+ column,
+ wildcard_expansion,
+ get_result,
+ result_set,
+ qual,
+ no_qual,
+ parametrize,
+)
+from prompt_toolkit.completion import Completion
+from utils import completions_to_set
+
+
+metadata = {
+ "tables": {
+ "users": ["id", "parentid", "email", "first_name", "last_name"],
+ "Users": ["userid", "username"],
+ "orders": ["id", "ordered_date", "status", "email"],
+ "select": ["id", "insert", "ABC"],
+ },
+ "views": {"user_emails": ["id", "email"], "functions": ["function"]},
+ "functions": [
+ ["custom_fun", [], [], [], "", False, False, False, False],
+ ["_custom_fun", [], [], [], "", False, False, False, False],
+ ["custom_func1", [], [], [], "", False, False, False, False],
+ ["custom_func2", [], [], [], "", False, False, False, False],
+ [
+ "set_returning_func",
+ ["x", "y"],
+ ["integer", "integer"],
+ ["b", "b"],
+ "",
+ False,
+ False,
+ True,
+ False,
+ ],
+ ],
+ "datatypes": ["custom_type1", "custom_type2"],
+ "foreignkeys": [
+ ("public", "users", "id", "public", "users", "parentid"),
+ ("public", "users", "id", "public", "Users", "userid"),
+ ],
+}
+
+metadata = {k: {"public": v} for k, v in metadata.items()}
+
+testdata = MetaData(metadata)
+
+cased_users_col_names = ["ID", "PARENTID", "Email", "First_Name", "last_name"]
+cased_users2_col_names = ["UserID", "UserName"]
+cased_func_names = [
+ "Custom_Fun",
+ "_custom_fun",
+ "Custom_Func1",
+ "custom_func2",
+ "set_returning_func",
+]
+cased_tbls = ["Users", "Orders"]
+cased_views = ["User_Emails", "Functions"]
+casing = (
+ ["SELECT", "PUBLIC"]
+ + cased_func_names
+ + cased_tbls
+ + cased_views
+ + cased_users_col_names
+ + cased_users2_col_names
+)
+# Lists for use in assertions
+cased_funcs = [
+ function(f)
+ for f in ("Custom_Fun()", "_custom_fun()", "Custom_Func1()", "custom_func2()")
+] + [function("set_returning_func(x := , y := )", display="set_returning_func(x, y)")]
+cased_tbls = [table(t) for t in (cased_tbls + ['"Users"', '"select"'])]
+cased_rels = [view(t) for t in cased_views] + cased_funcs + cased_tbls
+cased_users_cols = [column(c) for c in cased_users_col_names]
+aliased_rels = (
+ [table(t) for t in ("users u", '"Users" U', "orders o", '"select" s')]
+ + [view("user_emails ue"), view("functions f")]
+ + [
+ function(f)
+ for f in (
+ "_custom_fun() cf",
+ "custom_fun() cf",
+ "custom_func1() cf",
+ "custom_func2() cf",
+ )
+ ]
+ + [
+ function(
+ "set_returning_func(x := , y := ) srf",
+ display="set_returning_func(x, y) srf",
+ )
+ ]
+)
+cased_aliased_rels = (
+ [table(t) for t in ("Users U", '"Users" U', "Orders O", '"select" s')]
+ + [view("User_Emails UE"), view("Functions F")]
+ + [
+ function(f)
+ for f in (
+ "_custom_fun() cf",
+ "Custom_Fun() CF",
+ "Custom_Func1() CF",
+ "custom_func2() cf",
+ )
+ ]
+ + [
+ function(
+ "set_returning_func(x := , y := ) srf",
+ display="set_returning_func(x, y) srf",
+ )
+ ]
+)
+completers = testdata.get_completers(casing)
+
+
+# Just to make sure that this doesn't crash
+@parametrize("completer", completers())
+def test_function_column_name(completer):
+ for l in range(
+ len("SELECT * FROM Functions WHERE function:"),
+ len("SELECT * FROM Functions WHERE function:text") + 1,
+ ):
+ assert [] == get_result(
+ completer, "SELECT * FROM Functions WHERE function:text"[:l]
+ )
+
+
+@parametrize("action", ["ALTER", "DROP", "CREATE", "CREATE OR REPLACE"])
+@parametrize("completer", completers())
+def test_drop_alter_function(completer, action):
+ assert get_result(completer, action + " FUNCTION set_ret") == [
+ function("set_returning_func(x integer, y integer)", -len("set_ret"))
+ ]
+
+
+@parametrize("completer", completers())
+def test_empty_string_completion(completer):
+ result = get_result(completer, "")
+ assert completions_to_set(
+ testdata.keywords() + testdata.specials()
+ ) == completions_to_set(result)
+
+
+@parametrize("completer", completers())
+def test_select_keyword_completion(completer):
+ result = get_result(completer, "SEL")
+ assert completions_to_set(result) == completions_to_set([keyword("SELECT", -3)])
+
+
+@parametrize("completer", completers())
+def test_builtin_function_name_completion(completer):
+ result = get_result(completer, "SELECT MA")
+ assert completions_to_set(result) == completions_to_set(
+ [
+ function("MAKE_DATE", -2),
+ function("MAKE_INTERVAL", -2),
+ function("MAKE_TIME", -2),
+ function("MAKE_TIMESTAMP", -2),
+ function("MAKE_TIMESTAMPTZ", -2),
+ function("MASKLEN", -2),
+ function("MAX", -2),
+ keyword("MAXEXTENTS", -2),
+ keyword("MATERIALIZED VIEW", -2),
+ ]
+ )
+
+
+@parametrize("completer", completers())
+def test_builtin_function_matches_only_at_start(completer):
+ text = "SELECT IN"
+
+ result = [c.text for c in get_result(completer, text)]
+
+ assert "MIN" not in result
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+def test_user_function_name_completion(completer):
+ result = get_result(completer, "SELECT cu")
+ assert completions_to_set(result) == completions_to_set(
+ [
+ function("custom_fun()", -2),
+ function("_custom_fun()", -2),
+ function("custom_func1()", -2),
+ function("custom_func2()", -2),
+ function("CURRENT_DATE", -2),
+ function("CURRENT_TIMESTAMP", -2),
+ function("CUME_DIST", -2),
+ function("CURRENT_TIME", -2),
+ keyword("CURRENT", -2),
+ ]
+ )
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+def test_user_function_name_completion_matches_anywhere(completer):
+ result = get_result(completer, "SELECT om")
+ assert completions_to_set(result) == completions_to_set(
+ [
+ function("custom_fun()", -2),
+ function("_custom_fun()", -2),
+ function("custom_func1()", -2),
+ function("custom_func2()", -2),
+ ]
+ )
+
+
+@parametrize("completer", completers(casing=True))
+def test_list_functions_for_special(completer):
+ result = get_result(completer, r"\df ")
+ assert completions_to_set(result) == completions_to_set(
+ [schema("PUBLIC")] + [function(f) for f in cased_func_names]
+ )
+
+
+@parametrize("completer", completers(casing=False, qualify=no_qual))
+def test_suggested_column_names_from_visible_table(completer):
+ result = get_result(completer, "SELECT from users", len("SELECT "))
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns_functions_and_keywords("users")
+ )
+
+
+@parametrize("completer", completers(casing=True, qualify=no_qual))
+def test_suggested_cased_column_names(completer):
+ result = get_result(completer, "SELECT from users", len("SELECT "))
+ assert completions_to_set(result) == completions_to_set(
+ cased_funcs
+ + cased_users_cols
+ + testdata.builtin_functions()
+ + testdata.keywords()
+ )
+
+
+@parametrize("completer", completers(casing=False, qualify=no_qual))
+@parametrize("text", ["SELECT from users", "INSERT INTO Orders SELECT from users"])
+def test_suggested_auto_qualified_column_names(text, completer):
+ position = text.index(" ") + 1
+ cols = [column(c.lower()) for c in cased_users_col_names]
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(
+ cols + testdata.functions_and_keywords()
+ )
+
+
+@parametrize("completer", completers(casing=False, qualify=qual))
+@parametrize(
+ "text",
+ [
+ 'SELECT from users U NATURAL JOIN "Users"',
+ 'INSERT INTO Orders SELECT from users U NATURAL JOIN "Users"',
+ ],
+)
+def test_suggested_auto_qualified_column_names_two_tables(text, completer):
+ position = text.index(" ") + 1
+ cols = [column("U." + c.lower()) for c in cased_users_col_names]
+ cols += [column('"Users".' + c.lower()) for c in cased_users2_col_names]
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(
+ cols + testdata.functions_and_keywords()
+ )
+
+
+@parametrize("completer", completers(casing=True, qualify=["always"]))
+@parametrize("text", ["UPDATE users SET ", "INSERT INTO users("])
+def test_no_column_qualification(text, completer):
+ cols = [column(c) for c in cased_users_col_names]
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(cols)
+
+
+@parametrize("completer", completers(casing=True, qualify=["always"]))
+def test_suggested_cased_always_qualified_column_names(completer):
+ text = "SELECT from users"
+ position = len("SELECT ")
+ cols = [column("users." + c) for c in cased_users_col_names]
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(
+ cased_funcs + cols + testdata.builtin_functions() + testdata.keywords()
+ )
+
+
+@parametrize("completer", completers(casing=False, qualify=no_qual))
+def test_suggested_column_names_in_function(completer):
+ result = get_result(completer, "SELECT MAX( from users", len("SELECT MAX("))
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns_functions_and_keywords("users")
+ )
+
+
+@parametrize("completer", completers(casing=False))
+def test_suggested_column_names_with_table_dot(completer):
+ result = get_result(completer, "SELECT users. from users", len("SELECT users."))
+ assert completions_to_set(result) == completions_to_set(testdata.columns("users"))
+
+
+@parametrize("completer", completers(casing=False))
+def test_suggested_column_names_with_alias(completer):
+ result = get_result(completer, "SELECT u. from users u", len("SELECT u."))
+ assert completions_to_set(result) == completions_to_set(testdata.columns("users"))
+
+
+@parametrize("completer", completers(casing=False, qualify=no_qual))
+def test_suggested_multiple_column_names(completer):
+ result = get_result(completer, "SELECT id, from users u", len("SELECT id, "))
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns_functions_and_keywords("users")
+ )
+
+
+@parametrize("completer", completers(casing=False))
+def test_suggested_multiple_column_names_with_alias(completer):
+ result = get_result(
+ completer, "SELECT u.id, u. from users u", len("SELECT u.id, u.")
+ )
+ assert completions_to_set(result) == completions_to_set(testdata.columns("users"))
+
+
+@parametrize("completer", completers(casing=True))
+def test_suggested_cased_column_names_with_alias(completer):
+ result = get_result(
+ completer, "SELECT u.id, u. from users u", len("SELECT u.id, u.")
+ )
+ assert completions_to_set(result) == completions_to_set(cased_users_cols)
+
+
+@parametrize("completer", completers(casing=False))
+def test_suggested_multiple_column_names_with_dot(completer):
+ result = get_result(
+ completer,
+ "SELECT users.id, users. from users u",
+ len("SELECT users.id, users."),
+ )
+ assert completions_to_set(result) == completions_to_set(testdata.columns("users"))
+
+
+@parametrize("completer", completers(casing=False))
+def test_suggest_columns_after_three_way_join(completer):
+ text = """SELECT * FROM users u1
+ INNER JOIN users u2 ON u1.id = u2.id
+ INNER JOIN users u3 ON u2.id = u3."""
+ result = get_result(completer, text)
+ assert column("id") in result
+
+
+join_condition_texts = [
+ 'INSERT INTO orders SELECT * FROM users U JOIN "Users" U2 ON ',
+ """INSERT INTO public.orders(orderid)
+ SELECT * FROM users U JOIN "Users" U2 ON """,
+ 'SELECT * FROM users U JOIN "Users" U2 ON ',
+ 'SELECT * FROM users U INNER join "Users" U2 ON ',
+ 'SELECT * FROM USERS U right JOIN "Users" U2 ON ',
+ 'SELECT * FROM users U LEFT JOIN "Users" U2 ON ',
+ 'SELECT * FROM Users U FULL JOIN "Users" U2 ON ',
+ 'SELECT * FROM users U right outer join "Users" U2 ON ',
+ 'SELECT * FROM Users U LEFT OUTER JOIN "Users" U2 ON ',
+ 'SELECT * FROM users U FULL OUTER JOIN "Users" U2 ON ',
+ """SELECT *
+ FROM users U
+ FULL OUTER JOIN "Users" U2 ON
+ """,
+]
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize("text", join_condition_texts)
+def test_suggested_join_conditions(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [alias("U"), alias("U2"), fk_join("U2.userid = U.id")]
+ )
+
+
+@parametrize("completer", completers(casing=True))
+@parametrize("text", join_condition_texts)
+def test_cased_join_conditions(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [alias("U"), alias("U2"), fk_join("U2.UserID = U.ID")]
+ )
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ """SELECT *
+ FROM users
+ CROSS JOIN "Users"
+ NATURAL JOIN users u
+ JOIN "Users" u2 ON
+ """
+ ],
+)
+def test_suggested_join_conditions_with_same_table_twice(completer, text):
+ result = get_result(completer, text)
+ assert result == [
+ fk_join("u2.userid = u.id"),
+ fk_join("u2.userid = users.id"),
+ name_join('u2.userid = "Users".userid'),
+ name_join('u2.username = "Users".username'),
+ alias("u"),
+ alias("u2"),
+ alias("users"),
+ alias('"Users"'),
+ ]
+
+
+@parametrize("completer", completers())
+@parametrize("text", ["SELECT * FROM users JOIN users u2 on foo."])
+def test_suggested_join_conditions_with_invalid_qualifier(completer, text):
+ result = get_result(completer, text)
+ assert result == []
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ ("text", "ref"),
+ [
+ ("SELECT * FROM users JOIN NonTable on ", "NonTable"),
+ ("SELECT * FROM users JOIN nontable nt on ", "nt"),
+ ],
+)
+def test_suggested_join_conditions_with_invalid_table(completer, text, ref):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [alias("users"), alias(ref)]
+ )
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+@parametrize(
+ "text",
+ [
+ 'SELECT * FROM "Users" u JOIN u',
+ 'SELECT * FROM "Users" u JOIN uid',
+ 'SELECT * FROM "Users" u JOIN userid',
+ 'SELECT * FROM "Users" u JOIN id',
+ ],
+)
+def test_suggested_joins_fuzzy(completer, text):
+ result = get_result(completer, text)
+ last_word = text.split()[-1]
+ expected = join("users ON users.id = u.userid", -len(last_word))
+ assert expected in result
+
+
+join_texts = [
+ "SELECT * FROM Users JOIN ",
+ """INSERT INTO "Users"
+ SELECT *
+ FROM Users
+ INNER JOIN """,
+ """INSERT INTO public."Users"(username)
+ SELECT *
+ FROM Users
+ INNER JOIN """,
+ """SELECT *
+ FROM Users
+ INNER JOIN """,
+]
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+@parametrize("text", join_texts)
+def test_suggested_joins(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas_and_from_clause_items()
+ + [
+ join('"Users" ON "Users".userid = Users.id'),
+ join("users users2 ON users2.id = Users.parentid"),
+ join("users users2 ON users2.parentid = Users.id"),
+ ]
+ )
+
+
+@parametrize("completer", completers(casing=True, aliasing=False))
+@parametrize("text", join_texts)
+def test_cased_joins(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [schema("PUBLIC")]
+ + cased_rels
+ + [
+ join('"Users" ON "Users".UserID = Users.ID'),
+ join("Users Users2 ON Users2.ID = Users.PARENTID"),
+ join("Users Users2 ON Users2.PARENTID = Users.ID"),
+ ]
+ )
+
+
+@parametrize("completer", completers(casing=False, aliasing=True))
+@parametrize("text", join_texts)
+def test_aliased_joins(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas()
+ + aliased_rels
+ + [
+ join('"Users" U ON U.userid = Users.id'),
+ join("users u ON u.id = Users.parentid"),
+ join("users u ON u.parentid = Users.id"),
+ ]
+ )
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+@parametrize(
+ "text",
+ [
+ 'SELECT * FROM public."Users" JOIN ',
+ 'SELECT * FROM public."Users" RIGHT OUTER JOIN ',
+ """SELECT *
+ FROM public."Users"
+ LEFT JOIN """,
+ ],
+)
+def test_suggested_joins_quoted_schema_qualified_table(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas_and_from_clause_items()
+ + [join('public.users ON users.id = "Users".userid')]
+ )
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT u.name, o.id FROM users u JOIN orders o ON ",
+ "SELECT u.name, o.id FROM users u JOIN orders o ON JOIN orders o2 ON",
+ ],
+)
+def test_suggested_aliases_after_on(completer, text):
+ position = len("SELECT u.name, o.id FROM users u JOIN orders o ON ")
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(
+ [
+ alias("u"),
+ name_join("o.id = u.id"),
+ name_join("o.email = u.email"),
+ alias("o"),
+ ]
+ )
+
+
+@parametrize("completer", completers())
+@parametrize(
+ "text",
+ [
+ "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ",
+ "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = JOIN orders o2 ON",
+ ],
+)
+def test_suggested_aliases_after_on_right_side(completer, text):
+ position = len("SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ")
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set([alias("u"), alias("o")])
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT users.name, orders.id FROM users JOIN orders ON ",
+ "SELECT users.name, orders.id FROM users JOIN orders ON JOIN orders orders2 ON",
+ ],
+)
+def test_suggested_tables_after_on(completer, text):
+ position = len("SELECT users.name, orders.id FROM users JOIN orders ON ")
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(
+ [
+ name_join("orders.id = users.id"),
+ name_join("orders.email = users.email"),
+ alias("users"),
+ alias("orders"),
+ ]
+ )
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = JOIN orders orders2 ON",
+ "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ",
+ ],
+)
+def test_suggested_tables_after_on_right_side(completer, text):
+ position = len(
+ "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = "
+ )
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(
+ [alias("users"), alias("orders")]
+ )
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT * FROM users INNER JOIN orders USING (",
+ "SELECT * FROM users INNER JOIN orders USING(",
+ ],
+)
+def test_join_using_suggests_common_columns(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [column("id"), column("email")]
+ )
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT * FROM users u1 JOIN users u2 USING (email) JOIN user_emails ue USING()",
+ "SELECT * FROM users u1 JOIN users u2 USING(email) JOIN user_emails ue USING ()",
+ "SELECT * FROM users u1 JOIN user_emails ue USING () JOIN users u2 ue USING(first_name, last_name)",
+ "SELECT * FROM users u1 JOIN user_emails ue USING() JOIN users u2 ue USING (first_name, last_name)",
+ ],
+)
+def test_join_using_suggests_from_last_table(completer, text):
+ position = text.index("()") + 1
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(
+ [column("id"), column("email")]
+ )
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT * FROM users INNER JOIN orders USING (id,",
+ "SELECT * FROM users INNER JOIN orders USING(id,",
+ ],
+)
+def test_join_using_suggests_columns_after_first_column(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [column("id"), column("email")]
+ )
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT * FROM ",
+ "SELECT * FROM users CROSS JOIN ",
+ "SELECT * FROM users natural join ",
+ ],
+)
+def test_table_names_after_from(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas_and_from_clause_items()
+ )
+ assert [c.text for c in result] == [
+ "public",
+ "orders",
+ '"select"',
+ "users",
+ '"Users"',
+ "functions",
+ "user_emails",
+ "_custom_fun()",
+ "custom_fun()",
+ "custom_func1()",
+ "custom_func2()",
+ "set_returning_func(x := , y := )",
+ ]
+
+
+@parametrize("completer", completers(casing=False, qualify=no_qual))
+def test_auto_escaped_col_names(completer):
+ result = get_result(completer, 'SELECT from "select"', len("SELECT "))
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns_functions_and_keywords("select")
+ )
+
+
+@parametrize("completer", completers(aliasing=False))
+def test_allow_leading_double_quote_in_last_word(completer):
+ result = get_result(completer, 'SELECT * from "sele')
+
+ expected = table('"select"', -5)
+
+ assert expected in result
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT 1::",
+ "CREATE TABLE foo (bar ",
+ "CREATE FUNCTION foo (bar INT, baz ",
+ "ALTER TABLE foo ALTER COLUMN bar TYPE ",
+ ],
+)
+def test_suggest_datatype(text, completer):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas() + testdata.types() + testdata.builtin_datatypes()
+ )
+
+
+@parametrize("completer", completers(casing=False))
+def test_suggest_columns_from_escaped_table_alias(completer):
+ result = get_result(completer, 'select * from "select" s where s.')
+ assert completions_to_set(result) == completions_to_set(testdata.columns("select"))
+
+
+@parametrize("completer", completers(casing=False, qualify=no_qual))
+def test_suggest_columns_from_set_returning_function(completer):
+ result = get_result(completer, "select from set_returning_func()", len("select "))
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns_functions_and_keywords("set_returning_func", typ="functions")
+ )
+
+
+@parametrize("completer", completers(casing=False))
+def test_suggest_columns_from_aliased_set_returning_function(completer):
+ result = get_result(
+ completer, "select f. from set_returning_func() f", len("select f.")
+ )
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns("set_returning_func", typ="functions")
+ )
+
+
+@parametrize("completer", completers(casing=False))
+def test_join_functions_using_suggests_common_columns(completer):
+ text = """SELECT * FROM set_returning_func() f1
+ INNER JOIN set_returning_func() f2 USING ("""
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.columns("set_returning_func", typ="functions")
+ )
+
+
+@parametrize("completer", completers(casing=False))
+def test_join_functions_on_suggests_columns_and_join_conditions(completer):
+ text = """SELECT * FROM set_returning_func() f1
+ INNER JOIN set_returning_func() f2 ON f1."""
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [name_join("y = f2.y"), name_join("x = f2.x")]
+ + testdata.columns("set_returning_func", typ="functions")
+ )
+
+
+@parametrize("completer", completers())
+def test_learn_keywords(completer):
+ history = "CREATE VIEW v AS SELECT 1"
+ completer.extend_query_history(history)
+
+ # Now that we've used `VIEW` once, it should be suggested ahead of other
+ # keywords starting with v.
+ text = "create v"
+ completions = get_result(completer, text)
+ assert completions[0].text == "VIEW"
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+def test_learn_table_names(completer):
+ history = "SELECT * FROM users; SELECT * FROM orders; SELECT * FROM users"
+ completer.extend_query_history(history)
+
+ text = "SELECT * FROM "
+ completions = get_result(completer, text)
+
+ # `users` should be higher priority than `orders` (used more often)
+ users = table("users")
+ orders = table("orders")
+
+ assert completions.index(users) < completions.index(orders)
+
+
+@parametrize("completer", completers(casing=False, qualify=no_qual))
+def test_columns_before_keywords(completer):
+ text = "SELECT * FROM orders WHERE s"
+ completions = get_result(completer, text)
+
+ col = column("status", -1)
+ kw = keyword("SELECT", -1)
+
+ assert completions.index(col) < completions.index(kw)
+
+
+@parametrize("completer", completers(casing=False, qualify=no_qual))
+@parametrize(
+ "text",
+ [
+ "SELECT * FROM users",
+ "INSERT INTO users SELECT * FROM users u",
+ """INSERT INTO users(id, parentid, email, first_name, last_name)
+ SELECT *
+ FROM users u""",
+ ],
+)
+def test_wildcard_column_expansion(completer, text):
+ position = text.find("*") + 1
+
+ completions = get_result(completer, text, position)
+
+ col_list = "id, parentid, email, first_name, last_name"
+ expected = [wildcard_expansion(col_list)]
+
+ assert expected == completions
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ "SELECT u.* FROM users u",
+ "INSERT INTO public.users SELECT u.* FROM users u",
+ """INSERT INTO users(id, parentid, email, first_name, last_name)
+ SELECT u.*
+ FROM users u""",
+ ],
+)
+def test_wildcard_column_expansion_with_alias(completer, text):
+ position = text.find("*") + 1
+
+ completions = get_result(completer, text, position)
+
+ col_list = "id, u.parentid, u.email, u.first_name, u.last_name"
+ expected = [wildcard_expansion(col_list)]
+
+ assert expected == completions
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text,expected",
+ [
+ (
+ "SELECT users.* FROM users",
+ "id, users.parentid, users.email, users.first_name, users.last_name",
+ ),
+ (
+ "SELECT Users.* FROM Users",
+ "id, Users.parentid, Users.email, Users.first_name, Users.last_name",
+ ),
+ ],
+)
+def test_wildcard_column_expansion_with_table_qualifier(completer, text, expected):
+ position = len("SELECT users.*")
+
+ completions = get_result(completer, text, position)
+
+ expected = [wildcard_expansion(expected)]
+
+ assert expected == completions
+
+
+@parametrize("completer", completers(casing=False, qualify=qual))
+def test_wildcard_column_expansion_with_two_tables(completer):
+ text = 'SELECT * FROM "select" JOIN users u ON true'
+ position = len("SELECT *")
+
+ completions = get_result(completer, text, position)
+
+ cols = (
+ '"select".id, "select".insert, "select"."ABC", '
+ "u.id, u.parentid, u.email, u.first_name, u.last_name"
+ )
+ expected = [wildcard_expansion(cols)]
+ assert completions == expected
+
+
+@parametrize("completer", completers(casing=False))
+def test_wildcard_column_expansion_with_two_tables_and_parent(completer):
+ text = 'SELECT "select".* FROM "select" JOIN users u ON true'
+ position = len('SELECT "select".*')
+
+ completions = get_result(completer, text, position)
+
+ col_list = 'id, "select".insert, "select"."ABC"'
+ expected = [wildcard_expansion(col_list)]
+
+ assert expected == completions
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ ["SELECT U. FROM Users U", "SELECT U. FROM USERS U", "SELECT U. FROM users U"],
+)
+def test_suggest_columns_from_unquoted_table(completer, text):
+ position = len("SELECT U.")
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(testdata.columns("users"))
+
+
+@parametrize("completer", completers(casing=False))
+def test_suggest_columns_from_quoted_table(completer):
+ result = get_result(completer, 'SELECT U. FROM "Users" U', len("SELECT U."))
+ assert completions_to_set(result) == completions_to_set(testdata.columns("Users"))
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+@parametrize("text", ["SELECT * FROM ", "SELECT * FROM Orders o CROSS JOIN "])
+def test_schema_or_visible_table_completion(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas_and_from_clause_items()
+ )
+
+
+@parametrize("completer", completers(casing=False, aliasing=True))
+@parametrize("text", ["SELECT * FROM "])
+def test_table_aliases(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas() + aliased_rels
+ )
+
+
+@parametrize("completer", completers(casing=False, aliasing=True))
+@parametrize("text", ["SELECT * FROM Orders o CROSS JOIN "])
+def test_duplicate_table_aliases(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ testdata.schemas()
+ + [
+ table("orders o2"),
+ table("users u"),
+ table('"Users" U'),
+ table('"select" s'),
+ view("user_emails ue"),
+ view("functions f"),
+ function("_custom_fun() cf"),
+ function("custom_fun() cf"),
+ function("custom_func1() cf"),
+ function("custom_func2() cf"),
+ function(
+ "set_returning_func(x := , y := ) srf",
+ display="set_returning_func(x, y) srf",
+ ),
+ ]
+ )
+
+
+@parametrize("completer", completers(casing=True, aliasing=True))
+@parametrize("text", ["SELECT * FROM Orders o CROSS JOIN "])
+def test_duplicate_aliases_with_casing(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [
+ schema("PUBLIC"),
+ table("Orders O2"),
+ table("Users U"),
+ table('"Users" U'),
+ table('"select" s'),
+ view("User_Emails UE"),
+ view("Functions F"),
+ function("_custom_fun() cf"),
+ function("Custom_Fun() CF"),
+ function("Custom_Func1() CF"),
+ function("custom_func2() cf"),
+ function(
+ "set_returning_func(x := , y := ) srf",
+ display="set_returning_func(x, y) srf",
+ ),
+ ]
+ )
+
+
+@parametrize("completer", completers(casing=True, aliasing=True))
+@parametrize("text", ["SELECT * FROM "])
+def test_aliases_with_casing(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [schema("PUBLIC")] + cased_aliased_rels
+ )
+
+
+@parametrize("completer", completers(casing=True, aliasing=False))
+@parametrize("text", ["SELECT * FROM "])
+def test_table_casing(completer, text):
+ result = get_result(completer, text)
+ assert completions_to_set(result) == completions_to_set(
+ [schema("PUBLIC")] + cased_rels
+ )
+
+
+@parametrize("completer", completers(casing=False))
+@parametrize(
+ "text",
+ [
+ "INSERT INTO users ()",
+ "INSERT INTO users()",
+ "INSERT INTO users () SELECT * FROM orders;",
+ "INSERT INTO users() SELECT * FROM users u cross join orders o",
+ ],
+)
+def test_insert(completer, text):
+ position = text.find("(") + 1
+ result = get_result(completer, text, position)
+ assert completions_to_set(result) == completions_to_set(testdata.columns("users"))
+
+
+@parametrize("completer", completers(casing=False, aliasing=False))
+def test_suggest_cte_names(completer):
+ text = """
+ WITH cte1 AS (SELECT a, b, c FROM foo),
+ cte2 AS (SELECT d, e, f FROM bar)
+ SELECT * FROM
+ """
+ result = get_result(completer, text)
+ expected = completions_to_set(
+ [
+ Completion("cte1", 0, display_meta="table"),
+ Completion("cte2", 0, display_meta="table"),
+ ]
+ )
+ assert expected <= completions_to_set(result)
+
+
+@parametrize("completer", completers(casing=False, qualify=no_qual))
+def test_suggest_columns_from_cte(completer):
+ result = get_result(
+ completer,
+ "WITH cte AS (SELECT foo, bar FROM baz) SELECT FROM cte",
+ len("WITH cte AS (SELECT foo, bar FROM baz) SELECT "),
+ )
+ expected = [
+ Completion("foo", 0, display_meta="column"),
+ Completion("bar", 0, display_meta="column"),
+ ] + testdata.functions_and_keywords()
+
+ assert completions_to_set(expected) == completions_to_set(result)
+
+
+@parametrize("completer", completers(casing=False, qualify=no_qual))
+@parametrize(
+ "text",
+ [
+ "WITH cte AS (SELECT foo FROM bar) SELECT * FROM cte WHERE cte.",
+ "WITH cte AS (SELECT foo FROM bar) SELECT * FROM cte c WHERE c.",
+ ],
+)
+def test_cte_qualified_columns(completer, text):
+ result = get_result(completer, text)
+ expected = [Completion("foo", 0, display_meta="column")]
+ assert completions_to_set(expected) == completions_to_set(result)
+
+
+@parametrize(
+ "keyword_casing,expected,texts",
+ [
+ ("upper", "SELECT", ("", "s", "S", "Sel")),
+ ("lower", "select", ("", "s", "S", "Sel")),
+ ("auto", "SELECT", ("", "S", "SEL", "seL")),
+ ("auto", "select", ("s", "sel", "SEl")),
+ ],
+)
+def test_keyword_casing_upper(keyword_casing, expected, texts):
+ for text in texts:
+ completer = testdata.get_completer({"keyword_casing": keyword_casing})
+ completions = get_result(completer, text)
+ assert expected in [cpl.text for cpl in completions]
+
+
+@parametrize("completer", completers())
+def test_keyword_after_alter(completer):
+ text = "ALTER TABLE users ALTER "
+ expected = Completion("COLUMN", start_position=0, display_meta="keyword")
+ completions = get_result(completer, text)
+ assert expected in completions
+
+
+@parametrize("completer", completers())
+def test_set_schema(completer):
+ text = "SET SCHEMA "
+ result = get_result(completer, text)
+ expected = completions_to_set([schema("'public'")])
+ assert completions_to_set(result) == expected
+
+
+@parametrize("completer", completers())
+def test_special_name_completion(completer):
+ result = get_result(completer, "\\t")
+ assert completions_to_set(result) == completions_to_set(
+ [
+ Completion(
+ text="\\timing",
+ start_position=-2,
+ display_meta="Toggle timing of commands.",
+ )
+ ]
+ )
diff --git a/tests/test_sqlcompletion.py b/tests/test_sqlcompletion.py
new file mode 100644
index 0000000..1034bbe
--- /dev/null
+++ b/tests/test_sqlcompletion.py
@@ -0,0 +1,964 @@
+from pgcli.packages.sqlcompletion import (
+ suggest_type,
+ Special,
+ Database,
+ Schema,
+ Table,
+ Column,
+ View,
+ Keyword,
+ FromClauseItem,
+ Function,
+ Datatype,
+ Alias,
+ JoinCondition,
+ Join,
+)
+from pgcli.packages.parseutils.tables import TableReference
+import pytest
+
+
+def cols_etc(
+ table, schema=None, alias=None, is_function=False, parent=None, last_keyword=None
+):
+ """Returns the expected select-clause suggestions for a single-table
+ select."""
+ return {
+ Column(
+ table_refs=(TableReference(schema, table, alias, is_function),),
+ qualifiable=True,
+ ),
+ Function(schema=parent),
+ Keyword(last_keyword),
+ }
+
+
+def test_select_suggests_cols_with_visible_table_scope():
+ suggestions = suggest_type("SELECT FROM tabl", "SELECT ")
+ assert set(suggestions) == cols_etc("tabl", last_keyword="SELECT")
+
+
+def test_select_suggests_cols_with_qualified_table_scope():
+ suggestions = suggest_type("SELECT FROM sch.tabl", "SELECT ")
+ assert set(suggestions) == cols_etc("tabl", "sch", last_keyword="SELECT")
+
+
+def test_cte_does_not_crash():
+ sql = "WITH CTE AS (SELECT F.* FROM Foo F WHERE F.Bar > 23) SELECT C.* FROM CTE C WHERE C.FooID BETWEEN 123 AND 234;"
+ for i in range(len(sql)):
+ suggestions = suggest_type(sql[: i + 1], sql[: i + 1])
+
+
+@pytest.mark.parametrize("expression", ['SELECT * FROM "tabl" WHERE '])
+def test_where_suggests_columns_functions_quoted_table(expression):
+ expected = cols_etc("tabl", alias='"tabl"', last_keyword="WHERE")
+ suggestions = suggest_type(expression, expression)
+ assert expected == set(suggestions)
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "INSERT INTO OtherTabl(ID, Name) SELECT * FROM tabl WHERE ",
+ "INSERT INTO OtherTabl SELECT * FROM tabl WHERE ",
+ "SELECT * FROM tabl WHERE ",
+ "SELECT * FROM tabl WHERE (",
+ "SELECT * FROM tabl WHERE foo = ",
+ "SELECT * FROM tabl WHERE bar OR ",
+ "SELECT * FROM tabl WHERE foo = 1 AND ",
+ "SELECT * FROM tabl WHERE (bar > 10 AND ",
+ "SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (",
+ "SELECT * FROM tabl WHERE 10 < ",
+ "SELECT * FROM tabl WHERE foo BETWEEN ",
+ "SELECT * FROM tabl WHERE foo BETWEEN foo AND ",
+ ],
+)
+def test_where_suggests_columns_functions(expression):
+ suggestions = suggest_type(expression, expression)
+ assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE")
+
+
+@pytest.mark.parametrize(
+ "expression",
+ ["SELECT * FROM tabl WHERE foo IN (", "SELECT * FROM tabl WHERE foo IN (bar, "],
+)
+def test_where_in_suggests_columns(expression):
+ suggestions = suggest_type(expression, expression)
+ assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE")
+
+
+@pytest.mark.parametrize("expression", ["SELECT 1 AS ", "SELECT 1 FROM tabl AS "])
+def test_after_as(expression):
+ suggestions = suggest_type(expression, expression)
+ assert set(suggestions) == set()
+
+
+def test_where_equals_any_suggests_columns_or_keywords():
+ text = "SELECT * FROM tabl WHERE foo = ANY("
+ suggestions = suggest_type(text, text)
+ assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE")
+
+
+def test_lparen_suggests_cols_and_funcs():
+ suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(")
+ assert set(suggestion) == {
+ Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
+ Function(schema=None),
+ Keyword("("),
+ }
+
+
+def test_select_suggests_cols_and_funcs():
+ suggestions = suggest_type("SELECT ", "SELECT ")
+ assert set(suggestions) == {
+ Column(table_refs=(), qualifiable=True),
+ Function(schema=None),
+ Keyword("SELECT"),
+ }
+
+
+@pytest.mark.parametrize(
+ "expression", ["INSERT INTO ", "COPY ", "UPDATE ", "DESCRIBE "]
+)
+def test_suggests_tables_views_and_schemas(expression):
+ suggestions = suggest_type(expression, expression)
+ assert set(suggestions) == {Table(schema=None), View(schema=None), Schema()}
+
+
+@pytest.mark.parametrize("expression", ["SELECT * FROM "])
+def test_suggest_tables_views_schemas_and_functions(expression):
+ suggestions = suggest_type(expression, expression)
+ assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM foo JOIN bar on bar.barid = foo.barid JOIN ",
+ "SELECT * FROM foo JOIN bar USING (barid) JOIN ",
+ ],
+)
+def test_suggest_after_join_with_two_tables(expression):
+ suggestions = suggest_type(expression, expression)
+ tables = tuple([(None, "foo", None, False), (None, "bar", None, False)])
+ assert set(suggestions) == {
+ FromClauseItem(schema=None, table_refs=tables),
+ Join(tables, None),
+ Schema(),
+ }
+
+
+@pytest.mark.parametrize(
+ "expression", ["SELECT * FROM foo JOIN ", "SELECT * FROM foo JOIN bar"]
+)
+def test_suggest_after_join_with_one_table(expression):
+ suggestions = suggest_type(expression, expression)
+ tables = ((None, "foo", None, False),)
+ assert set(suggestions) == {
+ FromClauseItem(schema=None, table_refs=tables),
+ Join(((None, "foo", None, False),), None),
+ Schema(),
+ }
+
+
+@pytest.mark.parametrize(
+ "expression", ["INSERT INTO sch.", "COPY sch.", "DESCRIBE sch."]
+)
+def test_suggest_qualified_tables_and_views(expression):
+ suggestions = suggest_type(expression, expression)
+ assert set(suggestions) == {Table(schema="sch"), View(schema="sch")}
+
+
+@pytest.mark.parametrize("expression", ["UPDATE sch."])
+def test_suggest_qualified_aliasable_tables_and_views(expression):
+ suggestions = suggest_type(expression, expression)
+ assert set(suggestions) == {Table(schema="sch"), View(schema="sch")}
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM sch.",
+ 'SELECT * FROM sch."',
+ 'SELECT * FROM sch."foo',
+ 'SELECT * FROM "sch".',
+ 'SELECT * FROM "sch"."',
+ ],
+)
+def test_suggest_qualified_tables_views_and_functions(expression):
+ suggestions = suggest_type(expression, expression)
+ assert set(suggestions) == {FromClauseItem(schema="sch")}
+
+
+@pytest.mark.parametrize("expression", ["SELECT * FROM foo JOIN sch."])
+def test_suggest_qualified_tables_views_functions_and_joins(expression):
+ suggestions = suggest_type(expression, expression)
+ tbls = tuple([(None, "foo", None, False)])
+ assert set(suggestions) == {
+ FromClauseItem(schema="sch", table_refs=tbls),
+ Join(tbls, "sch"),
+ }
+
+
+def test_truncate_suggests_tables_and_schemas():
+ suggestions = suggest_type("TRUNCATE ", "TRUNCATE ")
+ assert set(suggestions) == {Table(schema=None), Schema()}
+
+
+def test_truncate_suggests_qualified_tables():
+ suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.")
+ assert set(suggestions) == {Table(schema="sch")}
+
+
+@pytest.mark.parametrize(
+ "text", ["SELECT DISTINCT ", "INSERT INTO foo SELECT DISTINCT "]
+)
+def test_distinct_suggests_cols(text):
+ suggestions = suggest_type(text, text)
+ assert set(suggestions) == {
+ Column(table_refs=(), local_tables=(), qualifiable=True),
+ Function(schema=None),
+ Keyword("DISTINCT"),
+ }
+
+
+@pytest.mark.parametrize(
+ "text, text_before, last_keyword",
+ [
+ ("SELECT DISTINCT FROM tbl x JOIN tbl1 y", "SELECT DISTINCT", "SELECT"),
+ (
+ "SELECT * FROM tbl x JOIN tbl1 y ORDER BY ",
+ "SELECT * FROM tbl x JOIN tbl1 y ORDER BY ",
+ "ORDER BY",
+ ),
+ ],
+)
+def test_distinct_and_order_by_suggestions_with_aliases(
+ text, text_before, last_keyword
+):
+ suggestions = suggest_type(text, text_before)
+ assert set(suggestions) == {
+ Column(
+ table_refs=(
+ TableReference(None, "tbl", "x", False),
+ TableReference(None, "tbl1", "y", False),
+ ),
+ local_tables=(),
+ qualifiable=True,
+ ),
+ Function(schema=None),
+ Keyword(last_keyword),
+ }
+
+
+@pytest.mark.parametrize(
+ "text, text_before",
+ [
+ ("SELECT DISTINCT x. FROM tbl x JOIN tbl1 y", "SELECT DISTINCT x."),
+ (
+ "SELECT * FROM tbl x JOIN tbl1 y ORDER BY x.",
+ "SELECT * FROM tbl x JOIN tbl1 y ORDER BY x.",
+ ),
+ ],
+)
+def test_distinct_and_order_by_suggestions_with_alias_given(text, text_before):
+ suggestions = suggest_type(text, text_before)
+ assert set(suggestions) == {
+ Column(
+ table_refs=(TableReference(None, "tbl", "x", False),),
+ local_tables=(),
+ qualifiable=False,
+ ),
+ Table(schema="x"),
+ View(schema="x"),
+ Function(schema="x"),
+ }
+
+
+def test_function_arguments_with_alias_given():
+ suggestions = suggest_type("SELECT avg(x. FROM tbl x, tbl2 y", "SELECT avg(x.")
+
+ assert set(suggestions) == {
+ Column(
+ table_refs=(TableReference(None, "tbl", "x", False),),
+ local_tables=(),
+ qualifiable=False,
+ ),
+ Table(schema="x"),
+ View(schema="x"),
+ Function(schema="x"),
+ }
+
+
+def test_col_comma_suggests_cols():
+ suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,")
+ assert set(suggestions) == {
+ Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
+ Function(schema=None),
+ Keyword("SELECT"),
+ }
+
+
+def test_table_comma_suggests_tables_and_schemas():
+ suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ")
+ assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
+
+
+def test_into_suggests_tables_and_schemas():
+ suggestion = suggest_type("INSERT INTO ", "INSERT INTO ")
+ assert set(suggestion) == {Table(schema=None), View(schema=None), Schema()}
+
+
+@pytest.mark.parametrize(
+ "text", ["INSERT INTO abc (", "INSERT INTO abc () SELECT * FROM hij;"]
+)
+def test_insert_into_lparen_suggests_cols(text):
+ suggestions = suggest_type(text, "INSERT INTO abc (")
+ assert suggestions == (
+ Column(table_refs=((None, "abc", None, False),), context="insert"),
+ )
+
+
+def test_insert_into_lparen_partial_text_suggests_cols():
+ suggestions = suggest_type("INSERT INTO abc (i", "INSERT INTO abc (i")
+ assert suggestions == (
+ Column(table_refs=((None, "abc", None, False),), context="insert"),
+ )
+
+
+def test_insert_into_lparen_comma_suggests_cols():
+ suggestions = suggest_type("INSERT INTO abc (id,", "INSERT INTO abc (id,")
+ assert suggestions == (
+ Column(table_refs=((None, "abc", None, False),), context="insert"),
+ )
+
+
+def test_partially_typed_col_name_suggests_col_names():
+ suggestions = suggest_type(
+ "SELECT * FROM tabl WHERE col_n", "SELECT * FROM tabl WHERE col_n"
+ )
+ assert set(suggestions) == cols_etc("tabl", last_keyword="WHERE")
+
+
+def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
+ suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.")
+ assert set(suggestions) == {
+ Column(table_refs=((None, "tabl", None, False),)),
+ Table(schema="tabl"),
+ View(schema="tabl"),
+ Function(schema="tabl"),
+ }
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "SELECT t1. FROM tabl1 t1",
+ "SELECT t1. FROM tabl1 t1, tabl2 t2",
+ 'SELECT t1. FROM "tabl1" t1',
+ 'SELECT t1. FROM "tabl1" t1, "tabl2" t2',
+ ],
+)
+def test_dot_suggests_cols_of_an_alias(sql):
+ suggestions = suggest_type(sql, "SELECT t1.")
+ assert set(suggestions) == {
+ Table(schema="t1"),
+ View(schema="t1"),
+ Column(table_refs=((None, "tabl1", "t1", False),)),
+ Function(schema="t1"),
+ }
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "SELECT * FROM tabl1 t1 WHERE t1.",
+ "SELECT * FROM tabl1 t1, tabl2 t2 WHERE t1.",
+ 'SELECT * FROM "tabl1" t1 WHERE t1.',
+ 'SELECT * FROM "tabl1" t1, tabl2 t2 WHERE t1.',
+ ],
+)
+def test_dot_suggests_cols_of_an_alias_where(sql):
+ suggestions = suggest_type(sql, sql)
+ assert set(suggestions) == {
+ Table(schema="t1"),
+ View(schema="t1"),
+ Column(table_refs=((None, "tabl1", "t1", False),)),
+ Function(schema="t1"),
+ }
+
+
+def test_dot_col_comma_suggests_cols_or_schema_qualified_table():
+ suggestions = suggest_type(
+ "SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2."
+ )
+ assert set(suggestions) == {
+ Column(table_refs=((None, "tabl2", "t2", False),)),
+ Table(schema="t2"),
+ View(schema="t2"),
+ Function(schema="t2"),
+ }
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM (",
+ "SELECT * FROM foo WHERE EXISTS (",
+ "SELECT * FROM foo WHERE bar AND NOT EXISTS (",
+ ],
+)
+def test_sub_select_suggests_keyword(expression):
+ suggestion = suggest_type(expression, expression)
+ assert suggestion == (Keyword(),)
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM (S",
+ "SELECT * FROM foo WHERE EXISTS (S",
+ "SELECT * FROM foo WHERE bar AND NOT EXISTS (S",
+ ],
+)
+def test_sub_select_partial_text_suggests_keyword(expression):
+ suggestion = suggest_type(expression, expression)
+ assert suggestion == (Keyword(),)
+
+
+def test_outer_table_reference_in_exists_subquery_suggests_columns():
+ q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f."
+ suggestions = suggest_type(q, q)
+ assert set(suggestions) == {
+ Column(table_refs=((None, "foo", "f", False),)),
+ Table(schema="f"),
+ View(schema="f"),
+ Function(schema="f"),
+ }
+
+
+@pytest.mark.parametrize("expression", ["SELECT * FROM (SELECT * FROM "])
+def test_sub_select_table_name_completion(expression):
+ suggestion = suggest_type(expression, expression)
+ assert set(suggestion) == {FromClauseItem(schema=None), Schema()}
+
+
+@pytest.mark.parametrize(
+ "expression",
+ [
+ "SELECT * FROM foo WHERE EXISTS (SELECT * FROM ",
+ "SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ",
+ ],
+)
+def test_sub_select_table_name_completion_with_outer_table(expression):
+ suggestion = suggest_type(expression, expression)
+ tbls = tuple([(None, "foo", None, False)])
+ assert set(suggestion) == {FromClauseItem(schema=None, table_refs=tbls), Schema()}
+
+
+def test_sub_select_col_name_completion():
+ suggestions = suggest_type(
+ "SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT "
+ )
+ assert set(suggestions) == {
+ Column(table_refs=((None, "abc", None, False),), qualifiable=True),
+ Function(schema=None),
+ Keyword("SELECT"),
+ }
+
+
+@pytest.mark.xfail
+def test_sub_select_multiple_col_name_completion():
+ suggestions = suggest_type(
+ "SELECT * FROM (SELECT a, FROM abc", "SELECT * FROM (SELECT a, "
+ )
+ assert set(suggestions) == cols_etc("abc")
+
+
+def test_sub_select_dot_col_name_completion():
+ suggestions = suggest_type(
+ "SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t."
+ )
+ assert set(suggestions) == {
+ Column(table_refs=((None, "tabl", "t", False),)),
+ Table(schema="t"),
+ View(schema="t"),
+ Function(schema="t"),
+ }
+
+
+@pytest.mark.parametrize("join_type", ("", "INNER", "LEFT", "RIGHT OUTER"))
+@pytest.mark.parametrize("tbl_alias", ("", "foo"))
+def test_join_suggests_tables_and_schemas(tbl_alias, join_type):
+ text = f"SELECT * FROM abc {tbl_alias} {join_type} JOIN "
+ suggestion = suggest_type(text, text)
+ tbls = tuple([(None, "abc", tbl_alias or None, False)])
+ assert set(suggestion) == {
+ FromClauseItem(schema=None, table_refs=tbls),
+ Schema(),
+ Join(tbls, None),
+ }
+
+
+def test_left_join_with_comma():
+ text = "select * from foo f left join bar b,"
+ suggestions = suggest_type(text, text)
+ # tbls should also include (None, 'bar', 'b', False)
+ # but there's a bug with commas
+ tbls = tuple([(None, "foo", "f", False)])
+ assert set(suggestions) == {FromClauseItem(schema=None, table_refs=tbls), Schema()}
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "SELECT * FROM abc a JOIN def d ON a.",
+ "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.",
+ ],
+)
+def test_join_alias_dot_suggests_cols1(sql):
+ suggestions = suggest_type(sql, sql)
+ tables = ((None, "abc", "a", False), (None, "def", "d", False))
+ assert set(suggestions) == {
+ Column(table_refs=((None, "abc", "a", False),)),
+ Table(schema="a"),
+ View(schema="a"),
+ Function(schema="a"),
+ JoinCondition(table_refs=tables, parent=(None, "abc", "a", False)),
+ }
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "SELECT * FROM abc a JOIN def d ON a.id = d.",
+ "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.",
+ ],
+)
+def test_join_alias_dot_suggests_cols2(sql):
+ suggestion = suggest_type(sql, sql)
+ assert set(suggestion) == {
+ Column(table_refs=((None, "def", "d", False),)),
+ Table(schema="d"),
+ View(schema="d"),
+ Function(schema="d"),
+ }
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select a.x, b.y from abc a join bcd b on ",
+ """select a.x, b.y
+from abc a
+join bcd b on
+""",
+ """select a.x, b.y
+from abc a
+join bcd b
+on """,
+ "select a.x, b.y from abc a join bcd b on a.id = b.id OR ",
+ ],
+)
+def test_on_suggests_aliases_and_join_conditions(sql):
+ suggestions = suggest_type(sql, sql)
+ tables = ((None, "abc", "a", False), (None, "bcd", "b", False))
+ assert set(suggestions) == {
+ JoinCondition(table_refs=tables, parent=None),
+ Alias(aliases=("a", "b")),
+ }
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ",
+ "select abc.x, bcd.y from abc join bcd on ",
+ ],
+)
+def test_on_suggests_tables_and_join_conditions(sql):
+ suggestions = suggest_type(sql, sql)
+ tables = ((None, "abc", None, False), (None, "bcd", None, False))
+ assert set(suggestions) == {
+ JoinCondition(table_refs=tables, parent=None),
+ Alias(aliases=("abc", "bcd")),
+ }
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select a.x, b.y from abc a join bcd b on a.id = ",
+ "select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ",
+ ],
+)
+def test_on_suggests_aliases_right_side(sql):
+ suggestions = suggest_type(sql, sql)
+ assert suggestions == (Alias(aliases=("a", "b")),)
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ",
+ "select abc.x, bcd.y from abc join bcd on ",
+ ],
+)
+def test_on_suggests_tables_and_join_conditions_right_side(sql):
+ suggestions = suggest_type(sql, sql)
+ tables = ((None, "abc", None, False), (None, "bcd", None, False))
+ assert set(suggestions) == {
+ JoinCondition(table_refs=tables, parent=None),
+ Alias(aliases=("abc", "bcd")),
+ }
+
+
+@pytest.mark.parametrize(
+ "text",
+ (
+ "select * from abc inner join def using (",
+ "select * from abc inner join def using (col1, ",
+ "insert into hij select * from abc inner join def using (",
+ """insert into hij(x, y, z)
+ select * from abc inner join def using (col1, """,
+ """insert into hij (a,b,c)
+ select * from abc inner join def using (col1, """,
+ ),
+)
+def test_join_using_suggests_common_columns(text):
+ tables = ((None, "abc", None, False), (None, "def", None, False))
+ assert set(suggest_type(text, text)) == {
+ Column(table_refs=tables, require_last_table=True)
+ }
+
+
+def test_suggest_columns_after_multiple_joins():
+ sql = """select * from t1
+ inner join t2 ON
+ t1.id = t2.t1_id
+ inner join t3 ON
+ t2.id = t3."""
+ suggestions = suggest_type(sql, sql)
+ assert Column(table_refs=((None, "t3", None, False),)) in set(suggestions)
+
+
+def test_2_statements_2nd_current():
+ suggestions = suggest_type(
+ "select * from a; select * from ", "select * from a; select * from "
+ )
+ assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
+
+ suggestions = suggest_type(
+ "select * from a; select from b", "select * from a; select "
+ )
+ assert set(suggestions) == {
+ Column(table_refs=((None, "b", None, False),), qualifiable=True),
+ Function(schema=None),
+ Keyword("SELECT"),
+ }
+
+ # Should work even if first statement is invalid
+ suggestions = suggest_type(
+ "select * from; select * from ", "select * from; select * from "
+ )
+ assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
+
+
+def test_2_statements_1st_current():
+ suggestions = suggest_type("select * from ; select * from b", "select * from ")
+ assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
+
+ suggestions = suggest_type("select from a; select * from b", "select ")
+ assert set(suggestions) == cols_etc("a", last_keyword="SELECT")
+
+
+def test_3_statements_2nd_current():
+ suggestions = suggest_type(
+ "select * from a; select * from ; select * from c",
+ "select * from a; select * from ",
+ )
+ assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
+
+ suggestions = suggest_type(
+ "select * from a; select from b; select * from c", "select * from a; select "
+ )
+ assert set(suggestions) == cols_etc("b", last_keyword="SELECT")
+
+
+@pytest.mark.parametrize(
+ "text",
+ [
+ """
+CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $$
+SELECT FROM foo;
+SELECT 2 FROM bar;
+$$ language sql;
+ """,
+ """create function func2(int, varchar)
+RETURNS text
+language sql AS
+$func$
+SELECT 2 FROM bar;
+SELECT FROM foo;
+$func$
+ """,
+ """
+CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $func$
+SELECT 3 FROM foo;
+SELECT 2 FROM bar;
+$$ language sql;
+create function func2(int, varchar)
+RETURNS text
+language sql AS
+$func$
+SELECT 2 FROM bar;
+SELECT FROM foo;
+$func$
+ """,
+ """
+SELECT * FROM baz;
+CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $func$
+SELECT FROM foo;
+SELECT 2 FROM bar;
+$$ language sql;
+create function func2(int, varchar)
+RETURNS text
+language sql AS
+$func$
+SELECT 3 FROM bar;
+SELECT FROM foo;
+$func$
+SELECT * FROM qux;
+ """,
+ ],
+)
+def test_statements_in_function_body(text):
+ suggestions = suggest_type(text, text[: text.find(" ") + 1])
+ assert set(suggestions) == {
+ Column(table_refs=((None, "foo", None, False),), qualifiable=True),
+ Function(schema=None),
+ Keyword("SELECT"),
+ }
+
+
+functions = [
+ """
+CREATE OR REPLACE FUNCTION func() RETURNS setof int AS $$
+SELECT 1 FROM foo;
+SELECT 2 FROM bar;
+$$ language sql;
+ """,
+ """
+create function func2(int, varchar)
+RETURNS text
+language sql AS
+'
+SELECT 2 FROM bar;
+SELECT 1 FROM foo;
+';
+ """,
+]
+
+
+@pytest.mark.parametrize("text", functions)
+def test_statements_with_cursor_after_function_body(text):
+ suggestions = suggest_type(text, text[: text.find("; ") + 1])
+ assert set(suggestions) == {Keyword(), Special()}
+
+
+@pytest.mark.parametrize("text", functions)
+def test_statements_with_cursor_before_function_body(text):
+ suggestions = suggest_type(text, "")
+ assert set(suggestions) == {Keyword(), Special()}
+
+
+def test_create_db_with_template():
+ suggestions = suggest_type(
+ "create database foo with template ", "create database foo with template "
+ )
+
+ assert set(suggestions) == {Database()}
+
+
+@pytest.mark.parametrize("initial_text", ("", " ", "\t \t", "\n"))
+def test_specials_included_for_initial_completion(initial_text):
+ suggestions = suggest_type(initial_text, initial_text)
+
+ assert set(suggestions) == {Keyword(), Special()}
+
+
+def test_drop_schema_qualified_table_suggests_only_tables():
+ text = "DROP TABLE schema_name.table_name"
+ suggestions = suggest_type(text, text)
+ assert suggestions == (Table(schema="schema_name"),)
+
+
+@pytest.mark.parametrize("text", (",", " ,", "sel ,"))
+def test_handle_pre_completion_comma_gracefully(text):
+ suggestions = suggest_type(text, text)
+
+ assert iter(suggestions)
+
+
+def test_drop_schema_suggests_schemas():
+ sql = "DROP SCHEMA "
+ assert suggest_type(sql, sql) == (Schema(),)
+
+
+@pytest.mark.parametrize("text", ["SELECT x::", "SELECT x::y", "SELECT (x + y)::"])
+def test_cast_operator_suggests_types(text):
+ assert set(suggest_type(text, text)) == {
+ Datatype(schema=None),
+ Table(schema=None),
+ Schema(),
+ }
+
+
+@pytest.mark.parametrize(
+ "text", ["SELECT foo::bar.", "SELECT foo::bar.baz", "SELECT (x + y)::bar."]
+)
+def test_cast_operator_suggests_schema_qualified_types(text):
+ assert set(suggest_type(text, text)) == {
+ Datatype(schema="bar"),
+ Table(schema="bar"),
+ }
+
+
+def test_alter_column_type_suggests_types():
+ q = "ALTER TABLE foo ALTER COLUMN bar TYPE "
+ assert set(suggest_type(q, q)) == {
+ Datatype(schema=None),
+ Table(schema=None),
+ Schema(),
+ }
+
+
+@pytest.mark.parametrize(
+ "text",
+ [
+ "CREATE TABLE foo (bar ",
+ "CREATE TABLE foo (bar DOU",
+ "CREATE TABLE foo (bar INT, baz ",
+ "CREATE TABLE foo (bar INT, baz TEXT, qux ",
+ "CREATE FUNCTION foo (bar ",
+ "CREATE FUNCTION foo (bar INT, baz ",
+ "SELECT * FROM foo() AS bar (baz ",
+ "SELECT * FROM foo() AS bar (baz INT, qux ",
+ # make sure this doesn't trigger special completion
+ "CREATE TABLE foo (dt d",
+ ],
+)
+def test_identifier_suggests_types_in_parentheses(text):
+ assert set(suggest_type(text, text)) == {
+ Datatype(schema=None),
+ Table(schema=None),
+ Schema(),
+ }
+
+
+@pytest.mark.parametrize(
+ "text",
+ [
+ "SELECT foo ",
+ "SELECT foo FROM bar ",
+ "SELECT foo AS bar ",
+ "SELECT foo bar ",
+ "SELECT * FROM foo AS bar ",
+ "SELECT * FROM foo bar ",
+ "SELECT foo FROM (SELECT bar ",
+ ],
+)
+def test_alias_suggests_keywords(text):
+ suggestions = suggest_type(text, text)
+ assert suggestions == (Keyword(),)
+
+
+def test_invalid_sql():
+ # issue 317
+ text = "selt *"
+ suggestions = suggest_type(text, text)
+ assert suggestions == (Keyword(),)
+
+
+@pytest.mark.parametrize(
+ "text",
+ ["SELECT * FROM foo where created > now() - ", "select * from foo where bar "],
+)
+def test_suggest_where_keyword(text):
+ # https://github.com/dbcli/mycli/issues/135
+ suggestions = suggest_type(text, text)
+ assert set(suggestions) == cols_etc("foo", last_keyword="WHERE")
+
+
+@pytest.mark.parametrize(
+ "text, before, expected",
+ [
+ (
+ "\\ns abc SELECT ",
+ "SELECT ",
+ [
+ Column(table_refs=(), qualifiable=True),
+ Function(schema=None),
+ Keyword("SELECT"),
+ ],
+ ),
+ ("\\ns abc SELECT foo ", "SELECT foo ", (Keyword(),)),
+ (
+ "\\ns abc SELECT t1. FROM tabl1 t1",
+ "SELECT t1.",
+ [
+ Table(schema="t1"),
+ View(schema="t1"),
+ Column(table_refs=((None, "tabl1", "t1", False),)),
+ Function(schema="t1"),
+ ],
+ ),
+ ],
+)
+def test_named_query_completion(text, before, expected):
+ suggestions = suggest_type(text, before)
+ assert set(expected) == set(suggestions)
+
+
+def test_select_suggests_fields_from_function():
+ suggestions = suggest_type("SELECT FROM func()", "SELECT ")
+ assert set(suggestions) == cols_etc("func", is_function=True, last_keyword="SELECT")
+
+
+@pytest.mark.parametrize("sql", ["("])
+def test_leading_parenthesis(sql):
+ # No assertion for now; just make sure it doesn't crash
+ suggest_type(sql, sql)
+
+
+@pytest.mark.parametrize("sql", ['select * from "', 'select * from "foo'])
+def test_ignore_leading_double_quotes(sql):
+ suggestions = suggest_type(sql, sql)
+ assert FromClauseItem(schema=None) in set(suggestions)
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ "ALTER TABLE foo ALTER COLUMN ",
+ "ALTER TABLE foo ALTER COLUMN bar",
+ "ALTER TABLE foo DROP COLUMN ",
+ "ALTER TABLE foo DROP COLUMN bar",
+ ],
+)
+def test_column_keyword_suggests_columns(sql):
+ suggestions = suggest_type(sql, sql)
+ assert set(suggestions) == {Column(table_refs=((None, "foo", None, False),))}
+
+
+def test_handle_unrecognized_kw_generously():
+ sql = "SELECT * FROM sessions WHERE session = 1 AND "
+ suggestions = suggest_type(sql, sql)
+ expected = Column(table_refs=((None, "sessions", None, False),), qualifiable=True)
+
+ assert expected in set(suggestions)
+
+
+@pytest.mark.parametrize("sql", ["ALTER ", "ALTER TABLE foo ALTER "])
+def test_keyword_after_alter(sql):
+ assert Keyword("ALTER") in set(suggest_type(sql, sql))
diff --git a/tests/test_ssh_tunnel.py b/tests/test_ssh_tunnel.py
new file mode 100644
index 0000000..ae865f4
--- /dev/null
+++ b/tests/test_ssh_tunnel.py
@@ -0,0 +1,188 @@
+import os
+from unittest.mock import patch, MagicMock, ANY
+
+import pytest
+from configobj import ConfigObj
+from click.testing import CliRunner
+from sshtunnel import SSHTunnelForwarder
+
+from pgcli.main import cli, PGCli
+from pgcli.pgexecute import PGExecute
+
+
+@pytest.fixture
+def mock_ssh_tunnel_forwarder() -> MagicMock:
+ mock_ssh_tunnel_forwarder = MagicMock(
+ SSHTunnelForwarder, local_bind_ports=[1111], autospec=True
+ )
+ with patch(
+ "pgcli.main.sshtunnel.SSHTunnelForwarder",
+ return_value=mock_ssh_tunnel_forwarder,
+ ) as mock:
+ yield mock
+
+
+@pytest.fixture
+def mock_pgexecute() -> MagicMock:
+ with patch.object(PGExecute, "__init__", return_value=None) as mock_pgexecute:
+ yield mock_pgexecute
+
+
+def test_ssh_tunnel(
+ mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock
+) -> None:
+ # Test with just a host
+ tunnel_url = "some.host"
+ db_params = {
+ "database": "dbname",
+ "host": "db.host",
+ "user": "db_user",
+ "passwd": "db_passwd",
+ }
+ expected_tunnel_params = {
+ "local_bind_address": ("127.0.0.1",),
+ "remote_bind_address": (db_params["host"], 5432),
+ "ssh_address_or_host": (tunnel_url, 22),
+ "logger": ANY,
+ }
+
+ pgcli = PGCli(ssh_tunnel_url=tunnel_url)
+ pgcli.connect(**db_params)
+
+ mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params)
+ mock_ssh_tunnel_forwarder.return_value.start.assert_called_once()
+ mock_pgexecute.assert_called_once()
+
+ call_args, call_kwargs = mock_pgexecute.call_args
+ assert call_args == (
+ db_params["database"],
+ db_params["user"],
+ db_params["passwd"],
+ "127.0.0.1",
+ pgcli.ssh_tunnel.local_bind_ports[0],
+ "",
+ )
+ mock_ssh_tunnel_forwarder.reset_mock()
+ mock_pgexecute.reset_mock()
+
+ # Test with a full url and with a specific db port
+ tunnel_user = "tunnel_user"
+ tunnel_passwd = "tunnel_pass"
+ tunnel_host = "some.other.host"
+ tunnel_port = 1022
+ tunnel_url = f"ssh://{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}"
+ db_params["port"] = 1234
+
+ expected_tunnel_params["remote_bind_address"] = (
+ db_params["host"],
+ db_params["port"],
+ )
+ expected_tunnel_params["ssh_address_or_host"] = (tunnel_host, tunnel_port)
+ expected_tunnel_params["ssh_username"] = tunnel_user
+ expected_tunnel_params["ssh_password"] = tunnel_passwd
+
+ pgcli = PGCli(ssh_tunnel_url=tunnel_url)
+ pgcli.connect(**db_params)
+
+ mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params)
+ mock_ssh_tunnel_forwarder.return_value.start.assert_called_once()
+ mock_pgexecute.assert_called_once()
+
+ call_args, call_kwargs = mock_pgexecute.call_args
+ assert call_args == (
+ db_params["database"],
+ db_params["user"],
+ db_params["passwd"],
+ "127.0.0.1",
+ pgcli.ssh_tunnel.local_bind_ports[0],
+ "",
+ )
+ mock_ssh_tunnel_forwarder.reset_mock()
+ mock_pgexecute.reset_mock()
+
+ # Test with DSN
+ dsn = (
+ f"user={db_params['user']} password={db_params['passwd']} "
+ f"host={db_params['host']} port={db_params['port']}"
+ )
+
+ pgcli = PGCli(ssh_tunnel_url=tunnel_url)
+ pgcli.connect(dsn=dsn)
+
+ expected_dsn = (
+ f"user={db_params['user']} password={db_params['passwd']} "
+ f"host=127.0.0.1 port={pgcli.ssh_tunnel.local_bind_ports[0]}"
+ )
+
+ mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params)
+ mock_pgexecute.assert_called_once()
+
+ call_args, call_kwargs = mock_pgexecute.call_args
+ assert expected_dsn in call_args
+
+
+def test_cli_with_tunnel() -> None:
+ runner = CliRunner()
+ tunnel_url = "mytunnel"
+ with patch.object(
+ PGCli, "__init__", autospec=True, return_value=None
+ ) as mock_pgcli:
+ runner.invoke(cli, ["--ssh-tunnel", tunnel_url])
+ mock_pgcli.assert_called_once()
+ call_args, call_kwargs = mock_pgcli.call_args
+ assert call_kwargs["ssh_tunnel_url"] == tunnel_url
+
+
+def test_config(
+ tmpdir: os.PathLike, mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock
+) -> None:
+ pgclirc = str(tmpdir.join("rcfile"))
+
+ tunnel_user = "tunnel_user"
+ tunnel_passwd = "tunnel_pass"
+ tunnel_host = "tunnel.host"
+ tunnel_port = 1022
+ tunnel_url = f"{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}"
+
+ tunnel2_url = "tunnel2.host"
+
+ config = ConfigObj()
+ config.filename = pgclirc
+ config["ssh tunnels"] = {}
+ config["ssh tunnels"][r"\.com$"] = tunnel_url
+ config["ssh tunnels"][r"^hello-"] = tunnel2_url
+ config.write()
+
+ # Unmatched host
+ pgcli = PGCli(pgclirc_file=pgclirc)
+ pgcli.connect(host="unmatched.host")
+ mock_ssh_tunnel_forwarder.assert_not_called()
+
+ # Host matching first tunnel
+ pgcli = PGCli(pgclirc_file=pgclirc)
+ pgcli.connect(host="matched.host.com")
+ mock_ssh_tunnel_forwarder.assert_called_once()
+ call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args
+ assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port)
+ assert call_kwargs["ssh_username"] == tunnel_user
+ assert call_kwargs["ssh_password"] == tunnel_passwd
+ mock_ssh_tunnel_forwarder.reset_mock()
+
+ # Host matching second tunnel
+ pgcli = PGCli(pgclirc_file=pgclirc)
+ pgcli.connect(host="hello-i-am-matched")
+ mock_ssh_tunnel_forwarder.assert_called_once()
+
+ call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args
+ assert call_kwargs["ssh_address_or_host"] == (tunnel2_url, 22)
+ mock_ssh_tunnel_forwarder.reset_mock()
+
+ # Host matching both tunnels (will use the first one matched)
+ pgcli = PGCli(pgclirc_file=pgclirc)
+ pgcli.connect(host="hello-i-am-matched.com")
+ mock_ssh_tunnel_forwarder.assert_called_once()
+
+ call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args
+ assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port)
+ assert call_kwargs["ssh_username"] == tunnel_user
+ assert call_kwargs["ssh_password"] == tunnel_passwd
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 0000000..67d769f
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,92 @@
+import pytest
+import psycopg
+from pgcli.main import format_output, OutputSettings
+from os import getenv
+
+POSTGRES_USER = getenv("PGUSER", "postgres")
+POSTGRES_HOST = getenv("PGHOST", "localhost")
+POSTGRES_PORT = getenv("PGPORT", 5432)
+POSTGRES_PASSWORD = getenv("PGPASSWORD", "postgres")
+
+
+def db_connection(dbname=None):
+ conn = psycopg.connect(
+ user=POSTGRES_USER,
+ host=POSTGRES_HOST,
+ password=POSTGRES_PASSWORD,
+ port=POSTGRES_PORT,
+ dbname=dbname,
+ )
+ conn.autocommit = True
+ return conn
+
+
+try:
+ conn = db_connection()
+ CAN_CONNECT_TO_DB = True
+ SERVER_VERSION = conn.info.parameter_status("server_version")
+ JSON_AVAILABLE = True
+ JSONB_AVAILABLE = True
+except Exception as x:
+ CAN_CONNECT_TO_DB = JSON_AVAILABLE = JSONB_AVAILABLE = False
+ SERVER_VERSION = 0
+
+
+dbtest = pytest.mark.skipif(
+ not CAN_CONNECT_TO_DB,
+ reason="Need a postgres instance at localhost accessible by user 'postgres'",
+)
+
+
+requires_json = pytest.mark.skipif(
+ not JSON_AVAILABLE, reason="Postgres server unavailable or json type not defined"
+)
+
+
+requires_jsonb = pytest.mark.skipif(
+ not JSONB_AVAILABLE, reason="Postgres server unavailable or jsonb type not defined"
+)
+
+
+def create_db(dbname):
+ with db_connection().cursor() as cur:
+ try:
+ cur.execute("""CREATE DATABASE _test_db""")
+ except:
+ pass
+
+
+def drop_tables(conn):
+ with conn.cursor() as cur:
+ cur.execute(
+ """
+ DROP SCHEMA public CASCADE;
+ CREATE SCHEMA public;
+ DROP SCHEMA IF EXISTS schema1 CASCADE;
+ DROP SCHEMA IF EXISTS schema2 CASCADE"""
+ )
+
+
+def run(
+ executor, sql, join=False, expanded=False, pgspecial=None, exception_formatter=None
+):
+ "Return string output for the sql to be run"
+
+ results = executor.run(sql, pgspecial, exception_formatter)
+ formatted = []
+ settings = OutputSettings(
+ table_format="psql", dcmlfmt="d", floatfmt="g", expanded=expanded
+ )
+ for title, rows, headers, status, sql, success, is_special in results:
+ formatted.extend(format_output(title, rows, headers, status, settings))
+ if join:
+ formatted = "\n".join(formatted)
+
+ return formatted
+
+
+def completions_to_set(completions):
+ return {
+ (completion.display_text, completion.display_meta_text)
+ for completion in completions
+ }