summaryrefslogtreecommitdiffstats
path: root/tests/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/utils.py')
-rw-r--r--tests/utils.py95
1 files changed, 95 insertions, 0 deletions
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 0000000..2427c30
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,95 @@
+import pytest
+import psycopg2
+import psycopg2.extras
+from pgcli.main import format_output, OutputSettings
+from pgcli.pgexecute import register_json_typecasters
+from os import getenv
+
+POSTGRES_USER = getenv("PGUSER", "postgres")
+POSTGRES_HOST = getenv("PGHOST", "localhost")
+POSTGRES_PORT = getenv("PGPORT", 5432)
+POSTGRES_PASSWORD = getenv("PGPASSWORD", "")
+
+
+def db_connection(dbname=None):
+ conn = psycopg2.connect(
+ user=POSTGRES_USER,
+ host=POSTGRES_HOST,
+ password=POSTGRES_PASSWORD,
+ port=POSTGRES_PORT,
+ database=dbname,
+ )
+ conn.autocommit = True
+ return conn
+
+
+try:
+ conn = db_connection()
+ CAN_CONNECT_TO_DB = True
+ SERVER_VERSION = conn.server_version
+ json_types = register_json_typecasters(conn, lambda x: x)
+ JSON_AVAILABLE = "json" in json_types
+ JSONB_AVAILABLE = "jsonb" in json_types
+except:
+ CAN_CONNECT_TO_DB = JSON_AVAILABLE = JSONB_AVAILABLE = False
+ SERVER_VERSION = 0
+
+
+dbtest = pytest.mark.skipif(
+ not CAN_CONNECT_TO_DB,
+ reason="Need a postgres instance at localhost accessible by user 'postgres'",
+)
+
+
+requires_json = pytest.mark.skipif(
+ not JSON_AVAILABLE, reason="Postgres server unavailable or json type not defined"
+)
+
+
+requires_jsonb = pytest.mark.skipif(
+ not JSONB_AVAILABLE, reason="Postgres server unavailable or jsonb type not defined"
+)
+
+
+def create_db(dbname):
+ with db_connection().cursor() as cur:
+ try:
+ cur.execute("""CREATE DATABASE _test_db""")
+ except:
+ pass
+
+
+def drop_tables(conn):
+ with conn.cursor() as cur:
+ cur.execute(
+ """
+ DROP SCHEMA public CASCADE;
+ CREATE SCHEMA public;
+ DROP SCHEMA IF EXISTS schema1 CASCADE;
+ DROP SCHEMA IF EXISTS schema2 CASCADE"""
+ )
+
+
+def run(
+ executor, sql, join=False, expanded=False, pgspecial=None, exception_formatter=None
+):
+ " Return string output for the sql to be run "
+
+ results = executor.run(sql, pgspecial, exception_formatter)
+ formatted = []
+ settings = OutputSettings(
+ table_format="psql", dcmlfmt="d", floatfmt="g", expanded=expanded
+ )
+ for title, rows, headers, status, sql, success, is_special in results:
+ formatted.extend(format_output(title, rows, headers, status, settings))
+ if join:
+ formatted = "\n".join(formatted)
+
+ return formatted
+
+
+def completions_to_set(completions):
+ return set(
+ (completion.display_text, completion.display_meta_text)
+ for completion in completions
+ )