summaryrefslogtreecommitdiffstats
path: root/tests/test_conninfo.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_conninfo.py')
-rw-r--r--tests/test_conninfo.py450
1 files changed, 450 insertions, 0 deletions
diff --git a/tests/test_conninfo.py b/tests/test_conninfo.py
new file mode 100644
index 0000000..e2c2c01
--- /dev/null
+++ b/tests/test_conninfo.py
@@ -0,0 +1,450 @@
+import socket
+import asyncio
+import datetime as dt
+
+import pytest
+
+import psycopg
+from psycopg import ProgrammingError
+from psycopg.conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
+from psycopg.conninfo import resolve_hostaddr_async
+from psycopg._encodings import pg2pyenc
+
+from .fix_crdb import crdb_encoding
+
+snowman = "\u2603"
+
+
+class MyString(str):
+ pass
+
+
+@pytest.mark.parametrize(
+ "conninfo, kwargs, exp",
+ [
+ ("", {}, ""),
+ ("dbname=foo", {}, "dbname=foo"),
+ ("dbname=foo", {"user": "bar"}, "dbname=foo user=bar"),
+ ("dbname=sony", {"password": ""}, "dbname=sony password="),
+ ("dbname=foo", {"dbname": "bar"}, "dbname=bar"),
+ ("user=bar", {"dbname": "foo bar"}, "dbname='foo bar' user=bar"),
+ ("", {"dbname": "foo"}, "dbname=foo"),
+ ("", {"dbname": "foo", "user": None}, "dbname=foo"),
+ ("", {"dbname": "foo", "port": 15432}, "dbname=foo port=15432"),
+ ("", {"dbname": "a'b"}, r"dbname='a\'b'"),
+ (f"dbname={snowman}", {}, f"dbname={snowman}"),
+ ("", {"dbname": snowman}, f"dbname={snowman}"),
+ (
+ "postgresql://host1/test",
+ {"host": "host2"},
+ "dbname=test host=host2",
+ ),
+ (MyString(""), {}, ""),
+ ],
+)
+def test_make_conninfo(conninfo, kwargs, exp):
+ out = make_conninfo(conninfo, **kwargs)
+ assert conninfo_to_dict(out) == conninfo_to_dict(exp)
+
+
+@pytest.mark.parametrize(
+ "conninfo, kwargs",
+ [
+ ("hello", {}),
+ ("dbname=foo bar", {}),
+ ("foo=bar", {}),
+ ("dbname=foo", {"bar": "baz"}),
+ ("postgresql://tester:secret@/test?port=5433=x", {}),
+ (f"{snowman}={snowman}", {}),
+ ],
+)
+def test_make_conninfo_bad(conninfo, kwargs):
+ with pytest.raises(ProgrammingError):
+ make_conninfo(conninfo, **kwargs)
+
+
+@pytest.mark.parametrize(
+ "conninfo, exp",
+ [
+ ("", {}),
+ ("dbname=foo user=bar", {"dbname": "foo", "user": "bar"}),
+ ("dbname=sony password=", {"dbname": "sony", "password": ""}),
+ ("dbname='foo bar'", {"dbname": "foo bar"}),
+ ("dbname='a\"b'", {"dbname": 'a"b'}),
+ (r"dbname='a\'b'", {"dbname": "a'b"}),
+ (r"dbname='a\\b'", {"dbname": r"a\b"}),
+ (f"dbname={snowman}", {"dbname": snowman}),
+ (
+ "postgresql://tester:secret@/test?port=5433",
+ {
+ "user": "tester",
+ "password": "secret",
+ "dbname": "test",
+ "port": "5433",
+ },
+ ),
+ ],
+)
+def test_conninfo_to_dict(conninfo, exp):
+ assert conninfo_to_dict(conninfo) == exp
+
+
+def test_no_munging():
+ dsnin = "dbname=a host=b user=c password=d"
+ dsnout = make_conninfo(dsnin)
+ assert dsnin == dsnout
+
+
+class TestConnectionInfo:
+ @pytest.mark.parametrize(
+ "attr",
+ [("dbname", "db"), "host", "hostaddr", "user", "password", "options"],
+ )
+ def test_attrs(self, conn, attr):
+ if isinstance(attr, tuple):
+ info_attr, pgconn_attr = attr
+ else:
+ info_attr = pgconn_attr = attr
+
+ if info_attr == "hostaddr" and psycopg.pq.version() < 120000:
+ pytest.skip("hostaddr not supported on libpq < 12")
+
+ info_val = getattr(conn.info, info_attr)
+ pgconn_val = getattr(conn.pgconn, pgconn_attr).decode()
+ assert info_val == pgconn_val
+
+ conn.close()
+ with pytest.raises(psycopg.OperationalError):
+ getattr(conn.info, info_attr)
+
+ @pytest.mark.libpq("< 12")
+ def test_hostaddr_not_supported(self, conn):
+ with pytest.raises(psycopg.NotSupportedError):
+ conn.info.hostaddr
+
+ def test_port(self, conn):
+ assert conn.info.port == int(conn.pgconn.port.decode())
+ conn.close()
+ with pytest.raises(psycopg.OperationalError):
+ conn.info.port
+
+ def test_get_params(self, conn, dsn):
+ info = conn.info.get_parameters()
+ for k, v in conninfo_to_dict(dsn).items():
+ if k != "password":
+ assert info.get(k) == v
+ else:
+ assert k not in info
+
+ def test_dsn(self, conn, dsn):
+ dsn = conn.info.dsn
+ assert "password" not in dsn
+ for k, v in conninfo_to_dict(dsn).items():
+ if k != "password":
+ assert f"{k}=" in dsn
+
+ def test_get_params_env(self, conn_cls, dsn, monkeypatch):
+ dsn = conninfo_to_dict(dsn)
+ dsn.pop("application_name", None)
+
+ monkeypatch.delenv("PGAPPNAME", raising=False)
+ with conn_cls.connect(**dsn) as conn:
+ assert "application_name" not in conn.info.get_parameters()
+
+ monkeypatch.setenv("PGAPPNAME", "hello test")
+ with conn_cls.connect(**dsn) as conn:
+ assert conn.info.get_parameters()["application_name"] == "hello test"
+
+ def test_dsn_env(self, conn_cls, dsn, monkeypatch):
+ dsn = conninfo_to_dict(dsn)
+ dsn.pop("application_name", None)
+
+ monkeypatch.delenv("PGAPPNAME", raising=False)
+ with conn_cls.connect(**dsn) as conn:
+ assert "application_name=" not in conn.info.dsn
+
+ monkeypatch.setenv("PGAPPNAME", "hello test")
+ with conn_cls.connect(**dsn) as conn:
+ assert "application_name='hello test'" in conn.info.dsn
+
+ def test_status(self, conn):
+ assert conn.info.status.name == "OK"
+ conn.close()
+ assert conn.info.status.name == "BAD"
+
+ def test_transaction_status(self, conn):
+ assert conn.info.transaction_status.name == "IDLE"
+ conn.close()
+ assert conn.info.transaction_status.name == "UNKNOWN"
+
+ @pytest.mark.pipeline
+ def test_pipeline_status(self, conn):
+ assert not conn.info.pipeline_status
+ assert conn.info.pipeline_status.name == "OFF"
+ with conn.pipeline():
+ assert conn.info.pipeline_status
+ assert conn.info.pipeline_status.name == "ON"
+
+ @pytest.mark.libpq("< 14")
+ def test_pipeline_status_no_pipeline(self, conn):
+ assert not conn.info.pipeline_status
+ assert conn.info.pipeline_status.name == "OFF"
+
+ def test_no_password(self, dsn):
+ dsn2 = make_conninfo(dsn, password="the-pass-word")
+ pgconn = psycopg.pq.PGconn.connect_start(dsn2.encode())
+ info = ConnectionInfo(pgconn)
+ assert info.password == "the-pass-word"
+ assert "password" not in info.get_parameters()
+ assert info.get_parameters()["dbname"] == info.dbname
+
+ def test_dsn_no_password(self, dsn):
+ dsn2 = make_conninfo(dsn, password="the-pass-word")
+ pgconn = psycopg.pq.PGconn.connect_start(dsn2.encode())
+ info = ConnectionInfo(pgconn)
+ assert info.password == "the-pass-word"
+ assert "password" not in info.dsn
+ assert f"dbname={info.dbname}" in info.dsn
+
+ def test_parameter_status(self, conn):
+ assert conn.info.parameter_status("nosuchparam") is None
+ tz = conn.info.parameter_status("TimeZone")
+ assert tz and isinstance(tz, str)
+ assert tz == conn.execute("show timezone").fetchone()[0]
+
+ @pytest.mark.crdb("skip")
+ def test_server_version(self, conn):
+ assert conn.info.server_version == conn.pgconn.server_version
+
+ def test_error_message(self, conn):
+ assert conn.info.error_message == ""
+ with pytest.raises(psycopg.ProgrammingError) as ex:
+ conn.execute("wat")
+
+ assert conn.info.error_message
+ assert str(ex.value) in conn.info.error_message
+ assert ex.value.diag.severity in conn.info.error_message
+
+ conn.close()
+ assert "NULL" in conn.info.error_message
+
+ @pytest.mark.crdb_skip("backend pid")
+ def test_backend_pid(self, conn):
+ assert conn.info.backend_pid
+ assert conn.info.backend_pid == conn.pgconn.backend_pid
+ conn.close()
+ with pytest.raises(psycopg.OperationalError):
+ conn.info.backend_pid
+
+ def test_timezone(self, conn):
+ conn.execute("set timezone to 'Europe/Rome'")
+ tz = conn.info.timezone
+ assert isinstance(tz, dt.tzinfo)
+ offset = tz.utcoffset(dt.datetime(2000, 1, 1))
+ assert offset and offset.total_seconds() == 3600
+ offset = tz.utcoffset(dt.datetime(2000, 7, 1))
+ assert offset and offset.total_seconds() == 7200
+
+ @pytest.mark.crdb("skip", reason="crdb doesn't allow invalid timezones")
+ def test_timezone_warn(self, conn, caplog):
+ conn.execute("set timezone to 'FOOBAR0'")
+ assert len(caplog.records) == 0
+ tz = conn.info.timezone
+ assert tz == dt.timezone.utc
+ assert len(caplog.records) == 1
+ assert "FOOBAR0" in caplog.records[0].message
+
+ conn.info.timezone
+ assert len(caplog.records) == 1
+
+ conn.execute("set timezone to 'FOOBAAR0'")
+ assert len(caplog.records) == 1
+ conn.info.timezone
+ assert len(caplog.records) == 2
+ assert "FOOBAAR0" in caplog.records[1].message
+
+ def test_encoding(self, conn):
+ enc = conn.execute("show client_encoding").fetchone()[0]
+ assert conn.info.encoding == pg2pyenc(enc.encode())
+
+ @pytest.mark.crdb("skip", reason="encoding not normalized")
+ @pytest.mark.parametrize(
+ "enc, out, codec",
+ [
+ ("utf8", "UTF8", "utf-8"),
+ ("utf-8", "UTF8", "utf-8"),
+ ("utf_8", "UTF8", "utf-8"),
+ ("eucjp", "EUC_JP", "euc_jp"),
+ ("euc-jp", "EUC_JP", "euc_jp"),
+ ("latin9", "LATIN9", "iso8859-15"),
+ ],
+ )
+ def test_normalize_encoding(self, conn, enc, out, codec):
+ conn.execute("select set_config('client_encoding', %s, false)", [enc])
+ assert conn.info.parameter_status("client_encoding") == out
+ assert conn.info.encoding == codec
+
+ @pytest.mark.parametrize(
+ "enc, out, codec",
+ [
+ ("utf8", "UTF8", "utf-8"),
+ ("utf-8", "UTF8", "utf-8"),
+ ("utf_8", "UTF8", "utf-8"),
+ crdb_encoding("eucjp", "EUC_JP", "euc_jp"),
+ crdb_encoding("euc-jp", "EUC_JP", "euc_jp"),
+ ],
+ )
+ def test_encoding_env_var(self, conn_cls, dsn, monkeypatch, enc, out, codec):
+ monkeypatch.setenv("PGCLIENTENCODING", enc)
+ with conn_cls.connect(dsn) as conn:
+ clienc = conn.info.parameter_status("client_encoding")
+ assert clienc
+ if conn.info.vendor == "PostgreSQL":
+ assert clienc == out
+ else:
+ assert clienc.replace("-", "").replace("_", "").upper() == out
+ assert conn.info.encoding == codec
+
+ @pytest.mark.crdb_skip("encoding")
+ def test_set_encoding_unsupported(self, conn):
+ cur = conn.cursor()
+ cur.execute("set client_encoding to EUC_TW")
+ with pytest.raises(psycopg.NotSupportedError):
+ cur.execute("select 'x'")
+
+ def test_vendor(self, conn):
+ assert conn.info.vendor
+
+
+@pytest.mark.parametrize(
+ "conninfo, want, env",
+ [
+ ("", "", None),
+ ("host='' user=bar", "host='' user=bar", None),
+ (
+ "host=127.0.0.1 user=bar",
+ "host=127.0.0.1 user=bar hostaddr=127.0.0.1",
+ None,
+ ),
+ (
+ "host=1.1.1.1,2.2.2.2 user=bar",
+ "host=1.1.1.1,2.2.2.2 user=bar hostaddr=1.1.1.1,2.2.2.2",
+ None,
+ ),
+ (
+ "host=1.1.1.1,2.2.2.2 port=5432",
+ "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
+ None,
+ ),
+ (
+ "port=5432",
+ "host=1.1.1.1,2.2.2.2 port=5432 hostaddr=1.1.1.1,2.2.2.2",
+ {"PGHOST": "1.1.1.1,2.2.2.2"},
+ ),
+ (
+ "host=foo.com port=5432",
+ "host=foo.com port=5432",
+ {"PGHOSTADDR": "1.2.3.4"},
+ ),
+ ],
+)
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async_no_resolve(
+ setpgenv, conninfo, want, env, fail_resolve
+):
+ setpgenv(env)
+ params = conninfo_to_dict(conninfo)
+ params = await resolve_hostaddr_async(params)
+ assert conninfo_to_dict(want) == params
+
+
+@pytest.mark.parametrize(
+ "conninfo, want, env",
+ [
+ (
+ "host=foo.com,qux.com",
+ "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+ None,
+ ),
+ (
+ "host=foo.com,qux.com port=5433",
+ "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5433",
+ None,
+ ),
+ (
+ "host=foo.com,qux.com port=5432,5433",
+ "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2 port=5432,5433",
+ None,
+ ),
+ (
+ "host=foo.com,nosuchhost.com",
+ "host=foo.com hostaddr=1.1.1.1",
+ None,
+ ),
+ (
+ "host=foo.com, port=5432,5433",
+ "host=foo.com, hostaddr=1.1.1.1, port=5432,5433",
+ None,
+ ),
+ (
+ "host=nosuchhost.com,foo.com",
+ "host=foo.com hostaddr=1.1.1.1",
+ None,
+ ),
+ (
+ "host=foo.com,qux.com",
+ "host=foo.com,qux.com hostaddr=1.1.1.1,2.2.2.2",
+ {},
+ ),
+ ],
+)
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async(conninfo, want, env, fake_resolve):
+ params = conninfo_to_dict(conninfo)
+ params = await resolve_hostaddr_async(params)
+ assert conninfo_to_dict(want) == params
+
+
+@pytest.mark.parametrize(
+ "conninfo, env",
+ [
+ ("host=bad1.com,bad2.com", None),
+ ("host=foo.com port=1,2", None),
+ ("host=1.1.1.1,2.2.2.2 port=5432,5433,5434", None),
+ ("host=1.1.1.1,2.2.2.2", {"PGPORT": "1,2,3"}),
+ ],
+)
+@pytest.mark.asyncio
+async def test_resolve_hostaddr_async_bad(setpgenv, conninfo, env, fake_resolve):
+ setpgenv(env)
+ params = conninfo_to_dict(conninfo)
+ with pytest.raises(psycopg.Error):
+ await resolve_hostaddr_async(params)
+
+
+@pytest.fixture
+async def fake_resolve(monkeypatch):
+ fake_hosts = {
+ "localhost": "127.0.0.1",
+ "foo.com": "1.1.1.1",
+ "qux.com": "2.2.2.2",
+ }
+
+ async def fake_getaddrinfo(host, port, **kwargs):
+ assert isinstance(port, int) or (isinstance(port, str) and port.isdigit())
+ try:
+ addr = fake_hosts[host]
+ except KeyError:
+ raise OSError(f"unknown test host: {host}")
+ else:
+ return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (addr, 432))]
+
+ monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fake_getaddrinfo)
+
+
+@pytest.fixture
+async def fail_resolve(monkeypatch):
+ async def fail_getaddrinfo(host, port, **kwargs):
+ pytest.fail(f"shouldn't try to resolve {host}")
+
+ monkeypatch.setattr(asyncio.get_running_loop(), "getaddrinfo", fail_getaddrinfo)