diff options
Diffstat (limited to 'tests/pq/test_escaping.py')
-rw-r--r-- | tests/pq/test_escaping.py | 188 |
1 files changed, 188 insertions, 0 deletions
diff --git a/tests/pq/test_escaping.py b/tests/pq/test_escaping.py new file mode 100644 index 0000000..ad88d8a --- /dev/null +++ b/tests/pq/test_escaping.py @@ -0,0 +1,188 @@ +import pytest + +import psycopg +from psycopg import pq + +from ..fix_crdb import crdb_scs_off + + +@pytest.mark.parametrize( + "data, want", + [ + (b"", b"''"), + (b"hello", b"'hello'"), + (b"foo'bar", b"'foo''bar'"), + (b"foo\\bar", b" E'foo\\\\bar'"), + ], +) +def test_escape_literal(pgconn, data, want): + esc = pq.Escaping(pgconn) + out = esc.escape_literal(data) + assert out == want + + +@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")]) +def test_escape_literal_1char(pgconn, scs): + res = pgconn.exec_(f"set standard_conforming_strings to {scs}".encode("ascii")) + assert res.status == pq.ExecStatus.COMMAND_OK + esc = pq.Escaping(pgconn) + special = {b"'": b"''''", b"\\": b" E'\\\\'"} + for c in range(1, 128): + data = bytes([c]) + rv = esc.escape_literal(data) + exp = special.get(data) or b"'%s'" % data + assert rv == exp + + +def test_escape_literal_noconn(pgconn): + esc = pq.Escaping() + with pytest.raises(psycopg.OperationalError): + esc.escape_literal(b"hi") + + esc = pq.Escaping(pgconn) + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + esc.escape_literal(b"hi") + + +@pytest.mark.parametrize( + "data, want", + [ + (b"", b'""'), + (b"hello", b'"hello"'), + (b'foo"bar', b'"foo""bar"'), + (b"foo\\bar", b'"foo\\bar"'), + ], +) +def test_escape_identifier(pgconn, data, want): + esc = pq.Escaping(pgconn) + out = esc.escape_identifier(data) + assert out == want + + +@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")]) +def test_escape_identifier_1char(pgconn, scs): + res = pgconn.exec_(f"set standard_conforming_strings to {scs}".encode("ascii")) + assert res.status == pq.ExecStatus.COMMAND_OK + esc = pq.Escaping(pgconn) + special = {b'"': b'""""', b"\\": b'"\\"'} + for c in range(1, 128): + data = bytes([c]) + rv = esc.escape_identifier(data) + exp = special.get(data) or b'"%s"' % data + assert rv == exp + + +def test_escape_identifier_noconn(pgconn): + esc = pq.Escaping() + with pytest.raises(psycopg.OperationalError): + esc.escape_identifier(b"hi") + + esc = pq.Escaping(pgconn) + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + esc.escape_identifier(b"hi") + + +@pytest.mark.parametrize( + "data, want", + [ + (b"", b""), + (b"hello", b"hello"), + (b"foo'bar", b"foo''bar"), + (b"foo\\bar", b"foo\\bar"), + ], +) +def test_escape_string(pgconn, data, want): + esc = pq.Escaping(pgconn) + out = esc.escape_string(data) + assert out == want + + +@pytest.mark.parametrize("scs", ["on", crdb_scs_off("off")]) +def test_escape_string_1char(pgconn, scs): + esc = pq.Escaping(pgconn) + res = pgconn.exec_(f"set standard_conforming_strings to {scs}".encode("ascii")) + assert res.status == pq.ExecStatus.COMMAND_OK + special = {b"'": b"''", b"\\": b"\\" if scs == "on" else b"\\\\"} + for c in range(1, 128): + data = bytes([c]) + rv = esc.escape_string(data) + exp = special.get(data) or b"%s" % data + assert rv == exp + + +@pytest.mark.parametrize( + "data, want", + [ + (b"", b""), + (b"hello", b"hello"), + (b"foo'bar", b"foo''bar"), + # This libpq function behaves unpredictably when not passed a conn + (b"foo\\bar", (b"foo\\\\bar", b"foo\\bar")), + ], +) +def test_escape_string_noconn(data, want): + esc = pq.Escaping() + out = esc.escape_string(data) + if isinstance(want, bytes): + assert out == want + else: + assert out in want + + +def test_escape_string_badconn(pgconn): + esc = pq.Escaping(pgconn) + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + esc.escape_string(b"hi") + + +def test_escape_string_badenc(pgconn): + res = pgconn.exec_(b"set client_encoding to 'UTF8'") + assert res.status == pq.ExecStatus.COMMAND_OK + data = "\u20ac".encode()[:-1] + esc = pq.Escaping(pgconn) + with pytest.raises(psycopg.OperationalError): + esc.escape_string(data) + + +@pytest.mark.parametrize("data", [b"hello\00world", b"\00\00\00\00"]) +def test_escape_bytea(pgconn, data): + exp = rb"\x" + b"".join(b"%02x" % c for c in data) + esc = pq.Escaping(pgconn) + rv = esc.escape_bytea(data) + assert rv == exp + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + esc.escape_bytea(data) + + +def test_escape_noconn(pgconn): + data = bytes(range(256)) + esc = pq.Escaping() + escdata = esc.escape_bytea(data) + res = pgconn.exec_params(b"select '%s'::bytea" % escdata, [], result_format=1) + assert res.status == pq.ExecStatus.TUPLES_OK + assert res.get_value(0, 0) == data + + +def test_escape_1char(pgconn): + esc = pq.Escaping(pgconn) + for c in range(256): + rv = esc.escape_bytea(bytes([c])) + exp = rb"\x%02x" % c + assert rv == exp + + +@pytest.mark.parametrize("data", [b"hello\00world", b"\00\00\00\00"]) +def test_unescape_bytea(pgconn, data): + enc = rb"\x" + b"".join(b"%02x" % c for c in data) + esc = pq.Escaping(pgconn) + rv = esc.unescape_bytea(enc) + assert rv == data + + pgconn.finish() + with pytest.raises(psycopg.OperationalError): + esc.unescape_bytea(data) |