summaryrefslogtreecommitdiffstats
path: root/tests/test_adapt.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_adapt.py')
-rw-r--r--tests/test_adapt.py530
1 files changed, 530 insertions, 0 deletions
diff --git a/tests/test_adapt.py b/tests/test_adapt.py
new file mode 100644
index 0000000..2190a84
--- /dev/null
+++ b/tests/test_adapt.py
@@ -0,0 +1,530 @@
+import datetime as dt
+from types import ModuleType
+from typing import Any, List
+
+import pytest
+
+import psycopg
+from psycopg import pq, sql, postgres
+from psycopg import errors as e
+from psycopg.adapt import Transformer, PyFormat, Dumper, Loader
+from psycopg._cmodule import _psycopg
+from psycopg.postgres import types as builtins, TEXT_OID
+from psycopg.types.array import ListDumper, ListBinaryDumper
+
+
+@pytest.mark.parametrize(
+ "data, format, result, type",
+ [
+ (1, PyFormat.TEXT, b"1", "int2"),
+ ("hello", PyFormat.TEXT, b"hello", "text"),
+ ("hello", PyFormat.BINARY, b"hello", "text"),
+ ],
+)
+def test_dump(data, format, result, type):
+ t = Transformer()
+ dumper = t.get_dumper(data, format)
+ assert dumper.dump(data) == result
+ if type == "text" and format != PyFormat.BINARY:
+ assert dumper.oid == 0
+ else:
+ assert dumper.oid == builtins[type].oid
+
+
+@pytest.mark.parametrize(
+ "data, result",
+ [
+ (1, b"1"),
+ ("hello", b"'hello'"),
+ ("he'llo", b"'he''llo'"),
+ (True, b"true"),
+ (None, b"NULL"),
+ ],
+)
+def test_quote(data, result):
+ t = Transformer()
+ dumper = t.get_dumper(data, PyFormat.TEXT)
+ assert dumper.quote(data) == result
+
+
+def test_register_dumper_by_class(conn):
+ dumper = make_dumper("x")
+ assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is not dumper
+ conn.adapters.register_dumper(MyStr, dumper)
+ assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is dumper
+
+
+def test_register_dumper_by_class_name(conn):
+ dumper = make_dumper("x")
+ assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is not dumper
+ conn.adapters.register_dumper(f"{MyStr.__module__}.{MyStr.__qualname__}", dumper)
+ assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is dumper
+
+
+@pytest.mark.crdb("skip", reason="global adapters don't affect crdb")
+def test_dump_global_ctx(conn_cls, dsn, global_adapters, pgconn):
+ psycopg.adapters.register_dumper(MyStr, make_bin_dumper("gb"))
+ psycopg.adapters.register_dumper(MyStr, make_dumper("gt"))
+ with conn_cls.connect(dsn) as conn:
+ cur = conn.execute("select %s", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogt",)
+ cur = conn.execute("select %b", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogb",)
+ cur = conn.execute("select %t", [MyStr("hello")])
+ assert cur.fetchone() == ("hellogt",)
+
+
+def test_dump_connection_ctx(conn):
+ conn.adapters.register_dumper(MyStr, make_bin_dumper("b"))
+ conn.adapters.register_dumper(MyStr, make_dumper("t"))
+
+ cur = conn.cursor()
+ cur.execute("select %s", [MyStr("hello")])
+ assert cur.fetchone() == ("hellot",)
+ cur.execute("select %t", [MyStr("hello")])
+ assert cur.fetchone() == ("hellot",)
+ cur.execute("select %b", [MyStr("hello")])
+ assert cur.fetchone() == ("hellob",)
+
+
+def test_dump_cursor_ctx(conn):
+ conn.adapters.register_dumper(str, make_bin_dumper("b"))
+ conn.adapters.register_dumper(str, make_dumper("t"))
+
+ cur = conn.cursor()
+ cur.adapters.register_dumper(str, make_bin_dumper("bc"))
+ cur.adapters.register_dumper(str, make_dumper("tc"))
+
+ cur.execute("select %s", [MyStr("hello")])
+ assert cur.fetchone() == ("hellotc",)
+ cur.execute("select %t", [MyStr("hello")])
+ assert cur.fetchone() == ("hellotc",)
+ cur.execute("select %b", [MyStr("hello")])
+ assert cur.fetchone() == ("hellobc",)
+
+ cur = conn.cursor()
+ cur.execute("select %s", [MyStr("hello")])
+ assert cur.fetchone() == ("hellot",)
+ cur.execute("select %t", [MyStr("hello")])
+ assert cur.fetchone() == ("hellot",)
+ cur.execute("select %b", [MyStr("hello")])
+ assert cur.fetchone() == ("hellob",)
+
+
+def test_dump_subclass(conn):
+ class MyString(str):
+ pass
+
+ cur = conn.cursor()
+ cur.execute("select %s::text, %b::text", [MyString("hello"), MyString("world")])
+ assert cur.fetchone() == ("hello", "world")
+
+
+def test_subclass_dumper(conn):
+ # This might be a C fast object: make sure that the Python code is called
+ from psycopg.types.string import StrDumper
+
+ class MyStrDumper(StrDumper):
+ def dump(self, obj):
+ return (obj * 2).encode()
+
+ conn.adapters.register_dumper(str, MyStrDumper)
+ assert conn.execute("select %t", ["hello"]).fetchone()[0] == "hellohello"
+
+
+def test_dumper_protocol(conn):
+
+ # This class doesn't inherit from adapt.Dumper but passes a mypy check
+ from .adapters_example import MyStrDumper
+
+ conn.adapters.register_dumper(str, MyStrDumper)
+ cur = conn.execute("select %s", ["hello"])
+ assert cur.fetchone()[0] == "hellohello"
+ cur = conn.execute("select %s", [["hi", "ha"]])
+ assert cur.fetchone()[0] == ["hihi", "haha"]
+ assert sql.Literal("hello").as_string(conn) == "'qelloqello'"
+
+
+def test_loader_protocol(conn):
+
+ # This class doesn't inherit from adapt.Loader but passes a mypy check
+ from .adapters_example import MyTextLoader
+
+ conn.adapters.register_loader("text", MyTextLoader)
+ cur = conn.execute("select 'hello'::text")
+ assert cur.fetchone()[0] == "hellohello"
+ cur = conn.execute("select '{hi,ha}'::text[]")
+ assert cur.fetchone()[0] == ["hihi", "haha"]
+
+
+def test_subclass_loader(conn):
+ # This might be a C fast object: make sure that the Python code is called
+ from psycopg.types.string import TextLoader
+
+ class MyTextLoader(TextLoader):
+ def load(self, data):
+ return (bytes(data) * 2).decode()
+
+ conn.adapters.register_loader("text", MyTextLoader)
+ assert conn.execute("select 'hello'::text").fetchone()[0] == "hellohello"
+
+
+@pytest.mark.parametrize(
+ "data, format, type, result",
+ [
+ (b"1", pq.Format.TEXT, "int4", 1),
+ (b"hello", pq.Format.TEXT, "text", "hello"),
+ (b"hello", pq.Format.BINARY, "text", "hello"),
+ ],
+)
+def test_cast(data, format, type, result):
+ t = Transformer()
+ rv = t.get_loader(builtins[type].oid, format).load(data)
+ assert rv == result
+
+
+def test_register_loader_by_oid(conn):
+ assert TEXT_OID == 25
+ loader = make_loader("x")
+ assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is not loader
+ conn.adapters.register_loader(TEXT_OID, loader)
+ assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is loader
+
+
+def test_register_loader_by_type_name(conn):
+ loader = make_loader("x")
+ assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is not loader
+ conn.adapters.register_loader("text", loader)
+ assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is loader
+
+
+@pytest.mark.crdb("skip", reason="global adapters don't affect crdb")
+def test_load_global_ctx(conn_cls, dsn, global_adapters):
+ psycopg.adapters.register_loader("text", make_loader("gt"))
+ psycopg.adapters.register_loader("text", make_bin_loader("gb"))
+ with conn_cls.connect(dsn) as conn:
+ cur = conn.cursor(binary=False).execute("select 'hello'::text")
+ assert cur.fetchone() == ("hellogt",)
+ cur = conn.cursor(binary=True).execute("select 'hello'::text")
+ assert cur.fetchone() == ("hellogb",)
+
+
+def test_load_connection_ctx(conn):
+ conn.adapters.register_loader("text", make_loader("t"))
+ conn.adapters.register_loader("text", make_bin_loader("b"))
+
+ r = conn.cursor(binary=False).execute("select 'hello'::text").fetchone()
+ assert r == ("hellot",)
+ r = conn.cursor(binary=True).execute("select 'hello'::text").fetchone()
+ assert r == ("hellob",)
+
+
+def test_load_cursor_ctx(conn):
+ conn.adapters.register_loader("text", make_loader("t"))
+ conn.adapters.register_loader("text", make_bin_loader("b"))
+
+ cur = conn.cursor()
+ cur.adapters.register_loader("text", make_loader("tc"))
+ cur.adapters.register_loader("text", make_bin_loader("bc"))
+
+ assert cur.execute("select 'hello'::text").fetchone() == ("hellotc",)
+ cur.format = pq.Format.BINARY
+ assert cur.execute("select 'hello'::text").fetchone() == ("hellobc",)
+
+ cur = conn.cursor()
+ assert cur.execute("select 'hello'::text").fetchone() == ("hellot",)
+ cur.format = pq.Format.BINARY
+ assert cur.execute("select 'hello'::text").fetchone() == ("hellob",)
+
+
+def test_cow_dumpers(conn):
+ conn.adapters.register_dumper(str, make_dumper("t"))
+
+ cur1 = conn.cursor()
+ cur2 = conn.cursor()
+ cur2.adapters.register_dumper(str, make_dumper("c2"))
+
+ r = cur1.execute("select %s::text -- 1", ["hello"]).fetchone()
+ assert r == ("hellot",)
+ r = cur2.execute("select %s::text -- 1", ["hello"]).fetchone()
+ assert r == ("helloc2",)
+
+ conn.adapters.register_dumper(str, make_dumper("t1"))
+ r = cur1.execute("select %s::text -- 2", ["hello"]).fetchone()
+ assert r == ("hellot",)
+ r = cur2.execute("select %s::text -- 2", ["hello"]).fetchone()
+ assert r == ("helloc2",)
+
+
+def test_cow_loaders(conn):
+ conn.adapters.register_loader("text", make_loader("t"))
+
+ cur1 = conn.cursor()
+ cur2 = conn.cursor()
+ cur2.adapters.register_loader("text", make_loader("c2"))
+
+ assert cur1.execute("select 'hello'::text").fetchone() == ("hellot",)
+ assert cur2.execute("select 'hello'::text").fetchone() == ("helloc2",)
+
+ conn.adapters.register_loader("text", make_loader("t1"))
+ assert cur1.execute("select 'hello2'::text").fetchone() == ("hello2t",)
+ assert cur2.execute("select 'hello2'::text").fetchone() == ("hello2c2",)
+
+
+@pytest.mark.parametrize(
+ "sql, obj",
+ [("'{hello}'::text[]", ["helloc"]), ("row('hello'::text)", ("helloc",))],
+)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_cursor_ctx_nested(conn, sql, obj, fmt_out):
+ cur = conn.cursor(binary=fmt_out == pq.Format.BINARY)
+ if fmt_out == pq.Format.TEXT:
+ cur.adapters.register_loader("text", make_loader("c"))
+ else:
+ cur.adapters.register_loader("text", make_bin_loader("c"))
+
+ cur.execute(f"select {sql}")
+ res = cur.fetchone()[0]
+ assert res == obj
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_list_dumper(conn, fmt_out):
+ t = Transformer(conn)
+ fmt_in = PyFormat.from_pq(fmt_out)
+ dint = t.get_dumper([0], fmt_in)
+ assert isinstance(dint, (ListDumper, ListBinaryDumper))
+ assert dint.oid == builtins["int2"].array_oid
+ assert dint.sub_dumper and dint.sub_dumper.oid == builtins["int2"].oid
+
+ dstr = t.get_dumper([""], fmt_in)
+ assert dstr is not dint
+
+ assert t.get_dumper([1], fmt_in) is dint
+ assert t.get_dumper([None, [1]], fmt_in) is dint
+
+ dempty = t.get_dumper([], fmt_in)
+ assert t.get_dumper([None, [None]], fmt_in) is dempty
+ assert dempty.oid == 0
+ assert dempty.dump([]) == b"{}"
+
+ L: List[List[Any]] = []
+ L.append(L)
+ with pytest.raises(psycopg.DataError):
+ assert t.get_dumper(L, fmt_in)
+
+
+@pytest.mark.crdb("skip", reason="test in crdb test suite")
+def test_str_list_dumper_text(conn):
+ t = Transformer(conn)
+ dstr = t.get_dumper([""], PyFormat.TEXT)
+ assert isinstance(dstr, ListDumper)
+ assert dstr.oid == 0
+ assert dstr.sub_dumper and dstr.sub_dumper.oid == 0
+
+
+def test_str_list_dumper_binary(conn):
+ t = Transformer(conn)
+ dstr = t.get_dumper([""], PyFormat.BINARY)
+ assert isinstance(dstr, ListBinaryDumper)
+ assert dstr.oid == builtins["text"].array_oid
+ assert dstr.sub_dumper and dstr.sub_dumper.oid == builtins["text"].oid
+
+
+def test_last_dumper_registered_ctx(conn):
+ cur = conn.cursor()
+
+ bd = make_bin_dumper("b")
+ cur.adapters.register_dumper(str, bd)
+ td = make_dumper("t")
+ cur.adapters.register_dumper(str, td)
+
+ assert cur.execute("select %s", ["hello"]).fetchone()[0] == "hellot"
+ assert cur.execute("select %t", ["hello"]).fetchone()[0] == "hellot"
+ assert cur.execute("select %b", ["hello"]).fetchone()[0] == "hellob"
+
+ cur.adapters.register_dumper(str, bd)
+ assert cur.execute("select %s", ["hello"]).fetchone()[0] == "hellob"
+
+
+@pytest.mark.parametrize("fmt_in", [PyFormat.TEXT, PyFormat.BINARY])
+def test_none_type_argument(conn, fmt_in):
+ cur = conn.cursor()
+ cur.execute("create table none_args (id serial primary key, num integer)")
+ cur.execute("insert into none_args (num) values (%s) returning id", (None,))
+ assert cur.fetchone()[0]
+
+
+@pytest.mark.crdb("skip", reason="test in crdb test suite")
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_return_untyped(conn, fmt_in):
+ # Analyze and check for changes using strings in untyped/typed contexts
+ cur = conn.cursor()
+ # Currently string are passed as unknown oid to libpq. This is because
+ # unknown is more easily cast by postgres to different types (see jsonb
+ # later).
+ cur.execute(f"select %{fmt_in.value}, %{fmt_in.value}", ["hello", 10])
+ assert cur.fetchone() == ("hello", 10)
+
+ cur.execute("create table testjson(data jsonb)")
+ if fmt_in != PyFormat.BINARY:
+ cur.execute(f"insert into testjson (data) values (%{fmt_in.value})", ["{}"])
+ assert cur.execute("select data from testjson").fetchone() == ({},)
+ else:
+ # Binary types cannot be passed as unknown oids.
+ with pytest.raises(e.DatatypeMismatch):
+ cur.execute(f"insert into testjson (data) values (%{fmt_in.value})", ["{}"])
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_no_cast_needed(conn, fmt_in):
+ # Verify that there is no need of cast in certain common scenario
+ cur = conn.execute(f"select '2021-01-01'::date + %{fmt_in.value}", [3])
+ assert cur.fetchone()[0] == dt.date(2021, 1, 4)
+
+ cur = conn.execute(f"select '[10, 20, 30]'::jsonb -> %{fmt_in.value}", [1])
+ assert cur.fetchone()[0] == 20
+
+
+@pytest.mark.slow
+@pytest.mark.skipif(_psycopg is None, reason="C module test")
+def test_optimised_adapters():
+
+ # All the optimised adapters available
+ c_adapters = {}
+ for n in dir(_psycopg):
+ if n.startswith("_") or n in ("CDumper", "CLoader"):
+ continue
+ obj = getattr(_psycopg, n)
+ if not isinstance(obj, type):
+ continue
+ if not issubclass(
+ obj,
+ (_psycopg.CDumper, _psycopg.CLoader), # type: ignore[attr-defined]
+ ):
+ continue
+ c_adapters[n] = obj
+
+ # All the registered adapters
+ reg_adapters = set()
+ adapters = list(postgres.adapters._dumpers.values()) + postgres.adapters._loaders
+ assert len(adapters) == 5
+ for m in adapters:
+ reg_adapters |= set(m.values())
+
+ # Check that the registered adapters are the optimised one
+ i = 0
+ for cls in reg_adapters:
+ if cls.__name__ in c_adapters:
+ assert cls is c_adapters[cls.__name__]
+ i += 1
+
+ assert i >= 10
+
+ # Check that every optimised adapter is the optimised version of a Py one
+ for n in dir(psycopg.types):
+ mod = getattr(psycopg.types, n)
+ if not isinstance(mod, ModuleType):
+ continue
+ for n1 in dir(mod):
+ obj = getattr(mod, n1)
+ if not isinstance(obj, type):
+ continue
+ if not issubclass(obj, (Dumper, Loader)):
+ continue
+ c_adapters.pop(obj.__name__, None)
+
+ assert not c_adapters
+
+
+def test_dumper_init_error(conn):
+ class BadDumper(Dumper):
+ def __init__(self, cls, context):
+ super().__init__(cls, context)
+ 1 / 0
+
+ def dump(self, obj):
+ return obj.encode()
+
+ cur = conn.cursor()
+ cur.adapters.register_dumper(str, BadDumper)
+ with pytest.raises(ZeroDivisionError):
+ cur.execute("select %s::text", ["hi"])
+
+
+def test_loader_init_error(conn):
+ class BadLoader(Loader):
+ def __init__(self, oid, context):
+ super().__init__(oid, context)
+ 1 / 0
+
+ def load(self, data):
+ return data.decode()
+
+ cur = conn.cursor()
+ cur.adapters.register_loader("text", BadLoader)
+ with pytest.raises(ZeroDivisionError):
+ cur.execute("select 'hi'::text")
+ assert cur.fetchone() == ("hi",)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("fmt", PyFormat)
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
+def test_random(conn, faker, fmt, fmt_out):
+ faker.format = fmt
+ faker.choose_schema(ncols=20)
+ faker.make_records(50)
+
+ with conn.cursor(binary=fmt_out) as cur:
+ cur.execute(faker.drop_stmt)
+ cur.execute(faker.create_stmt)
+ with faker.find_insert_problem(conn):
+ cur.executemany(faker.insert_stmt, faker.records)
+
+ cur.execute(faker.select_stmt)
+ recs = cur.fetchall()
+
+ for got, want in zip(recs, faker.records):
+ faker.assert_record(got, want)
+
+
+class MyStr(str):
+ pass
+
+
+def make_dumper(suffix):
+ """Create a test dumper appending a suffix to the bytes representation."""
+
+ class TestDumper(Dumper):
+ oid = TEXT_OID
+ format = pq.Format.TEXT
+
+ def dump(self, s):
+ return (s + suffix).encode("ascii")
+
+ return TestDumper
+
+
+def make_bin_dumper(suffix):
+ cls = make_dumper(suffix)
+ cls.format = pq.Format.BINARY
+ return cls
+
+
+def make_loader(suffix):
+ """Create a test loader appending a suffix to the data returned."""
+
+ class TestLoader(Loader):
+ format = pq.Format.TEXT
+
+ def load(self, b):
+ return bytes(b).decode("ascii") + suffix
+
+ return TestLoader
+
+
+def make_bin_loader(suffix):
+ cls = make_loader(suffix)
+ cls.format = pq.Format.BINARY
+ return cls