summaryrefslogtreecommitdiffstats
path: root/tests/utils.py
blob: 67d769fd497edaaf68bf1a8b5bc7890232191e80 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import pytest
import psycopg
from pgcli.main import format_output, OutputSettings
from os import getenv

POSTGRES_USER = getenv("PGUSER", "postgres")
POSTGRES_HOST = getenv("PGHOST", "localhost")
POSTGRES_PORT = getenv("PGPORT", 5432)
POSTGRES_PASSWORD = getenv("PGPASSWORD", "postgres")


def db_connection(dbname=None):
    conn = psycopg.connect(
        user=POSTGRES_USER,
        host=POSTGRES_HOST,
        password=POSTGRES_PASSWORD,
        port=POSTGRES_PORT,
        dbname=dbname,
    )
    conn.autocommit = True
    return conn


try:
    conn = db_connection()
    CAN_CONNECT_TO_DB = True
    SERVER_VERSION = conn.info.parameter_status("server_version")
    JSON_AVAILABLE = True
    JSONB_AVAILABLE = True
except Exception as x:
    CAN_CONNECT_TO_DB = JSON_AVAILABLE = JSONB_AVAILABLE = False
    SERVER_VERSION = 0


dbtest = pytest.mark.skipif(
    not CAN_CONNECT_TO_DB,
    reason="Need a postgres instance at localhost accessible by user 'postgres'",
)


requires_json = pytest.mark.skipif(
    not JSON_AVAILABLE, reason="Postgres server unavailable or json type not defined"
)


requires_jsonb = pytest.mark.skipif(
    not JSONB_AVAILABLE, reason="Postgres server unavailable or jsonb type not defined"
)


def create_db(dbname):
    with db_connection().cursor() as cur:
        try:
            cur.execute("""CREATE DATABASE _test_db""")
        except:
            pass


def drop_tables(conn):
    with conn.cursor() as cur:
        cur.execute(
            """
            DROP SCHEMA public CASCADE;
            CREATE SCHEMA public;
            DROP SCHEMA IF EXISTS schema1 CASCADE;
            DROP SCHEMA IF EXISTS schema2 CASCADE"""
        )


def run(
    executor, sql, join=False, expanded=False, pgspecial=None, exception_formatter=None
):
    "Return string output for the sql to be run"

    results = executor.run(sql, pgspecial, exception_formatter)
    formatted = []
    settings = OutputSettings(
        table_format="psql", dcmlfmt="d", floatfmt="g", expanded=expanded
    )
    for title, rows, headers, status, sql, success, is_special in results:
        formatted.extend(format_output(title, rows, headers, status, settings))
    if join:
        formatted = "\n".join(formatted)

    return formatted


def completions_to_set(completions):
    return {
        (completion.display_text, completion.display_meta_text)
        for completion in completions
    }