summaryrefslogtreecommitdiffstats
path: root/tests/test_prepared_async.py
blob: 84d948f6553d7f13b7642577fa78edfe77727bb3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
"""
Prepared statements tests on async connections
"""

import datetime as dt
from decimal import Decimal

import pytest

from psycopg.rows import namedtuple_row

pytestmark = pytest.mark.asyncio


@pytest.mark.parametrize("value", [None, 0, 3])
async def test_prepare_threshold_init(aconn_cls, dsn, value):
    async with await aconn_cls.connect(dsn, prepare_threshold=value) as conn:
        assert conn.prepare_threshold == value


async def test_dont_prepare(aconn):
    cur = aconn.cursor()
    for i in range(10):
        await cur.execute("select %s::int", [i], prepare=False)

    stmts = await get_prepared_statements(aconn)
    assert len(stmts) == 0


async def test_do_prepare(aconn):
    cur = aconn.cursor()
    await cur.execute("select %s::int", [10], prepare=True)
    stmts = await get_prepared_statements(aconn)
    assert len(stmts) == 1


async def test_auto_prepare(aconn):
    res = []
    for i in range(10):
        await aconn.execute("select %s::int", [0])
        stmts = await get_prepared_statements(aconn)
        res.append(len(stmts))

    assert res == [0] * 5 + [1] * 5


async def test_dont_prepare_conn(aconn):
    for i in range(10):
        await aconn.execute("select %s::int", [i], prepare=False)

    stmts = await get_prepared_statements(aconn)
    assert len(stmts) == 0


async def test_do_prepare_conn(aconn):
    await aconn.execute("select %s::int", [10], prepare=True)
    stmts = await get_prepared_statements(aconn)
    assert len(stmts) == 1


async def test_auto_prepare_conn(aconn):
    res = []
    for i in range(10):
        await aconn.execute("select %s", [0])
        stmts = await get_prepared_statements(aconn)
        res.append(len(stmts))

    assert res == [0] * 5 + [1] * 5


async def test_prepare_disable(aconn):
    aconn.prepare_threshold = None
    res = []
    for i in range(10):
        await aconn.execute("select %s", [0])
        stmts = await get_prepared_statements(aconn)
        res.append(len(stmts))

    assert res == [0] * 10
    assert not aconn._prepared._names
    assert not aconn._prepared._counts


async def test_no_prepare_multi(aconn):
    res = []
    for i in range(10):
        await aconn.execute("select 1; select 2")
        stmts = await get_prepared_statements(aconn)
        res.append(len(stmts))

    assert res == [0] * 10


async def test_no_prepare_error(aconn):
    await aconn.set_autocommit(True)
    for i in range(10):
        with pytest.raises(aconn.ProgrammingError):
            await aconn.execute("select wat")

    stmts = await get_prepared_statements(aconn)
    assert len(stmts) == 0


@pytest.mark.parametrize(
    "query",
    [
        "create table test_no_prepare ()",
        pytest.param("notify foo, 'bar'", marks=pytest.mark.crdb_skip("notify")),
        "set timezone = utc",
        "select num from prepared_test",
        "insert into prepared_test (num) values (1)",
        "update prepared_test set num = num * 2",
        "delete from prepared_test where num > 10",
    ],
)
async def test_misc_statement(aconn, query):
    await aconn.execute("create table prepared_test (num int)", prepare=False)
    aconn.prepare_threshold = 0
    await aconn.execute(query)
    stmts = await get_prepared_statements(aconn)
    assert len(stmts) == 1


async def test_params_types(aconn):
    await aconn.execute(
        "select %s, %s, %s",
        [dt.date(2020, 12, 10), 42, Decimal(42)],
        prepare=True,
    )
    stmts = await get_prepared_statements(aconn)
    want = [stmt.parameter_types for stmt in stmts]
    assert want == [["date", "smallint", "numeric"]]


async def test_evict_lru(aconn):
    aconn.prepared_max = 5
    for i in range(10):
        await aconn.execute("select 'a'")
        await aconn.execute(f"select {i}")

    assert len(aconn._prepared._names) == 1
    assert aconn._prepared._names[b"select 'a'", ()] == b"_pg3_0"
    for i in [9, 8, 7, 6]:
        assert aconn._prepared._counts[f"select {i}".encode(), ()] == 1

    stmts = await get_prepared_statements(aconn)
    assert len(stmts) == 1
    assert stmts[0].statement == "select 'a'"


async def test_evict_lru_deallocate(aconn):
    aconn.prepared_max = 5
    aconn.prepare_threshold = 0
    for i in range(10):
        await aconn.execute("select 'a'")
        await aconn.execute(f"select {i}")

    assert len(aconn._prepared._names) == 5
    for j in [9, 8, 7, 6, "'a'"]:
        name = aconn._prepared._names[f"select {j}".encode(), ()]
        assert name.startswith(b"_pg3_")

    stmts = await get_prepared_statements(aconn)
    stmts.sort(key=lambda rec: rec.prepare_time)
    got = [stmt.statement for stmt in stmts]
    assert got == [f"select {i}" for i in ["'a'", 6, 7, 8, 9]]


async def test_different_types(aconn):
    aconn.prepare_threshold = 0
    await aconn.execute("select %s", [None])
    await aconn.execute("select %s", [dt.date(2000, 1, 1)])
    await aconn.execute("select %s", [42])
    await aconn.execute("select %s", [41])
    await aconn.execute("select %s", [dt.date(2000, 1, 2)])

    stmts = await get_prepared_statements(aconn)
    stmts.sort(key=lambda rec: rec.prepare_time)
    got = [stmt.parameter_types for stmt in stmts]
    assert got == [["text"], ["date"], ["smallint"]]


async def test_untyped_json(aconn):
    aconn.prepare_threshold = 1
    await aconn.execute("create table testjson(data jsonb)")
    for i in range(2):
        await aconn.execute("insert into testjson (data) values (%s)", ["{}"])

    stmts = await get_prepared_statements(aconn)
    got = [stmt.parameter_types for stmt in stmts]
    assert got == [["jsonb"]]


async def get_prepared_statements(aconn):
    cur = aconn.cursor(row_factory=namedtuple_row)
    await cur.execute(
        r"""
select name,
    regexp_replace(statement, 'prepare _pg3_\d+ as ', '', 'i') as statement,
    prepare_time,
    parameter_types
from pg_prepared_statements
where name != ''
        """,
        prepare=False,
    )
    return await cur.fetchall()