diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-16 19:23:18 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-16 19:23:18 +0000 |
commit | 43a123c1ae6613b3efeed291fa552ecd909d3acf (patch) | |
tree | fd92518b7024bc74031f78a1cf9e454b65e73665 /src/net/http/httptest | |
parent | Initial commit. (diff) | |
download | golang-1.20-43a123c1ae6613b3efeed291fa552ecd909d3acf.tar.xz golang-1.20-43a123c1ae6613b3efeed291fa552ecd909d3acf.zip |
Adding upstream version 1.20.14.upstream/1.20.14upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
-rw-r--r-- | src/net/http/httptest/example_test.go | 99 | ||||
-rw-r--r-- | src/net/http/httptest/httptest.go | 90 | ||||
-rw-r--r-- | src/net/http/httptest/httptest_test.go | 179 | ||||
-rw-r--r-- | src/net/http/httptest/recorder.go | 255 | ||||
-rw-r--r-- | src/net/http/httptest/recorder_test.go | 371 | ||||
-rw-r--r-- | src/net/http/httptest/server.go | 385 | ||||
-rw-r--r-- | src/net/http/httptest/server_test.go | 294 |
7 files changed, 1673 insertions, 0 deletions
diff --git a/src/net/http/httptest/example_test.go b/src/net/http/httptest/example_test.go new file mode 100644 index 0000000..a673843 --- /dev/null +++ b/src/net/http/httptest/example_test.go @@ -0,0 +1,99 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest_test + +import ( + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" +) + +func ExampleResponseRecorder() { + handler := func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "<html><body>Hello World!</body></html>") + } + + req := httptest.NewRequest("GET", "http://example.com/foo", nil) + w := httptest.NewRecorder() + handler(w, req) + + resp := w.Result() + body, _ := io.ReadAll(resp.Body) + + fmt.Println(resp.StatusCode) + fmt.Println(resp.Header.Get("Content-Type")) + fmt.Println(string(body)) + + // Output: + // 200 + // text/html; charset=utf-8 + // <html><body>Hello World!</body></html> +} + +func ExampleServer() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, client") + })) + defer ts.Close() + + res, err := http.Get(ts.URL) + if err != nil { + log.Fatal(err) + } + greeting, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%s", greeting) + // Output: Hello, client +} + +func ExampleServer_hTTP2() { + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Hello, %s", r.Proto) + })) + ts.EnableHTTP2 = true + ts.StartTLS() + defer ts.Close() + + res, err := ts.Client().Get(ts.URL) + if err != nil { + log.Fatal(err) + } + greeting, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + log.Fatal(err) + } + fmt.Printf("%s", greeting) + + // Output: Hello, HTTP/2.0 +} + +func ExampleNewTLSServer() { + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, client") + })) + defer ts.Close() + + client := ts.Client() + res, err := client.Get(ts.URL) + if err != nil { + log.Fatal(err) + } + + greeting, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%s", greeting) + // Output: Hello, client +} diff --git a/src/net/http/httptest/httptest.go b/src/net/http/httptest/httptest.go new file mode 100644 index 0000000..9bedefd --- /dev/null +++ b/src/net/http/httptest/httptest.go @@ -0,0 +1,90 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package httptest provides utilities for HTTP testing. +package httptest + +import ( + "bufio" + "bytes" + "crypto/tls" + "io" + "net/http" + "strings" +) + +// NewRequest returns a new incoming server Request, suitable +// for passing to an http.Handler for testing. +// +// The target is the RFC 7230 "request-target": it may be either a +// path or an absolute URL. If target is an absolute URL, the host name +// from the URL is used. Otherwise, "example.com" is used. +// +// The TLS field is set to a non-nil dummy value if target has scheme +// "https". +// +// The Request.Proto is always HTTP/1.1. +// +// An empty method means "GET". +// +// The provided body may be nil. If the body is of type *bytes.Reader, +// *strings.Reader, or *bytes.Buffer, the Request.ContentLength is +// set. +// +// NewRequest panics on error for ease of use in testing, where a +// panic is acceptable. +// +// To generate a client HTTP request instead of a server request, see +// the NewRequest function in the net/http package. +func NewRequest(method, target string, body io.Reader) *http.Request { + if method == "" { + method = "GET" + } + req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(method + " " + target + " HTTP/1.0\r\n\r\n"))) + if err != nil { + panic("invalid NewRequest arguments; " + err.Error()) + } + + // HTTP/1.0 was used above to avoid needing a Host field. Change it to 1.1 here. + req.Proto = "HTTP/1.1" + req.ProtoMinor = 1 + req.Close = false + + if body != nil { + switch v := body.(type) { + case *bytes.Buffer: + req.ContentLength = int64(v.Len()) + case *bytes.Reader: + req.ContentLength = int64(v.Len()) + case *strings.Reader: + req.ContentLength = int64(v.Len()) + default: + req.ContentLength = -1 + } + if rc, ok := body.(io.ReadCloser); ok { + req.Body = rc + } else { + req.Body = io.NopCloser(body) + } + } + + // 192.0.2.0/24 is "TEST-NET" in RFC 5737 for use solely in + // documentation and example source code and should not be + // used publicly. + req.RemoteAddr = "192.0.2.1:1234" + + if req.Host == "" { + req.Host = "example.com" + } + + if strings.HasPrefix(target, "https://") { + req.TLS = &tls.ConnectionState{ + Version: tls.VersionTLS12, + HandshakeComplete: true, + ServerName: req.Host, + } + } + + return req +} diff --git a/src/net/http/httptest/httptest_test.go b/src/net/http/httptest/httptest_test.go new file mode 100644 index 0000000..071add6 --- /dev/null +++ b/src/net/http/httptest/httptest_test.go @@ -0,0 +1,179 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest + +import ( + "crypto/tls" + "io" + "net/http" + "net/url" + "reflect" + "strings" + "testing" +) + +func TestNewRequest(t *testing.T) { + for _, tt := range [...]struct { + name string + + method, uri string + body io.Reader + + want *http.Request + wantBody string + }{ + { + name: "Empty method means GET", + method: "", + uri: "/", + body: nil, + want: &http.Request{ + Method: "GET", + Host: "example.com", + URL: &url.URL{Path: "/"}, + Header: http.Header{}, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + RemoteAddr: "192.0.2.1:1234", + RequestURI: "/", + }, + wantBody: "", + }, + + { + name: "GET with full URL", + method: "GET", + uri: "http://foo.com/path/%2f/bar/", + body: nil, + want: &http.Request{ + Method: "GET", + Host: "foo.com", + URL: &url.URL{ + Scheme: "http", + Path: "/path///bar/", + RawPath: "/path/%2f/bar/", + Host: "foo.com", + }, + Header: http.Header{}, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + RemoteAddr: "192.0.2.1:1234", + RequestURI: "http://foo.com/path/%2f/bar/", + }, + wantBody: "", + }, + + { + name: "GET with full https URL", + method: "GET", + uri: "https://foo.com/path/", + body: nil, + want: &http.Request{ + Method: "GET", + Host: "foo.com", + URL: &url.URL{ + Scheme: "https", + Path: "/path/", + Host: "foo.com", + }, + Header: http.Header{}, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + RemoteAddr: "192.0.2.1:1234", + RequestURI: "https://foo.com/path/", + TLS: &tls.ConnectionState{ + Version: tls.VersionTLS12, + HandshakeComplete: true, + ServerName: "foo.com", + }, + }, + wantBody: "", + }, + + { + name: "Post with known length", + method: "POST", + uri: "/", + body: strings.NewReader("foo"), + want: &http.Request{ + Method: "POST", + Host: "example.com", + URL: &url.URL{Path: "/"}, + Header: http.Header{}, + Proto: "HTTP/1.1", + ContentLength: 3, + ProtoMajor: 1, + ProtoMinor: 1, + RemoteAddr: "192.0.2.1:1234", + RequestURI: "/", + }, + wantBody: "foo", + }, + + { + name: "Post with unknown length", + method: "POST", + uri: "/", + body: struct{ io.Reader }{strings.NewReader("foo")}, + want: &http.Request{ + Method: "POST", + Host: "example.com", + URL: &url.URL{Path: "/"}, + Header: http.Header{}, + Proto: "HTTP/1.1", + ContentLength: -1, + ProtoMajor: 1, + ProtoMinor: 1, + RemoteAddr: "192.0.2.1:1234", + RequestURI: "/", + }, + wantBody: "foo", + }, + + { + name: "OPTIONS *", + method: "OPTIONS", + uri: "*", + want: &http.Request{ + Method: "OPTIONS", + Host: "example.com", + URL: &url.URL{Path: "*"}, + Header: http.Header{}, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + RemoteAddr: "192.0.2.1:1234", + RequestURI: "*", + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + got := NewRequest(tt.method, tt.uri, tt.body) + slurp, err := io.ReadAll(got.Body) + if err != nil { + t.Errorf("ReadAll: %v", err) + } + if string(slurp) != tt.wantBody { + t.Errorf("Body = %q; want %q", slurp, tt.wantBody) + } + got.Body = nil // before DeepEqual + if !reflect.DeepEqual(got.URL, tt.want.URL) { + t.Errorf("Request.URL mismatch:\n got: %#v\nwant: %#v", got.URL, tt.want.URL) + } + if !reflect.DeepEqual(got.Header, tt.want.Header) { + t.Errorf("Request.Header mismatch:\n got: %#v\nwant: %#v", got.Header, tt.want.Header) + } + if !reflect.DeepEqual(got.TLS, tt.want.TLS) { + t.Errorf("Request.TLS mismatch:\n got: %#v\nwant: %#v", got.TLS, tt.want.TLS) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Request mismatch:\n got: %#v\nwant: %#v", got, tt.want) + } + }) + } +} diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go new file mode 100644 index 0000000..1c1d880 --- /dev/null +++ b/src/net/http/httptest/recorder.go @@ -0,0 +1,255 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/textproto" + "strconv" + "strings" + + "golang.org/x/net/http/httpguts" +) + +// ResponseRecorder is an implementation of http.ResponseWriter that +// records its mutations for later inspection in tests. +type ResponseRecorder struct { + // Code is the HTTP response code set by WriteHeader. + // + // Note that if a Handler never calls WriteHeader or Write, + // this might end up being 0, rather than the implicit + // http.StatusOK. To get the implicit value, use the Result + // method. + Code int + + // HeaderMap contains the headers explicitly set by the Handler. + // It is an internal detail. + // + // Deprecated: HeaderMap exists for historical compatibility + // and should not be used. To access the headers returned by a handler, + // use the Response.Header map as returned by the Result method. + HeaderMap http.Header + + // Body is the buffer to which the Handler's Write calls are sent. + // If nil, the Writes are silently discarded. + Body *bytes.Buffer + + // Flushed is whether the Handler called Flush. + Flushed bool + + result *http.Response // cache of Result's return value + snapHeader http.Header // snapshot of HeaderMap at first Write + wroteHeader bool +} + +// NewRecorder returns an initialized ResponseRecorder. +func NewRecorder() *ResponseRecorder { + return &ResponseRecorder{ + HeaderMap: make(http.Header), + Body: new(bytes.Buffer), + Code: 200, + } +} + +// DefaultRemoteAddr is the default remote address to return in RemoteAddr if +// an explicit DefaultRemoteAddr isn't set on ResponseRecorder. +const DefaultRemoteAddr = "1.2.3.4" + +// Header implements http.ResponseWriter. It returns the response +// headers to mutate within a handler. To test the headers that were +// written after a handler completes, use the Result method and see +// the returned Response value's Header. +func (rw *ResponseRecorder) Header() http.Header { + m := rw.HeaderMap + if m == nil { + m = make(http.Header) + rw.HeaderMap = m + } + return m +} + +// writeHeader writes a header if it was not written yet and +// detects Content-Type if needed. +// +// bytes or str are the beginning of the response body. +// We pass both to avoid unnecessarily generate garbage +// in rw.WriteString which was created for performance reasons. +// Non-nil bytes win. +func (rw *ResponseRecorder) writeHeader(b []byte, str string) { + if rw.wroteHeader { + return + } + if len(str) > 512 { + str = str[:512] + } + + m := rw.Header() + + _, hasType := m["Content-Type"] + hasTE := m.Get("Transfer-Encoding") != "" + if !hasType && !hasTE { + if b == nil { + b = []byte(str) + } + m.Set("Content-Type", http.DetectContentType(b)) + } + + rw.WriteHeader(200) +} + +// Write implements http.ResponseWriter. The data in buf is written to +// rw.Body, if not nil. +func (rw *ResponseRecorder) Write(buf []byte) (int, error) { + rw.writeHeader(buf, "") + if rw.Body != nil { + rw.Body.Write(buf) + } + return len(buf), nil +} + +// WriteString implements io.StringWriter. The data in str is written +// to rw.Body, if not nil. +func (rw *ResponseRecorder) WriteString(str string) (int, error) { + rw.writeHeader(nil, str) + if rw.Body != nil { + rw.Body.WriteString(str) + } + return len(str), nil +} + +func checkWriteHeaderCode(code int) { + // Issue 22880: require valid WriteHeader status codes. + // For now we only enforce that it's three digits. + // In the future we might block things over 599 (600 and above aren't defined + // at https://httpwg.org/specs/rfc7231.html#status.codes) + // and we might block under 200 (once we have more mature 1xx support). + // But for now any three digits. + // + // We used to send "HTTP/1.1 000 0" on the wire in responses but there's + // no equivalent bogus thing we can realistically send in HTTP/2, + // so we'll consistently panic instead and help people find their bugs + // early. (We can't return an error from WriteHeader even if we wanted to.) + if code < 100 || code > 999 { + panic(fmt.Sprintf("invalid WriteHeader code %v", code)) + } +} + +// WriteHeader implements http.ResponseWriter. +func (rw *ResponseRecorder) WriteHeader(code int) { + if rw.wroteHeader { + return + } + + checkWriteHeaderCode(code) + rw.Code = code + rw.wroteHeader = true + if rw.HeaderMap == nil { + rw.HeaderMap = make(http.Header) + } + rw.snapHeader = rw.HeaderMap.Clone() +} + +// Flush implements http.Flusher. To test whether Flush was +// called, see rw.Flushed. +func (rw *ResponseRecorder) Flush() { + if !rw.wroteHeader { + rw.WriteHeader(200) + } + rw.Flushed = true +} + +// Result returns the response generated by the handler. +// +// The returned Response will have at least its StatusCode, +// Header, Body, and optionally Trailer populated. +// More fields may be populated in the future, so callers should +// not DeepEqual the result in tests. +// +// The Response.Header is a snapshot of the headers at the time of the +// first write call, or at the time of this call, if the handler never +// did a write. +// +// The Response.Body is guaranteed to be non-nil and Body.Read call is +// guaranteed to not return any error other than io.EOF. +// +// Result must only be called after the handler has finished running. +func (rw *ResponseRecorder) Result() *http.Response { + if rw.result != nil { + return rw.result + } + if rw.snapHeader == nil { + rw.snapHeader = rw.HeaderMap.Clone() + } + res := &http.Response{ + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + StatusCode: rw.Code, + Header: rw.snapHeader, + } + rw.result = res + if res.StatusCode == 0 { + res.StatusCode = 200 + } + res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode)) + if rw.Body != nil { + res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes())) + } else { + res.Body = http.NoBody + } + res.ContentLength = parseContentLength(res.Header.Get("Content-Length")) + + if trailers, ok := rw.snapHeader["Trailer"]; ok { + res.Trailer = make(http.Header, len(trailers)) + for _, k := range trailers { + for _, k := range strings.Split(k, ",") { + k = http.CanonicalHeaderKey(textproto.TrimString(k)) + if !httpguts.ValidTrailerHeader(k) { + // Ignore since forbidden by RFC 7230, section 4.1.2. + continue + } + vv, ok := rw.HeaderMap[k] + if !ok { + continue + } + vv2 := make([]string, len(vv)) + copy(vv2, vv) + res.Trailer[k] = vv2 + } + } + } + for k, vv := range rw.HeaderMap { + if !strings.HasPrefix(k, http.TrailerPrefix) { + continue + } + if res.Trailer == nil { + res.Trailer = make(http.Header) + } + for _, v := range vv { + res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v) + } + } + return res +} + +// parseContentLength trims whitespace from s and returns -1 if no value +// is set, or the value if it's >= 0. +// +// This a modified version of same function found in net/http/transfer.go. This +// one just ignores an invalid header. +func parseContentLength(cl string) int64 { + cl = textproto.TrimString(cl) + if cl == "" { + return -1 + } + n, err := strconv.ParseUint(cl, 10, 63) + if err != nil { + return -1 + } + return int64(n) +} diff --git a/src/net/http/httptest/recorder_test.go b/src/net/http/httptest/recorder_test.go new file mode 100644 index 0000000..4782ece --- /dev/null +++ b/src/net/http/httptest/recorder_test.go @@ -0,0 +1,371 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest + +import ( + "fmt" + "io" + "net/http" + "testing" +) + +func TestRecorder(t *testing.T) { + type checkFunc func(*ResponseRecorder) error + check := func(fns ...checkFunc) []checkFunc { return fns } + + hasStatus := func(wantCode int) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Code != wantCode { + return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode) + } + return nil + } + } + hasResultStatus := func(want string) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Result().Status != want { + return fmt.Errorf("Result().Status = %q; want %q", rec.Result().Status, want) + } + return nil + } + } + hasResultStatusCode := func(wantCode int) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Result().StatusCode != wantCode { + return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode) + } + return nil + } + } + hasResultContents := func(want string) checkFunc { + return func(rec *ResponseRecorder) error { + contentBytes, err := io.ReadAll(rec.Result().Body) + if err != nil { + return err + } + contents := string(contentBytes) + if contents != want { + return fmt.Errorf("Result().Body = %s; want %s", contents, want) + } + return nil + } + } + hasContents := func(want string) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Body.String() != want { + return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want) + } + return nil + } + } + hasFlush := func(want bool) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Flushed != want { + return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want) + } + return nil + } + } + hasOldHeader := func(key, want string) checkFunc { + return func(rec *ResponseRecorder) error { + if got := rec.HeaderMap.Get(key); got != want { + return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want) + } + return nil + } + } + hasHeader := func(key, want string) checkFunc { + return func(rec *ResponseRecorder) error { + if got := rec.Result().Header.Get(key); got != want { + return fmt.Errorf("final header %s = %q; want %q", key, got, want) + } + return nil + } + } + hasNotHeaders := func(keys ...string) checkFunc { + return func(rec *ResponseRecorder) error { + for _, k := range keys { + v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)] + if ok { + return fmt.Errorf("unexpected header %s with value %q", k, v) + } + } + return nil + } + } + hasTrailer := func(key, want string) checkFunc { + return func(rec *ResponseRecorder) error { + if got := rec.Result().Trailer.Get(key); got != want { + return fmt.Errorf("trailer %s = %q; want %q", key, got, want) + } + return nil + } + } + hasNotTrailers := func(keys ...string) checkFunc { + return func(rec *ResponseRecorder) error { + trailers := rec.Result().Trailer + for _, k := range keys { + _, ok := trailers[http.CanonicalHeaderKey(k)] + if ok { + return fmt.Errorf("unexpected trailer %s", k) + } + } + return nil + } + } + hasContentLength := func(length int64) checkFunc { + return func(rec *ResponseRecorder) error { + if got := rec.Result().ContentLength; got != length { + return fmt.Errorf("ContentLength = %d; want %d", got, length) + } + return nil + } + } + + for _, tt := range [...]struct { + name string + h func(w http.ResponseWriter, r *http.Request) + checks []checkFunc + }{ + { + "200 default", + func(w http.ResponseWriter, r *http.Request) {}, + check(hasStatus(200), hasContents("")), + }, + { + "first code only", + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(201) + w.WriteHeader(202) + w.Write([]byte("hi")) + }, + check(hasStatus(201), hasContents("hi")), + }, + { + "write sends 200", + func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi first")) + w.WriteHeader(201) + w.WriteHeader(202) + }, + check(hasStatus(200), hasContents("hi first"), hasFlush(false)), + }, + { + "write string", + func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "hi first") + }, + check( + hasStatus(200), + hasContents("hi first"), + hasFlush(false), + hasHeader("Content-Type", "text/plain; charset=utf-8"), + ), + }, + { + "flush", + func(w http.ResponseWriter, r *http.Request) { + w.(http.Flusher).Flush() // also sends a 200 + w.WriteHeader(201) + }, + check(hasStatus(200), hasFlush(true), hasContentLength(-1)), + }, + { + "Content-Type detection", + func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "<html>") + }, + check(hasHeader("Content-Type", "text/html; charset=utf-8")), + }, + { + "no Content-Type detection with Transfer-Encoding", + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Transfer-Encoding", "some encoding") + io.WriteString(w, "<html>") + }, + check(hasHeader("Content-Type", "")), // no header + }, + { + "no Content-Type detection if set explicitly", + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "some/type") + io.WriteString(w, "<html>") + }, + check(hasHeader("Content-Type", "some/type")), + }, + { + "Content-Type detection doesn't crash if HeaderMap is nil", + func(w http.ResponseWriter, r *http.Request) { + // Act as if the user wrote new(httptest.ResponseRecorder) + // rather than using NewRecorder (which initializes + // HeaderMap) + w.(*ResponseRecorder).HeaderMap = nil + io.WriteString(w, "<html>") + }, + check(hasHeader("Content-Type", "text/html; charset=utf-8")), + }, + { + "Header is not changed after write", + func(w http.ResponseWriter, r *http.Request) { + hdr := w.Header() + hdr.Set("Key", "correct") + w.WriteHeader(200) + hdr.Set("Key", "incorrect") + }, + check(hasHeader("Key", "correct")), + }, + { + "Trailer headers are correctly recorded", + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Non-Trailer", "correct") + w.Header().Set("Trailer", "Trailer-A, Trailer-B") + w.Header().Add("Trailer", "Trailer-C") + io.WriteString(w, "<html>") + w.Header().Set("Non-Trailer", "incorrect") + w.Header().Set("Trailer-A", "valuea") + w.Header().Set("Trailer-C", "valuec") + w.Header().Set("Trailer-NotDeclared", "should be omitted") + w.Header().Set("Trailer:Trailer-D", "with prefix") + }, + check( + hasStatus(200), + hasHeader("Content-Type", "text/html; charset=utf-8"), + hasHeader("Non-Trailer", "correct"), + hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"), + hasTrailer("Trailer-A", "valuea"), + hasTrailer("Trailer-C", "valuec"), + hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"), + hasTrailer("Trailer-D", "with prefix"), + ), + }, + { + "Header set without any write", // Issue 15560 + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Foo", "1") + + // Simulate somebody using + // new(ResponseRecorder) instead of + // using the constructor which sets + // this to 200 + w.(*ResponseRecorder).Code = 0 + }, + check( + hasOldHeader("X-Foo", "1"), + hasStatus(0), + hasHeader("X-Foo", "1"), + hasResultStatus("200 OK"), + hasResultStatusCode(200), + ), + }, + { + "HeaderMap vs FinalHeaders", // more for Issue 15560 + func(w http.ResponseWriter, r *http.Request) { + h := w.Header() + h.Set("X-Foo", "1") + w.Write([]byte("hi")) + h.Set("X-Foo", "2") + h.Set("X-Bar", "2") + }, + check( + hasOldHeader("X-Foo", "2"), + hasOldHeader("X-Bar", "2"), + hasHeader("X-Foo", "1"), + hasNotHeaders("X-Bar"), + ), + }, + { + "setting Content-Length header", + func(w http.ResponseWriter, r *http.Request) { + body := "Some body" + contentLength := fmt.Sprintf("%d", len(body)) + w.Header().Set("Content-Length", contentLength) + io.WriteString(w, body) + }, + check(hasStatus(200), hasContents("Some body"), hasContentLength(9)), + }, + { + "nil ResponseRecorder.Body", // Issue 26642 + func(w http.ResponseWriter, r *http.Request) { + w.(*ResponseRecorder).Body = nil + io.WriteString(w, "hi") + }, + check(hasResultContents("")), // check we don't crash reading the body + + }, + } { + t.Run(tt.name, func(t *testing.T) { + r, _ := http.NewRequest("GET", "http://foo.com/", nil) + h := http.HandlerFunc(tt.h) + rec := NewRecorder() + h.ServeHTTP(rec, r) + for _, check := range tt.checks { + if err := check(rec); err != nil { + t.Error(err) + } + } + }) + } +} + +// issue 39017 - disallow Content-Length values such as "+3" +func TestParseContentLength(t *testing.T) { + tests := []struct { + cl string + want int64 + }{ + { + cl: "3", + want: 3, + }, + { + cl: "+3", + want: -1, + }, + { + cl: "-3", + want: -1, + }, + { + // max int64, for safe conversion before returning + cl: "9223372036854775807", + want: 9223372036854775807, + }, + { + cl: "9223372036854775808", + want: -1, + }, + } + + for _, tt := range tests { + if got := parseContentLength(tt.cl); got != tt.want { + t.Errorf("%q:\n\tgot=%d\n\twant=%d", tt.cl, got, tt.want) + } + } +} + +// Ensure that httptest.Recorder panics when given a non-3 digit (XXX) +// status HTTP code. See https://golang.org/issues/45353 +func TestRecorderPanicsOnNonXXXStatusCode(t *testing.T) { + badCodes := []int{ + -100, 0, 99, 1000, 20000, + } + for _, badCode := range badCodes { + badCode := badCode + t.Run(fmt.Sprintf("Code=%d", badCode), func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("Expected a panic") + } + }() + + handler := func(rw http.ResponseWriter, _ *http.Request) { + rw.WriteHeader(badCode) + } + r, _ := http.NewRequest("GET", "http://example.org/", nil) + rw := NewRecorder() + handler(rw, r) + }) + } +} diff --git a/src/net/http/httptest/server.go b/src/net/http/httptest/server.go new file mode 100644 index 0000000..f254a49 --- /dev/null +++ b/src/net/http/httptest/server.go @@ -0,0 +1,385 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Implementation of Server + +package httptest + +import ( + "crypto/tls" + "crypto/x509" + "flag" + "fmt" + "log" + "net" + "net/http" + "net/http/internal/testcert" + "os" + "strings" + "sync" + "time" +) + +// A Server is an HTTP server listening on a system-chosen port on the +// local loopback interface, for use in end-to-end HTTP tests. +type Server struct { + URL string // base URL of form http://ipaddr:port with no trailing slash + Listener net.Listener + + // EnableHTTP2 controls whether HTTP/2 is enabled + // on the server. It must be set between calling + // NewUnstartedServer and calling Server.StartTLS. + EnableHTTP2 bool + + // TLS is the optional TLS configuration, populated with a new config + // after TLS is started. If set on an unstarted server before StartTLS + // is called, existing fields are copied into the new config. + TLS *tls.Config + + // Config may be changed after calling NewUnstartedServer and + // before Start or StartTLS. + Config *http.Server + + // certificate is a parsed version of the TLS config certificate, if present. + certificate *x509.Certificate + + // wg counts the number of outstanding HTTP requests on this server. + // Close blocks until all requests are finished. + wg sync.WaitGroup + + mu sync.Mutex // guards closed and conns + closed bool + conns map[net.Conn]http.ConnState // except terminal states + + // client is configured for use with the server. + // Its transport is automatically closed when Close is called. + client *http.Client +} + +func newLocalListener() net.Listener { + if serveFlag != "" { + l, err := net.Listen("tcp", serveFlag) + if err != nil { + panic(fmt.Sprintf("httptest: failed to listen on %v: %v", serveFlag, err)) + } + return l + } + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { + panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err)) + } + } + return l +} + +// When debugging a particular http server-based test, +// this flag lets you run +// +// go test -run=BrokenTest -httptest.serve=127.0.0.1:8000 +// +// to start the broken server so you can interact with it manually. +// We only register this flag if it looks like the caller knows about it +// and is trying to use it as we don't want to pollute flags and this +// isn't really part of our API. Don't depend on this. +var serveFlag string + +func init() { + if strSliceContainsPrefix(os.Args, "-httptest.serve=") || strSliceContainsPrefix(os.Args, "--httptest.serve=") { + flag.StringVar(&serveFlag, "httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks.") + } +} + +func strSliceContainsPrefix(v []string, pre string) bool { + for _, s := range v { + if strings.HasPrefix(s, pre) { + return true + } + } + return false +} + +// NewServer starts and returns a new Server. +// The caller should call Close when finished, to shut it down. +func NewServer(handler http.Handler) *Server { + ts := NewUnstartedServer(handler) + ts.Start() + return ts +} + +// NewUnstartedServer returns a new Server but doesn't start it. +// +// After changing its configuration, the caller should call Start or +// StartTLS. +// +// The caller should call Close when finished, to shut it down. +func NewUnstartedServer(handler http.Handler) *Server { + return &Server{ + Listener: newLocalListener(), + Config: &http.Server{Handler: handler}, + } +} + +// Start starts a server from NewUnstartedServer. +func (s *Server) Start() { + if s.URL != "" { + panic("Server already started") + } + if s.client == nil { + s.client = &http.Client{Transport: &http.Transport{}} + } + s.URL = "http://" + s.Listener.Addr().String() + s.wrap() + s.goServe() + if serveFlag != "" { + fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL) + select {} + } +} + +// StartTLS starts TLS on a server from NewUnstartedServer. +func (s *Server) StartTLS() { + if s.URL != "" { + panic("Server already started") + } + if s.client == nil { + s.client = &http.Client{Transport: &http.Transport{}} + } + cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey) + if err != nil { + panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) + } + + existingConfig := s.TLS + if existingConfig != nil { + s.TLS = existingConfig.Clone() + } else { + s.TLS = new(tls.Config) + } + if s.TLS.NextProtos == nil { + nextProtos := []string{"http/1.1"} + if s.EnableHTTP2 { + nextProtos = []string{"h2"} + } + s.TLS.NextProtos = nextProtos + } + if len(s.TLS.Certificates) == 0 { + s.TLS.Certificates = []tls.Certificate{cert} + } + s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0]) + if err != nil { + panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) + } + certpool := x509.NewCertPool() + certpool.AddCert(s.certificate) + s.client.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: certpool, + }, + ForceAttemptHTTP2: s.EnableHTTP2, + } + s.Listener = tls.NewListener(s.Listener, s.TLS) + s.URL = "https://" + s.Listener.Addr().String() + s.wrap() + s.goServe() +} + +// NewTLSServer starts and returns a new Server using TLS. +// The caller should call Close when finished, to shut it down. +func NewTLSServer(handler http.Handler) *Server { + ts := NewUnstartedServer(handler) + ts.StartTLS() + return ts +} + +type closeIdleTransport interface { + CloseIdleConnections() +} + +// Close shuts down the server and blocks until all outstanding +// requests on this server have completed. +func (s *Server) Close() { + s.mu.Lock() + if !s.closed { + s.closed = true + s.Listener.Close() + s.Config.SetKeepAlivesEnabled(false) + for c, st := range s.conns { + // Force-close any idle connections (those between + // requests) and new connections (those which connected + // but never sent a request). StateNew connections are + // super rare and have only been seen (in + // previously-flaky tests) in the case of + // socket-late-binding races from the http Client + // dialing this server and then getting an idle + // connection before the dial completed. There is thus + // a connected connection in StateNew with no + // associated Request. We only close StateIdle and + // StateNew because they're not doing anything. It's + // possible StateNew is about to do something in a few + // milliseconds, but a previous CL to check again in a + // few milliseconds wasn't liked (early versions of + // https://golang.org/cl/15151) so now we just + // forcefully close StateNew. The docs for Server.Close say + // we wait for "outstanding requests", so we don't close things + // in StateActive. + if st == http.StateIdle || st == http.StateNew { + s.closeConn(c) + } + } + // If this server doesn't shut down in 5 seconds, tell the user why. + t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo) + defer t.Stop() + } + s.mu.Unlock() + + // Not part of httptest.Server's correctness, but assume most + // users of httptest.Server will be using the standard + // transport, so help them out and close any idle connections for them. + if t, ok := http.DefaultTransport.(closeIdleTransport); ok { + t.CloseIdleConnections() + } + + // Also close the client idle connections. + if s.client != nil { + if t, ok := s.client.Transport.(closeIdleTransport); ok { + t.CloseIdleConnections() + } + } + + s.wg.Wait() +} + +func (s *Server) logCloseHangDebugInfo() { + s.mu.Lock() + defer s.mu.Unlock() + var buf strings.Builder + buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n") + for c, st := range s.conns { + fmt.Fprintf(&buf, " %T %p %v in state %v\n", c, c, c.RemoteAddr(), st) + } + log.Print(buf.String()) +} + +// CloseClientConnections closes any open HTTP connections to the test Server. +func (s *Server) CloseClientConnections() { + s.mu.Lock() + nconn := len(s.conns) + ch := make(chan struct{}, nconn) + for c := range s.conns { + go s.closeConnChan(c, ch) + } + s.mu.Unlock() + + // Wait for outstanding closes to finish. + // + // Out of paranoia for making a late change in Go 1.6, we + // bound how long this can wait, since golang.org/issue/14291 + // isn't fully understood yet. At least this should only be used + // in tests. + timer := time.NewTimer(5 * time.Second) + defer timer.Stop() + for i := 0; i < nconn; i++ { + select { + case <-ch: + case <-timer.C: + // Too slow. Give up. + return + } + } +} + +// Certificate returns the certificate used by the server, or nil if +// the server doesn't use TLS. +func (s *Server) Certificate() *x509.Certificate { + return s.certificate +} + +// Client returns an HTTP client configured for making requests to the server. +// It is configured to trust the server's TLS test certificate and will +// close its idle connections on Server.Close. +func (s *Server) Client() *http.Client { + return s.client +} + +func (s *Server) goServe() { + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.Config.Serve(s.Listener) + }() +} + +// wrap installs the connection state-tracking hook to know which +// connections are idle. +func (s *Server) wrap() { + oldHook := s.Config.ConnState + s.Config.ConnState = func(c net.Conn, cs http.ConnState) { + s.mu.Lock() + defer s.mu.Unlock() + + switch cs { + case http.StateNew: + if _, exists := s.conns[c]; exists { + panic("invalid state transition") + } + if s.conns == nil { + s.conns = make(map[net.Conn]http.ConnState) + } + // Add c to the set of tracked conns and increment it to the + // waitgroup. + s.wg.Add(1) + s.conns[c] = cs + if s.closed { + // Probably just a socket-late-binding dial from + // the default transport that lost the race (and + // thus this connection is now idle and will + // never be used). + s.closeConn(c) + } + case http.StateActive: + if oldState, ok := s.conns[c]; ok { + if oldState != http.StateNew && oldState != http.StateIdle { + panic("invalid state transition") + } + s.conns[c] = cs + } + case http.StateIdle: + if oldState, ok := s.conns[c]; ok { + if oldState != http.StateActive { + panic("invalid state transition") + } + s.conns[c] = cs + } + if s.closed { + s.closeConn(c) + } + case http.StateHijacked, http.StateClosed: + // Remove c from the set of tracked conns and decrement it from the + // waitgroup, unless it was previously removed. + if _, ok := s.conns[c]; ok { + delete(s.conns, c) + // Keep Close from returning until the user's ConnState hook + // (if any) finishes. + defer s.wg.Done() + } + } + if oldHook != nil { + oldHook(c, cs) + } + } +} + +// closeConn closes c. +// s.mu must be held. +func (s *Server) closeConn(c net.Conn) { s.closeConnChan(c, nil) } + +// closeConnChan is like closeConn, but takes an optional channel to receive a value +// when the goroutine closing c is done. +func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) { + c.Close() + if done != nil { + done <- struct{}{} + } +} diff --git a/src/net/http/httptest/server_test.go b/src/net/http/httptest/server_test.go new file mode 100644 index 0000000..5313f65 --- /dev/null +++ b/src/net/http/httptest/server_test.go @@ -0,0 +1,294 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest + +import ( + "bufio" + "io" + "net" + "net/http" + "sync" + "testing" +) + +type newServerFunc func(http.Handler) *Server + +var newServers = map[string]newServerFunc{ + "NewServer": NewServer, + "NewTLSServer": NewTLSServer, + + // The manual variants of newServer create a Server manually by only filling + // in the exported fields of Server. + "NewServerManual": func(h http.Handler) *Server { + ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}} + ts.Start() + return ts + }, + "NewTLSServerManual": func(h http.Handler) *Server { + ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}} + ts.StartTLS() + return ts + }, +} + +func TestServer(t *testing.T) { + for _, name := range []string{"NewServer", "NewServerManual"} { + t.Run(name, func(t *testing.T) { + newServer := newServers[name] + t.Run("Server", func(t *testing.T) { testServer(t, newServer) }) + t.Run("GetAfterClose", func(t *testing.T) { testGetAfterClose(t, newServer) }) + t.Run("ServerCloseBlocking", func(t *testing.T) { testServerCloseBlocking(t, newServer) }) + t.Run("ServerCloseClientConnections", func(t *testing.T) { testServerCloseClientConnections(t, newServer) }) + t.Run("ServerClientTransportType", func(t *testing.T) { testServerClientTransportType(t, newServer) }) + }) + } + for _, name := range []string{"NewTLSServer", "NewTLSServerManual"} { + t.Run(name, func(t *testing.T) { + newServer := newServers[name] + t.Run("ServerClient", func(t *testing.T) { testServerClient(t, newServer) }) + t.Run("TLSServerClientTransportType", func(t *testing.T) { testTLSServerClientTransportType(t, newServer) }) + }) + } +} + +func testServer(t *testing.T, newServer newServerFunc) { + ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hello")) + })) + defer ts.Close() + res, err := http.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + got, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatal(err) + } + if string(got) != "hello" { + t.Errorf("got %q, want hello", string(got)) + } +} + +// Issue 12781 +func testGetAfterClose(t *testing.T, newServer newServerFunc) { + ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hello")) + })) + + res, err := http.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + got, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(got) != "hello" { + t.Fatalf("got %q, want hello", string(got)) + } + + ts.Close() + + res, err = http.Get(ts.URL) + if err == nil { + body, _ := io.ReadAll(res.Body) + t.Fatalf("Unexpected response after close: %v, %v, %s", res.Status, res.Header, body) + } +} + +func testServerCloseBlocking(t *testing.T, newServer newServerFunc) { + ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hello")) + })) + dial := func() net.Conn { + c, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + return c + } + + // Keep one connection in StateNew (connected, but not sending anything) + cnew := dial() + defer cnew.Close() + + // Keep one connection in StateIdle (idle after a request) + cidle := dial() + defer cidle.Close() + cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n")) + _, err := http.ReadResponse(bufio.NewReader(cidle), nil) + if err != nil { + t.Fatal(err) + } + + ts.Close() // test we don't hang here forever. +} + +// Issue 14290 +func testServerCloseClientConnections(t *testing.T, newServer newServerFunc) { + var s *Server + s = newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.CloseClientConnections() + })) + defer s.Close() + res, err := http.Get(s.URL) + if err == nil { + res.Body.Close() + t.Fatalf("Unexpected response: %#v", res) + } +} + +// Tests that the Server.Client method works and returns an http.Client that can hit +// NewTLSServer without cert warnings. +func testServerClient(t *testing.T, newTLSServer newServerFunc) { + ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hello")) + })) + defer ts.Close() + client := ts.Client() + res, err := client.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + got, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatal(err) + } + if string(got) != "hello" { + t.Errorf("got %q, want hello", string(got)) + } +} + +// Tests that the Server.Client.Transport interface is implemented +// by a *http.Transport. +func testServerClientTransportType(t *testing.T, newServer newServerFunc) { + ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })) + defer ts.Close() + client := ts.Client() + if _, ok := client.Transport.(*http.Transport); !ok { + t.Errorf("got %T, want *http.Transport", client.Transport) + } +} + +// Tests that the TLS Server.Client.Transport interface is implemented +// by a *http.Transport. +func testTLSServerClientTransportType(t *testing.T, newTLSServer newServerFunc) { + ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })) + defer ts.Close() + client := ts.Client() + if _, ok := client.Transport.(*http.Transport); !ok { + t.Errorf("got %T, want *http.Transport", client.Transport) + } +} + +type onlyCloseListener struct { + net.Listener +} + +func (onlyCloseListener) Close() error { return nil } + +// Issue 19729: panic in Server.Close for values created directly +// without a constructor (so the unexported client field is nil). +func TestServerZeroValueClose(t *testing.T) { + ts := &Server{ + Listener: onlyCloseListener{}, + Config: &http.Server{}, + } + + ts.Close() // tests that it doesn't panic +} + +// Issue 51799: test hijacking a connection and then closing it +// concurrently with closing the server. +func TestCloseHijackedConnection(t *testing.T) { + hijacked := make(chan net.Conn) + ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer close(hijacked) + hj, ok := w.(http.Hijacker) + if !ok { + t.Fatal("failed to hijack") + } + c, _, err := hj.Hijack() + if err != nil { + t.Fatal(err) + } + hijacked <- c + })) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + req, err := http.NewRequest("GET", ts.URL, nil) + if err != nil { + t.Log(err) + } + // Use a client not associated with the Server. + var c http.Client + resp, err := c.Do(req) + if err != nil { + t.Log(err) + return + } + resp.Body.Close() + }() + + wg.Add(1) + conn := <-hijacked + go func(conn net.Conn) { + defer wg.Done() + // Close the connection and then inform the Server that + // we closed it. + conn.Close() + ts.Config.ConnState(conn, http.StateClosed) + }(conn) + + wg.Add(1) + go func() { + defer wg.Done() + ts.Close() + }() + wg.Wait() +} + +func TestTLSServerWithHTTP2(t *testing.T) { + modes := []struct { + name string + wantProto string + }{ + {"http1", "HTTP/1.1"}, + {"http2", "HTTP/2.0"}, + } + + for _, tt := range modes { + t.Run(tt.name, func(t *testing.T) { + cst := NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Proto", r.Proto) + })) + + switch tt.name { + case "http2": + cst.EnableHTTP2 = true + cst.StartTLS() + default: + cst.Start() + } + + defer cst.Close() + + res, err := cst.Client().Get(cst.URL) + if err != nil { + t.Fatalf("Failed to make request: %v", err) + } + if g, w := res.Header.Get("X-Proto"), tt.wantProto; g != w { + t.Fatalf("X-Proto header mismatch:\n\tgot: %q\n\twant: %q", g, w) + } + }) + } +} |