1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
|
import pytest
import psycopg2
import psycopg2.extras
from pgcli.main import format_output, OutputSettings
from pgcli.pgexecute import register_json_typecasters
from os import getenv
POSTGRES_USER = getenv("PGUSER", "postgres")
POSTGRES_HOST = getenv("PGHOST", "localhost")
POSTGRES_PORT = getenv("PGPORT", 5432)
POSTGRES_PASSWORD = getenv("PGPASSWORD", "")
def db_connection(dbname=None):
conn = psycopg2.connect(
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
port=POSTGRES_PORT,
database=dbname,
)
conn.autocommit = True
return conn
try:
conn = db_connection()
CAN_CONNECT_TO_DB = True
SERVER_VERSION = conn.server_version
json_types = register_json_typecasters(conn, lambda x: x)
JSON_AVAILABLE = "json" in json_types
JSONB_AVAILABLE = "jsonb" in json_types
except:
CAN_CONNECT_TO_DB = JSON_AVAILABLE = JSONB_AVAILABLE = False
SERVER_VERSION = 0
dbtest = pytest.mark.skipif(
not CAN_CONNECT_TO_DB,
reason="Need a postgres instance at localhost accessible by user 'postgres'",
)
requires_json = pytest.mark.skipif(
not JSON_AVAILABLE, reason="Postgres server unavailable or json type not defined"
)
requires_jsonb = pytest.mark.skipif(
not JSONB_AVAILABLE, reason="Postgres server unavailable or jsonb type not defined"
)
def create_db(dbname):
with db_connection().cursor() as cur:
try:
cur.execute("""CREATE DATABASE _test_db""")
except:
pass
def drop_tables(conn):
with conn.cursor() as cur:
cur.execute(
"""
DROP SCHEMA public CASCADE;
CREATE SCHEMA public;
DROP SCHEMA IF EXISTS schema1 CASCADE;
DROP SCHEMA IF EXISTS schema2 CASCADE"""
)
def run(
executor, sql, join=False, expanded=False, pgspecial=None, exception_formatter=None
):
" Return string output for the sql to be run "
results = executor.run(sql, pgspecial, exception_formatter)
formatted = []
settings = OutputSettings(
table_format="psql", dcmlfmt="d", floatfmt="g", expanded=expanded
)
for title, rows, headers, status, sql, success, is_special in results:
formatted.extend(format_output(title, rows, headers, status, settings))
if join:
formatted = "\n".join(formatted)
return formatted
def completions_to_set(completions):
return set(
(completion.display_text, completion.display_meta_text)
for completion in completions
)
|