diff options
Diffstat (limited to 'tests/test_typing.py')
-rw-r--r-- | tests/test_typing.py | 449 |
1 files changed, 449 insertions, 0 deletions
diff --git a/tests/test_typing.py b/tests/test_typing.py new file mode 100644 index 0000000..fff9cec --- /dev/null +++ b/tests/test_typing.py @@ -0,0 +1,449 @@ +import os + +import pytest + +HERE = os.path.dirname(os.path.abspath(__file__)) + + +@pytest.mark.parametrize( + "filename", + ["adapters_example.py", "typing_example.py"], +) +def test_typing_example(mypy, filename): + cp = mypy.run_on_file(os.path.join(HERE, filename)) + errors = cp.stdout.decode("utf8", "replace").splitlines() + assert not errors + assert cp.returncode == 0 + + +@pytest.mark.parametrize( + "conn, type", + [ + ( + "psycopg.connect()", + "psycopg.Connection[Tuple[Any, ...]]", + ), + ( + "psycopg.connect(row_factory=rows.tuple_row)", + "psycopg.Connection[Tuple[Any, ...]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "psycopg.Connection[Dict[str, Any]]", + ), + ( + "psycopg.connect(row_factory=rows.namedtuple_row)", + "psycopg.Connection[NamedTuple]", + ), + ( + "psycopg.connect(row_factory=rows.class_row(Thing))", + "psycopg.Connection[Thing]", + ), + ( + "psycopg.connect(row_factory=thing_row)", + "psycopg.Connection[Thing]", + ), + ( + "psycopg.Connection.connect()", + "psycopg.Connection[Tuple[Any, ...]]", + ), + ( + "psycopg.Connection.connect(row_factory=rows.dict_row)", + "psycopg.Connection[Dict[str, Any]]", + ), + ( + "await psycopg.AsyncConnection.connect()", + "psycopg.AsyncConnection[Tuple[Any, ...]]", + ), + ( + "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)", + "psycopg.AsyncConnection[Dict[str, Any]]", + ), + ], +) +def test_connection_type(conn, type, mypy): + stmts = f"obj = {conn}" + _test_reveal(stmts, type, mypy) + + +@pytest.mark.parametrize( + "conn, curs, type", + [ + ( + "psycopg.connect()", + "conn.cursor()", + "psycopg.Cursor[Tuple[Any, ...]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "conn.cursor()", + "psycopg.Cursor[Dict[str, Any]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "conn.cursor(row_factory=rows.namedtuple_row)", + "psycopg.Cursor[NamedTuple]", + ), + ( + "psycopg.connect(row_factory=rows.class_row(Thing))", + "conn.cursor()", + "psycopg.Cursor[Thing]", + ), + ( + "psycopg.connect(row_factory=thing_row)", + "conn.cursor()", + "psycopg.Cursor[Thing]", + ), + ( + "psycopg.connect()", + "conn.cursor(row_factory=thing_row)", + "psycopg.Cursor[Thing]", + ), + # Async cursors + ( + "await psycopg.AsyncConnection.connect()", + "conn.cursor()", + "psycopg.AsyncCursor[Tuple[Any, ...]]", + ), + ( + "await psycopg.AsyncConnection.connect()", + "conn.cursor(row_factory=thing_row)", + "psycopg.AsyncCursor[Thing]", + ), + # Server-side cursors + ( + "psycopg.connect()", + "conn.cursor(name='foo')", + "psycopg.ServerCursor[Tuple[Any, ...]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "conn.cursor(name='foo')", + "psycopg.ServerCursor[Dict[str, Any]]", + ), + ( + "psycopg.connect()", + "conn.cursor(name='foo', row_factory=rows.dict_row)", + "psycopg.ServerCursor[Dict[str, Any]]", + ), + # Async server-side cursors + ( + "await psycopg.AsyncConnection.connect()", + "conn.cursor(name='foo')", + "psycopg.AsyncServerCursor[Tuple[Any, ...]]", + ), + ( + "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)", + "conn.cursor(name='foo')", + "psycopg.AsyncServerCursor[Dict[str, Any]]", + ), + ( + "await psycopg.AsyncConnection.connect()", + "conn.cursor(name='foo', row_factory=rows.dict_row)", + "psycopg.AsyncServerCursor[Dict[str, Any]]", + ), + ], +) +def test_cursor_type(conn, curs, type, mypy): + stmts = f"""\ +conn = {conn} +obj = {curs} +""" + _test_reveal(stmts, type, mypy) + + +@pytest.mark.parametrize( + "conn, curs, type", + [ + ( + "psycopg.connect()", + "psycopg.Cursor(conn)", + "psycopg.Cursor[Tuple[Any, ...]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "psycopg.Cursor(conn)", + "psycopg.Cursor[Dict[str, Any]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "psycopg.Cursor(conn, row_factory=rows.namedtuple_row)", + "psycopg.Cursor[NamedTuple]", + ), + # Async cursors + ( + "await psycopg.AsyncConnection.connect()", + "psycopg.AsyncCursor(conn)", + "psycopg.AsyncCursor[Tuple[Any, ...]]", + ), + ( + "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)", + "psycopg.AsyncCursor(conn)", + "psycopg.AsyncCursor[Dict[str, Any]]", + ), + ( + "await psycopg.AsyncConnection.connect()", + "psycopg.AsyncCursor(conn, row_factory=thing_row)", + "psycopg.AsyncCursor[Thing]", + ), + # Server-side cursors + ( + "psycopg.connect()", + "psycopg.ServerCursor(conn, 'foo')", + "psycopg.ServerCursor[Tuple[Any, ...]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "psycopg.ServerCursor(conn, name='foo')", + "psycopg.ServerCursor[Dict[str, Any]]", + ), + ( + "psycopg.connect(row_factory=rows.dict_row)", + "psycopg.ServerCursor(conn, 'foo', row_factory=rows.namedtuple_row)", + "psycopg.ServerCursor[NamedTuple]", + ), + # Async server-side cursors + ( + "await psycopg.AsyncConnection.connect()", + "psycopg.AsyncServerCursor(conn, name='foo')", + "psycopg.AsyncServerCursor[Tuple[Any, ...]]", + ), + ( + "await psycopg.AsyncConnection.connect(row_factory=rows.dict_row)", + "psycopg.AsyncServerCursor(conn, name='foo')", + "psycopg.AsyncServerCursor[Dict[str, Any]]", + ), + ( + "await psycopg.AsyncConnection.connect()", + "psycopg.AsyncServerCursor(conn, name='foo', row_factory=rows.dict_row)", + "psycopg.AsyncServerCursor[Dict[str, Any]]", + ), + ], +) +def test_cursor_type_init(conn, curs, type, mypy): + stmts = f"""\ +conn = {conn} +obj = {curs} +""" + _test_reveal(stmts, type, mypy) + + +@pytest.mark.parametrize( + "curs, type", + [ + ( + "conn.cursor()", + "Optional[Tuple[Any, ...]]", + ), + ( + "conn.cursor(row_factory=rows.dict_row)", + "Optional[Dict[str, Any]]", + ), + ( + "conn.cursor(row_factory=thing_row)", + "Optional[Thing]", + ), + ], +) +@pytest.mark.parametrize("server_side", [False, True]) +@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"]) +def test_fetchone_type(conn_class, server_side, curs, type, mypy): + await_ = "await" if "Async" in conn_class else "" + if server_side: + curs = curs.replace("(", "(name='foo',", 1) + stmts = f"""\ +conn = {await_} psycopg.{conn_class}.connect() +curs = {curs} +obj = {await_} curs.fetchone() +""" + _test_reveal(stmts, type, mypy) + + +@pytest.mark.parametrize( + "curs, type", + [ + ( + "conn.cursor()", + "Tuple[Any, ...]", + ), + ( + "conn.cursor(row_factory=rows.dict_row)", + "Dict[str, Any]", + ), + ( + "conn.cursor(row_factory=thing_row)", + "Thing", + ), + ], +) +@pytest.mark.parametrize("server_side", [False, True]) +@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"]) +def test_iter_type(conn_class, server_side, curs, type, mypy): + if "Async" in conn_class: + async_ = "async " + await_ = "await " + else: + async_ = await_ = "" + + if server_side: + curs = curs.replace("(", "(name='foo',", 1) + stmts = f"""\ +conn = {await_}psycopg.{conn_class}.connect() +curs = {curs} +{async_}for obj in curs: + pass +""" + _test_reveal(stmts, type, mypy) + + +@pytest.mark.parametrize("method", ["fetchmany", "fetchall"]) +@pytest.mark.parametrize( + "curs, type", + [ + ( + "conn.cursor()", + "List[Tuple[Any, ...]]", + ), + ( + "conn.cursor(row_factory=rows.dict_row)", + "List[Dict[str, Any]]", + ), + ( + "conn.cursor(row_factory=thing_row)", + "List[Thing]", + ), + ], +) +@pytest.mark.parametrize("server_side", [False, True]) +@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"]) +def test_fetchsome_type(conn_class, server_side, curs, type, method, mypy): + await_ = "await" if "Async" in conn_class else "" + if server_side: + curs = curs.replace("(", "(name='foo',", 1) + stmts = f"""\ +conn = {await_} psycopg.{conn_class}.connect() +curs = {curs} +obj = {await_} curs.{method}() +""" + _test_reveal(stmts, type, mypy) + + +@pytest.mark.parametrize("server_side", [False, True]) +@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"]) +def test_cur_subclass_execute(mypy, conn_class, server_side): + async_ = "async " if "Async" in conn_class else "" + await_ = "await" if "Async" in conn_class else "" + cur_base_class = "".join( + [ + "Async" if "Async" in conn_class else "", + "Server" if server_side else "", + "Cursor", + ] + ) + cur_name = "'foo'" if server_side else "" + + src = f"""\ +from typing import Any, cast +import psycopg +from psycopg.rows import Row, TupleRow + +class MyCursor(psycopg.{cur_base_class}[Row]): + pass + +{async_}def test() -> None: + conn = {await_} psycopg.{conn_class}.connect() + + cur: MyCursor[TupleRow] + reveal_type(cur) + + cur = cast(MyCursor[TupleRow], conn.cursor({cur_name})) + {async_}with cur as cur2: + reveal_type(cur2) + cur3 = {await_} cur2.execute("") + reveal_type(cur3) +""" + cp = mypy.run_on_source(src) + out = cp.stdout.decode("utf8", "replace").splitlines() + assert len(out) == 3 + types = [mypy.get_revealed(line) for line in out] + assert types[0] == types[1] + assert types[0] == types[2] + + +def _test_reveal(stmts, type, mypy): + ignore = "" if type.startswith("Optional") else "# type: ignore[assignment]" + stmts = "\n".join(f" {line}" for line in stmts.splitlines()) + + src = f"""\ +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence +from typing import Tuple, Union +import psycopg +from psycopg import rows + +class Thing: + def __init__(self, **kwargs: Any) -> None: + self.kwargs = kwargs + +def thing_row( + cur: Union[psycopg.Cursor[Any], psycopg.AsyncCursor[Any]], +) -> Callable[[Sequence[Any]], Thing]: + assert cur.description + names = [d.name for d in cur.description] + + def make_row(t: Sequence[Any]) -> Thing: + return Thing(**dict(zip(names, t))) + + return make_row + +async def tmp() -> None: +{stmts} + reveal_type(obj) + +ref: {type} = None {ignore} +reveal_type(ref) +""" + cp = mypy.run_on_source(src) + out = cp.stdout.decode("utf8", "replace").splitlines() + assert len(out) == 2, "\n".join(out) + got, want = [mypy.get_revealed(line) for line in out] + assert got == want + + +@pytest.mark.xfail(reason="https://github.com/psycopg/psycopg/issues/308") +@pytest.mark.parametrize( + "conn, type", + [ + ( + "MyConnection.connect()", + "MyConnection[Tuple[Any, ...]]", + ), + ( + "MyConnection.connect(row_factory=rows.tuple_row)", + "MyConnection[Tuple[Any, ...]]", + ), + ( + "MyConnection.connect(row_factory=rows.dict_row)", + "MyConnection[Dict[str, Any]]", + ), + ], +) +def test_generic_connect(conn, type, mypy): + src = f""" +from typing import Any, Dict, Tuple +import psycopg +from psycopg import rows + +class MyConnection(psycopg.Connection[rows.Row]): + pass + +obj = {conn} +reveal_type(obj) + +ref: {type} = None # type: ignore[assignment] +reveal_type(ref) +""" + cp = mypy.run_on_source(src) + out = cp.stdout.decode("utf8", "replace").splitlines() + assert len(out) == 2, "\n".join(out) + got, want = [mypy.get_revealed(line) for line in out] + assert got == want |