summaryrefslogtreecommitdiffstats
path: root/src/net/resolverdialfunc_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/net/resolverdialfunc_test.go')
-rw-r--r--src/net/resolverdialfunc_test.go327
1 files changed, 327 insertions, 0 deletions
diff --git a/src/net/resolverdialfunc_test.go b/src/net/resolverdialfunc_test.go
new file mode 100644
index 0000000..1fb02b1
--- /dev/null
+++ b/src/net/resolverdialfunc_test.go
@@ -0,0 +1,327 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !js
+
+// Test that Resolver.Dial can be a func returning an in-memory net.Conn
+// speaking DNS.
+
+package net
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "reflect"
+ "sort"
+ "testing"
+ "time"
+
+ "golang.org/x/net/dns/dnsmessage"
+)
+
+func TestResolverDialFunc(t *testing.T) {
+ r := &Resolver{
+ PreferGo: true,
+ Dial: newResolverDialFunc(&resolverDialHandler{
+ StartDial: func(network, address string) error {
+ t.Logf("StartDial(%q, %q) ...", network, address)
+ return nil
+ },
+ Question: func(h dnsmessage.Header, q dnsmessage.Question) {
+ t.Logf("Header: %+v for %q (type=%v, class=%v)", h,
+ q.Name.String(), q.Type, q.Class)
+ },
+ // TODO: add test without HandleA* hooks specified at all, that Go
+ // doesn't issue retries; map to something terminal.
+ HandleA: func(w AWriter, name string) error {
+ w.AddIP([4]byte{1, 2, 3, 4})
+ w.AddIP([4]byte{5, 6, 7, 8})
+ return nil
+ },
+ HandleAAAA: func(w AAAAWriter, name string) error {
+ w.AddIP([16]byte{1: 1, 15: 15})
+ w.AddIP([16]byte{2: 2, 14: 14})
+ return nil
+ },
+ HandleSRV: func(w SRVWriter, name string) error {
+ w.AddSRV(1, 2, 80, "foo.bar.")
+ w.AddSRV(2, 3, 81, "bar.baz.")
+ return nil
+ },
+ }),
+ }
+ ctx := context.Background()
+ const fakeDomain = "something-that-is-a-not-a-real-domain.fake-tld."
+
+ t.Run("LookupIP", func(t *testing.T) {
+ ips, err := r.LookupIP(ctx, "ip", fakeDomain)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := sortedIPStrings(ips), []string{"0:200::e00", "1.2.3.4", "1::f", "5.6.7.8"}; !reflect.DeepEqual(got, want) {
+ t.Errorf("LookupIP wrong.\n got: %q\nwant: %q\n", got, want)
+ }
+ })
+
+ t.Run("LookupSRV", func(t *testing.T) {
+ _, got, err := r.LookupSRV(ctx, "some-service", "tcp", fakeDomain)
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := []*SRV{
+ {
+ Target: "foo.bar.",
+ Port: 80,
+ Priority: 1,
+ Weight: 2,
+ },
+ {
+ Target: "bar.baz.",
+ Port: 81,
+ Priority: 2,
+ Weight: 3,
+ },
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("wrong result. got:")
+ for _, r := range got {
+ t.Logf(" - %+v", r)
+ }
+ }
+ })
+}
+
+func sortedIPStrings(ips []IP) []string {
+ ret := make([]string, len(ips))
+ for i, ip := range ips {
+ ret[i] = ip.String()
+ }
+ sort.Strings(ret)
+ return ret
+}
+
+func newResolverDialFunc(h *resolverDialHandler) func(ctx context.Context, network, address string) (Conn, error) {
+ return func(ctx context.Context, network, address string) (Conn, error) {
+ a := &resolverFuncConn{
+ h: h,
+ network: network,
+ address: address,
+ ttl: 10, // 10 second default if unset
+ }
+ if h.StartDial != nil {
+ if err := h.StartDial(network, address); err != nil {
+ return nil, err
+ }
+ }
+ return a, nil
+ }
+}
+
+type resolverDialHandler struct {
+ // StartDial, if non-nil, is called when Go first calls Resolver.Dial.
+ // Any error returned aborts the dial and is returned unwrapped.
+ StartDial func(network, address string) error
+
+ Question func(dnsmessage.Header, dnsmessage.Question)
+
+ // err may be ErrNotExist or ErrRefused; others map to SERVFAIL (RCode2).
+ // A nil error means success.
+ HandleA func(w AWriter, name string) error
+ HandleAAAA func(w AAAAWriter, name string) error
+ HandleSRV func(w SRVWriter, name string) error
+}
+
+type ResponseWriter struct{ a *resolverFuncConn }
+
+func (w ResponseWriter) header() dnsmessage.ResourceHeader {
+ q := w.a.q
+ return dnsmessage.ResourceHeader{
+ Name: q.Name,
+ Type: q.Type,
+ Class: q.Class,
+ TTL: w.a.ttl,
+ }
+}
+
+// SetTTL sets the TTL for subsequent written resources.
+// Once a resource has been written, SetTTL calls are no-ops.
+// That is, it can only be called at most once, before anything
+// else is written.
+func (w ResponseWriter) SetTTL(seconds uint32) {
+ // ... intention is last one wins and mutates all previously
+ // written records too, but that's a little annoying.
+ // But it's also annoying if the requirement is it needs to be set
+ // last.
+ // And it's also annoying if it's possible for users to set
+ // different TTLs per Answer.
+ if w.a.wrote {
+ return
+ }
+ w.a.ttl = seconds
+
+}
+
+type AWriter struct{ ResponseWriter }
+
+func (w AWriter) AddIP(v4 [4]byte) {
+ w.a.wrote = true
+ err := w.a.builder.AResource(w.header(), dnsmessage.AResource{A: v4})
+ if err != nil {
+ panic(err)
+ }
+}
+
+type AAAAWriter struct{ ResponseWriter }
+
+func (w AAAAWriter) AddIP(v6 [16]byte) {
+ w.a.wrote = true
+ err := w.a.builder.AAAAResource(w.header(), dnsmessage.AAAAResource{AAAA: v6})
+ if err != nil {
+ panic(err)
+ }
+}
+
+type SRVWriter struct{ ResponseWriter }
+
+// AddSRV adds a SRV record. The target name must end in a period and
+// be 63 bytes or fewer.
+func (w SRVWriter) AddSRV(priority, weight, port uint16, target string) error {
+ targetName, err := dnsmessage.NewName(target)
+ if err != nil {
+ return err
+ }
+ w.a.wrote = true
+ err = w.a.builder.SRVResource(w.header(), dnsmessage.SRVResource{
+ Priority: priority,
+ Weight: weight,
+ Port: port,
+ Target: targetName,
+ })
+ if err != nil {
+ panic(err) // internal fault, not user
+ }
+ return nil
+}
+
+var (
+ ErrNotExist = errors.New("name does not exist") // maps to RCode3, NXDOMAIN
+ ErrRefused = errors.New("refused") // maps to RCode5, REFUSED
+)
+
+type resolverFuncConn struct {
+ h *resolverDialHandler
+ network string
+ address string
+ builder *dnsmessage.Builder
+ q dnsmessage.Question
+ ttl uint32
+ wrote bool
+
+ rbuf bytes.Buffer
+}
+
+func (*resolverFuncConn) Close() error { return nil }
+func (*resolverFuncConn) LocalAddr() Addr { return someaddr{} }
+func (*resolverFuncConn) RemoteAddr() Addr { return someaddr{} }
+func (*resolverFuncConn) SetDeadline(t time.Time) error { return nil }
+func (*resolverFuncConn) SetReadDeadline(t time.Time) error { return nil }
+func (*resolverFuncConn) SetWriteDeadline(t time.Time) error { return nil }
+
+func (a *resolverFuncConn) Read(p []byte) (n int, err error) {
+ return a.rbuf.Read(p)
+}
+
+func (a *resolverFuncConn) Write(packet []byte) (n int, err error) {
+ if len(packet) < 2 {
+ return 0, fmt.Errorf("short write of %d bytes; want 2+", len(packet))
+ }
+ reqLen := int(packet[0])<<8 | int(packet[1])
+ req := packet[2:]
+ if len(req) != reqLen {
+ return 0, fmt.Errorf("packet declared length %d doesn't match body length %d", reqLen, len(req))
+ }
+
+ var parser dnsmessage.Parser
+ h, err := parser.Start(req)
+ if err != nil {
+ // TODO: hook
+ return 0, err
+ }
+ q, err := parser.Question()
+ hadQ := (err == nil)
+ if err == nil && a.h.Question != nil {
+ a.h.Question(h, q)
+ }
+ if err != nil && err != dnsmessage.ErrSectionDone {
+ return 0, err
+ }
+
+ resh := h
+ resh.Response = true
+ resh.Authoritative = true
+ if hadQ {
+ resh.RCode = dnsmessage.RCodeSuccess
+ } else {
+ resh.RCode = dnsmessage.RCodeNotImplemented
+ }
+ a.rbuf.Grow(514)
+ a.rbuf.WriteByte('X') // reserved header for beu16 length
+ a.rbuf.WriteByte('Y') // reserved header for beu16 length
+ builder := dnsmessage.NewBuilder(a.rbuf.Bytes(), resh)
+ a.builder = &builder
+ if hadQ {
+ a.q = q
+ a.builder.StartQuestions()
+ err := a.builder.Question(q)
+ if err != nil {
+ return 0, fmt.Errorf("Question: %w", err)
+ }
+ a.builder.StartAnswers()
+ switch q.Type {
+ case dnsmessage.TypeA:
+ if a.h.HandleA != nil {
+ resh.RCode = mapRCode(a.h.HandleA(AWriter{ResponseWriter{a}}, q.Name.String()))
+ }
+ case dnsmessage.TypeAAAA:
+ if a.h.HandleAAAA != nil {
+ resh.RCode = mapRCode(a.h.HandleAAAA(AAAAWriter{ResponseWriter{a}}, q.Name.String()))
+ }
+ case dnsmessage.TypeSRV:
+ if a.h.HandleSRV != nil {
+ resh.RCode = mapRCode(a.h.HandleSRV(SRVWriter{ResponseWriter{a}}, q.Name.String()))
+ }
+ }
+ }
+ tcpRes, err := builder.Finish()
+ if err != nil {
+ return 0, fmt.Errorf("Finish: %w", err)
+ }
+
+ n = len(tcpRes) - 2
+ tcpRes[0] = byte(n >> 8)
+ tcpRes[1] = byte(n)
+ a.rbuf.Write(tcpRes[2:])
+
+ return len(packet), nil
+}
+
+type someaddr struct{}
+
+func (someaddr) Network() string { return "unused" }
+func (someaddr) String() string { return "unused-someaddr" }
+
+func mapRCode(err error) dnsmessage.RCode {
+ switch err {
+ case nil:
+ return dnsmessage.RCodeSuccess
+ case ErrNotExist:
+ return dnsmessage.RCodeNameError
+ case ErrRefused:
+ return dnsmessage.RCodeRefused
+ default:
+ return dnsmessage.RCodeServerFailure
+ }
+}