summaryrefslogtreecommitdiffstats
path: root/tests/test_server_cursor_async.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-04 17:41:08 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-04 17:41:08 +0000
commit506ed8899b3a97e512be3fd6d44d5b11463bf9bf (patch)
tree808913770c5e6935d3714058c2a066c57b4632ec /tests/test_server_cursor_async.py
parentInitial commit. (diff)
downloadpsycopg3-506ed8899b3a97e512be3fd6d44d5b11463bf9bf.tar.xz
psycopg3-506ed8899b3a97e512be3fd6d44d5b11463bf9bf.zip
Adding upstream version 3.1.7.upstream/3.1.7upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/test_server_cursor_async.py')
-rw-r--r--tests/test_server_cursor_async.py543
1 files changed, 543 insertions, 0 deletions
diff --git a/tests/test_server_cursor_async.py b/tests/test_server_cursor_async.py
new file mode 100644
index 0000000..21b4345
--- /dev/null
+++ b/tests/test_server_cursor_async.py
@@ -0,0 +1,543 @@
+import pytest
+
+import psycopg
+from psycopg import rows, errors as e
+from psycopg.pq import Format
+
+pytestmark = [
+ pytest.mark.asyncio,
+ pytest.mark.crdb_skip("server-side cursor"),
+]
+
+
+async def test_init_row_factory(aconn):
+ async with psycopg.AsyncServerCursor(aconn, "foo") as cur:
+ assert cur.name == "foo"
+ assert cur.connection is aconn
+ assert cur.row_factory is aconn.row_factory
+
+ aconn.row_factory = rows.dict_row
+
+ async with psycopg.AsyncServerCursor(aconn, "bar") as cur:
+ assert cur.name == "bar"
+ assert cur.row_factory is rows.dict_row # type: ignore
+
+ async with psycopg.AsyncServerCursor(
+ aconn, "baz", row_factory=rows.namedtuple_row
+ ) as cur:
+ assert cur.name == "baz"
+ assert cur.row_factory is rows.namedtuple_row # type: ignore
+
+
+async def test_init_params(aconn):
+ async with psycopg.AsyncServerCursor(aconn, "foo") as cur:
+ assert cur.scrollable is None
+ assert cur.withhold is False
+
+ async with psycopg.AsyncServerCursor(
+ aconn, "bar", withhold=True, scrollable=False
+ ) as cur:
+ assert cur.scrollable is False
+ assert cur.withhold is True
+
+
+@pytest.mark.crdb_skip("cursor invalid name")
+async def test_funny_name(aconn):
+ cur = aconn.cursor("1-2-3")
+ await cur.execute("select generate_series(1, 3) as bar")
+ assert await cur.fetchall() == [(1,), (2,), (3,)]
+ assert cur.name == "1-2-3"
+ await cur.close()
+
+
+async def test_repr(aconn):
+ cur = aconn.cursor("my-name")
+ assert "psycopg.AsyncServerCursor" in str(cur)
+ assert "my-name" in repr(cur)
+ await cur.close()
+
+
+async def test_connection(aconn):
+ cur = aconn.cursor("foo")
+ assert cur.connection is aconn
+ await cur.close()
+
+
+async def test_description(aconn):
+ cur = aconn.cursor("foo")
+ assert cur.name == "foo"
+ await cur.execute("select generate_series(1, 10)::int4 as bar")
+ assert len(cur.description) == 1
+ assert cur.description[0].name == "bar"
+ assert cur.description[0].type_code == cur.adapters.types["int4"].oid
+ assert cur.pgresult.ntuples == 0
+ await cur.close()
+
+
+async def test_format(aconn):
+ cur = aconn.cursor("foo")
+ assert cur.format == Format.TEXT
+ await cur.close()
+
+ cur = aconn.cursor("foo", binary=True)
+ assert cur.format == Format.BINARY
+ await cur.close()
+
+
+async def test_query_params(aconn):
+ async with aconn.cursor("foo") as cur:
+ assert cur._query is None
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert cur._query is not None
+ assert b"declare" in cur._query.query.lower()
+ assert b"(1, $1)" in cur._query.query.lower()
+ assert cur._query.params == [bytes([0, 3])] # 3 as binary int2
+
+
+async def test_binary_cursor_execute(aconn):
+ cur = aconn.cursor("foo", binary=True)
+ await cur.execute("select generate_series(1, 2)::int4")
+ assert (await cur.fetchone()) == (1,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01"
+ assert (await cur.fetchone()) == (2,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02"
+ await cur.close()
+
+
+async def test_execute_binary(aconn):
+ cur = aconn.cursor("foo")
+ await cur.execute("select generate_series(1, 2)::int4", binary=True)
+ assert (await cur.fetchone()) == (1,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01"
+ assert (await cur.fetchone()) == (2,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02"
+
+ await cur.execute("select generate_series(1, 1)")
+ assert (await cur.fetchone()) == (1,)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"1"
+ await cur.close()
+
+
+async def test_binary_cursor_text_override(aconn):
+ cur = aconn.cursor("foo", binary=True)
+ await cur.execute("select generate_series(1, 2)", binary=False)
+ assert (await cur.fetchone()) == (1,)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"1"
+ assert (await cur.fetchone()) == (2,)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"2"
+
+ await cur.execute("select generate_series(1, 2)::int4")
+ assert (await cur.fetchone()) == (1,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01"
+ await cur.close()
+
+
+async def test_close(aconn, recwarn):
+ if aconn.info.transaction_status == aconn.TransactionStatus.INTRANS:
+ # connection dirty from previous failure
+ await aconn.execute("close foo")
+ recwarn.clear()
+ cur = aconn.cursor("foo")
+ await cur.execute("select generate_series(1, 10) as bar")
+ await cur.close()
+ assert cur.closed
+
+ assert not await (
+ await aconn.execute("select * from pg_cursors where name = 'foo'")
+ ).fetchone()
+ del cur
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+async def test_close_idempotent(aconn):
+ cur = aconn.cursor("foo")
+ await cur.execute("select 1")
+ await cur.fetchall()
+ await cur.close()
+ await cur.close()
+
+
+async def test_close_broken_conn(aconn):
+ cur = aconn.cursor("foo")
+ await aconn.close()
+ await cur.close()
+ assert cur.closed
+
+
+async def test_cursor_close_fetchone(aconn):
+ cur = aconn.cursor("foo")
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ await cur.execute(query)
+ for _ in range(5):
+ await cur.fetchone()
+
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(e.InterfaceError):
+ await cur.fetchone()
+
+
+async def test_cursor_close_fetchmany(aconn):
+ cur = aconn.cursor("foo")
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ await cur.execute(query)
+ assert len(await cur.fetchmany(2)) == 2
+
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(e.InterfaceError):
+ await cur.fetchmany(2)
+
+
+async def test_cursor_close_fetchall(aconn):
+ cur = aconn.cursor("foo")
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ await cur.execute(query)
+ assert len(await cur.fetchall()) == 10
+
+ await cur.close()
+ assert cur.closed
+
+ with pytest.raises(e.InterfaceError):
+ await cur.fetchall()
+
+
+async def test_close_noop(aconn, recwarn):
+ recwarn.clear()
+ cur = aconn.cursor("foo")
+ await cur.close()
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+async def test_close_on_error(aconn):
+ cur = aconn.cursor("foo")
+ await cur.execute("select 1")
+ with pytest.raises(e.ProgrammingError):
+ await aconn.execute("wat")
+ assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR
+ await cur.close()
+
+
+async def test_pgresult(aconn):
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert cur.pgresult
+ await cur.close()
+ assert not cur.pgresult
+
+
+async def test_context(aconn, recwarn):
+ recwarn.clear()
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, 10) as bar")
+
+ assert cur.closed
+ assert not await (
+ await aconn.execute("select * from pg_cursors where name = 'foo'")
+ ).fetchone()
+ del cur
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+async def test_close_no_clobber(aconn):
+ with pytest.raises(e.DivisionByZero):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select 1 / %s", (0,))
+ await cur.fetchall()
+
+
+async def test_warn_close(aconn, recwarn):
+ recwarn.clear()
+ cur = aconn.cursor("foo")
+ await cur.execute("select generate_series(1, 10) as bar")
+ del cur
+ assert ".close()" in str(recwarn.pop(ResourceWarning).message)
+
+
+async def test_execute_reuse(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as foo", (3,))
+ assert await cur.fetchone() == (1,)
+
+ await cur.execute("select %s::text as bar, %s::text as baz", ("hello", "world"))
+ assert await cur.fetchone() == ("hello", "world")
+ assert cur.description[0].name == "bar"
+ assert cur.description[0].type_code == cur.adapters.types["text"].oid
+ assert cur.description[1].name == "baz"
+
+
+@pytest.mark.parametrize(
+ "stmt", ["", "wat", "create table ssc ()", "select 1; select 2"]
+)
+async def test_execute_error(aconn, stmt):
+ cur = aconn.cursor("foo")
+ with pytest.raises(e.ProgrammingError):
+ await cur.execute(stmt)
+ await cur.close()
+
+
+async def test_executemany(aconn):
+ cur = aconn.cursor("foo")
+ with pytest.raises(e.NotSupportedError):
+ await cur.executemany("select %s", [(1,), (2,)])
+ await cur.close()
+
+
+async def test_fetchone(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (2,))
+ assert await cur.fetchone() == (1,)
+ assert await cur.fetchone() == (2,)
+ assert await cur.fetchone() is None
+
+
+async def test_fetchmany(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (5,))
+ assert await cur.fetchmany(3) == [(1,), (2,), (3,)]
+ assert await cur.fetchone() == (4,)
+ assert await cur.fetchmany(3) == [(5,)]
+ assert await cur.fetchmany(3) == []
+
+
+async def test_fetchall(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert await cur.fetchall() == [(1,), (2,), (3,)]
+ assert await cur.fetchall() == []
+
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert await cur.fetchone() == (1,)
+ assert await cur.fetchall() == [(2,), (3,)]
+ assert await cur.fetchall() == []
+
+
+async def test_nextset(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert not cur.nextset()
+
+
+async def test_no_result(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar where false", (3,))
+ assert len(cur.description) == 1
+ assert (await cur.fetchall()) == []
+
+
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+async def test_standard_row_factory(aconn, row_factory):
+ if row_factory == "tuple_row":
+ getter = lambda r: r[0] # noqa: E731
+ elif row_factory == "dict_row":
+ getter = lambda r: r["bar"] # noqa: E731
+ elif row_factory == "namedtuple_row":
+ getter = lambda r: r.bar # noqa: E731
+ else:
+ assert False, row_factory
+
+ row_factory = getattr(rows, row_factory)
+ async with aconn.cursor("foo", row_factory=row_factory) as cur:
+ await cur.execute("select generate_series(1, 5) as bar")
+ assert getter(await cur.fetchone()) == 1
+ assert list(map(getter, await cur.fetchmany(2))) == [2, 3]
+ assert list(map(getter, await cur.fetchall())) == [4, 5]
+
+
+@pytest.mark.crdb_skip("scroll cursor")
+async def test_row_factory(aconn):
+ n = 0
+
+ def my_row_factory(cur):
+ nonlocal n
+ n += 1
+ return lambda values: [n] + [-v for v in values]
+
+ cur = aconn.cursor("foo", row_factory=my_row_factory, scrollable=True)
+ await cur.execute("select generate_series(1, 3) as x")
+ recs = await cur.fetchall()
+ await cur.scroll(0, "absolute")
+ while True:
+ rec = await cur.fetchone()
+ if not rec:
+ break
+ recs.append(rec)
+ assert recs == [[1, -1], [1, -2], [1, -3]] * 2
+
+ await cur.scroll(0, "absolute")
+ cur.row_factory = rows.dict_row
+ assert await cur.fetchone() == {"x": 1}
+ await cur.close()
+
+
+async def test_rownumber(aconn):
+ cur = aconn.cursor("foo")
+ assert cur.rownumber is None
+
+ await cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rownumber == 0
+
+ await cur.fetchone()
+ assert cur.rownumber == 1
+ await cur.fetchone()
+ assert cur.rownumber == 2
+ await cur.fetchmany(10)
+ assert cur.rownumber == 12
+ await cur.fetchall()
+ assert cur.rownumber == 42
+ await cur.close()
+
+
+async def test_iter(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ recs = []
+ async for rec in cur:
+ recs.append(rec)
+ assert recs == [(1,), (2,), (3,)]
+
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert await cur.fetchone() == (1,)
+ recs = []
+ async for rec in cur:
+ recs.append(rec)
+ assert recs == [(2,), (3,)]
+
+
+async def test_iter_rownumber(aconn):
+ async with aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ async for row in cur:
+ assert cur.rownumber == row[0]
+
+
+async def test_itersize(aconn, acommands):
+ async with aconn.cursor("foo") as cur:
+ assert cur.itersize == 100
+ cur.itersize = 2
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ acommands.popall() # flush begin and other noise
+
+ async for rec in cur:
+ pass
+ cmds = acommands.popall()
+ assert len(cmds) == 2
+ for cmd in cmds:
+ assert "fetch forward 2" in cmd.lower()
+
+
+async def test_cant_scroll_by_default(aconn):
+ cur = aconn.cursor("tmp")
+ assert cur.scrollable is None
+ with pytest.raises(e.ProgrammingError):
+ await cur.scroll(0)
+ await cur.close()
+
+
+@pytest.mark.crdb_skip("scroll cursor")
+async def test_scroll(aconn):
+ cur = aconn.cursor("tmp", scrollable=True)
+ await cur.execute("select generate_series(0,9)")
+ await cur.scroll(2)
+ assert await cur.fetchone() == (2,)
+ await cur.scroll(2)
+ assert await cur.fetchone() == (5,)
+ await cur.scroll(2, mode="relative")
+ assert await cur.fetchone() == (8,)
+ await cur.scroll(9, mode="absolute")
+ assert await cur.fetchone() == (9,)
+
+ with pytest.raises(ValueError):
+ await cur.scroll(9, mode="wat")
+ await cur.close()
+
+
+@pytest.mark.crdb_skip("scroll cursor")
+async def test_scrollable(aconn):
+ curs = aconn.cursor("foo", scrollable=True)
+ assert curs.scrollable is True
+ await curs.execute("select generate_series(0, 5)")
+ await curs.scroll(5)
+ for i in range(4, -1, -1):
+ await curs.scroll(-1)
+ assert i == (await curs.fetchone())[0]
+ await curs.scroll(-1)
+ await curs.close()
+
+
+async def test_non_scrollable(aconn):
+ curs = aconn.cursor("foo", scrollable=False)
+ assert curs.scrollable is False
+ await curs.execute("select generate_series(0, 5)")
+ await curs.scroll(5)
+ with pytest.raises(e.OperationalError):
+ await curs.scroll(-1)
+ await curs.close()
+
+
+@pytest.mark.parametrize("kwargs", [{}, {"withhold": False}])
+async def test_no_hold(aconn, kwargs):
+ async with aconn.cursor("foo", **kwargs) as curs:
+ assert curs.withhold is False
+ await curs.execute("select generate_series(0, 2)")
+ assert await curs.fetchone() == (0,)
+ await aconn.commit()
+ with pytest.raises(e.InvalidCursorName):
+ await curs.fetchone()
+
+
+@pytest.mark.crdb_skip("cursor with hold")
+async def test_hold(aconn):
+ async with aconn.cursor("foo", withhold=True) as curs:
+ assert curs.withhold is True
+ await curs.execute("select generate_series(0, 5)")
+ assert await curs.fetchone() == (0,)
+ await aconn.commit()
+ assert await curs.fetchone() == (1,)
+
+
+@pytest.mark.parametrize("row_factory", ["tuple_row", "namedtuple_row"])
+async def test_steal_cursor(aconn, row_factory):
+ cur1 = aconn.cursor()
+ await cur1.execute(
+ "declare test cursor without hold for select generate_series(1, 6) as s"
+ )
+
+ cur2 = aconn.cursor("test", row_factory=getattr(rows, row_factory))
+ # can call fetch without execute
+ rec = await cur2.fetchone()
+ assert rec == (1,)
+ if row_factory == "namedtuple_row":
+ assert rec.s == 1
+ assert await cur2.fetchmany(3) == [(2,), (3,), (4,)]
+ assert await cur2.fetchall() == [(5,), (6,)]
+ await cur2.close()
+
+
+async def test_stolen_cursor_close(aconn):
+ cur1 = aconn.cursor()
+ await cur1.execute("declare test cursor for select generate_series(1, 6)")
+ cur2 = aconn.cursor("test")
+ await cur2.close()
+
+ await cur1.execute("declare test cursor for select generate_series(1, 6)")
+ cur2 = aconn.cursor("test")
+ await cur2.close()