diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-04 17:41:08 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-04 17:41:08 +0000 |
commit | 506ed8899b3a97e512be3fd6d44d5b11463bf9bf (patch) | |
tree | 808913770c5e6935d3714058c2a066c57b4632ec /tests/test_server_cursor_async.py | |
parent | Initial commit. (diff) | |
download | psycopg3-506ed8899b3a97e512be3fd6d44d5b11463bf9bf.tar.xz psycopg3-506ed8899b3a97e512be3fd6d44d5b11463bf9bf.zip |
Adding upstream version 3.1.7.upstream/3.1.7upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/test_server_cursor_async.py')
-rw-r--r-- | tests/test_server_cursor_async.py | 543 |
1 files changed, 543 insertions, 0 deletions
diff --git a/tests/test_server_cursor_async.py b/tests/test_server_cursor_async.py new file mode 100644 index 0000000..21b4345 --- /dev/null +++ b/tests/test_server_cursor_async.py @@ -0,0 +1,543 @@ +import pytest + +import psycopg +from psycopg import rows, errors as e +from psycopg.pq import Format + +pytestmark = [ + pytest.mark.asyncio, + pytest.mark.crdb_skip("server-side cursor"), +] + + +async def test_init_row_factory(aconn): + async with psycopg.AsyncServerCursor(aconn, "foo") as cur: + assert cur.name == "foo" + assert cur.connection is aconn + assert cur.row_factory is aconn.row_factory + + aconn.row_factory = rows.dict_row + + async with psycopg.AsyncServerCursor(aconn, "bar") as cur: + assert cur.name == "bar" + assert cur.row_factory is rows.dict_row # type: ignore + + async with psycopg.AsyncServerCursor( + aconn, "baz", row_factory=rows.namedtuple_row + ) as cur: + assert cur.name == "baz" + assert cur.row_factory is rows.namedtuple_row # type: ignore + + +async def test_init_params(aconn): + async with psycopg.AsyncServerCursor(aconn, "foo") as cur: + assert cur.scrollable is None + assert cur.withhold is False + + async with psycopg.AsyncServerCursor( + aconn, "bar", withhold=True, scrollable=False + ) as cur: + assert cur.scrollable is False + assert cur.withhold is True + + +@pytest.mark.crdb_skip("cursor invalid name") +async def test_funny_name(aconn): + cur = aconn.cursor("1-2-3") + await cur.execute("select generate_series(1, 3) as bar") + assert await cur.fetchall() == [(1,), (2,), (3,)] + assert cur.name == "1-2-3" + await cur.close() + + +async def test_repr(aconn): + cur = aconn.cursor("my-name") + assert "psycopg.AsyncServerCursor" in str(cur) + assert "my-name" in repr(cur) + await cur.close() + + +async def test_connection(aconn): + cur = aconn.cursor("foo") + assert cur.connection is aconn + await cur.close() + + +async def test_description(aconn): + cur = aconn.cursor("foo") + assert cur.name == "foo" + await cur.execute("select generate_series(1, 10)::int4 as bar") + assert len(cur.description) == 1 + assert cur.description[0].name == "bar" + assert cur.description[0].type_code == cur.adapters.types["int4"].oid + assert cur.pgresult.ntuples == 0 + await cur.close() + + +async def test_format(aconn): + cur = aconn.cursor("foo") + assert cur.format == Format.TEXT + await cur.close() + + cur = aconn.cursor("foo", binary=True) + assert cur.format == Format.BINARY + await cur.close() + + +async def test_query_params(aconn): + async with aconn.cursor("foo") as cur: + assert cur._query is None + await cur.execute("select generate_series(1, %s) as bar", (3,)) + assert cur._query is not None + assert b"declare" in cur._query.query.lower() + assert b"(1, $1)" in cur._query.query.lower() + assert cur._query.params == [bytes([0, 3])] # 3 as binary int2 + + +async def test_binary_cursor_execute(aconn): + cur = aconn.cursor("foo", binary=True) + await cur.execute("select generate_series(1, 2)::int4") + assert (await cur.fetchone()) == (1,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01" + assert (await cur.fetchone()) == (2,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02" + await cur.close() + + +async def test_execute_binary(aconn): + cur = aconn.cursor("foo") + await cur.execute("select generate_series(1, 2)::int4", binary=True) + assert (await cur.fetchone()) == (1,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01" + assert (await cur.fetchone()) == (2,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02" + + await cur.execute("select generate_series(1, 1)") + assert (await cur.fetchone()) == (1,) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"1" + await cur.close() + + +async def test_binary_cursor_text_override(aconn): + cur = aconn.cursor("foo", binary=True) + await cur.execute("select generate_series(1, 2)", binary=False) + assert (await cur.fetchone()) == (1,) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"1" + assert (await cur.fetchone()) == (2,) + assert cur.pgresult.fformat(0) == 0 + assert cur.pgresult.get_value(0, 0) == b"2" + + await cur.execute("select generate_series(1, 2)::int4") + assert (await cur.fetchone()) == (1,) + assert cur.pgresult.fformat(0) == 1 + assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01" + await cur.close() + + +async def test_close(aconn, recwarn): + if aconn.info.transaction_status == aconn.TransactionStatus.INTRANS: + # connection dirty from previous failure + await aconn.execute("close foo") + recwarn.clear() + cur = aconn.cursor("foo") + await cur.execute("select generate_series(1, 10) as bar") + await cur.close() + assert cur.closed + + assert not await ( + await aconn.execute("select * from pg_cursors where name = 'foo'") + ).fetchone() + del cur + assert not recwarn, [str(w.message) for w in recwarn.list] + + +async def test_close_idempotent(aconn): + cur = aconn.cursor("foo") + await cur.execute("select 1") + await cur.fetchall() + await cur.close() + await cur.close() + + +async def test_close_broken_conn(aconn): + cur = aconn.cursor("foo") + await aconn.close() + await cur.close() + assert cur.closed + + +async def test_cursor_close_fetchone(aconn): + cur = aconn.cursor("foo") + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + for _ in range(5): + await cur.fetchone() + + await cur.close() + assert cur.closed + + with pytest.raises(e.InterfaceError): + await cur.fetchone() + + +async def test_cursor_close_fetchmany(aconn): + cur = aconn.cursor("foo") + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + assert len(await cur.fetchmany(2)) == 2 + + await cur.close() + assert cur.closed + + with pytest.raises(e.InterfaceError): + await cur.fetchmany(2) + + +async def test_cursor_close_fetchall(aconn): + cur = aconn.cursor("foo") + assert not cur.closed + + query = "select * from generate_series(1, 10)" + await cur.execute(query) + assert len(await cur.fetchall()) == 10 + + await cur.close() + assert cur.closed + + with pytest.raises(e.InterfaceError): + await cur.fetchall() + + +async def test_close_noop(aconn, recwarn): + recwarn.clear() + cur = aconn.cursor("foo") + await cur.close() + assert not recwarn, [str(w.message) for w in recwarn.list] + + +async def test_close_on_error(aconn): + cur = aconn.cursor("foo") + await cur.execute("select 1") + with pytest.raises(e.ProgrammingError): + await aconn.execute("wat") + assert aconn.info.transaction_status == aconn.TransactionStatus.INERROR + await cur.close() + + +async def test_pgresult(aconn): + cur = aconn.cursor() + await cur.execute("select 1") + assert cur.pgresult + await cur.close() + assert not cur.pgresult + + +async def test_context(aconn, recwarn): + recwarn.clear() + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, 10) as bar") + + assert cur.closed + assert not await ( + await aconn.execute("select * from pg_cursors where name = 'foo'") + ).fetchone() + del cur + assert not recwarn, [str(w.message) for w in recwarn.list] + + +async def test_close_no_clobber(aconn): + with pytest.raises(e.DivisionByZero): + async with aconn.cursor("foo") as cur: + await cur.execute("select 1 / %s", (0,)) + await cur.fetchall() + + +async def test_warn_close(aconn, recwarn): + recwarn.clear() + cur = aconn.cursor("foo") + await cur.execute("select generate_series(1, 10) as bar") + del cur + assert ".close()" in str(recwarn.pop(ResourceWarning).message) + + +async def test_execute_reuse(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as foo", (3,)) + assert await cur.fetchone() == (1,) + + await cur.execute("select %s::text as bar, %s::text as baz", ("hello", "world")) + assert await cur.fetchone() == ("hello", "world") + assert cur.description[0].name == "bar" + assert cur.description[0].type_code == cur.adapters.types["text"].oid + assert cur.description[1].name == "baz" + + +@pytest.mark.parametrize( + "stmt", ["", "wat", "create table ssc ()", "select 1; select 2"] +) +async def test_execute_error(aconn, stmt): + cur = aconn.cursor("foo") + with pytest.raises(e.ProgrammingError): + await cur.execute(stmt) + await cur.close() + + +async def test_executemany(aconn): + cur = aconn.cursor("foo") + with pytest.raises(e.NotSupportedError): + await cur.executemany("select %s", [(1,), (2,)]) + await cur.close() + + +async def test_fetchone(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (2,)) + assert await cur.fetchone() == (1,) + assert await cur.fetchone() == (2,) + assert await cur.fetchone() is None + + +async def test_fetchmany(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (5,)) + assert await cur.fetchmany(3) == [(1,), (2,), (3,)] + assert await cur.fetchone() == (4,) + assert await cur.fetchmany(3) == [(5,)] + assert await cur.fetchmany(3) == [] + + +async def test_fetchall(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + assert await cur.fetchall() == [(1,), (2,), (3,)] + assert await cur.fetchall() == [] + + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + assert await cur.fetchone() == (1,) + assert await cur.fetchall() == [(2,), (3,)] + assert await cur.fetchall() == [] + + +async def test_nextset(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + assert not cur.nextset() + + +async def test_no_result(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar where false", (3,)) + assert len(cur.description) == 1 + assert (await cur.fetchall()) == [] + + +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +async def test_standard_row_factory(aconn, row_factory): + if row_factory == "tuple_row": + getter = lambda r: r[0] # noqa: E731 + elif row_factory == "dict_row": + getter = lambda r: r["bar"] # noqa: E731 + elif row_factory == "namedtuple_row": + getter = lambda r: r.bar # noqa: E731 + else: + assert False, row_factory + + row_factory = getattr(rows, row_factory) + async with aconn.cursor("foo", row_factory=row_factory) as cur: + await cur.execute("select generate_series(1, 5) as bar") + assert getter(await cur.fetchone()) == 1 + assert list(map(getter, await cur.fetchmany(2))) == [2, 3] + assert list(map(getter, await cur.fetchall())) == [4, 5] + + +@pytest.mark.crdb_skip("scroll cursor") +async def test_row_factory(aconn): + n = 0 + + def my_row_factory(cur): + nonlocal n + n += 1 + return lambda values: [n] + [-v for v in values] + + cur = aconn.cursor("foo", row_factory=my_row_factory, scrollable=True) + await cur.execute("select generate_series(1, 3) as x") + recs = await cur.fetchall() + await cur.scroll(0, "absolute") + while True: + rec = await cur.fetchone() + if not rec: + break + recs.append(rec) + assert recs == [[1, -1], [1, -2], [1, -3]] * 2 + + await cur.scroll(0, "absolute") + cur.row_factory = rows.dict_row + assert await cur.fetchone() == {"x": 1} + await cur.close() + + +async def test_rownumber(aconn): + cur = aconn.cursor("foo") + assert cur.rownumber is None + + await cur.execute("select 1 from generate_series(1, 42)") + assert cur.rownumber == 0 + + await cur.fetchone() + assert cur.rownumber == 1 + await cur.fetchone() + assert cur.rownumber == 2 + await cur.fetchmany(10) + assert cur.rownumber == 12 + await cur.fetchall() + assert cur.rownumber == 42 + await cur.close() + + +async def test_iter(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + recs = [] + async for rec in cur: + recs.append(rec) + assert recs == [(1,), (2,), (3,)] + + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + assert await cur.fetchone() == (1,) + recs = [] + async for rec in cur: + recs.append(rec) + assert recs == [(2,), (3,)] + + +async def test_iter_rownumber(aconn): + async with aconn.cursor("foo") as cur: + await cur.execute("select generate_series(1, %s) as bar", (3,)) + async for row in cur: + assert cur.rownumber == row[0] + + +async def test_itersize(aconn, acommands): + async with aconn.cursor("foo") as cur: + assert cur.itersize == 100 + cur.itersize = 2 + await cur.execute("select generate_series(1, %s) as bar", (3,)) + acommands.popall() # flush begin and other noise + + async for rec in cur: + pass + cmds = acommands.popall() + assert len(cmds) == 2 + for cmd in cmds: + assert "fetch forward 2" in cmd.lower() + + +async def test_cant_scroll_by_default(aconn): + cur = aconn.cursor("tmp") + assert cur.scrollable is None + with pytest.raises(e.ProgrammingError): + await cur.scroll(0) + await cur.close() + + +@pytest.mark.crdb_skip("scroll cursor") +async def test_scroll(aconn): + cur = aconn.cursor("tmp", scrollable=True) + await cur.execute("select generate_series(0,9)") + await cur.scroll(2) + assert await cur.fetchone() == (2,) + await cur.scroll(2) + assert await cur.fetchone() == (5,) + await cur.scroll(2, mode="relative") + assert await cur.fetchone() == (8,) + await cur.scroll(9, mode="absolute") + assert await cur.fetchone() == (9,) + + with pytest.raises(ValueError): + await cur.scroll(9, mode="wat") + await cur.close() + + +@pytest.mark.crdb_skip("scroll cursor") +async def test_scrollable(aconn): + curs = aconn.cursor("foo", scrollable=True) + assert curs.scrollable is True + await curs.execute("select generate_series(0, 5)") + await curs.scroll(5) + for i in range(4, -1, -1): + await curs.scroll(-1) + assert i == (await curs.fetchone())[0] + await curs.scroll(-1) + await curs.close() + + +async def test_non_scrollable(aconn): + curs = aconn.cursor("foo", scrollable=False) + assert curs.scrollable is False + await curs.execute("select generate_series(0, 5)") + await curs.scroll(5) + with pytest.raises(e.OperationalError): + await curs.scroll(-1) + await curs.close() + + +@pytest.mark.parametrize("kwargs", [{}, {"withhold": False}]) +async def test_no_hold(aconn, kwargs): + async with aconn.cursor("foo", **kwargs) as curs: + assert curs.withhold is False + await curs.execute("select generate_series(0, 2)") + assert await curs.fetchone() == (0,) + await aconn.commit() + with pytest.raises(e.InvalidCursorName): + await curs.fetchone() + + +@pytest.mark.crdb_skip("cursor with hold") +async def test_hold(aconn): + async with aconn.cursor("foo", withhold=True) as curs: + assert curs.withhold is True + await curs.execute("select generate_series(0, 5)") + assert await curs.fetchone() == (0,) + await aconn.commit() + assert await curs.fetchone() == (1,) + + +@pytest.mark.parametrize("row_factory", ["tuple_row", "namedtuple_row"]) +async def test_steal_cursor(aconn, row_factory): + cur1 = aconn.cursor() + await cur1.execute( + "declare test cursor without hold for select generate_series(1, 6) as s" + ) + + cur2 = aconn.cursor("test", row_factory=getattr(rows, row_factory)) + # can call fetch without execute + rec = await cur2.fetchone() + assert rec == (1,) + if row_factory == "namedtuple_row": + assert rec.s == 1 + assert await cur2.fetchmany(3) == [(2,), (3,), (4,)] + assert await cur2.fetchall() == [(5,), (6,)] + await cur2.close() + + +async def test_stolen_cursor_close(aconn): + cur1 = aconn.cursor() + await cur1.execute("declare test cursor for select generate_series(1, 6)") + cur2 = aconn.cursor("test") + await cur2.close() + + await cur1.execute("declare test cursor for select generate_series(1, 6)") + cur2 = aconn.cursor("test") + await cur2.close() |