summaryrefslogtreecommitdiffstats
path: root/src/net/dnsclient_unix_test.go
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-28 13:14:23 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-28 13:14:23 +0000
commit73df946d56c74384511a194dd01dbe099584fd1a (patch)
treefd0bcea490dd81327ddfbb31e215439672c9a068 /src/net/dnsclient_unix_test.go
parentInitial commit. (diff)
downloadgolang-1.16-73df946d56c74384511a194dd01dbe099584fd1a.tar.xz
golang-1.16-73df946d56c74384511a194dd01dbe099584fd1a.zip
Adding upstream version 1.16.10.upstream/1.16.10upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
-rw-r--r--src/net/dnsclient_unix_test.go2121
1 files changed, 2121 insertions, 0 deletions
diff --git a/src/net/dnsclient_unix_test.go b/src/net/dnsclient_unix_test.go
new file mode 100644
index 0000000..e7f7621
--- /dev/null
+++ b/src/net/dnsclient_unix_test.go
@@ -0,0 +1,2121 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
+
+package net
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "path"
+ "reflect"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "golang.org/x/net/dns/dnsmessage"
+)
+
+var goResolver = Resolver{PreferGo: true}
+
+// Test address from 192.0.2.0/24 block, reserved by RFC 5737 for documentation.
+var TestAddr = [4]byte{0xc0, 0x00, 0x02, 0x01}
+
+// Test address from 2001:db8::/32 block, reserved by RFC 3849 for documentation.
+var TestAddr6 = [16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
+
+func mustNewName(name string) dnsmessage.Name {
+ nn, err := dnsmessage.NewName(name)
+ if err != nil {
+ panic(fmt.Sprint("creating name: ", err))
+ }
+ return nn
+}
+
+func mustQuestion(name string, qtype dnsmessage.Type, class dnsmessage.Class) dnsmessage.Question {
+ return dnsmessage.Question{
+ Name: mustNewName(name),
+ Type: qtype,
+ Class: class,
+ }
+}
+
+var dnsTransportFallbackTests = []struct {
+ server string
+ question dnsmessage.Question
+ timeout int
+ rcode dnsmessage.RCode
+}{
+ // Querying "com." with qtype=255 usually makes an answer
+ // which requires more than 512 bytes.
+ {"8.8.8.8:53", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), 2, dnsmessage.RCodeSuccess},
+ {"8.8.4.4:53", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), 4, dnsmessage.RCodeSuccess},
+}
+
+func TestDNSTransportFallback(t *testing.T) {
+ fake := fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ },
+ Questions: q.Questions,
+ }
+ if n == "udp" {
+ r.Header.Truncated = true
+ }
+ return r, nil
+ },
+ }
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+ for _, tt := range dnsTransportFallbackTests {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ _, h, err := r.exchange(ctx, tt.server, tt.question, time.Second, useUDPOrTCP)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ if h.RCode != tt.rcode {
+ t.Errorf("got %v from %v; want %v", h.RCode, tt.server, tt.rcode)
+ continue
+ }
+ }
+}
+
+// See RFC 6761 for further information about the reserved, pseudo
+// domain names.
+var specialDomainNameTests = []struct {
+ question dnsmessage.Question
+ rcode dnsmessage.RCode
+}{
+ // Name resolution APIs and libraries should not recognize the
+ // followings as special.
+ {mustQuestion("1.0.168.192.in-addr.arpa.", dnsmessage.TypePTR, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
+ {mustQuestion("test.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
+ {mustQuestion("example.com.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeSuccess},
+
+ // Name resolution APIs and libraries should recognize the
+ // followings as special and should not send any queries.
+ // Though, we test those names here for verifying negative
+ // answers at DNS query-response interaction level.
+ {mustQuestion("localhost.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
+ {mustQuestion("invalid.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
+}
+
+func TestSpecialDomainName(t *testing.T) {
+ fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ },
+ Questions: q.Questions,
+ }
+
+ switch q.Questions[0].Name.String() {
+ case "example.com.":
+ r.Header.RCode = dnsmessage.RCodeSuccess
+ default:
+ r.Header.RCode = dnsmessage.RCodeNameError
+ }
+
+ return r, nil
+ }}
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+ server := "8.8.8.8:53"
+ for _, tt := range specialDomainNameTests {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ _, h, err := r.exchange(ctx, server, tt.question, 3*time.Second, useUDPOrTCP)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ if h.RCode != tt.rcode {
+ t.Errorf("got %v from %v; want %v", h.RCode, server, tt.rcode)
+ continue
+ }
+ }
+}
+
+// Issue 13705: don't try to resolve onion addresses, etc
+func TestAvoidDNSName(t *testing.T) {
+ tests := []struct {
+ name string
+ avoid bool
+ }{
+ {"foo.com", false},
+ {"foo.com.", false},
+
+ {"foo.onion.", true},
+ {"foo.onion", true},
+ {"foo.ONION", true},
+ {"foo.ONION.", true},
+
+ // But do resolve *.local address; Issue 16739
+ {"foo.local.", false},
+ {"foo.local", false},
+ {"foo.LOCAL", false},
+ {"foo.LOCAL.", false},
+
+ {"", true}, // will be rejected earlier too
+
+ // Without stuff before onion/local, they're fine to
+ // use DNS. With a search path,
+ // "onion.vegetables.com" can use DNS. Without a
+ // search path (or with a trailing dot), the queries
+ // are just kinda useless, but don't reveal anything
+ // private.
+ {"local", false},
+ {"onion", false},
+ {"local.", false},
+ {"onion.", false},
+ }
+ for _, tt := range tests {
+ got := avoidDNS(tt.name)
+ if got != tt.avoid {
+ t.Errorf("avoidDNS(%q) = %v; want %v", tt.name, got, tt.avoid)
+ }
+ }
+}
+
+var fakeDNSServerSuccessful = fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ },
+ Questions: q.Questions,
+ }
+ if len(q.Questions) == 1 && q.Questions[0].Type == dnsmessage.TypeA {
+ r.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
+ },
+ },
+ }
+ }
+ return r, nil
+}}
+
+// Issue 13705: don't try to resolve onion addresses, etc
+func TestLookupTorOnion(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+ r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
+ addrs, err := r.LookupIPAddr(context.Background(), "foo.onion")
+ if err != nil {
+ t.Fatalf("lookup = %v; want nil", err)
+ }
+ if len(addrs) > 0 {
+ t.Errorf("unexpected addresses: %v", addrs)
+ }
+}
+
+type resolvConfTest struct {
+ dir string
+ path string
+ *resolverConfig
+}
+
+func newResolvConfTest() (*resolvConfTest, error) {
+ dir, err := os.MkdirTemp("", "go-resolvconftest")
+ if err != nil {
+ return nil, err
+ }
+ conf := &resolvConfTest{
+ dir: dir,
+ path: path.Join(dir, "resolv.conf"),
+ resolverConfig: &resolvConf,
+ }
+ conf.initOnce.Do(conf.init)
+ return conf, nil
+}
+
+func (conf *resolvConfTest) writeAndUpdate(lines []string) error {
+ f, err := os.OpenFile(conf.path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
+ if err != nil {
+ return err
+ }
+ if _, err := f.WriteString(strings.Join(lines, "\n")); err != nil {
+ f.Close()
+ return err
+ }
+ f.Close()
+ if err := conf.forceUpdate(conf.path, time.Now().Add(time.Hour)); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (conf *resolvConfTest) forceUpdate(name string, lastChecked time.Time) error {
+ dnsConf := dnsReadConfig(name)
+ conf.mu.Lock()
+ conf.dnsConfig = dnsConf
+ conf.mu.Unlock()
+ for i := 0; i < 5; i++ {
+ if conf.tryAcquireSema() {
+ conf.lastChecked = lastChecked
+ conf.releaseSema()
+ return nil
+ }
+ }
+ return fmt.Errorf("tryAcquireSema for %s failed", name)
+}
+
+func (conf *resolvConfTest) servers() []string {
+ conf.mu.RLock()
+ servers := conf.dnsConfig.servers
+ conf.mu.RUnlock()
+ return servers
+}
+
+func (conf *resolvConfTest) teardown() error {
+ err := conf.forceUpdate("/etc/resolv.conf", time.Time{})
+ os.RemoveAll(conf.dir)
+ return err
+}
+
+var updateResolvConfTests = []struct {
+ name string // query name
+ lines []string // resolver configuration lines
+ servers []string // expected name servers
+}{
+ {
+ name: "golang.org",
+ lines: []string{"nameserver 8.8.8.8"},
+ servers: []string{"8.8.8.8:53"},
+ },
+ {
+ name: "",
+ lines: nil, // an empty resolv.conf should use defaultNS as name servers
+ servers: defaultNS,
+ },
+ {
+ name: "www.example.com",
+ lines: []string{"nameserver 8.8.4.4"},
+ servers: []string{"8.8.4.4:53"},
+ },
+}
+
+func TestUpdateResolvConf(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+
+ r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext}
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ for i, tt := range updateResolvConfTests {
+ if err := conf.writeAndUpdate(tt.lines); err != nil {
+ t.Error(err)
+ continue
+ }
+ if tt.name != "" {
+ var wg sync.WaitGroup
+ const N = 10
+ wg.Add(N)
+ for j := 0; j < N; j++ {
+ go func(name string) {
+ defer wg.Done()
+ ips, err := r.LookupIPAddr(context.Background(), name)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ if len(ips) == 0 {
+ t.Errorf("no records for %s", name)
+ return
+ }
+ }(tt.name)
+ }
+ wg.Wait()
+ }
+ servers := conf.servers()
+ if !reflect.DeepEqual(servers, tt.servers) {
+ t.Errorf("#%d: got %v; want %v", i, servers, tt.servers)
+ continue
+ }
+ }
+}
+
+var goLookupIPWithResolverConfigTests = []struct {
+ name string
+ lines []string // resolver configuration lines
+ error
+ a, aaaa bool // whether response contains A, AAAA-record
+}{
+ // no records, transport timeout
+ {
+ "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j",
+ []string{
+ "options timeout:1 attempts:1",
+ "nameserver 255.255.255.255", // please forgive us for abuse of limited broadcast address
+ },
+ &DNSError{Name: "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j", Server: "255.255.255.255:53", IsTimeout: true},
+ false, false,
+ },
+
+ // no records, non-existent domain
+ {
+ "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j",
+ []string{
+ "options timeout:3 attempts:1",
+ "nameserver 8.8.8.8",
+ },
+ &DNSError{Name: "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j", Server: "8.8.8.8:53", IsTimeout: false},
+ false, false,
+ },
+
+ // a few A records, no AAAA records
+ {
+ "ipv4.google.com.",
+ []string{
+ "nameserver 8.8.8.8",
+ "nameserver 2001:4860:4860::8888",
+ },
+ nil,
+ true, false,
+ },
+ {
+ "ipv4.google.com",
+ []string{
+ "domain golang.org",
+ "nameserver 2001:4860:4860::8888",
+ "nameserver 8.8.8.8",
+ },
+ nil,
+ true, false,
+ },
+ {
+ "ipv4.google.com",
+ []string{
+ "search x.golang.org y.golang.org",
+ "nameserver 2001:4860:4860::8888",
+ "nameserver 8.8.8.8",
+ },
+ nil,
+ true, false,
+ },
+
+ // no A records, a few AAAA records
+ {
+ "ipv6.google.com.",
+ []string{
+ "nameserver 2001:4860:4860::8888",
+ "nameserver 8.8.8.8",
+ },
+ nil,
+ false, true,
+ },
+ {
+ "ipv6.google.com",
+ []string{
+ "domain golang.org",
+ "nameserver 8.8.8.8",
+ "nameserver 2001:4860:4860::8888",
+ },
+ nil,
+ false, true,
+ },
+ {
+ "ipv6.google.com",
+ []string{
+ "search x.golang.org y.golang.org",
+ "nameserver 8.8.8.8",
+ "nameserver 2001:4860:4860::8888",
+ },
+ nil,
+ false, true,
+ },
+
+ // both A and AAAA records
+ {
+ "hostname.as112.net", // see RFC 7534
+ []string{
+ "domain golang.org",
+ "nameserver 2001:4860:4860::8888",
+ "nameserver 8.8.8.8",
+ },
+ nil,
+ true, true,
+ },
+ {
+ "hostname.as112.net", // see RFC 7534
+ []string{
+ "search x.golang.org y.golang.org",
+ "nameserver 2001:4860:4860::8888",
+ "nameserver 8.8.8.8",
+ },
+ nil,
+ true, true,
+ },
+}
+
+func TestGoLookupIPWithResolverConfig(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+ fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ switch s {
+ case "[2001:4860:4860::8888]:53", "8.8.8.8:53":
+ break
+ default:
+ time.Sleep(10 * time.Millisecond)
+ return dnsmessage.Message{}, os.ErrDeadlineExceeded
+ }
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ },
+ Questions: q.Questions,
+ }
+ for _, question := range q.Questions {
+ switch question.Type {
+ case dnsmessage.TypeA:
+ switch question.Name.String() {
+ case "hostname.as112.net.":
+ break
+ case "ipv4.google.com.":
+ r.Answers = append(r.Answers, dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
+ },
+ })
+ default:
+
+ }
+ case dnsmessage.TypeAAAA:
+ switch question.Name.String() {
+ case "hostname.as112.net.":
+ break
+ case "ipv6.google.com.":
+ r.Answers = append(r.Answers, dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeAAAA,
+ Class: dnsmessage.ClassINET,
+ Length: 16,
+ },
+ Body: &dnsmessage.AAAAResource{
+ AAAA: TestAddr6,
+ },
+ })
+ }
+ }
+ }
+ return r, nil
+ }}
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ for _, tt := range goLookupIPWithResolverConfigTests {
+ if err := conf.writeAndUpdate(tt.lines); err != nil {
+ t.Error(err)
+ continue
+ }
+ addrs, err := r.LookupIPAddr(context.Background(), tt.name)
+ if err != nil {
+ if err, ok := err.(*DNSError); !ok || tt.error != nil && (err.Name != tt.error.(*DNSError).Name || err.Server != tt.error.(*DNSError).Server || err.IsTimeout != tt.error.(*DNSError).IsTimeout) {
+ t.Errorf("got %v; want %v", err, tt.error)
+ }
+ continue
+ }
+ if len(addrs) == 0 {
+ t.Errorf("no records for %s", tt.name)
+ }
+ if !tt.a && !tt.aaaa && len(addrs) > 0 {
+ t.Errorf("unexpected %v for %s", addrs, tt.name)
+ }
+ for _, addr := range addrs {
+ if !tt.a && addr.IP.To4() != nil {
+ t.Errorf("got %v; must not be IPv4 address", addr)
+ }
+ if !tt.aaaa && addr.IP.To16() != nil && addr.IP.To4() == nil {
+ t.Errorf("got %v; must not be IPv6 address", addr)
+ }
+ }
+ }
+}
+
+// Test that goLookupIPOrder falls back to the host file when no DNS servers are available.
+func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+
+ fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ },
+ Questions: q.Questions,
+ }
+ return r, nil
+ }}
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+
+ // Add a config that simulates no dns servers being available.
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ if err := conf.writeAndUpdate([]string{}); err != nil {
+ t.Fatal(err)
+ }
+ // Redirect host file lookups.
+ defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+ testHookHostsPath = "testdata/hosts"
+
+ for _, order := range []hostLookupOrder{hostLookupFilesDNS, hostLookupDNSFiles} {
+ name := fmt.Sprintf("order %v", order)
+
+ // First ensure that we get an error when contacting a non-existent host.
+ _, _, err := r.goLookupIPCNAMEOrder(context.Background(), "notarealhost", order)
+ if err == nil {
+ t.Errorf("%s: expected error while looking up name not in hosts file", name)
+ continue
+ }
+
+ // Now check that we get an address when the name appears in the hosts file.
+ addrs, _, err := r.goLookupIPCNAMEOrder(context.Background(), "thor", order) // entry is in "testdata/hosts"
+ if err != nil {
+ t.Errorf("%s: expected to successfully lookup host entry", name)
+ continue
+ }
+ if len(addrs) != 1 {
+ t.Errorf("%s: expected exactly one result, but got %v", name, addrs)
+ continue
+ }
+ if got, want := addrs[0].String(), "127.1.1.1"; got != want {
+ t.Errorf("%s: address doesn't match expectation. got %v, want %v", name, got, want)
+ }
+ }
+}
+
+// Issue 12712.
+// When using search domains, return the error encountered
+// querying the original name instead of an error encountered
+// querying a generated name.
+func TestErrorForOriginalNameWhenSearching(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+
+ const fqdn = "doesnotexist.domain"
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ if err := conf.writeAndUpdate([]string{"search servfail"}); err != nil {
+ t.Fatal(err)
+ }
+
+ fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ },
+ Questions: q.Questions,
+ }
+
+ switch q.Questions[0].Name.String() {
+ case fqdn + ".servfail.":
+ r.Header.RCode = dnsmessage.RCodeServerFailure
+ default:
+ r.Header.RCode = dnsmessage.RCodeNameError
+ }
+
+ return r, nil
+ }}
+
+ cases := []struct {
+ strictErrors bool
+ wantErr *DNSError
+ }{
+ {true, &DNSError{Name: fqdn, Err: "server misbehaving", IsTemporary: true}},
+ {false, &DNSError{Name: fqdn, Err: errNoSuchHost.Error(), IsNotFound: true}},
+ }
+ for _, tt := range cases {
+ r := Resolver{PreferGo: true, StrictErrors: tt.strictErrors, Dial: fake.DialContext}
+ _, err = r.LookupIPAddr(context.Background(), fqdn)
+ if err == nil {
+ t.Fatal("expected an error")
+ }
+
+ want := tt.wantErr
+ if err, ok := err.(*DNSError); !ok || err.Name != want.Name || err.Err != want.Err || err.IsTemporary != want.IsTemporary {
+ t.Errorf("got %v; want %v", err, want)
+ }
+ }
+}
+
+// Issue 15434. If a name server gives a lame referral, continue to the next.
+func TestIgnoreLameReferrals(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ if err := conf.writeAndUpdate([]string{"nameserver 192.0.2.1", // the one that will give a lame referral
+ "nameserver 192.0.2.2"}); err != nil {
+ t.Fatal(err)
+ }
+
+ fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ t.Log(s, q)
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ },
+ Questions: q.Questions,
+ }
+
+ if s == "192.0.2.2:53" {
+ r.Header.RecursionAvailable = true
+ if q.Questions[0].Type == dnsmessage.TypeA {
+ r.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
+ },
+ },
+ }
+ }
+ }
+
+ return r, nil
+ }}
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+
+ addrs, err := r.LookupIPAddr(context.Background(), "www.golang.org")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if got := len(addrs); got != 1 {
+ t.Fatalf("got %d addresses, want 1", got)
+ }
+
+ if got, want := addrs[0].String(), "192.0.2.1"; got != want {
+ t.Fatalf("got address %v, want %v", got, want)
+ }
+}
+
+func BenchmarkGoLookupIP(b *testing.B) {
+ testHookUninstaller.Do(uninstallTestHooks)
+ ctx := context.Background()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ goResolver.LookupIPAddr(ctx, "www.example.com")
+ }
+}
+
+func BenchmarkGoLookupIPNoSuchHost(b *testing.B) {
+ testHookUninstaller.Do(uninstallTestHooks)
+ ctx := context.Background()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ goResolver.LookupIPAddr(ctx, "some.nonexistent")
+ }
+}
+
+func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
+ testHookUninstaller.Do(uninstallTestHooks)
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer conf.teardown()
+
+ lines := []string{
+ "nameserver 203.0.113.254", // use TEST-NET-3 block, see RFC 5737
+ "nameserver 8.8.8.8",
+ }
+ if err := conf.writeAndUpdate(lines); err != nil {
+ b.Fatal(err)
+ }
+ ctx := context.Background()
+ b.ReportAllocs()
+
+ for i := 0; i < b.N; i++ {
+ goResolver.LookupIPAddr(ctx, "www.example.com")
+ }
+}
+
+type fakeDNSServer struct {
+ rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error)
+ alwaysTCP bool
+}
+
+func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
+ if server.alwaysTCP || n == "tcp" || n == "tcp4" || n == "tcp6" {
+ return &fakeDNSConn{tcp: true, server: server, n: n, s: s}, nil
+ }
+ return &fakeDNSPacketConn{fakeDNSConn: fakeDNSConn{tcp: false, server: server, n: n, s: s}}, nil
+}
+
+type fakeDNSConn struct {
+ Conn
+ tcp bool
+ server *fakeDNSServer
+ n string
+ s string
+ q dnsmessage.Message
+ t time.Time
+ buf []byte
+}
+
+func (f *fakeDNSConn) Close() error {
+ return nil
+}
+
+func (f *fakeDNSConn) Read(b []byte) (int, error) {
+ if len(f.buf) > 0 {
+ n := copy(b, f.buf)
+ f.buf = f.buf[n:]
+ return n, nil
+ }
+
+ resp, err := f.server.rh(f.n, f.s, f.q, f.t)
+ if err != nil {
+ return 0, err
+ }
+
+ bb := make([]byte, 2, 514)
+ bb, err = resp.AppendPack(bb)
+ if err != nil {
+ return 0, fmt.Errorf("cannot marshal DNS message: %v", err)
+ }
+
+ if f.tcp {
+ l := len(bb) - 2
+ bb[0] = byte(l >> 8)
+ bb[1] = byte(l)
+ f.buf = bb
+ return f.Read(b)
+ }
+
+ bb = bb[2:]
+ if len(b) < len(bb) {
+ return 0, errors.New("read would fragment DNS message")
+ }
+
+ copy(b, bb)
+ return len(bb), nil
+}
+
+func (f *fakeDNSConn) Write(b []byte) (int, error) {
+ if f.tcp && len(b) >= 2 {
+ b = b[2:]
+ }
+ if f.q.Unpack(b) != nil {
+ return 0, fmt.Errorf("cannot unmarshal DNS message fake %s (%d)", f.n, len(b))
+ }
+ return len(b), nil
+}
+
+func (f *fakeDNSConn) SetDeadline(t time.Time) error {
+ f.t = t
+ return nil
+}
+
+type fakeDNSPacketConn struct {
+ PacketConn
+ fakeDNSConn
+}
+
+func (f *fakeDNSPacketConn) SetDeadline(t time.Time) error {
+ return f.fakeDNSConn.SetDeadline(t)
+}
+
+func (f *fakeDNSPacketConn) Close() error {
+ return f.fakeDNSConn.Close()
+}
+
+// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
+func TestIgnoreDNSForgeries(t *testing.T) {
+ c, s := Pipe()
+ go func() {
+ b := make([]byte, 512)
+ n, err := s.Read(b)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+
+ var msg dnsmessage.Message
+ if msg.Unpack(b[:n]) != nil {
+ t.Error("invalid DNS query:", err)
+ return
+ }
+
+ s.Write([]byte("garbage DNS response packet"))
+
+ msg.Header.Response = true
+ msg.Header.ID++ // make invalid ID
+
+ if b, err = msg.Pack(); err != nil {
+ t.Error("failed to pack DNS response:", err)
+ return
+ }
+ s.Write(b)
+
+ msg.Header.ID-- // restore original ID
+ msg.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: mustNewName("www.example.com."),
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
+ },
+ },
+ }
+
+ b, err = msg.Pack()
+ if err != nil {
+ t.Error("failed to pack DNS response:", err)
+ return
+ }
+ s.Write(b)
+ }()
+
+ msg := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: 42,
+ },
+ Questions: []dnsmessage.Question{
+ {
+ Name: mustNewName("www.example.com."),
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ },
+ },
+ }
+
+ b, err := msg.Pack()
+ if err != nil {
+ t.Fatal("Pack failed:", err)
+ }
+
+ p, _, err := dnsPacketRoundTrip(c, 42, msg.Questions[0], b)
+ if err != nil {
+ t.Fatalf("dnsPacketRoundTrip failed: %v", err)
+ }
+
+ p.SkipAllQuestions()
+ as, err := p.AllAnswers()
+ if err != nil {
+ t.Fatal("AllAnswers failed:", err)
+ }
+ if got := as[0].Body.(*dnsmessage.AResource).A; got != TestAddr {
+ t.Errorf("got address %v, want %v", got, TestAddr)
+ }
+}
+
+// Issue 16865. If a name server times out, continue to the next.
+func TestRetryTimeout(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ testConf := []string{
+ "nameserver 192.0.2.1", // the one that will timeout
+ "nameserver 192.0.2.2",
+ }
+ if err := conf.writeAndUpdate(testConf); err != nil {
+ t.Fatal(err)
+ }
+
+ var deadline0 time.Time
+
+ fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
+ t.Log(s, q, deadline)
+
+ if deadline.IsZero() {
+ t.Error("zero deadline")
+ }
+
+ if s == "192.0.2.1:53" {
+ deadline0 = deadline
+ time.Sleep(10 * time.Millisecond)
+ return dnsmessage.Message{}, os.ErrDeadlineExceeded
+ }
+
+ if deadline.Equal(deadline0) {
+ t.Error("deadline didn't change")
+ }
+
+ return mockTXTResponse(q), nil
+ }}
+ r := &Resolver{PreferGo: true, Dial: fake.DialContext}
+
+ _, err = r.LookupTXT(context.Background(), "www.golang.org")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if deadline0.IsZero() {
+ t.Error("deadline0 still zero", deadline0)
+ }
+}
+
+func TestRotate(t *testing.T) {
+ // without rotation, always uses the first server
+ testRotate(t, false, []string{"192.0.2.1", "192.0.2.2"}, []string{"192.0.2.1:53", "192.0.2.1:53", "192.0.2.1:53"})
+
+ // with rotation, rotates through back to first
+ testRotate(t, true, []string{"192.0.2.1", "192.0.2.2"}, []string{"192.0.2.1:53", "192.0.2.2:53", "192.0.2.1:53"})
+}
+
+func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
+ defer dnsWaitGroup.Wait()
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ var confLines []string
+ for _, ns := range nameservers {
+ confLines = append(confLines, "nameserver "+ns)
+ }
+ if rotate {
+ confLines = append(confLines, "options rotate")
+ }
+
+ if err := conf.writeAndUpdate(confLines); err != nil {
+ t.Fatal(err)
+ }
+
+ var usedServers []string
+ fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
+ usedServers = append(usedServers, s)
+ return mockTXTResponse(q), nil
+ }}
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+
+ // len(nameservers) + 1 to allow rotation to get back to start
+ for i := 0; i < len(nameservers)+1; i++ {
+ if _, err := r.LookupTXT(context.Background(), "www.golang.org"); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ if !reflect.DeepEqual(usedServers, wantServers) {
+ t.Errorf("rotate=%t got used servers:\n%v\nwant:\n%v", rotate, usedServers, wantServers)
+ }
+}
+
+func mockTXTResponse(q dnsmessage.Message) dnsmessage.Message {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ RecursionAvailable: true,
+ },
+ Questions: q.Questions,
+ Answers: []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeTXT,
+ Class: dnsmessage.ClassINET,
+ },
+ Body: &dnsmessage.TXTResource{
+ TXT: []string{"ok"},
+ },
+ },
+ },
+ }
+
+ return r
+}
+
+// Issue 17448. With StrictErrors enabled, temporary errors should make
+// LookupIP fail rather than return a partial result.
+func TestStrictErrorsLookupIP(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ confData := []string{
+ "nameserver 192.0.2.53",
+ "search x.golang.org y.golang.org",
+ }
+ if err := conf.writeAndUpdate(confData); err != nil {
+ t.Fatal(err)
+ }
+
+ const name = "test-issue19592"
+ const server = "192.0.2.53:53"
+ const searchX = "test-issue19592.x.golang.org."
+ const searchY = "test-issue19592.y.golang.org."
+ const ip4 = "192.0.2.1"
+ const ip6 = "2001:db8::1"
+
+ type resolveWhichEnum int
+ const (
+ resolveOK resolveWhichEnum = iota
+ resolveOpError
+ resolveServfail
+ resolveTimeout
+ )
+
+ makeTempError := func(err string) error {
+ return &DNSError{
+ Err: err,
+ Name: name,
+ Server: server,
+ IsTemporary: true,
+ }
+ }
+ makeTimeout := func() error {
+ return &DNSError{
+ Err: os.ErrDeadlineExceeded.Error(),
+ Name: name,
+ Server: server,
+ IsTimeout: true,
+ }
+ }
+ makeNxDomain := func() error {
+ return &DNSError{
+ Err: errNoSuchHost.Error(),
+ Name: name,
+ Server: server,
+ IsNotFound: true,
+ }
+ }
+
+ cases := []struct {
+ desc string
+ resolveWhich func(quest dnsmessage.Question) resolveWhichEnum
+ wantStrictErr error
+ wantLaxErr error
+ wantIPs []string
+ }{
+ {
+ desc: "No errors",
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ return resolveOK
+ },
+ wantIPs: []string{ip4, ip6},
+ },
+ {
+ desc: "searchX error fails in strict mode",
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchX {
+ return resolveTimeout
+ }
+ return resolveOK
+ },
+ wantStrictErr: makeTimeout(),
+ wantIPs: []string{ip4, ip6},
+ },
+ {
+ desc: "searchX IPv4-only timeout fails in strict mode",
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchX && quest.Type == dnsmessage.TypeA {
+ return resolveTimeout
+ }
+ return resolveOK
+ },
+ wantStrictErr: makeTimeout(),
+ wantIPs: []string{ip4, ip6},
+ },
+ {
+ desc: "searchX IPv6-only servfail fails in strict mode",
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchX && quest.Type == dnsmessage.TypeAAAA {
+ return resolveServfail
+ }
+ return resolveOK
+ },
+ wantStrictErr: makeTempError("server misbehaving"),
+ wantIPs: []string{ip4, ip6},
+ },
+ {
+ desc: "searchY error always fails",
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchY {
+ return resolveTimeout
+ }
+ return resolveOK
+ },
+ wantStrictErr: makeTimeout(),
+ wantLaxErr: makeNxDomain(), // This one reaches the "test." FQDN.
+ },
+ {
+ desc: "searchY IPv4-only socket error fails in strict mode",
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchY && quest.Type == dnsmessage.TypeA {
+ return resolveOpError
+ }
+ return resolveOK
+ },
+ wantStrictErr: makeTempError("write: socket on fire"),
+ wantIPs: []string{ip6},
+ },
+ {
+ desc: "searchY IPv6-only timeout fails in strict mode",
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchY && quest.Type == dnsmessage.TypeAAAA {
+ return resolveTimeout
+ }
+ return resolveOK
+ },
+ wantStrictErr: makeTimeout(),
+ wantIPs: []string{ip4},
+ },
+ }
+
+ for i, tt := range cases {
+ fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
+ t.Log(s, q)
+
+ switch tt.resolveWhich(q.Questions[0]) {
+ case resolveOK:
+ // Handle below.
+ case resolveOpError:
+ return dnsmessage.Message{}, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")}
+ case resolveServfail:
+ return dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeServerFailure,
+ },
+ Questions: q.Questions,
+ }, nil
+ case resolveTimeout:
+ return dnsmessage.Message{}, os.ErrDeadlineExceeded
+ default:
+ t.Fatal("Impossible resolveWhich")
+ }
+
+ switch q.Questions[0].Name.String() {
+ case searchX, name + ".":
+ // Return NXDOMAIN to utilize the search list.
+ return dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeNameError,
+ },
+ Questions: q.Questions,
+ }, nil
+ case searchY:
+ // Return records below.
+ default:
+ return dnsmessage.Message{}, fmt.Errorf("Unexpected Name: %v", q.Questions[0].Name)
+ }
+
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ },
+ Questions: q.Questions,
+ }
+ switch q.Questions[0].Type {
+ case dnsmessage.TypeA:
+ r.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
+ },
+ },
+ }
+ case dnsmessage.TypeAAAA:
+ r.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeAAAA,
+ Class: dnsmessage.ClassINET,
+ Length: 16,
+ },
+ Body: &dnsmessage.AAAAResource{
+ AAAA: TestAddr6,
+ },
+ },
+ }
+ default:
+ return dnsmessage.Message{}, fmt.Errorf("Unexpected Type: %v", q.Questions[0].Type)
+ }
+ return r, nil
+ }}
+
+ for _, strict := range []bool{true, false} {
+ r := Resolver{PreferGo: true, StrictErrors: strict, Dial: fake.DialContext}
+ ips, err := r.LookupIPAddr(context.Background(), name)
+
+ var wantErr error
+ if strict {
+ wantErr = tt.wantStrictErr
+ } else {
+ wantErr = tt.wantLaxErr
+ }
+ if !reflect.DeepEqual(err, wantErr) {
+ t.Errorf("#%d (%s) strict=%v: got err %#v; want %#v", i, tt.desc, strict, err, wantErr)
+ }
+
+ gotIPs := map[string]struct{}{}
+ for _, ip := range ips {
+ gotIPs[ip.String()] = struct{}{}
+ }
+ wantIPs := map[string]struct{}{}
+ if wantErr == nil {
+ for _, ip := range tt.wantIPs {
+ wantIPs[ip] = struct{}{}
+ }
+ }
+ if !reflect.DeepEqual(gotIPs, wantIPs) {
+ t.Errorf("#%d (%s) strict=%v: got ips %v; want %v", i, tt.desc, strict, gotIPs, wantIPs)
+ }
+ }
+ }
+}
+
+// Issue 17448. With StrictErrors enabled, temporary errors should make
+// LookupTXT stop walking the search list.
+func TestStrictErrorsLookupTXT(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ confData := []string{
+ "nameserver 192.0.2.53",
+ "search x.golang.org y.golang.org",
+ }
+ if err := conf.writeAndUpdate(confData); err != nil {
+ t.Fatal(err)
+ }
+
+ const name = "test"
+ const server = "192.0.2.53:53"
+ const searchX = "test.x.golang.org."
+ const searchY = "test.y.golang.org."
+ const txt = "Hello World"
+
+ fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
+ t.Log(s, q)
+
+ switch q.Questions[0].Name.String() {
+ case searchX:
+ return dnsmessage.Message{}, os.ErrDeadlineExceeded
+ case searchY:
+ return mockTXTResponse(q), nil
+ default:
+ return dnsmessage.Message{}, fmt.Errorf("Unexpected Name: %v", q.Questions[0].Name)
+ }
+ }}
+
+ for _, strict := range []bool{true, false} {
+ r := Resolver{StrictErrors: strict, Dial: fake.DialContext}
+ p, _, err := r.lookup(context.Background(), name, dnsmessage.TypeTXT)
+ var wantErr error
+ var wantRRs int
+ if strict {
+ wantErr = &DNSError{
+ Err: os.ErrDeadlineExceeded.Error(),
+ Name: name,
+ Server: server,
+ IsTimeout: true,
+ }
+ } else {
+ wantRRs = 1
+ }
+ if !reflect.DeepEqual(err, wantErr) {
+ t.Errorf("strict=%v: got err %#v; want %#v", strict, err, wantErr)
+ }
+ a, err := p.AllAnswers()
+ if err != nil {
+ a = nil
+ }
+ if len(a) != wantRRs {
+ t.Errorf("strict=%v: got %v; want %v", strict, len(a), wantRRs)
+ }
+ }
+}
+
+// Test for a race between uninstalling the test hooks and closing a
+// socket connection. This used to fail when testing with -race.
+func TestDNSGoroutineRace(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+
+ fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) {
+ time.Sleep(10 * time.Microsecond)
+ return dnsmessage.Message{}, os.ErrDeadlineExceeded
+ }}
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+
+ // The timeout here is less than the timeout used by the server,
+ // so the goroutine started to query the (fake) server will hang
+ // around after this test is done if we don't call dnsWaitGroup.Wait.
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Microsecond)
+ defer cancel()
+ _, err := r.LookupIPAddr(ctx, "where.are.they.now")
+ if err == nil {
+ t.Fatal("fake DNS lookup unexpectedly succeeded")
+ }
+}
+
+func lookupWithFake(fake fakeDNSServer, name string, typ dnsmessage.Type) error {
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+
+ resolvConf.mu.RLock()
+ conf := resolvConf.dnsConfig
+ resolvConf.mu.RUnlock()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ _, _, err := r.tryOneName(ctx, conf, name, typ)
+ return err
+}
+
+// Issue 8434: verify that Temporary returns true on an error when rcode
+// is SERVFAIL
+func TestIssue8434(t *testing.T) {
+ err := lookupWithFake(fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ return dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeServerFailure,
+ },
+ Questions: q.Questions,
+ }, nil
+ },
+ }, "golang.org.", dnsmessage.TypeALL)
+ if err == nil {
+ t.Fatal("expected an error")
+ }
+ if ne, ok := err.(Error); !ok {
+ t.Fatalf("err = %#v; wanted something supporting net.Error", err)
+ } else if !ne.Temporary() {
+ t.Fatalf("Temporary = false for err = %#v; want Temporary == true", err)
+ }
+ if de, ok := err.(*DNSError); !ok {
+ t.Fatalf("err = %#v; wanted a *net.DNSError", err)
+ } else if !de.IsTemporary {
+ t.Fatalf("IsTemporary = false for err = %#v; want IsTemporary == true", err)
+ }
+}
+
+func TestIssueNoSuchHostExists(t *testing.T) {
+ err := lookupWithFake(fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ return dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeNameError,
+ },
+ Questions: q.Questions,
+ }, nil
+ },
+ }, "golang.org.", dnsmessage.TypeALL)
+ if err == nil {
+ t.Fatal("expected an error")
+ }
+ if _, ok := err.(Error); !ok {
+ t.Fatalf("err = %#v; wanted something supporting net.Error", err)
+ }
+ if de, ok := err.(*DNSError); !ok {
+ t.Fatalf("err = %#v; wanted a *net.DNSError", err)
+ } else if !de.IsNotFound {
+ t.Fatalf("IsNotFound = false for err = %#v; want IsNotFound == true", err)
+ }
+}
+
+// TestNoSuchHost verifies that tryOneName works correctly when the domain does
+// not exist.
+//
+// Issue 12778: verify that NXDOMAIN without RA bit errors as "no such host"
+// and not "server misbehaving"
+//
+// Issue 25336: verify that NXDOMAIN errors fail fast.
+//
+// Issue 27525: verify that empty answers fail fast.
+func TestNoSuchHost(t *testing.T) {
+ tests := []struct {
+ name string
+ f func(string, string, dnsmessage.Message, time.Time) (dnsmessage.Message, error)
+ }{
+ {
+ "NXDOMAIN",
+ func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ return dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeNameError,
+ RecursionAvailable: false,
+ },
+ Questions: q.Questions,
+ }, nil
+ },
+ },
+ {
+ "no answers",
+ func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ return dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ RecursionAvailable: false,
+ Authoritative: true,
+ },
+ Questions: q.Questions,
+ }, nil
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ lookups := 0
+ err := lookupWithFake(fakeDNSServer{
+ rh: func(n, s string, q dnsmessage.Message, d time.Time) (dnsmessage.Message, error) {
+ lookups++
+ return test.f(n, s, q, d)
+ },
+ }, ".", dnsmessage.TypeALL)
+
+ if lookups != 1 {
+ t.Errorf("got %d lookups, wanted 1", lookups)
+ }
+
+ if err == nil {
+ t.Fatal("expected an error")
+ }
+ de, ok := err.(*DNSError)
+ if !ok {
+ t.Fatalf("err = %#v; wanted a *net.DNSError", err)
+ }
+ if de.Err != errNoSuchHost.Error() {
+ t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error())
+ }
+ if !de.IsNotFound {
+ t.Fatalf("IsNotFound = %v wanted true", de.IsNotFound)
+ }
+ })
+ }
+}
+
+// Issue 26573: verify that Conns that don't implement PacketConn are treated
+// as streams even when udp was requested.
+func TestDNSDialTCP(t *testing.T) {
+ fake := fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ },
+ Questions: q.Questions,
+ }
+ return r, nil
+ },
+ alwaysTCP: true,
+ }
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+ ctx := context.Background()
+ _, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second, useUDPOrTCP)
+ if err != nil {
+ t.Fatal("exhange failed:", err)
+ }
+}
+
+// Issue 27763: verify that two strings in one TXT record are concatenated.
+func TestTXTRecordTwoStrings(t *testing.T) {
+ fake := fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ },
+ Questions: q.Questions,
+ Answers: []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ },
+ Body: &dnsmessage.TXTResource{
+ TXT: []string{"string1 ", "string2"},
+ },
+ },
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ },
+ Body: &dnsmessage.TXTResource{
+ TXT: []string{"onestring"},
+ },
+ },
+ },
+ }
+ return r, nil
+ },
+ }
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+ txt, err := r.lookupTXT(context.Background(), "golang.org")
+ if err != nil {
+ t.Fatal("LookupTXT failed:", err)
+ }
+ if want := 2; len(txt) != want {
+ t.Fatalf("len(txt), got %d, want %d", len(txt), want)
+ }
+ if want := "string1 string2"; txt[0] != want {
+ t.Errorf("txt[0], got %q, want %q", txt[0], want)
+ }
+ if want := "onestring"; txt[1] != want {
+ t.Errorf("txt[1], got %q, want %q", txt[1], want)
+ }
+}
+
+// Issue 29644: support single-request resolv.conf option in pure Go resolver.
+// The A and AAAA queries will be sent sequentially, not in parallel.
+func TestSingleRequestLookup(t *testing.T) {
+ defer dnsWaitGroup.Wait()
+ var (
+ firstcalled int32
+ ipv4 int32 = 1
+ ipv6 int32 = 2
+ )
+ fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ },
+ Questions: q.Questions,
+ }
+ for _, question := range q.Questions {
+ switch question.Type {
+ case dnsmessage.TypeA:
+ if question.Name.String() == "slowipv4.example.net." {
+ time.Sleep(10 * time.Millisecond)
+ }
+ if !atomic.CompareAndSwapInt32(&firstcalled, 0, ipv4) {
+ t.Errorf("the A query was received after the AAAA query !")
+ }
+ r.Answers = append(r.Answers, dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
+ },
+ })
+ case dnsmessage.TypeAAAA:
+ atomic.CompareAndSwapInt32(&firstcalled, 0, ipv6)
+ r.Answers = append(r.Answers, dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeAAAA,
+ Class: dnsmessage.ClassINET,
+ Length: 16,
+ },
+ Body: &dnsmessage.AAAAResource{
+ AAAA: TestAddr6,
+ },
+ })
+ }
+ }
+ return r, nil
+ }}
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+ if err := conf.writeAndUpdate([]string{"options single-request"}); err != nil {
+ t.Fatal(err)
+ }
+ for _, name := range []string{"hostname.example.net", "slowipv4.example.net"} {
+ firstcalled = 0
+ _, err := r.LookupIPAddr(context.Background(), name)
+ if err != nil {
+ t.Error(err)
+ }
+ }
+}
+
+// Issue 29358. Add configuration knob to force TCP-only DNS requests in the pure Go resolver.
+func TestDNSUseTCP(t *testing.T) {
+ fake := fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ },
+ Questions: q.Questions,
+ }
+ if n == "udp" {
+ t.Fatal("udp protocol was used instead of tcp")
+ }
+ return r, nil
+ },
+ }
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ _, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second, useTCPOnly)
+ if err != nil {
+ t.Fatal("exchange failed:", err)
+ }
+}
+
+// Issue 34660: PTR response with non-PTR answers should ignore non-PTR
+func TestPTRandNonPTR(t *testing.T) {
+ fake := fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ },
+ Questions: q.Questions,
+ Answers: []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypePTR,
+ Class: dnsmessage.ClassINET,
+ },
+ Body: &dnsmessage.PTRResource{
+ PTR: dnsmessage.MustNewName("golang.org."),
+ },
+ },
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeTXT,
+ Class: dnsmessage.ClassINET,
+ },
+ Body: &dnsmessage.TXTResource{
+ TXT: []string{"PTR 8 6 60 ..."}, // fake RRSIG
+ },
+ },
+ },
+ }
+ return r, nil
+ },
+ }
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+ names, err := r.lookupAddr(context.Background(), "192.0.2.123")
+ if err != nil {
+ t.Fatalf("LookupAddr: %v", err)
+ }
+ if want := []string{"golang.org."}; !reflect.DeepEqual(names, want) {
+ t.Errorf("names = %q; want %q", names, want)
+ }
+}
+
+func TestCVE202133195(t *testing.T) {
+ fake := fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ RecursionAvailable: true,
+ },
+ Questions: q.Questions,
+ }
+ switch q.Questions[0].Type {
+ case dnsmessage.TypeCNAME:
+ r.Answers = []dnsmessage.Resource{}
+ case dnsmessage.TypeA: // CNAME lookup uses a A/AAAA as a proxy
+ r.Answers = append(r.Answers,
+ dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: dnsmessage.MustNewName("<html>.golang.org."),
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
+ },
+ },
+ )
+ case dnsmessage.TypeSRV:
+ n := q.Questions[0].Name
+ if n.String() == "_hdr._tcp.golang.org." {
+ n = dnsmessage.MustNewName("<html>.golang.org.")
+ }
+ r.Answers = append(r.Answers,
+ dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: n,
+ Type: dnsmessage.TypeSRV,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.SRVResource{
+ Target: dnsmessage.MustNewName("<html>.golang.org."),
+ },
+ },
+ dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: n,
+ Type: dnsmessage.TypeSRV,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.SRVResource{
+ Target: dnsmessage.MustNewName("good.golang.org."),
+ },
+ },
+ )
+ case dnsmessage.TypeMX:
+ r.Answers = append(r.Answers,
+ dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: dnsmessage.MustNewName("<html>.golang.org."),
+ Type: dnsmessage.TypeMX,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.MXResource{
+ MX: dnsmessage.MustNewName("<html>.golang.org."),
+ },
+ },
+ dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: dnsmessage.MustNewName("good.golang.org."),
+ Type: dnsmessage.TypeMX,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.MXResource{
+ MX: dnsmessage.MustNewName("good.golang.org."),
+ },
+ },
+ )
+ case dnsmessage.TypeNS:
+ r.Answers = append(r.Answers,
+ dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: dnsmessage.MustNewName("<html>.golang.org."),
+ Type: dnsmessage.TypeNS,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.NSResource{
+ NS: dnsmessage.MustNewName("<html>.golang.org."),
+ },
+ },
+ dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: dnsmessage.MustNewName("good.golang.org."),
+ Type: dnsmessage.TypeNS,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.NSResource{
+ NS: dnsmessage.MustNewName("good.golang.org."),
+ },
+ },
+ )
+ case dnsmessage.TypePTR:
+ r.Answers = append(r.Answers,
+ dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: dnsmessage.MustNewName("<html>.golang.org."),
+ Type: dnsmessage.TypePTR,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.PTRResource{
+ PTR: dnsmessage.MustNewName("<html>.golang.org."),
+ },
+ },
+ dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: dnsmessage.MustNewName("good.golang.org."),
+ Type: dnsmessage.TypePTR,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.PTRResource{
+ PTR: dnsmessage.MustNewName("good.golang.org."),
+ },
+ },
+ )
+ }
+ return r, nil
+ },
+ }
+
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+ // Change the default resolver to match our manipulated resolver
+ originalDefault := DefaultResolver
+ DefaultResolver = &r
+ defer func() { DefaultResolver = originalDefault }()
+ // Redirect host file lookups.
+ defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+ testHookHostsPath = "testdata/hosts"
+
+ tests := []struct {
+ name string
+ f func(*testing.T)
+ }{
+ {
+ name: "CNAME",
+ f: func(t *testing.T) {
+ expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
+ _, err := r.LookupCNAME(context.Background(), "golang.org")
+ if err.Error() != expectedErr.Error() {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ _, err = LookupCNAME("golang.org")
+ if err.Error() != expectedErr.Error() {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ },
+ },
+ {
+ name: "SRV (bad record)",
+ f: func(t *testing.T) {
+ expected := []*SRV{
+ {
+ Target: "good.golang.org.",
+ },
+ }
+ expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
+ _, records, err := r.LookupSRV(context.Background(), "target", "tcp", "golang.org")
+ if err.Error() != expectedErr.Error() {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ if !reflect.DeepEqual(records, expected) {
+ t.Error("Unexpected record set")
+ }
+ _, records, err = LookupSRV("target", "tcp", "golang.org")
+ if err.Error() != expectedErr.Error() {
+ t.Errorf("unexpected error: %s", err)
+ }
+ if !reflect.DeepEqual(records, expected) {
+ t.Error("Unexpected record set")
+ }
+ },
+ },
+ {
+ name: "SRV (bad header)",
+ f: func(t *testing.T) {
+ _, _, err := r.LookupSRV(context.Background(), "hdr", "tcp", "golang.org.")
+ if expected := "lookup golang.org.: SRV header name is invalid"; err == nil || err.Error() != expected {
+ t.Errorf("Resolver.LookupSRV returned unexpected error, got %q, want %q", err, expected)
+ }
+ _, _, err = LookupSRV("hdr", "tcp", "golang.org.")
+ if expected := "lookup golang.org.: SRV header name is invalid"; err == nil || err.Error() != expected {
+ t.Errorf("LookupSRV returned unexpected error, got %q, want %q", err, expected)
+ }
+ },
+ },
+ {
+ name: "MX",
+ f: func(t *testing.T) {
+ expected := []*MX{
+ {
+ Host: "good.golang.org.",
+ },
+ }
+ expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
+ records, err := r.LookupMX(context.Background(), "golang.org")
+ if err.Error() != expectedErr.Error() {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ if !reflect.DeepEqual(records, expected) {
+ t.Error("Unexpected record set")
+ }
+ records, err = LookupMX("golang.org")
+ if err.Error() != expectedErr.Error() {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ if !reflect.DeepEqual(records, expected) {
+ t.Error("Unexpected record set")
+ }
+ },
+ },
+ {
+ name: "NS",
+ f: func(t *testing.T) {
+ expected := []*NS{
+ {
+ Host: "good.golang.org.",
+ },
+ }
+ expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"}
+ records, err := r.LookupNS(context.Background(), "golang.org")
+ if err.Error() != expectedErr.Error() {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ if !reflect.DeepEqual(records, expected) {
+ t.Error("Unexpected record set")
+ }
+ records, err = LookupNS("golang.org")
+ if err.Error() != expectedErr.Error() {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ if !reflect.DeepEqual(records, expected) {
+ t.Error("Unexpected record set")
+ }
+ },
+ },
+ {
+ name: "Addr",
+ f: func(t *testing.T) {
+ expected := []string{"good.golang.org."}
+ expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "192.0.2.42"}
+ records, err := r.LookupAddr(context.Background(), "192.0.2.42")
+ if err.Error() != expectedErr.Error() {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ if !reflect.DeepEqual(records, expected) {
+ t.Error("Unexpected record set")
+ }
+ records, err = LookupAddr("192.0.2.42")
+ if err.Error() != expectedErr.Error() {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ if !reflect.DeepEqual(records, expected) {
+ t.Error("Unexpected record set")
+ }
+ },
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, tc.f)
+ }
+}
+
+func TestNullMX(t *testing.T) {
+ fake := fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ },
+ Questions: q.Questions,
+ Answers: []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeMX,
+ Class: dnsmessage.ClassINET,
+ },
+ Body: &dnsmessage.MXResource{
+ MX: dnsmessage.MustNewName("."),
+ },
+ },
+ },
+ }
+ return r, nil
+ },
+ }
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+ rrset, err := r.LookupMX(context.Background(), "golang.org")
+ if err != nil {
+ t.Fatalf("LookupMX: %v", err)
+ }
+ if want := []*MX{&MX{Host: "."}}; !reflect.DeepEqual(rrset, want) {
+ records := []string{}
+ for _, rr := range rrset {
+ records = append(records, fmt.Sprintf("%v", rr))
+ }
+ t.Errorf("records = [%v]; want [%v]", strings.Join(records, " "), want[0])
+ }
+}