summaryrefslogtreecommitdiffstats
path: root/tests/test_dns_srv.py
blob: 15b370685b819377ca930ff34be60a40736eb027 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from typing import List, Union

import pytest

import psycopg
from psycopg.conninfo import conninfo_to_dict

from .test_dns import import_dnspython

pytestmark = [pytest.mark.dns]

samples_ok = [
    ("", "", None),
    ("host=_pg._tcp.foo.com", "host=db1.example.com port=5432", None),
    ("", "host=db1.example.com port=5432", {"PGHOST": "_pg._tcp.foo.com"}),
    (
        "host=foo.com,_pg._tcp.foo.com",
        "host=foo.com,db1.example.com port=,5432",
        None,
    ),
    (
        "host=_pg._tcp.dot.com,foo.com,_pg._tcp.foo.com",
        "host=foo.com,db1.example.com port=,5432",
        None,
    ),
    (
        "host=_pg._tcp.bar.com",
        "host=db1.example.com,db4.example.com,db3.example.com,db2.example.com"
        " port=5432,5432,5433,5432",
        None,
    ),
    (
        "host=service.foo.com port=srv",
        "host=service.example.com port=15432",
        None,
    ),
    # No resolution
    (
        "host=_pg._tcp.foo.com hostaddr=1.1.1.1",
        "host=_pg._tcp.foo.com hostaddr=1.1.1.1",
        None,
    ),
]


@pytest.mark.flakey("random weight order, might cause wrong order")
@pytest.mark.parametrize("conninfo, want, env", samples_ok)
def test_srv(conninfo, want, env, fake_srv, setpgenv):
    setpgenv(env)
    params = conninfo_to_dict(conninfo)
    params = psycopg._dns.resolve_srv(params)  # type: ignore[attr-defined]
    assert conninfo_to_dict(want) == params


@pytest.mark.asyncio
@pytest.mark.parametrize("conninfo, want, env", samples_ok)
async def test_srv_async(conninfo, want, env, afake_srv, setpgenv):
    setpgenv(env)
    params = conninfo_to_dict(conninfo)
    params = await (
        psycopg._dns.resolve_srv_async(params)  # type: ignore[attr-defined]
    )
    assert conninfo_to_dict(want) == params


samples_bad = [
    ("host=_pg._tcp.dot.com", None),
    ("host=_pg._tcp.foo.com port=1,2", None),
]


@pytest.mark.parametrize("conninfo,  env", samples_bad)
def test_srv_bad(conninfo, env, fake_srv, setpgenv):
    setpgenv(env)
    params = conninfo_to_dict(conninfo)
    with pytest.raises(psycopg.OperationalError):
        psycopg._dns.resolve_srv(params)  # type: ignore[attr-defined]


@pytest.mark.asyncio
@pytest.mark.parametrize("conninfo,  env", samples_bad)
async def test_srv_bad_async(conninfo, env, afake_srv, setpgenv):
    setpgenv(env)
    params = conninfo_to_dict(conninfo)
    with pytest.raises(psycopg.OperationalError):
        await psycopg._dns.resolve_srv_async(params)  # type: ignore[attr-defined]


@pytest.fixture
def fake_srv(monkeypatch):
    f = get_fake_srv_function(monkeypatch)
    monkeypatch.setattr(
        psycopg._dns.resolver,  # type: ignore[attr-defined]
        "resolve",
        f,
    )


@pytest.fixture
def afake_srv(monkeypatch):
    f = get_fake_srv_function(monkeypatch)

    async def af(qname, rdtype):
        return f(qname, rdtype)

    monkeypatch.setattr(
        psycopg._dns.async_resolver,  # type: ignore[attr-defined]
        "resolve",
        af,
    )


def get_fake_srv_function(monkeypatch):
    import_dnspython()

    from dns.rdtypes.IN.A import A
    from dns.rdtypes.IN.SRV import SRV
    from dns.exception import DNSException

    fake_hosts = {
        ("_pg._tcp.dot.com", "SRV"): ["0 0 5432 ."],
        ("_pg._tcp.foo.com", "SRV"): ["0 0 5432 db1.example.com."],
        ("_pg._tcp.bar.com", "SRV"): [
            "1 0 5432 db2.example.com.",
            "1 255 5433 db3.example.com.",
            "0 0 5432 db1.example.com.",
            "1 65535 5432 db4.example.com.",
        ],
        ("service.foo.com", "SRV"): ["0 0 15432 service.example.com."],
    }

    def fake_srv_(qname, rdtype):
        try:
            ans = fake_hosts[qname, rdtype]
        except KeyError:
            raise DNSException(f"unknown test host: {qname} {rdtype}")
        rv: List[Union[A, SRV]] = []

        if rdtype == "A":
            for entry in ans:
                rv.append(A("IN", "A", entry))
        else:
            for entry in ans:
                pri, w, port, target = entry.split()
                rv.append(SRV("IN", "SRV", int(pri), int(w), int(port), target))

        return rv

    return fake_srv_