summaryrefslogtreecommitdiffstats
path: root/src/net/http/clientserver_test.go
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/net/http/clientserver_test.go1756
1 files changed, 1756 insertions, 0 deletions
diff --git a/src/net/http/clientserver_test.go b/src/net/http/clientserver_test.go
new file mode 100644
index 0000000..da5671d
--- /dev/null
+++ b/src/net/http/clientserver_test.go
@@ -0,0 +1,1756 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests that use both the client & server, in both HTTP/1 and HTTP/2 mode.
+
+package http_test
+
+import (
+ "bytes"
+ "compress/gzip"
+ "context"
+ "crypto/rand"
+ "crypto/sha1"
+ "crypto/tls"
+ "fmt"
+ "hash"
+ "io"
+ "log"
+ "net"
+ . "net/http"
+ "net/http/httptest"
+ "net/http/httptrace"
+ "net/http/httputil"
+ "net/textproto"
+ "net/url"
+ "os"
+ "reflect"
+ "runtime"
+ "sort"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+type testMode string
+
+const (
+ http1Mode = testMode("h1") // HTTP/1.1
+ https1Mode = testMode("https1") // HTTPS/1.1
+ http2Mode = testMode("h2") // HTTP/2
+)
+
+type testNotParallelOpt struct{}
+
+var (
+ testNotParallel = testNotParallelOpt{}
+)
+
+type TBRun[T any] interface {
+ testing.TB
+ Run(string, func(T)) bool
+}
+
+// run runs a client/server test in a variety of test configurations.
+//
+// Tests execute in HTTP/1.1 and HTTP/2 modes by default.
+// To run in a different set of configurations, pass a []testMode option.
+//
+// Tests call t.Parallel() by default.
+// To disable parallel execution, pass the testNotParallel option.
+func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) {
+ t.Helper()
+ modes := []testMode{http1Mode, http2Mode}
+ parallel := true
+ for _, opt := range opts {
+ switch opt := opt.(type) {
+ case []testMode:
+ modes = opt
+ case testNotParallelOpt:
+ parallel = false
+ default:
+ t.Fatalf("unknown option type %T", opt)
+ }
+ }
+ if t, ok := any(t).(*testing.T); ok && parallel {
+ setParallel(t)
+ }
+ for _, mode := range modes {
+ t.Run(string(mode), func(t T) {
+ t.Helper()
+ if t, ok := any(t).(*testing.T); ok && parallel {
+ setParallel(t)
+ }
+ t.Cleanup(func() {
+ afterTest(t)
+ })
+ f(t, mode)
+ })
+ }
+}
+
+type clientServerTest struct {
+ t testing.TB
+ h2 bool
+ h Handler
+ ts *httptest.Server
+ tr *Transport
+ c *Client
+}
+
+func (t *clientServerTest) close() {
+ t.tr.CloseIdleConnections()
+ t.ts.Close()
+}
+
+func (t *clientServerTest) getURL(u string) string {
+ res, err := t.c.Get(u)
+ if err != nil {
+ t.t.Fatal(err)
+ }
+ defer res.Body.Close()
+ slurp, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.t.Fatal(err)
+ }
+ return string(slurp)
+}
+
+func (t *clientServerTest) scheme() string {
+ if t.h2 {
+ return "https"
+ }
+ return "http"
+}
+
+var optQuietLog = func(ts *httptest.Server) {
+ ts.Config.ErrorLog = quietLog
+}
+
+func optWithServerLog(lg *log.Logger) func(*httptest.Server) {
+ return func(ts *httptest.Server) {
+ ts.Config.ErrorLog = lg
+ }
+}
+
+// newClientServerTest creates and starts an httptest.Server.
+//
+// The mode parameter selects the implementation to test:
+// HTTP/1, HTTP/2, etc. Tests using newClientServerTest should use
+// the 'run' function, which will start a subtests for each tested mode.
+//
+// The vararg opts parameter can include functions to configure the
+// test server or transport.
+//
+// func(*httptest.Server) // run before starting the server
+// func(*http.Transport)
+func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest {
+ if mode == http2Mode {
+ CondSkipHTTP2(t)
+ }
+ cst := &clientServerTest{
+ t: t,
+ h2: mode == http2Mode,
+ h: h,
+ }
+ cst.ts = httptest.NewUnstartedServer(h)
+
+ var transportFuncs []func(*Transport)
+ for _, opt := range opts {
+ switch opt := opt.(type) {
+ case func(*Transport):
+ transportFuncs = append(transportFuncs, opt)
+ case func(*httptest.Server):
+ opt(cst.ts)
+ default:
+ t.Fatalf("unhandled option type %T", opt)
+ }
+ }
+
+ if cst.ts.Config.ErrorLog == nil {
+ cst.ts.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
+ }
+
+ switch mode {
+ case http1Mode:
+ cst.ts.Start()
+ case https1Mode:
+ cst.ts.StartTLS()
+ case http2Mode:
+ ExportHttp2ConfigureServer(cst.ts.Config, nil)
+ cst.ts.TLS = cst.ts.Config.TLSConfig
+ cst.ts.StartTLS()
+ default:
+ t.Fatalf("unknown test mode %v", mode)
+ }
+ cst.c = cst.ts.Client()
+ cst.tr = cst.c.Transport.(*Transport)
+ if mode == http2Mode {
+ if err := ExportHttp2ConfigureTransport(cst.tr); err != nil {
+ t.Fatal(err)
+ }
+ }
+ for _, f := range transportFuncs {
+ f(cst.tr)
+ }
+ t.Cleanup(func() {
+ cst.close()
+ })
+ return cst
+}
+
+type testLogWriter struct {
+ t testing.TB
+}
+
+func (w testLogWriter) Write(b []byte) (int, error) {
+ w.t.Logf("server log: %v", strings.TrimSpace(string(b)))
+ return len(b), nil
+}
+
+// Testing the newClientServerTest helper itself.
+func TestNewClientServerTest(t *testing.T) {
+ run(t, testNewClientServerTest, []testMode{http1Mode, https1Mode, http2Mode})
+}
+func testNewClientServerTest(t *testing.T, mode testMode) {
+ var got struct {
+ sync.Mutex
+ proto string
+ hasTLS bool
+ }
+ h := HandlerFunc(func(w ResponseWriter, r *Request) {
+ got.Lock()
+ defer got.Unlock()
+ got.proto = r.Proto
+ got.hasTLS = r.TLS != nil
+ })
+ cst := newClientServerTest(t, mode, h)
+ if _, err := cst.c.Head(cst.ts.URL); err != nil {
+ t.Fatal(err)
+ }
+ var wantProto string
+ var wantTLS bool
+ switch mode {
+ case http1Mode:
+ wantProto = "HTTP/1.1"
+ wantTLS = false
+ case https1Mode:
+ wantProto = "HTTP/1.1"
+ wantTLS = true
+ case http2Mode:
+ wantProto = "HTTP/2.0"
+ wantTLS = true
+ }
+ if got.proto != wantProto {
+ t.Errorf("req.Proto = %q, want %q", got.proto, wantProto)
+ }
+ if got.hasTLS != wantTLS {
+ t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS)
+ }
+}
+
+func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) }
+func testChunkedResponseHeaders(t *testing.T, mode testMode) {
+ log.SetOutput(io.Discard) // is noisy otherwise
+ defer log.SetOutput(os.Stderr)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted
+ w.(Flusher).Flush()
+ fmt.Fprintf(w, "I am a chunked response.")
+ }))
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatalf("Get error: %v", err)
+ }
+ defer res.Body.Close()
+ if g, e := res.ContentLength, int64(-1); g != e {
+ t.Errorf("expected ContentLength of %d; got %d", e, g)
+ }
+ wantTE := []string{"chunked"}
+ if mode == http2Mode {
+ wantTE = nil
+ }
+ if !reflect.DeepEqual(res.TransferEncoding, wantTE) {
+ t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE)
+ }
+ if got, haveCL := res.Header["Content-Length"]; haveCL {
+ t.Errorf("Unexpected Content-Length: %q", got)
+ }
+}
+
+type reqFunc func(c *Client, url string) (*Response, error)
+
+// h12Compare is a test that compares HTTP/1 and HTTP/2 behavior
+// against each other.
+type h12Compare struct {
+ Handler func(ResponseWriter, *Request) // required
+ ReqFunc reqFunc // optional
+ CheckResponse func(proto string, res *Response) // optional
+ EarlyCheckResponse func(proto string, res *Response) // optional; pre-normalize
+ Opts []any
+}
+
+func (tt h12Compare) reqFunc() reqFunc {
+ if tt.ReqFunc == nil {
+ return (*Client).Get
+ }
+ return tt.ReqFunc
+}
+
+func (tt h12Compare) run(t *testing.T) {
+ setParallel(t)
+ cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...)
+ defer cst1.close()
+ cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...)
+ defer cst2.close()
+
+ res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
+ if err != nil {
+ t.Errorf("HTTP/1 request: %v", err)
+ return
+ }
+ res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL)
+ if err != nil {
+ t.Errorf("HTTP/2 request: %v", err)
+ return
+ }
+
+ if fn := tt.EarlyCheckResponse; fn != nil {
+ fn("HTTP/1.1", res1)
+ fn("HTTP/2.0", res2)
+ }
+
+ tt.normalizeRes(t, res1, "HTTP/1.1")
+ tt.normalizeRes(t, res2, "HTTP/2.0")
+ res1body, res2body := res1.Body, res2.Body
+
+ eres1 := mostlyCopy(res1)
+ eres2 := mostlyCopy(res2)
+ if !reflect.DeepEqual(eres1, eres2) {
+ t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v",
+ cst1.ts.URL, eres1, cst2.ts.URL, eres2)
+ }
+ if !reflect.DeepEqual(res1body, res2body) {
+ t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body)
+ }
+ if fn := tt.CheckResponse; fn != nil {
+ res1.Body, res2.Body = res1body, res2body
+ fn("HTTP/1.1", res1)
+ fn("HTTP/2.0", res2)
+ }
+}
+
+func mostlyCopy(r *Response) *Response {
+ c := *r
+ c.Body = nil
+ c.TransferEncoding = nil
+ c.TLS = nil
+ c.Request = nil
+ return &c
+}
+
+type slurpResult struct {
+ io.ReadCloser
+ body []byte
+ err error
+}
+
+func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
+
+func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
+ if res.Proto == wantProto || res.Proto == "HTTP/IGNORE" {
+ res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
+ } else {
+ t.Errorf("got %q response; want %q", res.Proto, wantProto)
+ }
+ slurp, err := io.ReadAll(res.Body)
+
+ res.Body.Close()
+ res.Body = slurpResult{
+ ReadCloser: io.NopCloser(bytes.NewReader(slurp)),
+ body: slurp,
+ err: err,
+ }
+ for i, v := range res.Header["Date"] {
+ res.Header["Date"][i] = strings.Repeat("x", len(v))
+ }
+ if res.Request == nil {
+ t.Errorf("for %s, no request", wantProto)
+ }
+ if (res.TLS != nil) != (wantProto == "HTTP/2.0") {
+ t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil)
+ }
+}
+
+// Issue 13532
+func TestH12_HeadContentLengthNoBody(t *testing.T) {
+ h12Compare{
+ ReqFunc: (*Client).Head,
+ Handler: func(w ResponseWriter, r *Request) {
+ },
+ }.run(t)
+}
+
+func TestH12_HeadContentLengthSmallBody(t *testing.T) {
+ h12Compare{
+ ReqFunc: (*Client).Head,
+ Handler: func(w ResponseWriter, r *Request) {
+ io.WriteString(w, "small")
+ },
+ }.run(t)
+}
+
+func TestH12_HeadContentLengthLargeBody(t *testing.T) {
+ h12Compare{
+ ReqFunc: (*Client).Head,
+ Handler: func(w ResponseWriter, r *Request) {
+ chunk := strings.Repeat("x", 512<<10)
+ for i := 0; i < 10; i++ {
+ io.WriteString(w, chunk)
+ }
+ },
+ }.run(t)
+}
+
+func TestH12_200NoBody(t *testing.T) {
+ h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t)
+}
+
+func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) }
+func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) }
+func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) }
+
+func testH12_noBody(t *testing.T, status int) {
+ h12Compare{Handler: func(w ResponseWriter, r *Request) {
+ w.WriteHeader(status)
+ }}.run(t)
+}
+
+func TestH12_SmallBody(t *testing.T) {
+ h12Compare{Handler: func(w ResponseWriter, r *Request) {
+ io.WriteString(w, "small body")
+ }}.run(t)
+}
+
+func TestH12_ExplicitContentLength(t *testing.T) {
+ h12Compare{Handler: func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "3")
+ io.WriteString(w, "foo")
+ }}.run(t)
+}
+
+func TestH12_FlushBeforeBody(t *testing.T) {
+ h12Compare{Handler: func(w ResponseWriter, r *Request) {
+ w.(Flusher).Flush()
+ io.WriteString(w, "foo")
+ }}.run(t)
+}
+
+func TestH12_FlushMidBody(t *testing.T) {
+ h12Compare{Handler: func(w ResponseWriter, r *Request) {
+ io.WriteString(w, "foo")
+ w.(Flusher).Flush()
+ io.WriteString(w, "bar")
+ }}.run(t)
+}
+
+func TestH12_Head_ExplicitLen(t *testing.T) {
+ h12Compare{
+ ReqFunc: (*Client).Head,
+ Handler: func(w ResponseWriter, r *Request) {
+ if r.Method != "HEAD" {
+ t.Errorf("unexpected method %q", r.Method)
+ }
+ w.Header().Set("Content-Length", "1235")
+ },
+ }.run(t)
+}
+
+func TestH12_Head_ImplicitLen(t *testing.T) {
+ h12Compare{
+ ReqFunc: (*Client).Head,
+ Handler: func(w ResponseWriter, r *Request) {
+ if r.Method != "HEAD" {
+ t.Errorf("unexpected method %q", r.Method)
+ }
+ io.WriteString(w, "foo")
+ },
+ }.run(t)
+}
+
+func TestH12_HandlerWritesTooLittle(t *testing.T) {
+ h12Compare{
+ Handler: func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "3")
+ io.WriteString(w, "12") // one byte short
+ },
+ CheckResponse: func(proto string, res *Response) {
+ sr, ok := res.Body.(slurpResult)
+ if !ok {
+ t.Errorf("%s body is %T; want slurpResult", proto, res.Body)
+ return
+ }
+ if sr.err != io.ErrUnexpectedEOF {
+ t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err)
+ }
+ if string(sr.body) != "12" {
+ t.Errorf("%s body = %q; want %q", proto, sr.body, "12")
+ }
+ },
+ }.run(t)
+}
+
+// Tests that the HTTP/1 and HTTP/2 servers prevent handlers from
+// writing more than they declared. This test does not test whether
+// the transport deals with too much data, though, since the server
+// doesn't make it possible to send bogus data. For those tests, see
+// transport_test.go (for HTTP/1) or x/net/http2/transport_test.go
+// (for HTTP/2).
+func TestH12_HandlerWritesTooMuch(t *testing.T) {
+ h12Compare{
+ Handler: func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "3")
+ w.(Flusher).Flush()
+ io.WriteString(w, "123")
+ w.(Flusher).Flush()
+ n, err := io.WriteString(w, "x") // too many
+ if n > 0 || err == nil {
+ t.Errorf("for proto %q, final write = %v, %v; want 0, some error", r.Proto, n, err)
+ }
+ },
+ }.run(t)
+}
+
+// Verify that both our HTTP/1 and HTTP/2 request and auto-decompress gzip.
+// Some hosts send gzip even if you don't ask for it; see golang.org/issue/13298
+func TestH12_AutoGzip(t *testing.T) {
+ h12Compare{
+ Handler: func(w ResponseWriter, r *Request) {
+ if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" {
+ t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae)
+ }
+ w.Header().Set("Content-Encoding", "gzip")
+ gz := gzip.NewWriter(w)
+ io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.")
+ gz.Close()
+ },
+ }.run(t)
+}
+
+func TestH12_AutoGzip_Disabled(t *testing.T) {
+ h12Compare{
+ Opts: []any{
+ func(tr *Transport) { tr.DisableCompression = true },
+ },
+ Handler: func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
+ if ae := r.Header.Get("Accept-Encoding"); ae != "" {
+ t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
+ }
+ },
+ }.run(t)
+}
+
+// Test304Responses verifies that 304s don't declare that they're
+// chunking in their response headers and aren't allowed to produce
+// output.
+func Test304Responses(t *testing.T) { run(t, test304Responses) }
+func test304Responses(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.WriteHeader(StatusNotModified)
+ _, err := w.Write([]byte("illegal body"))
+ if err != ErrBodyNotAllowed {
+ t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
+ }
+ }))
+ defer cst.close()
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(res.TransferEncoding) > 0 {
+ t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
+ }
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Error(err)
+ }
+ if len(body) > 0 {
+ t.Errorf("got unexpected body %q", string(body))
+ }
+}
+
+func TestH12_ServerEmptyContentLength(t *testing.T) {
+ h12Compare{
+ Handler: func(w ResponseWriter, r *Request) {
+ w.Header()["Content-Type"] = []string{""}
+ io.WriteString(w, "<html><body>hi</body></html>")
+ },
+ }.run(t)
+}
+
+func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
+ h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
+}
+
+func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
+ h12requestContentLength(t, func() io.Reader { return nil }, 0)
+}
+
+func TestH12_RequestContentLength_Unknown(t *testing.T) {
+ h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
+}
+
+func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
+ h12Compare{
+ Handler: func(w ResponseWriter, r *Request) {
+ w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
+ fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
+ },
+ ReqFunc: func(c *Client, url string) (*Response, error) {
+ return c.Post(url, "text/plain", bodyfn())
+ },
+ CheckResponse: func(proto string, res *Response) {
+ if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
+ t.Errorf("Proto %q got length %q; want %q", proto, got, want)
+ }
+ },
+ }.run(t)
+}
+
+// Tests that closing the Request.Cancel channel also while still
+// reading the response body. Issue 13159.
+func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) }
+func testCancelRequestMidBody(t *testing.T, mode testMode) {
+ unblock := make(chan bool)
+ didFlush := make(chan bool, 1)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.WriteString(w, "Hello")
+ w.(Flusher).Flush()
+ didFlush <- true
+ <-unblock
+ io.WriteString(w, ", world.")
+ }))
+ defer close(unblock)
+
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ cancel := make(chan struct{})
+ req.Cancel = cancel
+
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ <-didFlush
+
+ // Read a bit before we cancel. (Issue 13626)
+ // We should have "Hello" at least sitting there.
+ firstRead := make([]byte, 10)
+ n, err := res.Body.Read(firstRead)
+ if err != nil {
+ t.Fatal(err)
+ }
+ firstRead = firstRead[:n]
+
+ close(cancel)
+
+ rest, err := io.ReadAll(res.Body)
+ all := string(firstRead) + string(rest)
+ if all != "Hello" {
+ t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
+ }
+ if err != ExportErrRequestCanceled {
+ t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
+ }
+}
+
+// Tests that clients can send trailers to a server and that the server can read them.
+func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) }
+func testTrailersClientToServer(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ var decl []string
+ for k := range r.Trailer {
+ decl = append(decl, k)
+ }
+ sort.Strings(decl)
+
+ slurp, err := io.ReadAll(r.Body)
+ if err != nil {
+ t.Errorf("Server reading request body: %v", err)
+ }
+ if string(slurp) != "foo" {
+ t.Errorf("Server read request body %q; want foo", slurp)
+ }
+ if r.Trailer == nil {
+ io.WriteString(w, "nil Trailer")
+ } else {
+ fmt.Fprintf(w, "decl: %v, vals: %s, %s",
+ decl,
+ r.Trailer.Get("Client-Trailer-A"),
+ r.Trailer.Get("Client-Trailer-B"))
+ }
+ }))
+
+ var req *Request
+ req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader(
+ eofReaderFunc(func() {
+ req.Trailer["Client-Trailer-A"] = []string{"valuea"}
+ }),
+ strings.NewReader("foo"),
+ eofReaderFunc(func() {
+ req.Trailer["Client-Trailer-B"] = []string{"valueb"}
+ }),
+ ))
+ req.Trailer = Header{
+ "Client-Trailer-A": nil, // to be set later
+ "Client-Trailer-B": nil, // to be set later
+ }
+ req.ContentLength = -1
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
+ t.Error(err)
+ }
+}
+
+// Tests that servers send trailers to a client and that the client can read them.
+func TestTrailersServerToClient(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testTrailersServerToClient(t, mode, false)
+ })
+}
+func TestTrailersServerToClientFlush(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testTrailersServerToClient(t, mode, true)
+ })
+}
+
+func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) {
+ const body = "Some body"
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
+ w.Header().Add("Trailer", "Server-Trailer-C")
+
+ io.WriteString(w, body)
+ if flush {
+ w.(Flusher).Flush()
+ }
+
+ // How handlers set Trailers: declare it ahead of time
+ // with the Trailer header, and then mutate the
+ // Header() of those values later, after the response
+ // has been written (we wrote to w above).
+ w.Header().Set("Server-Trailer-A", "valuea")
+ w.Header().Set("Server-Trailer-C", "valuec") // skipping B
+ w.Header().Set("Server-Trailer-NotDeclared", "should be omitted")
+ }))
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ wantHeader := Header{
+ "Content-Type": {"text/plain; charset=utf-8"},
+ }
+ wantLen := -1
+ if mode == http2Mode && !flush {
+ // In HTTP/1.1, any use of trailers forces HTTP/1.1
+ // chunking and a flush at the first write. That's
+ // unnecessary with HTTP/2's framing, so the server
+ // is able to calculate the length while still sending
+ // trailers afterwards.
+ wantLen = len(body)
+ wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)}
+ }
+ if res.ContentLength != int64(wantLen) {
+ t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen)
+ }
+
+ delete(res.Header, "Date") // irrelevant for test
+ if !reflect.DeepEqual(res.Header, wantHeader) {
+ t.Errorf("Header = %v; want %v", res.Header, wantHeader)
+ }
+
+ if got, want := res.Trailer, (Header{
+ "Server-Trailer-A": nil,
+ "Server-Trailer-B": nil,
+ "Server-Trailer-C": nil,
+ }); !reflect.DeepEqual(got, want) {
+ t.Errorf("Trailer before body read = %v; want %v", got, want)
+ }
+
+ if err := wantBody(res, nil, body); err != nil {
+ t.Fatal(err)
+ }
+
+ if got, want := res.Trailer, (Header{
+ "Server-Trailer-A": {"valuea"},
+ "Server-Trailer-B": nil,
+ "Server-Trailer-C": {"valuec"},
+ }); !reflect.DeepEqual(got, want) {
+ t.Errorf("Trailer after body read = %v; want %v", got, want)
+ }
+}
+
+// Don't allow a Body.Read after Body.Close. Issue 13648.
+func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) }
+func testResponseBodyReadAfterClose(t *testing.T, mode testMode) {
+ const body = "Some body"
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.WriteString(w, body)
+ }))
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ data, err := io.ReadAll(res.Body)
+ if len(data) != 0 || err == nil {
+ t.Fatalf("ReadAll returned %q, %v; want error", data, err)
+ }
+}
+
+func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) }
+func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) {
+ const reqBody = "some request body"
+ const resBody = "some response body"
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ var wg sync.WaitGroup
+ wg.Add(2)
+ didRead := make(chan bool, 1)
+ // Read in one goroutine.
+ go func() {
+ defer wg.Done()
+ data, err := io.ReadAll(r.Body)
+ if string(data) != reqBody {
+ t.Errorf("Handler read %q; want %q", data, reqBody)
+ }
+ if err != nil {
+ t.Errorf("Handler Read: %v", err)
+ }
+ didRead <- true
+ }()
+ // Write in another goroutine.
+ go func() {
+ defer wg.Done()
+ if mode != http2Mode {
+ // our HTTP/1 implementation intentionally
+ // doesn't permit writes during read (mostly
+ // due to it being undefined); if that is ever
+ // relaxed, change this.
+ <-didRead
+ }
+ io.WriteString(w, resBody)
+ }()
+ wg.Wait()
+ }))
+ req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody))
+ req.Header.Add("Expect", "100-continue") // just to complicate things
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ data, err := io.ReadAll(res.Body)
+ defer res.Body.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(data) != resBody {
+ t.Errorf("read %q; want %q", data, resBody)
+ }
+}
+
+func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) }
+func testConnectRequest(t *testing.T, mode testMode) {
+ gotc := make(chan *Request, 1)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ gotc <- r
+ }))
+
+ u, err := url.Parse(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ tests := []struct {
+ req *Request
+ want string
+ }{
+ {
+ req: &Request{
+ Method: "CONNECT",
+ Header: Header{},
+ URL: u,
+ },
+ want: u.Host,
+ },
+ {
+ req: &Request{
+ Method: "CONNECT",
+ Header: Header{},
+ URL: u,
+ Host: "example.com:123",
+ },
+ want: "example.com:123",
+ },
+ }
+
+ for i, tt := range tests {
+ res, err := cst.c.Do(tt.req)
+ if err != nil {
+ t.Errorf("%d. RoundTrip = %v", i, err)
+ continue
+ }
+ res.Body.Close()
+ req := <-gotc
+ if req.Method != "CONNECT" {
+ t.Errorf("method = %q; want CONNECT", req.Method)
+ }
+ if req.Host != tt.want {
+ t.Errorf("Host = %q; want %q", req.Host, tt.want)
+ }
+ if req.URL.Host != tt.want {
+ t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
+ }
+ }
+}
+
+func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) }
+func testTransportUserAgent(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "%q", r.Header["User-Agent"])
+ }))
+
+ either := func(a, b string) string {
+ if mode == http2Mode {
+ return b
+ }
+ return a
+ }
+
+ tests := []struct {
+ setup func(*Request)
+ want string
+ }{
+ {
+ func(r *Request) {},
+ either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`),
+ },
+ {
+ func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") },
+ `["foo/1.2.3"]`,
+ },
+ {
+ func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} },
+ `["single"]`,
+ },
+ {
+ func(r *Request) { r.Header.Set("User-Agent", "") },
+ `[]`,
+ },
+ {
+ func(r *Request) { r.Header["User-Agent"] = nil },
+ `[]`,
+ },
+ }
+ for i, tt := range tests {
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ tt.setup(req)
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Errorf("%d. RoundTrip = %v", i, err)
+ continue
+ }
+ slurp, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Errorf("%d. read body = %v", i, err)
+ continue
+ }
+ if string(slurp) != tt.want {
+ t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want)
+ }
+ }
+}
+
+func TestStarRequestMethod(t *testing.T) {
+ for _, method := range []string{"FOO", "OPTIONS"} {
+ t.Run(method, func(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testStarRequest(t, method, mode)
+ })
+ })
+ }
+}
+func testStarRequest(t *testing.T, method string, mode testMode) {
+ gotc := make(chan *Request, 1)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("foo", "bar")
+ gotc <- r
+ w.(Flusher).Flush()
+ }))
+
+ u, err := url.Parse(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ u.Path = "*"
+
+ req := &Request{
+ Method: method,
+ Header: Header{},
+ URL: u,
+ }
+
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatalf("RoundTrip = %v", err)
+ }
+ res.Body.Close()
+
+ wantFoo := "bar"
+ wantLen := int64(-1)
+ if method == "OPTIONS" {
+ wantFoo = ""
+ wantLen = 0
+ }
+ if res.StatusCode != 200 {
+ t.Errorf("status code = %v; want %d", res.Status, 200)
+ }
+ if res.ContentLength != wantLen {
+ t.Errorf("content length = %v; want %d", res.ContentLength, wantLen)
+ }
+ if got := res.Header.Get("foo"); got != wantFoo {
+ t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo)
+ }
+ select {
+ case req = <-gotc:
+ default:
+ req = nil
+ }
+ if req == nil {
+ if method != "OPTIONS" {
+ t.Fatalf("handler never got request")
+ }
+ return
+ }
+ if req.Method != method {
+ t.Errorf("method = %q; want %q", req.Method, method)
+ }
+ if req.URL.Path != "*" {
+ t.Errorf("URL.Path = %q; want *", req.URL.Path)
+ }
+ if req.RequestURI != "*" {
+ t.Errorf("RequestURI = %q; want *", req.RequestURI)
+ }
+}
+
+// Issue 13957
+func TestTransportDiscardsUnneededConns(t *testing.T) {
+ run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode})
+}
+func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
+ }))
+ defer cst.close()
+
+ var numOpen, numClose int32 // atomic
+
+ tlsConfig := &tls.Config{InsecureSkipVerify: true}
+ tr := &Transport{
+ TLSClientConfig: tlsConfig,
+ DialTLS: func(_, addr string) (net.Conn, error) {
+ time.Sleep(10 * time.Millisecond)
+ rc, err := net.Dial("tcp", addr)
+ if err != nil {
+ return nil, err
+ }
+ atomic.AddInt32(&numOpen, 1)
+ c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
+ return tls.Client(c, tlsConfig), nil
+ },
+ }
+ if err := ExportHttp2ConfigureTransport(tr); err != nil {
+ t.Fatal(err)
+ }
+ defer tr.CloseIdleConnections()
+
+ c := &Client{Transport: tr}
+
+ const N = 10
+ gotBody := make(chan string, N)
+ var wg sync.WaitGroup
+ for i := 0; i < N; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ resp, err := c.Get(cst.ts.URL)
+ if err != nil {
+ // Try to work around spurious connection reset on loaded system.
+ // See golang.org/issue/33585 and golang.org/issue/36797.
+ time.Sleep(10 * time.Millisecond)
+ resp, err = c.Get(cst.ts.URL)
+ if err != nil {
+ t.Errorf("Get: %v", err)
+ return
+ }
+ }
+ defer resp.Body.Close()
+ slurp, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Error(err)
+ }
+ gotBody <- string(slurp)
+ }()
+ }
+ wg.Wait()
+ close(gotBody)
+
+ var last string
+ for got := range gotBody {
+ if last == "" {
+ last = got
+ continue
+ }
+ if got != last {
+ t.Errorf("Response body changed: %q -> %q", last, got)
+ }
+ }
+
+ var open, close int32
+ for i := 0; i < 150; i++ {
+ open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
+ if open < 1 {
+ t.Fatalf("open = %d; want at least", open)
+ }
+ if close == open-1 {
+ // Success
+ return
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+ t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
+}
+
+// tests that Transport doesn't retain a pointer to the provided request.
+func TestTransportGCRequest(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) })
+ t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) })
+ })
+}
+func testTransportGCRequest(t *testing.T, mode testMode, body bool) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.ReadAll(r.Body)
+ if body {
+ io.WriteString(w, "Hello.")
+ }
+ }))
+
+ didGC := make(chan struct{})
+ (func() {
+ body := strings.NewReader("some body")
+ req, _ := NewRequest("POST", cst.ts.URL, body)
+ runtime.SetFinalizer(req, func(*Request) { close(didGC) })
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := io.ReadAll(res.Body); err != nil {
+ t.Fatal(err)
+ }
+ if err := res.Body.Close(); err != nil {
+ t.Fatal(err)
+ }
+ })()
+ timeout := time.NewTimer(5 * time.Second)
+ defer timeout.Stop()
+ for {
+ select {
+ case <-didGC:
+ return
+ case <-time.After(100 * time.Millisecond):
+ runtime.GC()
+ case <-timeout.C:
+ t.Fatal("never saw GC of request")
+ }
+ }
+}
+
+func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) }
+func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "Handler saw headers: %q", r.Header)
+ }), optQuietLog)
+ cst.tr.DisableKeepAlives = true
+
+ tests := []struct {
+ key, val string
+ ok bool
+ }{
+ {"Foo", "capital-key", true}, // verify h2 allows capital keys
+ {"Foo", "foo\x00bar", false}, // \x00 byte in value not allowed
+ {"Foo", "two\nlines", false}, // \n byte in value not allowed
+ {"bogus\nkey", "v", false}, // \n byte also not allowed in key
+ {"A space", "v", false}, // spaces in keys not allowed
+ {"имя", "v", false}, // key must be ascii
+ {"name", "валю", true}, // value may be non-ascii
+ {"", "v", false}, // key must be non-empty
+ {"k", "", true}, // value may be empty
+ }
+ for _, tt := range tests {
+ dialedc := make(chan bool, 1)
+ cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
+ dialedc <- true
+ return net.Dial(netw, addr)
+ }
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ req.Header[tt.key] = []string{tt.val}
+ res, err := cst.c.Do(req)
+ var body []byte
+ if err == nil {
+ body, _ = io.ReadAll(res.Body)
+ res.Body.Close()
+ }
+ var dialed bool
+ select {
+ case <-dialedc:
+ dialed = true
+ default:
+ }
+
+ if !tt.ok && dialed {
+ t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body)
+ } else if (err == nil) != tt.ok {
+ t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok)
+ }
+ }
+}
+
+func TestInterruptWithPanic(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") })
+ t.Run("nil", func(t *testing.T) { testInterruptWithPanic(t, mode, nil) })
+ t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) })
+ })
+}
+func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) {
+ const msg = "hello"
+
+ testDone := make(chan struct{})
+ defer close(testDone)
+
+ var errorLog lockedBytesBuffer
+ gotHeaders := make(chan bool, 1)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.WriteString(w, msg)
+ w.(Flusher).Flush()
+
+ select {
+ case <-gotHeaders:
+ case <-testDone:
+ }
+ panic(panicValue)
+ }), func(ts *httptest.Server) {
+ ts.Config.ErrorLog = log.New(&errorLog, "", 0)
+ })
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ gotHeaders <- true
+ defer res.Body.Close()
+ slurp, err := io.ReadAll(res.Body)
+ if string(slurp) != msg {
+ t.Errorf("client read %q; want %q", slurp, msg)
+ }
+ if err == nil {
+ t.Errorf("client read all successfully; want some error")
+ }
+ logOutput := func() string {
+ errorLog.Lock()
+ defer errorLog.Unlock()
+ return errorLog.String()
+ }
+ wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler
+
+ if err := waitErrCondition(5*time.Second, 10*time.Millisecond, func() error {
+ gotLog := logOutput()
+ if !wantStackLogged {
+ if gotLog == "" {
+ return nil
+ }
+ return fmt.Errorf("want no log output; got: %s", gotLog)
+ }
+ if gotLog == "" {
+ return fmt.Errorf("wanted a stack trace logged; got nothing")
+ }
+ if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 {
+ return fmt.Errorf("output doesn't look like a panic stack trace. Got: %s", gotLog)
+ }
+ return nil
+ }); err != nil {
+ t.Fatal(err)
+ }
+}
+
+type lockedBytesBuffer struct {
+ sync.Mutex
+ bytes.Buffer
+}
+
+func (b *lockedBytesBuffer) Write(p []byte) (int, error) {
+ b.Lock()
+ defer b.Unlock()
+ return b.Buffer.Write(p)
+}
+
+// Issue 15366
+func TestH12_AutoGzipWithDumpResponse(t *testing.T) {
+ h12Compare{
+ Handler: func(w ResponseWriter, r *Request) {
+ h := w.Header()
+ h.Set("Content-Encoding", "gzip")
+ h.Set("Content-Length", "23")
+ io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00")
+ },
+ EarlyCheckResponse: func(proto string, res *Response) {
+ if !res.Uncompressed {
+ t.Errorf("%s: expected Uncompressed to be set", proto)
+ }
+ dump, err := httputil.DumpResponse(res, true)
+ if err != nil {
+ t.Errorf("%s: DumpResponse: %v", proto, err)
+ return
+ }
+ if strings.Contains(string(dump), "Connection: close") {
+ t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump)
+ }
+ if !strings.Contains(string(dump), "FOO") {
+ t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump)
+ }
+ },
+ }.run(t)
+}
+
+// Issue 14607
+func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) }
+func testCloseIdleConnections(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("X-Addr", r.RemoteAddr)
+ }))
+ get := func() string {
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ v := res.Header.Get("X-Addr")
+ if v == "" {
+ t.Fatal("didn't get X-Addr")
+ }
+ return v
+ }
+ a1 := get()
+ cst.tr.CloseIdleConnections()
+ a2 := get()
+ if a1 == a2 {
+ t.Errorf("didn't close connection")
+ }
+}
+
+type noteCloseConn struct {
+ net.Conn
+ closeFunc func()
+}
+
+func (x noteCloseConn) Close() error {
+ x.closeFunc()
+ return x.Conn.Close()
+}
+
+type testErrorReader struct{ t *testing.T }
+
+func (r testErrorReader) Read(p []byte) (n int, err error) {
+ r.t.Error("unexpected Read call")
+ return 0, io.EOF
+}
+
+func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) }
+func testNoSniffExpectRequestBody(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.WriteHeader(StatusUnauthorized)
+ }))
+
+ // Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it.
+ cst.tr.ExpectContinueTimeout = 10 * time.Second
+
+ req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t})
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.ContentLength = 0 // so transport is tempted to sniff it
+ req.Header.Set("Expect", "100-continue")
+ res, err := cst.tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != StatusUnauthorized {
+ t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized)
+ }
+}
+
+func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) }
+func testServerUndeclaredTrailers(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Foo", "Bar")
+ w.Header().Set("Trailer:Foo", "Baz")
+ w.(Flusher).Flush()
+ w.Header().Add("Trailer:Foo", "Baz2")
+ w.Header().Set("Trailer:Bar", "Quux")
+ }))
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := io.Copy(io.Discard, res.Body); err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ delete(res.Header, "Date")
+ delete(res.Header, "Content-Type")
+
+ if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) {
+ t.Errorf("Header = %#v; want %#v", res.Header, want)
+ }
+ if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) {
+ t.Errorf("Trailer = %#v; want %#v", res.Trailer, want)
+ }
+}
+
+func TestBadResponseAfterReadingBody(t *testing.T) {
+ run(t, testBadResponseAfterReadingBody, []testMode{http1Mode})
+}
+func testBadResponseAfterReadingBody(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ _, err := io.Copy(io.Discard, r.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ c, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ fmt.Fprintln(c, "some bogus crap")
+ }))
+
+ closes := 0
+ res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
+ if err == nil {
+ res.Body.Close()
+ t.Fatal("expected an error to be returned from Post")
+ }
+ if closes != 1 {
+ t.Errorf("closes = %d; want 1", closes)
+ }
+}
+
+func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) }
+func testWriteHeader0(t *testing.T, mode testMode) {
+ gotpanic := make(chan bool, 1)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ defer close(gotpanic)
+ defer func() {
+ if e := recover(); e != nil {
+ got := fmt.Sprintf("%T, %v", e, e)
+ want := "string, invalid WriteHeader code 0"
+ if got != want {
+ t.Errorf("unexpected panic value:\n got: %v\nwant: %v\n", got, want)
+ }
+ gotpanic <- true
+
+ // Set an explicit 503. This also tests that the WriteHeader call panics
+ // before it recorded that an explicit value was set and that bogus
+ // value wasn't stuck.
+ w.WriteHeader(503)
+ }
+ }()
+ w.WriteHeader(0)
+ }))
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 503 {
+ t.Errorf("Response: %v %q; want 503", res.StatusCode, res.Status)
+ }
+ if !<-gotpanic {
+ t.Error("expected panic in handler")
+ }
+}
+
+// Issue 23010: don't be super strict checking WriteHeader's code if
+// it's not even valid to call WriteHeader then anyway.
+func TestWriteHeaderNoCodeCheck(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testWriteHeaderAfterWrite(t, mode, false)
+ })
+}
+func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) {
+ testWriteHeaderAfterWrite(t, http1Mode, true)
+}
+func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) {
+ var errorLog lockedBytesBuffer
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if hijack {
+ conn, _, _ := w.(Hijacker).Hijack()
+ defer conn.Close()
+ conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo"))
+ w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010
+ conn.Write([]byte("bar"))
+ return
+ }
+ io.WriteString(w, "foo")
+ w.(Flusher).Flush()
+ w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010
+ io.WriteString(w, "bar")
+ }), func(ts *httptest.Server) {
+ ts.Config.ErrorLog = log.New(&errorLog, "", 0)
+ })
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := string(body), "foobar"; got != want {
+ t.Errorf("got = %q; want %q", got, want)
+ }
+
+ // Also check the stderr output:
+ if mode == http2Mode {
+ // TODO: also emit this log message for HTTP/2?
+ // We historically haven't, so don't check.
+ return
+ }
+ gotLog := strings.TrimSpace(errorLog.String())
+ wantLog := "http: superfluous response.WriteHeader call from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
+ if hijack {
+ wantLog = "http: response.WriteHeader on hijacked connection from net/http_test.testWriteHeaderAfterWrite.func1 (clientserver_test.go:"
+ }
+ if !strings.HasPrefix(gotLog, wantLog) {
+ t.Errorf("stderr output = %q; want %q", gotLog, wantLog)
+ }
+}
+
+func TestBidiStreamReverseProxy(t *testing.T) {
+ run(t, testBidiStreamReverseProxy, []testMode{http2Mode})
+}
+func testBidiStreamReverseProxy(t *testing.T, mode testMode) {
+ backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if _, err := io.Copy(w, r.Body); err != nil {
+ log.Printf("bidi backend copy: %v", err)
+ }
+ }))
+
+ backURL, err := url.Parse(backend.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ rp := httputil.NewSingleHostReverseProxy(backURL)
+ rp.Transport = backend.tr
+ proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ rp.ServeHTTP(w, r)
+ }))
+
+ bodyRes := make(chan any, 1) // error or hash.Hash
+ pr, pw := io.Pipe()
+ req, _ := NewRequest("PUT", proxy.ts.URL, pr)
+ const size = 4 << 20
+ go func() {
+ h := sha1.New()
+ _, err := io.CopyN(io.MultiWriter(h, pw), rand.Reader, size)
+ go pw.Close()
+ if err != nil {
+ bodyRes <- err
+ } else {
+ bodyRes <- h
+ }
+ }()
+ res, err := backend.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ hgot := sha1.New()
+ n, err := io.Copy(hgot, res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != size {
+ t.Fatalf("got %d bytes; want %d", n, size)
+ }
+ select {
+ case v := <-bodyRes:
+ switch v := v.(type) {
+ default:
+ t.Fatalf("body copy: %v", err)
+ case hash.Hash:
+ if !bytes.Equal(v.Sum(nil), hgot.Sum(nil)) {
+ t.Errorf("written bytes didn't match received bytes")
+ }
+ }
+ case <-time.After(10 * time.Second):
+ t.Fatal("timeout")
+ }
+
+}
+
+// Always use HTTP/1.1 for WebSocket upgrades.
+func TestH12_WebSocketUpgrade(t *testing.T) {
+ h12Compare{
+ Handler: func(w ResponseWriter, r *Request) {
+ h := w.Header()
+ h.Set("Foo", "bar")
+ },
+ ReqFunc: func(c *Client, url string) (*Response, error) {
+ req, _ := NewRequest("GET", url, nil)
+ req.Header.Set("Connection", "Upgrade")
+ req.Header.Set("Upgrade", "WebSocket")
+ return c.Do(req)
+ },
+ EarlyCheckResponse: func(proto string, res *Response) {
+ if res.Proto != "HTTP/1.1" {
+ t.Errorf("%s: expected HTTP/1.1, got %q", proto, res.Proto)
+ }
+ res.Proto = "HTTP/IGNORE" // skip later checks that Proto must be 1.1 vs 2.0
+ },
+ }.run(t)
+}
+
+func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) }
+func testIdentityTransferEncoding(t *testing.T, mode testMode) {
+ const body = "body"
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ gotBody, _ := io.ReadAll(r.Body)
+ if got, want := string(gotBody), body; got != want {
+ t.Errorf("got request body = %q; want %q", got, want)
+ }
+ w.Header().Set("Transfer-Encoding", "identity")
+ w.WriteHeader(StatusOK)
+ w.(Flusher).Flush()
+ io.WriteString(w, body)
+ }))
+ req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body))
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ gotBody, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := string(gotBody), body; got != want {
+ t.Errorf("got response body = %q; want %q", got, want)
+ }
+}
+
+func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) }
+func testEarlyHintsRequest(t *testing.T, mode testMode) {
+ var wg sync.WaitGroup
+ wg.Add(1)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ h := w.Header()
+
+ h.Add("Content-Length", "123") // must be ignored
+ h.Add("Link", "</style.css>; rel=preload; as=style")
+ h.Add("Link", "</script.js>; rel=preload; as=script")
+ w.WriteHeader(StatusEarlyHints)
+
+ wg.Wait()
+
+ h.Add("Link", "</foo.js>; rel=preload; as=script")
+ w.WriteHeader(StatusEarlyHints)
+
+ w.Write([]byte("Hello"))
+ }))
+
+ checkLinkHeaders := func(t *testing.T, expected, got []string) {
+ t.Helper()
+
+ if len(expected) != len(got) {
+ t.Errorf("got %d expected %d", len(got), len(expected))
+ }
+
+ for i := range expected {
+ if expected[i] != got[i] {
+ t.Errorf("got %q expected %q", got[i], expected[i])
+ }
+ }
+ }
+
+ checkExcludedHeaders := func(t *testing.T, header textproto.MIMEHeader) {
+ t.Helper()
+
+ for _, h := range []string{"Content-Length", "Transfer-Encoding"} {
+ if v, ok := header[h]; ok {
+ t.Errorf("%s is %q; must not be sent", h, v)
+ }
+ }
+ }
+
+ var respCounter uint8
+ trace := &httptrace.ClientTrace{
+ Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+ switch respCounter {
+ case 0:
+ checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
+ checkExcludedHeaders(t, header)
+
+ wg.Done()
+ case 1:
+ checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
+ checkExcludedHeaders(t, header)
+
+ default:
+ t.Error("Unexpected 1xx response")
+ }
+
+ respCounter++
+
+ return nil
+ },
+ }
+ req, _ := NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", cst.ts.URL, nil)
+
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+
+ checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
+ if cl := res.Header.Get("Content-Length"); cl != "123" {
+ t.Errorf("Content-Length is %q; want 123", cl)
+ }
+
+ body, _ := io.ReadAll(res.Body)
+ if string(body) != "Hello" {
+ t.Errorf("Read body %q; want Hello", body)
+ }
+}