diff options
Diffstat (limited to 'src/net/dnsclient_unix.go')
-rw-r--r-- | src/net/dnsclient_unix.go | 789 |
1 files changed, 789 insertions, 0 deletions
diff --git a/src/net/dnsclient_unix.go b/src/net/dnsclient_unix.go new file mode 100644 index 0000000..d7db0c8 --- /dev/null +++ b/src/net/dnsclient_unix.go @@ -0,0 +1,789 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris + +// DNS client: see RFC 1035. +// Has to be linked into package net for Dial. + +// TODO(rsc): +// Could potentially handle many outstanding lookups faster. +// Random UDP source port (net.Dial should do that for us). +// Random request IDs. + +package net + +import ( + "context" + "errors" + "io" + "os" + "sync" + "time" + + "golang.org/x/net/dns/dnsmessage" +) + +const ( + // to be used as a useTCP parameter to exchange + useTCPOnly = true + useUDPOrTCP = false +) + +var ( + errLameReferral = errors.New("lame referral") + errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message") + errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message") + errServerMisbehaving = errors.New("server misbehaving") + errInvalidDNSResponse = errors.New("invalid DNS response") + errNoAnswerFromDNSServer = errors.New("no answer from DNS server") + + // errServerTemporarilyMisbehaving is like errServerMisbehaving, except + // that when it gets translated to a DNSError, the IsTemporary field + // gets set to true. + errServerTemporarilyMisbehaving = errors.New("server misbehaving") +) + +func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) { + id = uint16(randInt()) + b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true}) + b.EnableCompression() + if err := b.StartQuestions(); err != nil { + return 0, nil, nil, err + } + if err := b.Question(q); err != nil { + return 0, nil, nil, err + } + tcpReq, err = b.Finish() + udpReq = tcpReq[2:] + l := len(tcpReq) - 2 + tcpReq[0] = byte(l >> 8) + tcpReq[1] = byte(l) + return id, udpReq, tcpReq, err +} + +func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool { + if !respHdr.Response { + return false + } + if reqID != respHdr.ID { + return false + } + if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) { + return false + } + return true +} + +func dnsPacketRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { + if _, err := c.Write(b); err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + + b = make([]byte, 512) // see RFC 1035 + for { + n, err := c.Read(b) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + var p dnsmessage.Parser + // Ignore invalid responses as they may be malicious + // forgery attempts. Instead continue waiting until + // timeout. See golang.org/issue/13281. + h, err := p.Start(b[:n]) + if err != nil { + continue + } + q, err := p.Question() + if err != nil || !checkResponse(id, query, h, q) { + continue + } + return p, h, nil + } +} + +func dnsStreamRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { + if _, err := c.Write(b); err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + + b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035 + if _, err := io.ReadFull(c, b[:2]); err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + l := int(b[0])<<8 | int(b[1]) + if l > len(b) { + b = make([]byte, l) + } + n, err := io.ReadFull(c, b[:l]) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + var p dnsmessage.Parser + h, err := p.Start(b[:n]) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage + } + q, err := p.Question() + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage + } + if !checkResponse(id, query, h, q) { + return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse + } + return p, h, nil +} + +// exchange sends a query on the connection and hopes for a response. +func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Question, timeout time.Duration, useTCP bool) (dnsmessage.Parser, dnsmessage.Header, error) { + q.Class = dnsmessage.ClassINET + id, udpReq, tcpReq, err := newRequest(q) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage + } + var networks []string + if useTCP { + networks = []string{"tcp"} + } else { + networks = []string{"udp", "tcp"} + } + for _, network := range networks { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) + defer cancel() + + c, err := r.dial(ctx, network, server) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + if d, ok := ctx.Deadline(); ok && !d.IsZero() { + c.SetDeadline(d) + } + var p dnsmessage.Parser + var h dnsmessage.Header + if _, ok := c.(PacketConn); ok { + p, h, err = dnsPacketRoundTrip(c, id, q, udpReq) + } else { + p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq) + } + c.Close() + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, mapErr(err) + } + if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone { + return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse + } + if h.Truncated { // see RFC 5966 + continue + } + return p, h, nil + } + return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer +} + +// checkHeader performs basic sanity checks on the header. +func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error { + if h.RCode == dnsmessage.RCodeNameError { + return errNoSuchHost + } + + _, err := p.AnswerHeader() + if err != nil && err != dnsmessage.ErrSectionDone { + return errCannotUnmarshalDNSMessage + } + + // libresolv continues to the next server when it receives + // an invalid referral response. See golang.org/issue/15434. + if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone { + return errLameReferral + } + + if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError { + // None of the error codes make sense + // for the query we sent. If we didn't get + // a name error and we didn't get success, + // the server is behaving incorrectly or + // having temporary trouble. + if h.RCode == dnsmessage.RCodeServerFailure { + return errServerTemporarilyMisbehaving + } + return errServerMisbehaving + } + + return nil +} + +func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error { + for { + h, err := p.AnswerHeader() + if err == dnsmessage.ErrSectionDone { + return errNoSuchHost + } + if err != nil { + return errCannotUnmarshalDNSMessage + } + if h.Type == qtype { + return nil + } + if err := p.SkipAnswer(); err != nil { + return errCannotUnmarshalDNSMessage + } + } +} + +// Do a lookup for a single name, which must be rooted +// (otherwise answer will not find the answers). +func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) { + var lastErr error + serverOffset := cfg.serverOffset() + sLen := uint32(len(cfg.servers)) + + n, err := dnsmessage.NewName(name) + if err != nil { + return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage + } + q := dnsmessage.Question{ + Name: n, + Type: qtype, + Class: dnsmessage.ClassINET, + } + + for i := 0; i < cfg.attempts; i++ { + for j := uint32(0); j < sLen; j++ { + server := cfg.servers[(serverOffset+j)%sLen] + + p, h, err := r.exchange(ctx, server, q, cfg.timeout, cfg.useTCP) + if err != nil { + dnsErr := &DNSError{ + Err: err.Error(), + Name: name, + Server: server, + } + if nerr, ok := err.(Error); ok && nerr.Timeout() { + dnsErr.IsTimeout = true + } + // Set IsTemporary for socket-level errors. Note that this flag + // may also be used to indicate a SERVFAIL response. + if _, ok := err.(*OpError); ok { + dnsErr.IsTemporary = true + } + lastErr = dnsErr + continue + } + + if err := checkHeader(&p, h); err != nil { + dnsErr := &DNSError{ + Err: err.Error(), + Name: name, + Server: server, + } + if err == errServerTemporarilyMisbehaving { + dnsErr.IsTemporary = true + } + if err == errNoSuchHost { + // The name does not exist, so trying + // another server won't help. + + dnsErr.IsNotFound = true + return p, server, dnsErr + } + lastErr = dnsErr + continue + } + + err = skipToAnswer(&p, qtype) + if err == nil { + return p, server, nil + } + lastErr = &DNSError{ + Err: err.Error(), + Name: name, + Server: server, + } + if err == errNoSuchHost { + // The name does not exist, so trying another + // server won't help. + + lastErr.(*DNSError).IsNotFound = true + return p, server, lastErr + } + } + } + return dnsmessage.Parser{}, "", lastErr +} + +// A resolverConfig represents a DNS stub resolver configuration. +type resolverConfig struct { + initOnce sync.Once // guards init of resolverConfig + + // ch is used as a semaphore that only allows one lookup at a + // time to recheck resolv.conf. + ch chan struct{} // guards lastChecked and modTime + lastChecked time.Time // last time resolv.conf was checked + + mu sync.RWMutex // protects dnsConfig + dnsConfig *dnsConfig // parsed resolv.conf structure used in lookups +} + +var resolvConf resolverConfig + +// init initializes conf and is only called via conf.initOnce. +func (conf *resolverConfig) init() { + // Set dnsConfig and lastChecked so we don't parse + // resolv.conf twice the first time. + conf.dnsConfig = systemConf().resolv + if conf.dnsConfig == nil { + conf.dnsConfig = dnsReadConfig("/etc/resolv.conf") + } + conf.lastChecked = time.Now() + + // Prepare ch so that only one update of resolverConfig may + // run at once. + conf.ch = make(chan struct{}, 1) +} + +// tryUpdate tries to update conf with the named resolv.conf file. +// The name variable only exists for testing. It is otherwise always +// "/etc/resolv.conf". +func (conf *resolverConfig) tryUpdate(name string) { + conf.initOnce.Do(conf.init) + + // Ensure only one update at a time checks resolv.conf. + if !conf.tryAcquireSema() { + return + } + defer conf.releaseSema() + + now := time.Now() + if conf.lastChecked.After(now.Add(-5 * time.Second)) { + return + } + conf.lastChecked = now + + var mtime time.Time + if fi, err := os.Stat(name); err == nil { + mtime = fi.ModTime() + } + if mtime.Equal(conf.dnsConfig.mtime) { + return + } + + dnsConf := dnsReadConfig(name) + conf.mu.Lock() + conf.dnsConfig = dnsConf + conf.mu.Unlock() +} + +func (conf *resolverConfig) tryAcquireSema() bool { + select { + case conf.ch <- struct{}{}: + return true + default: + return false + } +} + +func (conf *resolverConfig) releaseSema() { + <-conf.ch +} + +func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) { + if !isDomainName(name) { + // We used to use "invalid domain name" as the error, + // but that is a detail of the specific lookup mechanism. + // Other lookups might allow broader name syntax + // (for example Multicast DNS allows UTF-8; see RFC 6762). + // For consistency with libc resolvers, report no such host. + return dnsmessage.Parser{}, "", &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true} + } + resolvConf.tryUpdate("/etc/resolv.conf") + resolvConf.mu.RLock() + conf := resolvConf.dnsConfig + resolvConf.mu.RUnlock() + var ( + p dnsmessage.Parser + server string + err error + ) + for _, fqdn := range conf.nameList(name) { + p, server, err = r.tryOneName(ctx, conf, fqdn, qtype) + if err == nil { + break + } + if nerr, ok := err.(Error); ok && nerr.Temporary() && r.strictErrors() { + // If we hit a temporary error with StrictErrors enabled, + // stop immediately instead of trying more names. + break + } + } + if err == nil { + return p, server, nil + } + if err, ok := err.(*DNSError); ok { + // Show original name passed to lookup, not suffixed one. + // In general we might have tried many suffixes; showing + // just one is misleading. See also golang.org/issue/6324. + err.Name = name + } + return dnsmessage.Parser{}, "", err +} + +// avoidDNS reports whether this is a hostname for which we should not +// use DNS. Currently this includes only .onion, per RFC 7686. See +// golang.org/issue/13705. Does not cover .local names (RFC 6762), +// see golang.org/issue/16739. +func avoidDNS(name string) bool { + if name == "" { + return true + } + if name[len(name)-1] == '.' { + name = name[:len(name)-1] + } + return stringsHasSuffixFold(name, ".onion") +} + +// nameList returns a list of names for sequential DNS queries. +func (conf *dnsConfig) nameList(name string) []string { + if avoidDNS(name) { + return nil + } + + // Check name length (see isDomainName). + l := len(name) + rooted := l > 0 && name[l-1] == '.' + if l > 254 || l == 254 && rooted { + return nil + } + + // If name is rooted (trailing dot), try only that name. + if rooted { + return []string{name} + } + + hasNdots := count(name, '.') >= conf.ndots + name += "." + l++ + + // Build list of search choices. + names := make([]string, 0, 1+len(conf.search)) + // If name has enough dots, try unsuffixed first. + if hasNdots { + names = append(names, name) + } + // Try suffixes that are not too long (see isDomainName). + for _, suffix := range conf.search { + if l+len(suffix) <= 254 { + names = append(names, name+suffix) + } + } + // Try unsuffixed, if not tried first above. + if !hasNdots { + names = append(names, name) + } + return names +} + +// hostLookupOrder specifies the order of LookupHost lookup strategies. +// It is basically a simplified representation of nsswitch.conf. +// "files" means /etc/hosts. +type hostLookupOrder int + +const ( + // hostLookupCgo means defer to cgo. + hostLookupCgo hostLookupOrder = iota + hostLookupFilesDNS // files first + hostLookupDNSFiles // dns first + hostLookupFiles // only files + hostLookupDNS // only DNS +) + +var lookupOrderName = map[hostLookupOrder]string{ + hostLookupCgo: "cgo", + hostLookupFilesDNS: "files,dns", + hostLookupDNSFiles: "dns,files", + hostLookupFiles: "files", + hostLookupDNS: "dns", +} + +func (o hostLookupOrder) String() string { + if s, ok := lookupOrderName[o]; ok { + return s + } + return "hostLookupOrder=" + itoa(int(o)) + "??" +} + +// goLookupHost is the native Go implementation of LookupHost. +// Used only if cgoLookupHost refuses to handle the request +// (that is, only if cgoLookupHost is the stub in cgo_stub.go). +// Normally we let cgo use the C library resolver instead of +// depending on our lookup code, so that Go and C get the same +// answers. +func (r *Resolver) goLookupHost(ctx context.Context, name string) (addrs []string, err error) { + return r.goLookupHostOrder(ctx, name, hostLookupFilesDNS) +} + +func (r *Resolver) goLookupHostOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []string, err error) { + if order == hostLookupFilesDNS || order == hostLookupFiles { + // Use entries from /etc/hosts if they match. + addrs = lookupStaticHost(name) + if len(addrs) > 0 || order == hostLookupFiles { + return + } + } + ips, _, err := r.goLookupIPCNAMEOrder(ctx, name, order) + if err != nil { + return + } + addrs = make([]string, 0, len(ips)) + for _, ip := range ips { + addrs = append(addrs, ip.String()) + } + return +} + +// lookup entries from /etc/hosts +func goLookupIPFiles(name string) (addrs []IPAddr) { + for _, haddr := range lookupStaticHost(name) { + haddr, zone := splitHostZone(haddr) + if ip := ParseIP(haddr); ip != nil { + addr := IPAddr{IP: ip, Zone: zone} + addrs = append(addrs, addr) + } + } + sortByRFC6724(addrs) + return +} + +// goLookupIP is the native Go implementation of LookupIP. +// The libc versions are in cgo_*.go. +func (r *Resolver) goLookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { + order := systemConf().hostLookupOrder(r, host) + addrs, _, err = r.goLookupIPCNAMEOrder(ctx, host, order) + return +} + +func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, cname dnsmessage.Name, err error) { + if order == hostLookupFilesDNS || order == hostLookupFiles { + addrs = goLookupIPFiles(name) + if len(addrs) > 0 || order == hostLookupFiles { + return addrs, dnsmessage.Name{}, nil + } + } + if !isDomainName(name) { + // See comment in func lookup above about use of errNoSuchHost. + return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true} + } + resolvConf.tryUpdate("/etc/resolv.conf") + resolvConf.mu.RLock() + conf := resolvConf.dnsConfig + resolvConf.mu.RUnlock() + type result struct { + p dnsmessage.Parser + server string + error + } + lane := make(chan result, 1) + qtypes := [...]dnsmessage.Type{dnsmessage.TypeA, dnsmessage.TypeAAAA} + var queryFn func(fqdn string, qtype dnsmessage.Type) + var responseFn func(fqdn string, qtype dnsmessage.Type) result + if conf.singleRequest { + queryFn = func(fqdn string, qtype dnsmessage.Type) {} + responseFn = func(fqdn string, qtype dnsmessage.Type) result { + dnsWaitGroup.Add(1) + defer dnsWaitGroup.Done() + p, server, err := r.tryOneName(ctx, conf, fqdn, qtype) + return result{p, server, err} + } + } else { + queryFn = func(fqdn string, qtype dnsmessage.Type) { + dnsWaitGroup.Add(1) + go func(qtype dnsmessage.Type) { + p, server, err := r.tryOneName(ctx, conf, fqdn, qtype) + lane <- result{p, server, err} + dnsWaitGroup.Done() + }(qtype) + } + responseFn = func(fqdn string, qtype dnsmessage.Type) result { + return <-lane + } + } + var lastErr error + for _, fqdn := range conf.nameList(name) { + for _, qtype := range qtypes { + queryFn(fqdn, qtype) + } + hitStrictError := false + for _, qtype := range qtypes { + result := responseFn(fqdn, qtype) + if result.error != nil { + if nerr, ok := result.error.(Error); ok && nerr.Temporary() && r.strictErrors() { + // This error will abort the nameList loop. + hitStrictError = true + lastErr = result.error + } else if lastErr == nil || fqdn == name+"." { + // Prefer error for original name. + lastErr = result.error + } + continue + } + + // Presotto says it's okay to assume that servers listed in + // /etc/resolv.conf are recursive resolvers. + // + // We asked for recursion, so it should have included all the + // answers we need in this one packet. + // + // Further, RFC 1035 section 4.3.1 says that "the recursive + // response to a query will be... The answer to the query, + // possibly preface by one or more CNAME RRs that specify + // aliases encountered on the way to an answer." + // + // Therefore, we should be able to assume that we can ignore + // CNAMEs and that the A and AAAA records we requested are + // for the canonical name. + + loop: + for { + h, err := result.p.AnswerHeader() + if err != nil && err != dnsmessage.ErrSectionDone { + lastErr = &DNSError{ + Err: "cannot marshal DNS message", + Name: name, + Server: result.server, + } + } + if err != nil { + break + } + switch h.Type { + case dnsmessage.TypeA: + a, err := result.p.AResource() + if err != nil { + lastErr = &DNSError{ + Err: "cannot marshal DNS message", + Name: name, + Server: result.server, + } + break loop + } + addrs = append(addrs, IPAddr{IP: IP(a.A[:])}) + + case dnsmessage.TypeAAAA: + aaaa, err := result.p.AAAAResource() + if err != nil { + lastErr = &DNSError{ + Err: "cannot marshal DNS message", + Name: name, + Server: result.server, + } + break loop + } + addrs = append(addrs, IPAddr{IP: IP(aaaa.AAAA[:])}) + + default: + if err := result.p.SkipAnswer(); err != nil { + lastErr = &DNSError{ + Err: "cannot marshal DNS message", + Name: name, + Server: result.server, + } + break loop + } + continue + } + if cname.Length == 0 && h.Name.Length != 0 { + cname = h.Name + } + } + } + if hitStrictError { + // If either family hit an error with StrictErrors enabled, + // discard all addresses. This ensures that network flakiness + // cannot turn a dualstack hostname IPv4/IPv6-only. + addrs = nil + break + } + if len(addrs) > 0 { + break + } + } + if lastErr, ok := lastErr.(*DNSError); ok { + // Show original name passed to lookup, not suffixed one. + // In general we might have tried many suffixes; showing + // just one is misleading. See also golang.org/issue/6324. + lastErr.Name = name + } + sortByRFC6724(addrs) + if len(addrs) == 0 { + if order == hostLookupDNSFiles { + addrs = goLookupIPFiles(name) + } + if len(addrs) == 0 && lastErr != nil { + return nil, dnsmessage.Name{}, lastErr + } + } + return addrs, cname, nil +} + +// goLookupCNAME is the native Go (non-cgo) implementation of LookupCNAME. +func (r *Resolver) goLookupCNAME(ctx context.Context, host string) (string, error) { + order := systemConf().hostLookupOrder(r, host) + _, cname, err := r.goLookupIPCNAMEOrder(ctx, host, order) + return cname.String(), err +} + +// goLookupPTR is the native Go implementation of LookupAddr. +// Used only if cgoLookupPTR refuses to handle the request (that is, +// only if cgoLookupPTR is the stub in cgo_stub.go). +// Normally we let cgo use the C library resolver instead of depending +// on our lookup code, so that Go and C get the same answers. +func (r *Resolver) goLookupPTR(ctx context.Context, addr string) ([]string, error) { + names := lookupStaticAddr(addr) + if len(names) > 0 { + return names, nil + } + arpa, err := reverseaddr(addr) + if err != nil { + return nil, err + } + p, server, err := r.lookup(ctx, arpa, dnsmessage.TypePTR) + if err != nil { + return nil, err + } + var ptrs []string + for { + h, err := p.AnswerHeader() + if err == dnsmessage.ErrSectionDone { + break + } + if err != nil { + return nil, &DNSError{ + Err: "cannot marshal DNS message", + Name: addr, + Server: server, + } + } + if h.Type != dnsmessage.TypePTR { + err := p.SkipAnswer() + if err != nil { + return nil, &DNSError{ + Err: "cannot marshal DNS message", + Name: addr, + Server: server, + } + } + continue + } + ptr, err := p.PTRResource() + if err != nil { + return nil, &DNSError{ + Err: "cannot marshal DNS message", + Name: addr, + Server: server, + } + } + ptrs = append(ptrs, ptr.PTR.String()) + + } + return ptrs, nil +} |