diff options
Diffstat (limited to 'src/net/http/transport_test.go')
-rw-r--r-- | src/net/http/transport_test.go | 6675 |
1 files changed, 6675 insertions, 0 deletions
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go new file mode 100644 index 0000000..f4896c5 --- /dev/null +++ b/src/net/http/transport_test.go @@ -0,0 +1,6675 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tests for transport.go. +// +// More tests are in clientserver_test.go (for things testing both client & server for both +// HTTP/1 and HTTP/2). This + +package http_test + +import ( + "bufio" + "bytes" + "compress/gzip" + "context" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "encoding/binary" + "errors" + "fmt" + "go/token" + "internal/nettrace" + "io" + "log" + mrand "math/rand" + "net" + . "net/http" + "net/http/httptest" + "net/http/httptrace" + "net/http/httputil" + "net/http/internal/testcert" + "net/textproto" + "net/url" + "os" + "reflect" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "testing/iotest" + "time" + + "golang.org/x/net/http/httpguts" +) + +// TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close +// and then verify that the final 2 responses get errors back. + +// hostPortHandler writes back the client's "host:port". +var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) { + if r.FormValue("close") == "true" { + w.Header().Set("Connection", "close") + } + w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close)) + w.Write([]byte(r.RemoteAddr)) + + // Include the address of the net.Conn in addition to the RemoteAddr, + // in case kernels reuse source ports quickly (see Issue 52450) + if c, ok := ResponseWriterConnForTesting(w); ok { + fmt.Fprintf(w, ", %T %p", c, c) + } +}) + +// testCloseConn is a net.Conn tracked by a testConnSet. +type testCloseConn struct { + net.Conn + set *testConnSet +} + +func (c *testCloseConn) Close() error { + c.set.remove(c) + return c.Conn.Close() +} + +// testConnSet tracks a set of TCP connections and whether they've +// been closed. +type testConnSet struct { + t *testing.T + mu sync.Mutex // guards closed and list + closed map[net.Conn]bool + list []net.Conn // in order created +} + +func (tcs *testConnSet) insert(c net.Conn) { + tcs.mu.Lock() + defer tcs.mu.Unlock() + tcs.closed[c] = false + tcs.list = append(tcs.list, c) +} + +func (tcs *testConnSet) remove(c net.Conn) { + tcs.mu.Lock() + defer tcs.mu.Unlock() + tcs.closed[c] = true +} + +// some tests use this to manage raw tcp connections for later inspection +func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) { + connSet := &testConnSet{ + t: t, + closed: make(map[net.Conn]bool), + } + dial := func(n, addr string) (net.Conn, error) { + c, err := net.Dial(n, addr) + if err != nil { + return nil, err + } + tc := &testCloseConn{c, connSet} + connSet.insert(tc) + return tc, nil + } + return connSet, dial +} + +func (tcs *testConnSet) check(t *testing.T) { + tcs.mu.Lock() + defer tcs.mu.Unlock() + for i := 4; i >= 0; i-- { + for i, c := range tcs.list { + if tcs.closed[c] { + continue + } + if i != 0 { + tcs.mu.Unlock() + time.Sleep(50 * time.Millisecond) + tcs.mu.Lock() + continue + } + t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list)) + } + } +} + +func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) } +func testReuseRequest(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Write([]byte("{}")) + })).ts + + c := ts.Client() + req, _ := NewRequest("GET", ts.URL, nil) + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + err = res.Body.Close() + if err != nil { + t.Fatal(err) + } + + res, err = c.Do(req) + if err != nil { + t.Fatal(err) + } + err = res.Body.Close() + if err != nil { + t.Fatal(err) + } +} + +// Two subsequent requests and verify their response is the same. +// The response from the server is our own IP:port +func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) } +func testTransportKeepAlives(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts + + c := ts.Client() + for _, disableKeepAlive := range []bool{false, true} { + c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive + fetch := func(n int) string { + res, err := c.Get(ts.URL) + if err != nil { + t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err) + } + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err) + } + return string(body) + } + + body1 := fetch(1) + body2 := fetch(2) + + bodiesDiffer := body1 != body2 + if bodiesDiffer != disableKeepAlive { + t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", + disableKeepAlive, bodiesDiffer, body1, body2) + } + } +} + +func TestTransportConnectionCloseOnResponse(t *testing.T) { + run(t, testTransportConnectionCloseOnResponse) +} +func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts + + connSet, testDial := makeTestDial(t) + + c := ts.Client() + tr := c.Transport.(*Transport) + tr.Dial = testDial + + for _, connectionClose := range []bool{false, true} { + fetch := func(n int) string { + req := new(Request) + var err error + req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose)) + if err != nil { + t.Fatalf("URL parse error: %v", err) + } + req.Method = "GET" + req.Proto = "HTTP/1.1" + req.ProtoMajor = 1 + req.ProtoMinor = 1 + + res, err := c.Do(req) + if err != nil { + t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) + } + defer res.Body.Close() + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) + } + return string(body) + } + + body1 := fetch(1) + body2 := fetch(2) + bodiesDiffer := body1 != body2 + if bodiesDiffer != connectionClose { + t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", + connectionClose, bodiesDiffer, body1, body2) + } + + tr.CloseIdleConnections() + } + + connSet.check(t) +} + +// TestTransportConnectionCloseOnRequest tests that the Transport's doesn't reuse +// an underlying TCP connection after making an http.Request with Request.Close set. +// +// It tests the behavior by making an HTTP request to a server which +// describes the source source connection it got (remote port number + +// address of its net.Conn). +func TestTransportConnectionCloseOnRequest(t *testing.T) { + run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode}) +} +func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts + + connSet, testDial := makeTestDial(t) + + c := ts.Client() + tr := c.Transport.(*Transport) + tr.Dial = testDial + for _, reqClose := range []bool{false, true} { + fetch := func(n int) string { + req := new(Request) + var err error + req.URL, err = url.Parse(ts.URL) + if err != nil { + t.Fatalf("URL parse error: %v", err) + } + req.Method = "GET" + req.Proto = "HTTP/1.1" + req.ProtoMajor = 1 + req.ProtoMinor = 1 + req.Close = reqClose + + res, err := c.Do(req) + if err != nil { + t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err) + } + if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want { + t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v", + reqClose, got, !reqClose) + } + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err) + } + return string(body) + } + + body1 := fetch(1) + body2 := fetch(2) + + got := 1 + if body1 != body2 { + got++ + } + want := 1 + if reqClose { + want = 2 + } + if got != want { + t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q", + reqClose, got, want, body1, body2) + } + + tr.CloseIdleConnections() + } + + connSet.check(t) +} + +// if the Transport's DisableKeepAlives is set, all requests should +// send Connection: close. +// HTTP/1-only (Connection: close doesn't exist in h2) +func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) { + run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode}) +} +func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts + + c := ts.Client() + c.Transport.(*Transport).DisableKeepAlives = true + + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.Header.Get("X-Saw-Close") != "true" { + t.Errorf("handler didn't see Connection: close ") + } +} + +// Test that Transport only sends one "Connection: close", regardless of +// how "close" was indicated. +func TestTransportRespectRequestWantsClose(t *testing.T) { + run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode}) +} +func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) { + tests := []struct { + disableKeepAlives bool + close bool + }{ + {disableKeepAlives: false, close: false}, + {disableKeepAlives: false, close: true}, + {disableKeepAlives: true, close: false}, + {disableKeepAlives: true, close: true}, + } + + for _, tc := range tests { + t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close), + func(t *testing.T) { + ts := newClientServerTest(t, mode, hostPortHandler).ts + + c := ts.Client() + c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + count := 0 + trace := &httptrace.ClientTrace{ + WroteHeaderField: func(key string, field []string) { + if key != "Connection" { + return + } + if httpguts.HeaderValuesContainsToken(field, "close") { + count += 1 + } + }, + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + req.Close = tc.close + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want { + t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count) + } + }) + } + +} + +func TestTransportIdleCacheKeys(t *testing.T) { + run(t, testTransportIdleCacheKeys, []testMode{http1Mode}) +} +func testTransportIdleCacheKeys(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts + c := ts.Client() + tr := c.Transport.(*Transport) + + if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { + t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) + } + + resp, err := c.Get(ts.URL) + if err != nil { + t.Error(err) + } + io.ReadAll(resp.Body) + + keys := tr.IdleConnKeysForTesting() + if e, g := 1, len(keys); e != g { + t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g) + } + + if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e { + t.Errorf("Expected idle cache key %q; got %q", e, keys[0]) + } + + tr.CloseIdleConnections() + if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { + t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) + } +} + +// Tests that the HTTP transport re-uses connections when a client +// reads to the end of a response Body without closing it. +func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) } +func testTransportReadToEndReusesConn(t *testing.T, mode testMode) { + const msg = "foobar" + + var addrSeen map[string]int + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + addrSeen[r.RemoteAddr]++ + if r.URL.Path == "/chunked/" { + w.WriteHeader(200) + w.(Flusher).Flush() + } else { + w.Header().Set("Content-Length", strconv.Itoa(len(msg))) + w.WriteHeader(200) + } + w.Write([]byte(msg)) + })).ts + + for pi, path := range []string{"/content-length/", "/chunked/"} { + wantLen := []int{len(msg), -1}[pi] + addrSeen = make(map[string]int) + for i := 0; i < 3; i++ { + res, err := ts.Client().Get(ts.URL + path) + if err != nil { + t.Errorf("Get %s: %v", path, err) + continue + } + // We want to close this body eventually (before the + // defer afterTest at top runs), but not before the + // len(addrSeen) check at the bottom of this test, + // since Closing this early in the loop would risk + // making connections be re-used for the wrong reason. + defer res.Body.Close() + + if res.ContentLength != int64(wantLen) { + t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen) + } + got, err := io.ReadAll(res.Body) + if string(got) != msg || err != nil { + t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg) + } + } + if len(addrSeen) != 1 { + t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen)) + } + } +} + +func TestTransportMaxPerHostIdleConns(t *testing.T) { + run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode}) +} +func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) { + stop := make(chan struct{}) // stop marks the exit of main Test goroutine + defer close(stop) + + resch := make(chan string) + gotReq := make(chan bool) + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + gotReq <- true + var msg string + select { + case <-stop: + return + case msg = <-resch: + } + _, err := w.Write([]byte(msg)) + if err != nil { + t.Errorf("Write: %v", err) + return + } + })).ts + + c := ts.Client() + tr := c.Transport.(*Transport) + maxIdleConnsPerHost := 2 + tr.MaxIdleConnsPerHost = maxIdleConnsPerHost + + // Start 3 outstanding requests and wait for the server to get them. + // Their responses will hang until we write to resch, though. + donech := make(chan bool) + doReq := func() { + defer func() { + select { + case <-stop: + return + case donech <- t.Failed(): + } + }() + resp, err := c.Get(ts.URL) + if err != nil { + t.Error(err) + return + } + if _, err := io.ReadAll(resp.Body); err != nil { + t.Errorf("ReadAll: %v", err) + return + } + } + go doReq() + <-gotReq + go doReq() + <-gotReq + go doReq() + <-gotReq + + if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { + t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g) + } + + resch <- "res1" + <-donech + keys := tr.IdleConnKeysForTesting() + if e, g := 1, len(keys); e != g { + t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g) + } + addr := ts.Listener.Addr().String() + cacheKey := "|http|" + addr + if keys[0] != cacheKey { + t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0]) + } + if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g { + t.Errorf("after first response, expected %d idle conns; got %d", e, g) + } + + resch <- "res2" + <-donech + if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w { + t.Errorf("after second response, idle conns = %d; want %d", g, w) + } + + resch <- "res3" + <-donech + if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w { + t.Errorf("after third response, idle conns = %d; want %d", g, w) + } +} + +func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) { + run(t, testTransportMaxConnsPerHostIncludeDialInProgress) +} +func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + _, err := w.Write([]byte("foo")) + if err != nil { + t.Fatalf("Write: %v", err) + } + })).ts + c := ts.Client() + tr := c.Transport.(*Transport) + dialStarted := make(chan struct{}) + stallDial := make(chan struct{}) + tr.Dial = func(network, addr string) (net.Conn, error) { + dialStarted <- struct{}{} + <-stallDial + return net.Dial(network, addr) + } + + tr.DisableKeepAlives = true + tr.MaxConnsPerHost = 1 + + preDial := make(chan struct{}) + reqComplete := make(chan struct{}) + doReq := func(reqId string) { + req, _ := NewRequest("GET", ts.URL, nil) + trace := &httptrace.ClientTrace{ + GetConn: func(hostPort string) { + preDial <- struct{}{} + }, + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + resp, err := tr.RoundTrip(req) + if err != nil { + t.Errorf("unexpected error for request %s: %v", reqId, err) + } + _, err = io.ReadAll(resp.Body) + if err != nil { + t.Errorf("unexpected error for request %s: %v", reqId, err) + } + reqComplete <- struct{}{} + } + // get req1 to dial-in-progress + go doReq("req1") + <-preDial + <-dialStarted + + // get req2 to waiting on conns per host to go down below max + go doReq("req2") + <-preDial + select { + case <-dialStarted: + t.Error("req2 dial started while req1 dial in progress") + return + default: + } + + // let req1 complete + stallDial <- struct{}{} + <-reqComplete + + // let req2 complete + <-dialStarted + stallDial <- struct{}{} + <-reqComplete +} + +func TestTransportMaxConnsPerHost(t *testing.T) { + run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode}) +} +func testTransportMaxConnsPerHost(t *testing.T, mode testMode) { + CondSkipHTTP2(t) + + h := HandlerFunc(func(w ResponseWriter, r *Request) { + _, err := w.Write([]byte("foo")) + if err != nil { + t.Fatalf("Write: %v", err) + } + }) + + ts := newClientServerTest(t, mode, h).ts + c := ts.Client() + tr := c.Transport.(*Transport) + tr.MaxConnsPerHost = 1 + + mu := sync.Mutex{} + var conns []net.Conn + var dialCnt, gotConnCnt, tlsHandshakeCnt int32 + tr.Dial = func(network, addr string) (net.Conn, error) { + atomic.AddInt32(&dialCnt, 1) + c, err := net.Dial(network, addr) + mu.Lock() + defer mu.Unlock() + conns = append(conns, c) + return c, err + } + + doReq := func() { + trace := &httptrace.ClientTrace{ + GotConn: func(connInfo httptrace.GotConnInfo) { + if !connInfo.Reused { + atomic.AddInt32(&gotConnCnt, 1) + } + }, + TLSHandshakeStart: func() { + atomic.AddInt32(&tlsHandshakeCnt, 1) + }, + } + req, _ := NewRequest("GET", ts.URL, nil) + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + + resp, err := c.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body failed: %v", err) + } + } + + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + doReq() + }() + } + wg.Wait() + + expected := int32(tr.MaxConnsPerHost) + if dialCnt != expected { + t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected) + } + if gotConnCnt != expected { + t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected) + } + if ts.TLS != nil && tlsHandshakeCnt != expected { + t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected) + } + + if t.Failed() { + t.FailNow() + } + + mu.Lock() + for _, c := range conns { + c.Close() + } + conns = nil + mu.Unlock() + tr.CloseIdleConnections() + + doReq() + expected++ + if dialCnt != expected { + t.Errorf("round 2: too many dials: %d", dialCnt) + } + if gotConnCnt != expected { + t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected) + } + if ts.TLS != nil && tlsHandshakeCnt != expected { + t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected) + } +} + +func TestTransportRemovesDeadIdleConnections(t *testing.T) { + run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode}) +} +func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + io.WriteString(w, r.RemoteAddr) + })).ts + + c := ts.Client() + tr := c.Transport.(*Transport) + + doReq := func(name string) string { + // Do a POST instead of a GET to prevent the Transport's + // idempotent request retry logic from kicking in... + res, err := c.Post(ts.URL, "", nil) + if err != nil { + t.Fatalf("%s: %v", name, err) + } + if res.StatusCode != 200 { + t.Fatalf("%s: %v", name, res.Status) + } + defer res.Body.Close() + slurp, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("%s: %v", name, err) + } + return string(slurp) + } + + first := doReq("first") + keys1 := tr.IdleConnKeysForTesting() + + ts.CloseClientConnections() + + var keys2 []string + if !waitCondition(3*time.Second, 50*time.Millisecond, func() bool { + keys2 = tr.IdleConnKeysForTesting() + return len(keys2) == 0 + }) { + t.Fatalf("Transport didn't notice idle connection's death.\nbefore: %q\n after: %q\n", keys1, keys2) + } + + second := doReq("second") + if first == second { + t.Errorf("expected a different connection between requests. got %q both times", first) + } +} + +// Test that the Transport notices when a server hangs up on its +// unexpectedly (a keep-alive connection is closed). +func TestTransportServerClosingUnexpectedly(t *testing.T) { + run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode}) +} +func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts + c := ts.Client() + + fetch := func(n, retries int) string { + condFatalf := func(format string, arg ...any) { + if retries <= 0 { + t.Fatalf(format, arg...) + } + t.Logf("retrying shortly after expected error: "+format, arg...) + time.Sleep(time.Second / time.Duration(retries)) + } + for retries >= 0 { + retries-- + res, err := c.Get(ts.URL) + if err != nil { + condFatalf("error in req #%d, GET: %v", n, err) + continue + } + body, err := io.ReadAll(res.Body) + if err != nil { + condFatalf("error in req #%d, ReadAll: %v", n, err) + continue + } + res.Body.Close() + return string(body) + } + panic("unreachable") + } + + body1 := fetch(1, 0) + body2 := fetch(2, 0) + + // Close all the idle connections in a way that's similar to + // the server hanging up on us. We don't use + // httptest.Server.CloseClientConnections because it's + // best-effort and stops blocking after 5 seconds. On a loaded + // machine running many tests concurrently it's possible for + // that method to be async and cause the body3 fetch below to + // run on an old connection. This function is synchronous. + ExportCloseTransportConnsAbruptly(c.Transport.(*Transport)) + + body3 := fetch(3, 5) + + if body1 != body2 { + t.Errorf("expected body1 and body2 to be equal") + } + if body2 == body3 { + t.Errorf("expected body2 and body3 to be different") + } +} + +// Test for https://golang.org/issue/2616 (appropriate issue number) +// This fails pretty reliably with GOMAXPROCS=100 or something high. +func TestStressSurpriseServerCloses(t *testing.T) { + run(t, testStressSurpriseServerCloses, []testMode{http1Mode}) +} +func testStressSurpriseServerCloses(t *testing.T, mode testMode) { + if testing.Short() { + t.Skip("skipping test in short mode") + } + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "5") + w.Header().Set("Content-Type", "text/plain") + w.Write([]byte("Hello")) + w.(Flusher).Flush() + conn, buf, _ := w.(Hijacker).Hijack() + buf.Flush() + conn.Close() + })).ts + c := ts.Client() + + // Do a bunch of traffic from different goroutines. Send to activityc + // after each request completes, regardless of whether it failed. + // If these are too high, OS X exhausts its ephemeral ports + // and hangs waiting for them to transition TCP states. That's + // not what we want to test. TODO(bradfitz): use an io.Pipe + // dialer for this test instead? + const ( + numClients = 20 + reqsPerClient = 25 + ) + activityc := make(chan bool) + for i := 0; i < numClients; i++ { + go func() { + for i := 0; i < reqsPerClient; i++ { + res, err := c.Get(ts.URL) + if err == nil { + // We expect errors since the server is + // hanging up on us after telling us to + // send more requests, so we don't + // actually care what the error is. + // But we want to close the body in cases + // where we won the race. + res.Body.Close() + } + if !<-activityc { // Receives false when close(activityc) is executed + return + } + } + }() + } + + // Make sure all the request come back, one way or another. + for i := 0; i < numClients*reqsPerClient; i++ { + select { + case activityc <- true: + case <-time.After(5 * time.Second): + close(activityc) + t.Fatalf("presumed deadlock; no HTTP client activity seen in awhile") + } + } +} + +// TestTransportHeadResponses verifies that we deal with Content-Lengths +// with no bodies properly +func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) } +func testTransportHeadResponses(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method != "HEAD" { + panic("expected HEAD; got " + r.Method) + } + w.Header().Set("Content-Length", "123") + w.WriteHeader(200) + })).ts + c := ts.Client() + + for i := 0; i < 2; i++ { + res, err := c.Head(ts.URL) + if err != nil { + t.Errorf("error on loop %d: %v", i, err) + continue + } + if e, g := "123", res.Header.Get("Content-Length"); e != g { + t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g) + } + if e, g := int64(123), res.ContentLength; e != g { + t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g) + } + if all, err := io.ReadAll(res.Body); err != nil { + t.Errorf("loop %d: Body ReadAll: %v", i, err) + } else if len(all) != 0 { + t.Errorf("Bogus body %q", all) + } + } +} + +// TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding +// on responses to HEAD requests. +func TestTransportHeadChunkedResponse(t *testing.T) { + run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel) +} +func testTransportHeadChunkedResponse(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method != "HEAD" { + panic("expected HEAD; got " + r.Method) + } + w.Header().Set("Transfer-Encoding", "chunked") // client should ignore + w.Header().Set("x-client-ipport", r.RemoteAddr) + w.WriteHeader(200) + })).ts + c := ts.Client() + + // Ensure that we wait for the readLoop to complete before + // calling Head again + didRead := make(chan bool) + SetReadLoopBeforeNextReadHook(func() { didRead <- true }) + defer SetReadLoopBeforeNextReadHook(nil) + + res1, err := c.Head(ts.URL) + <-didRead + + if err != nil { + t.Fatalf("request 1 error: %v", err) + } + + res2, err := c.Head(ts.URL) + <-didRead + + if err != nil { + t.Fatalf("request 2 error: %v", err) + } + if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 { + t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2) + } +} + +var roundTripTests = []struct { + accept string + expectAccept string + compressed bool +}{ + // Requests with no accept-encoding header use transparent compression + {"", "gzip", false}, + // Requests with other accept-encoding should pass through unmodified + {"foo", "foo", false}, + // Requests with accept-encoding == gzip should be passed through + {"gzip", "gzip", true}, +} + +// Test that the modification made to the Request by the RoundTripper is cleaned up +func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) } +func testRoundTripGzip(t *testing.T, mode testMode) { + const responseBody = "test response body" + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { + accept := req.Header.Get("Accept-Encoding") + if expect := req.FormValue("expect_accept"); accept != expect { + t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q", + req.FormValue("testnum"), accept, expect) + } + if accept == "gzip" { + rw.Header().Set("Content-Encoding", "gzip") + gz := gzip.NewWriter(rw) + gz.Write([]byte(responseBody)) + gz.Close() + } else { + rw.Header().Set("Content-Encoding", accept) + rw.Write([]byte(responseBody)) + } + })).ts + tr := ts.Client().Transport.(*Transport) + + for i, test := range roundTripTests { + // Test basic request (no accept-encoding) + req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil) + if test.accept != "" { + req.Header.Set("Accept-Encoding", test.accept) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Errorf("%d. RoundTrip: %v", i, err) + continue + } + var body []byte + if test.compressed { + var r *gzip.Reader + r, err = gzip.NewReader(res.Body) + if err != nil { + t.Errorf("%d. gzip NewReader: %v", i, err) + continue + } + body, err = io.ReadAll(r) + res.Body.Close() + } else { + body, err = io.ReadAll(res.Body) + } + if err != nil { + t.Errorf("%d. Error: %q", i, err) + continue + } + if g, e := string(body), responseBody; g != e { + t.Errorf("%d. body = %q; want %q", i, g, e) + } + if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e { + t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e) + } + if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e { + t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e) + } + } + +} + +func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) } +func testTransportGzip(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("https://go.dev/issue/56020") + } + const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + const nRandBytes = 1024 * 1024 + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { + if req.Method == "HEAD" { + if g := req.Header.Get("Accept-Encoding"); g != "" { + t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g) + } + return + } + if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e { + t.Errorf("Accept-Encoding = %q, want %q", g, e) + } + rw.Header().Set("Content-Encoding", "gzip") + + var w io.Writer = rw + var buf bytes.Buffer + if req.FormValue("chunked") == "0" { + w = &buf + defer io.Copy(rw, &buf) + defer func() { + rw.Header().Set("Content-Length", strconv.Itoa(buf.Len())) + }() + } + gz := gzip.NewWriter(w) + gz.Write([]byte(testString)) + if req.FormValue("body") == "large" { + io.CopyN(gz, rand.Reader, nRandBytes) + } + gz.Close() + })).ts + c := ts.Client() + + for _, chunked := range []string{"1", "0"} { + // First fetch something large, but only read some of it. + res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked) + if err != nil { + t.Fatalf("large get: %v", err) + } + buf := make([]byte, len(testString)) + n, err := io.ReadFull(res.Body, buf) + if err != nil { + t.Fatalf("partial read of large response: size=%d, %v", n, err) + } + if e, g := testString, string(buf); e != g { + t.Errorf("partial read got %q, expected %q", g, e) + } + res.Body.Close() + // Read on the body, even though it's closed + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected error post-closed large Read; got = %d, %v", n, err) + } + + // Then something small. + res, err = c.Get(ts.URL + "/?chunked=" + chunked) + if err != nil { + t.Fatal(err) + } + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if g, e := string(body), testString; g != e { + t.Fatalf("body = %q; want %q", g, e) + } + if g, e := res.Header.Get("Content-Encoding"), ""; g != e { + t.Fatalf("Content-Encoding = %q; want %q", g, e) + } + + // Read on the body after it's been fully read: + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err) + } + res.Body.Close() + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected Read error after Close; got %d, %v", n, err) + } + } + + // And a HEAD request too, because they're always weird. + res, err := c.Head(ts.URL) + if err != nil { + t.Fatalf("Head: %v", err) + } + if res.StatusCode != 200 { + t.Errorf("Head status=%d; want=200", res.StatusCode) + } +} + +// If a request has Expect:100-continue header, the request blocks sending body until the first response. +// Premature consumption of the request body should not be occurred. +func TestTransportExpect100Continue(t *testing.T) { + run(t, testTransportExpect100Continue, []testMode{http1Mode}) +} +func testTransportExpect100Continue(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { + switch req.URL.Path { + case "/100": + // This endpoint implicitly responds 100 Continue and reads body. + if _, err := io.Copy(io.Discard, req.Body); err != nil { + t.Error("Failed to read Body", err) + } + rw.WriteHeader(StatusOK) + case "/200": + // Go 1.5 adds Connection: close header if the client expect + // continue but not entire request body is consumed. + rw.WriteHeader(StatusOK) + case "/500": + rw.WriteHeader(StatusInternalServerError) + case "/keepalive": + // This hijacked endpoint responds error without Connection:close. + _, bufrw, err := rw.(Hijacker).Hijack() + if err != nil { + log.Fatal(err) + } + bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n") + bufrw.WriteString("Content-Length: 0\r\n\r\n") + bufrw.Flush() + case "/timeout": + // This endpoint tries to read body without 100 (Continue) response. + // After ExpectContinueTimeout, the reading will be started. + conn, bufrw, err := rw.(Hijacker).Hijack() + if err != nil { + log.Fatal(err) + } + if _, err := io.CopyN(io.Discard, bufrw, req.ContentLength); err != nil { + t.Error("Failed to read Body", err) + } + bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n") + bufrw.Flush() + conn.Close() + } + + })).ts + + tests := []struct { + path string + body []byte + sent int + status int + }{ + {path: "/100", body: []byte("hello"), sent: 5, status: 200}, // Got 100 followed by 200, entire body is sent. + {path: "/200", body: []byte("hello"), sent: 0, status: 200}, // Got 200 without 100. body isn't sent. + {path: "/500", body: []byte("hello"), sent: 0, status: 500}, // Got 500 without 100. body isn't sent. + {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500}, // Although without Connection:close, body isn't sent. + {path: "/timeout", body: []byte("hello"), sent: 5, status: 200}, // Timeout exceeded and entire body is sent. + } + + c := ts.Client() + for i, v := range tests { + tr := &Transport{ + ExpectContinueTimeout: 2 * time.Second, + } + defer tr.CloseIdleConnections() + c.Transport = tr + body := bytes.NewReader(v.body) + req, err := NewRequest("PUT", ts.URL+v.path, body) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Expect", "100-continue") + req.ContentLength = int64(len(v.body)) + + resp, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + + sent := len(v.body) - body.Len() + if v.status != resp.StatusCode { + t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path) + } + if v.sent != sent { + t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path) + } + } +} + +func TestSOCKS5Proxy(t *testing.T) { + run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode}) +} +func testSOCKS5Proxy(t *testing.T, mode testMode) { + ch := make(chan string, 1) + l := newLocalListener(t) + defer l.Close() + defer close(ch) + proxy := func(t *testing.T) { + s, err := l.Accept() + if err != nil { + t.Errorf("socks5 proxy Accept(): %v", err) + return + } + defer s.Close() + var buf [22]byte + if _, err := io.ReadFull(s, buf[:3]); err != nil { + t.Errorf("socks5 proxy initial read: %v", err) + return + } + if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) { + t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want) + return + } + if _, err := s.Write([]byte{5, 0}); err != nil { + t.Errorf("socks5 proxy initial write: %v", err) + return + } + if _, err := io.ReadFull(s, buf[:4]); err != nil { + t.Errorf("socks5 proxy second read: %v", err) + return + } + if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) { + t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want) + return + } + var ipLen int + switch buf[3] { + case 1: + ipLen = net.IPv4len + case 4: + ipLen = net.IPv6len + default: + t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4]) + return + } + if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil { + t.Errorf("socks5 proxy address read: %v", err) + return + } + ip := net.IP(buf[4 : ipLen+4]) + port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6]) + copy(buf[:3], []byte{5, 0, 0}) + if _, err := s.Write(buf[:ipLen+6]); err != nil { + t.Errorf("socks5 proxy connect write: %v", err) + return + } + ch <- fmt.Sprintf("proxy for %s:%d", ip, port) + + // Implement proxying. + targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port))) + targetConn, err := net.Dial("tcp", targetHost) + if err != nil { + t.Errorf("net.Dial failed") + return + } + go io.Copy(targetConn, s) + io.Copy(s, targetConn) // Wait for the client to close the socket. + targetConn.Close() + } + + pu, err := url.Parse("socks5://" + l.Addr().String()) + if err != nil { + t.Fatal(err) + } + + sentinelHeader := "X-Sentinel" + sentinelValue := "12345" + h := HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set(sentinelHeader, sentinelValue) + }) + for _, useTLS := range []bool{false, true} { + t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) { + ts := newClientServerTest(t, mode, h).ts + go proxy(t) + c := ts.Client() + c.Transport.(*Transport).Proxy = ProxyURL(pu) + r, err := c.Head(ts.URL) + if err != nil { + t.Fatal(err) + } + if r.Header.Get(sentinelHeader) != sentinelValue { + t.Errorf("Failed to retrieve sentinel value") + } + var got string + select { + case got = <-ch: + case <-time.After(5 * time.Second): + t.Fatal("timeout connecting to socks5 proxy") + } + ts.Close() + tsu, err := url.Parse(ts.URL) + if err != nil { + t.Fatal(err) + } + want := "proxy for " + tsu.Host + if got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + } +} + +func TestTransportProxy(t *testing.T) { + defer afterTest(t) + testCases := []struct{ siteMode, proxyMode testMode }{ + {http1Mode, http1Mode}, + {http1Mode, https1Mode}, + {https1Mode, http1Mode}, + {https1Mode, https1Mode}, + } + for _, testCase := range testCases { + siteMode := testCase.siteMode + proxyMode := testCase.proxyMode + t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) { + siteCh := make(chan *Request, 1) + h1 := HandlerFunc(func(w ResponseWriter, r *Request) { + siteCh <- r + }) + proxyCh := make(chan *Request, 1) + h2 := HandlerFunc(func(w ResponseWriter, r *Request) { + proxyCh <- r + // Implement an entire CONNECT proxy + if r.Method == "CONNECT" { + hijacker, ok := w.(Hijacker) + if !ok { + t.Errorf("hijack not allowed") + return + } + clientConn, _, err := hijacker.Hijack() + if err != nil { + t.Errorf("hijacking failed") + return + } + res := &Response{ + StatusCode: StatusOK, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(Header), + } + + targetConn, err := net.Dial("tcp", r.URL.Host) + if err != nil { + t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err) + return + } + + if err := res.Write(clientConn); err != nil { + t.Errorf("Writing 200 OK failed: %v", err) + return + } + + go io.Copy(targetConn, clientConn) + go func() { + io.Copy(clientConn, targetConn) + targetConn.Close() + }() + } + }) + ts := newClientServerTest(t, siteMode, h1).ts + proxy := newClientServerTest(t, proxyMode, h2).ts + + pu, err := url.Parse(proxy.URL) + if err != nil { + t.Fatal(err) + } + + // If neither server is HTTPS or both are, then c may be derived from either. + // If only one server is HTTPS, c must be derived from that server in order + // to ensure that it is configured to use the fake root CA from testcert.go. + c := proxy.Client() + if siteMode == https1Mode { + c = ts.Client() + } + + c.Transport.(*Transport).Proxy = ProxyURL(pu) + if _, err := c.Head(ts.URL); err != nil { + t.Error(err) + } + var got *Request + select { + case got = <-proxyCh: + case <-time.After(5 * time.Second): + t.Fatal("timeout connecting to http proxy") + } + c.Transport.(*Transport).CloseIdleConnections() + ts.Close() + proxy.Close() + if siteMode == https1Mode { + // First message should be a CONNECT, asking for a socket to the real server, + if got.Method != "CONNECT" { + t.Errorf("Wrong method for secure proxying: %q", got.Method) + } + gotHost := got.URL.Host + pu, err := url.Parse(ts.URL) + if err != nil { + t.Fatal("Invalid site URL") + } + if wantHost := pu.Host; gotHost != wantHost { + t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost) + } + + // The next message on the channel should be from the site's server. + next := <-siteCh + if next.Method != "HEAD" { + t.Errorf("Wrong method at destination: %s", next.Method) + } + if nextURL := next.URL.String(); nextURL != "/" { + t.Errorf("Wrong URL at destination: %s", nextURL) + } + } else { + if got.Method != "HEAD" { + t.Errorf("Wrong method for destination: %q", got.Method) + } + gotURL := got.URL.String() + wantURL := ts.URL + "/" + if gotURL != wantURL { + t.Errorf("Got URL %q, want %q", gotURL, wantURL) + } + } + }) + } +} + +func TestOnProxyConnectResponse(t *testing.T) { + + var tcases = []struct { + proxyStatusCode int + err error + }{ + { + StatusOK, + nil, + }, + { + StatusForbidden, + errors.New("403"), + }, + } + for _, tcase := range tcases { + h1 := HandlerFunc(func(w ResponseWriter, r *Request) { + + }) + + h2 := HandlerFunc(func(w ResponseWriter, r *Request) { + // Implement an entire CONNECT proxy + if r.Method == "CONNECT" { + if tcase.proxyStatusCode != StatusOK { + w.WriteHeader(tcase.proxyStatusCode) + return + } + hijacker, ok := w.(Hijacker) + if !ok { + t.Errorf("hijack not allowed") + return + } + clientConn, _, err := hijacker.Hijack() + if err != nil { + t.Errorf("hijacking failed") + return + } + res := &Response{ + StatusCode: StatusOK, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(Header), + } + + targetConn, err := net.Dial("tcp", r.URL.Host) + if err != nil { + t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err) + return + } + + if err := res.Write(clientConn); err != nil { + t.Errorf("Writing 200 OK failed: %v", err) + return + } + + go io.Copy(targetConn, clientConn) + go func() { + io.Copy(clientConn, targetConn) + targetConn.Close() + }() + } + }) + ts := newClientServerTest(t, https1Mode, h1).ts + proxy := newClientServerTest(t, https1Mode, h2).ts + + pu, err := url.Parse(proxy.URL) + if err != nil { + t.Fatal(err) + } + + c := proxy.Client() + + c.Transport.(*Transport).Proxy = ProxyURL(pu) + c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error { + if proxyURL.String() != pu.String() { + t.Errorf("proxy url got %s, want %s", proxyURL, pu) + } + + if "https://"+connectReq.URL.String() != ts.URL { + t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL) + } + return tcase.err + } + if _, err := c.Head(ts.URL); err != nil { + if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) { + t.Errorf("got %v, want %v", err, tcase.err) + } + } + } +} + +// Issue 28012: verify that the Transport closes its TCP connection to http proxies +// when they're slow to reply to HTTPS CONNECT responses. +func TestTransportProxyHTTPSConnectLeak(t *testing.T) { + setParallel(t) + defer afterTest(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln := newLocalListener(t) + defer ln.Close() + listenerDone := make(chan struct{}) + go func() { + defer close(listenerDone) + c, err := ln.Accept() + if err != nil { + t.Errorf("Accept: %v", err) + return + } + defer c.Close() + // Read the CONNECT request + br := bufio.NewReader(c) + cr, err := ReadRequest(br) + if err != nil { + t.Errorf("proxy server failed to read CONNECT request") + return + } + if cr.Method != "CONNECT" { + t.Errorf("unexpected method %q", cr.Method) + return + } + + // Now hang and never write a response; instead, cancel the request and wait + // for the client to close. + // (Prior to Issue 28012 being fixed, we never closed.) + cancel() + var buf [1]byte + _, err = br.Read(buf[:]) + if err != io.EOF { + t.Errorf("proxy server Read err = %v; want EOF", err) + } + return + }() + + c := &Client{ + Transport: &Transport{ + Proxy: func(*Request) (*url.URL, error) { + return url.Parse("http://" + ln.Addr().String()) + }, + }, + } + req, err := NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil) + if err != nil { + t.Fatal(err) + } + _, err = c.Do(req) + if err == nil { + t.Errorf("unexpected Get success") + } + + // Wait unconditionally for the listener goroutine to exit: this should never + // hang, so if it does we want a full goroutine dump — and that's exactly what + // the testing package will give us when the test run times out. + <-listenerDone +} + +// Issue 16997: test transport dial preserves typed errors +func TestTransportDialPreservesNetOpProxyError(t *testing.T) { + defer afterTest(t) + + var errDial = errors.New("some dial error") + + tr := &Transport{ + Proxy: func(*Request) (*url.URL, error) { + return url.Parse("http://proxy.fake.tld/") + }, + Dial: func(string, string) (net.Conn, error) { + return nil, errDial + }, + } + defer tr.CloseIdleConnections() + + c := &Client{Transport: tr} + req, _ := NewRequest("GET", "http://fake.tld", nil) + res, err := c.Do(req) + if err == nil { + res.Body.Close() + t.Fatal("wanted a non-nil error") + } + + uerr, ok := err.(*url.Error) + if !ok { + t.Fatalf("got %T, want *url.Error", err) + } + oe, ok := uerr.Err.(*net.OpError) + if !ok { + t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err) + } + want := &net.OpError{ + Op: "proxyconnect", + Net: "tcp", + Err: errDial, // original error, unwrapped. + } + if !reflect.DeepEqual(oe, want) { + t.Errorf("Got error %#v; want %#v", oe, want) + } +} + +// Issue 36431: calls to RoundTrip should not mutate t.ProxyConnectHeader. +// +// (A bug caused dialConn to instead write the per-request Proxy-Authorization +// header through to the shared Header instance, introducing a data race.) +func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) { + run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader) +} +func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) { + proxy := newClientServerTest(t, mode, NotFoundHandler()).ts + defer proxy.Close() + c := proxy.Client() + + tr := c.Transport.(*Transport) + tr.Proxy = func(*Request) (*url.URL, error) { + u, _ := url.Parse(proxy.URL) + u.User = url.UserPassword("aladdin", "opensesame") + return u, nil + } + h := tr.ProxyConnectHeader + if h == nil { + h = make(Header) + } + tr.ProxyConnectHeader = h.Clone() + + req, err := NewRequest("GET", "https://golang.fake.tld/", nil) + if err != nil { + t.Fatal(err) + } + _, err = c.Do(req) + if err == nil { + t.Errorf("unexpected Get success") + } + + if !reflect.DeepEqual(tr.ProxyConnectHeader, h) { + t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h) + } +} + +// TestTransportGzipRecursive sends a gzip quine and checks that the +// client gets the same value back. This is more cute than anything, +// but checks that we don't recurse forever, and checks that +// Content-Encoding is removed. +func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) } +func testTransportGzipRecursive(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Encoding", "gzip") + w.Write(rgz) + })).ts + + c := ts.Client() + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(body, rgz) { + t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x", + body, rgz) + } + if g, e := res.Header.Get("Content-Encoding"), ""; g != e { + t.Fatalf("Content-Encoding = %q; want %q", g, e) + } +} + +// golang.org/issue/7750: request fails when server replies with +// a short gzip body +func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) } +func testTransportGzipShort(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Encoding", "gzip") + w.Write([]byte{0x1f, 0x8b}) + })).ts + + c := ts.Client() + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + _, err = io.ReadAll(res.Body) + if err == nil { + t.Fatal("Expect an error from reading a body.") + } + if err != io.ErrUnexpectedEOF { + t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err) + } +} + +// Wait until number of goroutines is no greater than nmax, or time out. +func waitNumGoroutine(nmax int) int { + nfinal := runtime.NumGoroutine() + for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- { + time.Sleep(50 * time.Millisecond) + runtime.GC() + nfinal = runtime.NumGoroutine() + } + return nfinal +} + +// tests that persistent goroutine connections shut down when no longer desired. +func TestTransportPersistConnLeak(t *testing.T) { + run(t, testTransportPersistConnLeak, testNotParallel) +} +func testTransportPersistConnLeak(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("flaky in HTTP/2") + } + // Not parallel: counts goroutines + + const numReq = 25 + gotReqCh := make(chan bool, numReq) + unblockCh := make(chan bool, numReq) + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + gotReqCh <- true + <-unblockCh + w.Header().Set("Content-Length", "0") + w.WriteHeader(204) + })).ts + c := ts.Client() + tr := c.Transport.(*Transport) + + n0 := runtime.NumGoroutine() + + didReqCh := make(chan bool, numReq) + failed := make(chan bool, numReq) + for i := 0; i < numReq; i++ { + go func() { + res, err := c.Get(ts.URL) + didReqCh <- true + if err != nil { + t.Logf("client fetch error: %v", err) + failed <- true + return + } + res.Body.Close() + }() + } + + // Wait for all goroutines to be stuck in the Handler. + for i := 0; i < numReq; i++ { + select { + case <-gotReqCh: + // ok + case <-failed: + // Not great but not what we are testing: + // sometimes an overloaded system will fail to make all the connections. + } + } + + nhigh := runtime.NumGoroutine() + + // Tell all handlers to unblock and reply. + close(unblockCh) + + // Wait for all HTTP clients to be done. + for i := 0; i < numReq; i++ { + <-didReqCh + } + + tr.CloseIdleConnections() + nfinal := waitNumGoroutine(n0 + 5) + + growth := nfinal - n0 + + // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. + // Previously we were leaking one per numReq. + if int(growth) > 5 { + t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) + t.Error("too many new goroutines") + } +} + +// golang.org/issue/4531: Transport leaks goroutines when +// request.ContentLength is explicitly short +func TestTransportPersistConnLeakShortBody(t *testing.T) { + run(t, testTransportPersistConnLeakShortBody, testNotParallel) +} +func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("flaky in HTTP/2") + } + + // Not parallel: measures goroutines. + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + })).ts + c := ts.Client() + tr := c.Transport.(*Transport) + + n0 := runtime.NumGoroutine() + body := []byte("Hello") + for i := 0; i < 20; i++ { + req, err := NewRequest("POST", ts.URL, bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + req.ContentLength = int64(len(body) - 2) // explicitly short + _, err = c.Do(req) + if err == nil { + t.Fatal("Expect an error from writing too long of a body.") + } + } + nhigh := runtime.NumGoroutine() + tr.CloseIdleConnections() + nfinal := waitNumGoroutine(n0 + 5) + + growth := nfinal - n0 + + // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. + // Previously we were leaking one per numReq. + t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) + if int(growth) > 5 { + t.Error("too many new goroutines") + } +} + +// A countedConn is a net.Conn that decrements an atomic counter when finalized. +type countedConn struct { + net.Conn +} + +// A countingDialer dials connections and counts the number that remain reachable. +type countingDialer struct { + dialer net.Dialer + mu sync.Mutex + total, live int64 +} + +func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + conn, err := d.dialer.DialContext(ctx, network, address) + if err != nil { + return nil, err + } + + counted := new(countedConn) + counted.Conn = conn + + d.mu.Lock() + defer d.mu.Unlock() + d.total++ + d.live++ + + runtime.SetFinalizer(counted, d.decrement) + return counted, nil +} + +func (d *countingDialer) decrement(*countedConn) { + d.mu.Lock() + defer d.mu.Unlock() + d.live-- +} + +func (d *countingDialer) Read() (total, live int64) { + d.mu.Lock() + defer d.mu.Unlock() + return d.total, d.live +} + +func TestTransportPersistConnLeakNeverIdle(t *testing.T) { + run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode}) +} +func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + // Close every connection so that it cannot be kept alive. + conn, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Errorf("Hijack failed unexpectedly: %v", err) + return + } + conn.Close() + })).ts + + var d countingDialer + c := ts.Client() + c.Transport.(*Transport).DialContext = d.DialContext + + body := []byte("Hello") + for i := 0; ; i++ { + total, live := d.Read() + if live < total { + break + } + if i >= 1<<12 { + t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i) + } + + req, err := NewRequest("POST", ts.URL, bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + _, err = c.Do(req) + if err == nil { + t.Fatal("expected broken connection") + } + + runtime.GC() + } +} + +type countedContext struct { + context.Context +} + +type contextCounter struct { + mu sync.Mutex + live int64 +} + +func (cc *contextCounter) Track(ctx context.Context) context.Context { + counted := new(countedContext) + counted.Context = ctx + cc.mu.Lock() + defer cc.mu.Unlock() + cc.live++ + runtime.SetFinalizer(counted, cc.decrement) + return counted +} + +func (cc *contextCounter) decrement(*countedContext) { + cc.mu.Lock() + defer cc.mu.Unlock() + cc.live-- +} + +func (cc *contextCounter) Read() (live int64) { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.live +} + +func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) { + run(t, testTransportPersistConnContextLeakMaxConnsPerHost) +} +func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("https://go.dev/issue/56021") + } + + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + runtime.Gosched() + w.WriteHeader(StatusOK) + })).ts + + c := ts.Client() + c.Transport.(*Transport).MaxConnsPerHost = 1 + + ctx := context.Background() + body := []byte("Hello") + doPosts := func(cc *contextCounter) { + var wg sync.WaitGroup + for n := 64; n > 0; n-- { + wg.Add(1) + go func() { + defer wg.Done() + + ctx := cc.Track(ctx) + req, err := NewRequest("POST", ts.URL, bytes.NewReader(body)) + if err != nil { + t.Error(err) + } + + _, err = c.Do(req.WithContext(ctx)) + if err != nil { + t.Errorf("Do failed with error: %v", err) + } + }() + } + wg.Wait() + } + + var initialCC contextCounter + doPosts(&initialCC) + + // flushCC exists only to put pressure on the GC to finalize the initialCC + // contexts: the flushCC allocations should eventually displace the initialCC + // allocations. + var flushCC contextCounter + for i := 0; ; i++ { + live := initialCC.Read() + if live == 0 { + break + } + if i >= 100 { + t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i) + } + doPosts(&flushCC) + runtime.GC() + } +} + +// This used to crash; https://golang.org/issue/3266 +func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) } +func testTransportIdleConnCrash(t *testing.T, mode testMode) { + var tr *Transport + + unblockCh := make(chan bool, 1) + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + <-unblockCh + tr.CloseIdleConnections() + })).ts + c := ts.Client() + tr = c.Transport.(*Transport) + + didreq := make(chan bool) + go func() { + res, err := c.Get(ts.URL) + if err != nil { + t.Error(err) + } else { + res.Body.Close() // returns idle conn + } + didreq <- true + }() + unblockCh <- true + <-didreq +} + +// Test that the transport doesn't close the TCP connection early, +// before the response body has been read. This was a regression +// which sadly lacked a triggering test. The large response body made +// the old race easier to trigger. +func TestIssue3644(t *testing.T) { run(t, testIssue3644) } +func testIssue3644(t *testing.T, mode testMode) { + const numFoos = 5000 + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Connection", "close") + for i := 0; i < numFoos; i++ { + w.Write([]byte("foo ")) + } + })).ts + c := ts.Client() + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + bs, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if len(bs) != numFoos*len("foo ") { + t.Errorf("unexpected response length") + } +} + +// Test that a client receives a server's reply, even if the server doesn't read +// the entire request body. +func TestIssue3595(t *testing.T) { run(t, testIssue3595) } +func testIssue3595(t *testing.T, mode testMode) { + const deniedMsg = "sorry, denied." + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + Error(w, deniedMsg, StatusUnauthorized) + })).ts + c := ts.Client() + res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a')) + if err != nil { + t.Errorf("Post: %v", err) + return + } + got, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("Body ReadAll: %v", err) + } + if !strings.Contains(string(got), deniedMsg) { + t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg) + } +} + +// From https://golang.org/issue/4454 , +// "client fails to handle requests with no body and chunked encoding" +func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) } +func testChunkedNoContent(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(StatusNoContent) + })).ts + + c := ts.Client() + for _, closeBody := range []bool{true, false} { + const n = 4 + for i := 1; i <= n; i++ { + res, err := c.Get(ts.URL) + if err != nil { + t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err) + } else { + if closeBody { + res.Body.Close() + } + } + } + } +} + +func TestTransportConcurrency(t *testing.T) { + run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode}) +} +func testTransportConcurrency(t *testing.T, mode testMode) { + // Not parallel: uses global test hooks. + maxProcs, numReqs := 16, 500 + if testing.Short() { + maxProcs, numReqs = 4, 50 + } + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "%v", r.FormValue("echo")) + })).ts + + var wg sync.WaitGroup + wg.Add(numReqs) + + // Due to the Transport's "socket late binding" (see + // idleConnCh in transport.go), the numReqs HTTP requests + // below can finish with a dial still outstanding. To keep + // the leak checker happy, keep track of pending dials and + // wait for them to finish (and be closed or returned to the + // idle pool) before we close idle connections. + SetPendingDialHooks(func() { wg.Add(1) }, wg.Done) + defer SetPendingDialHooks(nil, nil) + + c := ts.Client() + reqs := make(chan string) + defer close(reqs) + + for i := 0; i < maxProcs*2; i++ { + go func() { + for req := range reqs { + res, err := c.Get(ts.URL + "/?echo=" + req) + if err != nil { + if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") { + // https://go.dev/issue/52168: this test was observed to fail with + // ECONNRESET errors in Dial on various netbsd builders. + t.Logf("error on req %s: %v", req, err) + t.Logf("(see https://go.dev/issue/52168)") + } else { + t.Errorf("error on req %s: %v", req, err) + } + wg.Done() + continue + } + all, err := io.ReadAll(res.Body) + if err != nil { + t.Errorf("read error on req %s: %v", req, err) + } else if string(all) != req { + t.Errorf("body of req %s = %q; want %q", req, all, req) + } + res.Body.Close() + wg.Done() + } + }() + } + for i := 0; i < numReqs; i++ { + reqs <- fmt.Sprintf("request-%d", i) + } + wg.Wait() +} + +func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) } +func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) { + mux := NewServeMux() + mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { + io.Copy(w, neverEnding('a')) + }) + ts := newClientServerTest(t, mode, mux).ts + + connc := make(chan net.Conn, 1) + c := ts.Client() + c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { + conn, err := net.Dial(n, addr) + if err != nil { + return nil, err + } + select { + case connc <- conn: + default: + } + return conn, nil + } + + res, err := c.Get(ts.URL + "/get") + if err != nil { + t.Fatalf("Error issuing GET: %v", err) + } + defer res.Body.Close() + + conn := <-connc + conn.SetDeadline(time.Now().Add(1 * time.Millisecond)) + _, err = io.Copy(io.Discard, res.Body) + if err == nil { + t.Errorf("Unexpected successful copy") + } +} + +func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { + run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode}) +} +func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) { + const debug = false + mux := NewServeMux() + mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { + io.Copy(w, neverEnding('a')) + }) + mux.HandleFunc("/put", func(w ResponseWriter, r *Request) { + defer r.Body.Close() + io.Copy(io.Discard, r.Body) + }) + ts := newClientServerTest(t, mode, mux).ts + timeout := 100 * time.Millisecond + + c := ts.Client() + c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { + conn, err := net.Dial(n, addr) + if err != nil { + return nil, err + } + conn.SetDeadline(time.Now().Add(timeout)) + if debug { + conn = NewLoggingConn("client", conn) + } + return conn, nil + } + + getFailed := false + nRuns := 5 + if testing.Short() { + nRuns = 1 + } + for i := 0; i < nRuns; i++ { + if debug { + println("run", i+1, "of", nRuns) + } + sres, err := c.Get(ts.URL + "/get") + if err != nil { + if !getFailed { + // Make the timeout longer, once. + getFailed = true + t.Logf("increasing timeout") + i-- + timeout *= 10 + continue + } + t.Errorf("Error issuing GET: %v", err) + break + } + req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body) + _, err = c.Do(req) + if err == nil { + sres.Body.Close() + t.Errorf("Unexpected successful PUT") + break + } + sres.Body.Close() + } + if debug { + println("tests complete; waiting for handlers to finish") + } + ts.Close() +} + +func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) } +func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) { + if testing.Short() { + t.Skip("skipping timeout test in -short mode") + } + inHandler := make(chan bool, 1) + mux := NewServeMux() + mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) { + inHandler <- true + }) + mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) { + inHandler <- true + time.Sleep(2 * time.Second) + }) + ts := newClientServerTest(t, mode, mux).ts + + c := ts.Client() + c.Transport.(*Transport).ResponseHeaderTimeout = 500 * time.Millisecond + + tests := []struct { + path string + want int + wantErr string + }{ + {path: "/fast", want: 200}, + {path: "/slow", wantErr: "timeout awaiting response headers"}, + {path: "/fast", want: 200}, + } + for i, tt := range tests { + req, _ := NewRequest("GET", ts.URL+tt.path, nil) + req = req.WithT(t) + res, err := c.Do(req) + select { + case <-inHandler: + case <-time.After(5 * time.Second): + t.Errorf("never entered handler for test index %d, %s", i, tt.path) + continue + } + if err != nil { + uerr, ok := err.(*url.Error) + if !ok { + t.Errorf("error is not an url.Error; got: %#v", err) + continue + } + nerr, ok := uerr.Err.(net.Error) + if !ok { + t.Errorf("error does not satisfy net.Error interface; got: %#v", err) + continue + } + if !nerr.Timeout() { + t.Errorf("want timeout error; got: %q", nerr) + continue + } + if strings.Contains(err.Error(), tt.wantErr) { + continue + } + t.Errorf("%d. unexpected error: %v", i, err) + continue + } + if tt.wantErr != "" { + t.Errorf("%d. no error. expected error: %v", i, tt.wantErr) + continue + } + if res.StatusCode != tt.want { + t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want) + } + } +} + +func TestTransportCancelRequest(t *testing.T) { + run(t, testTransportCancelRequest, []testMode{http1Mode}) +} +func testTransportCancelRequest(t *testing.T, mode testMode) { + if testing.Short() { + t.Skip("skipping test in -short mode") + } + unblockc := make(chan bool) + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "Hello") + w.(Flusher).Flush() // send headers and some body + <-unblockc + })).ts + defer close(unblockc) + + c := ts.Client() + tr := c.Transport.(*Transport) + + req, _ := NewRequest("GET", ts.URL, nil) + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + go func() { + time.Sleep(1 * time.Second) + tr.CancelRequest(req) + }() + t0 := time.Now() + body, err := io.ReadAll(res.Body) + d := time.Since(t0) + + if err != ExportErrRequestCanceled { + t.Errorf("Body.Read error = %v; want errRequestCanceled", err) + } + if string(body) != "Hello" { + t.Errorf("Body = %q; want Hello", body) + } + if d < 500*time.Millisecond { + t.Errorf("expected ~1 second delay; got %v", d) + } + // Verify no outstanding requests after readLoop/writeLoop + // goroutines shut down. + for tries := 5; tries > 0; tries-- { + n := tr.NumPendingRequestsForTesting() + if n == 0 { + break + } + time.Sleep(100 * time.Millisecond) + if tries == 1 { + t.Errorf("pending requests = %d; want 0", n) + } + } +} + +func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) { + if testing.Short() { + t.Skip("skipping test in -short mode") + } + unblockc := make(chan bool) + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + <-unblockc + })).ts + defer close(unblockc) + + c := ts.Client() + tr := c.Transport.(*Transport) + + donec := make(chan bool) + req, _ := NewRequest("GET", ts.URL, body) + go func() { + defer close(donec) + c.Do(req) + }() + start := time.Now() + timeout := 10 * time.Second + for time.Since(start) < timeout { + time.Sleep(100 * time.Millisecond) + tr.CancelRequest(req) + select { + case <-donec: + return + default: + } + } + t.Errorf("Do of canceled request has not returned after %v", timeout) +} + +func TestTransportCancelRequestInDo(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testTransportCancelRequestInDo(t, mode, nil) + }, []testMode{http1Mode}) +} + +func TestTransportCancelRequestWithBodyInDo(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testTransportCancelRequestInDo(t, mode, bytes.NewBuffer([]byte{0})) + }, []testMode{http1Mode}) +} + +func TestTransportCancelRequestInDial(t *testing.T) { + defer afterTest(t) + if testing.Short() { + t.Skip("skipping test in -short mode") + } + var logbuf strings.Builder + eventLog := log.New(&logbuf, "", 0) + + unblockDial := make(chan bool) + defer close(unblockDial) + + inDial := make(chan bool) + tr := &Transport{ + Dial: func(network, addr string) (net.Conn, error) { + eventLog.Println("dial: blocking") + if !<-inDial { + return nil, errors.New("main Test goroutine exited") + } + <-unblockDial + return nil, errors.New("nope") + }, + } + cl := &Client{Transport: tr} + gotres := make(chan bool) + req, _ := NewRequest("GET", "http://something.no-network.tld/", nil) + go func() { + _, err := cl.Do(req) + eventLog.Printf("Get = %v", err) + gotres <- true + }() + + select { + case inDial <- true: + case <-time.After(5 * time.Second): + close(inDial) + t.Fatal("timeout; never saw blocking dial") + } + + eventLog.Printf("canceling") + tr.CancelRequest(req) + tr.CancelRequest(req) // used to panic on second call + + select { + case <-gotres: + case <-time.After(5 * time.Second): + panic("hang. events are: " + logbuf.String()) + } + + got := logbuf.String() + want := `dial: blocking +canceling +Get = Get "http://something.no-network.tld/": net/http: request canceled while waiting for connection +` + if got != want { + t.Errorf("Got events:\n%s\nWant:\n%s", got, want) + } +} + +func TestCancelRequestWithChannel(t *testing.T) { run(t, testCancelRequestWithChannel) } +func testCancelRequestWithChannel(t *testing.T, mode testMode) { + if testing.Short() { + t.Skip("skipping test in -short mode") + } + unblockc := make(chan bool) + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "Hello") + w.(Flusher).Flush() // send headers and some body + <-unblockc + })).ts + defer close(unblockc) + + c := ts.Client() + tr := c.Transport.(*Transport) + + req, _ := NewRequest("GET", ts.URL, nil) + ch := make(chan struct{}) + req.Cancel = ch + + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + go func() { + time.Sleep(1 * time.Second) + close(ch) + }() + t0 := time.Now() + body, err := io.ReadAll(res.Body) + d := time.Since(t0) + + if err != ExportErrRequestCanceled { + t.Errorf("Body.Read error = %v; want errRequestCanceled", err) + } + if string(body) != "Hello" { + t.Errorf("Body = %q; want Hello", body) + } + if d < 500*time.Millisecond { + t.Errorf("expected ~1 second delay; got %v", d) + } + // Verify no outstanding requests after readLoop/writeLoop + // goroutines shut down. + for tries := 5; tries > 0; tries-- { + n := tr.NumPendingRequestsForTesting() + if n == 0 { + break + } + time.Sleep(100 * time.Millisecond) + if tries == 1 { + t.Errorf("pending requests = %d; want 0", n) + } + } +} + +func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testCancelRequestWithChannelBeforeDo(t, mode, false) + }) +} +func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testCancelRequestWithChannelBeforeDo(t, mode, true) + }) +} +func testCancelRequestWithChannelBeforeDo(t *testing.T, mode testMode, withCtx bool) { + unblockc := make(chan bool) + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + <-unblockc + })).ts + defer close(unblockc) + + c := ts.Client() + + req, _ := NewRequest("GET", ts.URL, nil) + if withCtx { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + req = req.WithContext(ctx) + } else { + ch := make(chan struct{}) + req.Cancel = ch + close(ch) + } + + _, err := c.Do(req) + if ue, ok := err.(*url.Error); ok { + err = ue.Err + } + if withCtx { + if err != context.Canceled { + t.Errorf("Do error = %v; want %v", err, context.Canceled) + } + } else { + if err == nil || !strings.Contains(err.Error(), "canceled") { + t.Errorf("Do error = %v; want cancellation", err) + } + } +} + +// Issue 11020. The returned error message should be errRequestCanceled +func TestTransportCancelBeforeResponseHeaders(t *testing.T) { + defer afterTest(t) + + serverConnCh := make(chan net.Conn, 1) + tr := &Transport{ + Dial: func(network, addr string) (net.Conn, error) { + cc, sc := net.Pipe() + serverConnCh <- sc + return cc, nil + }, + } + defer tr.CloseIdleConnections() + errc := make(chan error, 1) + req, _ := NewRequest("GET", "http://example.com/", nil) + go func() { + _, err := tr.RoundTrip(req) + errc <- err + }() + + sc := <-serverConnCh + verb := make([]byte, 3) + if _, err := io.ReadFull(sc, verb); err != nil { + t.Errorf("Error reading HTTP verb from server: %v", err) + } + if string(verb) != "GET" { + t.Errorf("server received %q; want GET", verb) + } + defer sc.Close() + + tr.CancelRequest(req) + + err := <-errc + if err == nil { + t.Fatalf("unexpected success from RoundTrip") + } + if err != ExportErrRequestCanceled { + t.Errorf("RoundTrip error = %v; want ExportErrRequestCanceled", err) + } +} + +// golang.org/issue/3672 -- Client can't close HTTP stream +// Calling Close on a Response.Body used to just read until EOF. +// Now it actually closes the TCP connection. +func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) } +func testTransportCloseResponseBody(t *testing.T, mode testMode) { + writeErr := make(chan error, 1) + msg := []byte("young\n") + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + for { + _, err := w.Write(msg) + if err != nil { + writeErr <- err + return + } + w.(Flusher).Flush() + } + })).ts + + c := ts.Client() + tr := c.Transport.(*Transport) + + req, _ := NewRequest("GET", ts.URL, nil) + defer tr.CancelRequest(req) + + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + + const repeats = 3 + buf := make([]byte, len(msg)*repeats) + want := bytes.Repeat(msg, repeats) + + _, err = io.ReadFull(res.Body, buf) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf, want) { + t.Fatalf("read %q; want %q", buf, want) + } + didClose := make(chan error, 1) + go func() { + didClose <- res.Body.Close() + }() + select { + case err := <-didClose: + if err != nil { + t.Errorf("Close = %v", err) + } + case <-time.After(10 * time.Second): + t.Fatal("too long waiting for close") + } + select { + case err := <-writeErr: + if err == nil { + t.Errorf("expected non-nil write error") + } + case <-time.After(10 * time.Second): + t.Fatal("too long waiting for write error") + } +} + +type fooProto struct{} + +func (fooProto) RoundTrip(req *Request) (*Response, error) { + res := &Response{ + Status: "200 OK", + StatusCode: 200, + Header: make(Header), + Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())), + } + return res, nil +} + +func TestTransportAltProto(t *testing.T) { + defer afterTest(t) + tr := &Transport{} + c := &Client{Transport: tr} + tr.RegisterProtocol("foo", fooProto{}) + res, err := c.Get("foo://bar.com/path") + if err != nil { + t.Fatal(err) + } + bodyb, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + body := string(bodyb) + if e := "You wanted foo://bar.com/path"; body != e { + t.Errorf("got response %q, want %q", body, e) + } +} + +func TestTransportNoHost(t *testing.T) { + defer afterTest(t) + tr := &Transport{} + _, err := tr.RoundTrip(&Request{ + Header: make(Header), + URL: &url.URL{ + Scheme: "http", + }, + }) + want := "http: no Host in request URL" + if got := fmt.Sprint(err); got != want { + t.Errorf("error = %v; want %q", err, want) + } +} + +// Issue 13311 +func TestTransportEmptyMethod(t *testing.T) { + req, _ := NewRequest("GET", "http://foo.com/", nil) + req.Method = "" // docs say "For client requests an empty string means GET" + got, err := httputil.DumpRequestOut(req, false) // DumpRequestOut uses Transport + if err != nil { + t.Fatal(err) + } + if !strings.Contains(string(got), "GET ") { + t.Fatalf("expected substring 'GET '; got: %s", got) + } +} + +func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) } +func testTransportSocketLateBinding(t *testing.T, mode testMode) { + mux := NewServeMux() + fooGate := make(chan bool, 1) + mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) { + w.Header().Set("foo-ipport", r.RemoteAddr) + w.(Flusher).Flush() + <-fooGate + }) + mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) { + w.Header().Set("bar-ipport", r.RemoteAddr) + }) + ts := newClientServerTest(t, mode, mux).ts + + dialGate := make(chan bool, 1) + c := ts.Client() + c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { + if <-dialGate { + return net.Dial(n, addr) + } + return nil, errors.New("manually closed") + } + + dialGate <- true // only allow one dial + fooRes, err := c.Get(ts.URL + "/foo") + if err != nil { + t.Fatal(err) + } + fooAddr := fooRes.Header.Get("foo-ipport") + if fooAddr == "" { + t.Fatal("No addr on /foo request") + } + time.AfterFunc(200*time.Millisecond, func() { + // let the foo response finish so we can use its + // connection for /bar + fooGate <- true + io.Copy(io.Discard, fooRes.Body) + fooRes.Body.Close() + }) + + barRes, err := c.Get(ts.URL + "/bar") + if err != nil { + t.Fatal(err) + } + barAddr := barRes.Header.Get("bar-ipport") + if barAddr != fooAddr { + t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr) + } + barRes.Body.Close() + dialGate <- false +} + +// Issue 2184 +func TestTransportReading100Continue(t *testing.T) { + defer afterTest(t) + + const numReqs = 5 + reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) } + reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) } + + send100Response := func(w *io.PipeWriter, r *io.PipeReader) { + defer w.Close() + defer r.Close() + br := bufio.NewReader(r) + n := 0 + for { + n++ + req, err := ReadRequest(br) + if err == io.EOF { + return + } + if err != nil { + t.Error(err) + return + } + slurp, err := io.ReadAll(req.Body) + if err != nil { + t.Errorf("Server request body slurp: %v", err) + return + } + id := req.Header.Get("Request-Id") + resCode := req.Header.Get("X-Want-Response-Code") + if resCode == "" { + resCode = "100 Continue" + if string(slurp) != reqBody(n) { + t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n)) + } + } + body := fmt.Sprintf("Response number %d", n) + v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s +Date: Thu, 28 Feb 2013 17:55:41 GMT + +HTTP/1.1 200 OK +Content-Type: text/html +Echo-Request-Id: %s +Content-Length: %d + +%s`, resCode, id, len(body), body), "\n", "\r\n", -1)) + w.Write(v) + if id == reqID(numReqs) { + return + } + } + + } + + tr := &Transport{ + Dial: func(n, addr string) (net.Conn, error) { + sr, sw := io.Pipe() // server read/write + cr, cw := io.Pipe() // client read/write + conn := &rwTestConn{ + Reader: cr, + Writer: sw, + closeFunc: func() error { + sw.Close() + cw.Close() + return nil + }, + } + go send100Response(cw, sr) + return conn, nil + }, + DisableKeepAlives: false, + } + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + testResponse := func(req *Request, name string, wantCode int) { + t.Helper() + res, err := c.Do(req) + if err != nil { + t.Fatalf("%s: Do: %v", name, err) + } + if res.StatusCode != wantCode { + t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode) + } + if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack { + t.Errorf("%s: response id %q != request id %q", name, idBack, id) + } + _, err = io.ReadAll(res.Body) + if err != nil { + t.Fatalf("%s: Slurp error: %v", name, err) + } + } + + // Few 100 responses, making sure we're not off-by-one. + for i := 1; i <= numReqs; i++ { + req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i))) + req.Header.Set("Request-Id", reqID(i)) + testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200) + } +} + +// Issue 17739: the HTTP client must ignore any unknown 1xx +// informational responses before the actual response. +func TestTransportIgnore1xxResponses(t *testing.T) { + run(t, testTransportIgnore1xxResponses, []testMode{http1Mode}) +} +func testTransportIgnore1xxResponses(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + conn, buf, _ := w.(Hijacker).Hijack() + buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello")) + buf.Flush() + conn.Close() + })) + cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway + + var got strings.Builder + + req, _ := NewRequest("GET", cst.ts.URL, nil) + req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header) + return nil + }, + })) + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + + res.Write(&got) + want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello" + if got.String() != want { + t.Errorf(" got: %q\nwant: %q\n", got.String(), want) + } +} + +func TestTransportLimits1xxResponses(t *testing.T) { + run(t, testTransportLimits1xxResponses, []testMode{http1Mode}) +} +func testTransportLimits1xxResponses(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + conn, buf, _ := w.(Hijacker).Hijack() + for i := 0; i < 10; i++ { + buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n")) + } + buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n")) + buf.Flush() + conn.Close() + })) + cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway + + res, err := cst.c.Get(cst.ts.URL) + if res != nil { + defer res.Body.Close() + } + got := fmt.Sprint(err) + wantSub := "too many 1xx informational responses" + if !strings.Contains(got, wantSub) { + t.Errorf("Get error = %v; want substring %q", err, wantSub) + } +} + +// Issue 26161: the HTTP client must treat 101 responses +// as the final response. +func TestTransportTreat101Terminal(t *testing.T) { + run(t, testTransportTreat101Terminal, []testMode{http1Mode}) +} +func testTransportTreat101Terminal(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + conn, buf, _ := w.(Hijacker).Hijack() + buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n")) + buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n")) + buf.Flush() + conn.Close() + })) + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != StatusSwitchingProtocols { + t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode) + } +} + +type proxyFromEnvTest struct { + req string // URL to fetch; blank means "http://example.com" + + env string // HTTP_PROXY + httpsenv string // HTTPS_PROXY + noenv string // NO_PROXY + reqmeth string // REQUEST_METHOD + + want string + wanterr error +} + +func (t proxyFromEnvTest) String() string { + var buf strings.Builder + space := func() { + if buf.Len() > 0 { + buf.WriteByte(' ') + } + } + if t.env != "" { + fmt.Fprintf(&buf, "http_proxy=%q", t.env) + } + if t.httpsenv != "" { + space() + fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv) + } + if t.noenv != "" { + space() + fmt.Fprintf(&buf, "no_proxy=%q", t.noenv) + } + if t.reqmeth != "" { + space() + fmt.Fprintf(&buf, "request_method=%q", t.reqmeth) + } + req := "http://example.com" + if t.req != "" { + req = t.req + } + space() + fmt.Fprintf(&buf, "req=%q", req) + return strings.TrimSpace(buf.String()) +} + +var proxyFromEnvTests = []proxyFromEnvTest{ + {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"}, + {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"}, + {env: "cache.corp.example.com", want: "http://cache.corp.example.com"}, + {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"}, + {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"}, + {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"}, + {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"}, + + // Don't use secure for http + {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"}, + // Use secure for https. + {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"}, + {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"}, + + // Issue 16405: don't use HTTP_PROXY in a CGI environment, + // where HTTP_PROXY can be attacker-controlled. + {env: "http://10.1.2.3:8080", reqmeth: "POST", + want: "<nil>", + wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")}, + + {want: "<nil>"}, + + {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"}, + {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, + {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, + {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"}, + {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"}, +} + +func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) { + t.Helper() + reqURL := tt.req + if reqURL == "" { + reqURL = "http://example.com" + } + req, _ := NewRequest("GET", reqURL, nil) + url, err := proxyForRequest(req) + if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e { + t.Errorf("%v: got error = %q, want %q", tt, g, e) + return + } + if got := fmt.Sprintf("%s", url); got != tt.want { + t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want) + } +} + +func TestProxyFromEnvironment(t *testing.T) { + ResetProxyEnv() + defer ResetProxyEnv() + for _, tt := range proxyFromEnvTests { + testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) { + os.Setenv("HTTP_PROXY", tt.env) + os.Setenv("HTTPS_PROXY", tt.httpsenv) + os.Setenv("NO_PROXY", tt.noenv) + os.Setenv("REQUEST_METHOD", tt.reqmeth) + ResetCachedEnvironment() + return ProxyFromEnvironment(req) + }) + } +} + +func TestProxyFromEnvironmentLowerCase(t *testing.T) { + ResetProxyEnv() + defer ResetProxyEnv() + for _, tt := range proxyFromEnvTests { + testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) { + os.Setenv("http_proxy", tt.env) + os.Setenv("https_proxy", tt.httpsenv) + os.Setenv("no_proxy", tt.noenv) + os.Setenv("REQUEST_METHOD", tt.reqmeth) + ResetCachedEnvironment() + return ProxyFromEnvironment(req) + }) + } +} + +func TestIdleConnChannelLeak(t *testing.T) { + run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel) +} +func testIdleConnChannelLeak(t *testing.T, mode testMode) { + // Not parallel: uses global test hooks. + var mu sync.Mutex + var n int + + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + mu.Lock() + n++ + mu.Unlock() + })).ts + + const nReqs = 5 + didRead := make(chan bool, nReqs) + SetReadLoopBeforeNextReadHook(func() { didRead <- true }) + defer SetReadLoopBeforeNextReadHook(nil) + + c := ts.Client() + tr := c.Transport.(*Transport) + tr.Dial = func(netw, addr string) (net.Conn, error) { + return net.Dial(netw, ts.Listener.Addr().String()) + } + + // First, without keep-alives. + for _, disableKeep := range []bool{true, false} { + tr.DisableKeepAlives = disableKeep + for i := 0; i < nReqs; i++ { + _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i)) + if err != nil { + t.Fatal(err) + } + // Note: no res.Body.Close is needed here, since the + // response Content-Length is zero. Perhaps the test + // should be more explicit and use a HEAD, but tests + // elsewhere guarantee that zero byte responses generate + // a "Content-Length: 0" instead of chunking. + } + + // At this point, each of the 5 Transport.readLoop goroutines + // are scheduling noting that there are no response bodies (see + // earlier comment), and are then calling putIdleConn, which + // decrements this count. Usually that happens quickly, which is + // why this test has seemed to work for ages. But it's still + // racey: we have wait for them to finish first. See Issue 10427 + for i := 0; i < nReqs; i++ { + <-didRead + } + + if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 { + t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got) + } + } +} + +// Verify the status quo: that the Client.Post function coerces its +// body into a ReadCloser if it's a Closer, and that the Transport +// then closes it. +func TestTransportClosesRequestBody(t *testing.T) { + run(t, testTransportClosesRequestBody, []testMode{http1Mode}) +} +func testTransportClosesRequestBody(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + io.Copy(io.Discard, r.Body) + })).ts + + c := ts.Client() + + closes := 0 + + res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if closes != 1 { + t.Errorf("closes = %d; want 1", closes) + } +} + +func TestTransportTLSHandshakeTimeout(t *testing.T) { + defer afterTest(t) + if testing.Short() { + t.Skip("skipping in short mode") + } + ln := newLocalListener(t) + defer ln.Close() + testdonec := make(chan struct{}) + defer close(testdonec) + + go func() { + c, err := ln.Accept() + if err != nil { + t.Error(err) + return + } + <-testdonec + c.Close() + }() + + getdonec := make(chan struct{}) + go func() { + defer close(getdonec) + tr := &Transport{ + Dial: func(_, _ string) (net.Conn, error) { + return net.Dial("tcp", ln.Addr().String()) + }, + TLSHandshakeTimeout: 250 * time.Millisecond, + } + cl := &Client{Transport: tr} + _, err := cl.Get("https://dummy.tld/") + if err == nil { + t.Error("expected error") + return + } + ue, ok := err.(*url.Error) + if !ok { + t.Errorf("expected url.Error; got %#v", err) + return + } + ne, ok := ue.Err.(net.Error) + if !ok { + t.Errorf("expected net.Error; got %#v", err) + return + } + if !ne.Timeout() { + t.Errorf("expected timeout error; got %v", err) + } + if !strings.Contains(err.Error(), "handshake timeout") { + t.Errorf("expected 'handshake timeout' in error; got %v", err) + } + }() + select { + case <-getdonec: + case <-time.After(5 * time.Second): + t.Error("test timeout; TLS handshake hung?") + } +} + +// Trying to repro golang.org/issue/3514 +func TestTLSServerClosesConnection(t *testing.T) { + run(t, testTLSServerClosesConnection, []testMode{https1Mode}) +} +func testTLSServerClosesConnection(t *testing.T, mode testMode) { + closedc := make(chan bool, 1) + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + if strings.Contains(r.URL.Path, "/keep-alive-then-die") { + conn, _, _ := w.(Hijacker).Hijack() + conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) + conn.Close() + closedc <- true + return + } + fmt.Fprintf(w, "hello") + })).ts + + c := ts.Client() + tr := c.Transport.(*Transport) + + var nSuccess = 0 + var errs []error + const trials = 20 + for i := 0; i < trials; i++ { + tr.CloseIdleConnections() + res, err := c.Get(ts.URL + "/keep-alive-then-die") + if err != nil { + t.Fatal(err) + } + <-closedc + slurp, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(slurp) != "foo" { + t.Errorf("Got %q, want foo", slurp) + } + + // Now try again and see if we successfully + // pick a new connection. + res, err = c.Get(ts.URL + "/") + if err != nil { + errs = append(errs, err) + continue + } + slurp, err = io.ReadAll(res.Body) + if err != nil { + errs = append(errs, err) + continue + } + nSuccess++ + } + if nSuccess > 0 { + t.Logf("successes = %d of %d", nSuccess, trials) + } else { + t.Errorf("All runs failed:") + } + for _, err := range errs { + t.Logf(" err: %v", err) + } +} + +// byteFromChanReader is an io.Reader that reads a single byte at a +// time from the channel. When the channel is closed, the reader +// returns io.EOF. +type byteFromChanReader chan byte + +func (c byteFromChanReader) Read(p []byte) (n int, err error) { + if len(p) == 0 { + return + } + b, ok := <-c + if !ok { + return 0, io.EOF + } + p[0] = b + return 1, nil +} + +// Verifies that the Transport doesn't reuse a connection in the case +// where the server replies before the request has been fully +// written. We still honor that reply (see TestIssue3595), but don't +// send future requests on the connection because it's then in a +// questionable state. +// golang.org/issue/7569 +func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { + run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}) +} +func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) { + var sconn struct { + sync.Mutex + c net.Conn + } + var getOkay bool + closeConn := func() { + sconn.Lock() + defer sconn.Unlock() + if sconn.c != nil { + sconn.c.Close() + sconn.c = nil + if !getOkay { + t.Logf("Closed server connection") + } + } + } + defer closeConn() + + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method == "GET" { + io.WriteString(w, "bar") + return + } + conn, _, _ := w.(Hijacker).Hijack() + sconn.Lock() + sconn.c = conn + sconn.Unlock() + conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive + go io.Copy(io.Discard, conn) + })).ts + c := ts.Client() + + const bodySize = 256 << 10 + finalBit := make(byteFromChanReader, 1) + req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit)) + req.ContentLength = bodySize + res, err := c.Do(req) + if err := wantBody(res, err, "foo"); err != nil { + t.Errorf("POST response: %v", err) + } + donec := make(chan bool) + go func() { + defer close(donec) + res, err = c.Get(ts.URL) + if err := wantBody(res, err, "bar"); err != nil { + t.Errorf("GET response: %v", err) + return + } + getOkay = true // suppress test noise + }() + time.AfterFunc(5*time.Second, closeConn) + select { + case <-donec: + finalBit <- 'x' // unblock the writeloop of the first Post + close(finalBit) + case <-time.After(7 * time.Second): + t.Fatal("timeout waiting for GET request to finish") + } +} + +// Tests that we don't leak Transport persistConn.readLoop goroutines +// when a server hangs up immediately after saying it would keep-alive. +func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) } +func testTransportIssue10457(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + // Send a response with no body, keep-alive + // (implicit), and then lie and immediately close the + // connection. This forces the Transport's readLoop to + // immediately Peek an io.EOF and get to the point + // that used to hang. + conn, _, _ := w.(Hijacker).Hijack() + conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n")) // keep-alive + conn.Close() + })).ts + c := ts.Client() + + res, err := c.Get(ts.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + + // Just a sanity check that we at least get the response. The real + // test here is that the "defer afterTest" above doesn't find any + // leaked goroutines. + if got, want := res.Header.Get("Foo"), "Bar"; got != want { + t.Errorf("Foo header = %q; want %q", got, want) + } +} + +type closerFunc func() error + +func (f closerFunc) Close() error { return f() } + +type writerFuncConn struct { + net.Conn + write func(p []byte) (n int, err error) +} + +func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) } + +// Issues 4677, 18241, and 17844. If we try to reuse a connection that the +// server is in the process of closing, we may end up successfully writing out +// our request (or a portion of our request) only to find a connection error +// when we try to read from (or finish writing to) the socket. +// +// NOTE: we resend a request only if: +// - we reused a keep-alive connection +// - we haven't yet received any header data +// - either we wrote no bytes to the server, or the request is idempotent +// +// This automatically prevents an infinite resend loop because we'll run out of +// the cached keep-alive connections eventually. +func TestRetryRequestsOnError(t *testing.T) { + run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode}) +} +func testRetryRequestsOnError(t *testing.T, mode testMode) { + newRequest := func(method, urlStr string, body io.Reader) *Request { + req, err := NewRequest(method, urlStr, body) + if err != nil { + t.Fatal(err) + } + return req + } + + testCases := []struct { + name string + failureN int + failureErr error + // Note that we can't just re-use the Request object across calls to c.Do + // because we need to rewind Body between calls. (GetBody is only used to + // rewind Body on failure and redirects, not just because it's done.) + req func() *Request + reqString string + }{ + { + name: "IdempotentNoBodySomeWritten", + // Believe that we've written some bytes to the server, so we know we're + // not just in the "retry when no bytes sent" case". + failureN: 1, + // Use the specific error that shouldRetryRequest looks for with idempotent requests. + failureErr: ExportErrServerClosedIdle, + req: func() *Request { + return newRequest("GET", "http://fake.golang", nil) + }, + reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`, + }, + { + name: "IdempotentGetBodySomeWritten", + // Believe that we've written some bytes to the server, so we know we're + // not just in the "retry when no bytes sent" case". + failureN: 1, + // Use the specific error that shouldRetryRequest looks for with idempotent requests. + failureErr: ExportErrServerClosedIdle, + req: func() *Request { + return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n")) + }, + reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`, + }, + { + name: "NothingWrittenNoBody", + // It's key that we return 0 here -- that's what enables Transport to know + // that nothing was written, even though this is a non-idempotent request. + failureN: 0, + failureErr: errors.New("second write fails"), + req: func() *Request { + return newRequest("DELETE", "http://fake.golang", nil) + }, + reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`, + }, + { + name: "NothingWrittenGetBody", + // It's key that we return 0 here -- that's what enables Transport to know + // that nothing was written, even though this is a non-idempotent request. + failureN: 0, + failureErr: errors.New("second write fails"), + // Note that NewRequest will set up GetBody for strings.Reader, which is + // required for the retry to occur + req: func() *Request { + return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n")) + }, + reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var ( + mu sync.Mutex + logbuf strings.Builder + ) + logf := func(format string, args ...any) { + mu.Lock() + defer mu.Unlock() + fmt.Fprintf(&logbuf, format, args...) + logbuf.WriteByte('\n') + } + + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + logf("Handler") + w.Header().Set("X-Status", "ok") + })).ts + + var writeNumAtomic int32 + c := ts.Client() + c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) { + logf("Dial") + c, err := net.Dial(network, ts.Listener.Addr().String()) + if err != nil { + logf("Dial error: %v", err) + return nil, err + } + return &writerFuncConn{ + Conn: c, + write: func(p []byte) (n int, err error) { + if atomic.AddInt32(&writeNumAtomic, 1) == 2 { + logf("intentional write failure") + return tc.failureN, tc.failureErr + } + logf("Write(%q)", p) + return c.Write(p) + }, + }, nil + } + + SetRoundTripRetried(func() { + logf("Retried.") + }) + defer SetRoundTripRetried(nil) + + for i := 0; i < 3; i++ { + t0 := time.Now() + req := tc.req() + res, err := c.Do(req) + if err != nil { + if time.Since(t0) < MaxWriteWaitBeforeConnReuse/2 { + mu.Lock() + got := logbuf.String() + mu.Unlock() + t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got) + } + t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", MaxWriteWaitBeforeConnReuse) + } + res.Body.Close() + if res.Request != req { + t.Errorf("Response.Request != original request; want identical Request") + } + } + + mu.Lock() + got := logbuf.String() + mu.Unlock() + want := fmt.Sprintf(`Dial +Write("%s") +Handler +intentional write failure +Retried. +Dial +Write("%s") +Handler +Write("%s") +Handler +`, tc.reqString, tc.reqString, tc.reqString) + if got != want { + t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want) + } + }) + } +} + +// Issue 6981 +func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) } +func testTransportClosesBodyOnError(t *testing.T, mode testMode) { + readBody := make(chan error, 1) + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + _, err := io.ReadAll(r.Body) + readBody <- err + })).ts + c := ts.Client() + fakeErr := errors.New("fake error") + didClose := make(chan bool, 1) + req, _ := NewRequest("POST", ts.URL, struct { + io.Reader + io.Closer + }{ + io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)), + closerFunc(func() error { + select { + case didClose <- true: + default: + } + return nil + }), + }) + res, err := c.Do(req) + if res != nil { + defer res.Body.Close() + } + if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) { + t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error()) + } + select { + case err := <-readBody: + if err == nil { + t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'") + } + case <-time.After(5 * time.Second): + t.Error("timeout waiting for server handler to complete") + } + select { + case <-didClose: + default: + t.Errorf("didn't see Body.Close") + } +} + +func TestTransportDialTLS(t *testing.T) { + run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode}) +} +func testTransportDialTLS(t *testing.T, mode testMode) { + var mu sync.Mutex // guards following + var gotReq, didDial bool + + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + mu.Lock() + gotReq = true + mu.Unlock() + })).ts + c := ts.Client() + c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) { + mu.Lock() + didDial = true + mu.Unlock() + c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig) + if err != nil { + return nil, err + } + return c, c.Handshake() + } + + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + mu.Lock() + if !gotReq { + t.Error("didn't get request") + } + if !didDial { + t.Error("didn't use dial hook") + } +} + +func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) } +func testTransportDialContext(t *testing.T, mode testMode) { + var mu sync.Mutex // guards following + var gotReq bool + var receivedContext context.Context + + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + mu.Lock() + gotReq = true + mu.Unlock() + })).ts + c := ts.Client() + c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { + mu.Lock() + receivedContext = ctx + mu.Unlock() + return net.Dial(netw, addr) + } + + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + ctx := context.WithValue(context.Background(), "some-key", "some-value") + res, err := c.Do(req.WithContext(ctx)) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + mu.Lock() + if !gotReq { + t.Error("didn't get request") + } + if receivedContext != ctx { + t.Error("didn't receive correct context") + } +} + +func TestTransportDialTLSContext(t *testing.T) { + run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode}) +} +func testTransportDialTLSContext(t *testing.T, mode testMode) { + var mu sync.Mutex // guards following + var gotReq bool + var receivedContext context.Context + + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + mu.Lock() + gotReq = true + mu.Unlock() + })).ts + c := ts.Client() + c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { + mu.Lock() + receivedContext = ctx + mu.Unlock() + c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig) + if err != nil { + return nil, err + } + return c, c.HandshakeContext(ctx) + } + + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + ctx := context.WithValue(context.Background(), "some-key", "some-value") + res, err := c.Do(req.WithContext(ctx)) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + mu.Lock() + if !gotReq { + t.Error("didn't get request") + } + if receivedContext != ctx { + t.Error("didn't receive correct context") + } +} + +// Test for issue 8755 +// Ensure that if a proxy returns an error, it is exposed by RoundTrip +func TestRoundTripReturnsProxyError(t *testing.T) { + badProxy := func(*Request) (*url.URL, error) { + return nil, errors.New("errorMessage") + } + + tr := &Transport{Proxy: badProxy} + + req, _ := NewRequest("GET", "http://example.com", nil) + + _, err := tr.RoundTrip(req) + + if err == nil { + t.Error("Expected proxy error to be returned by RoundTrip") + } +} + +// tests that putting an idle conn after a call to CloseIdleConns does return it +func TestTransportCloseIdleConnsThenReturn(t *testing.T) { + tr := &Transport{} + wantIdle := func(when string, n int) bool { + got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn + if got == n { + return true + } + t.Errorf("%s: idle conns = %d; want %d", when, got, n) + return false + } + wantIdle("start", 0) + if !tr.PutIdleTestConn("http", "example.com") { + t.Fatal("put failed") + } + if !tr.PutIdleTestConn("http", "example.com") { + t.Fatal("second put failed") + } + wantIdle("after put", 2) + tr.CloseIdleConnections() + if !tr.IsIdleForTesting() { + t.Error("should be idle after CloseIdleConnections") + } + wantIdle("after close idle", 0) + if tr.PutIdleTestConn("http", "example.com") { + t.Fatal("put didn't fail") + } + wantIdle("after second put", 0) + + tr.QueueForIdleConnForTesting() // should toggle the transport out of idle mode + if tr.IsIdleForTesting() { + t.Error("shouldn't be idle after QueueForIdleConnForTesting") + } + if !tr.PutIdleTestConn("http", "example.com") { + t.Fatal("after re-activation") + } + wantIdle("after final put", 1) +} + +// Test for issue 34282 +// Ensure that getConn doesn't call the GotConn trace hook on a HTTP/2 idle conn +func TestTransportTraceGotConnH2IdleConns(t *testing.T) { + tr := &Transport{} + wantIdle := func(when string, n int) bool { + got := tr.IdleConnCountForTesting("https", "example.com:443") // key used by PutIdleTestConnH2 + if got == n { + return true + } + t.Errorf("%s: idle conns = %d; want %d", when, got, n) + return false + } + wantIdle("start", 0) + alt := funcRoundTripper(func() {}) + if !tr.PutIdleTestConnH2("https", "example.com:443", alt) { + t.Fatal("put failed") + } + wantIdle("after put", 1) + ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ + GotConn: func(httptrace.GotConnInfo) { + // tr.getConn should leave it for the HTTP/2 alt to call GotConn. + t.Error("GotConn called") + }, + }) + req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil) + _, err := tr.RoundTrip(req) + if err != errFakeRoundTrip { + t.Errorf("got error: %v; want %q", err, errFakeRoundTrip) + } + wantIdle("after round trip", 1) +} + +func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) { + run(t, testTransportRemovesH2ConnsAfterIdle, []testMode{http2Mode}) +} +func testTransportRemovesH2ConnsAfterIdle(t *testing.T, mode testMode) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + trFunc := func(tr *Transport) { + tr.MaxConnsPerHost = 1 + tr.MaxIdleConnsPerHost = 1 + tr.IdleConnTimeout = 10 * time.Millisecond + } + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc) + + if _, err := cst.c.Get(cst.ts.URL); err != nil { + t.Fatalf("got error: %s", err) + } + + time.Sleep(100 * time.Millisecond) + got := make(chan error) + go func() { + if _, err := cst.c.Get(cst.ts.URL); err != nil { + got <- err + } + close(got) + }() + + timeout := time.NewTimer(5 * time.Second) + defer timeout.Stop() + select { + case err := <-got: + if err != nil { + t.Fatalf("got error: %s", err) + } + case <-timeout.C: + t.Fatal("request never completed") + } +} + +// This tests that a client requesting a content range won't also +// implicitly ask for gzip support. If they want that, they need to do it +// on their own. +// golang.org/issue/8923 +func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) } +func testTransportRangeAndGzip(t *testing.T, mode testMode) { + reqc := make(chan *Request, 1) + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + reqc <- r + })).ts + c := ts.Client() + + req, _ := NewRequest("GET", ts.URL, nil) + req.Header.Set("Range", "bytes=7-11") + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + + select { + case r := <-reqc: + if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { + t.Error("Transport advertised gzip support in the Accept header") + } + if r.Header.Get("Range") == "" { + t.Error("no Range in request") + } + case <-time.After(10 * time.Second): + t.Fatal("timeout") + } + res.Body.Close() +} + +// Test for issue 10474 +func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) } +func testTransportResponseCancelRace(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + // important that this response has a body. + var b [1024]byte + w.Write(b[:]) + })).ts + tr := ts.Client().Transport.(*Transport) + + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + // If we do an early close, Transport just throws the connection away and + // doesn't reuse it. In order to trigger the bug, it has to reuse the connection + // so read the body + if _, err := io.Copy(io.Discard, res.Body); err != nil { + t.Fatal(err) + } + + req2, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + tr.CancelRequest(req) + res, err = tr.RoundTrip(req2) + if err != nil { + t.Fatal(err) + } + res.Body.Close() +} + +// Test for issue 19248: Content-Encoding's value is case insensitive. +func TestTransportContentEncodingCaseInsensitive(t *testing.T) { + run(t, testTransportContentEncodingCaseInsensitive) +} +func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) { + for _, ce := range []string{"gzip", "GZIP"} { + ce := ce + t.Run(ce, func(t *testing.T) { + const encodedString = "Hello Gopher" + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Encoding", ce) + gz := gzip.NewWriter(w) + gz.Write([]byte(encodedString)) + gz.Close() + })).ts + + res, err := ts.Client().Get(ts.URL) + if err != nil { + t.Fatal(err) + } + + body, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatal(err) + } + + if string(body) != encodedString { + t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body)) + } + }) + } +} + +func TestTransportDialCancelRace(t *testing.T) { + run(t, testTransportDialCancelRace, testNotParallel, []testMode{http1Mode}) +} +func testTransportDialCancelRace(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts + tr := ts.Client().Transport.(*Transport) + + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + SetEnterRoundTripHook(func() { + tr.CancelRequest(req) + }) + defer SetEnterRoundTripHook(nil) + res, err := tr.RoundTrip(req) + if err != ExportErrRequestCanceled { + t.Errorf("expected canceled request error; got %v", err) + if err == nil { + res.Body.Close() + } + } +} + +// logWritesConn is a net.Conn that logs each Write call to writes +// and then proxies to w. +// It proxies Read calls to a reader it receives from rch. +type logWritesConn struct { + net.Conn // nil. crash on use. + + w io.Writer + + rch <-chan io.Reader + r io.Reader // nil until received by rch + + mu sync.Mutex + writes []string +} + +func (c *logWritesConn) Write(p []byte) (n int, err error) { + c.mu.Lock() + defer c.mu.Unlock() + c.writes = append(c.writes, string(p)) + return c.w.Write(p) +} + +func (c *logWritesConn) Read(p []byte) (n int, err error) { + if c.r == nil { + c.r = <-c.rch + } + return c.r.Read(p) +} + +func (c *logWritesConn) Close() error { return nil } + +// Issue 6574 +func TestTransportFlushesBodyChunks(t *testing.T) { + defer afterTest(t) + resBody := make(chan io.Reader, 1) + connr, connw := io.Pipe() // connection pipe pair + lw := &logWritesConn{ + rch: resBody, + w: connw, + } + tr := &Transport{ + Dial: func(network, addr string) (net.Conn, error) { + return lw, nil + }, + } + bodyr, bodyw := io.Pipe() // body pipe pair + go func() { + defer bodyw.Close() + for i := 0; i < 3; i++ { + fmt.Fprintf(bodyw, "num%d\n", i) + } + }() + resc := make(chan *Response) + go func() { + req, _ := NewRequest("POST", "http://localhost:8080", bodyr) + req.Header.Set("User-Agent", "x") // known value for test + res, err := tr.RoundTrip(req) + if err != nil { + t.Errorf("RoundTrip: %v", err) + close(resc) + return + } + resc <- res + + }() + // Fully consume the request before checking the Write log vs. want. + req, err := ReadRequest(bufio.NewReader(connr)) + if err != nil { + t.Fatal(err) + } + io.Copy(io.Discard, req.Body) + + // Unblock the transport's roundTrip goroutine. + resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n") + res, ok := <-resc + if !ok { + return + } + defer res.Body.Close() + + want := []string{ + "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n", + "5\r\nnum0\n\r\n", + "5\r\nnum1\n\r\n", + "5\r\nnum2\n\r\n", + "0\r\n\r\n", + } + if !reflect.DeepEqual(lw.writes, want) { + t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want) + } +} + +// Issue 22088: flush Transport request headers if we're not sure the body won't block on read. +func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) } +func testTransportFlushesRequestHeader(t *testing.T, mode testMode) { + gotReq := make(chan struct{}) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + close(gotReq) + })) + + pr, pw := io.Pipe() + req, err := NewRequest("POST", cst.ts.URL, pr) + if err != nil { + t.Fatal(err) + } + gotRes := make(chan struct{}) + go func() { + defer close(gotRes) + res, err := cst.tr.RoundTrip(req) + if err != nil { + t.Error(err) + return + } + res.Body.Close() + }() + + select { + case <-gotReq: + pw.Close() + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for handler to get request") + } + <-gotRes +} + +// Issue 11745. +func TestTransportPrefersResponseOverWriteError(t *testing.T) { + run(t, testTransportPrefersResponseOverWriteError) +} +func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) { + if testing.Short() { + t.Skip("skipping in short mode") + } + const contentLengthLimit = 1024 * 1024 // 1MB + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + if r.ContentLength >= contentLengthLimit { + w.WriteHeader(StatusBadRequest) + r.Body.Close() + return + } + w.WriteHeader(StatusOK) + })).ts + c := ts.Client() + + fail := 0 + count := 100 + bigBody := strings.Repeat("a", contentLengthLimit*2) + for i := 0; i < count; i++ { + req, err := NewRequest("PUT", ts.URL, strings.NewReader(bigBody)) + if err != nil { + t.Fatal(err) + } + resp, err := c.Do(req) + if err != nil { + fail++ + t.Logf("%d = %#v", i, err) + if ue, ok := err.(*url.Error); ok { + t.Logf("urlErr = %#v", ue.Err) + if ne, ok := ue.Err.(*net.OpError); ok { + t.Logf("netOpError = %#v", ne.Err) + } + } + } else { + resp.Body.Close() + if resp.StatusCode != 400 { + t.Errorf("Expected status code 400, got %v", resp.Status) + } + } + } + if fail > 0 { + t.Errorf("Failed %v out of %v\n", fail, count) + } +} + +func TestTransportAutomaticHTTP2(t *testing.T) { + testTransportAutoHTTP(t, &Transport{}, true) +} + +func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) { + testTransportAutoHTTP(t, &Transport{ + ForceAttemptHTTP2: true, + TLSClientConfig: new(tls.Config), + }, true) +} + +// golang.org/issue/14391: also check DefaultTransport +func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) { + testTransportAutoHTTP(t, DefaultTransport.(*Transport), true) +} + +func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) { + testTransportAutoHTTP(t, &Transport{ + TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper), + }, false) +} + +func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) { + testTransportAutoHTTP(t, &Transport{ + TLSClientConfig: new(tls.Config), + }, false) +} + +func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) { + testTransportAutoHTTP(t, &Transport{ + ExpectContinueTimeout: 1 * time.Second, + }, true) +} + +func TestTransportAutomaticHTTP2_Dial(t *testing.T) { + var d net.Dialer + testTransportAutoHTTP(t, &Transport{ + Dial: d.Dial, + }, false) +} + +func TestTransportAutomaticHTTP2_DialContext(t *testing.T) { + var d net.Dialer + testTransportAutoHTTP(t, &Transport{ + DialContext: d.DialContext, + }, false) +} + +func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) { + testTransportAutoHTTP(t, &Transport{ + DialTLS: func(network, addr string) (net.Conn, error) { + panic("unused") + }, + }, false) +} + +func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) { + CondSkipHTTP2(t) + _, err := tr.RoundTrip(new(Request)) + if err == nil { + t.Error("expected error from RoundTrip") + } + if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 { + t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2) + } +} + +// Issue 13633: there was a race where we returned bodyless responses +// to callers before recycling the persistent connection, which meant +// a client doing two subsequent requests could end up on different +// connections. It's somewhat harmless but enough tests assume it's +// not true in order to test other things that it's worth fixing. +// Plus it's nice to be consistent and not have timing-dependent +// behavior. +func TestTransportReuseConnEmptyResponseBody(t *testing.T) { + run(t, testTransportReuseConnEmptyResponseBody) +} +func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("X-Addr", r.RemoteAddr) + // Empty response body. + })) + n := 100 + if testing.Short() { + n = 10 + } + var firstAddr string + for i := 0; i < n; i++ { + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + log.Fatal(err) + } + addr := res.Header.Get("X-Addr") + if i == 0 { + firstAddr = addr + } else if addr != firstAddr { + t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr) + } + res.Body.Close() + } +} + +// Issue 13839 +func TestNoCrashReturningTransportAltConn(t *testing.T) { + cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey) + if err != nil { + t.Fatal(err) + } + ln := newLocalListener(t) + defer ln.Close() + + var wg sync.WaitGroup + SetPendingDialHooks(func() { wg.Add(1) }, wg.Done) + defer SetPendingDialHooks(nil, nil) + + testDone := make(chan struct{}) + defer close(testDone) + go func() { + tln := tls.NewListener(ln, &tls.Config{ + NextProtos: []string{"foo"}, + Certificates: []tls.Certificate{cert}, + }) + sc, err := tln.Accept() + if err != nil { + t.Error(err) + return + } + if err := sc.(*tls.Conn).Handshake(); err != nil { + t.Error(err) + return + } + <-testDone + sc.Close() + }() + + addr := ln.Addr().String() + + req, _ := NewRequest("GET", "https://fake.tld/", nil) + cancel := make(chan struct{}) + req.Cancel = cancel + + doReturned := make(chan bool, 1) + madeRoundTripper := make(chan bool, 1) + + tr := &Transport{ + DisableKeepAlives: true, + TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{ + "foo": func(authority string, c *tls.Conn) RoundTripper { + madeRoundTripper <- true + return funcRoundTripper(func() { + t.Error("foo RoundTripper should not be called") + }) + }, + }, + Dial: func(_, _ string) (net.Conn, error) { + panic("shouldn't be called") + }, + DialTLS: func(_, _ string) (net.Conn, error) { + tc, err := tls.Dial("tcp", addr, &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"foo"}, + }) + if err != nil { + return nil, err + } + if err := tc.Handshake(); err != nil { + return nil, err + } + close(cancel) + <-doReturned + return tc, nil + }, + } + c := &Client{Transport: tr} + + _, err = c.Do(req) + if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn { + t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err) + } + + doReturned <- true + <-madeRoundTripper + wg.Wait() +} + +func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testTransportReuseConnection_Gzip(t, mode, true) + }) +} + +func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testTransportReuseConnection_Gzip(t, mode, false) + }) +} + +// Make sure we re-use underlying TCP connection for gzipped responses too. +func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) { + addr := make(chan string, 2) + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + addr <- r.RemoteAddr + w.Header().Set("Content-Encoding", "gzip") + if chunked { + w.(Flusher).Flush() + } + w.Write(rgz) // arbitrary gzip response + })).ts + c := ts.Client() + + trace := &httptrace.ClientTrace{ + GetConn: func(hostPort string) { t.Logf("GetConn(%q)", hostPort) }, + GotConn: func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) }, + PutIdleConn: func(err error) { t.Logf("PutIdleConn(%v)", err) }, + ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) }, + ConnectDone: func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) }, + } + ctx := httptrace.WithClientTrace(context.Background(), trace) + + for i := 0; i < 2; i++ { + req, _ := NewRequest("GET", ts.URL, nil) + req = req.WithContext(ctx) + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + buf := make([]byte, len(rgz)) + if n, err := io.ReadFull(res.Body, buf); err != nil { + t.Errorf("%d. ReadFull = %v, %v", i, n, err) + } + // Note: no res.Body.Close call. It should work without it, + // since the flate.Reader's internal buffering will hit EOF + // and that should be sufficient. + } + a1, a2 := <-addr, <-addr + if a1 != a2 { + t.Fatalf("didn't reuse connection") + } +} + +func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) } +func testTransportResponseHeaderLength(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes") + } + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + if r.URL.Path == "/long" { + w.Header().Set("Long", strings.Repeat("a", 1<<20)) + } + })).ts + c := ts.Client() + c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10 + + if res, err := c.Get(ts.URL); err != nil { + t.Fatal(err) + } else { + res.Body.Close() + } + + res, err := c.Get(ts.URL + "/long") + if err == nil { + defer res.Body.Close() + var n int64 + for k, vv := range res.Header { + for _, v := range vv { + n += int64(len(k)) + int64(len(v)) + } + } + t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n) + } + if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) { + t.Errorf("got error: %v; want %q", err, want) + } +} + +func TestTransportEventTrace(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testTransportEventTrace(t, mode, false) + }, testNotParallel) +} + +// test a non-nil httptrace.ClientTrace but with all hooks set to zero. +func TestTransportEventTrace_NoHooks(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testTransportEventTrace(t, mode, true) + }, testNotParallel) +} + +func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) { + const resBody = "some body" + gotWroteReqEvent := make(chan struct{}, 500) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method == "GET" { + // Do nothing for the second request. + return + } + if _, err := io.ReadAll(r.Body); err != nil { + t.Error(err) + } + if !noHooks { + select { + case <-gotWroteReqEvent: + case <-time.After(5 * time.Second): + t.Error("timeout waiting for WroteRequest event") + } + } + io.WriteString(w, resBody) + }), func(tr *Transport) { + if tr.TLSClientConfig != nil { + tr.TLSClientConfig.InsecureSkipVerify = true + } + }) + defer cst.close() + + cst.tr.ExpectContinueTimeout = 1 * time.Second + + var mu sync.Mutex // guards buf + var buf strings.Builder + logf := func(format string, args ...any) { + mu.Lock() + defer mu.Unlock() + fmt.Fprintf(&buf, format, args...) + buf.WriteByte('\n') + } + + addrStr := cst.ts.Listener.Addr().String() + ip, port, err := net.SplitHostPort(addrStr) + if err != nil { + t.Fatal(err) + } + + // Install a fake DNS server. + ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) { + if host != "dns-is-faked.golang" { + t.Errorf("unexpected DNS host lookup for %q/%q", network, host) + return nil, nil + } + return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil + }) + + body := "some body" + req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body)) + req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"} + trace := &httptrace.ClientTrace{ + GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) }, + GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) }, + GotFirstResponseByte: func() { logf("first response byte") }, + PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) }, + DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) }, + DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) }, + ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) }, + ConnectDone: func(network, addr string, err error) { + if err != nil { + t.Errorf("ConnectDone: %v", err) + } + logf("ConnectDone: connected to %s %s = %v", network, addr, err) + }, + WroteHeaderField: func(key string, value []string) { + logf("WroteHeaderField: %s: %v", key, value) + }, + WroteHeaders: func() { + logf("WroteHeaders") + }, + Wait100Continue: func() { logf("Wait100Continue") }, + Got100Continue: func() { logf("Got100Continue") }, + WroteRequest: func(e httptrace.WroteRequestInfo) { + logf("WroteRequest: %+v", e) + gotWroteReqEvent <- struct{}{} + }, + } + if mode == http2Mode { + trace.TLSHandshakeStart = func() { logf("tls handshake start") } + trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) { + logf("tls handshake done. ConnectionState = %v \n err = %v", s, err) + } + } + if noHooks { + // zero out all func pointers, trying to get some path to crash + *trace = httptrace.ClientTrace{} + } + req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) + + req.Header.Set("Expect", "100-continue") + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + logf("got roundtrip.response") + slurp, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + logf("consumed body") + if string(slurp) != resBody || res.StatusCode != 200 { + t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody) + } + res.Body.Close() + + if noHooks { + // Done at this point. Just testing a full HTTP + // requests can happen with a trace pointing to a zero + // ClientTrace, full of nil func pointers. + return + } + + mu.Lock() + got := buf.String() + mu.Unlock() + + wantOnce := func(sub string) { + if strings.Count(got, sub) != 1 { + t.Errorf("expected substring %q exactly once in output.", sub) + } + } + wantOnceOrMore := func(sub string) { + if strings.Count(got, sub) == 0 { + t.Errorf("expected substring %q at least once in output.", sub) + } + } + wantOnce("Getting conn for dns-is-faked.golang:" + port) + wantOnce("DNS start: {Host:dns-is-faked.golang}") + wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}") + wantOnce("got conn: {") + wantOnceOrMore("Connecting to tcp " + addrStr) + wantOnceOrMore("connected to tcp " + addrStr + " = <nil>") + wantOnce("Reused:false WasIdle:false IdleTime:0s") + wantOnce("first response byte") + if mode == http2Mode { + wantOnce("tls handshake start") + wantOnce("tls handshake done") + } else { + wantOnce("PutIdleConn = <nil>") + wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]") + // TODO(meirf): issue 19761. Make these agnostic to h1/h2. (These are not h1 specific, but the + // WroteHeaderField hook is not yet implemented in h2.) + wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port)) + wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body))) + wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]") + wantOnce("WroteHeaderField: Accept-Encoding: [gzip]") + } + wantOnce("WroteHeaders") + wantOnce("Wait100Continue") + wantOnce("Got100Continue") + wantOnce("WroteRequest: {Err:<nil>}") + if strings.Contains(got, " to udp ") { + t.Errorf("should not see UDP (DNS) connections") + } + if t.Failed() { + t.Errorf("Output:\n%s", got) + } + + // And do a second request: + req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil) + req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) + res, err = cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 200 { + t.Fatal(res.Status) + } + res.Body.Close() + + mu.Lock() + got = buf.String() + mu.Unlock() + + sub := "Getting conn for dns-is-faked.golang:" + if gotn, want := strings.Count(got, sub), 2; gotn != want { + t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got) + } + +} + +func TestTransportEventTraceTLSVerify(t *testing.T) { + run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode}) +} +func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) { + var mu sync.Mutex + var buf strings.Builder + logf := func(format string, args ...any) { + mu.Lock() + defer mu.Unlock() + fmt.Fprintf(&buf, format, args...) + buf.WriteByte('\n') + } + + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + t.Error("Unexpected request") + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) { + logf("%s", p) + return len(p), nil + }), "", 0) + }).ts + + certpool := x509.NewCertPool() + certpool.AddCert(ts.Certificate()) + + c := &Client{Transport: &Transport{ + TLSClientConfig: &tls.Config{ + ServerName: "dns-is-faked.golang", + RootCAs: certpool, + }, + }} + + trace := &httptrace.ClientTrace{ + TLSHandshakeStart: func() { logf("TLSHandshakeStart") }, + TLSHandshakeDone: func(s tls.ConnectionState, err error) { + logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err) + }, + } + + req, _ := NewRequest("GET", ts.URL, nil) + req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace)) + _, err := c.Do(req) + if err == nil { + t.Error("Expected request to fail TLS verification") + } + + mu.Lock() + got := buf.String() + mu.Unlock() + + wantOnce := func(sub string) { + if strings.Count(got, sub) != 1 { + t.Errorf("expected substring %q exactly once in output.", sub) + } + } + + wantOnce("TLSHandshakeStart") + wantOnce("TLSHandshakeDone") + wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com") + + if t.Failed() { + t.Errorf("Output:\n%s", got) + } +} + +var ( + isDNSHijackedOnce sync.Once + isDNSHijacked bool +) + +func skipIfDNSHijacked(t *testing.T) { + // Skip this test if the user is using a shady/ISP + // DNS server hijacking queries. + // See issues 16732, 16716. + isDNSHijackedOnce.Do(func() { + addrs, _ := net.LookupHost("dns-should-not-resolve.golang") + isDNSHijacked = len(addrs) != 0 + }) + if isDNSHijacked { + t.Skip("skipping; test requires non-hijacking DNS server") + } +} + +func TestTransportEventTraceRealDNS(t *testing.T) { + skipIfDNSHijacked(t) + defer afterTest(t) + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + var mu sync.Mutex // guards buf + var buf strings.Builder + logf := func(format string, args ...any) { + mu.Lock() + defer mu.Unlock() + fmt.Fprintf(&buf, format, args...) + buf.WriteByte('\n') + } + + req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil) + trace := &httptrace.ClientTrace{ + DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) }, + DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) }, + ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) }, + ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) }, + } + req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace)) + + resp, err := c.Do(req) + if err == nil { + resp.Body.Close() + t.Fatal("expected error during DNS lookup") + } + + mu.Lock() + got := buf.String() + mu.Unlock() + + wantSub := func(sub string) { + if !strings.Contains(got, sub) { + t.Errorf("expected substring %q in output.", sub) + } + } + wantSub("DNSStart: {Host:dns-should-not-resolve.golang}") + wantSub("DNSDone: {Addrs:[] Err:") + if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") { + t.Errorf("should not see Connect events") + } + if t.Failed() { + t.Errorf("Output:\n%s", got) + } +} + +// Issue 14353: port can only contain digits. +func TestTransportRejectsAlphaPort(t *testing.T) { + res, err := Get("http://dummy.tld:123foo/bar") + if err == nil { + res.Body.Close() + t.Fatal("unexpected success") + } + ue, ok := err.(*url.Error) + if !ok { + t.Fatalf("got %#v; want *url.Error", err) + } + got := ue.Err.Error() + want := `invalid port ":123foo" after host` + if got != want { + t.Errorf("got error %q; want %q", got, want) + } +} + +// Test the httptrace.TLSHandshake{Start,Done} hooks with a https http1 +// connections. The http2 test is done in TestTransportEventTrace_h2 +func TestTLSHandshakeTrace(t *testing.T) { + run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode}) +} +func testTLSHandshakeTrace(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts + + var mu sync.Mutex + var start, done bool + trace := &httptrace.ClientTrace{ + TLSHandshakeStart: func() { + mu.Lock() + defer mu.Unlock() + start = true + }, + TLSHandshakeDone: func(s tls.ConnectionState, err error) { + mu.Lock() + defer mu.Unlock() + done = true + if err != nil { + t.Fatal("Expected error to be nil but was:", err) + } + }, + } + + c := ts.Client() + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal("Unable to construct test request:", err) + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + + r, err := c.Do(req) + if err != nil { + t.Fatal("Unexpected error making request:", err) + } + r.Body.Close() + mu.Lock() + defer mu.Unlock() + if !start { + t.Fatal("Expected TLSHandshakeStart to be called, but wasn't") + } + if !done { + t.Fatal("Expected TLSHandshakeDone to be called, but wasnt't") + } +} + +func TestTransportMaxIdleConns(t *testing.T) { + run(t, testTransportMaxIdleConns, []testMode{http1Mode}) +} +func testTransportMaxIdleConns(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + // No body for convenience. + })).ts + c := ts.Client() + tr := c.Transport.(*Transport) + tr.MaxIdleConns = 4 + + ip, port, err := net.SplitHostPort(ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) { + return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil + }) + + hitHost := func(n int) { + req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil) + req = req.WithContext(ctx) + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + } + for i := 0; i < 4; i++ { + hitHost(i) + } + want := []string{ + "|http|host-0.dns-is-faked.golang:" + port, + "|http|host-1.dns-is-faked.golang:" + port, + "|http|host-2.dns-is-faked.golang:" + port, + "|http|host-3.dns-is-faked.golang:" + port, + } + if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) { + t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want) + } + + // Now hitting the 5th host should kick out the first host: + hitHost(4) + want = []string{ + "|http|host-1.dns-is-faked.golang:" + port, + "|http|host-2.dns-is-faked.golang:" + port, + "|http|host-3.dns-is-faked.golang:" + port, + "|http|host-4.dns-is-faked.golang:" + port, + } + if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) { + t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want) + } +} + +func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) } +func testTransportIdleConnTimeout(t *testing.T, mode testMode) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + const timeout = 1 * time.Second + + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + // No body for convenience. + })) + tr := cst.tr + tr.IdleConnTimeout = timeout + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + idleConns := func() []string { + if mode == http2Mode { + return tr.IdleConnStrsForTesting_h2() + } else { + return tr.IdleConnStrsForTesting() + } + } + + var conn string + doReq := func(n int) { + req, _ := NewRequest("GET", cst.ts.URL, nil) + req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ + PutIdleConn: func(err error) { + if err != nil { + t.Errorf("failed to keep idle conn: %v", err) + } + }, + })) + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + conns := idleConns() + if len(conns) != 1 { + t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns) + } + if conn == "" { + conn = conns[0] + } + if conn != conns[0] { + t.Fatalf("req %v: cached connection changed; expected the same one throughout the test", n) + } + } + for i := 0; i < 3; i++ { + doReq(i) + time.Sleep(timeout / 2) + } + time.Sleep(timeout * 3 / 2) + if got := idleConns(); len(got) != 0 { + t.Errorf("idle conns = %q; want none", got) + } +} + +// Issue 16208: Go 1.7 crashed after Transport.IdleConnTimeout if an +// HTTP/2 connection was established but its caller no longer +// wanted it. (Assuming the connection cache was enabled, which it is +// by default) +// +// This test reproduced the crash by setting the IdleConnTimeout low +// (to make the test reasonable) and then making a request which is +// canceled by the DialTLS hook, which then also waits to return the +// real connection until after the RoundTrip saw the error. Then we +// know the successful tls.Dial from DialTLS will need to go into the +// idle pool. Then we give it a of time to explode. +func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) } +func testIdleConnH2Crash(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + // nothing + })) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sawDoErr := make(chan bool, 1) + testDone := make(chan struct{}) + defer close(testDone) + + cst.tr.IdleConnTimeout = 5 * time.Millisecond + cst.tr.DialTLS = func(network, addr string) (net.Conn, error) { + c, err := tls.Dial(network, addr, &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"h2"}, + }) + if err != nil { + t.Error(err) + return nil, err + } + if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" { + t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2") + c.Close() + return nil, errors.New("bogus") + } + + cancel() + + failTimer := time.NewTimer(5 * time.Second) + defer failTimer.Stop() + select { + case <-sawDoErr: + case <-testDone: + case <-failTimer.C: + t.Error("timeout in DialTLS, waiting too long for cst.c.Do to fail") + } + return c, nil + } + + req, _ := NewRequest("GET", cst.ts.URL, nil) + req = req.WithContext(ctx) + res, err := cst.c.Do(req) + if err == nil { + res.Body.Close() + t.Fatal("unexpected success") + } + sawDoErr <- true + + // Wait for the explosion. + time.Sleep(cst.tr.IdleConnTimeout * 10) +} + +type funcConn struct { + net.Conn + read func([]byte) (int, error) + write func([]byte) (int, error) +} + +func (c funcConn) Read(p []byte) (int, error) { return c.read(p) } +func (c funcConn) Write(p []byte) (int, error) { return c.write(p) } +func (c funcConn) Close() error { return nil } + +// Issue 16465: Transport.RoundTrip should return the raw net.Conn.Read error from Peek +// back to the caller. +func TestTransportReturnsPeekError(t *testing.T) { + errValue := errors.New("specific error value") + + wrote := make(chan struct{}) + var wroteOnce sync.Once + + tr := &Transport{ + Dial: func(network, addr string) (net.Conn, error) { + c := funcConn{ + read: func([]byte) (int, error) { + <-wrote + return 0, errValue + }, + write: func(p []byte) (int, error) { + wroteOnce.Do(func() { close(wrote) }) + return len(p), nil + }, + } + return c, nil + }, + } + _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil)) + if err != errValue { + t.Errorf("error = %#v; want %v", err, errValue) + } +} + +// Issue 13835: international domain names should work +func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) } +func testTransportIDNA(t *testing.T, mode testMode) { + const uniDomain = "гофер.го" + const punyDomain = "xn--c1ae0ajs.xn--c1aw" + + var port string + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + want := punyDomain + ":" + port + if r.Host != want { + t.Errorf("Host header = %q; want %q", r.Host, want) + } + if mode == http2Mode { + if r.TLS == nil { + t.Errorf("r.TLS == nil") + } else if r.TLS.ServerName != punyDomain { + t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain) + } + } + w.Header().Set("Hit-Handler", "1") + }), func(tr *Transport) { + if tr.TLSClientConfig != nil { + tr.TLSClientConfig.InsecureSkipVerify = true + } + }) + + ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + + // Install a fake DNS server. + ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) { + if host != punyDomain { + t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain) + return nil, nil + } + return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil + }) + + req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil) + trace := &httptrace.ClientTrace{ + GetConn: func(hostPort string) { + want := net.JoinHostPort(punyDomain, port) + if hostPort != want { + t.Errorf("getting conn for %q; want %q", hostPort, want) + } + }, + DNSStart: func(e httptrace.DNSStartInfo) { + if e.Host != punyDomain { + t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain) + } + }, + } + req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) + + res, err := cst.tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.Header.Get("Hit-Handler") != "1" { + out, err := httputil.DumpResponse(res, true) + if err != nil { + t.Fatal(err) + } + t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out) + } +} + +// Issue 13290: send User-Agent in proxy CONNECT +func TestTransportProxyConnectHeader(t *testing.T) { + run(t, testTransportProxyConnectHeader, []testMode{http1Mode}) +} +func testTransportProxyConnectHeader(t *testing.T, mode testMode) { + reqc := make(chan *Request, 1) + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method != "CONNECT" { + t.Errorf("method = %q; want CONNECT", r.Method) + } + reqc <- r + c, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Errorf("Hijack: %v", err) + return + } + c.Close() + })).ts + + c := ts.Client() + c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) { + return url.Parse(ts.URL) + } + c.Transport.(*Transport).ProxyConnectHeader = Header{ + "User-Agent": {"foo"}, + "Other": {"bar"}, + } + + res, err := c.Get("https://dummy.tld/") // https to force a CONNECT + if err == nil { + res.Body.Close() + t.Errorf("unexpected success") + } + select { + case <-time.After(3 * time.Second): + t.Fatal("timeout") + case r := <-reqc: + if got, want := r.Header.Get("User-Agent"), "foo"; got != want { + t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) + } + if got, want := r.Header.Get("Other"), "bar"; got != want { + t.Errorf("CONNECT request Other = %q; want %q", got, want) + } + } +} + +func TestTransportProxyGetConnectHeader(t *testing.T) { + run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode}) +} +func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) { + reqc := make(chan *Request, 1) + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method != "CONNECT" { + t.Errorf("method = %q; want CONNECT", r.Method) + } + reqc <- r + c, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Errorf("Hijack: %v", err) + return + } + c.Close() + })).ts + + c := ts.Client() + c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) { + return url.Parse(ts.URL) + } + // These should be ignored: + c.Transport.(*Transport).ProxyConnectHeader = Header{ + "User-Agent": {"foo"}, + "Other": {"bar"}, + } + c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) { + return Header{ + "User-Agent": {"foo2"}, + "Other": {"bar2"}, + }, nil + } + + res, err := c.Get("https://dummy.tld/") // https to force a CONNECT + if err == nil { + res.Body.Close() + t.Errorf("unexpected success") + } + select { + case <-time.After(3 * time.Second): + t.Fatal("timeout") + case r := <-reqc: + if got, want := r.Header.Get("User-Agent"), "foo2"; got != want { + t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) + } + if got, want := r.Header.Get("Other"), "bar2"; got != want { + t.Errorf("CONNECT request Other = %q; want %q", got, want) + } + } +} + +var errFakeRoundTrip = errors.New("fake roundtrip") + +type funcRoundTripper func() + +func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) { + fn() + return nil, errFakeRoundTrip +} + +func wantBody(res *Response, err error, want string) error { + if err != nil { + return err + } + slurp, err := io.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("error reading body: %v", err) + } + if string(slurp) != want { + return fmt.Errorf("body = %q; want %q", slurp, want) + } + if err := res.Body.Close(); err != nil { + return fmt.Errorf("body Close = %v", err) + } + return nil +} + +func newLocalListener(t *testing.T) net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + ln, err = net.Listen("tcp6", "[::1]:0") + } + if err != nil { + t.Fatal(err) + } + return ln +} + +type countCloseReader struct { + n *int + io.Reader +} + +func (cr countCloseReader) Close() error { + (*cr.n)++ + return nil +} + +// rgz is a gzip quine that uncompresses to itself. +var rgz = []byte{ + 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73, + 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0, + 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, + 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, + 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60, + 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2, + 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00, + 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, + 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16, + 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05, + 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff, + 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00, + 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, + 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, + 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, + 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, + 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, + 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, + 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, + 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, + 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff, + 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00, + 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, + 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, + 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, + 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, + 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06, + 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00, + 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, + 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, + 0x00, 0x00, +} + +// Ensure that a missing status doesn't make the server panic +// See Issue https://golang.org/issues/21701 +func TestMissingStatusNoPanic(t *testing.T) { + t.Parallel() + + const want = "unknown status code" + + ln := newLocalListener(t) + addr := ln.Addr().String() + done := make(chan bool) + fullAddrURL := fmt.Sprintf("http://%s", addr) + raw := "HTTP/1.1 400\r\n" + + "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" + + "Content-Type: text/html; charset=utf-8\r\n" + + "Content-Length: 10\r\n" + + "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" + + "Vary: Accept-Encoding\r\n\r\n" + + "Aloha Olaa" + + go func() { + defer close(done) + + conn, _ := ln.Accept() + if conn != nil { + io.WriteString(conn, raw) + io.ReadAll(conn) + conn.Close() + } + }() + + proxyURL, err := url.Parse(fullAddrURL) + if err != nil { + t.Fatalf("proxyURL: %v", err) + } + + tr := &Transport{Proxy: ProxyURL(proxyURL)} + + req, _ := NewRequest("GET", "https://golang.org/", nil) + res, err, panicked := doFetchCheckPanic(tr, req) + if panicked { + t.Error("panicked, expecting an error") + } + if res != nil && res.Body != nil { + io.Copy(io.Discard, res.Body) + res.Body.Close() + } + + if err == nil || !strings.Contains(err.Error(), want) { + t.Errorf("got=%v want=%q", err, want) + } + + ln.Close() + <-done +} + +func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + res, err = tr.RoundTrip(req) + return +} + +// Issue 22330: do not allow the response body to be read when the status code +// forbids a response body. +func TestNoBodyOnChunked304Response(t *testing.T) { + run(t, testNoBodyOnChunked304Response, []testMode{http1Mode}) +} +func testNoBodyOnChunked304Response(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + conn, buf, _ := w.(Hijacker).Hijack() + buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n")) + buf.Flush() + conn.Close() + })) + + // Our test server above is sending back bogus data after the + // response (the "0\r\n\r\n" part), which causes the Transport + // code to log spam. Disable keep-alives so we never even try + // to reuse the connection. + cst.tr.DisableKeepAlives = true + + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + + if res.Body != NoBody { + t.Errorf("Unexpected body on 304 response") + } +} + +type funcWriter func([]byte) (int, error) + +func (f funcWriter) Write(p []byte) (int, error) { return f(p) } + +type doneContext struct { + context.Context + err error +} + +func (doneContext) Done() <-chan struct{} { + c := make(chan struct{}) + close(c) + return c +} + +func (d doneContext) Err() error { return d.err } + +// Issue 25852: Transport should check whether Context is done early. +func TestTransportCheckContextDoneEarly(t *testing.T) { + tr := &Transport{} + req, _ := NewRequest("GET", "http://fake.example/", nil) + wantErr := errors.New("some error") + req = req.WithContext(doneContext{context.Background(), wantErr}) + _, err := tr.RoundTrip(req) + if err != wantErr { + t.Errorf("error = %v; want %v", err, wantErr) + } +} + +// Issue 23399: verify that if a client request times out, the Transport's +// conn is closed so that it's not reused. +// +// This is the test variant that times out before the server replies with +// any response headers. +func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) { + run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode}) +} +func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) { + inHandler := make(chan net.Conn, 1) + handlerReadReturned := make(chan bool, 1) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + conn, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + inHandler <- conn + n, err := conn.Read([]byte{0}) + if n != 0 || err != io.EOF { + t.Errorf("unexpected Read result: %v, %v", n, err) + } + handlerReadReturned <- true + })) + + const timeout = 50 * time.Millisecond + cst.c.Timeout = timeout + + _, err := cst.c.Get(cst.ts.URL) + if err == nil { + t.Fatal("unexpected Get succeess") + } + + select { + case c := <-inHandler: + select { + case <-handlerReadReturned: + // Success. + return + case <-time.After(5 * time.Second): + t.Error("Handler's conn.Read seems to be stuck in Read") + c.Close() // close it to unblock Handler + } + case <-time.After(timeout * 10): + // If we didn't get into the Handler in 50ms, that probably means + // the builder was just slow and the Get failed in that time + // but never made it to the server. That's fine. We'll usually + // test the part above on faster machines. + t.Skip("skipping test on slow builder") + } +} + +// Issue 23399: verify that if a client request times out, the Transport's +// conn is closed so that it's not reused. +// +// This is the test variant that has the server send response headers +// first, and time out during the write of the response body. +func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) { + run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode}) +} +func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) { + inHandler := make(chan net.Conn, 1) + handlerResult := make(chan error, 1) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "100") + w.(Flusher).Flush() + conn, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + conn.Write([]byte("foo")) + inHandler <- conn + n, err := conn.Read([]byte{0}) + // The error should be io.EOF or "read tcp + // 127.0.0.1:35827->127.0.0.1:40290: read: connection + // reset by peer" depending on timing. Really we just + // care that it returns at all. But if it returns with + // data, that's weird. + if n != 0 || err == nil { + handlerResult <- fmt.Errorf("unexpected Read result: %v, %v", n, err) + return + } + handlerResult <- nil + })) + + // Set Timeout to something very long but non-zero to exercise + // the codepaths that check for it. But rather than wait for it to fire + // (which would make the test slow), we send on the req.Cancel channel instead, + // which happens to exercise the same code paths. + cst.c.Timeout = time.Minute // just to be non-zero, not to hit it. + req, _ := NewRequest("GET", cst.ts.URL, nil) + cancel := make(chan struct{}) + req.Cancel = cancel + + res, err := cst.c.Do(req) + if err != nil { + select { + case <-inHandler: + t.Fatalf("Get error: %v", err) + default: + // Failed before entering handler. Ignore result. + t.Skip("skipping test on slow builder") + } + } + + close(cancel) + got, err := io.ReadAll(res.Body) + if err == nil { + t.Fatalf("unexpected success; read %q, nil", got) + } + + select { + case c := <-inHandler: + select { + case err := <-handlerResult: + if err != nil { + t.Errorf("handler: %v", err) + } + return + case <-time.After(5 * time.Second): + t.Error("Handler's conn.Read seems to be stuck in Read") + c.Close() // close it to unblock Handler + } + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } +} + +func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) { + run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode}) +} +func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) { + done := make(chan struct{}) + defer close(done) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + conn, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + defer conn.Close() + io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n") + bs := bufio.NewScanner(conn) + bs.Scan() + fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text())) + <-done + })) + + req, _ := NewRequest("GET", cst.ts.URL, nil) + req.Header.Set("Upgrade", "foo") + req.Header.Set("Connection", "upgrade") + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 101 { + t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header) + } + rwc, ok := res.Body.(io.ReadWriteCloser) + if !ok { + t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body) + } + defer rwc.Close() + bs := bufio.NewScanner(rwc) + if !bs.Scan() { + t.Fatalf("expected readable input") + } + if got, want := bs.Text(), "Some buffered data"; got != want { + t.Errorf("read %q; want %q", got, want) + } + io.WriteString(rwc, "echo\n") + if !bs.Scan() { + t.Fatalf("expected another line") + } + if got, want := bs.Text(), "ECHO"; got != want { + t.Errorf("read %q; want %q", got, want) + } +} + +func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) } +func testTransportCONNECTBidi(t *testing.T, mode testMode) { + const target = "backend:443" + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method != "CONNECT" { + t.Errorf("unexpected method %q", r.Method) + w.WriteHeader(500) + return + } + if r.RequestURI != target { + t.Errorf("unexpected CONNECT target %q", r.RequestURI) + w.WriteHeader(500) + return + } + nc, brw, err := w.(Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + defer nc.Close() + nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n")) + // Switch to a little protocol that capitalize its input lines: + for { + line, err := brw.ReadString('\n') + if err != nil { + if err != io.EOF { + t.Error(err) + } + return + } + io.WriteString(brw, strings.ToUpper(line)) + brw.Flush() + } + })) + pr, pw := io.Pipe() + defer pw.Close() + req, err := NewRequest("CONNECT", cst.ts.URL, pr) + if err != nil { + t.Fatal(err) + } + req.URL.Opaque = target + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + t.Fatalf("status code = %d; want 200", res.StatusCode) + } + br := bufio.NewReader(res.Body) + for _, str := range []string{"foo", "bar", "baz"} { + fmt.Fprintf(pw, "%s\n", str) + got, err := br.ReadString('\n') + if err != nil { + t.Fatal(err) + } + got = strings.TrimSpace(got) + want := strings.ToUpper(str) + if got != want { + t.Fatalf("got %q; want %q", got, want) + } + } +} + +func TestTransportRequestReplayable(t *testing.T) { + someBody := io.NopCloser(strings.NewReader("")) + tests := []struct { + name string + req *Request + want bool + }{ + { + name: "GET", + req: &Request{Method: "GET"}, + want: true, + }, + { + name: "GET_http.NoBody", + req: &Request{Method: "GET", Body: NoBody}, + want: true, + }, + { + name: "GET_body", + req: &Request{Method: "GET", Body: someBody}, + want: false, + }, + { + name: "POST", + req: &Request{Method: "POST"}, + want: false, + }, + { + name: "POST_idempotency-key", + req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}}, + want: true, + }, + { + name: "POST_x-idempotency-key", + req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}}, + want: true, + }, + { + name: "POST_body", + req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody}, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.req.ExportIsReplayable() + if got != tt.want { + t.Errorf("replyable = %v; want %v", got, tt.want) + } + }) + } +} + +// testMockTCPConn is a mock TCP connection used to test that +// ReadFrom is called when sending the request body. +type testMockTCPConn struct { + *net.TCPConn + + ReadFromCalled bool +} + +func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) { + c.ReadFromCalled = true + return c.TCPConn.ReadFrom(r) +} + +func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) } +func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) { + nBytes := int64(1 << 10) + newFileFunc := func() (r io.Reader, done func(), err error) { + f, err := os.CreateTemp("", "net-http-newfilefunc") + if err != nil { + return nil, nil, err + } + + // Write some bytes to the file to enable reading. + if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil { + return nil, nil, fmt.Errorf("failed to write data to file: %v", err) + } + if _, err := f.Seek(0, 0); err != nil { + return nil, nil, fmt.Errorf("failed to seek to front: %v", err) + } + + done = func() { + f.Close() + os.Remove(f.Name()) + } + + return f, done, nil + } + + newBufferFunc := func() (io.Reader, func(), error) { + return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil + } + + cases := []struct { + name string + readerFunc func() (io.Reader, func(), error) + contentLength int64 + expectedReadFrom bool + }{ + { + name: "file, length", + readerFunc: newFileFunc, + contentLength: nBytes, + expectedReadFrom: true, + }, + { + name: "file, no length", + readerFunc: newFileFunc, + }, + { + name: "file, negative length", + readerFunc: newFileFunc, + contentLength: -1, + }, + { + name: "buffer", + contentLength: nBytes, + readerFunc: newBufferFunc, + }, + { + name: "buffer, no length", + readerFunc: newBufferFunc, + }, + { + name: "buffer, length -1", + contentLength: -1, + readerFunc: newBufferFunc, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + r, cleanup, err := tc.readerFunc() + if err != nil { + t.Fatal(err) + } + defer cleanup() + + tConn := &testMockTCPConn{} + trFunc := func(tr *Transport) { + tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + conn, err := d.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr) + } + + tConn.TCPConn = tcpConn + return tConn, nil + } + } + + cst := newClientServerTest( + t, + mode, + HandlerFunc(func(w ResponseWriter, r *Request) { + io.Copy(io.Discard, r.Body) + r.Body.Close() + w.WriteHeader(200) + }), + trFunc, + ) + + req, err := NewRequest("PUT", cst.ts.URL, r) + if err != nil { + t.Fatal(err) + } + req.ContentLength = tc.contentLength + req.Header.Set("Content-Type", "application/octet-stream") + resp, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Fatalf("status code = %d; want 200", resp.StatusCode) + } + + expectedReadFrom := tc.expectedReadFrom + if mode != http1Mode { + expectedReadFrom = false + } + if !tConn.ReadFromCalled && expectedReadFrom { + t.Fatalf("did not call ReadFrom") + } + + if tConn.ReadFromCalled && !expectedReadFrom { + t.Fatalf("ReadFrom was unexpectedly invoked") + } + }) + } +} + +func TestTransportClone(t *testing.T) { + tr := &Transport{ + Proxy: func(*Request) (*url.URL, error) { panic("") }, + OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error { + return nil + }, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, + Dial: func(network, addr string) (net.Conn, error) { panic("") }, + DialTLS: func(network, addr string) (net.Conn, error) { panic("") }, + DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, + TLSClientConfig: new(tls.Config), + TLSHandshakeTimeout: time.Second, + DisableKeepAlives: true, + DisableCompression: true, + MaxIdleConns: 1, + MaxIdleConnsPerHost: 1, + MaxConnsPerHost: 1, + IdleConnTimeout: time.Second, + ResponseHeaderTimeout: time.Second, + ExpectContinueTimeout: time.Second, + ProxyConnectHeader: Header{}, + GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil }, + MaxResponseHeaderBytes: 1, + ForceAttemptHTTP2: true, + TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{ + "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") }, + }, + ReadBufferSize: 1, + WriteBufferSize: 1, + } + tr2 := tr.Clone() + rv := reflect.ValueOf(tr2).Elem() + rt := rv.Type() + for i := 0; i < rt.NumField(); i++ { + sf := rt.Field(i) + if !token.IsExported(sf.Name) { + continue + } + if rv.Field(i).IsZero() { + t.Errorf("cloned field t2.%s is zero", sf.Name) + } + } + + if _, ok := tr2.TLSNextProto["foo"]; !ok { + t.Errorf("cloned Transport lacked TLSNextProto 'foo' key") + } + + // But test that a nil TLSNextProto is kept nil: + tr = new(Transport) + tr2 = tr.Clone() + if tr2.TLSNextProto != nil { + t.Errorf("Transport.TLSNextProto unexpected non-nil") + } +} + +func TestIs408(t *testing.T) { + tests := []struct { + in string + want bool + }{ + {"HTTP/1.0 408", true}, + {"HTTP/1.1 408", true}, + {"HTTP/1.8 408", true}, + {"HTTP/2.0 408", false}, // maybe h2c would do this? but false for now. + {"HTTP/1.1 408 ", true}, + {"HTTP/1.1 40", false}, + {"http/1.0 408", false}, + {"HTTP/1-1 408", false}, + } + for _, tt := range tests { + if got := Export_is408Message([]byte(tt.in)); got != tt.want { + t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want) + } + } +} + +func TestTransportIgnores408(t *testing.T) { + run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel) +} +func testTransportIgnores408(t *testing.T, mode testMode) { + // Not parallel. Relies on mutating the log package's global Output. + defer log.SetOutput(log.Writer()) + + var logout strings.Builder + log.SetOutput(&logout) + + const target = "backend:443" + + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + nc, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + defer nc.Close() + nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok")) + nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail + })) + req, err := NewRequest("GET", cst.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + slurp, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if err != nil { + t.Fatal(err) + } + if string(slurp) != "ok" { + t.Fatalf("got %q; want ok", slurp) + } + + t0 := time.Now() + for i := 0; i < 50; i++ { + time.Sleep(time.Duration(i) * 5 * time.Millisecond) + if cst.tr.IdleConnKeyCountForTesting() == 0 { + if got := logout.String(); got != "" { + t.Fatalf("expected no log output; got: %s", got) + } + return + } + } + t.Fatalf("timeout after %v waiting for Transport connections to die off", time.Since(t0)) +} + +func TestInvalidHeaderResponse(t *testing.T) { + run(t, testInvalidHeaderResponse, []testMode{http1Mode}) +} +func testInvalidHeaderResponse(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + conn, buf, _ := w.(Hijacker).Hijack() + buf.Write([]byte("HTTP/1.1 200 OK\r\n" + + "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" + + "Content-Type: text/html; charset=utf-8\r\n" + + "Content-Length: 0\r\n" + + "Foo : bar\r\n\r\n")) + buf.Flush() + conn.Close() + })) + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if v := res.Header.Get("Foo"); v != "" { + t.Errorf(`unexpected "Foo" header: %q`, v) + } + if v := res.Header.Get("Foo "); v != "bar" { + t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar") + } +} + +type bodyCloser bool + +func (bc *bodyCloser) Close() error { + *bc = true + return nil +} +func (bc *bodyCloser) Read(b []byte) (n int, err error) { + return 0, io.EOF +} + +// Issue 35015: ensure that Transport closes the body on any error +// with an invalid request, as promised by Client.Do docs. +func TestTransportClosesBodyOnInvalidRequests(t *testing.T) { + run(t, testTransportClosesBodyOnInvalidRequests) +} +func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + t.Errorf("Should not have been invoked") + })).ts + + u, _ := url.Parse(cst.URL) + + tests := []struct { + name string + req *Request + wantErr string + }{ + { + name: "invalid method", + req: &Request{ + Method: " ", + URL: u, + }, + wantErr: `invalid method " "`, + }, + { + name: "nil URL", + req: &Request{ + Method: "GET", + }, + wantErr: `nil Request.URL`, + }, + { + name: "invalid header key", + req: &Request{ + Method: "GET", + Header: Header{"💡": {"emoji"}}, + URL: u, + }, + wantErr: `invalid header field name "💡"`, + }, + { + name: "invalid header value", + req: &Request{ + Method: "POST", + Header: Header{"key": {"\x19"}}, + URL: u, + }, + wantErr: `invalid header field value for "key"`, + }, + { + name: "non HTTP(s) scheme", + req: &Request{ + Method: "POST", + URL: &url.URL{Scheme: "faux"}, + }, + wantErr: `unsupported protocol scheme "faux"`, + }, + { + name: "no Host in URL", + req: &Request{ + Method: "POST", + URL: &url.URL{Scheme: "http"}, + }, + wantErr: `no Host in request URL`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var bc bodyCloser + req := tt.req + req.Body = &bc + _, err := cst.Client().Do(tt.req) + if err == nil { + t.Fatal("Expected an error") + } + if !bc { + t.Fatal("Expected body to have been closed") + } + if g, w := err.Error(), tt.wantErr; !strings.HasSuffix(g, w) { + t.Fatalf("Error mismatch: %q does not end with %q", g, w) + } + }) + } +} + +// breakableConn is a net.Conn wrapper with a Write method +// that will fail when its brokenState is true. +type breakableConn struct { + net.Conn + *brokenState +} + +type brokenState struct { + sync.Mutex + broken bool +} + +func (w *breakableConn) Write(b []byte) (n int, err error) { + w.Lock() + defer w.Unlock() + if w.broken { + return 0, errors.New("some write error") + } + return w.Conn.Write(b) +} + +// Issue 34978: don't cache a broken HTTP/2 connection +func TestDontCacheBrokenHTTP2Conn(t *testing.T) { + run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode}) +} +func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog) + + var brokenState brokenState + + const numReqs = 5 + var numDials, gotConns uint32 // atomic + + cst.tr.Dial = func(netw, addr string) (net.Conn, error) { + atomic.AddUint32(&numDials, 1) + c, err := net.Dial(netw, addr) + if err != nil { + t.Errorf("unexpected Dial error: %v", err) + return nil, err + } + return &breakableConn{c, &brokenState}, err + } + + for i := 1; i <= numReqs; i++ { + brokenState.Lock() + brokenState.broken = false + brokenState.Unlock() + + // doBreak controls whether we break the TCP connection after the TLS + // handshake (before the HTTP/2 handshake). We test a few failures + // in a row followed by a final success. + doBreak := i != numReqs + + ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ + GotConn: func(info httptrace.GotConnInfo) { + t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime) + atomic.AddUint32(&gotConns, 1) + }, + TLSHandshakeDone: func(cfg tls.ConnectionState, err error) { + brokenState.Lock() + defer brokenState.Unlock() + if doBreak { + brokenState.broken = true + } + }, + }) + req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + _, err = cst.c.Do(req) + if doBreak != (err != nil) { + t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err) + } + } + if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want { + t.Errorf("GotConn calls = %v; want %v", got, want) + } + if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want { + t.Errorf("Dials = %v; want %v", got, want) + } +} + +// Issue 34941 +// When the client has too many concurrent requests on a single connection, +// http.http2noCachedConnError is reported on multiple requests. There should +// only be one decrement regardless of the number of failures. +func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) { + run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode}) +} +func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) { + CondSkipHTTP2(t) + + h := HandlerFunc(func(w ResponseWriter, r *Request) { + _, err := w.Write([]byte("foo")) + if err != nil { + t.Fatalf("Write: %v", err) + } + }) + + ts := newClientServerTest(t, mode, h).ts + + c := ts.Client() + tr := c.Transport.(*Transport) + tr.MaxConnsPerHost = 1 + + errCh := make(chan error, 300) + doReq := func() { + resp, err := c.Get(ts.URL) + if err != nil { + errCh <- fmt.Errorf("request failed: %v", err) + return + } + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + if err != nil { + errCh <- fmt.Errorf("read body failed: %v", err) + } + } + + var wg sync.WaitGroup + for i := 0; i < 300; i++ { + wg.Add(1) + go func() { + defer wg.Done() + doReq() + }() + } + wg.Wait() + close(errCh) + + for err := range errCh { + t.Errorf("error occurred: %v", err) + } +} + +// Issue 36820 +// Test that we use the older backward compatible cancellation protocol +// when a RoundTripper is registered via RegisterProtocol. +func TestAltProtoCancellation(t *testing.T) { + defer afterTest(t) + tr := &Transport{} + c := &Client{ + Transport: tr, + Timeout: time.Millisecond, + } + tr.RegisterProtocol("timeout", timeoutProto{}) + _, err := c.Get("timeout://bar.com/path") + if err == nil { + t.Error("request unexpectedly succeeded") + } else if !strings.Contains(err.Error(), timeoutProtoErr.Error()) { + t.Errorf("got error %q, does not contain expected string %q", err, timeoutProtoErr) + } +} + +var timeoutProtoErr = errors.New("canceled as expected") + +type timeoutProto struct{} + +func (timeoutProto) RoundTrip(req *Request) (*Response, error) { + select { + case <-req.Cancel: + return nil, timeoutProtoErr + case <-time.After(5 * time.Second): + return nil, errors.New("request was not canceled") + } +} + +type roundTripFunc func(r *Request) (*Response, error) + +func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) } + +// Issue 32441: body is not reset after ErrSkipAltProtocol +func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) } +func testIssue32441(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + if n, _ := io.Copy(io.Discard, r.Body); n == 0 { + t.Error("body length is zero") + } + })).ts + c := ts.Client() + c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) { + // Draining body to trigger failure condition on actual request to server. + if n, _ := io.Copy(io.Discard, r.Body); n == 0 { + t.Error("body length is zero during round trip") + } + return nil, ErrSkipAltProtocol + })) + if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil { + t.Error(err) + } +} + +// Issue 39017. Ensure that HTTP/1 transports reject Content-Length headers +// that contain a sign (eg. "+3"), per RFC 2616, Section 14.13. +func TestTransportRejectsSignInContentLength(t *testing.T) { + run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode}) +} +func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "+3") + w.Write([]byte("abc")) + })).ts + + c := cst.Client() + res, err := c.Get(cst.URL) + if err == nil || res != nil { + t.Fatal("Expected a non-nil error and a nil http.Response") + } + if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) { + t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want) + } +} + +// dumpConn is a net.Conn which writes to Writer and reads from Reader +type dumpConn struct { + io.Writer + io.Reader +} + +func (c *dumpConn) Close() error { return nil } +func (c *dumpConn) LocalAddr() net.Addr { return nil } +func (c *dumpConn) RemoteAddr() net.Addr { return nil } +func (c *dumpConn) SetDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil } + +// delegateReader is a reader that delegates to another reader, +// once it arrives on a channel. +type delegateReader struct { + c chan io.Reader + r io.Reader // nil until received from c +} + +func (r *delegateReader) Read(p []byte) (int, error) { + if r.r == nil { + var ok bool + if r.r, ok = <-r.c; !ok { + return 0, errors.New("delegate closed") + } + } + return r.r.Read(p) +} + +func testTransportRace(req *Request) { + save := req.Body + pr, pw := io.Pipe() + defer pr.Close() + defer pw.Close() + dr := &delegateReader{c: make(chan io.Reader)} + + t := &Transport{ + Dial: func(net, addr string) (net.Conn, error) { + return &dumpConn{pw, dr}, nil + }, + } + defer t.CloseIdleConnections() + + quitReadCh := make(chan struct{}) + // Wait for the request before replying with a dummy response: + go func() { + defer close(quitReadCh) + + req, err := ReadRequest(bufio.NewReader(pr)) + if err == nil { + // Ensure all the body is read; otherwise + // we'll get a partial dump. + io.Copy(io.Discard, req.Body) + req.Body.Close() + } + select { + case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"): + case quitReadCh <- struct{}{}: + // Ensure delegate is closed so Read doesn't block forever. + close(dr.c) + } + }() + + t.RoundTrip(req) + + // Ensure the reader returns before we reset req.Body to prevent + // a data race on req.Body. + pw.Close() + <-quitReadCh + + req.Body = save +} + +// Issue 37669 +// Test that a cancellation doesn't result in a data race due to the writeLoop +// goroutine being left running, if the caller mutates the processed Request +// upon completion. +func TestErrorWriteLoopRace(t *testing.T) { + if testing.Short() { + return + } + t.Parallel() + for i := 0; i < 1000; i++ { + delay := time.Duration(mrand.Intn(5)) * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), delay) + defer cancel() + + r := bytes.NewBuffer(make([]byte, 10000)) + req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r) + if err != nil { + t.Fatal(err) + } + + testTransportRace(req) + } +} + +// Issue 41600 +// Test that a new request which uses the connection of an active request +// cannot cause it to be canceled as well. +func TestCancelRequestWhenSharingConnection(t *testing.T) { + run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode}) +} +func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) { + reqc := make(chan chan struct{}, 2) + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) { + ch := make(chan struct{}, 1) + reqc <- ch + <-ch + w.Header().Add("Content-Length", "0") + })).ts + + client := ts.Client() + transport := client.Transport.(*Transport) + transport.MaxIdleConns = 1 + transport.MaxConnsPerHost = 1 + + var wg sync.WaitGroup + + wg.Add(1) + putidlec := make(chan chan struct{}, 1) + reqerrc := make(chan error, 1) + go func() { + defer wg.Done() + ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ + PutIdleConn: func(error) { + // Signal that the idle conn has been returned to the pool, + // and wait for the order to proceed. + ch := make(chan struct{}) + putidlec <- ch + close(putidlec) // panic if PutIdleConn runs twice for some reason + <-ch + }, + }) + req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil) + res, err := client.Do(req) + reqerrc <- err + if err == nil { + res.Body.Close() + } + }() + + // Wait for the first request to receive a response and return the + // connection to the idle pool. + r1c := <-reqc + close(r1c) + var idlec chan struct{} + select { + case err := <-reqerrc: + if err != nil { + t.Fatalf("request 1: got err %v, want nil", err) + } + idlec = <-putidlec + case idlec = <-putidlec: + } + + wg.Add(1) + cancelctx, cancel := context.WithCancel(context.Background()) + go func() { + defer wg.Done() + req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil) + res, err := client.Do(req) + if err == nil { + res.Body.Close() + } + if !errors.Is(err, context.Canceled) { + t.Errorf("request 2: got err %v, want Canceled", err) + } + + // Unblock the first request. + close(idlec) + }() + + // Wait for the second request to arrive at the server, and then cancel + // the request context. + r2c := <-reqc + cancel() + + <-idlec + + close(r2c) + wg.Wait() +} + +func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) } +func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { + go io.Copy(io.Discard, req.Body) + panic(ErrAbortHandler) + })).ts + + var wg sync.WaitGroup + for i := 0; i < 2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 10; j++ { + const reqLen = 6 * 1024 * 1024 + req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen}) + req.ContentLength = reqLen + resp, _ := ts.Client().Transport.RoundTrip(req) + if resp != nil { + resp.Body.Close() + } + } + }() + } + wg.Wait() +} + +func TestRequestSanitization(t *testing.T) { run(t, testRequestSanitization) } +func testRequestSanitization(t *testing.T, mode testMode) { + if mode == http2Mode { + // Remove this after updating x/net. + t.Skip("https://go.dev/issue/60374 test fails when run with HTTP/2") + } + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { + if h, ok := req.Header["X-Evil"]; ok { + t.Errorf("request has X-Evil header: %q", h) + } + })).ts + req, _ := NewRequest("GET", ts.URL, nil) + req.Host = "go.dev\r\nX-Evil:evil" + resp, _ := ts.Client().Do(req) + if resp != nil { + resp.Body.Close() + } +} |