diff options
Diffstat (limited to '')
-rw-r--r-- | tests/test_conninfo.py | 450 |
1 files changed, 450 insertions, 0 deletions
diff --git a/tests/test_conninfo.py b/tests/test_conninfo.py new file mode 100644 index 0000000..e2c2c01 --- /dev/null +++ b/tests/test_conninfo.py @@ -0,0 +1,450 @@ +import socket +import asyncio +import datetime as dt + +import pytest + +import psycopg +from psycopg import ProgrammingError +from psycopg.conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo +from psycopg.conninfo import resolve_hostaddr_async +from psycopg._encodings import pg2pyenc + +from .fix_crdb import crdb_encoding + +snowman = "\u2603" + + +class MyString(str): + pass + + +@pytest.mark.parametrize( + "conninfo, kwargs, exp", + [ + ("", {}, ""), + ("dbname=foo", {}, "dbname=foo"), + ("dbname=foo", {"user": "bar"}, "dbname=foo user=bar"), + ("dbname=sony", {"password": ""}, "dbname=sony password="), + ("dbname=foo", {"dbname": "bar"}, "dbname=bar"), + ("user=bar", {"dbname": "foo bar"}, "dbname='foo bar' user=bar"), + ("", {"dbname": "foo"}, "dbname=foo"), + ("", {"dbname": "foo", "user": None}, "dbname=foo"), + ("", {"dbname": "foo", "port": 15432}, "dbname=foo port=15432"), + ("", {"dbname": "a'b"}, r"dbname='a\'b'"), + (f"dbname={snowman}", {}, f"dbname={snowman}"), + ("", {"dbname": snowman}, f"dbname={snowman}"), + ( + "postgresql://host1/test", + {"host": "host2"}, + "dbname=test host=host2", + ), + (MyString(""), {}, ""), + ], +) +def test_make_conninfo(conninfo, kwargs, exp): + out = make_conninfo(conninfo, **kwargs) + assert conninfo_to_dict(out) == conninfo_to_dict(exp) + + +@pytest.mark.parametrize( + "conninfo, kwargs", + [ + ("hello", {}), + ("dbname=foo bar", {}), + ("foo=bar", {}), + ("dbname=foo", {"bar": "baz"}), + ("postgresql://tester:secret@/test?port=5433=x", {}), + (f"{snowman}={snowman}", {}), + ], +) +def test_make_conninfo_bad(conninfo, kwargs): + with pytest.raises(ProgrammingError): + make_conninfo(conninfo, **kwargs) + + +@pytest.mark.parametrize( + "conninfo, exp", + [ + ("", {}), + ("dbname=foo user=bar", {"dbname": "foo", "user": "bar"}), + ("dbname=sony password=", {"dbname": "sony", "password": ""}), + ("dbname='foo bar'", {"dbname": "foo bar"}), + ("dbname='a\"b'", {"dbname": 'a"b'}), + (r"dbname='a\'b'", {"dbname": "a'b"}), + (r"dbname='a\\b'", {"dbname": r"a\b"}), + (f"dbname={snowman}", {"dbname": snowman}), + ( + "postgresql://tester:secret@/test?port=5433", + { + "user": "tester", + "password": "secret", + "dbname": "test", + "port": "5433", + }, + ), + ], +) +def test_conninfo_to_dict(conninfo, exp): + assert conninfo_to_dict(conninfo) == exp + + +def test_no_munging(): + dsnin = "dbname=a host=b user=c password=d" + dsnout = make_conninfo(dsnin) + assert dsnin == dsnout + + +class TestConnectionInfo: + @pytest.mark.parametrize( + "attr", + [("dbname", "db"), "host", "hostaddr", "user", "password", "options"], + ) + def test_attrs(self, conn, attr): + if isinstance(attr, tuple): + info_attr, pgconn_attr = attr + else: + info_attr = pgconn_attr = attr + + if info_attr == "hostaddr" and psycopg.pq.version() < 120000: + pytest.skip("hostaddr not supported on libpq < 12") + + info_val = getattr(conn.info, info_attr) + pgconn_val = getattr(conn.pgconn, pgconn_attr).decode() + assert info_val == pgconn_val + + conn.close() + with pytest.raises(psycopg.OperationalError): + getattr(conn.info, info_attr) + + @pytest.mark.libpq("< 12") + def test_hostaddr_not_supported(self, conn): + with pytest.raises(psycopg.NotSupportedError): + conn.info.hostaddr + + def test_port(self, conn): + assert conn.info.port == int(conn.pgconn.port.decode()) + conn.close() + with pytest.raises(psycopg.OperationalError): + conn.info.port + + def test_get_params(self, conn, dsn): + info = conn.info.get_parameters() + for k, v in conninfo_to_dict(dsn).items(): + if k != "password": + assert info.get(k) == v + else: + assert k not in info + + def test_dsn(self, conn, dsn): + dsn = conn.info.dsn + assert "password" not in dsn + for k, v in conninfo_to_dict(dsn).items(): + if k != "password": + assert f"{k}=" in dsn + + def test_get_params_env(self, conn_cls, dsn, monkeypatch): + dsn = conninfo_to_dict(dsn) + dsn.pop("application_name", None) + + monkeypatch.delenv("PGAPPNAME", raising=False) + with conn_cls.connect(**dsn) as conn: + assert "application_name" not in conn.info.get_parameters() + + monkeypatch.setenv("PGAPPNAME", "hello test") + with conn_cls.connect(**dsn) as conn: + assert conn.info.get_parameters()["application_name"] == "hello test" + + def test_dsn_env(self, conn_cls, dsn, monkeypatch): + dsn = conninfo_to_dict(dsn) + dsn.pop("application_name", None) + + monkeypatch.delenv("PGAPPNAME", raising=False) + with conn_cls.connect(**dsn) as conn: + assert "application_name=" not in conn.info.dsn + + monkeypatch.setenv("PGAPPNAME", "hello test") + with conn_cls.connect(**dsn) as conn: + assert "application_name='hello test'" in conn.info.dsn + + def test_status(self, conn): + assert conn.info.status.name == "OK" + conn.close() + assert conn.info.status.name == "BAD" + + def test_transaction_status(self, conn): + assert conn.info.transaction_status.name == "IDLE" + conn.close() + assert conn.info.transaction_status.name == "UNKNOWN" + + @pytest.mark.pipeline + def test_pipeline_status(self, conn): + assert not conn.info.pipeline_status + assert conn.info.pipeline_status.name == "OFF" + with conn.pipeline(): + assert conn.info.pipeline_status + assert conn.info.pipeline_status.name == "ON" + + @pytest.mark.libpq("< 14") + def test_pipeline_status_no_pipeline(self, conn): + assert not conn.info.pipeline_status + assert conn.info.pipeline_status.name == "OFF" + + def test_no_password(self, dsn): + dsn2 = make_conninfo(dsn, password="the-pass-word") + pgconn = psycopg.pq.PGconn.connect_start(dsn2.encode()) + info = ConnectionInfo(pgconn) + assert info.password == "the-pass-word" + assert "password" not in info.get_parameters() + assert info.get_parameters()["dbname"] == info.dbname + + def test_dsn_no_password(self, dsn): + dsn2 = make_conninfo(dsn, password="the-pass-word") + pgconn = psycopg.pq.PGconn.connect_start(dsn2.encode()) + info = ConnectionInfo(pgconn) + assert info.password == "the-pass-word" + assert "password" not in info.dsn + assert f"dbname={info.dbname}" in info.dsn + + def test_parameter_status(self, conn): + assert conn.info.parameter_status("nosuchparam") is None + tz = conn.info.parameter_status("TimeZone") + assert tz and isinstance(tz, str) + assert tz == conn.execute("show timezone").fetchone()[0] + + @pytest.mark.crdb("skip") + def test_server_version(self, conn): + assert conn.info.server_version == conn.pgconn.server_version + + def test_error_message(self, conn): + assert conn.info.error_message == "" + with pytest.raises(psycopg.ProgrammingError) as ex: + conn.execute("wat") + + assert conn.info.error_message + assert str(ex.value) in conn.info.error_message + assert ex.value.diag.severity in conn.info.error_message + + conn.close() + assert "NULL" in conn.info.error_message + + @pytest.mark.crdb_skip("backend pid") + def test_backend_pid(self, conn): + assert conn.info.backend_pid + assert conn.info.backend_pid == conn.pgconn.backend_pid + conn.close() + with pytest.raises(psycopg.OperationalError): + conn.info.backend_pid + + def test_timezone(self, conn): + conn.execute("set timezone to 'Europe/Rome'") + tz = conn.info.timezone + assert isinstance(tz, dt.tzinfo) + offset = tz.utcoffset(dt.datetime(2000, 1, 1)) + assert offset and offset.total_seconds() == 3600 + offset = tz.utcoffset(dt.datetime(2000, 7, 1)) + assert offset and offset.total_seconds() == 7200 + + @pytest.mark.crdb("skip", reason="crdb doesn't allow invalid timezones") + def test_timezone_warn(self, conn, caplog): + conn.execute("set timezone to 'FOOBAR0'") + assert len(caplog.records) == 0 + tz = conn.info.timezone + assert tz == dt.timezone.utc + assert len(caplog.records) == 1 + assert "FOOBAR0" in caplog.records[0].message + + conn.info.timezone + assert len(caplog.records) == 1 + + conn.execute("set timezone to 'FOOBAAR0'") + assert len(caplog.records) == 1 + conn.info.timezone + assert len(caplog.records) == 2 + assert "FOOBAAR0" in caplog.records[1].message + + def test_encoding(self, conn): + enc = conn.execute("show client_encoding").fetchone()[0] + assert conn.info.encoding == pg2pyenc(enc.encode()) + + @pytest.mark.crdb("skip", reason="encoding not normalized") + @pytest.mark.parametrize( + "enc, out, codec", + [ + ("utf8", "UTF8", "utf-8"), + ("utf-8", "UTF8", "utf-8"), + ("utf_8", "UTF8", "utf-8"), + ("eucjp", "EUC_JP", "euc_jp"), + ("euc-jp", "EUC_JP", "euc_jp"), + ("latin9", "LATIN9", "iso8859-15"), + ], + ) + def test_normalize_encoding(self, conn, enc, out, codec): + conn.execute("select set_config('client_encoding', %s, false)", [enc]) + assert conn.info.parameter_status("client_encoding") == out + assert conn.info.encoding == codec + + @pytest.mark.parametrize( + "enc, out, codec", + [ + ("utf8", "UTF8", "utf-8"), + ("utf-8", "UTF8", "utf-8"), + ("utf_8", "UTF8", "utf-8"), + crdb_encoding("eucjp", "EUC_JP", "euc_jp"), + crdb_encoding("euc-jp", "EUC_JP", "euc_jp"), + ], + ) + def test_encoding_env_var(self, conn_cls, dsn, monkeypatch, enc, out, codec): + monkeypatch.setenv("PGCLIENTENCODING", enc) + with conn_cls.connect(dsn) as conn: + clienc = conn.info.parameter_status("client_encoding") + assert clienc + if conn.info.vendor == "PostgreSQL": + assert clienc == out + else: + assert clienc.replace("-", "").replace("_", "").upper() == out + assert conn.info.encoding == codec + + @pytest.mark.crdb_skip("encoding") + def test_set_encoding_unsupported(self, conn): + cur = conn.cursor() + cur.execute("set client_encoding to EUC_TW") + with pytest.raises(psycopg.NotSupportedError): + cur.execute("select 'x'") + + def test_vendor(self, conn): + assert conn.info.vendor + + +@pytest.mark.parametrize( + "conninfo, want, env", + [ + ("", "", None), + ("host='' user=bar", "host='' user=bar", None), + ( + "host=127.0.0.1 user=bar", + "host=127.0.0.1 user=bar hostaddr=127.0.0.1", + None, + ), + ( + "host=1.1.1.1,2.2.2.2 user=bar", + "host=1.1.1.1,2.2.2.2 user=bar hostaddr=1.1.1.1,2.2.2.2", + None, + ), + ( + "host=1.1.1.1,2.2.2.2 port=5432", + "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2", + None, + ), + ( + "port=5432", + "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2", + {"PGHOST": "1.1.1.1,2.2.2.2"}, + ), + ( + "host=foo.com port=5432", + "host=foo.com port=5432", + {"PGHOSTADDR": "1.2.3.4"}, + ), + ], +) +@pytest.mark.asyncio +async def test_resolve_hostaddr_async_no_resolve( + setpgenv, conninfo, want, env, fail_resolve +): + setpgenv(env) + params = conninfo_to_dict(conninfo) + params = await resolve_hostaddr_async(params) + assert conninfo_to_dict(want) == params + + +@pytest.mark.parametrize( + "conninfo, want, env", + [ + ( + "host=foo.com,qux.com", + "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2", + None, + ), + ( + "host=foo.com,qux.com port=5433", + "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433", + None, + ), + ( + "host=foo.com,qux.com port=5432,5433", + "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5432,5433", + None, + ), + ( + "host=foo.com,nosuchhost.com", + "host=foo.com hostaddr=1.1.1.1", + None, + ), + ( + "host=foo.com, port=5432,5433", + "host=foo.com, hostaddr=1.1.1.1, port=5432,5433", + None, + ), + ( + "host=nosuchhost.com,foo.com", + "host=foo.com hostaddr=1.1.1.1", + None, + ), + ( + "host=foo.com,qux.com", + "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2", + {}, + ), + ], +) +@pytest.mark.asyncio +async def test_resolve_hostaddr_async(conninfo, want, env, fake_resolve): + params = conninfo_to_dict(conninfo) + params = await resolve_hostaddr_async(params) + assert conninfo_to_dict(want) == params + + +@pytest.mark.parametrize( + "conninfo, env", + [ + ("host=bad1.com,bad2.com", None), + ("host=foo.com port=1,2", None), + ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None), + ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}), + ], +) +@pytest.mark.asyncio +async def test_resolve_hostaddr_async_bad(setpgenv, conninfo, env, fake_resolve): + setpgenv(env) + params = conninfo_to_dict(conninfo) + with pytest.raises(psycopg.Error): + await resolve_hostaddr_async(params) + + +@pytest.fixture +async def fake_resolve(monkeypatch): + fake_hosts = { + "localhost": "127.0.0.1", + "foo.com": "1.1.1.1", + "qux.com": "2.2.2.2", + } + + async def fake_getaddrinfo(host, port, **kwargs): + assert isinstance(port, int) or (isinstance(port, str) and port.isdigit()) + try: + addr = fake_hosts[host] + except KeyError: + raise OSError(f"unknown test host: {host}") + else: + return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (addr, 432))] + + monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo) + + +@pytest.fixture +async def fail_resolve(monkeypatch): + async def fail_getaddrinfo(host, port, **kwargs): + pytest.fail(f"shouldn't try to resolve {host}") + + monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fail_getaddrinfo) |