summaryrefslogtreecommitdiffstats
path: root/tests/test_server_cursor.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/test_server_cursor.py525
1 files changed, 525 insertions, 0 deletions
diff --git a/tests/test_server_cursor.py b/tests/test_server_cursor.py
new file mode 100644
index 0000000..f7b6c8e
--- /dev/null
+++ b/tests/test_server_cursor.py
@@ -0,0 +1,525 @@
+import pytest
+
+import psycopg
+from psycopg import rows, errors as e
+from psycopg.pq import Format
+
+pytestmark = pytest.mark.crdb_skip("server-side cursor")
+
+
+def test_init_row_factory(conn):
+ with psycopg.ServerCursor(conn, "foo") as cur:
+ assert cur.name == "foo"
+ assert cur.connection is conn
+ assert cur.row_factory is conn.row_factory
+
+ conn.row_factory = rows.dict_row
+
+ with psycopg.ServerCursor(conn, "bar") as cur:
+ assert cur.name == "bar"
+ assert cur.row_factory is rows.dict_row # type: ignore
+
+ with psycopg.ServerCursor(conn, "baz", row_factory=rows.namedtuple_row) as cur:
+ assert cur.name == "baz"
+ assert cur.row_factory is rows.namedtuple_row # type: ignore
+
+
+def test_init_params(conn):
+ with psycopg.ServerCursor(conn, "foo") as cur:
+ assert cur.scrollable is None
+ assert cur.withhold is False
+
+ with psycopg.ServerCursor(conn, "bar", withhold=True, scrollable=False) as cur:
+ assert cur.scrollable is False
+ assert cur.withhold is True
+
+
+@pytest.mark.crdb_skip("cursor invalid name")
+def test_funny_name(conn):
+ cur = conn.cursor("1-2-3")
+ cur.execute("select generate_series(1, 3) as bar")
+ assert cur.fetchall() == [(1,), (2,), (3,)]
+ assert cur.name == "1-2-3"
+ cur.close()
+
+
+def test_repr(conn):
+ cur = conn.cursor("my-name")
+ assert "psycopg.ServerCursor" in str(cur)
+ assert "my-name" in repr(cur)
+ cur.close()
+
+
+def test_connection(conn):
+ cur = conn.cursor("foo")
+ assert cur.connection is conn
+ cur.close()
+
+
+def test_description(conn):
+ cur = conn.cursor("foo")
+ assert cur.name == "foo"
+ cur.execute("select generate_series(1, 10)::int4 as bar")
+ assert len(cur.description) == 1
+ assert cur.description[0].name == "bar"
+ assert cur.description[0].type_code == cur.adapters.types["int4"].oid
+ assert cur.pgresult.ntuples == 0
+ cur.close()
+
+
+def test_format(conn):
+ cur = conn.cursor("foo")
+ assert cur.format == Format.TEXT
+ cur.close()
+
+ cur = conn.cursor("foo", binary=True)
+ assert cur.format == Format.BINARY
+ cur.close()
+
+
+def test_query_params(conn):
+ with conn.cursor("foo") as cur:
+ assert cur._query is None
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert cur._query
+ assert b"declare" in cur._query.query.lower()
+ assert b"(1, $1)" in cur._query.query.lower()
+ assert cur._query.params == [bytes([0, 3])] # 3 as binary int2
+
+
+def test_binary_cursor_execute(conn):
+ cur = conn.cursor("foo", binary=True)
+ cur.execute("select generate_series(1, 2)::int4")
+ assert cur.fetchone() == (1,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01"
+ assert cur.fetchone() == (2,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02"
+ cur.close()
+
+
+def test_execute_binary(conn):
+ cur = conn.cursor("foo")
+ cur.execute("select generate_series(1, 2)::int4", binary=True)
+ assert cur.fetchone() == (1,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01"
+ assert cur.fetchone() == (2,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x02"
+
+ cur.execute("select generate_series(1, 1)::int4")
+ assert cur.fetchone() == (1,)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"1"
+ cur.close()
+
+
+def test_binary_cursor_text_override(conn):
+ cur = conn.cursor("foo", binary=True)
+ cur.execute("select generate_series(1, 2)", binary=False)
+ assert cur.fetchone() == (1,)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"1"
+ assert cur.fetchone() == (2,)
+ assert cur.pgresult.fformat(0) == 0
+ assert cur.pgresult.get_value(0, 0) == b"2"
+
+ cur.execute("select generate_series(1, 2)::int4")
+ assert cur.fetchone() == (1,)
+ assert cur.pgresult.fformat(0) == 1
+ assert cur.pgresult.get_value(0, 0) == b"\x00\x00\x00\x01"
+ cur.close()
+
+
+def test_close(conn, recwarn):
+ if conn.info.transaction_status == conn.TransactionStatus.INTRANS:
+ # connection dirty from previous failure
+ conn.execute("close foo")
+ recwarn.clear()
+ cur = conn.cursor("foo")
+ cur.execute("select generate_series(1, 10) as bar")
+ cur.close()
+ assert cur.closed
+
+ assert not conn.execute("select * from pg_cursors where name = 'foo'").fetchone()
+ del cur
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+def test_close_idempotent(conn):
+ cur = conn.cursor("foo")
+ cur.execute("select 1")
+ cur.fetchall()
+ cur.close()
+ cur.close()
+
+
+def test_close_broken_conn(conn):
+ cur = conn.cursor("foo")
+ conn.close()
+ cur.close()
+ assert cur.closed
+
+
+def test_cursor_close_fetchone(conn):
+ cur = conn.cursor("foo")
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ cur.execute(query)
+ for _ in range(5):
+ cur.fetchone()
+
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(e.InterfaceError):
+ cur.fetchone()
+
+
+def test_cursor_close_fetchmany(conn):
+ cur = conn.cursor("foo")
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ cur.execute(query)
+ assert len(cur.fetchmany(2)) == 2
+
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(e.InterfaceError):
+ cur.fetchmany(2)
+
+
+def test_cursor_close_fetchall(conn):
+ cur = conn.cursor("foo")
+ assert not cur.closed
+
+ query = "select * from generate_series(1, 10)"
+ cur.execute(query)
+ assert len(cur.fetchall()) == 10
+
+ cur.close()
+ assert cur.closed
+
+ with pytest.raises(e.InterfaceError):
+ cur.fetchall()
+
+
+def test_close_noop(conn, recwarn):
+ recwarn.clear()
+ cur = conn.cursor("foo")
+ cur.close()
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+def test_close_on_error(conn):
+ cur = conn.cursor("foo")
+ cur.execute("select 1")
+ with pytest.raises(e.ProgrammingError):
+ conn.execute("wat")
+ assert conn.info.transaction_status == conn.TransactionStatus.INERROR
+ cur.close()
+
+
+def test_pgresult(conn):
+ cur = conn.cursor()
+ cur.execute("select 1")
+ assert cur.pgresult
+ cur.close()
+ assert not cur.pgresult
+
+
+def test_context(conn, recwarn):
+ recwarn.clear()
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, 10) as bar")
+
+ assert cur.closed
+ assert not conn.execute("select * from pg_cursors where name = 'foo'").fetchone()
+ del cur
+ assert not recwarn, [str(w.message) for w in recwarn.list]
+
+
+def test_close_no_clobber(conn):
+ with pytest.raises(e.DivisionByZero):
+ with conn.cursor("foo") as cur:
+ cur.execute("select 1 / %s", (0,))
+ cur.fetchall()
+
+
+def test_warn_close(conn, recwarn):
+ recwarn.clear()
+ cur = conn.cursor("foo")
+ cur.execute("select generate_series(1, 10) as bar")
+ del cur
+ assert ".close()" in str(recwarn.pop(ResourceWarning).message)
+
+
+def test_execute_reuse(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as foo", (3,))
+ assert cur.fetchone() == (1,)
+
+ cur.execute("select %s::text as bar, %s::text as baz", ("hello", "world"))
+ assert cur.fetchone() == ("hello", "world")
+ assert cur.description[0].name == "bar"
+ assert cur.description[0].type_code == cur.adapters.types["text"].oid
+ assert cur.description[1].name == "baz"
+
+
+@pytest.mark.parametrize(
+ "stmt", ["", "wat", "create table ssc ()", "select 1; select 2"]
+)
+def test_execute_error(conn, stmt):
+ cur = conn.cursor("foo")
+ with pytest.raises(e.ProgrammingError):
+ cur.execute(stmt)
+ cur.close()
+
+
+def test_executemany(conn):
+ cur = conn.cursor("foo")
+ with pytest.raises(e.NotSupportedError):
+ cur.executemany("select %s", [(1,), (2,)])
+ cur.close()
+
+
+def test_fetchone(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (2,))
+ assert cur.fetchone() == (1,)
+ assert cur.fetchone() == (2,)
+ assert cur.fetchone() is None
+
+
+def test_fetchmany(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (5,))
+ assert cur.fetchmany(3) == [(1,), (2,), (3,)]
+ assert cur.fetchone() == (4,)
+ assert cur.fetchmany(3) == [(5,)]
+ assert cur.fetchmany(3) == []
+
+
+def test_fetchall(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert cur.fetchall() == [(1,), (2,), (3,)]
+ assert cur.fetchall() == []
+
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert cur.fetchone() == (1,)
+ assert cur.fetchall() == [(2,), (3,)]
+ assert cur.fetchall() == []
+
+
+def test_nextset(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert not cur.nextset()
+
+
+def test_no_result(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar where false", (3,))
+ assert len(cur.description) == 1
+ assert cur.fetchall() == []
+
+
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+def test_standard_row_factory(conn, row_factory):
+ if row_factory == "tuple_row":
+ getter = lambda r: r[0] # noqa: E731
+ elif row_factory == "dict_row":
+ getter = lambda r: r["bar"] # noqa: E731
+ elif row_factory == "namedtuple_row":
+ getter = lambda r: r.bar # noqa: E731
+ else:
+ assert False, row_factory
+
+ row_factory = getattr(rows, row_factory)
+ with conn.cursor("foo", row_factory=row_factory) as cur:
+ cur.execute("select generate_series(1, 5) as bar")
+ assert getter(cur.fetchone()) == 1
+ assert list(map(getter, cur.fetchmany(2))) == [2, 3]
+ assert list(map(getter, cur.fetchall())) == [4, 5]
+
+
+@pytest.mark.crdb_skip("scroll cursor")
+def test_row_factory(conn):
+ n = 0
+
+ def my_row_factory(cur):
+ nonlocal n
+ n += 1
+ return lambda values: [n] + [-v for v in values]
+
+ cur = conn.cursor("foo", row_factory=my_row_factory, scrollable=True)
+ cur.execute("select generate_series(1, 3) as x")
+ recs = cur.fetchall()
+ cur.scroll(0, "absolute")
+ while True:
+ rec = cur.fetchone()
+ if not rec:
+ break
+ recs.append(rec)
+ assert recs == [[1, -1], [1, -2], [1, -3]] * 2
+
+ cur.scroll(0, "absolute")
+ cur.row_factory = rows.dict_row
+ assert cur.fetchone() == {"x": 1}
+ cur.close()
+
+
+def test_rownumber(conn):
+ cur = conn.cursor("foo")
+ assert cur.rownumber is None
+
+ cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rownumber == 0
+
+ cur.fetchone()
+ assert cur.rownumber == 1
+ cur.fetchone()
+ assert cur.rownumber == 2
+ cur.fetchmany(10)
+ assert cur.rownumber == 12
+ cur.fetchall()
+ assert cur.rownumber == 42
+ cur.close()
+
+
+def test_iter(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ recs = list(cur)
+ assert recs == [(1,), (2,), (3,)]
+
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert cur.fetchone() == (1,)
+ recs = list(cur)
+ assert recs == [(2,), (3,)]
+
+
+def test_iter_rownumber(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ for row in cur:
+ assert cur.rownumber == row[0]
+
+
+def test_itersize(conn, commands):
+ with conn.cursor("foo") as cur:
+ assert cur.itersize == 100
+ cur.itersize = 2
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ commands.popall() # flush begin and other noise
+
+ list(cur)
+ cmds = commands.popall()
+ assert len(cmds) == 2
+ for cmd in cmds:
+ assert "fetch forward 2" in cmd.lower()
+
+
+def test_cant_scroll_by_default(conn):
+ cur = conn.cursor("tmp")
+ assert cur.scrollable is None
+ with pytest.raises(e.ProgrammingError):
+ cur.scroll(0)
+ cur.close()
+
+
+@pytest.mark.crdb_skip("scroll cursor")
+def test_scroll(conn):
+ cur = conn.cursor("tmp", scrollable=True)
+ cur.execute("select generate_series(0,9)")
+ cur.scroll(2)
+ assert cur.fetchone() == (2,)
+ cur.scroll(2)
+ assert cur.fetchone() == (5,)
+ cur.scroll(2, mode="relative")
+ assert cur.fetchone() == (8,)
+ cur.scroll(9, mode="absolute")
+ assert cur.fetchone() == (9,)
+
+ with pytest.raises(ValueError):
+ cur.scroll(9, mode="wat")
+ cur.close()
+
+
+@pytest.mark.crdb_skip("scroll cursor")
+def test_scrollable(conn):
+ curs = conn.cursor("foo", scrollable=True)
+ assert curs.scrollable is True
+ curs.execute("select generate_series(0, 5)")
+ curs.scroll(5)
+ for i in range(4, -1, -1):
+ curs.scroll(-1)
+ assert i == curs.fetchone()[0]
+ curs.scroll(-1)
+ curs.close()
+
+
+def test_non_scrollable(conn):
+ curs = conn.cursor("foo", scrollable=False)
+ assert curs.scrollable is False
+ curs.execute("select generate_series(0, 5)")
+ curs.scroll(5)
+ with pytest.raises(e.OperationalError):
+ curs.scroll(-1)
+ curs.close()
+
+
+@pytest.mark.parametrize("kwargs", [{}, {"withhold": False}])
+def test_no_hold(conn, kwargs):
+ with conn.cursor("foo", **kwargs) as curs:
+ assert curs.withhold is False
+ curs.execute("select generate_series(0, 2)")
+ assert curs.fetchone() == (0,)
+ conn.commit()
+ with pytest.raises(e.InvalidCursorName):
+ curs.fetchone()
+
+
+@pytest.mark.crdb_skip("cursor with hold")
+def test_hold(conn):
+ with conn.cursor("foo", withhold=True) as curs:
+ assert curs.withhold is True
+ curs.execute("select generate_series(0, 5)")
+ assert curs.fetchone() == (0,)
+ conn.commit()
+ assert curs.fetchone() == (1,)
+
+
+@pytest.mark.parametrize("row_factory", ["tuple_row", "namedtuple_row"])
+def test_steal_cursor(conn, row_factory):
+ cur1 = conn.cursor()
+ cur1.execute("declare test cursor for select generate_series(1, 6) as s")
+
+ cur2 = conn.cursor("test", row_factory=getattr(rows, row_factory))
+ # can call fetch without execute
+ rec = cur2.fetchone()
+ assert rec == (1,)
+ if row_factory == "namedtuple_row":
+ assert rec.s == 1
+ assert cur2.fetchmany(3) == [(2,), (3,), (4,)]
+ assert cur2.fetchall() == [(5,), (6,)]
+ cur2.close()
+
+
+def test_stolen_cursor_close(conn):
+ cur1 = conn.cursor()
+ cur1.execute("declare test cursor for select generate_series(1, 6)")
+ cur2 = conn.cursor("test")
+ cur2.close()
+
+ cur1.execute("declare test cursor for select generate_series(1, 6)")
+ cur2 = conn.cursor("test")
+ cur2.close()