diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-04 17:41:08 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-04 17:41:08 +0000 |
commit | 506ed8899b3a97e512be3fd6d44d5b11463bf9bf (patch) | |
tree | 808913770c5e6935d3714058c2a066c57b4632ec /tests/test_prepared_async.py | |
parent | Initial commit. (diff) | |
download | psycopg3-506ed8899b3a97e512be3fd6d44d5b11463bf9bf.tar.xz psycopg3-506ed8899b3a97e512be3fd6d44d5b11463bf9bf.zip |
Adding upstream version 3.1.7.upstream/3.1.7upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/test_prepared_async.py')
-rw-r--r-- | tests/test_prepared_async.py | 207 |
1 files changed, 207 insertions, 0 deletions
diff --git a/tests/test_prepared_async.py b/tests/test_prepared_async.py new file mode 100644 index 0000000..84d948f --- /dev/null +++ b/tests/test_prepared_async.py @@ -0,0 +1,207 @@ +""" +Prepared statements tests on async connections +""" + +import datetime as dt +from decimal import Decimal + +import pytest + +from psycopg.rows import namedtuple_row + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.parametrize("value", [None, 0, 3]) +async def test_prepare_threshold_init(aconn_cls, dsn, value): + async with await aconn_cls.connect(dsn, prepare_threshold=value) as conn: + assert conn.prepare_threshold == value + + +async def test_dont_prepare(aconn): + cur = aconn.cursor() + for i in range(10): + await cur.execute("select %s::int", [i], prepare=False) + + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 0 + + +async def test_do_prepare(aconn): + cur = aconn.cursor() + await cur.execute("select %s::int", [10], prepare=True) + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 1 + + +async def test_auto_prepare(aconn): + res = [] + for i in range(10): + await aconn.execute("select %s::int", [0]) + stmts = await get_prepared_statements(aconn) + res.append(len(stmts)) + + assert res == [0] * 5 + [1] * 5 + + +async def test_dont_prepare_conn(aconn): + for i in range(10): + await aconn.execute("select %s::int", [i], prepare=False) + + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 0 + + +async def test_do_prepare_conn(aconn): + await aconn.execute("select %s::int", [10], prepare=True) + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 1 + + +async def test_auto_prepare_conn(aconn): + res = [] + for i in range(10): + await aconn.execute("select %s", [0]) + stmts = await get_prepared_statements(aconn) + res.append(len(stmts)) + + assert res == [0] * 5 + [1] * 5 + + +async def test_prepare_disable(aconn): + aconn.prepare_threshold = None + res = [] + for i in range(10): + await aconn.execute("select %s", [0]) + stmts = await get_prepared_statements(aconn) + res.append(len(stmts)) + + assert res == [0] * 10 + assert not aconn._prepared._names + assert not aconn._prepared._counts + + +async def test_no_prepare_multi(aconn): + res = [] + for i in range(10): + await aconn.execute("select 1; select 2") + stmts = await get_prepared_statements(aconn) + res.append(len(stmts)) + + assert res == [0] * 10 + + +async def test_no_prepare_error(aconn): + await aconn.set_autocommit(True) + for i in range(10): + with pytest.raises(aconn.ProgrammingError): + await aconn.execute("select wat") + + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 0 + + +@pytest.mark.parametrize( + "query", + [ + "create table test_no_prepare ()", + pytest.param("notify foo, 'bar'", marks=pytest.mark.crdb_skip("notify")), + "set timezone = utc", + "select num from prepared_test", + "insert into prepared_test (num) values (1)", + "update prepared_test set num = num * 2", + "delete from prepared_test where num > 10", + ], +) +async def test_misc_statement(aconn, query): + await aconn.execute("create table prepared_test (num int)", prepare=False) + aconn.prepare_threshold = 0 + await aconn.execute(query) + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 1 + + +async def test_params_types(aconn): + await aconn.execute( + "select %s, %s, %s", + [dt.date(2020, 12, 10), 42, Decimal(42)], + prepare=True, + ) + stmts = await get_prepared_statements(aconn) + want = [stmt.parameter_types for stmt in stmts] + assert want == [["date", "smallint", "numeric"]] + + +async def test_evict_lru(aconn): + aconn.prepared_max = 5 + for i in range(10): + await aconn.execute("select 'a'") + await aconn.execute(f"select {i}") + + assert len(aconn._prepared._names) == 1 + assert aconn._prepared._names[b"select 'a'", ()] == b"_pg3_0" + for i in [9, 8, 7, 6]: + assert aconn._prepared._counts[f"select {i}".encode(), ()] == 1 + + stmts = await get_prepared_statements(aconn) + assert len(stmts) == 1 + assert stmts[0].statement == "select 'a'" + + +async def test_evict_lru_deallocate(aconn): + aconn.prepared_max = 5 + aconn.prepare_threshold = 0 + for i in range(10): + await aconn.execute("select 'a'") + await aconn.execute(f"select {i}") + + assert len(aconn._prepared._names) == 5 + for j in [9, 8, 7, 6, "'a'"]: + name = aconn._prepared._names[f"select {j}".encode(), ()] + assert name.startswith(b"_pg3_") + + stmts = await get_prepared_statements(aconn) + stmts.sort(key=lambda rec: rec.prepare_time) + got = [stmt.statement for stmt in stmts] + assert got == [f"select {i}" for i in ["'a'", 6, 7, 8, 9]] + + +async def test_different_types(aconn): + aconn.prepare_threshold = 0 + await aconn.execute("select %s", [None]) + await aconn.execute("select %s", [dt.date(2000, 1, 1)]) + await aconn.execute("select %s", [42]) + await aconn.execute("select %s", [41]) + await aconn.execute("select %s", [dt.date(2000, 1, 2)]) + + stmts = await get_prepared_statements(aconn) + stmts.sort(key=lambda rec: rec.prepare_time) + got = [stmt.parameter_types for stmt in stmts] + assert got == [["text"], ["date"], ["smallint"]] + + +async def test_untyped_json(aconn): + aconn.prepare_threshold = 1 + await aconn.execute("create table testjson(data jsonb)") + for i in range(2): + await aconn.execute("insert into testjson (data) values (%s)", ["{}"]) + + stmts = await get_prepared_statements(aconn) + got = [stmt.parameter_types for stmt in stmts] + assert got == [["jsonb"]] + + +async def get_prepared_statements(aconn): + cur = aconn.cursor(row_factory=namedtuple_row) + await cur.execute( + r""" +select name, + regexp_replace(statement, 'prepare _pg3_\d+ as ', '', 'i') as statement, + prepare_time, + parameter_types +from pg_prepared_statements +where name != '' + """, + prepare=False, + ) + return await cur.fetchall() |