diff options
Diffstat (limited to '')
-rw-r--r-- | src/net/textproto/header.go | 56 | ||||
-rw-r--r-- | src/net/textproto/header_test.go | 54 | ||||
-rw-r--r-- | src/net/textproto/pipeline.go | 118 | ||||
-rw-r--r-- | src/net/textproto/reader.go | 822 | ||||
-rw-r--r-- | src/net/textproto/reader_test.go | 525 | ||||
-rw-r--r-- | src/net/textproto/textproto.go | 152 | ||||
-rw-r--r-- | src/net/textproto/writer.go | 119 | ||||
-rw-r--r-- | src/net/textproto/writer_test.go | 61 |
8 files changed, 1907 insertions, 0 deletions
diff --git a/src/net/textproto/header.go b/src/net/textproto/header.go new file mode 100644 index 0000000..a58df7a --- /dev/null +++ b/src/net/textproto/header.go @@ -0,0 +1,56 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +// A MIMEHeader represents a MIME-style header mapping +// keys to sets of values. +type MIMEHeader map[string][]string + +// Add adds the key, value pair to the header. +// It appends to any existing values associated with key. +func (h MIMEHeader) Add(key, value string) { + key = CanonicalMIMEHeaderKey(key) + h[key] = append(h[key], value) +} + +// Set sets the header entries associated with key to +// the single element value. It replaces any existing +// values associated with key. +func (h MIMEHeader) Set(key, value string) { + h[CanonicalMIMEHeaderKey(key)] = []string{value} +} + +// Get gets the first value associated with the given key. +// It is case insensitive; CanonicalMIMEHeaderKey is used +// to canonicalize the provided key. +// If there are no values associated with the key, Get returns "". +// To use non-canonical keys, access the map directly. +func (h MIMEHeader) Get(key string) string { + if h == nil { + return "" + } + v := h[CanonicalMIMEHeaderKey(key)] + if len(v) == 0 { + return "" + } + return v[0] +} + +// Values returns all values associated with the given key. +// It is case insensitive; CanonicalMIMEHeaderKey is +// used to canonicalize the provided key. To use non-canonical +// keys, access the map directly. +// The returned slice is not a copy. +func (h MIMEHeader) Values(key string) []string { + if h == nil { + return nil + } + return h[CanonicalMIMEHeaderKey(key)] +} + +// Del deletes the values associated with key. +func (h MIMEHeader) Del(key string) { + delete(h, CanonicalMIMEHeaderKey(key)) +} diff --git a/src/net/textproto/header_test.go b/src/net/textproto/header_test.go new file mode 100644 index 0000000..de9405c --- /dev/null +++ b/src/net/textproto/header_test.go @@ -0,0 +1,54 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import "testing" + +type canonicalHeaderKeyTest struct { + in, out string +} + +var canonicalHeaderKeyTests = []canonicalHeaderKeyTest{ + {"a-b-c", "A-B-C"}, + {"a-1-c", "A-1-C"}, + {"User-Agent", "User-Agent"}, + {"uSER-aGENT", "User-Agent"}, + {"user-agent", "User-Agent"}, + {"USER-AGENT", "User-Agent"}, + + // Other valid tchar bytes in tokens: + {"foo-bar_baz", "Foo-Bar_baz"}, + {"foo-bar$baz", "Foo-Bar$baz"}, + {"foo-bar~baz", "Foo-Bar~baz"}, + {"foo-bar*baz", "Foo-Bar*baz"}, + + // Non-ASCII or anything with spaces or non-token chars is unchanged: + {"üser-agenT", "üser-agenT"}, + {"a B", "a B"}, + + // This caused a panic due to mishandling of a space: + {"C Ontent-Transfer-Encoding", "C Ontent-Transfer-Encoding"}, + {"foo bar", "foo bar"}, +} + +func TestCanonicalMIMEHeaderKey(t *testing.T) { + for _, tt := range canonicalHeaderKeyTests { + if s := CanonicalMIMEHeaderKey(tt.in); s != tt.out { + t.Errorf("CanonicalMIMEHeaderKey(%q) = %q, want %q", tt.in, s, tt.out) + } + } +} + +// Issue #34799 add a Header method to get multiple values []string, with canonicalized key +func TestMIMEHeaderMultipleValues(t *testing.T) { + testHeader := MIMEHeader{ + "Set-Cookie": {"cookie 1", "cookie 2"}, + } + values := testHeader.Values("set-cookie") + n := len(values) + if n != 2 { + t.Errorf("count: %d; want 2", n) + } +} diff --git a/src/net/textproto/pipeline.go b/src/net/textproto/pipeline.go new file mode 100644 index 0000000..1928a30 --- /dev/null +++ b/src/net/textproto/pipeline.go @@ -0,0 +1,118 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "sync" +) + +// A Pipeline manages a pipelined in-order request/response sequence. +// +// To use a Pipeline p to manage multiple clients on a connection, +// each client should run: +// +// id := p.Next() // take a number +// +// p.StartRequest(id) // wait for turn to send request +// «send request» +// p.EndRequest(id) // notify Pipeline that request is sent +// +// p.StartResponse(id) // wait for turn to read response +// «read response» +// p.EndResponse(id) // notify Pipeline that response is read +// +// A pipelined server can use the same calls to ensure that +// responses computed in parallel are written in the correct order. +type Pipeline struct { + mu sync.Mutex + id uint + request sequencer + response sequencer +} + +// Next returns the next id for a request/response pair. +func (p *Pipeline) Next() uint { + p.mu.Lock() + id := p.id + p.id++ + p.mu.Unlock() + return id +} + +// StartRequest blocks until it is time to send (or, if this is a server, receive) +// the request with the given id. +func (p *Pipeline) StartRequest(id uint) { + p.request.Start(id) +} + +// EndRequest notifies p that the request with the given id has been sent +// (or, if this is a server, received). +func (p *Pipeline) EndRequest(id uint) { + p.request.End(id) +} + +// StartResponse blocks until it is time to receive (or, if this is a server, send) +// the request with the given id. +func (p *Pipeline) StartResponse(id uint) { + p.response.Start(id) +} + +// EndResponse notifies p that the response with the given id has been received +// (or, if this is a server, sent). +func (p *Pipeline) EndResponse(id uint) { + p.response.End(id) +} + +// A sequencer schedules a sequence of numbered events that must +// happen in order, one after the other. The event numbering must start +// at 0 and increment without skipping. The event number wraps around +// safely as long as there are not 2^32 simultaneous events pending. +type sequencer struct { + mu sync.Mutex + id uint + wait map[uint]chan struct{} +} + +// Start waits until it is time for the event numbered id to begin. +// That is, except for the first event, it waits until End(id-1) has +// been called. +func (s *sequencer) Start(id uint) { + s.mu.Lock() + if s.id == id { + s.mu.Unlock() + return + } + c := make(chan struct{}) + if s.wait == nil { + s.wait = make(map[uint]chan struct{}) + } + s.wait[id] = c + s.mu.Unlock() + <-c +} + +// End notifies the sequencer that the event numbered id has completed, +// allowing it to schedule the event numbered id+1. It is a run-time error +// to call End with an id that is not the number of the active event. +func (s *sequencer) End(id uint) { + s.mu.Lock() + if s.id != id { + s.mu.Unlock() + panic("out of sync") + } + id++ + s.id = id + if s.wait == nil { + s.wait = make(map[uint]chan struct{}) + } + c, ok := s.wait[id] + if ok { + delete(s.wait, id) + } + s.mu.Unlock() + if ok { + close(c) + } +} diff --git a/src/net/textproto/reader.go b/src/net/textproto/reader.go new file mode 100644 index 0000000..fc2590b --- /dev/null +++ b/src/net/textproto/reader.go @@ -0,0 +1,822 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "math" + "strconv" + "strings" + "sync" +) + +// A Reader implements convenience methods for reading requests +// or responses from a text protocol network connection. +type Reader struct { + R *bufio.Reader + dot *dotReader + buf []byte // a re-usable buffer for readContinuedLineSlice +} + +// NewReader returns a new Reader reading from r. +// +// To avoid denial of service attacks, the provided bufio.Reader +// should be reading from an io.LimitReader or similar Reader to bound +// the size of responses. +func NewReader(r *bufio.Reader) *Reader { + return &Reader{R: r} +} + +// ReadLine reads a single line from r, +// eliding the final \n or \r\n from the returned string. +func (r *Reader) ReadLine() (string, error) { + line, err := r.readLineSlice() + return string(line), err +} + +// ReadLineBytes is like ReadLine but returns a []byte instead of a string. +func (r *Reader) ReadLineBytes() ([]byte, error) { + line, err := r.readLineSlice() + if line != nil { + line = bytes.Clone(line) + } + return line, err +} + +func (r *Reader) readLineSlice() ([]byte, error) { + r.closeDot() + var line []byte + for { + l, more, err := r.R.ReadLine() + if err != nil { + return nil, err + } + // Avoid the copy if the first call produced a full line. + if line == nil && !more { + return l, nil + } + line = append(line, l...) + if !more { + break + } + } + return line, nil +} + +// ReadContinuedLine reads a possibly continued line from r, +// eliding the final trailing ASCII white space. +// Lines after the first are considered continuations if they +// begin with a space or tab character. In the returned data, +// continuation lines are separated from the previous line +// only by a single space: the newline and leading white space +// are removed. +// +// For example, consider this input: +// +// Line 1 +// continued... +// Line 2 +// +// The first call to ReadContinuedLine will return "Line 1 continued..." +// and the second will return "Line 2". +// +// Empty lines are never continued. +func (r *Reader) ReadContinuedLine() (string, error) { + line, err := r.readContinuedLineSlice(noValidation) + return string(line), err +} + +// trim returns s with leading and trailing spaces and tabs removed. +// It does not assume Unicode or UTF-8. +func trim(s []byte) []byte { + i := 0 + for i < len(s) && (s[i] == ' ' || s[i] == '\t') { + i++ + } + n := len(s) + for n > i && (s[n-1] == ' ' || s[n-1] == '\t') { + n-- + } + return s[i:n] +} + +// ReadContinuedLineBytes is like ReadContinuedLine but +// returns a []byte instead of a string. +func (r *Reader) ReadContinuedLineBytes() ([]byte, error) { + line, err := r.readContinuedLineSlice(noValidation) + if line != nil { + line = bytes.Clone(line) + } + return line, err +} + +// readContinuedLineSlice reads continued lines from the reader buffer, +// returning a byte slice with all lines. The validateFirstLine function +// is run on the first read line, and if it returns an error then this +// error is returned from readContinuedLineSlice. +func (r *Reader) readContinuedLineSlice(validateFirstLine func([]byte) error) ([]byte, error) { + if validateFirstLine == nil { + return nil, fmt.Errorf("missing validateFirstLine func") + } + + // Read the first line. + line, err := r.readLineSlice() + if err != nil { + return nil, err + } + if len(line) == 0 { // blank line - no continuation + return line, nil + } + + if err := validateFirstLine(line); err != nil { + return nil, err + } + + // Optimistically assume that we have started to buffer the next line + // and it starts with an ASCII letter (the next header key), or a blank + // line, so we can avoid copying that buffered data around in memory + // and skipping over non-existent whitespace. + if r.R.Buffered() > 1 { + peek, _ := r.R.Peek(2) + if len(peek) > 0 && (isASCIILetter(peek[0]) || peek[0] == '\n') || + len(peek) == 2 && peek[0] == '\r' && peek[1] == '\n' { + return trim(line), nil + } + } + + // ReadByte or the next readLineSlice will flush the read buffer; + // copy the slice into buf. + r.buf = append(r.buf[:0], trim(line)...) + + // Read continuation lines. + for r.skipSpace() > 0 { + line, err := r.readLineSlice() + if err != nil { + break + } + r.buf = append(r.buf, ' ') + r.buf = append(r.buf, trim(line)...) + } + return r.buf, nil +} + +// skipSpace skips R over all spaces and returns the number of bytes skipped. +func (r *Reader) skipSpace() int { + n := 0 + for { + c, err := r.R.ReadByte() + if err != nil { + // Bufio will keep err until next read. + break + } + if c != ' ' && c != '\t' { + r.R.UnreadByte() + break + } + n++ + } + return n +} + +func (r *Reader) readCodeLine(expectCode int) (code int, continued bool, message string, err error) { + line, err := r.ReadLine() + if err != nil { + return + } + return parseCodeLine(line, expectCode) +} + +func parseCodeLine(line string, expectCode int) (code int, continued bool, message string, err error) { + if len(line) < 4 || line[3] != ' ' && line[3] != '-' { + err = ProtocolError("short response: " + line) + return + } + continued = line[3] == '-' + code, err = strconv.Atoi(line[0:3]) + if err != nil || code < 100 { + err = ProtocolError("invalid response code: " + line) + return + } + message = line[4:] + if 1 <= expectCode && expectCode < 10 && code/100 != expectCode || + 10 <= expectCode && expectCode < 100 && code/10 != expectCode || + 100 <= expectCode && expectCode < 1000 && code != expectCode { + err = &Error{code, message} + } + return +} + +// ReadCodeLine reads a response code line of the form +// +// code message +// +// where code is a three-digit status code and the message +// extends to the rest of the line. An example of such a line is: +// +// 220 plan9.bell-labs.com ESMTP +// +// If the prefix of the status does not match the digits in expectCode, +// ReadCodeLine returns with err set to &Error{code, message}. +// For example, if expectCode is 31, an error will be returned if +// the status is not in the range [310,319]. +// +// If the response is multi-line, ReadCodeLine returns an error. +// +// An expectCode <= 0 disables the check of the status code. +func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err error) { + code, continued, message, err := r.readCodeLine(expectCode) + if err == nil && continued { + err = ProtocolError("unexpected multi-line response: " + message) + } + return +} + +// ReadResponse reads a multi-line response of the form: +// +// code-message line 1 +// code-message line 2 +// ... +// code message line n +// +// where code is a three-digit status code. The first line starts with the +// code and a hyphen. The response is terminated by a line that starts +// with the same code followed by a space. Each line in message is +// separated by a newline (\n). +// +// See page 36 of RFC 959 (https://www.ietf.org/rfc/rfc959.txt) for +// details of another form of response accepted: +// +// code-message line 1 +// message line 2 +// ... +// code message line n +// +// If the prefix of the status does not match the digits in expectCode, +// ReadResponse returns with err set to &Error{code, message}. +// For example, if expectCode is 31, an error will be returned if +// the status is not in the range [310,319]. +// +// An expectCode <= 0 disables the check of the status code. +func (r *Reader) ReadResponse(expectCode int) (code int, message string, err error) { + code, continued, message, err := r.readCodeLine(expectCode) + multi := continued + for continued { + line, err := r.ReadLine() + if err != nil { + return 0, "", err + } + + var code2 int + var moreMessage string + code2, continued, moreMessage, err = parseCodeLine(line, 0) + if err != nil || code2 != code { + message += "\n" + strings.TrimRight(line, "\r\n") + continued = true + continue + } + message += "\n" + moreMessage + } + if err != nil && multi && message != "" { + // replace one line error message with all lines (full message) + err = &Error{code, message} + } + return +} + +// DotReader returns a new Reader that satisfies Reads using the +// decoded text of a dot-encoded block read from r. +// The returned Reader is only valid until the next call +// to a method on r. +// +// Dot encoding is a common framing used for data blocks +// in text protocols such as SMTP. The data consists of a sequence +// of lines, each of which ends in "\r\n". The sequence itself +// ends at a line containing just a dot: ".\r\n". Lines beginning +// with a dot are escaped with an additional dot to avoid +// looking like the end of the sequence. +// +// The decoded form returned by the Reader's Read method +// rewrites the "\r\n" line endings into the simpler "\n", +// removes leading dot escapes if present, and stops with error io.EOF +// after consuming (and discarding) the end-of-sequence line. +func (r *Reader) DotReader() io.Reader { + r.closeDot() + r.dot = &dotReader{r: r} + return r.dot +} + +type dotReader struct { + r *Reader + state int +} + +// Read satisfies reads by decoding dot-encoded data read from d.r. +func (d *dotReader) Read(b []byte) (n int, err error) { + // Run data through a simple state machine to + // elide leading dots, rewrite trailing \r\n into \n, + // and detect ending .\r\n line. + const ( + stateBeginLine = iota // beginning of line; initial state; must be zero + stateDot // read . at beginning of line + stateDotCR // read .\r at beginning of line + stateCR // read \r (possibly at end of line) + stateData // reading data in middle of line + stateEOF // reached .\r\n end marker line + ) + br := d.r.R + for n < len(b) && d.state != stateEOF { + var c byte + c, err = br.ReadByte() + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + break + } + switch d.state { + case stateBeginLine: + if c == '.' { + d.state = stateDot + continue + } + if c == '\r' { + d.state = stateCR + continue + } + d.state = stateData + + case stateDot: + if c == '\r' { + d.state = stateDotCR + continue + } + if c == '\n' { + d.state = stateEOF + continue + } + d.state = stateData + + case stateDotCR: + if c == '\n' { + d.state = stateEOF + continue + } + // Not part of .\r\n. + // Consume leading dot and emit saved \r. + br.UnreadByte() + c = '\r' + d.state = stateData + + case stateCR: + if c == '\n' { + d.state = stateBeginLine + break + } + // Not part of \r\n. Emit saved \r + br.UnreadByte() + c = '\r' + d.state = stateData + + case stateData: + if c == '\r' { + d.state = stateCR + continue + } + if c == '\n' { + d.state = stateBeginLine + } + } + b[n] = c + n++ + } + if err == nil && d.state == stateEOF { + err = io.EOF + } + if err != nil && d.r.dot == d { + d.r.dot = nil + } + return +} + +// closeDot drains the current DotReader if any, +// making sure that it reads until the ending dot line. +func (r *Reader) closeDot() { + if r.dot == nil { + return + } + buf := make([]byte, 128) + for r.dot != nil { + // When Read reaches EOF or an error, + // it will set r.dot == nil. + r.dot.Read(buf) + } +} + +// ReadDotBytes reads a dot-encoding and returns the decoded data. +// +// See the documentation for the DotReader method for details about dot-encoding. +func (r *Reader) ReadDotBytes() ([]byte, error) { + return io.ReadAll(r.DotReader()) +} + +// ReadDotLines reads a dot-encoding and returns a slice +// containing the decoded lines, with the final \r\n or \n elided from each. +// +// See the documentation for the DotReader method for details about dot-encoding. +func (r *Reader) ReadDotLines() ([]string, error) { + // We could use ReadDotBytes and then Split it, + // but reading a line at a time avoids needing a + // large contiguous block of memory and is simpler. + var v []string + var err error + for { + var line string + line, err = r.ReadLine() + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + break + } + + // Dot by itself marks end; otherwise cut one dot. + if len(line) > 0 && line[0] == '.' { + if len(line) == 1 { + break + } + line = line[1:] + } + v = append(v, line) + } + return v, err +} + +var colon = []byte(":") + +// ReadMIMEHeader reads a MIME-style header from r. +// The header is a sequence of possibly continued Key: Value lines +// ending in a blank line. +// The returned map m maps CanonicalMIMEHeaderKey(key) to a +// sequence of values in the same order encountered in the input. +// +// For example, consider this input: +// +// My-Key: Value 1 +// Long-Key: Even +// Longer Value +// My-Key: Value 2 +// +// Given that input, ReadMIMEHeader returns the map: +// +// map[string][]string{ +// "My-Key": {"Value 1", "Value 2"}, +// "Long-Key": {"Even Longer Value"}, +// } +func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) { + return readMIMEHeader(r, math.MaxInt64, math.MaxInt64) +} + +// readMIMEHeader is a version of ReadMIMEHeader which takes a limit on the header size. +// It is called by the mime/multipart package. +func readMIMEHeader(r *Reader, maxMemory, maxHeaders int64) (MIMEHeader, error) { + // Avoid lots of small slice allocations later by allocating one + // large one ahead of time which we'll cut up into smaller + // slices. If this isn't big enough later, we allocate small ones. + var strs []string + hint := r.upcomingHeaderKeys() + if hint > 0 { + if hint > 1000 { + hint = 1000 // set a cap to avoid overallocation + } + strs = make([]string, hint) + } + + m := make(MIMEHeader, hint) + + // Account for 400 bytes of overhead for the MIMEHeader, plus 200 bytes per entry. + // Benchmarking map creation as of go1.20, a one-entry MIMEHeader is 416 bytes and large + // MIMEHeaders average about 200 bytes per entry. + maxMemory -= 400 + const mapEntryOverhead = 200 + + // The first line cannot start with a leading space. + if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') { + line, err := r.readLineSlice() + if err != nil { + return m, err + } + return m, ProtocolError("malformed MIME header initial line: " + string(line)) + } + + for { + kv, err := r.readContinuedLineSlice(mustHaveFieldNameColon) + if len(kv) == 0 { + return m, err + } + + // Key ends at first colon. + k, v, ok := bytes.Cut(kv, colon) + if !ok { + return m, ProtocolError("malformed MIME header line: " + string(kv)) + } + key, ok := canonicalMIMEHeaderKey(k) + if !ok { + return m, ProtocolError("malformed MIME header line: " + string(kv)) + } + for _, c := range v { + if !validHeaderValueByte(c) { + return m, ProtocolError("malformed MIME header line: " + string(kv)) + } + } + + // As per RFC 7230 field-name is a token, tokens consist of one or more chars. + // We could return a ProtocolError here, but better to be liberal in what we + // accept, so if we get an empty key, skip it. + if key == "" { + continue + } + + maxHeaders-- + if maxHeaders < 0 { + return nil, errors.New("message too large") + } + + // Skip initial spaces in value. + value := string(bytes.TrimLeft(v, " \t")) + + vv := m[key] + if vv == nil { + maxMemory -= int64(len(key)) + maxMemory -= mapEntryOverhead + } + maxMemory -= int64(len(value)) + if maxMemory < 0 { + // TODO: This should be a distinguishable error (ErrMessageTooLarge) + // to allow mime/multipart to detect it. + return m, errors.New("message too large") + } + if vv == nil && len(strs) > 0 { + // More than likely this will be a single-element key. + // Most headers aren't multi-valued. + // Set the capacity on strs[0] to 1, so any future append + // won't extend the slice into the other strings. + vv, strs = strs[:1:1], strs[1:] + vv[0] = value + m[key] = vv + } else { + m[key] = append(vv, value) + } + + if err != nil { + return m, err + } + } +} + +// noValidation is a no-op validation func for readContinuedLineSlice +// that permits any lines. +func noValidation(_ []byte) error { return nil } + +// mustHaveFieldNameColon ensures that, per RFC 7230, the +// field-name is on a single line, so the first line must +// contain a colon. +func mustHaveFieldNameColon(line []byte) error { + if bytes.IndexByte(line, ':') < 0 { + return ProtocolError(fmt.Sprintf("malformed MIME header: missing colon: %q", line)) + } + return nil +} + +var nl = []byte("\n") + +// upcomingHeaderKeys returns an approximation of the number of keys +// that will be in this header. If it gets confused, it returns 0. +func (r *Reader) upcomingHeaderKeys() (n int) { + // Try to determine the 'hint' size. + r.R.Peek(1) // force a buffer load if empty + s := r.R.Buffered() + if s == 0 { + return + } + peek, _ := r.R.Peek(s) + for len(peek) > 0 && n < 1000 { + var line []byte + line, peek, _ = bytes.Cut(peek, nl) + if len(line) == 0 || (len(line) == 1 && line[0] == '\r') { + // Blank line separating headers from the body. + break + } + if line[0] == ' ' || line[0] == '\t' { + // Folded continuation of the previous line. + continue + } + n++ + } + return n +} + +// CanonicalMIMEHeaderKey returns the canonical format of the +// MIME header key s. The canonicalization converts the first +// letter and any letter following a hyphen to upper case; +// the rest are converted to lowercase. For example, the +// canonical key for "accept-encoding" is "Accept-Encoding". +// MIME header keys are assumed to be ASCII only. +// If s contains a space or invalid header field bytes, it is +// returned without modifications. +func CanonicalMIMEHeaderKey(s string) string { + // Quick check for canonical encoding. + upper := true + for i := 0; i < len(s); i++ { + c := s[i] + if !validHeaderFieldByte(c) { + return s + } + if upper && 'a' <= c && c <= 'z' { + s, _ = canonicalMIMEHeaderKey([]byte(s)) + return s + } + if !upper && 'A' <= c && c <= 'Z' { + s, _ = canonicalMIMEHeaderKey([]byte(s)) + return s + } + upper = c == '-' + } + return s +} + +const toLower = 'a' - 'A' + +// validHeaderFieldByte reports whether c is a valid byte in a header +// field name. RFC 7230 says: +// +// header-field = field-name ":" OWS field-value OWS +// field-name = token +// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / +// "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA +// token = 1*tchar +func validHeaderFieldByte(c byte) bool { + // mask is a 128-bit bitmap with 1s for allowed bytes, + // so that the byte c can be tested with a shift and an and. + // If c >= 128, then 1<<c and 1<<(c-64) will both be zero, + // and this function will return false. + const mask = 0 | + (1<<(10)-1)<<'0' | + (1<<(26)-1)<<'a' | + (1<<(26)-1)<<'A' | + 1<<'!' | + 1<<'#' | + 1<<'$' | + 1<<'%' | + 1<<'&' | + 1<<'\'' | + 1<<'*' | + 1<<'+' | + 1<<'-' | + 1<<'.' | + 1<<'^' | + 1<<'_' | + 1<<'`' | + 1<<'|' | + 1<<'~' + return ((uint64(1)<<c)&(mask&(1<<64-1)) | + (uint64(1)<<(c-64))&(mask>>64)) != 0 +} + +// validHeaderValueByte reports whether c is a valid byte in a header +// field value. RFC 7230 says: +// +// field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] +// field-vchar = VCHAR / obs-text +// obs-text = %x80-FF +// +// RFC 5234 says: +// +// HTAB = %x09 +// SP = %x20 +// VCHAR = %x21-7E +func validHeaderValueByte(c byte) bool { + // mask is a 128-bit bitmap with 1s for allowed bytes, + // so that the byte c can be tested with a shift and an and. + // If c >= 128, then 1<<c and 1<<(c-64) will both be zero. + // Since this is the obs-text range, we invert the mask to + // create a bitmap with 1s for disallowed bytes. + const mask = 0 | + (1<<(0x7f-0x21)-1)<<0x21 | // VCHAR: %x21-7E + 1<<0x20 | // SP: %x20 + 1<<0x09 // HTAB: %x09 + return ((uint64(1)<<c)&^(mask&(1<<64-1)) | + (uint64(1)<<(c-64))&^(mask>>64)) == 0 +} + +// canonicalMIMEHeaderKey is like CanonicalMIMEHeaderKey but is +// allowed to mutate the provided byte slice before returning the +// string. +// +// For invalid inputs (if a contains spaces or non-token bytes), a +// is unchanged and a string copy is returned. +// +// ok is true if the header key contains only valid characters and spaces. +// ReadMIMEHeader accepts header keys containing spaces, but does not +// canonicalize them. +func canonicalMIMEHeaderKey(a []byte) (_ string, ok bool) { + // See if a looks like a header key. If not, return it unchanged. + noCanon := false + for _, c := range a { + if validHeaderFieldByte(c) { + continue + } + // Don't canonicalize. + if c == ' ' { + // We accept invalid headers with a space before the + // colon, but must not canonicalize them. + // See https://go.dev/issue/34540. + noCanon = true + continue + } + return string(a), false + } + if noCanon { + return string(a), true + } + + upper := true + for i, c := range a { + // Canonicalize: first letter upper case + // and upper case after each dash. + // (Host, User-Agent, If-Modified-Since). + // MIME headers are ASCII only, so no Unicode issues. + if upper && 'a' <= c && c <= 'z' { + c -= toLower + } else if !upper && 'A' <= c && c <= 'Z' { + c += toLower + } + a[i] = c + upper = c == '-' // for next time + } + commonHeaderOnce.Do(initCommonHeader) + // The compiler recognizes m[string(byteSlice)] as a special + // case, so a copy of a's bytes into a new string does not + // happen in this map lookup: + if v := commonHeader[string(a)]; v != "" { + return v, true + } + return string(a), true +} + +// commonHeader interns common header strings. +var commonHeader map[string]string + +var commonHeaderOnce sync.Once + +func initCommonHeader() { + commonHeader = make(map[string]string) + for _, v := range []string{ + "Accept", + "Accept-Charset", + "Accept-Encoding", + "Accept-Language", + "Accept-Ranges", + "Cache-Control", + "Cc", + "Connection", + "Content-Id", + "Content-Language", + "Content-Length", + "Content-Transfer-Encoding", + "Content-Type", + "Cookie", + "Date", + "Dkim-Signature", + "Etag", + "Expires", + "From", + "Host", + "If-Modified-Since", + "If-None-Match", + "In-Reply-To", + "Last-Modified", + "Location", + "Message-Id", + "Mime-Version", + "Pragma", + "Received", + "Return-Path", + "Server", + "Set-Cookie", + "Subject", + "To", + "User-Agent", + "Via", + "X-Forwarded-For", + "X-Imforwards", + "X-Powered-By", + } { + commonHeader[v] = v + } +} diff --git a/src/net/textproto/reader_test.go b/src/net/textproto/reader_test.go new file mode 100644 index 0000000..696ae40 --- /dev/null +++ b/src/net/textproto/reader_test.go @@ -0,0 +1,525 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "bufio" + "bytes" + "io" + "net" + "reflect" + "runtime" + "strings" + "sync" + "testing" +) + +func reader(s string) *Reader { + return NewReader(bufio.NewReader(strings.NewReader(s))) +} + +func TestReadLine(t *testing.T) { + r := reader("line1\nline2\n") + s, err := r.ReadLine() + if s != "line1" || err != nil { + t.Fatalf("Line 1: %s, %v", s, err) + } + s, err = r.ReadLine() + if s != "line2" || err != nil { + t.Fatalf("Line 2: %s, %v", s, err) + } + s, err = r.ReadLine() + if s != "" || err != io.EOF { + t.Fatalf("EOF: %s, %v", s, err) + } +} + +func TestReadContinuedLine(t *testing.T) { + r := reader("line1\nline\n 2\nline3\n") + s, err := r.ReadContinuedLine() + if s != "line1" || err != nil { + t.Fatalf("Line 1: %s, %v", s, err) + } + s, err = r.ReadContinuedLine() + if s != "line 2" || err != nil { + t.Fatalf("Line 2: %s, %v", s, err) + } + s, err = r.ReadContinuedLine() + if s != "line3" || err != nil { + t.Fatalf("Line 3: %s, %v", s, err) + } + s, err = r.ReadContinuedLine() + if s != "" || err != io.EOF { + t.Fatalf("EOF: %s, %v", s, err) + } +} + +func TestReadCodeLine(t *testing.T) { + r := reader("123 hi\n234 bye\n345 no way\n") + code, msg, err := r.ReadCodeLine(0) + if code != 123 || msg != "hi" || err != nil { + t.Fatalf("Line 1: %d, %s, %v", code, msg, err) + } + code, msg, err = r.ReadCodeLine(23) + if code != 234 || msg != "bye" || err != nil { + t.Fatalf("Line 2: %d, %s, %v", code, msg, err) + } + code, msg, err = r.ReadCodeLine(346) + if code != 345 || msg != "no way" || err == nil { + t.Fatalf("Line 3: %d, %s, %v", code, msg, err) + } + if e, ok := err.(*Error); !ok || e.Code != code || e.Msg != msg { + t.Fatalf("Line 3: wrong error %v\n", err) + } + code, msg, err = r.ReadCodeLine(1) + if code != 0 || msg != "" || err != io.EOF { + t.Fatalf("EOF: %d, %s, %v", code, msg, err) + } +} + +func TestReadDotLines(t *testing.T) { + r := reader("dotlines\r\n.foo\r\n..bar\n...baz\nquux\r\n\r\n.\r\nanother\n") + s, err := r.ReadDotLines() + want := []string{"dotlines", "foo", ".bar", "..baz", "quux", ""} + if !reflect.DeepEqual(s, want) || err != nil { + t.Fatalf("ReadDotLines: %v, %v", s, err) + } + + s, err = r.ReadDotLines() + want = []string{"another"} + if !reflect.DeepEqual(s, want) || err != io.ErrUnexpectedEOF { + t.Fatalf("ReadDotLines2: %v, %v", s, err) + } +} + +func TestReadDotBytes(t *testing.T) { + r := reader("dotlines\r\n.foo\r\n..bar\n...baz\nquux\r\n\r\n.\r\nanot.her\r\n") + b, err := r.ReadDotBytes() + want := []byte("dotlines\nfoo\n.bar\n..baz\nquux\n\n") + if !reflect.DeepEqual(b, want) || err != nil { + t.Fatalf("ReadDotBytes: %q, %v", b, err) + } + + b, err = r.ReadDotBytes() + want = []byte("anot.her\n") + if !reflect.DeepEqual(b, want) || err != io.ErrUnexpectedEOF { + t.Fatalf("ReadDotBytes2: %q, %v", b, err) + } +} + +func TestReadMIMEHeader(t *testing.T) { + r := reader("my-key: Value 1 \r\nLong-key: Even \n Longer Value\r\nmy-Key: Value 2\r\n\n") + m, err := r.ReadMIMEHeader() + want := MIMEHeader{ + "My-Key": {"Value 1", "Value 2"}, + "Long-Key": {"Even Longer Value"}, + } + if !reflect.DeepEqual(m, want) || err != nil { + t.Fatalf("ReadMIMEHeader: %v, %v; want %v", m, err, want) + } +} + +func TestReadMIMEHeaderSingle(t *testing.T) { + r := reader("Foo: bar\n\n") + m, err := r.ReadMIMEHeader() + want := MIMEHeader{"Foo": {"bar"}} + if !reflect.DeepEqual(m, want) || err != nil { + t.Fatalf("ReadMIMEHeader: %v, %v; want %v", m, err, want) + } +} + +// TestReaderUpcomingHeaderKeys is testing an internal function, but it's very +// difficult to test well via the external API. +func TestReaderUpcomingHeaderKeys(t *testing.T) { + for _, test := range []struct { + input string + want int + }{{ + input: "", + want: 0, + }, { + input: "A: v", + want: 1, + }, { + input: "A: v\r\nB: v\r\n", + want: 2, + }, { + input: "A: v\nB: v\n", + want: 2, + }, { + input: "A: v\r\n continued\r\n still continued\r\nB: v\r\n\r\n", + want: 2, + }, { + input: "A: v\r\n\r\nB: v\r\nC: v\r\n", + want: 1, + }, { + input: "A: v" + strings.Repeat("\n", 1000), + want: 1, + }} { + r := reader(test.input) + got := r.upcomingHeaderKeys() + if test.want != got { + t.Fatalf("upcomingHeaderKeys(%q): %v; want %v", test.input, got, test.want) + } + } +} + +func TestReadMIMEHeaderNoKey(t *testing.T) { + r := reader(": bar\ntest-1: 1\n\n") + m, err := r.ReadMIMEHeader() + want := MIMEHeader{"Test-1": {"1"}} + if !reflect.DeepEqual(m, want) || err != nil { + t.Fatalf("ReadMIMEHeader: %v, %v; want %v", m, err, want) + } +} + +func TestLargeReadMIMEHeader(t *testing.T) { + data := make([]byte, 16*1024) + for i := 0; i < len(data); i++ { + data[i] = 'x' + } + sdata := string(data) + r := reader("Cookie: " + sdata + "\r\n\n") + m, err := r.ReadMIMEHeader() + if err != nil { + t.Fatalf("ReadMIMEHeader: %v", err) + } + cookie := m.Get("Cookie") + if cookie != sdata { + t.Fatalf("ReadMIMEHeader: %v bytes, want %v bytes", len(cookie), len(sdata)) + } +} + +// TestReadMIMEHeaderNonCompliant checks that we don't normalize headers +// with spaces before colons, and accept spaces in keys. +func TestReadMIMEHeaderNonCompliant(t *testing.T) { + // These invalid headers will be rejected by net/http according to RFC 7230. + r := reader("Foo: bar\r\n" + + "Content-Language: en\r\n" + + "SID : 0\r\n" + + "Audio Mode : None\r\n" + + "Privilege : 127\r\n\r\n") + m, err := r.ReadMIMEHeader() + want := MIMEHeader{ + "Foo": {"bar"}, + "Content-Language": {"en"}, + "SID ": {"0"}, + "Audio Mode ": {"None"}, + "Privilege ": {"127"}, + } + if !reflect.DeepEqual(m, want) || err != nil { + t.Fatalf("ReadMIMEHeader =\n%v, %v; want:\n%v", m, err, want) + } +} + +func TestReadMIMEHeaderMalformed(t *testing.T) { + inputs := []string{ + "No colon first line\r\nFoo: foo\r\n\r\n", + " No colon first line with leading space\r\nFoo: foo\r\n\r\n", + "\tNo colon first line with leading tab\r\nFoo: foo\r\n\r\n", + " First: line with leading space\r\nFoo: foo\r\n\r\n", + "\tFirst: line with leading tab\r\nFoo: foo\r\n\r\n", + "Foo: foo\r\nNo colon second line\r\n\r\n", + "Foo-\n\tBar: foo\r\n\r\n", + "Foo-\r\n\tBar: foo\r\n\r\n", + "Foo\r\n\t: foo\r\n\r\n", + "Foo-\n\tBar", + "Foo \tBar: foo\r\n\r\n", + } + for _, input := range inputs { + r := reader(input) + if m, err := r.ReadMIMEHeader(); err == nil || err == io.EOF { + t.Errorf("ReadMIMEHeader(%q) = %v, %v; want nil, err", input, m, err) + } + } +} + +func TestReadMIMEHeaderBytes(t *testing.T) { + for i := 0; i <= 0xff; i++ { + s := "Foo" + string(rune(i)) + "Bar: foo\r\n\r\n" + r := reader(s) + wantErr := true + switch { + case i >= '0' && i <= '9': + wantErr = false + case i >= 'a' && i <= 'z': + wantErr = false + case i >= 'A' && i <= 'Z': + wantErr = false + case i == '!' || i == '#' || i == '$' || i == '%' || i == '&' || i == '\'' || i == '*' || i == '+' || i == '-' || i == '.' || i == '^' || i == '_' || i == '`' || i == '|' || i == '~': + wantErr = false + case i == ':': + // Special case: "Foo:Bar: foo" is the header "Foo". + wantErr = false + case i == ' ': + wantErr = false + } + m, err := r.ReadMIMEHeader() + if err != nil != wantErr { + t.Errorf("ReadMIMEHeader(%q) = %v, %v; want error=%v", s, m, err, wantErr) + } + } + for i := 0; i <= 0xff; i++ { + s := "Foo: foo" + string(rune(i)) + "bar\r\n\r\n" + r := reader(s) + wantErr := true + switch { + case i >= 0x21 && i <= 0x7e: + wantErr = false + case i == ' ': + wantErr = false + case i == '\t': + wantErr = false + case i >= 0x80 && i <= 0xff: + wantErr = false + } + m, err := r.ReadMIMEHeader() + if (err != nil) != wantErr { + t.Errorf("ReadMIMEHeader(%q) = %v, %v; want error=%v", s, m, err, wantErr) + } + } +} + +// Test that continued lines are properly trimmed. Issue 11204. +func TestReadMIMEHeaderTrimContinued(t *testing.T) { + // In this header, \n and \r\n terminated lines are mixed on purpose. + // We expect each line to be trimmed (prefix and suffix) before being concatenated. + // Keep the spaces as they are. + r := reader("" + // for code formatting purpose. + "a:\n" + + " 0 \r\n" + + "b:1 \t\r\n" + + "c: 2\r\n" + + " 3\t\n" + + " \t 4 \r\n\n") + m, err := r.ReadMIMEHeader() + if err != nil { + t.Fatal(err) + } + want := MIMEHeader{ + "A": {"0"}, + "B": {"1"}, + "C": {"2 3 4"}, + } + if !reflect.DeepEqual(m, want) { + t.Fatalf("ReadMIMEHeader mismatch.\n got: %q\nwant: %q", m, want) + } +} + +// Test that reading a header doesn't overallocate. Issue 58975. +func TestReadMIMEHeaderAllocations(t *testing.T) { + var totalAlloc uint64 + const count = 200 + for i := 0; i < count; i++ { + r := reader("A: b\r\n\r\n" + strings.Repeat("\n", 4096)) + var m1, m2 runtime.MemStats + runtime.ReadMemStats(&m1) + _, err := r.ReadMIMEHeader() + if err != nil { + t.Fatalf("ReadMIMEHeader: %v", err) + } + runtime.ReadMemStats(&m2) + totalAlloc += m2.TotalAlloc - m1.TotalAlloc + } + // 32k is large and we actually allocate substantially less, + // but prior to the fix for #58975 we allocated ~400k in this case. + if got, want := totalAlloc/count, uint64(32768); got > want { + t.Fatalf("ReadMIMEHeader allocated %v bytes, want < %v", got, want) + } +} + +type readResponseTest struct { + in string + inCode int + wantCode int + wantMsg string +} + +var readResponseTests = []readResponseTest{ + {"230-Anonymous access granted, restrictions apply\n" + + "Read the file README.txt,\n" + + "230 please", + 23, + 230, + "Anonymous access granted, restrictions apply\nRead the file README.txt,\n please", + }, + + {"230 Anonymous access granted, restrictions apply\n", + 23, + 230, + "Anonymous access granted, restrictions apply", + }, + + {"400-A\n400-B\n400 C", + 4, + 400, + "A\nB\nC", + }, + + {"400-A\r\n400-B\r\n400 C\r\n", + 4, + 400, + "A\nB\nC", + }, +} + +// See https://www.ietf.org/rfc/rfc959.txt page 36. +func TestRFC959Lines(t *testing.T) { + for i, tt := range readResponseTests { + r := reader(tt.in + "\nFOLLOWING DATA") + code, msg, err := r.ReadResponse(tt.inCode) + if err != nil { + t.Errorf("#%d: ReadResponse: %v", i, err) + continue + } + if code != tt.wantCode { + t.Errorf("#%d: code=%d, want %d", i, code, tt.wantCode) + } + if msg != tt.wantMsg { + t.Errorf("#%d: msg=%q, want %q", i, msg, tt.wantMsg) + } + } +} + +// Test that multi-line errors are appropriately and fully read. Issue 10230. +func TestReadMultiLineError(t *testing.T) { + r := reader("550-5.1.1 The email account that you tried to reach does not exist. Please try\n" + + "550-5.1.1 double-checking the recipient's email address for typos or\n" + + "550-5.1.1 unnecessary spaces. Learn more at\n" + + "Unexpected but legal text!\n" + + "550 5.1.1 https://support.google.com/mail/answer/6596 h20si25154304pfd.166 - gsmtp\n") + + wantMsg := "5.1.1 The email account that you tried to reach does not exist. Please try\n" + + "5.1.1 double-checking the recipient's email address for typos or\n" + + "5.1.1 unnecessary spaces. Learn more at\n" + + "Unexpected but legal text!\n" + + "5.1.1 https://support.google.com/mail/answer/6596 h20si25154304pfd.166 - gsmtp" + + code, msg, err := r.ReadResponse(250) + if err == nil { + t.Errorf("ReadResponse: no error, want error") + } + if code != 550 { + t.Errorf("ReadResponse: code=%d, want %d", code, 550) + } + if msg != wantMsg { + t.Errorf("ReadResponse: msg=%q, want %q", msg, wantMsg) + } + if err != nil && err.Error() != "550 "+wantMsg { + t.Errorf("ReadResponse: error=%q, want %q", err.Error(), "550 "+wantMsg) + } +} + +func TestCommonHeaders(t *testing.T) { + commonHeaderOnce.Do(initCommonHeader) + for h := range commonHeader { + if h != CanonicalMIMEHeaderKey(h) { + t.Errorf("Non-canonical header %q in commonHeader", h) + } + } + b := []byte("content-Length") + want := "Content-Length" + n := testing.AllocsPerRun(200, func() { + if x, _ := canonicalMIMEHeaderKey(b); x != want { + t.Fatalf("canonicalMIMEHeaderKey(%q) = %q; want %q", b, x, want) + } + }) + if n > 0 { + t.Errorf("canonicalMIMEHeaderKey allocs = %v; want 0", n) + } +} + +func TestIssue46363(t *testing.T) { + // Regression test for data race reported in issue 46363: + // ReadMIMEHeader reads commonHeader before commonHeader has been initialized. + // Run this test with the race detector enabled to catch the reported data race. + + // Reset commonHeaderOnce, so that commonHeader will have to be initialized + commonHeaderOnce = sync.Once{} + commonHeader = nil + + // Test for data race by calling ReadMIMEHeader and CanonicalMIMEHeaderKey concurrently + + // Send MIME header over net.Conn + r, w := net.Pipe() + go func() { + // ReadMIMEHeader calls canonicalMIMEHeaderKey, which reads from commonHeader + NewConn(r).ReadMIMEHeader() + }() + w.Write([]byte("A: 1\r\nB: 2\r\nC: 3\r\n\r\n")) + + // CanonicalMIMEHeaderKey calls commonHeaderOnce.Do(initCommonHeader) which initializes commonHeader + CanonicalMIMEHeaderKey("a") + + if commonHeader == nil { + t.Fatal("CanonicalMIMEHeaderKey should initialize commonHeader") + } +} + +var clientHeaders = strings.Replace(`Host: golang.org +Connection: keep-alive +Cache-Control: max-age=0 +Accept: application/xml,application/xhtml+xml,text/html;q=0.9,text/plain;q=0.8,image/png,*/*;q=0.5 +User-Agent: Mozilla/5.0 (X11; U; Linux x86_64; en-US) AppleWebKit/534.3 (KHTML, like Gecko) Chrome/6.0.472.63 Safari/534.3 +Accept-Encoding: gzip,deflate,sdch +Accept-Language: en-US,en;q=0.8,fr-CH;q=0.6 +Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 +COOKIE: __utma=000000000.0000000000.0000000000.0000000000.0000000000.00; __utmb=000000000.0.00.0000000000; __utmc=000000000; __utmz=000000000.0000000000.00.0.utmcsr=code.google.com|utmccn=(referral)|utmcmd=referral|utmcct=/p/go/issues/detail +Non-Interned: test + +`, "\n", "\r\n", -1) + +var serverHeaders = strings.Replace(`Content-Type: text/html; charset=utf-8 +Content-Encoding: gzip +Date: Thu, 27 Sep 2012 09:03:33 GMT +Server: Google Frontend +Cache-Control: private +Content-Length: 2298 +VIA: 1.1 proxy.example.com:80 (XXX/n.n.n-nnn) +Connection: Close +Non-Interned: test + +`, "\n", "\r\n", -1) + +func BenchmarkReadMIMEHeader(b *testing.B) { + b.ReportAllocs() + for _, set := range []struct { + name string + headers string + }{ + {"client_headers", clientHeaders}, + {"server_headers", serverHeaders}, + } { + b.Run(set.name, func(b *testing.B) { + var buf bytes.Buffer + br := bufio.NewReader(&buf) + r := NewReader(br) + + for i := 0; i < b.N; i++ { + buf.WriteString(set.headers) + if _, err := r.ReadMIMEHeader(); err != nil { + b.Fatal(err) + } + } + }) + } +} + +func BenchmarkUncommon(b *testing.B) { + b.ReportAllocs() + var buf bytes.Buffer + br := bufio.NewReader(&buf) + r := NewReader(br) + for i := 0; i < b.N; i++ { + buf.WriteString("uncommon-header-for-benchmark: foo\r\n\r\n") + h, err := r.ReadMIMEHeader() + if err != nil { + b.Fatal(err) + } + if _, ok := h["Uncommon-Header-For-Benchmark"]; !ok { + b.Fatal("Missing result header.") + } + } +} diff --git a/src/net/textproto/textproto.go b/src/net/textproto/textproto.go new file mode 100644 index 0000000..70038d5 --- /dev/null +++ b/src/net/textproto/textproto.go @@ -0,0 +1,152 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package textproto implements generic support for text-based request/response +// protocols in the style of HTTP, NNTP, and SMTP. +// +// The package provides: +// +// Error, which represents a numeric error response from +// a server. +// +// Pipeline, to manage pipelined requests and responses +// in a client. +// +// Reader, to read numeric response code lines, +// key: value headers, lines wrapped with leading spaces +// on continuation lines, and whole text blocks ending +// with a dot on a line by itself. +// +// Writer, to write dot-encoded text blocks. +// +// Conn, a convenient packaging of Reader, Writer, and Pipeline for use +// with a single network connection. +package textproto + +import ( + "bufio" + "fmt" + "io" + "net" +) + +// An Error represents a numeric error response from a server. +type Error struct { + Code int + Msg string +} + +func (e *Error) Error() string { + return fmt.Sprintf("%03d %s", e.Code, e.Msg) +} + +// A ProtocolError describes a protocol violation such +// as an invalid response or a hung-up connection. +type ProtocolError string + +func (p ProtocolError) Error() string { + return string(p) +} + +// A Conn represents a textual network protocol connection. +// It consists of a Reader and Writer to manage I/O +// and a Pipeline to sequence concurrent requests on the connection. +// These embedded types carry methods with them; +// see the documentation of those types for details. +type Conn struct { + Reader + Writer + Pipeline + conn io.ReadWriteCloser +} + +// NewConn returns a new Conn using conn for I/O. +func NewConn(conn io.ReadWriteCloser) *Conn { + return &Conn{ + Reader: Reader{R: bufio.NewReader(conn)}, + Writer: Writer{W: bufio.NewWriter(conn)}, + conn: conn, + } +} + +// Close closes the connection. +func (c *Conn) Close() error { + return c.conn.Close() +} + +// Dial connects to the given address on the given network using net.Dial +// and then returns a new Conn for the connection. +func Dial(network, addr string) (*Conn, error) { + c, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + return NewConn(c), nil +} + +// Cmd is a convenience method that sends a command after +// waiting its turn in the pipeline. The command text is the +// result of formatting format with args and appending \r\n. +// Cmd returns the id of the command, for use with StartResponse and EndResponse. +// +// For example, a client might run a HELP command that returns a dot-body +// by using: +// +// id, err := c.Cmd("HELP") +// if err != nil { +// return nil, err +// } +// +// c.StartResponse(id) +// defer c.EndResponse(id) +// +// if _, _, err = c.ReadCodeLine(110); err != nil { +// return nil, err +// } +// text, err := c.ReadDotBytes() +// if err != nil { +// return nil, err +// } +// return c.ReadCodeLine(250) +func (c *Conn) Cmd(format string, args ...any) (id uint, err error) { + id = c.Next() + c.StartRequest(id) + err = c.PrintfLine(format, args...) + c.EndRequest(id) + if err != nil { + return 0, err + } + return id, nil +} + +// TrimString returns s without leading and trailing ASCII space. +func TrimString(s string) string { + for len(s) > 0 && isASCIISpace(s[0]) { + s = s[1:] + } + for len(s) > 0 && isASCIISpace(s[len(s)-1]) { + s = s[:len(s)-1] + } + return s +} + +// TrimBytes returns b without leading and trailing ASCII space. +func TrimBytes(b []byte) []byte { + for len(b) > 0 && isASCIISpace(b[0]) { + b = b[1:] + } + for len(b) > 0 && isASCIISpace(b[len(b)-1]) { + b = b[:len(b)-1] + } + return b +} + +func isASCIISpace(b byte) bool { + return b == ' ' || b == '\t' || b == '\n' || b == '\r' +} + +func isASCIILetter(b byte) bool { + b |= 0x20 // make lower case + return 'a' <= b && b <= 'z' +} diff --git a/src/net/textproto/writer.go b/src/net/textproto/writer.go new file mode 100644 index 0000000..2ece3f5 --- /dev/null +++ b/src/net/textproto/writer.go @@ -0,0 +1,119 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "bufio" + "fmt" + "io" +) + +// A Writer implements convenience methods for writing +// requests or responses to a text protocol network connection. +type Writer struct { + W *bufio.Writer + dot *dotWriter +} + +// NewWriter returns a new Writer writing to w. +func NewWriter(w *bufio.Writer) *Writer { + return &Writer{W: w} +} + +var crnl = []byte{'\r', '\n'} +var dotcrnl = []byte{'.', '\r', '\n'} + +// PrintfLine writes the formatted output followed by \r\n. +func (w *Writer) PrintfLine(format string, args ...any) error { + w.closeDot() + fmt.Fprintf(w.W, format, args...) + w.W.Write(crnl) + return w.W.Flush() +} + +// DotWriter returns a writer that can be used to write a dot-encoding to w. +// It takes care of inserting leading dots when necessary, +// translating line-ending \n into \r\n, and adding the final .\r\n line +// when the DotWriter is closed. The caller should close the +// DotWriter before the next call to a method on w. +// +// See the documentation for Reader's DotReader method for details about dot-encoding. +func (w *Writer) DotWriter() io.WriteCloser { + w.closeDot() + w.dot = &dotWriter{w: w} + return w.dot +} + +func (w *Writer) closeDot() { + if w.dot != nil { + w.dot.Close() // sets w.dot = nil + } +} + +type dotWriter struct { + w *Writer + state int +} + +const ( + wstateBegin = iota // initial state; must be zero + wstateBeginLine // beginning of line + wstateCR // wrote \r (possibly at end of line) + wstateData // writing data in middle of line +) + +func (d *dotWriter) Write(b []byte) (n int, err error) { + bw := d.w.W + for n < len(b) { + c := b[n] + switch d.state { + case wstateBegin, wstateBeginLine: + d.state = wstateData + if c == '.' { + // escape leading dot + bw.WriteByte('.') + } + fallthrough + + case wstateData: + if c == '\r' { + d.state = wstateCR + } + if c == '\n' { + bw.WriteByte('\r') + d.state = wstateBeginLine + } + + case wstateCR: + d.state = wstateData + if c == '\n' { + d.state = wstateBeginLine + } + } + if err = bw.WriteByte(c); err != nil { + break + } + n++ + } + return +} + +func (d *dotWriter) Close() error { + if d.w.dot == d { + d.w.dot = nil + } + bw := d.w.W + switch d.state { + default: + bw.WriteByte('\r') + fallthrough + case wstateCR: + bw.WriteByte('\n') + fallthrough + case wstateBeginLine: + bw.Write(dotcrnl) + } + return bw.Flush() +} diff --git a/src/net/textproto/writer_test.go b/src/net/textproto/writer_test.go new file mode 100644 index 0000000..8f11b10 --- /dev/null +++ b/src/net/textproto/writer_test.go @@ -0,0 +1,61 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package textproto + +import ( + "bufio" + "strings" + "testing" +) + +func TestPrintfLine(t *testing.T) { + var buf strings.Builder + w := NewWriter(bufio.NewWriter(&buf)) + err := w.PrintfLine("foo %d", 123) + if s := buf.String(); s != "foo 123\r\n" || err != nil { + t.Fatalf("s=%q; err=%s", s, err) + } +} + +func TestDotWriter(t *testing.T) { + var buf strings.Builder + w := NewWriter(bufio.NewWriter(&buf)) + d := w.DotWriter() + n, err := d.Write([]byte("abc\n.def\n..ghi\n.jkl\n.")) + if n != 21 || err != nil { + t.Fatalf("Write: %d, %s", n, err) + } + d.Close() + want := "abc\r\n..def\r\n...ghi\r\n..jkl\r\n..\r\n.\r\n" + if s := buf.String(); s != want { + t.Fatalf("wrote %q", s) + } +} + +func TestDotWriterCloseEmptyWrite(t *testing.T) { + var buf strings.Builder + w := NewWriter(bufio.NewWriter(&buf)) + d := w.DotWriter() + n, err := d.Write([]byte{}) + if n != 0 || err != nil { + t.Fatalf("Write: %d, %s", n, err) + } + d.Close() + want := "\r\n.\r\n" + if s := buf.String(); s != want { + t.Fatalf("wrote %q; want %q", s, want) + } +} + +func TestDotWriterCloseNoWrite(t *testing.T) { + var buf strings.Builder + w := NewWriter(bufio.NewWriter(&buf)) + d := w.DotWriter() + d.Close() + want := "\r\n.\r\n" + if s := buf.String(); s != want { + t.Fatalf("wrote %q; want %q", s, want) + } +} |