summaryrefslogtreecommitdiffstats
path: root/src/net/http/httptest
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-16 19:23:18 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-16 19:23:18 +0000
commit43a123c1ae6613b3efeed291fa552ecd909d3acf (patch)
treefd92518b7024bc74031f78a1cf9e454b65e73665 /src/net/http/httptest
parentInitial commit. (diff)
downloadgolang-1.20-43a123c1ae6613b3efeed291fa552ecd909d3acf.tar.xz
golang-1.20-43a123c1ae6613b3efeed291fa552ecd909d3acf.zip
Adding upstream version 1.20.14.upstream/1.20.14upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
-rw-r--r--src/net/http/httptest/example_test.go99
-rw-r--r--src/net/http/httptest/httptest.go90
-rw-r--r--src/net/http/httptest/httptest_test.go179
-rw-r--r--src/net/http/httptest/recorder.go255
-rw-r--r--src/net/http/httptest/recorder_test.go371
-rw-r--r--src/net/http/httptest/server.go385
-rw-r--r--src/net/http/httptest/server_test.go294
7 files changed, 1673 insertions, 0 deletions
diff --git a/src/net/http/httptest/example_test.go b/src/net/http/httptest/example_test.go
new file mode 100644
index 0000000..a673843
--- /dev/null
+++ b/src/net/http/httptest/example_test.go
@@ -0,0 +1,99 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptest_test
+
+import (
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "net/http/httptest"
+)
+
+func ExampleResponseRecorder() {
+ handler := func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, "<html><body>Hello World!</body></html>")
+ }
+
+ req := httptest.NewRequest("GET", "http://example.com/foo", nil)
+ w := httptest.NewRecorder()
+ handler(w, req)
+
+ resp := w.Result()
+ body, _ := io.ReadAll(resp.Body)
+
+ fmt.Println(resp.StatusCode)
+ fmt.Println(resp.Header.Get("Content-Type"))
+ fmt.Println(string(body))
+
+ // Output:
+ // 200
+ // text/html; charset=utf-8
+ // <html><body>Hello World!</body></html>
+}
+
+func ExampleServer() {
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintln(w, "Hello, client")
+ }))
+ defer ts.Close()
+
+ res, err := http.Get(ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+ greeting, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Printf("%s", greeting)
+ // Output: Hello, client
+}
+
+func ExampleServer_hTTP2() {
+ ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintf(w, "Hello, %s", r.Proto)
+ }))
+ ts.EnableHTTP2 = true
+ ts.StartTLS()
+ defer ts.Close()
+
+ res, err := ts.Client().Get(ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+ greeting, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Printf("%s", greeting)
+
+ // Output: Hello, HTTP/2.0
+}
+
+func ExampleNewTLSServer() {
+ ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintln(w, "Hello, client")
+ }))
+ defer ts.Close()
+
+ client := ts.Client()
+ res, err := client.Get(ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ greeting, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Printf("%s", greeting)
+ // Output: Hello, client
+}
diff --git a/src/net/http/httptest/httptest.go b/src/net/http/httptest/httptest.go
new file mode 100644
index 0000000..9bedefd
--- /dev/null
+++ b/src/net/http/httptest/httptest.go
@@ -0,0 +1,90 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package httptest provides utilities for HTTP testing.
+package httptest
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/tls"
+ "io"
+ "net/http"
+ "strings"
+)
+
+// NewRequest returns a new incoming server Request, suitable
+// for passing to an http.Handler for testing.
+//
+// The target is the RFC 7230 "request-target": it may be either a
+// path or an absolute URL. If target is an absolute URL, the host name
+// from the URL is used. Otherwise, "example.com" is used.
+//
+// The TLS field is set to a non-nil dummy value if target has scheme
+// "https".
+//
+// The Request.Proto is always HTTP/1.1.
+//
+// An empty method means "GET".
+//
+// The provided body may be nil. If the body is of type *bytes.Reader,
+// *strings.Reader, or *bytes.Buffer, the Request.ContentLength is
+// set.
+//
+// NewRequest panics on error for ease of use in testing, where a
+// panic is acceptable.
+//
+// To generate a client HTTP request instead of a server request, see
+// the NewRequest function in the net/http package.
+func NewRequest(method, target string, body io.Reader) *http.Request {
+ if method == "" {
+ method = "GET"
+ }
+ req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(method + " " + target + " HTTP/1.0\r\n\r\n")))
+ if err != nil {
+ panic("invalid NewRequest arguments; " + err.Error())
+ }
+
+ // HTTP/1.0 was used above to avoid needing a Host field. Change it to 1.1 here.
+ req.Proto = "HTTP/1.1"
+ req.ProtoMinor = 1
+ req.Close = false
+
+ if body != nil {
+ switch v := body.(type) {
+ case *bytes.Buffer:
+ req.ContentLength = int64(v.Len())
+ case *bytes.Reader:
+ req.ContentLength = int64(v.Len())
+ case *strings.Reader:
+ req.ContentLength = int64(v.Len())
+ default:
+ req.ContentLength = -1
+ }
+ if rc, ok := body.(io.ReadCloser); ok {
+ req.Body = rc
+ } else {
+ req.Body = io.NopCloser(body)
+ }
+ }
+
+ // 192.0.2.0/24 is "TEST-NET" in RFC 5737 for use solely in
+ // documentation and example source code and should not be
+ // used publicly.
+ req.RemoteAddr = "192.0.2.1:1234"
+
+ if req.Host == "" {
+ req.Host = "example.com"
+ }
+
+ if strings.HasPrefix(target, "https://") {
+ req.TLS = &tls.ConnectionState{
+ Version: tls.VersionTLS12,
+ HandshakeComplete: true,
+ ServerName: req.Host,
+ }
+ }
+
+ return req
+}
diff --git a/src/net/http/httptest/httptest_test.go b/src/net/http/httptest/httptest_test.go
new file mode 100644
index 0000000..071add6
--- /dev/null
+++ b/src/net/http/httptest/httptest_test.go
@@ -0,0 +1,179 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptest
+
+import (
+ "crypto/tls"
+ "io"
+ "net/http"
+ "net/url"
+ "reflect"
+ "strings"
+ "testing"
+)
+
+func TestNewRequest(t *testing.T) {
+ for _, tt := range [...]struct {
+ name string
+
+ method, uri string
+ body io.Reader
+
+ want *http.Request
+ wantBody string
+ }{
+ {
+ name: "Empty method means GET",
+ method: "",
+ uri: "/",
+ body: nil,
+ want: &http.Request{
+ Method: "GET",
+ Host: "example.com",
+ URL: &url.URL{Path: "/"},
+ Header: http.Header{},
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ RemoteAddr: "192.0.2.1:1234",
+ RequestURI: "/",
+ },
+ wantBody: "",
+ },
+
+ {
+ name: "GET with full URL",
+ method: "GET",
+ uri: "http://foo.com/path/%2f/bar/",
+ body: nil,
+ want: &http.Request{
+ Method: "GET",
+ Host: "foo.com",
+ URL: &url.URL{
+ Scheme: "http",
+ Path: "/path///bar/",
+ RawPath: "/path/%2f/bar/",
+ Host: "foo.com",
+ },
+ Header: http.Header{},
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ RemoteAddr: "192.0.2.1:1234",
+ RequestURI: "http://foo.com/path/%2f/bar/",
+ },
+ wantBody: "",
+ },
+
+ {
+ name: "GET with full https URL",
+ method: "GET",
+ uri: "https://foo.com/path/",
+ body: nil,
+ want: &http.Request{
+ Method: "GET",
+ Host: "foo.com",
+ URL: &url.URL{
+ Scheme: "https",
+ Path: "/path/",
+ Host: "foo.com",
+ },
+ Header: http.Header{},
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ RemoteAddr: "192.0.2.1:1234",
+ RequestURI: "https://foo.com/path/",
+ TLS: &tls.ConnectionState{
+ Version: tls.VersionTLS12,
+ HandshakeComplete: true,
+ ServerName: "foo.com",
+ },
+ },
+ wantBody: "",
+ },
+
+ {
+ name: "Post with known length",
+ method: "POST",
+ uri: "/",
+ body: strings.NewReader("foo"),
+ want: &http.Request{
+ Method: "POST",
+ Host: "example.com",
+ URL: &url.URL{Path: "/"},
+ Header: http.Header{},
+ Proto: "HTTP/1.1",
+ ContentLength: 3,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ RemoteAddr: "192.0.2.1:1234",
+ RequestURI: "/",
+ },
+ wantBody: "foo",
+ },
+
+ {
+ name: "Post with unknown length",
+ method: "POST",
+ uri: "/",
+ body: struct{ io.Reader }{strings.NewReader("foo")},
+ want: &http.Request{
+ Method: "POST",
+ Host: "example.com",
+ URL: &url.URL{Path: "/"},
+ Header: http.Header{},
+ Proto: "HTTP/1.1",
+ ContentLength: -1,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ RemoteAddr: "192.0.2.1:1234",
+ RequestURI: "/",
+ },
+ wantBody: "foo",
+ },
+
+ {
+ name: "OPTIONS *",
+ method: "OPTIONS",
+ uri: "*",
+ want: &http.Request{
+ Method: "OPTIONS",
+ Host: "example.com",
+ URL: &url.URL{Path: "*"},
+ Header: http.Header{},
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ RemoteAddr: "192.0.2.1:1234",
+ RequestURI: "*",
+ },
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ got := NewRequest(tt.method, tt.uri, tt.body)
+ slurp, err := io.ReadAll(got.Body)
+ if err != nil {
+ t.Errorf("ReadAll: %v", err)
+ }
+ if string(slurp) != tt.wantBody {
+ t.Errorf("Body = %q; want %q", slurp, tt.wantBody)
+ }
+ got.Body = nil // before DeepEqual
+ if !reflect.DeepEqual(got.URL, tt.want.URL) {
+ t.Errorf("Request.URL mismatch:\n got: %#v\nwant: %#v", got.URL, tt.want.URL)
+ }
+ if !reflect.DeepEqual(got.Header, tt.want.Header) {
+ t.Errorf("Request.Header mismatch:\n got: %#v\nwant: %#v", got.Header, tt.want.Header)
+ }
+ if !reflect.DeepEqual(got.TLS, tt.want.TLS) {
+ t.Errorf("Request.TLS mismatch:\n got: %#v\nwant: %#v", got.TLS, tt.want.TLS)
+ }
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("Request mismatch:\n got: %#v\nwant: %#v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go
new file mode 100644
index 0000000..1c1d880
--- /dev/null
+++ b/src/net/http/httptest/recorder.go
@@ -0,0 +1,255 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptest
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "net/http"
+ "net/textproto"
+ "strconv"
+ "strings"
+
+ "golang.org/x/net/http/httpguts"
+)
+
+// ResponseRecorder is an implementation of http.ResponseWriter that
+// records its mutations for later inspection in tests.
+type ResponseRecorder struct {
+ // Code is the HTTP response code set by WriteHeader.
+ //
+ // Note that if a Handler never calls WriteHeader or Write,
+ // this might end up being 0, rather than the implicit
+ // http.StatusOK. To get the implicit value, use the Result
+ // method.
+ Code int
+
+ // HeaderMap contains the headers explicitly set by the Handler.
+ // It is an internal detail.
+ //
+ // Deprecated: HeaderMap exists for historical compatibility
+ // and should not be used. To access the headers returned by a handler,
+ // use the Response.Header map as returned by the Result method.
+ HeaderMap http.Header
+
+ // Body is the buffer to which the Handler's Write calls are sent.
+ // If nil, the Writes are silently discarded.
+ Body *bytes.Buffer
+
+ // Flushed is whether the Handler called Flush.
+ Flushed bool
+
+ result *http.Response // cache of Result's return value
+ snapHeader http.Header // snapshot of HeaderMap at first Write
+ wroteHeader bool
+}
+
+// NewRecorder returns an initialized ResponseRecorder.
+func NewRecorder() *ResponseRecorder {
+ return &ResponseRecorder{
+ HeaderMap: make(http.Header),
+ Body: new(bytes.Buffer),
+ Code: 200,
+ }
+}
+
+// DefaultRemoteAddr is the default remote address to return in RemoteAddr if
+// an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
+const DefaultRemoteAddr = "1.2.3.4"
+
+// Header implements http.ResponseWriter. It returns the response
+// headers to mutate within a handler. To test the headers that were
+// written after a handler completes, use the Result method and see
+// the returned Response value's Header.
+func (rw *ResponseRecorder) Header() http.Header {
+ m := rw.HeaderMap
+ if m == nil {
+ m = make(http.Header)
+ rw.HeaderMap = m
+ }
+ return m
+}
+
+// writeHeader writes a header if it was not written yet and
+// detects Content-Type if needed.
+//
+// bytes or str are the beginning of the response body.
+// We pass both to avoid unnecessarily generate garbage
+// in rw.WriteString which was created for performance reasons.
+// Non-nil bytes win.
+func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
+ if rw.wroteHeader {
+ return
+ }
+ if len(str) > 512 {
+ str = str[:512]
+ }
+
+ m := rw.Header()
+
+ _, hasType := m["Content-Type"]
+ hasTE := m.Get("Transfer-Encoding") != ""
+ if !hasType && !hasTE {
+ if b == nil {
+ b = []byte(str)
+ }
+ m.Set("Content-Type", http.DetectContentType(b))
+ }
+
+ rw.WriteHeader(200)
+}
+
+// Write implements http.ResponseWriter. The data in buf is written to
+// rw.Body, if not nil.
+func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
+ rw.writeHeader(buf, "")
+ if rw.Body != nil {
+ rw.Body.Write(buf)
+ }
+ return len(buf), nil
+}
+
+// WriteString implements io.StringWriter. The data in str is written
+// to rw.Body, if not nil.
+func (rw *ResponseRecorder) WriteString(str string) (int, error) {
+ rw.writeHeader(nil, str)
+ if rw.Body != nil {
+ rw.Body.WriteString(str)
+ }
+ return len(str), nil
+}
+
+func checkWriteHeaderCode(code int) {
+ // Issue 22880: require valid WriteHeader status codes.
+ // For now we only enforce that it's three digits.
+ // In the future we might block things over 599 (600 and above aren't defined
+ // at https://httpwg.org/specs/rfc7231.html#status.codes)
+ // and we might block under 200 (once we have more mature 1xx support).
+ // But for now any three digits.
+ //
+ // We used to send "HTTP/1.1 000 0" on the wire in responses but there's
+ // no equivalent bogus thing we can realistically send in HTTP/2,
+ // so we'll consistently panic instead and help people find their bugs
+ // early. (We can't return an error from WriteHeader even if we wanted to.)
+ if code < 100 || code > 999 {
+ panic(fmt.Sprintf("invalid WriteHeader code %v", code))
+ }
+}
+
+// WriteHeader implements http.ResponseWriter.
+func (rw *ResponseRecorder) WriteHeader(code int) {
+ if rw.wroteHeader {
+ return
+ }
+
+ checkWriteHeaderCode(code)
+ rw.Code = code
+ rw.wroteHeader = true
+ if rw.HeaderMap == nil {
+ rw.HeaderMap = make(http.Header)
+ }
+ rw.snapHeader = rw.HeaderMap.Clone()
+}
+
+// Flush implements http.Flusher. To test whether Flush was
+// called, see rw.Flushed.
+func (rw *ResponseRecorder) Flush() {
+ if !rw.wroteHeader {
+ rw.WriteHeader(200)
+ }
+ rw.Flushed = true
+}
+
+// Result returns the response generated by the handler.
+//
+// The returned Response will have at least its StatusCode,
+// Header, Body, and optionally Trailer populated.
+// More fields may be populated in the future, so callers should
+// not DeepEqual the result in tests.
+//
+// The Response.Header is a snapshot of the headers at the time of the
+// first write call, or at the time of this call, if the handler never
+// did a write.
+//
+// The Response.Body is guaranteed to be non-nil and Body.Read call is
+// guaranteed to not return any error other than io.EOF.
+//
+// Result must only be called after the handler has finished running.
+func (rw *ResponseRecorder) Result() *http.Response {
+ if rw.result != nil {
+ return rw.result
+ }
+ if rw.snapHeader == nil {
+ rw.snapHeader = rw.HeaderMap.Clone()
+ }
+ res := &http.Response{
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ StatusCode: rw.Code,
+ Header: rw.snapHeader,
+ }
+ rw.result = res
+ if res.StatusCode == 0 {
+ res.StatusCode = 200
+ }
+ res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode))
+ if rw.Body != nil {
+ res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes()))
+ } else {
+ res.Body = http.NoBody
+ }
+ res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
+
+ if trailers, ok := rw.snapHeader["Trailer"]; ok {
+ res.Trailer = make(http.Header, len(trailers))
+ for _, k := range trailers {
+ for _, k := range strings.Split(k, ",") {
+ k = http.CanonicalHeaderKey(textproto.TrimString(k))
+ if !httpguts.ValidTrailerHeader(k) {
+ // Ignore since forbidden by RFC 7230, section 4.1.2.
+ continue
+ }
+ vv, ok := rw.HeaderMap[k]
+ if !ok {
+ continue
+ }
+ vv2 := make([]string, len(vv))
+ copy(vv2, vv)
+ res.Trailer[k] = vv2
+ }
+ }
+ }
+ for k, vv := range rw.HeaderMap {
+ if !strings.HasPrefix(k, http.TrailerPrefix) {
+ continue
+ }
+ if res.Trailer == nil {
+ res.Trailer = make(http.Header)
+ }
+ for _, v := range vv {
+ res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
+ }
+ }
+ return res
+}
+
+// parseContentLength trims whitespace from s and returns -1 if no value
+// is set, or the value if it's >= 0.
+//
+// This a modified version of same function found in net/http/transfer.go. This
+// one just ignores an invalid header.
+func parseContentLength(cl string) int64 {
+ cl = textproto.TrimString(cl)
+ if cl == "" {
+ return -1
+ }
+ n, err := strconv.ParseUint(cl, 10, 63)
+ if err != nil {
+ return -1
+ }
+ return int64(n)
+}
diff --git a/src/net/http/httptest/recorder_test.go b/src/net/http/httptest/recorder_test.go
new file mode 100644
index 0000000..4782ece
--- /dev/null
+++ b/src/net/http/httptest/recorder_test.go
@@ -0,0 +1,371 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptest
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+ "testing"
+)
+
+func TestRecorder(t *testing.T) {
+ type checkFunc func(*ResponseRecorder) error
+ check := func(fns ...checkFunc) []checkFunc { return fns }
+
+ hasStatus := func(wantCode int) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Code != wantCode {
+ return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode)
+ }
+ return nil
+ }
+ }
+ hasResultStatus := func(want string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Result().Status != want {
+ return fmt.Errorf("Result().Status = %q; want %q", rec.Result().Status, want)
+ }
+ return nil
+ }
+ }
+ hasResultStatusCode := func(wantCode int) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Result().StatusCode != wantCode {
+ return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode)
+ }
+ return nil
+ }
+ }
+ hasResultContents := func(want string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ contentBytes, err := io.ReadAll(rec.Result().Body)
+ if err != nil {
+ return err
+ }
+ contents := string(contentBytes)
+ if contents != want {
+ return fmt.Errorf("Result().Body = %s; want %s", contents, want)
+ }
+ return nil
+ }
+ }
+ hasContents := func(want string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Body.String() != want {
+ return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want)
+ }
+ return nil
+ }
+ }
+ hasFlush := func(want bool) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Flushed != want {
+ return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want)
+ }
+ return nil
+ }
+ }
+ hasOldHeader := func(key, want string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if got := rec.HeaderMap.Get(key); got != want {
+ return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want)
+ }
+ return nil
+ }
+ }
+ hasHeader := func(key, want string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if got := rec.Result().Header.Get(key); got != want {
+ return fmt.Errorf("final header %s = %q; want %q", key, got, want)
+ }
+ return nil
+ }
+ }
+ hasNotHeaders := func(keys ...string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ for _, k := range keys {
+ v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)]
+ if ok {
+ return fmt.Errorf("unexpected header %s with value %q", k, v)
+ }
+ }
+ return nil
+ }
+ }
+ hasTrailer := func(key, want string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if got := rec.Result().Trailer.Get(key); got != want {
+ return fmt.Errorf("trailer %s = %q; want %q", key, got, want)
+ }
+ return nil
+ }
+ }
+ hasNotTrailers := func(keys ...string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ trailers := rec.Result().Trailer
+ for _, k := range keys {
+ _, ok := trailers[http.CanonicalHeaderKey(k)]
+ if ok {
+ return fmt.Errorf("unexpected trailer %s", k)
+ }
+ }
+ return nil
+ }
+ }
+ hasContentLength := func(length int64) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if got := rec.Result().ContentLength; got != length {
+ return fmt.Errorf("ContentLength = %d; want %d", got, length)
+ }
+ return nil
+ }
+ }
+
+ for _, tt := range [...]struct {
+ name string
+ h func(w http.ResponseWriter, r *http.Request)
+ checks []checkFunc
+ }{
+ {
+ "200 default",
+ func(w http.ResponseWriter, r *http.Request) {},
+ check(hasStatus(200), hasContents("")),
+ },
+ {
+ "first code only",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(201)
+ w.WriteHeader(202)
+ w.Write([]byte("hi"))
+ },
+ check(hasStatus(201), hasContents("hi")),
+ },
+ {
+ "write sends 200",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hi first"))
+ w.WriteHeader(201)
+ w.WriteHeader(202)
+ },
+ check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
+ },
+ {
+ "write string",
+ func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, "hi first")
+ },
+ check(
+ hasStatus(200),
+ hasContents("hi first"),
+ hasFlush(false),
+ hasHeader("Content-Type", "text/plain; charset=utf-8"),
+ ),
+ },
+ {
+ "flush",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.(http.Flusher).Flush() // also sends a 200
+ w.WriteHeader(201)
+ },
+ check(hasStatus(200), hasFlush(true), hasContentLength(-1)),
+ },
+ {
+ "Content-Type detection",
+ func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, "<html>")
+ },
+ check(hasHeader("Content-Type", "text/html; charset=utf-8")),
+ },
+ {
+ "no Content-Type detection with Transfer-Encoding",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Transfer-Encoding", "some encoding")
+ io.WriteString(w, "<html>")
+ },
+ check(hasHeader("Content-Type", "")), // no header
+ },
+ {
+ "no Content-Type detection if set explicitly",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "some/type")
+ io.WriteString(w, "<html>")
+ },
+ check(hasHeader("Content-Type", "some/type")),
+ },
+ {
+ "Content-Type detection doesn't crash if HeaderMap is nil",
+ func(w http.ResponseWriter, r *http.Request) {
+ // Act as if the user wrote new(httptest.ResponseRecorder)
+ // rather than using NewRecorder (which initializes
+ // HeaderMap)
+ w.(*ResponseRecorder).HeaderMap = nil
+ io.WriteString(w, "<html>")
+ },
+ check(hasHeader("Content-Type", "text/html; charset=utf-8")),
+ },
+ {
+ "Header is not changed after write",
+ func(w http.ResponseWriter, r *http.Request) {
+ hdr := w.Header()
+ hdr.Set("Key", "correct")
+ w.WriteHeader(200)
+ hdr.Set("Key", "incorrect")
+ },
+ check(hasHeader("Key", "correct")),
+ },
+ {
+ "Trailer headers are correctly recorded",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Non-Trailer", "correct")
+ w.Header().Set("Trailer", "Trailer-A, Trailer-B")
+ w.Header().Add("Trailer", "Trailer-C")
+ io.WriteString(w, "<html>")
+ w.Header().Set("Non-Trailer", "incorrect")
+ w.Header().Set("Trailer-A", "valuea")
+ w.Header().Set("Trailer-C", "valuec")
+ w.Header().Set("Trailer-NotDeclared", "should be omitted")
+ w.Header().Set("Trailer:Trailer-D", "with prefix")
+ },
+ check(
+ hasStatus(200),
+ hasHeader("Content-Type", "text/html; charset=utf-8"),
+ hasHeader("Non-Trailer", "correct"),
+ hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"),
+ hasTrailer("Trailer-A", "valuea"),
+ hasTrailer("Trailer-C", "valuec"),
+ hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
+ hasTrailer("Trailer-D", "with prefix"),
+ ),
+ },
+ {
+ "Header set without any write", // Issue 15560
+ func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("X-Foo", "1")
+
+ // Simulate somebody using
+ // new(ResponseRecorder) instead of
+ // using the constructor which sets
+ // this to 200
+ w.(*ResponseRecorder).Code = 0
+ },
+ check(
+ hasOldHeader("X-Foo", "1"),
+ hasStatus(0),
+ hasHeader("X-Foo", "1"),
+ hasResultStatus("200 OK"),
+ hasResultStatusCode(200),
+ ),
+ },
+ {
+ "HeaderMap vs FinalHeaders", // more for Issue 15560
+ func(w http.ResponseWriter, r *http.Request) {
+ h := w.Header()
+ h.Set("X-Foo", "1")
+ w.Write([]byte("hi"))
+ h.Set("X-Foo", "2")
+ h.Set("X-Bar", "2")
+ },
+ check(
+ hasOldHeader("X-Foo", "2"),
+ hasOldHeader("X-Bar", "2"),
+ hasHeader("X-Foo", "1"),
+ hasNotHeaders("X-Bar"),
+ ),
+ },
+ {
+ "setting Content-Length header",
+ func(w http.ResponseWriter, r *http.Request) {
+ body := "Some body"
+ contentLength := fmt.Sprintf("%d", len(body))
+ w.Header().Set("Content-Length", contentLength)
+ io.WriteString(w, body)
+ },
+ check(hasStatus(200), hasContents("Some body"), hasContentLength(9)),
+ },
+ {
+ "nil ResponseRecorder.Body", // Issue 26642
+ func(w http.ResponseWriter, r *http.Request) {
+ w.(*ResponseRecorder).Body = nil
+ io.WriteString(w, "hi")
+ },
+ check(hasResultContents("")), // check we don't crash reading the body
+
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://foo.com/", nil)
+ h := http.HandlerFunc(tt.h)
+ rec := NewRecorder()
+ h.ServeHTTP(rec, r)
+ for _, check := range tt.checks {
+ if err := check(rec); err != nil {
+ t.Error(err)
+ }
+ }
+ })
+ }
+}
+
+// issue 39017 - disallow Content-Length values such as "+3"
+func TestParseContentLength(t *testing.T) {
+ tests := []struct {
+ cl string
+ want int64
+ }{
+ {
+ cl: "3",
+ want: 3,
+ },
+ {
+ cl: "+3",
+ want: -1,
+ },
+ {
+ cl: "-3",
+ want: -1,
+ },
+ {
+ // max int64, for safe conversion before returning
+ cl: "9223372036854775807",
+ want: 9223372036854775807,
+ },
+ {
+ cl: "9223372036854775808",
+ want: -1,
+ },
+ }
+
+ for _, tt := range tests {
+ if got := parseContentLength(tt.cl); got != tt.want {
+ t.Errorf("%q:\n\tgot=%d\n\twant=%d", tt.cl, got, tt.want)
+ }
+ }
+}
+
+// Ensure that httptest.Recorder panics when given a non-3 digit (XXX)
+// status HTTP code. See https://golang.org/issues/45353
+func TestRecorderPanicsOnNonXXXStatusCode(t *testing.T) {
+ badCodes := []int{
+ -100, 0, 99, 1000, 20000,
+ }
+ for _, badCode := range badCodes {
+ badCode := badCode
+ t.Run(fmt.Sprintf("Code=%d", badCode), func(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Fatal("Expected a panic")
+ }
+ }()
+
+ handler := func(rw http.ResponseWriter, _ *http.Request) {
+ rw.WriteHeader(badCode)
+ }
+ r, _ := http.NewRequest("GET", "http://example.org/", nil)
+ rw := NewRecorder()
+ handler(rw, r)
+ })
+ }
+}
diff --git a/src/net/http/httptest/server.go b/src/net/http/httptest/server.go
new file mode 100644
index 0000000..f254a49
--- /dev/null
+++ b/src/net/http/httptest/server.go
@@ -0,0 +1,385 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Implementation of Server
+
+package httptest
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "flag"
+ "fmt"
+ "log"
+ "net"
+ "net/http"
+ "net/http/internal/testcert"
+ "os"
+ "strings"
+ "sync"
+ "time"
+)
+
+// A Server is an HTTP server listening on a system-chosen port on the
+// local loopback interface, for use in end-to-end HTTP tests.
+type Server struct {
+ URL string // base URL of form http://ipaddr:port with no trailing slash
+ Listener net.Listener
+
+ // EnableHTTP2 controls whether HTTP/2 is enabled
+ // on the server. It must be set between calling
+ // NewUnstartedServer and calling Server.StartTLS.
+ EnableHTTP2 bool
+
+ // TLS is the optional TLS configuration, populated with a new config
+ // after TLS is started. If set on an unstarted server before StartTLS
+ // is called, existing fields are copied into the new config.
+ TLS *tls.Config
+
+ // Config may be changed after calling NewUnstartedServer and
+ // before Start or StartTLS.
+ Config *http.Server
+
+ // certificate is a parsed version of the TLS config certificate, if present.
+ certificate *x509.Certificate
+
+ // wg counts the number of outstanding HTTP requests on this server.
+ // Close blocks until all requests are finished.
+ wg sync.WaitGroup
+
+ mu sync.Mutex // guards closed and conns
+ closed bool
+ conns map[net.Conn]http.ConnState // except terminal states
+
+ // client is configured for use with the server.
+ // Its transport is automatically closed when Close is called.
+ client *http.Client
+}
+
+func newLocalListener() net.Listener {
+ if serveFlag != "" {
+ l, err := net.Listen("tcp", serveFlag)
+ if err != nil {
+ panic(fmt.Sprintf("httptest: failed to listen on %v: %v", serveFlag, err))
+ }
+ return l
+ }
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
+ panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
+ }
+ }
+ return l
+}
+
+// When debugging a particular http server-based test,
+// this flag lets you run
+//
+// go test -run=BrokenTest -httptest.serve=127.0.0.1:8000
+//
+// to start the broken server so you can interact with it manually.
+// We only register this flag if it looks like the caller knows about it
+// and is trying to use it as we don't want to pollute flags and this
+// isn't really part of our API. Don't depend on this.
+var serveFlag string
+
+func init() {
+ if strSliceContainsPrefix(os.Args, "-httptest.serve=") || strSliceContainsPrefix(os.Args, "--httptest.serve=") {
+ flag.StringVar(&serveFlag, "httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks.")
+ }
+}
+
+func strSliceContainsPrefix(v []string, pre string) bool {
+ for _, s := range v {
+ if strings.HasPrefix(s, pre) {
+ return true
+ }
+ }
+ return false
+}
+
+// NewServer starts and returns a new Server.
+// The caller should call Close when finished, to shut it down.
+func NewServer(handler http.Handler) *Server {
+ ts := NewUnstartedServer(handler)
+ ts.Start()
+ return ts
+}
+
+// NewUnstartedServer returns a new Server but doesn't start it.
+//
+// After changing its configuration, the caller should call Start or
+// StartTLS.
+//
+// The caller should call Close when finished, to shut it down.
+func NewUnstartedServer(handler http.Handler) *Server {
+ return &Server{
+ Listener: newLocalListener(),
+ Config: &http.Server{Handler: handler},
+ }
+}
+
+// Start starts a server from NewUnstartedServer.
+func (s *Server) Start() {
+ if s.URL != "" {
+ panic("Server already started")
+ }
+ if s.client == nil {
+ s.client = &http.Client{Transport: &http.Transport{}}
+ }
+ s.URL = "http://" + s.Listener.Addr().String()
+ s.wrap()
+ s.goServe()
+ if serveFlag != "" {
+ fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
+ select {}
+ }
+}
+
+// StartTLS starts TLS on a server from NewUnstartedServer.
+func (s *Server) StartTLS() {
+ if s.URL != "" {
+ panic("Server already started")
+ }
+ if s.client == nil {
+ s.client = &http.Client{Transport: &http.Transport{}}
+ }
+ cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
+ if err != nil {
+ panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
+ }
+
+ existingConfig := s.TLS
+ if existingConfig != nil {
+ s.TLS = existingConfig.Clone()
+ } else {
+ s.TLS = new(tls.Config)
+ }
+ if s.TLS.NextProtos == nil {
+ nextProtos := []string{"http/1.1"}
+ if s.EnableHTTP2 {
+ nextProtos = []string{"h2"}
+ }
+ s.TLS.NextProtos = nextProtos
+ }
+ if len(s.TLS.Certificates) == 0 {
+ s.TLS.Certificates = []tls.Certificate{cert}
+ }
+ s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0])
+ if err != nil {
+ panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
+ }
+ certpool := x509.NewCertPool()
+ certpool.AddCert(s.certificate)
+ s.client.Transport = &http.Transport{
+ TLSClientConfig: &tls.Config{
+ RootCAs: certpool,
+ },
+ ForceAttemptHTTP2: s.EnableHTTP2,
+ }
+ s.Listener = tls.NewListener(s.Listener, s.TLS)
+ s.URL = "https://" + s.Listener.Addr().String()
+ s.wrap()
+ s.goServe()
+}
+
+// NewTLSServer starts and returns a new Server using TLS.
+// The caller should call Close when finished, to shut it down.
+func NewTLSServer(handler http.Handler) *Server {
+ ts := NewUnstartedServer(handler)
+ ts.StartTLS()
+ return ts
+}
+
+type closeIdleTransport interface {
+ CloseIdleConnections()
+}
+
+// Close shuts down the server and blocks until all outstanding
+// requests on this server have completed.
+func (s *Server) Close() {
+ s.mu.Lock()
+ if !s.closed {
+ s.closed = true
+ s.Listener.Close()
+ s.Config.SetKeepAlivesEnabled(false)
+ for c, st := range s.conns {
+ // Force-close any idle connections (those between
+ // requests) and new connections (those which connected
+ // but never sent a request). StateNew connections are
+ // super rare and have only been seen (in
+ // previously-flaky tests) in the case of
+ // socket-late-binding races from the http Client
+ // dialing this server and then getting an idle
+ // connection before the dial completed. There is thus
+ // a connected connection in StateNew with no
+ // associated Request. We only close StateIdle and
+ // StateNew because they're not doing anything. It's
+ // possible StateNew is about to do something in a few
+ // milliseconds, but a previous CL to check again in a
+ // few milliseconds wasn't liked (early versions of
+ // https://golang.org/cl/15151) so now we just
+ // forcefully close StateNew. The docs for Server.Close say
+ // we wait for "outstanding requests", so we don't close things
+ // in StateActive.
+ if st == http.StateIdle || st == http.StateNew {
+ s.closeConn(c)
+ }
+ }
+ // If this server doesn't shut down in 5 seconds, tell the user why.
+ t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo)
+ defer t.Stop()
+ }
+ s.mu.Unlock()
+
+ // Not part of httptest.Server's correctness, but assume most
+ // users of httptest.Server will be using the standard
+ // transport, so help them out and close any idle connections for them.
+ if t, ok := http.DefaultTransport.(closeIdleTransport); ok {
+ t.CloseIdleConnections()
+ }
+
+ // Also close the client idle connections.
+ if s.client != nil {
+ if t, ok := s.client.Transport.(closeIdleTransport); ok {
+ t.CloseIdleConnections()
+ }
+ }
+
+ s.wg.Wait()
+}
+
+func (s *Server) logCloseHangDebugInfo() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ var buf strings.Builder
+ buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n")
+ for c, st := range s.conns {
+ fmt.Fprintf(&buf, " %T %p %v in state %v\n", c, c, c.RemoteAddr(), st)
+ }
+ log.Print(buf.String())
+}
+
+// CloseClientConnections closes any open HTTP connections to the test Server.
+func (s *Server) CloseClientConnections() {
+ s.mu.Lock()
+ nconn := len(s.conns)
+ ch := make(chan struct{}, nconn)
+ for c := range s.conns {
+ go s.closeConnChan(c, ch)
+ }
+ s.mu.Unlock()
+
+ // Wait for outstanding closes to finish.
+ //
+ // Out of paranoia for making a late change in Go 1.6, we
+ // bound how long this can wait, since golang.org/issue/14291
+ // isn't fully understood yet. At least this should only be used
+ // in tests.
+ timer := time.NewTimer(5 * time.Second)
+ defer timer.Stop()
+ for i := 0; i < nconn; i++ {
+ select {
+ case <-ch:
+ case <-timer.C:
+ // Too slow. Give up.
+ return
+ }
+ }
+}
+
+// Certificate returns the certificate used by the server, or nil if
+// the server doesn't use TLS.
+func (s *Server) Certificate() *x509.Certificate {
+ return s.certificate
+}
+
+// Client returns an HTTP client configured for making requests to the server.
+// It is configured to trust the server's TLS test certificate and will
+// close its idle connections on Server.Close.
+func (s *Server) Client() *http.Client {
+ return s.client
+}
+
+func (s *Server) goServe() {
+ s.wg.Add(1)
+ go func() {
+ defer s.wg.Done()
+ s.Config.Serve(s.Listener)
+ }()
+}
+
+// wrap installs the connection state-tracking hook to know which
+// connections are idle.
+func (s *Server) wrap() {
+ oldHook := s.Config.ConnState
+ s.Config.ConnState = func(c net.Conn, cs http.ConnState) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ switch cs {
+ case http.StateNew:
+ if _, exists := s.conns[c]; exists {
+ panic("invalid state transition")
+ }
+ if s.conns == nil {
+ s.conns = make(map[net.Conn]http.ConnState)
+ }
+ // Add c to the set of tracked conns and increment it to the
+ // waitgroup.
+ s.wg.Add(1)
+ s.conns[c] = cs
+ if s.closed {
+ // Probably just a socket-late-binding dial from
+ // the default transport that lost the race (and
+ // thus this connection is now idle and will
+ // never be used).
+ s.closeConn(c)
+ }
+ case http.StateActive:
+ if oldState, ok := s.conns[c]; ok {
+ if oldState != http.StateNew && oldState != http.StateIdle {
+ panic("invalid state transition")
+ }
+ s.conns[c] = cs
+ }
+ case http.StateIdle:
+ if oldState, ok := s.conns[c]; ok {
+ if oldState != http.StateActive {
+ panic("invalid state transition")
+ }
+ s.conns[c] = cs
+ }
+ if s.closed {
+ s.closeConn(c)
+ }
+ case http.StateHijacked, http.StateClosed:
+ // Remove c from the set of tracked conns and decrement it from the
+ // waitgroup, unless it was previously removed.
+ if _, ok := s.conns[c]; ok {
+ delete(s.conns, c)
+ // Keep Close from returning until the user's ConnState hook
+ // (if any) finishes.
+ defer s.wg.Done()
+ }
+ }
+ if oldHook != nil {
+ oldHook(c, cs)
+ }
+ }
+}
+
+// closeConn closes c.
+// s.mu must be held.
+func (s *Server) closeConn(c net.Conn) { s.closeConnChan(c, nil) }
+
+// closeConnChan is like closeConn, but takes an optional channel to receive a value
+// when the goroutine closing c is done.
+func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) {
+ c.Close()
+ if done != nil {
+ done <- struct{}{}
+ }
+}
diff --git a/src/net/http/httptest/server_test.go b/src/net/http/httptest/server_test.go
new file mode 100644
index 0000000..5313f65
--- /dev/null
+++ b/src/net/http/httptest/server_test.go
@@ -0,0 +1,294 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptest
+
+import (
+ "bufio"
+ "io"
+ "net"
+ "net/http"
+ "sync"
+ "testing"
+)
+
+type newServerFunc func(http.Handler) *Server
+
+var newServers = map[string]newServerFunc{
+ "NewServer": NewServer,
+ "NewTLSServer": NewTLSServer,
+
+ // The manual variants of newServer create a Server manually by only filling
+ // in the exported fields of Server.
+ "NewServerManual": func(h http.Handler) *Server {
+ ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
+ ts.Start()
+ return ts
+ },
+ "NewTLSServerManual": func(h http.Handler) *Server {
+ ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
+ ts.StartTLS()
+ return ts
+ },
+}
+
+func TestServer(t *testing.T) {
+ for _, name := range []string{"NewServer", "NewServerManual"} {
+ t.Run(name, func(t *testing.T) {
+ newServer := newServers[name]
+ t.Run("Server", func(t *testing.T) { testServer(t, newServer) })
+ t.Run("GetAfterClose", func(t *testing.T) { testGetAfterClose(t, newServer) })
+ t.Run("ServerCloseBlocking", func(t *testing.T) { testServerCloseBlocking(t, newServer) })
+ t.Run("ServerCloseClientConnections", func(t *testing.T) { testServerCloseClientConnections(t, newServer) })
+ t.Run("ServerClientTransportType", func(t *testing.T) { testServerClientTransportType(t, newServer) })
+ })
+ }
+ for _, name := range []string{"NewTLSServer", "NewTLSServerManual"} {
+ t.Run(name, func(t *testing.T) {
+ newServer := newServers[name]
+ t.Run("ServerClient", func(t *testing.T) { testServerClient(t, newServer) })
+ t.Run("TLSServerClientTransportType", func(t *testing.T) { testTLSServerClientTransportType(t, newServer) })
+ })
+ }
+}
+
+func testServer(t *testing.T, newServer newServerFunc) {
+ ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hello"))
+ }))
+ defer ts.Close()
+ res, err := http.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ got, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(got) != "hello" {
+ t.Errorf("got %q, want hello", string(got))
+ }
+}
+
+// Issue 12781
+func testGetAfterClose(t *testing.T, newServer newServerFunc) {
+ ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hello"))
+ }))
+
+ res, err := http.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ got, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(got) != "hello" {
+ t.Fatalf("got %q, want hello", string(got))
+ }
+
+ ts.Close()
+
+ res, err = http.Get(ts.URL)
+ if err == nil {
+ body, _ := io.ReadAll(res.Body)
+ t.Fatalf("Unexpected response after close: %v, %v, %s", res.Status, res.Header, body)
+ }
+}
+
+func testServerCloseBlocking(t *testing.T, newServer newServerFunc) {
+ ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hello"))
+ }))
+ dial := func() net.Conn {
+ c, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ return c
+ }
+
+ // Keep one connection in StateNew (connected, but not sending anything)
+ cnew := dial()
+ defer cnew.Close()
+
+ // Keep one connection in StateIdle (idle after a request)
+ cidle := dial()
+ defer cidle.Close()
+ cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n"))
+ _, err := http.ReadResponse(bufio.NewReader(cidle), nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ts.Close() // test we don't hang here forever.
+}
+
+// Issue 14290
+func testServerCloseClientConnections(t *testing.T, newServer newServerFunc) {
+ var s *Server
+ s = newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ s.CloseClientConnections()
+ }))
+ defer s.Close()
+ res, err := http.Get(s.URL)
+ if err == nil {
+ res.Body.Close()
+ t.Fatalf("Unexpected response: %#v", res)
+ }
+}
+
+// Tests that the Server.Client method works and returns an http.Client that can hit
+// NewTLSServer without cert warnings.
+func testServerClient(t *testing.T, newTLSServer newServerFunc) {
+ ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hello"))
+ }))
+ defer ts.Close()
+ client := ts.Client()
+ res, err := client.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ got, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(got) != "hello" {
+ t.Errorf("got %q, want hello", string(got))
+ }
+}
+
+// Tests that the Server.Client.Transport interface is implemented
+// by a *http.Transport.
+func testServerClientTransportType(t *testing.T, newServer newServerFunc) {
+ ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ }))
+ defer ts.Close()
+ client := ts.Client()
+ if _, ok := client.Transport.(*http.Transport); !ok {
+ t.Errorf("got %T, want *http.Transport", client.Transport)
+ }
+}
+
+// Tests that the TLS Server.Client.Transport interface is implemented
+// by a *http.Transport.
+func testTLSServerClientTransportType(t *testing.T, newTLSServer newServerFunc) {
+ ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ }))
+ defer ts.Close()
+ client := ts.Client()
+ if _, ok := client.Transport.(*http.Transport); !ok {
+ t.Errorf("got %T, want *http.Transport", client.Transport)
+ }
+}
+
+type onlyCloseListener struct {
+ net.Listener
+}
+
+func (onlyCloseListener) Close() error { return nil }
+
+// Issue 19729: panic in Server.Close for values created directly
+// without a constructor (so the unexported client field is nil).
+func TestServerZeroValueClose(t *testing.T) {
+ ts := &Server{
+ Listener: onlyCloseListener{},
+ Config: &http.Server{},
+ }
+
+ ts.Close() // tests that it doesn't panic
+}
+
+// Issue 51799: test hijacking a connection and then closing it
+// concurrently with closing the server.
+func TestCloseHijackedConnection(t *testing.T) {
+ hijacked := make(chan net.Conn)
+ ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ defer close(hijacked)
+ hj, ok := w.(http.Hijacker)
+ if !ok {
+ t.Fatal("failed to hijack")
+ }
+ c, _, err := hj.Hijack()
+ if err != nil {
+ t.Fatal(err)
+ }
+ hijacked <- c
+ }))
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ req, err := http.NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Log(err)
+ }
+ // Use a client not associated with the Server.
+ var c http.Client
+ resp, err := c.Do(req)
+ if err != nil {
+ t.Log(err)
+ return
+ }
+ resp.Body.Close()
+ }()
+
+ wg.Add(1)
+ conn := <-hijacked
+ go func(conn net.Conn) {
+ defer wg.Done()
+ // Close the connection and then inform the Server that
+ // we closed it.
+ conn.Close()
+ ts.Config.ConnState(conn, http.StateClosed)
+ }(conn)
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ ts.Close()
+ }()
+ wg.Wait()
+}
+
+func TestTLSServerWithHTTP2(t *testing.T) {
+ modes := []struct {
+ name string
+ wantProto string
+ }{
+ {"http1", "HTTP/1.1"},
+ {"http2", "HTTP/2.0"},
+ }
+
+ for _, tt := range modes {
+ t.Run(tt.name, func(t *testing.T) {
+ cst := NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("X-Proto", r.Proto)
+ }))
+
+ switch tt.name {
+ case "http2":
+ cst.EnableHTTP2 = true
+ cst.StartTLS()
+ default:
+ cst.Start()
+ }
+
+ defer cst.Close()
+
+ res, err := cst.Client().Get(cst.URL)
+ if err != nil {
+ t.Fatalf("Failed to make request: %v", err)
+ }
+ if g, w := res.Header.Get("X-Proto"), tt.wantProto; g != w {
+ t.Fatalf("X-Proto header mismatch:\n\tgot: %q\n\twant: %q", g, w)
+ }
+ })
+ }
+}