summaryrefslogtreecommitdiffstats
path: root/tests/test_prepared_async.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_prepared_async.py')
-rw-r--r--tests/test_prepared_async.py207
1 files changed, 207 insertions, 0 deletions
diff --git a/tests/test_prepared_async.py b/tests/test_prepared_async.py
new file mode 100644
index 0000000..84d948f
--- /dev/null
+++ b/tests/test_prepared_async.py
@@ -0,0 +1,207 @@
+"""
+Prepared statements tests on async connections
+"""
+
+import datetime as dt
+from decimal import Decimal
+
+import pytest
+
+from psycopg.rows import namedtuple_row
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.mark.parametrize("value", [None, 0, 3])
+async def test_prepare_threshold_init(aconn_cls, dsn, value):
+ async with await aconn_cls.connect(dsn, prepare_threshold=value) as conn:
+ assert conn.prepare_threshold == value
+
+
+async def test_dont_prepare(aconn):
+ cur = aconn.cursor()
+ for i in range(10):
+ await cur.execute("select %s::int", [i], prepare=False)
+
+ stmts = await get_prepared_statements(aconn)
+ assert len(stmts) == 0
+
+
+async def test_do_prepare(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select %s::int", [10], prepare=True)
+ stmts = await get_prepared_statements(aconn)
+ assert len(stmts) == 1
+
+
+async def test_auto_prepare(aconn):
+ res = []
+ for i in range(10):
+ await aconn.execute("select %s::int", [0])
+ stmts = await get_prepared_statements(aconn)
+ res.append(len(stmts))
+
+ assert res == [0] * 5 + [1] * 5
+
+
+async def test_dont_prepare_conn(aconn):
+ for i in range(10):
+ await aconn.execute("select %s::int", [i], prepare=False)
+
+ stmts = await get_prepared_statements(aconn)
+ assert len(stmts) == 0
+
+
+async def test_do_prepare_conn(aconn):
+ await aconn.execute("select %s::int", [10], prepare=True)
+ stmts = await get_prepared_statements(aconn)
+ assert len(stmts) == 1
+
+
+async def test_auto_prepare_conn(aconn):
+ res = []
+ for i in range(10):
+ await aconn.execute("select %s", [0])
+ stmts = await get_prepared_statements(aconn)
+ res.append(len(stmts))
+
+ assert res == [0] * 5 + [1] * 5
+
+
+async def test_prepare_disable(aconn):
+ aconn.prepare_threshold = None
+ res = []
+ for i in range(10):
+ await aconn.execute("select %s", [0])
+ stmts = await get_prepared_statements(aconn)
+ res.append(len(stmts))
+
+ assert res == [0] * 10
+ assert not aconn._prepared._names
+ assert not aconn._prepared._counts
+
+
+async def test_no_prepare_multi(aconn):
+ res = []
+ for i in range(10):
+ await aconn.execute("select 1; select 2")
+ stmts = await get_prepared_statements(aconn)
+ res.append(len(stmts))
+
+ assert res == [0] * 10
+
+
+async def test_no_prepare_error(aconn):
+ await aconn.set_autocommit(True)
+ for i in range(10):
+ with pytest.raises(aconn.ProgrammingError):
+ await aconn.execute("select wat")
+
+ stmts = await get_prepared_statements(aconn)
+ assert len(stmts) == 0
+
+
+@pytest.mark.parametrize(
+ "query",
+ [
+ "create table test_no_prepare ()",
+ pytest.param("notify foo, 'bar'", marks=pytest.mark.crdb_skip("notify")),
+ "set timezone = utc",
+ "select num from prepared_test",
+ "insert into prepared_test (num) values (1)",
+ "update prepared_test set num = num * 2",
+ "delete from prepared_test where num > 10",
+ ],
+)
+async def test_misc_statement(aconn, query):
+ await aconn.execute("create table prepared_test (num int)", prepare=False)
+ aconn.prepare_threshold = 0
+ await aconn.execute(query)
+ stmts = await get_prepared_statements(aconn)
+ assert len(stmts) == 1
+
+
+async def test_params_types(aconn):
+ await aconn.execute(
+ "select %s, %s, %s",
+ [dt.date(2020, 12, 10), 42, Decimal(42)],
+ prepare=True,
+ )
+ stmts = await get_prepared_statements(aconn)
+ want = [stmt.parameter_types for stmt in stmts]
+ assert want == [["date", "smallint", "numeric"]]
+
+
+async def test_evict_lru(aconn):
+ aconn.prepared_max = 5
+ for i in range(10):
+ await aconn.execute("select 'a'")
+ await aconn.execute(f"select {i}")
+
+ assert len(aconn._prepared._names) == 1
+ assert aconn._prepared._names[b"select 'a'", ()] == b"_pg3_0"
+ for i in [9, 8, 7, 6]:
+ assert aconn._prepared._counts[f"select {i}".encode(), ()] == 1
+
+ stmts = await get_prepared_statements(aconn)
+ assert len(stmts) == 1
+ assert stmts[0].statement == "select 'a'"
+
+
+async def test_evict_lru_deallocate(aconn):
+ aconn.prepared_max = 5
+ aconn.prepare_threshold = 0
+ for i in range(10):
+ await aconn.execute("select 'a'")
+ await aconn.execute(f"select {i}")
+
+ assert len(aconn._prepared._names) == 5
+ for j in [9, 8, 7, 6, "'a'"]:
+ name = aconn._prepared._names[f"select {j}".encode(), ()]
+ assert name.startswith(b"_pg3_")
+
+ stmts = await get_prepared_statements(aconn)
+ stmts.sort(key=lambda rec: rec.prepare_time)
+ got = [stmt.statement for stmt in stmts]
+ assert got == [f"select {i}" for i in ["'a'", 6, 7, 8, 9]]
+
+
+async def test_different_types(aconn):
+ aconn.prepare_threshold = 0
+ await aconn.execute("select %s", [None])
+ await aconn.execute("select %s", [dt.date(2000, 1, 1)])
+ await aconn.execute("select %s", [42])
+ await aconn.execute("select %s", [41])
+ await aconn.execute("select %s", [dt.date(2000, 1, 2)])
+
+ stmts = await get_prepared_statements(aconn)
+ stmts.sort(key=lambda rec: rec.prepare_time)
+ got = [stmt.parameter_types for stmt in stmts]
+ assert got == [["text"], ["date"], ["smallint"]]
+
+
+async def test_untyped_json(aconn):
+ aconn.prepare_threshold = 1
+ await aconn.execute("create table testjson(data jsonb)")
+ for i in range(2):
+ await aconn.execute("insert into testjson (data) values (%s)", ["{}"])
+
+ stmts = await get_prepared_statements(aconn)
+ got = [stmt.parameter_types for stmt in stmts]
+ assert got == [["jsonb"]]
+
+
+async def get_prepared_statements(aconn):
+ cur = aconn.cursor(row_factory=namedtuple_row)
+ await cur.execute(
+ r"""
+select name,
+ regexp_replace(statement, 'prepare _pg3_\d+ as ', '', 'i') as statement,
+ prepare_time,
+ parameter_types
+from pg_prepared_statements
+where name != ''
+ """,
+ prepare=False,
+ )
+ return await cur.fetchall()