diff options
Diffstat (limited to 'src/net/resolverdialfunc_test.go')
-rw-r--r-- | src/net/resolverdialfunc_test.go | 327 |
1 files changed, 327 insertions, 0 deletions
diff --git a/src/net/resolverdialfunc_test.go b/src/net/resolverdialfunc_test.go new file mode 100644 index 0000000..1fb02b1 --- /dev/null +++ b/src/net/resolverdialfunc_test.go @@ -0,0 +1,327 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !js + +// Test that Resolver.Dial can be a func returning an in-memory net.Conn +// speaking DNS. + +package net + +import ( + "bytes" + "context" + "errors" + "fmt" + "reflect" + "sort" + "testing" + "time" + + "golang.org/x/net/dns/dnsmessage" +) + +func TestResolverDialFunc(t *testing.T) { + r := &Resolver{ + PreferGo: true, + Dial: newResolverDialFunc(&resolverDialHandler{ + StartDial: func(network, address string) error { + t.Logf("StartDial(%q, %q) ...", network, address) + return nil + }, + Question: func(h dnsmessage.Header, q dnsmessage.Question) { + t.Logf("Header: %+v for %q (type=%v, class=%v)", h, + q.Name.String(), q.Type, q.Class) + }, + // TODO: add test without HandleA* hooks specified at all, that Go + // doesn't issue retries; map to something terminal. + HandleA: func(w AWriter, name string) error { + w.AddIP([4]byte{1, 2, 3, 4}) + w.AddIP([4]byte{5, 6, 7, 8}) + return nil + }, + HandleAAAA: func(w AAAAWriter, name string) error { + w.AddIP([16]byte{1: 1, 15: 15}) + w.AddIP([16]byte{2: 2, 14: 14}) + return nil + }, + HandleSRV: func(w SRVWriter, name string) error { + w.AddSRV(1, 2, 80, "foo.bar.") + w.AddSRV(2, 3, 81, "bar.baz.") + return nil + }, + }), + } + ctx := context.Background() + const fakeDomain = "something-that-is-a-not-a-real-domain.fake-tld." + + t.Run("LookupIP", func(t *testing.T) { + ips, err := r.LookupIP(ctx, "ip", fakeDomain) + if err != nil { + t.Fatal(err) + } + if got, want := sortedIPStrings(ips), []string{"0:200::e00", "1.2.3.4", "1::f", "5.6.7.8"}; !reflect.DeepEqual(got, want) { + t.Errorf("LookupIP wrong.\n got: %q\nwant: %q\n", got, want) + } + }) + + t.Run("LookupSRV", func(t *testing.T) { + _, got, err := r.LookupSRV(ctx, "some-service", "tcp", fakeDomain) + if err != nil { + t.Fatal(err) + } + want := []*SRV{ + { + Target: "foo.bar.", + Port: 80, + Priority: 1, + Weight: 2, + }, + { + Target: "bar.baz.", + Port: 81, + Priority: 2, + Weight: 3, + }, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("wrong result. got:") + for _, r := range got { + t.Logf(" - %+v", r) + } + } + }) +} + +func sortedIPStrings(ips []IP) []string { + ret := make([]string, len(ips)) + for i, ip := range ips { + ret[i] = ip.String() + } + sort.Strings(ret) + return ret +} + +func newResolverDialFunc(h *resolverDialHandler) func(ctx context.Context, network, address string) (Conn, error) { + return func(ctx context.Context, network, address string) (Conn, error) { + a := &resolverFuncConn{ + h: h, + network: network, + address: address, + ttl: 10, // 10 second default if unset + } + if h.StartDial != nil { + if err := h.StartDial(network, address); err != nil { + return nil, err + } + } + return a, nil + } +} + +type resolverDialHandler struct { + // StartDial, if non-nil, is called when Go first calls Resolver.Dial. + // Any error returned aborts the dial and is returned unwrapped. + StartDial func(network, address string) error + + Question func(dnsmessage.Header, dnsmessage.Question) + + // err may be ErrNotExist or ErrRefused; others map to SERVFAIL (RCode2). + // A nil error means success. + HandleA func(w AWriter, name string) error + HandleAAAA func(w AAAAWriter, name string) error + HandleSRV func(w SRVWriter, name string) error +} + +type ResponseWriter struct{ a *resolverFuncConn } + +func (w ResponseWriter) header() dnsmessage.ResourceHeader { + q := w.a.q + return dnsmessage.ResourceHeader{ + Name: q.Name, + Type: q.Type, + Class: q.Class, + TTL: w.a.ttl, + } +} + +// SetTTL sets the TTL for subsequent written resources. +// Once a resource has been written, SetTTL calls are no-ops. +// That is, it can only be called at most once, before anything +// else is written. +func (w ResponseWriter) SetTTL(seconds uint32) { + // ... intention is last one wins and mutates all previously + // written records too, but that's a little annoying. + // But it's also annoying if the requirement is it needs to be set + // last. + // And it's also annoying if it's possible for users to set + // different TTLs per Answer. + if w.a.wrote { + return + } + w.a.ttl = seconds + +} + +type AWriter struct{ ResponseWriter } + +func (w AWriter) AddIP(v4 [4]byte) { + w.a.wrote = true + err := w.a.builder.AResource(w.header(), dnsmessage.AResource{A: v4}) + if err != nil { + panic(err) + } +} + +type AAAAWriter struct{ ResponseWriter } + +func (w AAAAWriter) AddIP(v6 [16]byte) { + w.a.wrote = true + err := w.a.builder.AAAAResource(w.header(), dnsmessage.AAAAResource{AAAA: v6}) + if err != nil { + panic(err) + } +} + +type SRVWriter struct{ ResponseWriter } + +// AddSRV adds a SRV record. The target name must end in a period and +// be 63 bytes or fewer. +func (w SRVWriter) AddSRV(priority, weight, port uint16, target string) error { + targetName, err := dnsmessage.NewName(target) + if err != nil { + return err + } + w.a.wrote = true + err = w.a.builder.SRVResource(w.header(), dnsmessage.SRVResource{ + Priority: priority, + Weight: weight, + Port: port, + Target: targetName, + }) + if err != nil { + panic(err) // internal fault, not user + } + return nil +} + +var ( + ErrNotExist = errors.New("name does not exist") // maps to RCode3, NXDOMAIN + ErrRefused = errors.New("refused") // maps to RCode5, REFUSED +) + +type resolverFuncConn struct { + h *resolverDialHandler + network string + address string + builder *dnsmessage.Builder + q dnsmessage.Question + ttl uint32 + wrote bool + + rbuf bytes.Buffer +} + +func (*resolverFuncConn) Close() error { return nil } +func (*resolverFuncConn) LocalAddr() Addr { return someaddr{} } +func (*resolverFuncConn) RemoteAddr() Addr { return someaddr{} } +func (*resolverFuncConn) SetDeadline(t time.Time) error { return nil } +func (*resolverFuncConn) SetReadDeadline(t time.Time) error { return nil } +func (*resolverFuncConn) SetWriteDeadline(t time.Time) error { return nil } + +func (a *resolverFuncConn) Read(p []byte) (n int, err error) { + return a.rbuf.Read(p) +} + +func (a *resolverFuncConn) Write(packet []byte) (n int, err error) { + if len(packet) < 2 { + return 0, fmt.Errorf("short write of %d bytes; want 2+", len(packet)) + } + reqLen := int(packet[0])<<8 | int(packet[1]) + req := packet[2:] + if len(req) != reqLen { + return 0, fmt.Errorf("packet declared length %d doesn't match body length %d", reqLen, len(req)) + } + + var parser dnsmessage.Parser + h, err := parser.Start(req) + if err != nil { + // TODO: hook + return 0, err + } + q, err := parser.Question() + hadQ := (err == nil) + if err == nil && a.h.Question != nil { + a.h.Question(h, q) + } + if err != nil && err != dnsmessage.ErrSectionDone { + return 0, err + } + + resh := h + resh.Response = true + resh.Authoritative = true + if hadQ { + resh.RCode = dnsmessage.RCodeSuccess + } else { + resh.RCode = dnsmessage.RCodeNotImplemented + } + a.rbuf.Grow(514) + a.rbuf.WriteByte('X') // reserved header for beu16 length + a.rbuf.WriteByte('Y') // reserved header for beu16 length + builder := dnsmessage.NewBuilder(a.rbuf.Bytes(), resh) + a.builder = &builder + if hadQ { + a.q = q + a.builder.StartQuestions() + err := a.builder.Question(q) + if err != nil { + return 0, fmt.Errorf("Question: %w", err) + } + a.builder.StartAnswers() + switch q.Type { + case dnsmessage.TypeA: + if a.h.HandleA != nil { + resh.RCode = mapRCode(a.h.HandleA(AWriter{ResponseWriter{a}}, q.Name.String())) + } + case dnsmessage.TypeAAAA: + if a.h.HandleAAAA != nil { + resh.RCode = mapRCode(a.h.HandleAAAA(AAAAWriter{ResponseWriter{a}}, q.Name.String())) + } + case dnsmessage.TypeSRV: + if a.h.HandleSRV != nil { + resh.RCode = mapRCode(a.h.HandleSRV(SRVWriter{ResponseWriter{a}}, q.Name.String())) + } + } + } + tcpRes, err := builder.Finish() + if err != nil { + return 0, fmt.Errorf("Finish: %w", err) + } + + n = len(tcpRes) - 2 + tcpRes[0] = byte(n >> 8) + tcpRes[1] = byte(n) + a.rbuf.Write(tcpRes[2:]) + + return len(packet), nil +} + +type someaddr struct{} + +func (someaddr) Network() string { return "unused" } +func (someaddr) String() string { return "unused-someaddr" } + +func mapRCode(err error) dnsmessage.RCode { + switch err { + case nil: + return dnsmessage.RCodeSuccess + case ErrNotExist: + return dnsmessage.RCodeNameError + case ErrRefused: + return dnsmessage.RCodeRefused + default: + return dnsmessage.RCodeServerFailure + } +} |