diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-04 17:41:08 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-04 17:41:08 +0000 |
commit | 506ed8899b3a97e512be3fd6d44d5b11463bf9bf (patch) | |
tree | 808913770c5e6935d3714058c2a066c57b4632ec /tests/test_sql.py | |
parent | Initial commit. (diff) | |
download | psycopg3-506ed8899b3a97e512be3fd6d44d5b11463bf9bf.tar.xz psycopg3-506ed8899b3a97e512be3fd6d44d5b11463bf9bf.zip |
Adding upstream version 3.1.7.upstream/3.1.7upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/test_sql.py')
-rw-r--r-- | tests/test_sql.py | 604 |
1 files changed, 604 insertions, 0 deletions
diff --git a/tests/test_sql.py b/tests/test_sql.py new file mode 100644 index 0000000..42b6c63 --- /dev/null +++ b/tests/test_sql.py @@ -0,0 +1,604 @@ +# test_sql.py - tests for the psycopg2.sql module + +# Copyright (C) 2020 The Psycopg Team + +import re +import datetime as dt + +import pytest + +from psycopg import pq, sql, ProgrammingError +from psycopg.adapt import PyFormat +from psycopg._encodings import py2pgenc +from psycopg.types import TypeInfo +from psycopg.types.string import StrDumper + +from .utils import eur +from .fix_crdb import crdb_encoding, crdb_scs_off + + +@pytest.mark.parametrize( + "obj, quoted", + [ + ("foo\\bar", " E'foo\\\\bar'"), + ("hello", "'hello'"), + (42, "42"), + (True, "true"), + (None, "NULL"), + ], +) +def test_quote(obj, quoted): + assert sql.quote(obj) == quoted + + +@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")]) +def test_quote_roundtrip(conn, scs): + messages = [] + conn.add_notice_handler(lambda msg: messages.append(msg.message_primary)) + conn.execute(f"set standard_conforming_strings to {scs}") + + for i in range(1, 256): + want = chr(i) + quoted = sql.quote(want) + got = conn.execute(f"select {quoted}::text").fetchone()[0] + assert want == got + + # No "nonstandard use of \\ in a string literal" warning + assert not messages, f"error with {want!r}" + + +@pytest.mark.parametrize("dummy", [crdb_scs_off("off")]) +def test_quote_stable_despite_deranged_libpq(conn, dummy): + # Verify the libpq behaviour of PQescapeString using the last setting seen. + # Check that we are not affected by it. + good_str = " E'\\\\'" + good_bytes = " E'\\\\000'::bytea" + conn.execute("set standard_conforming_strings to on") + assert pq.Escaping().escape_string(b"\\") == b"\\" + assert sql.quote("\\") == good_str + assert pq.Escaping().escape_bytea(b"\x00") == b"\\000" + assert sql.quote(b"\x00") == good_bytes + + conn.execute("set standard_conforming_strings to off") + assert pq.Escaping().escape_string(b"\\") == b"\\\\" + assert sql.quote("\\") == good_str + assert pq.Escaping().escape_bytea(b"\x00") == b"\\\\000" + assert sql.quote(b"\x00") == good_bytes + + # Verify that the good values are actually good + messages = [] + conn.add_notice_handler(lambda msg: messages.append(msg.message_primary)) + conn.execute("set escape_string_warning to on") + for scs in ("on", "off"): + conn.execute(f"set standard_conforming_strings to {scs}") + cur = conn.execute(f"select {good_str}, {good_bytes}::bytea") + assert cur.fetchone() == ("\\", b"\x00") + + # No "nonstandard use of \\ in a string literal" warning + assert not messages + + +class TestSqlFormat: + def test_pos(self, conn): + s = sql.SQL("select {} from {}").format( + sql.Identifier("field"), sql.Identifier("table") + ) + s1 = s.as_string(conn) + assert isinstance(s1, str) + assert s1 == 'select "field" from "table"' + + def test_pos_spec(self, conn): + s = sql.SQL("select {0} from {1}").format( + sql.Identifier("field"), sql.Identifier("table") + ) + s1 = s.as_string(conn) + assert isinstance(s1, str) + assert s1 == 'select "field" from "table"' + + s = sql.SQL("select {1} from {0}").format( + sql.Identifier("table"), sql.Identifier("field") + ) + s1 = s.as_string(conn) + assert isinstance(s1, str) + assert s1 == 'select "field" from "table"' + + def test_dict(self, conn): + s = sql.SQL("select {f} from {t}").format( + f=sql.Identifier("field"), t=sql.Identifier("table") + ) + s1 = s.as_string(conn) + assert isinstance(s1, str) + assert s1 == 'select "field" from "table"' + + def test_compose_literal(self, conn): + s = sql.SQL("select {0};").format(sql.Literal(dt.date(2016, 12, 31))) + s1 = s.as_string(conn) + assert s1 == "select '2016-12-31'::date;" + + def test_compose_empty(self, conn): + s = sql.SQL("select foo;").format() + s1 = s.as_string(conn) + assert s1 == "select foo;" + + def test_percent_escape(self, conn): + s = sql.SQL("42 % {0}").format(sql.Literal(7)) + s1 = s.as_string(conn) + assert s1 == "42 % 7" + + def test_braces_escape(self, conn): + s = sql.SQL("{{{0}}}").format(sql.Literal(7)) + assert s.as_string(conn) == "{7}" + s = sql.SQL("{{1,{0}}}").format(sql.Literal(7)) + assert s.as_string(conn) == "{1,7}" + + def test_compose_badnargs(self): + with pytest.raises(IndexError): + sql.SQL("select {0};").format() + + def test_compose_badnargs_auto(self): + with pytest.raises(IndexError): + sql.SQL("select {};").format() + with pytest.raises(ValueError): + sql.SQL("select {} {1};").format(10, 20) + with pytest.raises(ValueError): + sql.SQL("select {0} {};").format(10, 20) + + def test_compose_bad_args_type(self): + with pytest.raises(IndexError): + sql.SQL("select {0};").format(a=10) + with pytest.raises(KeyError): + sql.SQL("select {x};").format(10) + + def test_no_modifiers(self): + with pytest.raises(ValueError): + sql.SQL("select {a!r};").format(a=10) + with pytest.raises(ValueError): + sql.SQL("select {a:<};").format(a=10) + + def test_must_be_adaptable(self, conn): + class Foo: + pass + + s = sql.SQL("select {0};").format(sql.Literal(Foo())) + with pytest.raises(ProgrammingError): + s.as_string(conn) + + def test_auto_literal(self, conn): + s = sql.SQL("select {}, {}, {}").format("he'lo", 10, dt.date(2020, 1, 1)) + assert s.as_string(conn) == "select 'he''lo', 10, '2020-01-01'::date" + + def test_execute(self, conn): + cur = conn.cursor() + cur.execute( + """ + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """ + ) + cur.execute( + sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format( + sql.Identifier("test_compose"), + sql.SQL(", ").join(map(sql.Identifier, ["foo", "bar", "ba'z"])), + (sql.Placeholder() * 3).join(", "), + ), + (10, "a", "b", "c"), + ) + + cur.execute("select * from test_compose") + assert cur.fetchall() == [(10, "a", "b", "c")] + + def test_executemany(self, conn): + cur = conn.cursor() + cur.execute( + """ + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """ + ) + cur.executemany( + sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format( + sql.Identifier("test_compose"), + sql.SQL(", ").join(map(sql.Identifier, ["foo", "bar", "ba'z"])), + (sql.Placeholder() * 3).join(", "), + ), + [(10, "a", "b", "c"), (20, "d", "e", "f")], + ) + + cur.execute("select * from test_compose") + assert cur.fetchall() == [(10, "a", "b", "c"), (20, "d", "e", "f")] + + @pytest.mark.crdb_skip("copy") + def test_copy(self, conn): + cur = conn.cursor() + cur.execute( + """ + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """ + ) + + with cur.copy( + sql.SQL("copy {t} (id, foo, bar, {f}) from stdin").format( + t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z") + ), + ) as copy: + copy.write_row((10, "a", "b", "c")) + copy.write_row((20, "d", "e", "f")) + + with cur.copy( + sql.SQL("copy (select {f} from {t} order by id) to stdout").format( + t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z") + ) + ) as copy: + assert list(copy) == [b"c\n", b"f\n"] + + +class TestIdentifier: + def test_class(self): + assert issubclass(sql.Identifier, sql.Composable) + + def test_init(self): + assert isinstance(sql.Identifier("foo"), sql.Identifier) + assert isinstance(sql.Identifier("foo"), sql.Identifier) + assert isinstance(sql.Identifier("foo", "bar", "baz"), sql.Identifier) + with pytest.raises(TypeError): + sql.Identifier() + with pytest.raises(TypeError): + sql.Identifier(10) # type: ignore[arg-type] + with pytest.raises(TypeError): + sql.Identifier(dt.date(2016, 12, 31)) # type: ignore[arg-type] + + def test_repr(self): + obj = sql.Identifier("fo'o") + assert repr(obj) == 'Identifier("fo\'o")' + assert repr(obj) == str(obj) + + obj = sql.Identifier("fo'o", 'ba"r') + assert repr(obj) == "Identifier(\"fo'o\", 'ba\"r')" + assert repr(obj) == str(obj) + + def test_eq(self): + assert sql.Identifier("foo") == sql.Identifier("foo") + assert sql.Identifier("foo", "bar") == sql.Identifier("foo", "bar") + assert sql.Identifier("foo") != sql.Identifier("bar") + assert sql.Identifier("foo") != "foo" + assert sql.Identifier("foo") != sql.SQL("foo") + + @pytest.mark.parametrize( + "args, want", + [ + (("foo",), '"foo"'), + (("foo", "bar"), '"foo"."bar"'), + (("fo'o", 'ba"r'), '"fo\'o"."ba""r"'), + ], + ) + def test_as_string(self, conn, args, want): + assert sql.Identifier(*args).as_string(conn) == want + + @pytest.mark.parametrize( + "args, want, enc", + [ + crdb_encoding(("foo",), '"foo"', "ascii"), + crdb_encoding(("foo", "bar"), '"foo"."bar"', "ascii"), + crdb_encoding(("fo'o", 'ba"r'), '"fo\'o"."ba""r"', "ascii"), + (("foo", eur), f'"foo"."{eur}"', "utf8"), + crdb_encoding(("foo", eur), f'"foo"."{eur}"', "latin9"), + ], + ) + def test_as_bytes(self, conn, args, want, enc): + want = want.encode(enc) + conn.execute(f"set client_encoding to {py2pgenc(enc).decode()}") + assert sql.Identifier(*args).as_bytes(conn) == want + + def test_join(self): + assert not hasattr(sql.Identifier("foo"), "join") + + +class TestLiteral: + def test_class(self): + assert issubclass(sql.Literal, sql.Composable) + + def test_init(self): + assert isinstance(sql.Literal("foo"), sql.Literal) + assert isinstance(sql.Literal("foo"), sql.Literal) + assert isinstance(sql.Literal(b"foo"), sql.Literal) + assert isinstance(sql.Literal(42), sql.Literal) + assert isinstance(sql.Literal(dt.date(2016, 12, 31)), sql.Literal) + + def test_repr(self): + assert repr(sql.Literal("foo")) == "Literal('foo')" + assert str(sql.Literal("foo")) == "Literal('foo')" + + def test_as_string(self, conn): + assert sql.Literal(None).as_string(conn) == "NULL" + assert no_e(sql.Literal("foo").as_string(conn)) == "'foo'" + assert sql.Literal(42).as_string(conn) == "42" + assert sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'::date" + + def test_as_bytes(self, conn): + assert sql.Literal(None).as_bytes(conn) == b"NULL" + assert no_e(sql.Literal("foo").as_bytes(conn)) == b"'foo'" + assert sql.Literal(42).as_bytes(conn) == b"42" + assert sql.Literal(dt.date(2017, 1, 1)).as_bytes(conn) == b"'2017-01-01'::date" + + @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) + def test_as_bytes_encoding(self, conn, encoding): + conn.execute(f"set client_encoding to {encoding}") + assert sql.Literal(eur).as_bytes(conn) == f"'{eur}'".encode(encoding) + + def test_eq(self): + assert sql.Literal("foo") == sql.Literal("foo") + assert sql.Literal("foo") != sql.Literal("bar") + assert sql.Literal("foo") != "foo" + assert sql.Literal("foo") != sql.SQL("foo") + + def test_must_be_adaptable(self, conn): + class Foo: + pass + + with pytest.raises(ProgrammingError): + sql.Literal(Foo()).as_string(conn) + + def test_array(self, conn): + assert ( + sql.Literal([dt.date(2000, 1, 1)]).as_string(conn) + == "'{2000-01-01}'::date[]" + ) + + def test_short_name_builtin(self, conn): + assert sql.Literal(dt.time(0, 0)).as_string(conn) == "'00:00:00'::time" + assert ( + sql.Literal(dt.datetime(2000, 1, 1)).as_string(conn) + == "'2000-01-01 00:00:00'::timestamp" + ) + assert ( + sql.Literal([dt.datetime(2000, 1, 1)]).as_string(conn) + == "'{\"2000-01-01 00:00:00\"}'::timestamp[]" + ) + + def test_text_literal(self, conn): + conn.adapters.register_dumper(str, StrDumper) + assert sql.Literal("foo").as_string(conn) == "'foo'" + + @pytest.mark.crdb_skip("composite") # create type, actually + @pytest.mark.parametrize("name", ["a-b", f"{eur}", "order", "foo bar"]) + def test_invalid_name(self, conn, name): + conn.execute( + f""" + set client_encoding to utf8; + create type "{name}"; + create function invin(cstring) returns "{name}" + language internal immutable strict as 'textin'; + create function invout("{name}") returns cstring + language internal immutable strict as 'textout'; + create type "{name}" (input=invin, output=invout, like=text); + """ + ) + info = TypeInfo.fetch(conn, f'"{name}"') + + class InvDumper(StrDumper): + oid = info.oid + + def dump(self, obj): + rv = super().dump(obj) + return b"%s-inv" % rv + + info.register(conn) + conn.adapters.register_dumper(str, InvDumper) + + assert sql.Literal("hello").as_string(conn) == f"'hello-inv'::\"{name}\"" + cur = conn.execute(sql.SQL("select {}").format("hello")) + assert cur.fetchone()[0] == "hello-inv" + + assert ( + sql.Literal(["hello"]).as_string(conn) == f"'{{hello-inv}}'::\"{name}\"[]" + ) + cur = conn.execute(sql.SQL("select {}").format(["hello"])) + assert cur.fetchone()[0] == ["hello-inv"] + + +class TestSQL: + def test_class(self): + assert issubclass(sql.SQL, sql.Composable) + + def test_init(self): + assert isinstance(sql.SQL("foo"), sql.SQL) + assert isinstance(sql.SQL("foo"), sql.SQL) + with pytest.raises(TypeError): + sql.SQL(10) # type: ignore[arg-type] + with pytest.raises(TypeError): + sql.SQL(dt.date(2016, 12, 31)) # type: ignore[arg-type] + + def test_repr(self, conn): + assert repr(sql.SQL("foo")) == "SQL('foo')" + assert str(sql.SQL("foo")) == "SQL('foo')" + assert sql.SQL("foo").as_string(conn) == "foo" + + def test_eq(self): + assert sql.SQL("foo") == sql.SQL("foo") + assert sql.SQL("foo") != sql.SQL("bar") + assert sql.SQL("foo") != "foo" + assert sql.SQL("foo") != sql.Literal("foo") + + def test_sum(self, conn): + obj = sql.SQL("foo") + sql.SQL("bar") + assert isinstance(obj, sql.Composed) + assert obj.as_string(conn) == "foobar" + + def test_sum_inplace(self, conn): + obj = sql.SQL("f") + sql.SQL("oo") + obj += sql.SQL("bar") + assert isinstance(obj, sql.Composed) + assert obj.as_string(conn) == "foobar" + + def test_multiply(self, conn): + obj = sql.SQL("foo") * 3 + assert isinstance(obj, sql.Composed) + assert obj.as_string(conn) == "foofoofoo" + + def test_join(self, conn): + obj = sql.SQL(", ").join( + [sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)] + ) + assert isinstance(obj, sql.Composed) + assert obj.as_string(conn) == '"foo", bar, 42' + + obj = sql.SQL(", ").join( + sql.Composed([sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)]) + ) + assert isinstance(obj, sql.Composed) + assert obj.as_string(conn) == '"foo", bar, 42' + + obj = sql.SQL(", ").join([]) + assert obj == sql.Composed([]) + + def test_as_string(self, conn): + assert sql.SQL("foo").as_string(conn) == "foo" + + @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) + def test_as_bytes(self, conn, encoding): + if encoding: + conn.execute(f"set client_encoding to {encoding}") + + assert sql.SQL(eur).as_bytes(conn) == eur.encode(encoding) + + +class TestComposed: + def test_class(self): + assert issubclass(sql.Composed, sql.Composable) + + def test_repr(self): + obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")]) + assert repr(obj) == """Composed([Literal('foo'), Identifier("b'ar")])""" + assert str(obj) == repr(obj) + + def test_eq(self): + L = [sql.Literal("foo"), sql.Identifier("b'ar")] + l2 = [sql.Literal("foo"), sql.Literal("b'ar")] + assert sql.Composed(L) == sql.Composed(list(L)) + assert sql.Composed(L) != L + assert sql.Composed(L) != sql.Composed(l2) + + def test_join(self, conn): + obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")]) + obj = obj.join(", ") + assert isinstance(obj, sql.Composed) + assert no_e(obj.as_string(conn)) == "'foo', \"b'ar\"" + + def test_auto_literal(self, conn): + obj = sql.Composed(["fo'o", dt.date(2020, 1, 1)]) + obj = obj.join(", ") + assert isinstance(obj, sql.Composed) + assert no_e(obj.as_string(conn)) == "'fo''o', '2020-01-01'::date" + + def test_sum(self, conn): + obj = sql.Composed([sql.SQL("foo ")]) + obj = obj + sql.Literal("bar") + assert isinstance(obj, sql.Composed) + assert no_e(obj.as_string(conn)) == "foo 'bar'" + + def test_sum_inplace(self, conn): + obj = sql.Composed([sql.SQL("foo ")]) + obj += sql.Literal("bar") + assert isinstance(obj, sql.Composed) + assert no_e(obj.as_string(conn)) == "foo 'bar'" + + obj = sql.Composed([sql.SQL("foo ")]) + obj += sql.Composed([sql.Literal("bar")]) + assert isinstance(obj, sql.Composed) + assert no_e(obj.as_string(conn)) == "foo 'bar'" + + def test_iter(self): + obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")]) + it = iter(obj) + i = next(it) + assert i == sql.SQL("foo") + i = next(it) + assert i == sql.SQL("bar") + with pytest.raises(StopIteration): + next(it) + + def test_as_string(self, conn): + obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")]) + assert obj.as_string(conn) == "foobar" + + def test_as_bytes(self, conn): + obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")]) + assert obj.as_bytes(conn) == b"foobar" + + @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) + def test_as_bytes_encoding(self, conn, encoding): + obj = sql.Composed([sql.SQL("foo"), sql.SQL(eur)]) + conn.execute(f"set client_encoding to {encoding}") + assert obj.as_bytes(conn) == ("foo" + eur).encode(encoding) + + +class TestPlaceholder: + def test_class(self): + assert issubclass(sql.Placeholder, sql.Composable) + + @pytest.mark.parametrize("format", PyFormat) + def test_repr_format(self, conn, format): + ph = sql.Placeholder(format=format) + add = f"format={format.name}" if format != PyFormat.AUTO else "" + assert str(ph) == repr(ph) == f"Placeholder({add})" + + @pytest.mark.parametrize("format", PyFormat) + def test_repr_name_format(self, conn, format): + ph = sql.Placeholder("foo", format=format) + add = f", format={format.name}" if format != PyFormat.AUTO else "" + assert str(ph) == repr(ph) == f"Placeholder('foo'{add})" + + def test_bad_name(self): + with pytest.raises(ValueError): + sql.Placeholder(")") + + def test_eq(self): + assert sql.Placeholder("foo") == sql.Placeholder("foo") + assert sql.Placeholder("foo") != sql.Placeholder("bar") + assert sql.Placeholder("foo") != "foo" + assert sql.Placeholder() == sql.Placeholder() + assert sql.Placeholder("foo") != sql.Placeholder() + assert sql.Placeholder("foo") != sql.Literal("foo") + + @pytest.mark.parametrize("format", PyFormat) + def test_as_string(self, conn, format): + ph = sql.Placeholder(format=format) + assert ph.as_string(conn) == f"%{format.value}" + + ph = sql.Placeholder(name="foo", format=format) + assert ph.as_string(conn) == f"%(foo){format.value}" + + @pytest.mark.parametrize("format", PyFormat) + def test_as_bytes(self, conn, format): + ph = sql.Placeholder(format=format) + assert ph.as_bytes(conn) == f"%{format.value}".encode("ascii") + + ph = sql.Placeholder(name="foo", format=format) + assert ph.as_bytes(conn) == f"%(foo){format.value}".encode("ascii") + + +class TestValues: + def test_null(self, conn): + assert isinstance(sql.NULL, sql.SQL) + assert sql.NULL.as_string(conn) == "NULL" + + def test_default(self, conn): + assert isinstance(sql.DEFAULT, sql.SQL) + assert sql.DEFAULT.as_string(conn) == "DEFAULT" + + +def no_e(s): + """Drop an eventual E from E'' quotes""" + if isinstance(s, memoryview): + s = bytes(s) + + if isinstance(s, str): + return re.sub(r"\bE'", "'", s) + elif isinstance(s, bytes): + return re.sub(rb"\bE'", b"'", s) + else: + raise TypeError(f"not dealing with {type(s).__name__}: {s}") |