summaryrefslogtreecommitdiffstats
path: root/src/net/http/transport_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/net/http/transport_test.go')
-rw-r--r--src/net/http/transport_test.go6675
1 files changed, 6675 insertions, 0 deletions
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go
new file mode 100644
index 0000000..f4896c5
--- /dev/null
+++ b/src/net/http/transport_test.go
@@ -0,0 +1,6675 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests for transport.go.
+//
+// More tests are in clientserver_test.go (for things testing both client & server for both
+// HTTP/1 and HTTP/2). This
+
+package http_test
+
+import (
+ "bufio"
+ "bytes"
+ "compress/gzip"
+ "context"
+ "crypto/rand"
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "go/token"
+ "internal/nettrace"
+ "io"
+ "log"
+ mrand "math/rand"
+ "net"
+ . "net/http"
+ "net/http/httptest"
+ "net/http/httptrace"
+ "net/http/httputil"
+ "net/http/internal/testcert"
+ "net/textproto"
+ "net/url"
+ "os"
+ "reflect"
+ "runtime"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "testing/iotest"
+ "time"
+
+ "golang.org/x/net/http/httpguts"
+)
+
+// TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close
+// and then verify that the final 2 responses get errors back.
+
+// hostPortHandler writes back the client's "host:port".
+var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.FormValue("close") == "true" {
+ w.Header().Set("Connection", "close")
+ }
+ w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close))
+ w.Write([]byte(r.RemoteAddr))
+
+ // Include the address of the net.Conn in addition to the RemoteAddr,
+ // in case kernels reuse source ports quickly (see Issue 52450)
+ if c, ok := ResponseWriterConnForTesting(w); ok {
+ fmt.Fprintf(w, ", %T %p", c, c)
+ }
+})
+
+// testCloseConn is a net.Conn tracked by a testConnSet.
+type testCloseConn struct {
+ net.Conn
+ set *testConnSet
+}
+
+func (c *testCloseConn) Close() error {
+ c.set.remove(c)
+ return c.Conn.Close()
+}
+
+// testConnSet tracks a set of TCP connections and whether they've
+// been closed.
+type testConnSet struct {
+ t *testing.T
+ mu sync.Mutex // guards closed and list
+ closed map[net.Conn]bool
+ list []net.Conn // in order created
+}
+
+func (tcs *testConnSet) insert(c net.Conn) {
+ tcs.mu.Lock()
+ defer tcs.mu.Unlock()
+ tcs.closed[c] = false
+ tcs.list = append(tcs.list, c)
+}
+
+func (tcs *testConnSet) remove(c net.Conn) {
+ tcs.mu.Lock()
+ defer tcs.mu.Unlock()
+ tcs.closed[c] = true
+}
+
+// some tests use this to manage raw tcp connections for later inspection
+func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
+ connSet := &testConnSet{
+ t: t,
+ closed: make(map[net.Conn]bool),
+ }
+ dial := func(n, addr string) (net.Conn, error) {
+ c, err := net.Dial(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ tc := &testCloseConn{c, connSet}
+ connSet.insert(tc)
+ return tc, nil
+ }
+ return connSet, dial
+}
+
+func (tcs *testConnSet) check(t *testing.T) {
+ tcs.mu.Lock()
+ defer tcs.mu.Unlock()
+ for i := 4; i >= 0; i-- {
+ for i, c := range tcs.list {
+ if tcs.closed[c] {
+ continue
+ }
+ if i != 0 {
+ tcs.mu.Unlock()
+ time.Sleep(50 * time.Millisecond)
+ tcs.mu.Lock()
+ continue
+ }
+ t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
+ }
+ }
+}
+
+func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) }
+func testReuseRequest(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Write([]byte("{}"))
+ })).ts
+
+ c := ts.Client()
+ req, _ := NewRequest("GET", ts.URL, nil)
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = res.Body.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ res, err = c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = res.Body.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+// Two subsequent requests and verify their response is the same.
+// The response from the server is our own IP:port
+func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) }
+func testTransportKeepAlives(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, hostPortHandler).ts
+
+ c := ts.Client()
+ for _, disableKeepAlive := range []bool{false, true} {
+ c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
+ fetch := func(n int) string {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
+ }
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
+ }
+ return string(body)
+ }
+
+ body1 := fetch(1)
+ body2 := fetch(2)
+
+ bodiesDiffer := body1 != body2
+ if bodiesDiffer != disableKeepAlive {
+ t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
+ disableKeepAlive, bodiesDiffer, body1, body2)
+ }
+ }
+}
+
+func TestTransportConnectionCloseOnResponse(t *testing.T) {
+ run(t, testTransportConnectionCloseOnResponse)
+}
+func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, hostPortHandler).ts
+
+ connSet, testDial := makeTestDial(t)
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ tr.Dial = testDial
+
+ for _, connectionClose := range []bool{false, true} {
+ fetch := func(n int) string {
+ req := new(Request)
+ var err error
+ req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
+ if err != nil {
+ t.Fatalf("URL parse error: %v", err)
+ }
+ req.Method = "GET"
+ req.Proto = "HTTP/1.1"
+ req.ProtoMajor = 1
+ req.ProtoMinor = 1
+
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
+ }
+ defer res.Body.Close()
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
+ }
+ return string(body)
+ }
+
+ body1 := fetch(1)
+ body2 := fetch(2)
+ bodiesDiffer := body1 != body2
+ if bodiesDiffer != connectionClose {
+ t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
+ connectionClose, bodiesDiffer, body1, body2)
+ }
+
+ tr.CloseIdleConnections()
+ }
+
+ connSet.check(t)
+}
+
+// TestTransportConnectionCloseOnRequest tests that the Transport's doesn't reuse
+// an underlying TCP connection after making an http.Request with Request.Close set.
+//
+// It tests the behavior by making an HTTP request to a server which
+// describes the source source connection it got (remote port number +
+// address of its net.Conn).
+func TestTransportConnectionCloseOnRequest(t *testing.T) {
+ run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode})
+}
+func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, hostPortHandler).ts
+
+ connSet, testDial := makeTestDial(t)
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ tr.Dial = testDial
+ for _, reqClose := range []bool{false, true} {
+ fetch := func(n int) string {
+ req := new(Request)
+ var err error
+ req.URL, err = url.Parse(ts.URL)
+ if err != nil {
+ t.Fatalf("URL parse error: %v", err)
+ }
+ req.Method = "GET"
+ req.Proto = "HTTP/1.1"
+ req.ProtoMajor = 1
+ req.ProtoMinor = 1
+ req.Close = reqClose
+
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err)
+ }
+ if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want {
+ t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v",
+ reqClose, got, !reqClose)
+ }
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err)
+ }
+ return string(body)
+ }
+
+ body1 := fetch(1)
+ body2 := fetch(2)
+
+ got := 1
+ if body1 != body2 {
+ got++
+ }
+ want := 1
+ if reqClose {
+ want = 2
+ }
+ if got != want {
+ t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q",
+ reqClose, got, want, body1, body2)
+ }
+
+ tr.CloseIdleConnections()
+ }
+
+ connSet.check(t)
+}
+
+// if the Transport's DisableKeepAlives is set, all requests should
+// send Connection: close.
+// HTTP/1-only (Connection: close doesn't exist in h2)
+func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
+ run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode})
+}
+func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, hostPortHandler).ts
+
+ c := ts.Client()
+ c.Transport.(*Transport).DisableKeepAlives = true
+
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if res.Header.Get("X-Saw-Close") != "true" {
+ t.Errorf("handler didn't see Connection: close ")
+ }
+}
+
+// Test that Transport only sends one "Connection: close", regardless of
+// how "close" was indicated.
+func TestTransportRespectRequestWantsClose(t *testing.T) {
+ run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode})
+}
+func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) {
+ tests := []struct {
+ disableKeepAlives bool
+ close bool
+ }{
+ {disableKeepAlives: false, close: false},
+ {disableKeepAlives: false, close: true},
+ {disableKeepAlives: true, close: false},
+ {disableKeepAlives: true, close: true},
+ }
+
+ for _, tc := range tests {
+ t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close),
+ func(t *testing.T) {
+ ts := newClientServerTest(t, mode, hostPortHandler).ts
+
+ c := ts.Client()
+ c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ count := 0
+ trace := &httptrace.ClientTrace{
+ WroteHeaderField: func(key string, field []string) {
+ if key != "Connection" {
+ return
+ }
+ if httpguts.HeaderValuesContainsToken(field, "close") {
+ count += 1
+ }
+ },
+ }
+ req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
+ req.Close = tc.close
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want {
+ t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count)
+ }
+ })
+ }
+
+}
+
+func TestTransportIdleCacheKeys(t *testing.T) {
+ run(t, testTransportIdleCacheKeys, []testMode{http1Mode})
+}
+func testTransportIdleCacheKeys(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, hostPortHandler).ts
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
+ t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
+ }
+
+ resp, err := c.Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ io.ReadAll(resp.Body)
+
+ keys := tr.IdleConnKeysForTesting()
+ if e, g := 1, len(keys); e != g {
+ t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
+ }
+
+ if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
+ t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
+ }
+
+ tr.CloseIdleConnections()
+ if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
+ t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
+ }
+}
+
+// Tests that the HTTP transport re-uses connections when a client
+// reads to the end of a response Body without closing it.
+func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) }
+func testTransportReadToEndReusesConn(t *testing.T, mode testMode) {
+ const msg = "foobar"
+
+ var addrSeen map[string]int
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ addrSeen[r.RemoteAddr]++
+ if r.URL.Path == "/chunked/" {
+ w.WriteHeader(200)
+ w.(Flusher).Flush()
+ } else {
+ w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
+ w.WriteHeader(200)
+ }
+ w.Write([]byte(msg))
+ })).ts
+
+ for pi, path := range []string{"/content-length/", "/chunked/"} {
+ wantLen := []int{len(msg), -1}[pi]
+ addrSeen = make(map[string]int)
+ for i := 0; i < 3; i++ {
+ res, err := ts.Client().Get(ts.URL + path)
+ if err != nil {
+ t.Errorf("Get %s: %v", path, err)
+ continue
+ }
+ // We want to close this body eventually (before the
+ // defer afterTest at top runs), but not before the
+ // len(addrSeen) check at the bottom of this test,
+ // since Closing this early in the loop would risk
+ // making connections be re-used for the wrong reason.
+ defer res.Body.Close()
+
+ if res.ContentLength != int64(wantLen) {
+ t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
+ }
+ got, err := io.ReadAll(res.Body)
+ if string(got) != msg || err != nil {
+ t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg)
+ }
+ }
+ if len(addrSeen) != 1 {
+ t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
+ }
+ }
+}
+
+func TestTransportMaxPerHostIdleConns(t *testing.T) {
+ run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode})
+}
+func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) {
+ stop := make(chan struct{}) // stop marks the exit of main Test goroutine
+ defer close(stop)
+
+ resch := make(chan string)
+ gotReq := make(chan bool)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ gotReq <- true
+ var msg string
+ select {
+ case <-stop:
+ return
+ case msg = <-resch:
+ }
+ _, err := w.Write([]byte(msg))
+ if err != nil {
+ t.Errorf("Write: %v", err)
+ return
+ }
+ })).ts
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ maxIdleConnsPerHost := 2
+ tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
+
+ // Start 3 outstanding requests and wait for the server to get them.
+ // Their responses will hang until we write to resch, though.
+ donech := make(chan bool)
+ doReq := func() {
+ defer func() {
+ select {
+ case <-stop:
+ return
+ case donech <- t.Failed():
+ }
+ }()
+ resp, err := c.Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ if _, err := io.ReadAll(resp.Body); err != nil {
+ t.Errorf("ReadAll: %v", err)
+ return
+ }
+ }
+ go doReq()
+ <-gotReq
+ go doReq()
+ <-gotReq
+ go doReq()
+ <-gotReq
+
+ if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
+ t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
+ }
+
+ resch <- "res1"
+ <-donech
+ keys := tr.IdleConnKeysForTesting()
+ if e, g := 1, len(keys); e != g {
+ t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
+ }
+ addr := ts.Listener.Addr().String()
+ cacheKey := "|http|" + addr
+ if keys[0] != cacheKey {
+ t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
+ }
+ if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
+ t.Errorf("after first response, expected %d idle conns; got %d", e, g)
+ }
+
+ resch <- "res2"
+ <-donech
+ if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
+ t.Errorf("after second response, idle conns = %d; want %d", g, w)
+ }
+
+ resch <- "res3"
+ <-donech
+ if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
+ t.Errorf("after third response, idle conns = %d; want %d", g, w)
+ }
+}
+
+func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
+ run(t, testTransportMaxConnsPerHostIncludeDialInProgress)
+}
+func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ _, err := w.Write([]byte("foo"))
+ if err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ })).ts
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ dialStarted := make(chan struct{})
+ stallDial := make(chan struct{})
+ tr.Dial = func(network, addr string) (net.Conn, error) {
+ dialStarted <- struct{}{}
+ <-stallDial
+ return net.Dial(network, addr)
+ }
+
+ tr.DisableKeepAlives = true
+ tr.MaxConnsPerHost = 1
+
+ preDial := make(chan struct{})
+ reqComplete := make(chan struct{})
+ doReq := func(reqId string) {
+ req, _ := NewRequest("GET", ts.URL, nil)
+ trace := &httptrace.ClientTrace{
+ GetConn: func(hostPort string) {
+ preDial <- struct{}{}
+ },
+ }
+ req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
+ resp, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Errorf("unexpected error for request %s: %v", reqId, err)
+ }
+ _, err = io.ReadAll(resp.Body)
+ if err != nil {
+ t.Errorf("unexpected error for request %s: %v", reqId, err)
+ }
+ reqComplete <- struct{}{}
+ }
+ // get req1 to dial-in-progress
+ go doReq("req1")
+ <-preDial
+ <-dialStarted
+
+ // get req2 to waiting on conns per host to go down below max
+ go doReq("req2")
+ <-preDial
+ select {
+ case <-dialStarted:
+ t.Error("req2 dial started while req1 dial in progress")
+ return
+ default:
+ }
+
+ // let req1 complete
+ stallDial <- struct{}{}
+ <-reqComplete
+
+ // let req2 complete
+ <-dialStarted
+ stallDial <- struct{}{}
+ <-reqComplete
+}
+
+func TestTransportMaxConnsPerHost(t *testing.T) {
+ run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode})
+}
+func testTransportMaxConnsPerHost(t *testing.T, mode testMode) {
+ CondSkipHTTP2(t)
+
+ h := HandlerFunc(func(w ResponseWriter, r *Request) {
+ _, err := w.Write([]byte("foo"))
+ if err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ })
+
+ ts := newClientServerTest(t, mode, h).ts
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ tr.MaxConnsPerHost = 1
+
+ mu := sync.Mutex{}
+ var conns []net.Conn
+ var dialCnt, gotConnCnt, tlsHandshakeCnt int32
+ tr.Dial = func(network, addr string) (net.Conn, error) {
+ atomic.AddInt32(&dialCnt, 1)
+ c, err := net.Dial(network, addr)
+ mu.Lock()
+ defer mu.Unlock()
+ conns = append(conns, c)
+ return c, err
+ }
+
+ doReq := func() {
+ trace := &httptrace.ClientTrace{
+ GotConn: func(connInfo httptrace.GotConnInfo) {
+ if !connInfo.Reused {
+ atomic.AddInt32(&gotConnCnt, 1)
+ }
+ },
+ TLSHandshakeStart: func() {
+ atomic.AddInt32(&tlsHandshakeCnt, 1)
+ },
+ }
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
+
+ resp, err := c.Do(req)
+ if err != nil {
+ t.Fatalf("request failed: %v", err)
+ }
+ defer resp.Body.Close()
+ _, err = io.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("read body failed: %v", err)
+ }
+ }
+
+ wg := sync.WaitGroup{}
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ doReq()
+ }()
+ }
+ wg.Wait()
+
+ expected := int32(tr.MaxConnsPerHost)
+ if dialCnt != expected {
+ t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected)
+ }
+ if gotConnCnt != expected {
+ t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected)
+ }
+ if ts.TLS != nil && tlsHandshakeCnt != expected {
+ t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ mu.Lock()
+ for _, c := range conns {
+ c.Close()
+ }
+ conns = nil
+ mu.Unlock()
+ tr.CloseIdleConnections()
+
+ doReq()
+ expected++
+ if dialCnt != expected {
+ t.Errorf("round 2: too many dials: %d", dialCnt)
+ }
+ if gotConnCnt != expected {
+ t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected)
+ }
+ if ts.TLS != nil && tlsHandshakeCnt != expected {
+ t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
+ }
+}
+
+func TestTransportRemovesDeadIdleConnections(t *testing.T) {
+ run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode})
+}
+func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.WriteString(w, r.RemoteAddr)
+ })).ts
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ doReq := func(name string) string {
+ // Do a POST instead of a GET to prevent the Transport's
+ // idempotent request retry logic from kicking in...
+ res, err := c.Post(ts.URL, "", nil)
+ if err != nil {
+ t.Fatalf("%s: %v", name, err)
+ }
+ if res.StatusCode != 200 {
+ t.Fatalf("%s: %v", name, res.Status)
+ }
+ defer res.Body.Close()
+ slurp, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("%s: %v", name, err)
+ }
+ return string(slurp)
+ }
+
+ first := doReq("first")
+ keys1 := tr.IdleConnKeysForTesting()
+
+ ts.CloseClientConnections()
+
+ var keys2 []string
+ if !waitCondition(3*time.Second, 50*time.Millisecond, func() bool {
+ keys2 = tr.IdleConnKeysForTesting()
+ return len(keys2) == 0
+ }) {
+ t.Fatalf("Transport didn't notice idle connection's death.\nbefore: %q\n after: %q\n", keys1, keys2)
+ }
+
+ second := doReq("second")
+ if first == second {
+ t.Errorf("expected a different connection between requests. got %q both times", first)
+ }
+}
+
+// Test that the Transport notices when a server hangs up on its
+// unexpectedly (a keep-alive connection is closed).
+func TestTransportServerClosingUnexpectedly(t *testing.T) {
+ run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode})
+}
+func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, hostPortHandler).ts
+ c := ts.Client()
+
+ fetch := func(n, retries int) string {
+ condFatalf := func(format string, arg ...any) {
+ if retries <= 0 {
+ t.Fatalf(format, arg...)
+ }
+ t.Logf("retrying shortly after expected error: "+format, arg...)
+ time.Sleep(time.Second / time.Duration(retries))
+ }
+ for retries >= 0 {
+ retries--
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ condFatalf("error in req #%d, GET: %v", n, err)
+ continue
+ }
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ condFatalf("error in req #%d, ReadAll: %v", n, err)
+ continue
+ }
+ res.Body.Close()
+ return string(body)
+ }
+ panic("unreachable")
+ }
+
+ body1 := fetch(1, 0)
+ body2 := fetch(2, 0)
+
+ // Close all the idle connections in a way that's similar to
+ // the server hanging up on us. We don't use
+ // httptest.Server.CloseClientConnections because it's
+ // best-effort and stops blocking after 5 seconds. On a loaded
+ // machine running many tests concurrently it's possible for
+ // that method to be async and cause the body3 fetch below to
+ // run on an old connection. This function is synchronous.
+ ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
+
+ body3 := fetch(3, 5)
+
+ if body1 != body2 {
+ t.Errorf("expected body1 and body2 to be equal")
+ }
+ if body2 == body3 {
+ t.Errorf("expected body2 and body3 to be different")
+ }
+}
+
+// Test for https://golang.org/issue/2616 (appropriate issue number)
+// This fails pretty reliably with GOMAXPROCS=100 or something high.
+func TestStressSurpriseServerCloses(t *testing.T) {
+ run(t, testStressSurpriseServerCloses, []testMode{http1Mode})
+}
+func testStressSurpriseServerCloses(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping test in short mode")
+ }
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "5")
+ w.Header().Set("Content-Type", "text/plain")
+ w.Write([]byte("Hello"))
+ w.(Flusher).Flush()
+ conn, buf, _ := w.(Hijacker).Hijack()
+ buf.Flush()
+ conn.Close()
+ })).ts
+ c := ts.Client()
+
+ // Do a bunch of traffic from different goroutines. Send to activityc
+ // after each request completes, regardless of whether it failed.
+ // If these are too high, OS X exhausts its ephemeral ports
+ // and hangs waiting for them to transition TCP states. That's
+ // not what we want to test. TODO(bradfitz): use an io.Pipe
+ // dialer for this test instead?
+ const (
+ numClients = 20
+ reqsPerClient = 25
+ )
+ activityc := make(chan bool)
+ for i := 0; i < numClients; i++ {
+ go func() {
+ for i := 0; i < reqsPerClient; i++ {
+ res, err := c.Get(ts.URL)
+ if err == nil {
+ // We expect errors since the server is
+ // hanging up on us after telling us to
+ // send more requests, so we don't
+ // actually care what the error is.
+ // But we want to close the body in cases
+ // where we won the race.
+ res.Body.Close()
+ }
+ if !<-activityc { // Receives false when close(activityc) is executed
+ return
+ }
+ }
+ }()
+ }
+
+ // Make sure all the request come back, one way or another.
+ for i := 0; i < numClients*reqsPerClient; i++ {
+ select {
+ case activityc <- true:
+ case <-time.After(5 * time.Second):
+ close(activityc)
+ t.Fatalf("presumed deadlock; no HTTP client activity seen in awhile")
+ }
+ }
+}
+
+// TestTransportHeadResponses verifies that we deal with Content-Lengths
+// with no bodies properly
+func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) }
+func testTransportHeadResponses(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method != "HEAD" {
+ panic("expected HEAD; got " + r.Method)
+ }
+ w.Header().Set("Content-Length", "123")
+ w.WriteHeader(200)
+ })).ts
+ c := ts.Client()
+
+ for i := 0; i < 2; i++ {
+ res, err := c.Head(ts.URL)
+ if err != nil {
+ t.Errorf("error on loop %d: %v", i, err)
+ continue
+ }
+ if e, g := "123", res.Header.Get("Content-Length"); e != g {
+ t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
+ }
+ if e, g := int64(123), res.ContentLength; e != g {
+ t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
+ }
+ if all, err := io.ReadAll(res.Body); err != nil {
+ t.Errorf("loop %d: Body ReadAll: %v", i, err)
+ } else if len(all) != 0 {
+ t.Errorf("Bogus body %q", all)
+ }
+ }
+}
+
+// TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding
+// on responses to HEAD requests.
+func TestTransportHeadChunkedResponse(t *testing.T) {
+ run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel)
+}
+func testTransportHeadChunkedResponse(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method != "HEAD" {
+ panic("expected HEAD; got " + r.Method)
+ }
+ w.Header().Set("Transfer-Encoding", "chunked") // client should ignore
+ w.Header().Set("x-client-ipport", r.RemoteAddr)
+ w.WriteHeader(200)
+ })).ts
+ c := ts.Client()
+
+ // Ensure that we wait for the readLoop to complete before
+ // calling Head again
+ didRead := make(chan bool)
+ SetReadLoopBeforeNextReadHook(func() { didRead <- true })
+ defer SetReadLoopBeforeNextReadHook(nil)
+
+ res1, err := c.Head(ts.URL)
+ <-didRead
+
+ if err != nil {
+ t.Fatalf("request 1 error: %v", err)
+ }
+
+ res2, err := c.Head(ts.URL)
+ <-didRead
+
+ if err != nil {
+ t.Fatalf("request 2 error: %v", err)
+ }
+ if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
+ t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
+ }
+}
+
+var roundTripTests = []struct {
+ accept string
+ expectAccept string
+ compressed bool
+}{
+ // Requests with no accept-encoding header use transparent compression
+ {"", "gzip", false},
+ // Requests with other accept-encoding should pass through unmodified
+ {"foo", "foo", false},
+ // Requests with accept-encoding == gzip should be passed through
+ {"gzip", "gzip", true},
+}
+
+// Test that the modification made to the Request by the RoundTripper is cleaned up
+func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) }
+func testRoundTripGzip(t *testing.T, mode testMode) {
+ const responseBody = "test response body"
+ ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ accept := req.Header.Get("Accept-Encoding")
+ if expect := req.FormValue("expect_accept"); accept != expect {
+ t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
+ req.FormValue("testnum"), accept, expect)
+ }
+ if accept == "gzip" {
+ rw.Header().Set("Content-Encoding", "gzip")
+ gz := gzip.NewWriter(rw)
+ gz.Write([]byte(responseBody))
+ gz.Close()
+ } else {
+ rw.Header().Set("Content-Encoding", accept)
+ rw.Write([]byte(responseBody))
+ }
+ })).ts
+ tr := ts.Client().Transport.(*Transport)
+
+ for i, test := range roundTripTests {
+ // Test basic request (no accept-encoding)
+ req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
+ if test.accept != "" {
+ req.Header.Set("Accept-Encoding", test.accept)
+ }
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Errorf("%d. RoundTrip: %v", i, err)
+ continue
+ }
+ var body []byte
+ if test.compressed {
+ var r *gzip.Reader
+ r, err = gzip.NewReader(res.Body)
+ if err != nil {
+ t.Errorf("%d. gzip NewReader: %v", i, err)
+ continue
+ }
+ body, err = io.ReadAll(r)
+ res.Body.Close()
+ } else {
+ body, err = io.ReadAll(res.Body)
+ }
+ if err != nil {
+ t.Errorf("%d. Error: %q", i, err)
+ continue
+ }
+ if g, e := string(body), responseBody; g != e {
+ t.Errorf("%d. body = %q; want %q", i, g, e)
+ }
+ if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
+ t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
+ }
+ if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
+ t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
+ }
+ }
+
+}
+
+func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) }
+func testTransportGzip(t *testing.T, mode testMode) {
+ if mode == http2Mode {
+ t.Skip("https://go.dev/issue/56020")
+ }
+ const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+ const nRandBytes = 1024 * 1024
+ ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ if req.Method == "HEAD" {
+ if g := req.Header.Get("Accept-Encoding"); g != "" {
+ t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
+ }
+ return
+ }
+ if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
+ t.Errorf("Accept-Encoding = %q, want %q", g, e)
+ }
+ rw.Header().Set("Content-Encoding", "gzip")
+
+ var w io.Writer = rw
+ var buf bytes.Buffer
+ if req.FormValue("chunked") == "0" {
+ w = &buf
+ defer io.Copy(rw, &buf)
+ defer func() {
+ rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
+ }()
+ }
+ gz := gzip.NewWriter(w)
+ gz.Write([]byte(testString))
+ if req.FormValue("body") == "large" {
+ io.CopyN(gz, rand.Reader, nRandBytes)
+ }
+ gz.Close()
+ })).ts
+ c := ts.Client()
+
+ for _, chunked := range []string{"1", "0"} {
+ // First fetch something large, but only read some of it.
+ res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
+ if err != nil {
+ t.Fatalf("large get: %v", err)
+ }
+ buf := make([]byte, len(testString))
+ n, err := io.ReadFull(res.Body, buf)
+ if err != nil {
+ t.Fatalf("partial read of large response: size=%d, %v", n, err)
+ }
+ if e, g := testString, string(buf); e != g {
+ t.Errorf("partial read got %q, expected %q", g, e)
+ }
+ res.Body.Close()
+ // Read on the body, even though it's closed
+ n, err = res.Body.Read(buf)
+ if n != 0 || err == nil {
+ t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
+ }
+
+ // Then something small.
+ res, err = c.Get(ts.URL + "/?chunked=" + chunked)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if g, e := string(body), testString; g != e {
+ t.Fatalf("body = %q; want %q", g, e)
+ }
+ if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
+ t.Fatalf("Content-Encoding = %q; want %q", g, e)
+ }
+
+ // Read on the body after it's been fully read:
+ n, err = res.Body.Read(buf)
+ if n != 0 || err == nil {
+ t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
+ }
+ res.Body.Close()
+ n, err = res.Body.Read(buf)
+ if n != 0 || err == nil {
+ t.Errorf("expected Read error after Close; got %d, %v", n, err)
+ }
+ }
+
+ // And a HEAD request too, because they're always weird.
+ res, err := c.Head(ts.URL)
+ if err != nil {
+ t.Fatalf("Head: %v", err)
+ }
+ if res.StatusCode != 200 {
+ t.Errorf("Head status=%d; want=200", res.StatusCode)
+ }
+}
+
+// If a request has Expect:100-continue header, the request blocks sending body until the first response.
+// Premature consumption of the request body should not be occurred.
+func TestTransportExpect100Continue(t *testing.T) {
+ run(t, testTransportExpect100Continue, []testMode{http1Mode})
+}
+func testTransportExpect100Continue(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ switch req.URL.Path {
+ case "/100":
+ // This endpoint implicitly responds 100 Continue and reads body.
+ if _, err := io.Copy(io.Discard, req.Body); err != nil {
+ t.Error("Failed to read Body", err)
+ }
+ rw.WriteHeader(StatusOK)
+ case "/200":
+ // Go 1.5 adds Connection: close header if the client expect
+ // continue but not entire request body is consumed.
+ rw.WriteHeader(StatusOK)
+ case "/500":
+ rw.WriteHeader(StatusInternalServerError)
+ case "/keepalive":
+ // This hijacked endpoint responds error without Connection:close.
+ _, bufrw, err := rw.(Hijacker).Hijack()
+ if err != nil {
+ log.Fatal(err)
+ }
+ bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n")
+ bufrw.WriteString("Content-Length: 0\r\n\r\n")
+ bufrw.Flush()
+ case "/timeout":
+ // This endpoint tries to read body without 100 (Continue) response.
+ // After ExpectContinueTimeout, the reading will be started.
+ conn, bufrw, err := rw.(Hijacker).Hijack()
+ if err != nil {
+ log.Fatal(err)
+ }
+ if _, err := io.CopyN(io.Discard, bufrw, req.ContentLength); err != nil {
+ t.Error("Failed to read Body", err)
+ }
+ bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n")
+ bufrw.Flush()
+ conn.Close()
+ }
+
+ })).ts
+
+ tests := []struct {
+ path string
+ body []byte
+ sent int
+ status int
+ }{
+ {path: "/100", body: []byte("hello"), sent: 5, status: 200}, // Got 100 followed by 200, entire body is sent.
+ {path: "/200", body: []byte("hello"), sent: 0, status: 200}, // Got 200 without 100. body isn't sent.
+ {path: "/500", body: []byte("hello"), sent: 0, status: 500}, // Got 500 without 100. body isn't sent.
+ {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500}, // Although without Connection:close, body isn't sent.
+ {path: "/timeout", body: []byte("hello"), sent: 5, status: 200}, // Timeout exceeded and entire body is sent.
+ }
+
+ c := ts.Client()
+ for i, v := range tests {
+ tr := &Transport{
+ ExpectContinueTimeout: 2 * time.Second,
+ }
+ defer tr.CloseIdleConnections()
+ c.Transport = tr
+ body := bytes.NewReader(v.body)
+ req, err := NewRequest("PUT", ts.URL+v.path, body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.Header.Set("Expect", "100-continue")
+ req.ContentLength = int64(len(v.body))
+
+ resp, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ resp.Body.Close()
+
+ sent := len(v.body) - body.Len()
+ if v.status != resp.StatusCode {
+ t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path)
+ }
+ if v.sent != sent {
+ t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path)
+ }
+ }
+}
+
+func TestSOCKS5Proxy(t *testing.T) {
+ run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode})
+}
+func testSOCKS5Proxy(t *testing.T, mode testMode) {
+ ch := make(chan string, 1)
+ l := newLocalListener(t)
+ defer l.Close()
+ defer close(ch)
+ proxy := func(t *testing.T) {
+ s, err := l.Accept()
+ if err != nil {
+ t.Errorf("socks5 proxy Accept(): %v", err)
+ return
+ }
+ defer s.Close()
+ var buf [22]byte
+ if _, err := io.ReadFull(s, buf[:3]); err != nil {
+ t.Errorf("socks5 proxy initial read: %v", err)
+ return
+ }
+ if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
+ t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want)
+ return
+ }
+ if _, err := s.Write([]byte{5, 0}); err != nil {
+ t.Errorf("socks5 proxy initial write: %v", err)
+ return
+ }
+ if _, err := io.ReadFull(s, buf[:4]); err != nil {
+ t.Errorf("socks5 proxy second read: %v", err)
+ return
+ }
+ if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
+ t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want)
+ return
+ }
+ var ipLen int
+ switch buf[3] {
+ case 1:
+ ipLen = net.IPv4len
+ case 4:
+ ipLen = net.IPv6len
+ default:
+ t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
+ return
+ }
+ if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
+ t.Errorf("socks5 proxy address read: %v", err)
+ return
+ }
+ ip := net.IP(buf[4 : ipLen+4])
+ port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6])
+ copy(buf[:3], []byte{5, 0, 0})
+ if _, err := s.Write(buf[:ipLen+6]); err != nil {
+ t.Errorf("socks5 proxy connect write: %v", err)
+ return
+ }
+ ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
+
+ // Implement proxying.
+ targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
+ targetConn, err := net.Dial("tcp", targetHost)
+ if err != nil {
+ t.Errorf("net.Dial failed")
+ return
+ }
+ go io.Copy(targetConn, s)
+ io.Copy(s, targetConn) // Wait for the client to close the socket.
+ targetConn.Close()
+ }
+
+ pu, err := url.Parse("socks5://" + l.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ sentinelHeader := "X-Sentinel"
+ sentinelValue := "12345"
+ h := HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set(sentinelHeader, sentinelValue)
+ })
+ for _, useTLS := range []bool{false, true} {
+ t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
+ ts := newClientServerTest(t, mode, h).ts
+ go proxy(t)
+ c := ts.Client()
+ c.Transport.(*Transport).Proxy = ProxyURL(pu)
+ r, err := c.Head(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if r.Header.Get(sentinelHeader) != sentinelValue {
+ t.Errorf("Failed to retrieve sentinel value")
+ }
+ var got string
+ select {
+ case got = <-ch:
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout connecting to socks5 proxy")
+ }
+ ts.Close()
+ tsu, err := url.Parse(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := "proxy for " + tsu.Host
+ if got != want {
+ t.Errorf("got %q, want %q", got, want)
+ }
+ })
+ }
+}
+
+func TestTransportProxy(t *testing.T) {
+ defer afterTest(t)
+ testCases := []struct{ siteMode, proxyMode testMode }{
+ {http1Mode, http1Mode},
+ {http1Mode, https1Mode},
+ {https1Mode, http1Mode},
+ {https1Mode, https1Mode},
+ }
+ for _, testCase := range testCases {
+ siteMode := testCase.siteMode
+ proxyMode := testCase.proxyMode
+ t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) {
+ siteCh := make(chan *Request, 1)
+ h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
+ siteCh <- r
+ })
+ proxyCh := make(chan *Request, 1)
+ h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
+ proxyCh <- r
+ // Implement an entire CONNECT proxy
+ if r.Method == "CONNECT" {
+ hijacker, ok := w.(Hijacker)
+ if !ok {
+ t.Errorf("hijack not allowed")
+ return
+ }
+ clientConn, _, err := hijacker.Hijack()
+ if err != nil {
+ t.Errorf("hijacking failed")
+ return
+ }
+ res := &Response{
+ StatusCode: StatusOK,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: make(Header),
+ }
+
+ targetConn, err := net.Dial("tcp", r.URL.Host)
+ if err != nil {
+ t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
+ return
+ }
+
+ if err := res.Write(clientConn); err != nil {
+ t.Errorf("Writing 200 OK failed: %v", err)
+ return
+ }
+
+ go io.Copy(targetConn, clientConn)
+ go func() {
+ io.Copy(clientConn, targetConn)
+ targetConn.Close()
+ }()
+ }
+ })
+ ts := newClientServerTest(t, siteMode, h1).ts
+ proxy := newClientServerTest(t, proxyMode, h2).ts
+
+ pu, err := url.Parse(proxy.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // If neither server is HTTPS or both are, then c may be derived from either.
+ // If only one server is HTTPS, c must be derived from that server in order
+ // to ensure that it is configured to use the fake root CA from testcert.go.
+ c := proxy.Client()
+ if siteMode == https1Mode {
+ c = ts.Client()
+ }
+
+ c.Transport.(*Transport).Proxy = ProxyURL(pu)
+ if _, err := c.Head(ts.URL); err != nil {
+ t.Error(err)
+ }
+ var got *Request
+ select {
+ case got = <-proxyCh:
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout connecting to http proxy")
+ }
+ c.Transport.(*Transport).CloseIdleConnections()
+ ts.Close()
+ proxy.Close()
+ if siteMode == https1Mode {
+ // First message should be a CONNECT, asking for a socket to the real server,
+ if got.Method != "CONNECT" {
+ t.Errorf("Wrong method for secure proxying: %q", got.Method)
+ }
+ gotHost := got.URL.Host
+ pu, err := url.Parse(ts.URL)
+ if err != nil {
+ t.Fatal("Invalid site URL")
+ }
+ if wantHost := pu.Host; gotHost != wantHost {
+ t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
+ }
+
+ // The next message on the channel should be from the site's server.
+ next := <-siteCh
+ if next.Method != "HEAD" {
+ t.Errorf("Wrong method at destination: %s", next.Method)
+ }
+ if nextURL := next.URL.String(); nextURL != "/" {
+ t.Errorf("Wrong URL at destination: %s", nextURL)
+ }
+ } else {
+ if got.Method != "HEAD" {
+ t.Errorf("Wrong method for destination: %q", got.Method)
+ }
+ gotURL := got.URL.String()
+ wantURL := ts.URL + "/"
+ if gotURL != wantURL {
+ t.Errorf("Got URL %q, want %q", gotURL, wantURL)
+ }
+ }
+ })
+ }
+}
+
+func TestOnProxyConnectResponse(t *testing.T) {
+
+ var tcases = []struct {
+ proxyStatusCode int
+ err error
+ }{
+ {
+ StatusOK,
+ nil,
+ },
+ {
+ StatusForbidden,
+ errors.New("403"),
+ },
+ }
+ for _, tcase := range tcases {
+ h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
+
+ })
+
+ h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
+ // Implement an entire CONNECT proxy
+ if r.Method == "CONNECT" {
+ if tcase.proxyStatusCode != StatusOK {
+ w.WriteHeader(tcase.proxyStatusCode)
+ return
+ }
+ hijacker, ok := w.(Hijacker)
+ if !ok {
+ t.Errorf("hijack not allowed")
+ return
+ }
+ clientConn, _, err := hijacker.Hijack()
+ if err != nil {
+ t.Errorf("hijacking failed")
+ return
+ }
+ res := &Response{
+ StatusCode: StatusOK,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Header: make(Header),
+ }
+
+ targetConn, err := net.Dial("tcp", r.URL.Host)
+ if err != nil {
+ t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
+ return
+ }
+
+ if err := res.Write(clientConn); err != nil {
+ t.Errorf("Writing 200 OK failed: %v", err)
+ return
+ }
+
+ go io.Copy(targetConn, clientConn)
+ go func() {
+ io.Copy(clientConn, targetConn)
+ targetConn.Close()
+ }()
+ }
+ })
+ ts := newClientServerTest(t, https1Mode, h1).ts
+ proxy := newClientServerTest(t, https1Mode, h2).ts
+
+ pu, err := url.Parse(proxy.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ c := proxy.Client()
+
+ c.Transport.(*Transport).Proxy = ProxyURL(pu)
+ c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
+ if proxyURL.String() != pu.String() {
+ t.Errorf("proxy url got %s, want %s", proxyURL, pu)
+ }
+
+ if "https://"+connectReq.URL.String() != ts.URL {
+ t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL)
+ }
+ return tcase.err
+ }
+ if _, err := c.Head(ts.URL); err != nil {
+ if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
+ t.Errorf("got %v, want %v", err, tcase.err)
+ }
+ }
+ }
+}
+
+// Issue 28012: verify that the Transport closes its TCP connection to http proxies
+// when they're slow to reply to HTTPS CONNECT responses.
+func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ ln := newLocalListener(t)
+ defer ln.Close()
+ listenerDone := make(chan struct{})
+ go func() {
+ defer close(listenerDone)
+ c, err := ln.Accept()
+ if err != nil {
+ t.Errorf("Accept: %v", err)
+ return
+ }
+ defer c.Close()
+ // Read the CONNECT request
+ br := bufio.NewReader(c)
+ cr, err := ReadRequest(br)
+ if err != nil {
+ t.Errorf("proxy server failed to read CONNECT request")
+ return
+ }
+ if cr.Method != "CONNECT" {
+ t.Errorf("unexpected method %q", cr.Method)
+ return
+ }
+
+ // Now hang and never write a response; instead, cancel the request and wait
+ // for the client to close.
+ // (Prior to Issue 28012 being fixed, we never closed.)
+ cancel()
+ var buf [1]byte
+ _, err = br.Read(buf[:])
+ if err != io.EOF {
+ t.Errorf("proxy server Read err = %v; want EOF", err)
+ }
+ return
+ }()
+
+ c := &Client{
+ Transport: &Transport{
+ Proxy: func(*Request) (*url.URL, error) {
+ return url.Parse("http://" + ln.Addr().String())
+ },
+ },
+ }
+ req, err := NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = c.Do(req)
+ if err == nil {
+ t.Errorf("unexpected Get success")
+ }
+
+ // Wait unconditionally for the listener goroutine to exit: this should never
+ // hang, so if it does we want a full goroutine dump — and that's exactly what
+ // the testing package will give us when the test run times out.
+ <-listenerDone
+}
+
+// Issue 16997: test transport dial preserves typed errors
+func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
+ defer afterTest(t)
+
+ var errDial = errors.New("some dial error")
+
+ tr := &Transport{
+ Proxy: func(*Request) (*url.URL, error) {
+ return url.Parse("http://proxy.fake.tld/")
+ },
+ Dial: func(string, string) (net.Conn, error) {
+ return nil, errDial
+ },
+ }
+ defer tr.CloseIdleConnections()
+
+ c := &Client{Transport: tr}
+ req, _ := NewRequest("GET", "http://fake.tld", nil)
+ res, err := c.Do(req)
+ if err == nil {
+ res.Body.Close()
+ t.Fatal("wanted a non-nil error")
+ }
+
+ uerr, ok := err.(*url.Error)
+ if !ok {
+ t.Fatalf("got %T, want *url.Error", err)
+ }
+ oe, ok := uerr.Err.(*net.OpError)
+ if !ok {
+ t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err)
+ }
+ want := &net.OpError{
+ Op: "proxyconnect",
+ Net: "tcp",
+ Err: errDial, // original error, unwrapped.
+ }
+ if !reflect.DeepEqual(oe, want) {
+ t.Errorf("Got error %#v; want %#v", oe, want)
+ }
+}
+
+// Issue 36431: calls to RoundTrip should not mutate t.ProxyConnectHeader.
+//
+// (A bug caused dialConn to instead write the per-request Proxy-Authorization
+// header through to the shared Header instance, introducing a data race.)
+func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) {
+ run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader)
+}
+func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) {
+ proxy := newClientServerTest(t, mode, NotFoundHandler()).ts
+ defer proxy.Close()
+ c := proxy.Client()
+
+ tr := c.Transport.(*Transport)
+ tr.Proxy = func(*Request) (*url.URL, error) {
+ u, _ := url.Parse(proxy.URL)
+ u.User = url.UserPassword("aladdin", "opensesame")
+ return u, nil
+ }
+ h := tr.ProxyConnectHeader
+ if h == nil {
+ h = make(Header)
+ }
+ tr.ProxyConnectHeader = h.Clone()
+
+ req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = c.Do(req)
+ if err == nil {
+ t.Errorf("unexpected Get success")
+ }
+
+ if !reflect.DeepEqual(tr.ProxyConnectHeader, h) {
+ t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h)
+ }
+}
+
+// TestTransportGzipRecursive sends a gzip quine and checks that the
+// client gets the same value back. This is more cute than anything,
+// but checks that we don't recurse forever, and checks that
+// Content-Encoding is removed.
+func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) }
+func testTransportGzipRecursive(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Encoding", "gzip")
+ w.Write(rgz)
+ })).ts
+
+ c := ts.Client()
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(body, rgz) {
+ t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
+ body, rgz)
+ }
+ if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
+ t.Fatalf("Content-Encoding = %q; want %q", g, e)
+ }
+}
+
+// golang.org/issue/7750: request fails when server replies with
+// a short gzip body
+func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) }
+func testTransportGzipShort(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Encoding", "gzip")
+ w.Write([]byte{0x1f, 0x8b})
+ })).ts
+
+ c := ts.Client()
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ _, err = io.ReadAll(res.Body)
+ if err == nil {
+ t.Fatal("Expect an error from reading a body.")
+ }
+ if err != io.ErrUnexpectedEOF {
+ t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
+ }
+}
+
+// Wait until number of goroutines is no greater than nmax, or time out.
+func waitNumGoroutine(nmax int) int {
+ nfinal := runtime.NumGoroutine()
+ for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- {
+ time.Sleep(50 * time.Millisecond)
+ runtime.GC()
+ nfinal = runtime.NumGoroutine()
+ }
+ return nfinal
+}
+
+// tests that persistent goroutine connections shut down when no longer desired.
+func TestTransportPersistConnLeak(t *testing.T) {
+ run(t, testTransportPersistConnLeak, testNotParallel)
+}
+func testTransportPersistConnLeak(t *testing.T, mode testMode) {
+ if mode == http2Mode {
+ t.Skip("flaky in HTTP/2")
+ }
+ // Not parallel: counts goroutines
+
+ const numReq = 25
+ gotReqCh := make(chan bool, numReq)
+ unblockCh := make(chan bool, numReq)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ gotReqCh <- true
+ <-unblockCh
+ w.Header().Set("Content-Length", "0")
+ w.WriteHeader(204)
+ })).ts
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ n0 := runtime.NumGoroutine()
+
+ didReqCh := make(chan bool, numReq)
+ failed := make(chan bool, numReq)
+ for i := 0; i < numReq; i++ {
+ go func() {
+ res, err := c.Get(ts.URL)
+ didReqCh <- true
+ if err != nil {
+ t.Logf("client fetch error: %v", err)
+ failed <- true
+ return
+ }
+ res.Body.Close()
+ }()
+ }
+
+ // Wait for all goroutines to be stuck in the Handler.
+ for i := 0; i < numReq; i++ {
+ select {
+ case <-gotReqCh:
+ // ok
+ case <-failed:
+ // Not great but not what we are testing:
+ // sometimes an overloaded system will fail to make all the connections.
+ }
+ }
+
+ nhigh := runtime.NumGoroutine()
+
+ // Tell all handlers to unblock and reply.
+ close(unblockCh)
+
+ // Wait for all HTTP clients to be done.
+ for i := 0; i < numReq; i++ {
+ <-didReqCh
+ }
+
+ tr.CloseIdleConnections()
+ nfinal := waitNumGoroutine(n0 + 5)
+
+ growth := nfinal - n0
+
+ // We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
+ // Previously we were leaking one per numReq.
+ if int(growth) > 5 {
+ t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
+ t.Error("too many new goroutines")
+ }
+}
+
+// golang.org/issue/4531: Transport leaks goroutines when
+// request.ContentLength is explicitly short
+func TestTransportPersistConnLeakShortBody(t *testing.T) {
+ run(t, testTransportPersistConnLeakShortBody, testNotParallel)
+}
+func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) {
+ if mode == http2Mode {
+ t.Skip("flaky in HTTP/2")
+ }
+
+ // Not parallel: measures goroutines.
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ })).ts
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ n0 := runtime.NumGoroutine()
+ body := []byte("Hello")
+ for i := 0; i < 20; i++ {
+ req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.ContentLength = int64(len(body) - 2) // explicitly short
+ _, err = c.Do(req)
+ if err == nil {
+ t.Fatal("Expect an error from writing too long of a body.")
+ }
+ }
+ nhigh := runtime.NumGoroutine()
+ tr.CloseIdleConnections()
+ nfinal := waitNumGoroutine(n0 + 5)
+
+ growth := nfinal - n0
+
+ // We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
+ // Previously we were leaking one per numReq.
+ t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
+ if int(growth) > 5 {
+ t.Error("too many new goroutines")
+ }
+}
+
+// A countedConn is a net.Conn that decrements an atomic counter when finalized.
+type countedConn struct {
+ net.Conn
+}
+
+// A countingDialer dials connections and counts the number that remain reachable.
+type countingDialer struct {
+ dialer net.Dialer
+ mu sync.Mutex
+ total, live int64
+}
+
+func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
+ conn, err := d.dialer.DialContext(ctx, network, address)
+ if err != nil {
+ return nil, err
+ }
+
+ counted := new(countedConn)
+ counted.Conn = conn
+
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ d.total++
+ d.live++
+
+ runtime.SetFinalizer(counted, d.decrement)
+ return counted, nil
+}
+
+func (d *countingDialer) decrement(*countedConn) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ d.live--
+}
+
+func (d *countingDialer) Read() (total, live int64) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return d.total, d.live
+}
+
+func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
+ run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode})
+}
+func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ // Close every connection so that it cannot be kept alive.
+ conn, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Errorf("Hijack failed unexpectedly: %v", err)
+ return
+ }
+ conn.Close()
+ })).ts
+
+ var d countingDialer
+ c := ts.Client()
+ c.Transport.(*Transport).DialContext = d.DialContext
+
+ body := []byte("Hello")
+ for i := 0; ; i++ {
+ total, live := d.Read()
+ if live < total {
+ break
+ }
+ if i >= 1<<12 {
+ t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
+ }
+
+ req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = c.Do(req)
+ if err == nil {
+ t.Fatal("expected broken connection")
+ }
+
+ runtime.GC()
+ }
+}
+
+type countedContext struct {
+ context.Context
+}
+
+type contextCounter struct {
+ mu sync.Mutex
+ live int64
+}
+
+func (cc *contextCounter) Track(ctx context.Context) context.Context {
+ counted := new(countedContext)
+ counted.Context = ctx
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ cc.live++
+ runtime.SetFinalizer(counted, cc.decrement)
+ return counted
+}
+
+func (cc *contextCounter) decrement(*countedContext) {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ cc.live--
+}
+
+func (cc *contextCounter) Read() (live int64) {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ return cc.live
+}
+
+func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
+ run(t, testTransportPersistConnContextLeakMaxConnsPerHost)
+}
+func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) {
+ if mode == http2Mode {
+ t.Skip("https://go.dev/issue/56021")
+ }
+
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ runtime.Gosched()
+ w.WriteHeader(StatusOK)
+ })).ts
+
+ c := ts.Client()
+ c.Transport.(*Transport).MaxConnsPerHost = 1
+
+ ctx := context.Background()
+ body := []byte("Hello")
+ doPosts := func(cc *contextCounter) {
+ var wg sync.WaitGroup
+ for n := 64; n > 0; n-- {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+
+ ctx := cc.Track(ctx)
+ req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
+ if err != nil {
+ t.Error(err)
+ }
+
+ _, err = c.Do(req.WithContext(ctx))
+ if err != nil {
+ t.Errorf("Do failed with error: %v", err)
+ }
+ }()
+ }
+ wg.Wait()
+ }
+
+ var initialCC contextCounter
+ doPosts(&initialCC)
+
+ // flushCC exists only to put pressure on the GC to finalize the initialCC
+ // contexts: the flushCC allocations should eventually displace the initialCC
+ // allocations.
+ var flushCC contextCounter
+ for i := 0; ; i++ {
+ live := initialCC.Read()
+ if live == 0 {
+ break
+ }
+ if i >= 100 {
+ t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
+ }
+ doPosts(&flushCC)
+ runtime.GC()
+ }
+}
+
+// This used to crash; https://golang.org/issue/3266
+func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) }
+func testTransportIdleConnCrash(t *testing.T, mode testMode) {
+ var tr *Transport
+
+ unblockCh := make(chan bool, 1)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ <-unblockCh
+ tr.CloseIdleConnections()
+ })).ts
+ c := ts.Client()
+ tr = c.Transport.(*Transport)
+
+ didreq := make(chan bool)
+ go func() {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ } else {
+ res.Body.Close() // returns idle conn
+ }
+ didreq <- true
+ }()
+ unblockCh <- true
+ <-didreq
+}
+
+// Test that the transport doesn't close the TCP connection early,
+// before the response body has been read. This was a regression
+// which sadly lacked a triggering test. The large response body made
+// the old race easier to trigger.
+func TestIssue3644(t *testing.T) { run(t, testIssue3644) }
+func testIssue3644(t *testing.T, mode testMode) {
+ const numFoos = 5000
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Connection", "close")
+ for i := 0; i < numFoos; i++ {
+ w.Write([]byte("foo "))
+ }
+ })).ts
+ c := ts.Client()
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ bs, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(bs) != numFoos*len("foo ") {
+ t.Errorf("unexpected response length")
+ }
+}
+
+// Test that a client receives a server's reply, even if the server doesn't read
+// the entire request body.
+func TestIssue3595(t *testing.T) { run(t, testIssue3595) }
+func testIssue3595(t *testing.T, mode testMode) {
+ const deniedMsg = "sorry, denied."
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ Error(w, deniedMsg, StatusUnauthorized)
+ })).ts
+ c := ts.Client()
+ res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
+ if err != nil {
+ t.Errorf("Post: %v", err)
+ return
+ }
+ got, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("Body ReadAll: %v", err)
+ }
+ if !strings.Contains(string(got), deniedMsg) {
+ t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
+ }
+}
+
+// From https://golang.org/issue/4454 ,
+// "client fails to handle requests with no body and chunked encoding"
+func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) }
+func testChunkedNoContent(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.WriteHeader(StatusNoContent)
+ })).ts
+
+ c := ts.Client()
+ for _, closeBody := range []bool{true, false} {
+ const n = 4
+ for i := 1; i <= n; i++ {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
+ } else {
+ if closeBody {
+ res.Body.Close()
+ }
+ }
+ }
+ }
+}
+
+func TestTransportConcurrency(t *testing.T) {
+ run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode})
+}
+func testTransportConcurrency(t *testing.T, mode testMode) {
+ // Not parallel: uses global test hooks.
+ maxProcs, numReqs := 16, 500
+ if testing.Short() {
+ maxProcs, numReqs = 4, 50
+ }
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "%v", r.FormValue("echo"))
+ })).ts
+
+ var wg sync.WaitGroup
+ wg.Add(numReqs)
+
+ // Due to the Transport's "socket late binding" (see
+ // idleConnCh in transport.go), the numReqs HTTP requests
+ // below can finish with a dial still outstanding. To keep
+ // the leak checker happy, keep track of pending dials and
+ // wait for them to finish (and be closed or returned to the
+ // idle pool) before we close idle connections.
+ SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
+ defer SetPendingDialHooks(nil, nil)
+
+ c := ts.Client()
+ reqs := make(chan string)
+ defer close(reqs)
+
+ for i := 0; i < maxProcs*2; i++ {
+ go func() {
+ for req := range reqs {
+ res, err := c.Get(ts.URL + "/?echo=" + req)
+ if err != nil {
+ if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") {
+ // https://go.dev/issue/52168: this test was observed to fail with
+ // ECONNRESET errors in Dial on various netbsd builders.
+ t.Logf("error on req %s: %v", req, err)
+ t.Logf("(see https://go.dev/issue/52168)")
+ } else {
+ t.Errorf("error on req %s: %v", req, err)
+ }
+ wg.Done()
+ continue
+ }
+ all, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Errorf("read error on req %s: %v", req, err)
+ } else if string(all) != req {
+ t.Errorf("body of req %s = %q; want %q", req, all, req)
+ }
+ res.Body.Close()
+ wg.Done()
+ }
+ }()
+ }
+ for i := 0; i < numReqs; i++ {
+ reqs <- fmt.Sprintf("request-%d", i)
+ }
+ wg.Wait()
+}
+
+func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) }
+func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) {
+ mux := NewServeMux()
+ mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
+ io.Copy(w, neverEnding('a'))
+ })
+ ts := newClientServerTest(t, mode, mux).ts
+
+ connc := make(chan net.Conn, 1)
+ c := ts.Client()
+ c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
+ conn, err := net.Dial(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ select {
+ case connc <- conn:
+ default:
+ }
+ return conn, nil
+ }
+
+ res, err := c.Get(ts.URL + "/get")
+ if err != nil {
+ t.Fatalf("Error issuing GET: %v", err)
+ }
+ defer res.Body.Close()
+
+ conn := <-connc
+ conn.SetDeadline(time.Now().Add(1 * time.Millisecond))
+ _, err = io.Copy(io.Discard, res.Body)
+ if err == nil {
+ t.Errorf("Unexpected successful copy")
+ }
+}
+
+func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
+ run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode})
+}
+func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) {
+ const debug = false
+ mux := NewServeMux()
+ mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
+ io.Copy(w, neverEnding('a'))
+ })
+ mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
+ defer r.Body.Close()
+ io.Copy(io.Discard, r.Body)
+ })
+ ts := newClientServerTest(t, mode, mux).ts
+ timeout := 100 * time.Millisecond
+
+ c := ts.Client()
+ c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
+ conn, err := net.Dial(n, addr)
+ if err != nil {
+ return nil, err
+ }
+ conn.SetDeadline(time.Now().Add(timeout))
+ if debug {
+ conn = NewLoggingConn("client", conn)
+ }
+ return conn, nil
+ }
+
+ getFailed := false
+ nRuns := 5
+ if testing.Short() {
+ nRuns = 1
+ }
+ for i := 0; i < nRuns; i++ {
+ if debug {
+ println("run", i+1, "of", nRuns)
+ }
+ sres, err := c.Get(ts.URL + "/get")
+ if err != nil {
+ if !getFailed {
+ // Make the timeout longer, once.
+ getFailed = true
+ t.Logf("increasing timeout")
+ i--
+ timeout *= 10
+ continue
+ }
+ t.Errorf("Error issuing GET: %v", err)
+ break
+ }
+ req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
+ _, err = c.Do(req)
+ if err == nil {
+ sres.Body.Close()
+ t.Errorf("Unexpected successful PUT")
+ break
+ }
+ sres.Body.Close()
+ }
+ if debug {
+ println("tests complete; waiting for handlers to finish")
+ }
+ ts.Close()
+}
+
+func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) }
+func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping timeout test in -short mode")
+ }
+ inHandler := make(chan bool, 1)
+ mux := NewServeMux()
+ mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
+ inHandler <- true
+ })
+ mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
+ inHandler <- true
+ time.Sleep(2 * time.Second)
+ })
+ ts := newClientServerTest(t, mode, mux).ts
+
+ c := ts.Client()
+ c.Transport.(*Transport).ResponseHeaderTimeout = 500 * time.Millisecond
+
+ tests := []struct {
+ path string
+ want int
+ wantErr string
+ }{
+ {path: "/fast", want: 200},
+ {path: "/slow", wantErr: "timeout awaiting response headers"},
+ {path: "/fast", want: 200},
+ }
+ for i, tt := range tests {
+ req, _ := NewRequest("GET", ts.URL+tt.path, nil)
+ req = req.WithT(t)
+ res, err := c.Do(req)
+ select {
+ case <-inHandler:
+ case <-time.After(5 * time.Second):
+ t.Errorf("never entered handler for test index %d, %s", i, tt.path)
+ continue
+ }
+ if err != nil {
+ uerr, ok := err.(*url.Error)
+ if !ok {
+ t.Errorf("error is not an url.Error; got: %#v", err)
+ continue
+ }
+ nerr, ok := uerr.Err.(net.Error)
+ if !ok {
+ t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
+ continue
+ }
+ if !nerr.Timeout() {
+ t.Errorf("want timeout error; got: %q", nerr)
+ continue
+ }
+ if strings.Contains(err.Error(), tt.wantErr) {
+ continue
+ }
+ t.Errorf("%d. unexpected error: %v", i, err)
+ continue
+ }
+ if tt.wantErr != "" {
+ t.Errorf("%d. no error. expected error: %v", i, tt.wantErr)
+ continue
+ }
+ if res.StatusCode != tt.want {
+ t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want)
+ }
+ }
+}
+
+func TestTransportCancelRequest(t *testing.T) {
+ run(t, testTransportCancelRequest, []testMode{http1Mode})
+}
+func testTransportCancelRequest(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping test in -short mode")
+ }
+ unblockc := make(chan bool)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "Hello")
+ w.(Flusher).Flush() // send headers and some body
+ <-unblockc
+ })).ts
+ defer close(unblockc)
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ go func() {
+ time.Sleep(1 * time.Second)
+ tr.CancelRequest(req)
+ }()
+ t0 := time.Now()
+ body, err := io.ReadAll(res.Body)
+ d := time.Since(t0)
+
+ if err != ExportErrRequestCanceled {
+ t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
+ }
+ if string(body) != "Hello" {
+ t.Errorf("Body = %q; want Hello", body)
+ }
+ if d < 500*time.Millisecond {
+ t.Errorf("expected ~1 second delay; got %v", d)
+ }
+ // Verify no outstanding requests after readLoop/writeLoop
+ // goroutines shut down.
+ for tries := 5; tries > 0; tries-- {
+ n := tr.NumPendingRequestsForTesting()
+ if n == 0 {
+ break
+ }
+ time.Sleep(100 * time.Millisecond)
+ if tries == 1 {
+ t.Errorf("pending requests = %d; want 0", n)
+ }
+ }
+}
+
+func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) {
+ if testing.Short() {
+ t.Skip("skipping test in -short mode")
+ }
+ unblockc := make(chan bool)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ <-unblockc
+ })).ts
+ defer close(unblockc)
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ donec := make(chan bool)
+ req, _ := NewRequest("GET", ts.URL, body)
+ go func() {
+ defer close(donec)
+ c.Do(req)
+ }()
+ start := time.Now()
+ timeout := 10 * time.Second
+ for time.Since(start) < timeout {
+ time.Sleep(100 * time.Millisecond)
+ tr.CancelRequest(req)
+ select {
+ case <-donec:
+ return
+ default:
+ }
+ }
+ t.Errorf("Do of canceled request has not returned after %v", timeout)
+}
+
+func TestTransportCancelRequestInDo(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testTransportCancelRequestInDo(t, mode, nil)
+ }, []testMode{http1Mode})
+}
+
+func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testTransportCancelRequestInDo(t, mode, bytes.NewBuffer([]byte{0}))
+ }, []testMode{http1Mode})
+}
+
+func TestTransportCancelRequestInDial(t *testing.T) {
+ defer afterTest(t)
+ if testing.Short() {
+ t.Skip("skipping test in -short mode")
+ }
+ var logbuf strings.Builder
+ eventLog := log.New(&logbuf, "", 0)
+
+ unblockDial := make(chan bool)
+ defer close(unblockDial)
+
+ inDial := make(chan bool)
+ tr := &Transport{
+ Dial: func(network, addr string) (net.Conn, error) {
+ eventLog.Println("dial: blocking")
+ if !<-inDial {
+ return nil, errors.New("main Test goroutine exited")
+ }
+ <-unblockDial
+ return nil, errors.New("nope")
+ },
+ }
+ cl := &Client{Transport: tr}
+ gotres := make(chan bool)
+ req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
+ go func() {
+ _, err := cl.Do(req)
+ eventLog.Printf("Get = %v", err)
+ gotres <- true
+ }()
+
+ select {
+ case inDial <- true:
+ case <-time.After(5 * time.Second):
+ close(inDial)
+ t.Fatal("timeout; never saw blocking dial")
+ }
+
+ eventLog.Printf("canceling")
+ tr.CancelRequest(req)
+ tr.CancelRequest(req) // used to panic on second call
+
+ select {
+ case <-gotres:
+ case <-time.After(5 * time.Second):
+ panic("hang. events are: " + logbuf.String())
+ }
+
+ got := logbuf.String()
+ want := `dial: blocking
+canceling
+Get = Get "http://something.no-network.tld/": net/http: request canceled while waiting for connection
+`
+ if got != want {
+ t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
+ }
+}
+
+func TestCancelRequestWithChannel(t *testing.T) { run(t, testCancelRequestWithChannel) }
+func testCancelRequestWithChannel(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping test in -short mode")
+ }
+ unblockc := make(chan bool)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "Hello")
+ w.(Flusher).Flush() // send headers and some body
+ <-unblockc
+ })).ts
+ defer close(unblockc)
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ ch := make(chan struct{})
+ req.Cancel = ch
+
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ go func() {
+ time.Sleep(1 * time.Second)
+ close(ch)
+ }()
+ t0 := time.Now()
+ body, err := io.ReadAll(res.Body)
+ d := time.Since(t0)
+
+ if err != ExportErrRequestCanceled {
+ t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
+ }
+ if string(body) != "Hello" {
+ t.Errorf("Body = %q; want Hello", body)
+ }
+ if d < 500*time.Millisecond {
+ t.Errorf("expected ~1 second delay; got %v", d)
+ }
+ // Verify no outstanding requests after readLoop/writeLoop
+ // goroutines shut down.
+ for tries := 5; tries > 0; tries-- {
+ n := tr.NumPendingRequestsForTesting()
+ if n == 0 {
+ break
+ }
+ time.Sleep(100 * time.Millisecond)
+ if tries == 1 {
+ t.Errorf("pending requests = %d; want 0", n)
+ }
+ }
+}
+
+func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testCancelRequestWithChannelBeforeDo(t, mode, false)
+ })
+}
+func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testCancelRequestWithChannelBeforeDo(t, mode, true)
+ })
+}
+func testCancelRequestWithChannelBeforeDo(t *testing.T, mode testMode, withCtx bool) {
+ unblockc := make(chan bool)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ <-unblockc
+ })).ts
+ defer close(unblockc)
+
+ c := ts.Client()
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ if withCtx {
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ req = req.WithContext(ctx)
+ } else {
+ ch := make(chan struct{})
+ req.Cancel = ch
+ close(ch)
+ }
+
+ _, err := c.Do(req)
+ if ue, ok := err.(*url.Error); ok {
+ err = ue.Err
+ }
+ if withCtx {
+ if err != context.Canceled {
+ t.Errorf("Do error = %v; want %v", err, context.Canceled)
+ }
+ } else {
+ if err == nil || !strings.Contains(err.Error(), "canceled") {
+ t.Errorf("Do error = %v; want cancellation", err)
+ }
+ }
+}
+
+// Issue 11020. The returned error message should be errRequestCanceled
+func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
+ defer afterTest(t)
+
+ serverConnCh := make(chan net.Conn, 1)
+ tr := &Transport{
+ Dial: func(network, addr string) (net.Conn, error) {
+ cc, sc := net.Pipe()
+ serverConnCh <- sc
+ return cc, nil
+ },
+ }
+ defer tr.CloseIdleConnections()
+ errc := make(chan error, 1)
+ req, _ := NewRequest("GET", "http://example.com/", nil)
+ go func() {
+ _, err := tr.RoundTrip(req)
+ errc <- err
+ }()
+
+ sc := <-serverConnCh
+ verb := make([]byte, 3)
+ if _, err := io.ReadFull(sc, verb); err != nil {
+ t.Errorf("Error reading HTTP verb from server: %v", err)
+ }
+ if string(verb) != "GET" {
+ t.Errorf("server received %q; want GET", verb)
+ }
+ defer sc.Close()
+
+ tr.CancelRequest(req)
+
+ err := <-errc
+ if err == nil {
+ t.Fatalf("unexpected success from RoundTrip")
+ }
+ if err != ExportErrRequestCanceled {
+ t.Errorf("RoundTrip error = %v; want ExportErrRequestCanceled", err)
+ }
+}
+
+// golang.org/issue/3672 -- Client can't close HTTP stream
+// Calling Close on a Response.Body used to just read until EOF.
+// Now it actually closes the TCP connection.
+func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) }
+func testTransportCloseResponseBody(t *testing.T, mode testMode) {
+ writeErr := make(chan error, 1)
+ msg := []byte("young\n")
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ for {
+ _, err := w.Write(msg)
+ if err != nil {
+ writeErr <- err
+ return
+ }
+ w.(Flusher).Flush()
+ }
+ })).ts
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ defer tr.CancelRequest(req)
+
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ const repeats = 3
+ buf := make([]byte, len(msg)*repeats)
+ want := bytes.Repeat(msg, repeats)
+
+ _, err = io.ReadFull(res.Body, buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("read %q; want %q", buf, want)
+ }
+ didClose := make(chan error, 1)
+ go func() {
+ didClose <- res.Body.Close()
+ }()
+ select {
+ case err := <-didClose:
+ if err != nil {
+ t.Errorf("Close = %v", err)
+ }
+ case <-time.After(10 * time.Second):
+ t.Fatal("too long waiting for close")
+ }
+ select {
+ case err := <-writeErr:
+ if err == nil {
+ t.Errorf("expected non-nil write error")
+ }
+ case <-time.After(10 * time.Second):
+ t.Fatal("too long waiting for write error")
+ }
+}
+
+type fooProto struct{}
+
+func (fooProto) RoundTrip(req *Request) (*Response, error) {
+ res := &Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Header: make(Header),
+ Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
+ }
+ return res, nil
+}
+
+func TestTransportAltProto(t *testing.T) {
+ defer afterTest(t)
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+ tr.RegisterProtocol("foo", fooProto{})
+ res, err := c.Get("foo://bar.com/path")
+ if err != nil {
+ t.Fatal(err)
+ }
+ bodyb, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body := string(bodyb)
+ if e := "You wanted foo://bar.com/path"; body != e {
+ t.Errorf("got response %q, want %q", body, e)
+ }
+}
+
+func TestTransportNoHost(t *testing.T) {
+ defer afterTest(t)
+ tr := &Transport{}
+ _, err := tr.RoundTrip(&Request{
+ Header: make(Header),
+ URL: &url.URL{
+ Scheme: "http",
+ },
+ })
+ want := "http: no Host in request URL"
+ if got := fmt.Sprint(err); got != want {
+ t.Errorf("error = %v; want %q", err, want)
+ }
+}
+
+// Issue 13311
+func TestTransportEmptyMethod(t *testing.T) {
+ req, _ := NewRequest("GET", "http://foo.com/", nil)
+ req.Method = "" // docs say "For client requests an empty string means GET"
+ got, err := httputil.DumpRequestOut(req, false) // DumpRequestOut uses Transport
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !strings.Contains(string(got), "GET ") {
+ t.Fatalf("expected substring 'GET '; got: %s", got)
+ }
+}
+
+func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) }
+func testTransportSocketLateBinding(t *testing.T, mode testMode) {
+ mux := NewServeMux()
+ fooGate := make(chan bool, 1)
+ mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
+ w.Header().Set("foo-ipport", r.RemoteAddr)
+ w.(Flusher).Flush()
+ <-fooGate
+ })
+ mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
+ w.Header().Set("bar-ipport", r.RemoteAddr)
+ })
+ ts := newClientServerTest(t, mode, mux).ts
+
+ dialGate := make(chan bool, 1)
+ c := ts.Client()
+ c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
+ if <-dialGate {
+ return net.Dial(n, addr)
+ }
+ return nil, errors.New("manually closed")
+ }
+
+ dialGate <- true // only allow one dial
+ fooRes, err := c.Get(ts.URL + "/foo")
+ if err != nil {
+ t.Fatal(err)
+ }
+ fooAddr := fooRes.Header.Get("foo-ipport")
+ if fooAddr == "" {
+ t.Fatal("No addr on /foo request")
+ }
+ time.AfterFunc(200*time.Millisecond, func() {
+ // let the foo response finish so we can use its
+ // connection for /bar
+ fooGate <- true
+ io.Copy(io.Discard, fooRes.Body)
+ fooRes.Body.Close()
+ })
+
+ barRes, err := c.Get(ts.URL + "/bar")
+ if err != nil {
+ t.Fatal(err)
+ }
+ barAddr := barRes.Header.Get("bar-ipport")
+ if barAddr != fooAddr {
+ t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
+ }
+ barRes.Body.Close()
+ dialGate <- false
+}
+
+// Issue 2184
+func TestTransportReading100Continue(t *testing.T) {
+ defer afterTest(t)
+
+ const numReqs = 5
+ reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
+ reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
+
+ send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
+ defer w.Close()
+ defer r.Close()
+ br := bufio.NewReader(r)
+ n := 0
+ for {
+ n++
+ req, err := ReadRequest(br)
+ if err == io.EOF {
+ return
+ }
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ slurp, err := io.ReadAll(req.Body)
+ if err != nil {
+ t.Errorf("Server request body slurp: %v", err)
+ return
+ }
+ id := req.Header.Get("Request-Id")
+ resCode := req.Header.Get("X-Want-Response-Code")
+ if resCode == "" {
+ resCode = "100 Continue"
+ if string(slurp) != reqBody(n) {
+ t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
+ }
+ }
+ body := fmt.Sprintf("Response number %d", n)
+ v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
+Date: Thu, 28 Feb 2013 17:55:41 GMT
+
+HTTP/1.1 200 OK
+Content-Type: text/html
+Echo-Request-Id: %s
+Content-Length: %d
+
+%s`, resCode, id, len(body), body), "\n", "\r\n", -1))
+ w.Write(v)
+ if id == reqID(numReqs) {
+ return
+ }
+ }
+
+ }
+
+ tr := &Transport{
+ Dial: func(n, addr string) (net.Conn, error) {
+ sr, sw := io.Pipe() // server read/write
+ cr, cw := io.Pipe() // client read/write
+ conn := &rwTestConn{
+ Reader: cr,
+ Writer: sw,
+ closeFunc: func() error {
+ sw.Close()
+ cw.Close()
+ return nil
+ },
+ }
+ go send100Response(cw, sr)
+ return conn, nil
+ },
+ DisableKeepAlives: false,
+ }
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ testResponse := func(req *Request, name string, wantCode int) {
+ t.Helper()
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatalf("%s: Do: %v", name, err)
+ }
+ if res.StatusCode != wantCode {
+ t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
+ }
+ if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
+ t.Errorf("%s: response id %q != request id %q", name, idBack, id)
+ }
+ _, err = io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("%s: Slurp error: %v", name, err)
+ }
+ }
+
+ // Few 100 responses, making sure we're not off-by-one.
+ for i := 1; i <= numReqs; i++ {
+ req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
+ req.Header.Set("Request-Id", reqID(i))
+ testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
+ }
+}
+
+// Issue 17739: the HTTP client must ignore any unknown 1xx
+// informational responses before the actual response.
+func TestTransportIgnore1xxResponses(t *testing.T) {
+ run(t, testTransportIgnore1xxResponses, []testMode{http1Mode})
+}
+func testTransportIgnore1xxResponses(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, buf, _ := w.(Hijacker).Hijack()
+ buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
+ buf.Flush()
+ conn.Close()
+ }))
+ cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
+
+ var got strings.Builder
+
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
+ Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+ fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
+ return nil
+ },
+ }))
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+
+ res.Write(&got)
+ want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
+ if got.String() != want {
+ t.Errorf(" got: %q\nwant: %q\n", got.String(), want)
+ }
+}
+
+func TestTransportLimits1xxResponses(t *testing.T) {
+ run(t, testTransportLimits1xxResponses, []testMode{http1Mode})
+}
+func testTransportLimits1xxResponses(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, buf, _ := w.(Hijacker).Hijack()
+ for i := 0; i < 10; i++ {
+ buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n"))
+ }
+ buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
+ buf.Flush()
+ conn.Close()
+ }))
+ cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if res != nil {
+ defer res.Body.Close()
+ }
+ got := fmt.Sprint(err)
+ wantSub := "too many 1xx informational responses"
+ if !strings.Contains(got, wantSub) {
+ t.Errorf("Get error = %v; want substring %q", err, wantSub)
+ }
+}
+
+// Issue 26161: the HTTP client must treat 101 responses
+// as the final response.
+func TestTransportTreat101Terminal(t *testing.T) {
+ run(t, testTransportTreat101Terminal, []testMode{http1Mode})
+}
+func testTransportTreat101Terminal(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, buf, _ := w.(Hijacker).Hijack()
+ buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
+ buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
+ buf.Flush()
+ conn.Close()
+ }))
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != StatusSwitchingProtocols {
+ t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
+ }
+}
+
+type proxyFromEnvTest struct {
+ req string // URL to fetch; blank means "http://example.com"
+
+ env string // HTTP_PROXY
+ httpsenv string // HTTPS_PROXY
+ noenv string // NO_PROXY
+ reqmeth string // REQUEST_METHOD
+
+ want string
+ wanterr error
+}
+
+func (t proxyFromEnvTest) String() string {
+ var buf strings.Builder
+ space := func() {
+ if buf.Len() > 0 {
+ buf.WriteByte(' ')
+ }
+ }
+ if t.env != "" {
+ fmt.Fprintf(&buf, "http_proxy=%q", t.env)
+ }
+ if t.httpsenv != "" {
+ space()
+ fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
+ }
+ if t.noenv != "" {
+ space()
+ fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
+ }
+ if t.reqmeth != "" {
+ space()
+ fmt.Fprintf(&buf, "request_method=%q", t.reqmeth)
+ }
+ req := "http://example.com"
+ if t.req != "" {
+ req = t.req
+ }
+ space()
+ fmt.Fprintf(&buf, "req=%q", req)
+ return strings.TrimSpace(buf.String())
+}
+
+var proxyFromEnvTests = []proxyFromEnvTest{
+ {env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
+ {env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
+ {env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
+ {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
+ {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
+ {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
+ {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"},
+
+ // Don't use secure for http
+ {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
+ // Use secure for https.
+ {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
+ {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
+
+ // Issue 16405: don't use HTTP_PROXY in a CGI environment,
+ // where HTTP_PROXY can be attacker-controlled.
+ {env: "http://10.1.2.3:8080", reqmeth: "POST",
+ want: "<nil>",
+ wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
+
+ {want: "<nil>"},
+
+ {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
+ {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
+ {noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
+ {noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
+ {noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
+}
+
+func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
+ t.Helper()
+ reqURL := tt.req
+ if reqURL == "" {
+ reqURL = "http://example.com"
+ }
+ req, _ := NewRequest("GET", reqURL, nil)
+ url, err := proxyForRequest(req)
+ if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
+ t.Errorf("%v: got error = %q, want %q", tt, g, e)
+ return
+ }
+ if got := fmt.Sprintf("%s", url); got != tt.want {
+ t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
+ }
+}
+
+func TestProxyFromEnvironment(t *testing.T) {
+ ResetProxyEnv()
+ defer ResetProxyEnv()
+ for _, tt := range proxyFromEnvTests {
+ testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
+ os.Setenv("HTTP_PROXY", tt.env)
+ os.Setenv("HTTPS_PROXY", tt.httpsenv)
+ os.Setenv("NO_PROXY", tt.noenv)
+ os.Setenv("REQUEST_METHOD", tt.reqmeth)
+ ResetCachedEnvironment()
+ return ProxyFromEnvironment(req)
+ })
+ }
+}
+
+func TestProxyFromEnvironmentLowerCase(t *testing.T) {
+ ResetProxyEnv()
+ defer ResetProxyEnv()
+ for _, tt := range proxyFromEnvTests {
+ testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
+ os.Setenv("http_proxy", tt.env)
+ os.Setenv("https_proxy", tt.httpsenv)
+ os.Setenv("no_proxy", tt.noenv)
+ os.Setenv("REQUEST_METHOD", tt.reqmeth)
+ ResetCachedEnvironment()
+ return ProxyFromEnvironment(req)
+ })
+ }
+}
+
+func TestIdleConnChannelLeak(t *testing.T) {
+ run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel)
+}
+func testIdleConnChannelLeak(t *testing.T, mode testMode) {
+ // Not parallel: uses global test hooks.
+ var mu sync.Mutex
+ var n int
+
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ mu.Lock()
+ n++
+ mu.Unlock()
+ })).ts
+
+ const nReqs = 5
+ didRead := make(chan bool, nReqs)
+ SetReadLoopBeforeNextReadHook(func() { didRead <- true })
+ defer SetReadLoopBeforeNextReadHook(nil)
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ tr.Dial = func(netw, addr string) (net.Conn, error) {
+ return net.Dial(netw, ts.Listener.Addr().String())
+ }
+
+ // First, without keep-alives.
+ for _, disableKeep := range []bool{true, false} {
+ tr.DisableKeepAlives = disableKeep
+ for i := 0; i < nReqs; i++ {
+ _, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Note: no res.Body.Close is needed here, since the
+ // response Content-Length is zero. Perhaps the test
+ // should be more explicit and use a HEAD, but tests
+ // elsewhere guarantee that zero byte responses generate
+ // a "Content-Length: 0" instead of chunking.
+ }
+
+ // At this point, each of the 5 Transport.readLoop goroutines
+ // are scheduling noting that there are no response bodies (see
+ // earlier comment), and are then calling putIdleConn, which
+ // decrements this count. Usually that happens quickly, which is
+ // why this test has seemed to work for ages. But it's still
+ // racey: we have wait for them to finish first. See Issue 10427
+ for i := 0; i < nReqs; i++ {
+ <-didRead
+ }
+
+ if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 {
+ t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
+ }
+ }
+}
+
+// Verify the status quo: that the Client.Post function coerces its
+// body into a ReadCloser if it's a Closer, and that the Transport
+// then closes it.
+func TestTransportClosesRequestBody(t *testing.T) {
+ run(t, testTransportClosesRequestBody, []testMode{http1Mode})
+}
+func testTransportClosesRequestBody(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.Copy(io.Discard, r.Body)
+ })).ts
+
+ c := ts.Client()
+
+ closes := 0
+
+ res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if closes != 1 {
+ t.Errorf("closes = %d; want 1", closes)
+ }
+}
+
+func TestTransportTLSHandshakeTimeout(t *testing.T) {
+ defer afterTest(t)
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ ln := newLocalListener(t)
+ defer ln.Close()
+ testdonec := make(chan struct{})
+ defer close(testdonec)
+
+ go func() {
+ c, err := ln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ <-testdonec
+ c.Close()
+ }()
+
+ getdonec := make(chan struct{})
+ go func() {
+ defer close(getdonec)
+ tr := &Transport{
+ Dial: func(_, _ string) (net.Conn, error) {
+ return net.Dial("tcp", ln.Addr().String())
+ },
+ TLSHandshakeTimeout: 250 * time.Millisecond,
+ }
+ cl := &Client{Transport: tr}
+ _, err := cl.Get("https://dummy.tld/")
+ if err == nil {
+ t.Error("expected error")
+ return
+ }
+ ue, ok := err.(*url.Error)
+ if !ok {
+ t.Errorf("expected url.Error; got %#v", err)
+ return
+ }
+ ne, ok := ue.Err.(net.Error)
+ if !ok {
+ t.Errorf("expected net.Error; got %#v", err)
+ return
+ }
+ if !ne.Timeout() {
+ t.Errorf("expected timeout error; got %v", err)
+ }
+ if !strings.Contains(err.Error(), "handshake timeout") {
+ t.Errorf("expected 'handshake timeout' in error; got %v", err)
+ }
+ }()
+ select {
+ case <-getdonec:
+ case <-time.After(5 * time.Second):
+ t.Error("test timeout; TLS handshake hung?")
+ }
+}
+
+// Trying to repro golang.org/issue/3514
+func TestTLSServerClosesConnection(t *testing.T) {
+ run(t, testTLSServerClosesConnection, []testMode{https1Mode})
+}
+func testTLSServerClosesConnection(t *testing.T, mode testMode) {
+ closedc := make(chan bool, 1)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
+ conn, _, _ := w.(Hijacker).Hijack()
+ conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
+ conn.Close()
+ closedc <- true
+ return
+ }
+ fmt.Fprintf(w, "hello")
+ })).ts
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+
+ var nSuccess = 0
+ var errs []error
+ const trials = 20
+ for i := 0; i < trials; i++ {
+ tr.CloseIdleConnections()
+ res, err := c.Get(ts.URL + "/keep-alive-then-die")
+ if err != nil {
+ t.Fatal(err)
+ }
+ <-closedc
+ slurp, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(slurp) != "foo" {
+ t.Errorf("Got %q, want foo", slurp)
+ }
+
+ // Now try again and see if we successfully
+ // pick a new connection.
+ res, err = c.Get(ts.URL + "/")
+ if err != nil {
+ errs = append(errs, err)
+ continue
+ }
+ slurp, err = io.ReadAll(res.Body)
+ if err != nil {
+ errs = append(errs, err)
+ continue
+ }
+ nSuccess++
+ }
+ if nSuccess > 0 {
+ t.Logf("successes = %d of %d", nSuccess, trials)
+ } else {
+ t.Errorf("All runs failed:")
+ }
+ for _, err := range errs {
+ t.Logf(" err: %v", err)
+ }
+}
+
+// byteFromChanReader is an io.Reader that reads a single byte at a
+// time from the channel. When the channel is closed, the reader
+// returns io.EOF.
+type byteFromChanReader chan byte
+
+func (c byteFromChanReader) Read(p []byte) (n int, err error) {
+ if len(p) == 0 {
+ return
+ }
+ b, ok := <-c
+ if !ok {
+ return 0, io.EOF
+ }
+ p[0] = b
+ return 1, nil
+}
+
+// Verifies that the Transport doesn't reuse a connection in the case
+// where the server replies before the request has been fully
+// written. We still honor that reply (see TestIssue3595), but don't
+// send future requests on the connection because it's then in a
+// questionable state.
+// golang.org/issue/7569
+func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
+ run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode})
+}
+func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) {
+ var sconn struct {
+ sync.Mutex
+ c net.Conn
+ }
+ var getOkay bool
+ closeConn := func() {
+ sconn.Lock()
+ defer sconn.Unlock()
+ if sconn.c != nil {
+ sconn.c.Close()
+ sconn.c = nil
+ if !getOkay {
+ t.Logf("Closed server connection")
+ }
+ }
+ }
+ defer closeConn()
+
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method == "GET" {
+ io.WriteString(w, "bar")
+ return
+ }
+ conn, _, _ := w.(Hijacker).Hijack()
+ sconn.Lock()
+ sconn.c = conn
+ sconn.Unlock()
+ conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive
+ go io.Copy(io.Discard, conn)
+ })).ts
+ c := ts.Client()
+
+ const bodySize = 256 << 10
+ finalBit := make(byteFromChanReader, 1)
+ req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
+ req.ContentLength = bodySize
+ res, err := c.Do(req)
+ if err := wantBody(res, err, "foo"); err != nil {
+ t.Errorf("POST response: %v", err)
+ }
+ donec := make(chan bool)
+ go func() {
+ defer close(donec)
+ res, err = c.Get(ts.URL)
+ if err := wantBody(res, err, "bar"); err != nil {
+ t.Errorf("GET response: %v", err)
+ return
+ }
+ getOkay = true // suppress test noise
+ }()
+ time.AfterFunc(5*time.Second, closeConn)
+ select {
+ case <-donec:
+ finalBit <- 'x' // unblock the writeloop of the first Post
+ close(finalBit)
+ case <-time.After(7 * time.Second):
+ t.Fatal("timeout waiting for GET request to finish")
+ }
+}
+
+// Tests that we don't leak Transport persistConn.readLoop goroutines
+// when a server hangs up immediately after saying it would keep-alive.
+func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) }
+func testTransportIssue10457(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ // Send a response with no body, keep-alive
+ // (implicit), and then lie and immediately close the
+ // connection. This forces the Transport's readLoop to
+ // immediately Peek an io.EOF and get to the point
+ // that used to hang.
+ conn, _, _ := w.(Hijacker).Hijack()
+ conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n")) // keep-alive
+ conn.Close()
+ })).ts
+ c := ts.Client()
+
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+
+ // Just a sanity check that we at least get the response. The real
+ // test here is that the "defer afterTest" above doesn't find any
+ // leaked goroutines.
+ if got, want := res.Header.Get("Foo"), "Bar"; got != want {
+ t.Errorf("Foo header = %q; want %q", got, want)
+ }
+}
+
+type closerFunc func() error
+
+func (f closerFunc) Close() error { return f() }
+
+type writerFuncConn struct {
+ net.Conn
+ write func(p []byte) (n int, err error)
+}
+
+func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) }
+
+// Issues 4677, 18241, and 17844. If we try to reuse a connection that the
+// server is in the process of closing, we may end up successfully writing out
+// our request (or a portion of our request) only to find a connection error
+// when we try to read from (or finish writing to) the socket.
+//
+// NOTE: we resend a request only if:
+// - we reused a keep-alive connection
+// - we haven't yet received any header data
+// - either we wrote no bytes to the server, or the request is idempotent
+//
+// This automatically prevents an infinite resend loop because we'll run out of
+// the cached keep-alive connections eventually.
+func TestRetryRequestsOnError(t *testing.T) {
+ run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode})
+}
+func testRetryRequestsOnError(t *testing.T, mode testMode) {
+ newRequest := func(method, urlStr string, body io.Reader) *Request {
+ req, err := NewRequest(method, urlStr, body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return req
+ }
+
+ testCases := []struct {
+ name string
+ failureN int
+ failureErr error
+ // Note that we can't just re-use the Request object across calls to c.Do
+ // because we need to rewind Body between calls. (GetBody is only used to
+ // rewind Body on failure and redirects, not just because it's done.)
+ req func() *Request
+ reqString string
+ }{
+ {
+ name: "IdempotentNoBodySomeWritten",
+ // Believe that we've written some bytes to the server, so we know we're
+ // not just in the "retry when no bytes sent" case".
+ failureN: 1,
+ // Use the specific error that shouldRetryRequest looks for with idempotent requests.
+ failureErr: ExportErrServerClosedIdle,
+ req: func() *Request {
+ return newRequest("GET", "http://fake.golang", nil)
+ },
+ reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
+ },
+ {
+ name: "IdempotentGetBodySomeWritten",
+ // Believe that we've written some bytes to the server, so we know we're
+ // not just in the "retry when no bytes sent" case".
+ failureN: 1,
+ // Use the specific error that shouldRetryRequest looks for with idempotent requests.
+ failureErr: ExportErrServerClosedIdle,
+ req: func() *Request {
+ return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n"))
+ },
+ reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
+ },
+ {
+ name: "NothingWrittenNoBody",
+ // It's key that we return 0 here -- that's what enables Transport to know
+ // that nothing was written, even though this is a non-idempotent request.
+ failureN: 0,
+ failureErr: errors.New("second write fails"),
+ req: func() *Request {
+ return newRequest("DELETE", "http://fake.golang", nil)
+ },
+ reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
+ },
+ {
+ name: "NothingWrittenGetBody",
+ // It's key that we return 0 here -- that's what enables Transport to know
+ // that nothing was written, even though this is a non-idempotent request.
+ failureN: 0,
+ failureErr: errors.New("second write fails"),
+ // Note that NewRequest will set up GetBody for strings.Reader, which is
+ // required for the retry to occur
+ req: func() *Request {
+ return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n"))
+ },
+ reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ var (
+ mu sync.Mutex
+ logbuf strings.Builder
+ )
+ logf := func(format string, args ...any) {
+ mu.Lock()
+ defer mu.Unlock()
+ fmt.Fprintf(&logbuf, format, args...)
+ logbuf.WriteByte('\n')
+ }
+
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ logf("Handler")
+ w.Header().Set("X-Status", "ok")
+ })).ts
+
+ var writeNumAtomic int32
+ c := ts.Client()
+ c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
+ logf("Dial")
+ c, err := net.Dial(network, ts.Listener.Addr().String())
+ if err != nil {
+ logf("Dial error: %v", err)
+ return nil, err
+ }
+ return &writerFuncConn{
+ Conn: c,
+ write: func(p []byte) (n int, err error) {
+ if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
+ logf("intentional write failure")
+ return tc.failureN, tc.failureErr
+ }
+ logf("Write(%q)", p)
+ return c.Write(p)
+ },
+ }, nil
+ }
+
+ SetRoundTripRetried(func() {
+ logf("Retried.")
+ })
+ defer SetRoundTripRetried(nil)
+
+ for i := 0; i < 3; i++ {
+ t0 := time.Now()
+ req := tc.req()
+ res, err := c.Do(req)
+ if err != nil {
+ if time.Since(t0) < MaxWriteWaitBeforeConnReuse/2 {
+ mu.Lock()
+ got := logbuf.String()
+ mu.Unlock()
+ t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
+ }
+ t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", MaxWriteWaitBeforeConnReuse)
+ }
+ res.Body.Close()
+ if res.Request != req {
+ t.Errorf("Response.Request != original request; want identical Request")
+ }
+ }
+
+ mu.Lock()
+ got := logbuf.String()
+ mu.Unlock()
+ want := fmt.Sprintf(`Dial
+Write("%s")
+Handler
+intentional write failure
+Retried.
+Dial
+Write("%s")
+Handler
+Write("%s")
+Handler
+`, tc.reqString, tc.reqString, tc.reqString)
+ if got != want {
+ t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want)
+ }
+ })
+ }
+}
+
+// Issue 6981
+func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) }
+func testTransportClosesBodyOnError(t *testing.T, mode testMode) {
+ readBody := make(chan error, 1)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ _, err := io.ReadAll(r.Body)
+ readBody <- err
+ })).ts
+ c := ts.Client()
+ fakeErr := errors.New("fake error")
+ didClose := make(chan bool, 1)
+ req, _ := NewRequest("POST", ts.URL, struct {
+ io.Reader
+ io.Closer
+ }{
+ io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)),
+ closerFunc(func() error {
+ select {
+ case didClose <- true:
+ default:
+ }
+ return nil
+ }),
+ })
+ res, err := c.Do(req)
+ if res != nil {
+ defer res.Body.Close()
+ }
+ if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
+ t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
+ }
+ select {
+ case err := <-readBody:
+ if err == nil {
+ t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
+ }
+ case <-time.After(5 * time.Second):
+ t.Error("timeout waiting for server handler to complete")
+ }
+ select {
+ case <-didClose:
+ default:
+ t.Errorf("didn't see Body.Close")
+ }
+}
+
+func TestTransportDialTLS(t *testing.T) {
+ run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode})
+}
+func testTransportDialTLS(t *testing.T, mode testMode) {
+ var mu sync.Mutex // guards following
+ var gotReq, didDial bool
+
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ mu.Lock()
+ gotReq = true
+ mu.Unlock()
+ })).ts
+ c := ts.Client()
+ c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
+ mu.Lock()
+ didDial = true
+ mu.Unlock()
+ c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
+ if err != nil {
+ return nil, err
+ }
+ return c, c.Handshake()
+ }
+
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ mu.Lock()
+ if !gotReq {
+ t.Error("didn't get request")
+ }
+ if !didDial {
+ t.Error("didn't use dial hook")
+ }
+}
+
+func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
+func testTransportDialContext(t *testing.T, mode testMode) {
+ var mu sync.Mutex // guards following
+ var gotReq bool
+ var receivedContext context.Context
+
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ mu.Lock()
+ gotReq = true
+ mu.Unlock()
+ })).ts
+ c := ts.Client()
+ c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
+ mu.Lock()
+ receivedContext = ctx
+ mu.Unlock()
+ return net.Dial(netw, addr)
+ }
+
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ctx := context.WithValue(context.Background(), "some-key", "some-value")
+ res, err := c.Do(req.WithContext(ctx))
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ mu.Lock()
+ if !gotReq {
+ t.Error("didn't get request")
+ }
+ if receivedContext != ctx {
+ t.Error("didn't receive correct context")
+ }
+}
+
+func TestTransportDialTLSContext(t *testing.T) {
+ run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
+}
+func testTransportDialTLSContext(t *testing.T, mode testMode) {
+ var mu sync.Mutex // guards following
+ var gotReq bool
+ var receivedContext context.Context
+
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ mu.Lock()
+ gotReq = true
+ mu.Unlock()
+ })).ts
+ c := ts.Client()
+ c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
+ mu.Lock()
+ receivedContext = ctx
+ mu.Unlock()
+ c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
+ if err != nil {
+ return nil, err
+ }
+ return c, c.HandshakeContext(ctx)
+ }
+
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ctx := context.WithValue(context.Background(), "some-key", "some-value")
+ res, err := c.Do(req.WithContext(ctx))
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ mu.Lock()
+ if !gotReq {
+ t.Error("didn't get request")
+ }
+ if receivedContext != ctx {
+ t.Error("didn't receive correct context")
+ }
+}
+
+// Test for issue 8755
+// Ensure that if a proxy returns an error, it is exposed by RoundTrip
+func TestRoundTripReturnsProxyError(t *testing.T) {
+ badProxy := func(*Request) (*url.URL, error) {
+ return nil, errors.New("errorMessage")
+ }
+
+ tr := &Transport{Proxy: badProxy}
+
+ req, _ := NewRequest("GET", "http://example.com", nil)
+
+ _, err := tr.RoundTrip(req)
+
+ if err == nil {
+ t.Error("Expected proxy error to be returned by RoundTrip")
+ }
+}
+
+// tests that putting an idle conn after a call to CloseIdleConns does return it
+func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
+ tr := &Transport{}
+ wantIdle := func(when string, n int) bool {
+ got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn
+ if got == n {
+ return true
+ }
+ t.Errorf("%s: idle conns = %d; want %d", when, got, n)
+ return false
+ }
+ wantIdle("start", 0)
+ if !tr.PutIdleTestConn("http", "example.com") {
+ t.Fatal("put failed")
+ }
+ if !tr.PutIdleTestConn("http", "example.com") {
+ t.Fatal("second put failed")
+ }
+ wantIdle("after put", 2)
+ tr.CloseIdleConnections()
+ if !tr.IsIdleForTesting() {
+ t.Error("should be idle after CloseIdleConnections")
+ }
+ wantIdle("after close idle", 0)
+ if tr.PutIdleTestConn("http", "example.com") {
+ t.Fatal("put didn't fail")
+ }
+ wantIdle("after second put", 0)
+
+ tr.QueueForIdleConnForTesting() // should toggle the transport out of idle mode
+ if tr.IsIdleForTesting() {
+ t.Error("shouldn't be idle after QueueForIdleConnForTesting")
+ }
+ if !tr.PutIdleTestConn("http", "example.com") {
+ t.Fatal("after re-activation")
+ }
+ wantIdle("after final put", 1)
+}
+
+// Test for issue 34282
+// Ensure that getConn doesn't call the GotConn trace hook on a HTTP/2 idle conn
+func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
+ tr := &Transport{}
+ wantIdle := func(when string, n int) bool {
+ got := tr.IdleConnCountForTesting("https", "example.com:443") // key used by PutIdleTestConnH2
+ if got == n {
+ return true
+ }
+ t.Errorf("%s: idle conns = %d; want %d", when, got, n)
+ return false
+ }
+ wantIdle("start", 0)
+ alt := funcRoundTripper(func() {})
+ if !tr.PutIdleTestConnH2("https", "example.com:443", alt) {
+ t.Fatal("put failed")
+ }
+ wantIdle("after put", 1)
+ ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
+ GotConn: func(httptrace.GotConnInfo) {
+ // tr.getConn should leave it for the HTTP/2 alt to call GotConn.
+ t.Error("GotConn called")
+ },
+ })
+ req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil)
+ _, err := tr.RoundTrip(req)
+ if err != errFakeRoundTrip {
+ t.Errorf("got error: %v; want %q", err, errFakeRoundTrip)
+ }
+ wantIdle("after round trip", 1)
+}
+
+func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) {
+ run(t, testTransportRemovesH2ConnsAfterIdle, []testMode{http2Mode})
+}
+func testTransportRemovesH2ConnsAfterIdle(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+
+ trFunc := func(tr *Transport) {
+ tr.MaxConnsPerHost = 1
+ tr.MaxIdleConnsPerHost = 1
+ tr.IdleConnTimeout = 10 * time.Millisecond
+ }
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc)
+
+ if _, err := cst.c.Get(cst.ts.URL); err != nil {
+ t.Fatalf("got error: %s", err)
+ }
+
+ time.Sleep(100 * time.Millisecond)
+ got := make(chan error)
+ go func() {
+ if _, err := cst.c.Get(cst.ts.URL); err != nil {
+ got <- err
+ }
+ close(got)
+ }()
+
+ timeout := time.NewTimer(5 * time.Second)
+ defer timeout.Stop()
+ select {
+ case err := <-got:
+ if err != nil {
+ t.Fatalf("got error: %s", err)
+ }
+ case <-timeout.C:
+ t.Fatal("request never completed")
+ }
+}
+
+// This tests that a client requesting a content range won't also
+// implicitly ask for gzip support. If they want that, they need to do it
+// on their own.
+// golang.org/issue/8923
+func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) }
+func testTransportRangeAndGzip(t *testing.T, mode testMode) {
+ reqc := make(chan *Request, 1)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ reqc <- r
+ })).ts
+ c := ts.Client()
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req.Header.Set("Range", "bytes=7-11")
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ select {
+ case r := <-reqc:
+ if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
+ t.Error("Transport advertised gzip support in the Accept header")
+ }
+ if r.Header.Get("Range") == "" {
+ t.Error("no Range in request")
+ }
+ case <-time.After(10 * time.Second):
+ t.Fatal("timeout")
+ }
+ res.Body.Close()
+}
+
+// Test for issue 10474
+func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) }
+func testTransportResponseCancelRace(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ // important that this response has a body.
+ var b [1024]byte
+ w.Write(b[:])
+ })).ts
+ tr := ts.Client().Transport.(*Transport)
+
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ // If we do an early close, Transport just throws the connection away and
+ // doesn't reuse it. In order to trigger the bug, it has to reuse the connection
+ // so read the body
+ if _, err := io.Copy(io.Discard, res.Body); err != nil {
+ t.Fatal(err)
+ }
+
+ req2, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ tr.CancelRequest(req)
+ res, err = tr.RoundTrip(req2)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+}
+
+// Test for issue 19248: Content-Encoding's value is case insensitive.
+func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
+ run(t, testTransportContentEncodingCaseInsensitive)
+}
+func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) {
+ for _, ce := range []string{"gzip", "GZIP"} {
+ ce := ce
+ t.Run(ce, func(t *testing.T) {
+ const encodedString = "Hello Gopher"
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Encoding", ce)
+ gz := gzip.NewWriter(w)
+ gz.Write([]byte(encodedString))
+ gz.Close()
+ })).ts
+
+ res, err := ts.Client().Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ body, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if string(body) != encodedString {
+ t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body))
+ }
+ })
+ }
+}
+
+func TestTransportDialCancelRace(t *testing.T) {
+ run(t, testTransportDialCancelRace, testNotParallel, []testMode{http1Mode})
+}
+func testTransportDialCancelRace(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
+ tr := ts.Client().Transport.(*Transport)
+
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ SetEnterRoundTripHook(func() {
+ tr.CancelRequest(req)
+ })
+ defer SetEnterRoundTripHook(nil)
+ res, err := tr.RoundTrip(req)
+ if err != ExportErrRequestCanceled {
+ t.Errorf("expected canceled request error; got %v", err)
+ if err == nil {
+ res.Body.Close()
+ }
+ }
+}
+
+// logWritesConn is a net.Conn that logs each Write call to writes
+// and then proxies to w.
+// It proxies Read calls to a reader it receives from rch.
+type logWritesConn struct {
+ net.Conn // nil. crash on use.
+
+ w io.Writer
+
+ rch <-chan io.Reader
+ r io.Reader // nil until received by rch
+
+ mu sync.Mutex
+ writes []string
+}
+
+func (c *logWritesConn) Write(p []byte) (n int, err error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.writes = append(c.writes, string(p))
+ return c.w.Write(p)
+}
+
+func (c *logWritesConn) Read(p []byte) (n int, err error) {
+ if c.r == nil {
+ c.r = <-c.rch
+ }
+ return c.r.Read(p)
+}
+
+func (c *logWritesConn) Close() error { return nil }
+
+// Issue 6574
+func TestTransportFlushesBodyChunks(t *testing.T) {
+ defer afterTest(t)
+ resBody := make(chan io.Reader, 1)
+ connr, connw := io.Pipe() // connection pipe pair
+ lw := &logWritesConn{
+ rch: resBody,
+ w: connw,
+ }
+ tr := &Transport{
+ Dial: func(network, addr string) (net.Conn, error) {
+ return lw, nil
+ },
+ }
+ bodyr, bodyw := io.Pipe() // body pipe pair
+ go func() {
+ defer bodyw.Close()
+ for i := 0; i < 3; i++ {
+ fmt.Fprintf(bodyw, "num%d\n", i)
+ }
+ }()
+ resc := make(chan *Response)
+ go func() {
+ req, _ := NewRequest("POST", "http://localhost:8080", bodyr)
+ req.Header.Set("User-Agent", "x") // known value for test
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Errorf("RoundTrip: %v", err)
+ close(resc)
+ return
+ }
+ resc <- res
+
+ }()
+ // Fully consume the request before checking the Write log vs. want.
+ req, err := ReadRequest(bufio.NewReader(connr))
+ if err != nil {
+ t.Fatal(err)
+ }
+ io.Copy(io.Discard, req.Body)
+
+ // Unblock the transport's roundTrip goroutine.
+ resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
+ res, ok := <-resc
+ if !ok {
+ return
+ }
+ defer res.Body.Close()
+
+ want := []string{
+ "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
+ "5\r\nnum0\n\r\n",
+ "5\r\nnum1\n\r\n",
+ "5\r\nnum2\n\r\n",
+ "0\r\n\r\n",
+ }
+ if !reflect.DeepEqual(lw.writes, want) {
+ t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want)
+ }
+}
+
+// Issue 22088: flush Transport request headers if we're not sure the body won't block on read.
+func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) }
+func testTransportFlushesRequestHeader(t *testing.T, mode testMode) {
+ gotReq := make(chan struct{})
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ close(gotReq)
+ }))
+
+ pr, pw := io.Pipe()
+ req, err := NewRequest("POST", cst.ts.URL, pr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ gotRes := make(chan struct{})
+ go func() {
+ defer close(gotRes)
+ res, err := cst.tr.RoundTrip(req)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ res.Body.Close()
+ }()
+
+ select {
+ case <-gotReq:
+ pw.Close()
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout waiting for handler to get request")
+ }
+ <-gotRes
+}
+
+// Issue 11745.
+func TestTransportPrefersResponseOverWriteError(t *testing.T) {
+ run(t, testTransportPrefersResponseOverWriteError)
+}
+func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ const contentLengthLimit = 1024 * 1024 // 1MB
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.ContentLength >= contentLengthLimit {
+ w.WriteHeader(StatusBadRequest)
+ r.Body.Close()
+ return
+ }
+ w.WriteHeader(StatusOK)
+ })).ts
+ c := ts.Client()
+
+ fail := 0
+ count := 100
+ bigBody := strings.Repeat("a", contentLengthLimit*2)
+ for i := 0; i < count; i++ {
+ req, err := NewRequest("PUT", ts.URL, strings.NewReader(bigBody))
+ if err != nil {
+ t.Fatal(err)
+ }
+ resp, err := c.Do(req)
+ if err != nil {
+ fail++
+ t.Logf("%d = %#v", i, err)
+ if ue, ok := err.(*url.Error); ok {
+ t.Logf("urlErr = %#v", ue.Err)
+ if ne, ok := ue.Err.(*net.OpError); ok {
+ t.Logf("netOpError = %#v", ne.Err)
+ }
+ }
+ } else {
+ resp.Body.Close()
+ if resp.StatusCode != 400 {
+ t.Errorf("Expected status code 400, got %v", resp.Status)
+ }
+ }
+ }
+ if fail > 0 {
+ t.Errorf("Failed %v out of %v\n", fail, count)
+ }
+}
+
+func TestTransportAutomaticHTTP2(t *testing.T) {
+ testTransportAutoHTTP(t, &Transport{}, true)
+}
+
+func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
+ testTransportAutoHTTP(t, &Transport{
+ ForceAttemptHTTP2: true,
+ TLSClientConfig: new(tls.Config),
+ }, true)
+}
+
+// golang.org/issue/14391: also check DefaultTransport
+func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
+ testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
+}
+
+func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) {
+ testTransportAutoHTTP(t, &Transport{
+ TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper),
+ }, false)
+}
+
+func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) {
+ testTransportAutoHTTP(t, &Transport{
+ TLSClientConfig: new(tls.Config),
+ }, false)
+}
+
+func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
+ testTransportAutoHTTP(t, &Transport{
+ ExpectContinueTimeout: 1 * time.Second,
+ }, true)
+}
+
+func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
+ var d net.Dialer
+ testTransportAutoHTTP(t, &Transport{
+ Dial: d.Dial,
+ }, false)
+}
+
+func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
+ var d net.Dialer
+ testTransportAutoHTTP(t, &Transport{
+ DialContext: d.DialContext,
+ }, false)
+}
+
+func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
+ testTransportAutoHTTP(t, &Transport{
+ DialTLS: func(network, addr string) (net.Conn, error) {
+ panic("unused")
+ },
+ }, false)
+}
+
+func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) {
+ CondSkipHTTP2(t)
+ _, err := tr.RoundTrip(new(Request))
+ if err == nil {
+ t.Error("expected error from RoundTrip")
+ }
+ if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 {
+ t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2)
+ }
+}
+
+// Issue 13633: there was a race where we returned bodyless responses
+// to callers before recycling the persistent connection, which meant
+// a client doing two subsequent requests could end up on different
+// connections. It's somewhat harmless but enough tests assume it's
+// not true in order to test other things that it's worth fixing.
+// Plus it's nice to be consistent and not have timing-dependent
+// behavior.
+func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
+ run(t, testTransportReuseConnEmptyResponseBody)
+}
+func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("X-Addr", r.RemoteAddr)
+ // Empty response body.
+ }))
+ n := 100
+ if testing.Short() {
+ n = 10
+ }
+ var firstAddr string
+ for i := 0; i < n; i++ {
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+ addr := res.Header.Get("X-Addr")
+ if i == 0 {
+ firstAddr = addr
+ } else if addr != firstAddr {
+ t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
+ }
+ res.Body.Close()
+ }
+}
+
+// Issue 13839
+func TestNoCrashReturningTransportAltConn(t *testing.T) {
+ cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ var wg sync.WaitGroup
+ SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
+ defer SetPendingDialHooks(nil, nil)
+
+ testDone := make(chan struct{})
+ defer close(testDone)
+ go func() {
+ tln := tls.NewListener(ln, &tls.Config{
+ NextProtos: []string{"foo"},
+ Certificates: []tls.Certificate{cert},
+ })
+ sc, err := tln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ if err := sc.(*tls.Conn).Handshake(); err != nil {
+ t.Error(err)
+ return
+ }
+ <-testDone
+ sc.Close()
+ }()
+
+ addr := ln.Addr().String()
+
+ req, _ := NewRequest("GET", "https://fake.tld/", nil)
+ cancel := make(chan struct{})
+ req.Cancel = cancel
+
+ doReturned := make(chan bool, 1)
+ madeRoundTripper := make(chan bool, 1)
+
+ tr := &Transport{
+ DisableKeepAlives: true,
+ TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
+ "foo": func(authority string, c *tls.Conn) RoundTripper {
+ madeRoundTripper <- true
+ return funcRoundTripper(func() {
+ t.Error("foo RoundTripper should not be called")
+ })
+ },
+ },
+ Dial: func(_, _ string) (net.Conn, error) {
+ panic("shouldn't be called")
+ },
+ DialTLS: func(_, _ string) (net.Conn, error) {
+ tc, err := tls.Dial("tcp", addr, &tls.Config{
+ InsecureSkipVerify: true,
+ NextProtos: []string{"foo"},
+ })
+ if err != nil {
+ return nil, err
+ }
+ if err := tc.Handshake(); err != nil {
+ return nil, err
+ }
+ close(cancel)
+ <-doReturned
+ return tc, nil
+ },
+ }
+ c := &Client{Transport: tr}
+
+ _, err = c.Do(req)
+ if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
+ t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
+ }
+
+ doReturned <- true
+ <-madeRoundTripper
+ wg.Wait()
+}
+
+func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testTransportReuseConnection_Gzip(t, mode, true)
+ })
+}
+
+func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testTransportReuseConnection_Gzip(t, mode, false)
+ })
+}
+
+// Make sure we re-use underlying TCP connection for gzipped responses too.
+func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) {
+ addr := make(chan string, 2)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ addr <- r.RemoteAddr
+ w.Header().Set("Content-Encoding", "gzip")
+ if chunked {
+ w.(Flusher).Flush()
+ }
+ w.Write(rgz) // arbitrary gzip response
+ })).ts
+ c := ts.Client()
+
+ trace := &httptrace.ClientTrace{
+ GetConn: func(hostPort string) { t.Logf("GetConn(%q)", hostPort) },
+ GotConn: func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) },
+ PutIdleConn: func(err error) { t.Logf("PutIdleConn(%v)", err) },
+ ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) },
+ ConnectDone: func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) },
+ }
+ ctx := httptrace.WithClientTrace(context.Background(), trace)
+
+ for i := 0; i < 2; i++ {
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req = req.WithContext(ctx)
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ buf := make([]byte, len(rgz))
+ if n, err := io.ReadFull(res.Body, buf); err != nil {
+ t.Errorf("%d. ReadFull = %v, %v", i, n, err)
+ }
+ // Note: no res.Body.Close call. It should work without it,
+ // since the flate.Reader's internal buffering will hit EOF
+ // and that should be sufficient.
+ }
+ a1, a2 := <-addr, <-addr
+ if a1 != a2 {
+ t.Fatalf("didn't reuse connection")
+ }
+}
+
+func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) }
+func testTransportResponseHeaderLength(t *testing.T, mode testMode) {
+ if mode == http2Mode {
+ t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes")
+ }
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.URL.Path == "/long" {
+ w.Header().Set("Long", strings.Repeat("a", 1<<20))
+ }
+ })).ts
+ c := ts.Client()
+ c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
+
+ if res, err := c.Get(ts.URL); err != nil {
+ t.Fatal(err)
+ } else {
+ res.Body.Close()
+ }
+
+ res, err := c.Get(ts.URL + "/long")
+ if err == nil {
+ defer res.Body.Close()
+ var n int64
+ for k, vv := range res.Header {
+ for _, v := range vv {
+ n += int64(len(k)) + int64(len(v))
+ }
+ }
+ t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
+ }
+ if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
+ t.Errorf("got error: %v; want %q", err, want)
+ }
+}
+
+func TestTransportEventTrace(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testTransportEventTrace(t, mode, false)
+ }, testNotParallel)
+}
+
+// test a non-nil httptrace.ClientTrace but with all hooks set to zero.
+func TestTransportEventTrace_NoHooks(t *testing.T) {
+ run(t, func(t *testing.T, mode testMode) {
+ testTransportEventTrace(t, mode, true)
+ }, testNotParallel)
+}
+
+func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) {
+ const resBody = "some body"
+ gotWroteReqEvent := make(chan struct{}, 500)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method == "GET" {
+ // Do nothing for the second request.
+ return
+ }
+ if _, err := io.ReadAll(r.Body); err != nil {
+ t.Error(err)
+ }
+ if !noHooks {
+ select {
+ case <-gotWroteReqEvent:
+ case <-time.After(5 * time.Second):
+ t.Error("timeout waiting for WroteRequest event")
+ }
+ }
+ io.WriteString(w, resBody)
+ }), func(tr *Transport) {
+ if tr.TLSClientConfig != nil {
+ tr.TLSClientConfig.InsecureSkipVerify = true
+ }
+ })
+ defer cst.close()
+
+ cst.tr.ExpectContinueTimeout = 1 * time.Second
+
+ var mu sync.Mutex // guards buf
+ var buf strings.Builder
+ logf := func(format string, args ...any) {
+ mu.Lock()
+ defer mu.Unlock()
+ fmt.Fprintf(&buf, format, args...)
+ buf.WriteByte('\n')
+ }
+
+ addrStr := cst.ts.Listener.Addr().String()
+ ip, port, err := net.SplitHostPort(addrStr)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Install a fake DNS server.
+ ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
+ if host != "dns-is-faked.golang" {
+ t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
+ return nil, nil
+ }
+ return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
+ })
+
+ body := "some body"
+ req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
+ req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
+ trace := &httptrace.ClientTrace{
+ GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
+ GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
+ GotFirstResponseByte: func() { logf("first response byte") },
+ PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) },
+ DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
+ DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
+ ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
+ ConnectDone: func(network, addr string, err error) {
+ if err != nil {
+ t.Errorf("ConnectDone: %v", err)
+ }
+ logf("ConnectDone: connected to %s %s = %v", network, addr, err)
+ },
+ WroteHeaderField: func(key string, value []string) {
+ logf("WroteHeaderField: %s: %v", key, value)
+ },
+ WroteHeaders: func() {
+ logf("WroteHeaders")
+ },
+ Wait100Continue: func() { logf("Wait100Continue") },
+ Got100Continue: func() { logf("Got100Continue") },
+ WroteRequest: func(e httptrace.WroteRequestInfo) {
+ logf("WroteRequest: %+v", e)
+ gotWroteReqEvent <- struct{}{}
+ },
+ }
+ if mode == http2Mode {
+ trace.TLSHandshakeStart = func() { logf("tls handshake start") }
+ trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
+ logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
+ }
+ }
+ if noHooks {
+ // zero out all func pointers, trying to get some path to crash
+ *trace = httptrace.ClientTrace{}
+ }
+ req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
+
+ req.Header.Set("Expect", "100-continue")
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ logf("got roundtrip.response")
+ slurp, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ logf("consumed body")
+ if string(slurp) != resBody || res.StatusCode != 200 {
+ t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
+ }
+ res.Body.Close()
+
+ if noHooks {
+ // Done at this point. Just testing a full HTTP
+ // requests can happen with a trace pointing to a zero
+ // ClientTrace, full of nil func pointers.
+ return
+ }
+
+ mu.Lock()
+ got := buf.String()
+ mu.Unlock()
+
+ wantOnce := func(sub string) {
+ if strings.Count(got, sub) != 1 {
+ t.Errorf("expected substring %q exactly once in output.", sub)
+ }
+ }
+ wantOnceOrMore := func(sub string) {
+ if strings.Count(got, sub) == 0 {
+ t.Errorf("expected substring %q at least once in output.", sub)
+ }
+ }
+ wantOnce("Getting conn for dns-is-faked.golang:" + port)
+ wantOnce("DNS start: {Host:dns-is-faked.golang}")
+ wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
+ wantOnce("got conn: {")
+ wantOnceOrMore("Connecting to tcp " + addrStr)
+ wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
+ wantOnce("Reused:false WasIdle:false IdleTime:0s")
+ wantOnce("first response byte")
+ if mode == http2Mode {
+ wantOnce("tls handshake start")
+ wantOnce("tls handshake done")
+ } else {
+ wantOnce("PutIdleConn = <nil>")
+ wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
+ // TODO(meirf): issue 19761. Make these agnostic to h1/h2. (These are not h1 specific, but the
+ // WroteHeaderField hook is not yet implemented in h2.)
+ wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
+ wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
+ wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
+ wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
+ }
+ wantOnce("WroteHeaders")
+ wantOnce("Wait100Continue")
+ wantOnce("Got100Continue")
+ wantOnce("WroteRequest: {Err:<nil>}")
+ if strings.Contains(got, " to udp ") {
+ t.Errorf("should not see UDP (DNS) connections")
+ }
+ if t.Failed() {
+ t.Errorf("Output:\n%s", got)
+ }
+
+ // And do a second request:
+ req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
+ req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
+ res, err = cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 200 {
+ t.Fatal(res.Status)
+ }
+ res.Body.Close()
+
+ mu.Lock()
+ got = buf.String()
+ mu.Unlock()
+
+ sub := "Getting conn for dns-is-faked.golang:"
+ if gotn, want := strings.Count(got, sub), 2; gotn != want {
+ t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
+ }
+
+}
+
+func TestTransportEventTraceTLSVerify(t *testing.T) {
+ run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode})
+}
+func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) {
+ var mu sync.Mutex
+ var buf strings.Builder
+ logf := func(format string, args ...any) {
+ mu.Lock()
+ defer mu.Unlock()
+ fmt.Fprintf(&buf, format, args...)
+ buf.WriteByte('\n')
+ }
+
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ t.Error("Unexpected request")
+ }), func(ts *httptest.Server) {
+ ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
+ logf("%s", p)
+ return len(p), nil
+ }), "", 0)
+ }).ts
+
+ certpool := x509.NewCertPool()
+ certpool.AddCert(ts.Certificate())
+
+ c := &Client{Transport: &Transport{
+ TLSClientConfig: &tls.Config{
+ ServerName: "dns-is-faked.golang",
+ RootCAs: certpool,
+ },
+ }}
+
+ trace := &httptrace.ClientTrace{
+ TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
+ TLSHandshakeDone: func(s tls.ConnectionState, err error) {
+ logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
+ },
+ }
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
+ _, err := c.Do(req)
+ if err == nil {
+ t.Error("Expected request to fail TLS verification")
+ }
+
+ mu.Lock()
+ got := buf.String()
+ mu.Unlock()
+
+ wantOnce := func(sub string) {
+ if strings.Count(got, sub) != 1 {
+ t.Errorf("expected substring %q exactly once in output.", sub)
+ }
+ }
+
+ wantOnce("TLSHandshakeStart")
+ wantOnce("TLSHandshakeDone")
+ wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com")
+
+ if t.Failed() {
+ t.Errorf("Output:\n%s", got)
+ }
+}
+
+var (
+ isDNSHijackedOnce sync.Once
+ isDNSHijacked bool
+)
+
+func skipIfDNSHijacked(t *testing.T) {
+ // Skip this test if the user is using a shady/ISP
+ // DNS server hijacking queries.
+ // See issues 16732, 16716.
+ isDNSHijackedOnce.Do(func() {
+ addrs, _ := net.LookupHost("dns-should-not-resolve.golang")
+ isDNSHijacked = len(addrs) != 0
+ })
+ if isDNSHijacked {
+ t.Skip("skipping; test requires non-hijacking DNS server")
+ }
+}
+
+func TestTransportEventTraceRealDNS(t *testing.T) {
+ skipIfDNSHijacked(t)
+ defer afterTest(t)
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ var mu sync.Mutex // guards buf
+ var buf strings.Builder
+ logf := func(format string, args ...any) {
+ mu.Lock()
+ defer mu.Unlock()
+ fmt.Fprintf(&buf, format, args...)
+ buf.WriteByte('\n')
+ }
+
+ req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
+ trace := &httptrace.ClientTrace{
+ DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
+ DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
+ ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
+ ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
+ }
+ req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
+
+ resp, err := c.Do(req)
+ if err == nil {
+ resp.Body.Close()
+ t.Fatal("expected error during DNS lookup")
+ }
+
+ mu.Lock()
+ got := buf.String()
+ mu.Unlock()
+
+ wantSub := func(sub string) {
+ if !strings.Contains(got, sub) {
+ t.Errorf("expected substring %q in output.", sub)
+ }
+ }
+ wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
+ wantSub("DNSDone: {Addrs:[] Err:")
+ if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
+ t.Errorf("should not see Connect events")
+ }
+ if t.Failed() {
+ t.Errorf("Output:\n%s", got)
+ }
+}
+
+// Issue 14353: port can only contain digits.
+func TestTransportRejectsAlphaPort(t *testing.T) {
+ res, err := Get("http://dummy.tld:123foo/bar")
+ if err == nil {
+ res.Body.Close()
+ t.Fatal("unexpected success")
+ }
+ ue, ok := err.(*url.Error)
+ if !ok {
+ t.Fatalf("got %#v; want *url.Error", err)
+ }
+ got := ue.Err.Error()
+ want := `invalid port ":123foo" after host`
+ if got != want {
+ t.Errorf("got error %q; want %q", got, want)
+ }
+}
+
+// Test the httptrace.TLSHandshake{Start,Done} hooks with a https http1
+// connections. The http2 test is done in TestTransportEventTrace_h2
+func TestTLSHandshakeTrace(t *testing.T) {
+ run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode})
+}
+func testTLSHandshakeTrace(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
+
+ var mu sync.Mutex
+ var start, done bool
+ trace := &httptrace.ClientTrace{
+ TLSHandshakeStart: func() {
+ mu.Lock()
+ defer mu.Unlock()
+ start = true
+ },
+ TLSHandshakeDone: func(s tls.ConnectionState, err error) {
+ mu.Lock()
+ defer mu.Unlock()
+ done = true
+ if err != nil {
+ t.Fatal("Expected error to be nil but was:", err)
+ }
+ },
+ }
+
+ c := ts.Client()
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal("Unable to construct test request:", err)
+ }
+ req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
+
+ r, err := c.Do(req)
+ if err != nil {
+ t.Fatal("Unexpected error making request:", err)
+ }
+ r.Body.Close()
+ mu.Lock()
+ defer mu.Unlock()
+ if !start {
+ t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
+ }
+ if !done {
+ t.Fatal("Expected TLSHandshakeDone to be called, but wasnt't")
+ }
+}
+
+func TestTransportMaxIdleConns(t *testing.T) {
+ run(t, testTransportMaxIdleConns, []testMode{http1Mode})
+}
+func testTransportMaxIdleConns(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ // No body for convenience.
+ })).ts
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ tr.MaxIdleConns = 4
+
+ ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
+ return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
+ })
+
+ hitHost := func(n int) {
+ req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil)
+ req = req.WithContext(ctx)
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ }
+ for i := 0; i < 4; i++ {
+ hitHost(i)
+ }
+ want := []string{
+ "|http|host-0.dns-is-faked.golang:" + port,
+ "|http|host-1.dns-is-faked.golang:" + port,
+ "|http|host-2.dns-is-faked.golang:" + port,
+ "|http|host-3.dns-is-faked.golang:" + port,
+ }
+ if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
+ t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want)
+ }
+
+ // Now hitting the 5th host should kick out the first host:
+ hitHost(4)
+ want = []string{
+ "|http|host-1.dns-is-faked.golang:" + port,
+ "|http|host-2.dns-is-faked.golang:" + port,
+ "|http|host-3.dns-is-faked.golang:" + port,
+ "|http|host-4.dns-is-faked.golang:" + port,
+ }
+ if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
+ t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want)
+ }
+}
+
+func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) }
+func testTransportIdleConnTimeout(t *testing.T, mode testMode) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+
+ const timeout = 1 * time.Second
+
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ // No body for convenience.
+ }))
+ tr := cst.tr
+ tr.IdleConnTimeout = timeout
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ idleConns := func() []string {
+ if mode == http2Mode {
+ return tr.IdleConnStrsForTesting_h2()
+ } else {
+ return tr.IdleConnStrsForTesting()
+ }
+ }
+
+ var conn string
+ doReq := func(n int) {
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
+ PutIdleConn: func(err error) {
+ if err != nil {
+ t.Errorf("failed to keep idle conn: %v", err)
+ }
+ },
+ }))
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ conns := idleConns()
+ if len(conns) != 1 {
+ t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
+ }
+ if conn == "" {
+ conn = conns[0]
+ }
+ if conn != conns[0] {
+ t.Fatalf("req %v: cached connection changed; expected the same one throughout the test", n)
+ }
+ }
+ for i := 0; i < 3; i++ {
+ doReq(i)
+ time.Sleep(timeout / 2)
+ }
+ time.Sleep(timeout * 3 / 2)
+ if got := idleConns(); len(got) != 0 {
+ t.Errorf("idle conns = %q; want none", got)
+ }
+}
+
+// Issue 16208: Go 1.7 crashed after Transport.IdleConnTimeout if an
+// HTTP/2 connection was established but its caller no longer
+// wanted it. (Assuming the connection cache was enabled, which it is
+// by default)
+//
+// This test reproduced the crash by setting the IdleConnTimeout low
+// (to make the test reasonable) and then making a request which is
+// canceled by the DialTLS hook, which then also waits to return the
+// real connection until after the RoundTrip saw the error. Then we
+// know the successful tls.Dial from DialTLS will need to go into the
+// idle pool. Then we give it a of time to explode.
+func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) }
+func testIdleConnH2Crash(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ // nothing
+ }))
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ sawDoErr := make(chan bool, 1)
+ testDone := make(chan struct{})
+ defer close(testDone)
+
+ cst.tr.IdleConnTimeout = 5 * time.Millisecond
+ cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
+ c, err := tls.Dial(network, addr, &tls.Config{
+ InsecureSkipVerify: true,
+ NextProtos: []string{"h2"},
+ })
+ if err != nil {
+ t.Error(err)
+ return nil, err
+ }
+ if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
+ t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
+ c.Close()
+ return nil, errors.New("bogus")
+ }
+
+ cancel()
+
+ failTimer := time.NewTimer(5 * time.Second)
+ defer failTimer.Stop()
+ select {
+ case <-sawDoErr:
+ case <-testDone:
+ case <-failTimer.C:
+ t.Error("timeout in DialTLS, waiting too long for cst.c.Do to fail")
+ }
+ return c, nil
+ }
+
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ req = req.WithContext(ctx)
+ res, err := cst.c.Do(req)
+ if err == nil {
+ res.Body.Close()
+ t.Fatal("unexpected success")
+ }
+ sawDoErr <- true
+
+ // Wait for the explosion.
+ time.Sleep(cst.tr.IdleConnTimeout * 10)
+}
+
+type funcConn struct {
+ net.Conn
+ read func([]byte) (int, error)
+ write func([]byte) (int, error)
+}
+
+func (c funcConn) Read(p []byte) (int, error) { return c.read(p) }
+func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
+func (c funcConn) Close() error { return nil }
+
+// Issue 16465: Transport.RoundTrip should return the raw net.Conn.Read error from Peek
+// back to the caller.
+func TestTransportReturnsPeekError(t *testing.T) {
+ errValue := errors.New("specific error value")
+
+ wrote := make(chan struct{})
+ var wroteOnce sync.Once
+
+ tr := &Transport{
+ Dial: func(network, addr string) (net.Conn, error) {
+ c := funcConn{
+ read: func([]byte) (int, error) {
+ <-wrote
+ return 0, errValue
+ },
+ write: func(p []byte) (int, error) {
+ wroteOnce.Do(func() { close(wrote) })
+ return len(p), nil
+ },
+ }
+ return c, nil
+ },
+ }
+ _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
+ if err != errValue {
+ t.Errorf("error = %#v; want %v", err, errValue)
+ }
+}
+
+// Issue 13835: international domain names should work
+func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) }
+func testTransportIDNA(t *testing.T, mode testMode) {
+ const uniDomain = "гофер.го"
+ const punyDomain = "xn--c1ae0ajs.xn--c1aw"
+
+ var port string
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ want := punyDomain + ":" + port
+ if r.Host != want {
+ t.Errorf("Host header = %q; want %q", r.Host, want)
+ }
+ if mode == http2Mode {
+ if r.TLS == nil {
+ t.Errorf("r.TLS == nil")
+ } else if r.TLS.ServerName != punyDomain {
+ t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
+ }
+ }
+ w.Header().Set("Hit-Handler", "1")
+ }), func(tr *Transport) {
+ if tr.TLSClientConfig != nil {
+ tr.TLSClientConfig.InsecureSkipVerify = true
+ }
+ })
+
+ ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Install a fake DNS server.
+ ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
+ if host != punyDomain {
+ t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
+ return nil, nil
+ }
+ return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
+ })
+
+ req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
+ trace := &httptrace.ClientTrace{
+ GetConn: func(hostPort string) {
+ want := net.JoinHostPort(punyDomain, port)
+ if hostPort != want {
+ t.Errorf("getting conn for %q; want %q", hostPort, want)
+ }
+ },
+ DNSStart: func(e httptrace.DNSStartInfo) {
+ if e.Host != punyDomain {
+ t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
+ }
+ },
+ }
+ req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
+
+ res, err := cst.tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.Header.Get("Hit-Handler") != "1" {
+ out, err := httputil.DumpResponse(res, true)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
+ }
+}
+
+// Issue 13290: send User-Agent in proxy CONNECT
+func TestTransportProxyConnectHeader(t *testing.T) {
+ run(t, testTransportProxyConnectHeader, []testMode{http1Mode})
+}
+func testTransportProxyConnectHeader(t *testing.T, mode testMode) {
+ reqc := make(chan *Request, 1)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method != "CONNECT" {
+ t.Errorf("method = %q; want CONNECT", r.Method)
+ }
+ reqc <- r
+ c, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Errorf("Hijack: %v", err)
+ return
+ }
+ c.Close()
+ })).ts
+
+ c := ts.Client()
+ c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
+ return url.Parse(ts.URL)
+ }
+ c.Transport.(*Transport).ProxyConnectHeader = Header{
+ "User-Agent": {"foo"},
+ "Other": {"bar"},
+ }
+
+ res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
+ if err == nil {
+ res.Body.Close()
+ t.Errorf("unexpected success")
+ }
+ select {
+ case <-time.After(3 * time.Second):
+ t.Fatal("timeout")
+ case r := <-reqc:
+ if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
+ t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
+ }
+ if got, want := r.Header.Get("Other"), "bar"; got != want {
+ t.Errorf("CONNECT request Other = %q; want %q", got, want)
+ }
+ }
+}
+
+func TestTransportProxyGetConnectHeader(t *testing.T) {
+ run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode})
+}
+func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) {
+ reqc := make(chan *Request, 1)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method != "CONNECT" {
+ t.Errorf("method = %q; want CONNECT", r.Method)
+ }
+ reqc <- r
+ c, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Errorf("Hijack: %v", err)
+ return
+ }
+ c.Close()
+ })).ts
+
+ c := ts.Client()
+ c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
+ return url.Parse(ts.URL)
+ }
+ // These should be ignored:
+ c.Transport.(*Transport).ProxyConnectHeader = Header{
+ "User-Agent": {"foo"},
+ "Other": {"bar"},
+ }
+ c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
+ return Header{
+ "User-Agent": {"foo2"},
+ "Other": {"bar2"},
+ }, nil
+ }
+
+ res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
+ if err == nil {
+ res.Body.Close()
+ t.Errorf("unexpected success")
+ }
+ select {
+ case <-time.After(3 * time.Second):
+ t.Fatal("timeout")
+ case r := <-reqc:
+ if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
+ t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
+ }
+ if got, want := r.Header.Get("Other"), "bar2"; got != want {
+ t.Errorf("CONNECT request Other = %q; want %q", got, want)
+ }
+ }
+}
+
+var errFakeRoundTrip = errors.New("fake roundtrip")
+
+type funcRoundTripper func()
+
+func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
+ fn()
+ return nil, errFakeRoundTrip
+}
+
+func wantBody(res *Response, err error, want string) error {
+ if err != nil {
+ return err
+ }
+ slurp, err := io.ReadAll(res.Body)
+ if err != nil {
+ return fmt.Errorf("error reading body: %v", err)
+ }
+ if string(slurp) != want {
+ return fmt.Errorf("body = %q; want %q", slurp, want)
+ }
+ if err := res.Body.Close(); err != nil {
+ return fmt.Errorf("body Close = %v", err)
+ }
+ return nil
+}
+
+func newLocalListener(t *testing.T) net.Listener {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ ln, err = net.Listen("tcp6", "[::1]:0")
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ return ln
+}
+
+type countCloseReader struct {
+ n *int
+ io.Reader
+}
+
+func (cr countCloseReader) Close() error {
+ (*cr.n)++
+ return nil
+}
+
+// rgz is a gzip quine that uncompresses to itself.
+var rgz = []byte{
+ 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
+ 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
+ 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
+ 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
+ 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
+ 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
+ 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
+ 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
+ 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
+ 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
+ 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
+ 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
+ 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
+ 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
+ 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
+ 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
+ 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
+ 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
+ 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
+ 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
+ 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
+ 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
+ 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
+ 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
+ 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
+ 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
+ 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
+ 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
+ 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
+ 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
+ 0x00, 0x00,
+}
+
+// Ensure that a missing status doesn't make the server panic
+// See Issue https://golang.org/issues/21701
+func TestMissingStatusNoPanic(t *testing.T) {
+ t.Parallel()
+
+ const want = "unknown status code"
+
+ ln := newLocalListener(t)
+ addr := ln.Addr().String()
+ done := make(chan bool)
+ fullAddrURL := fmt.Sprintf("http://%s", addr)
+ raw := "HTTP/1.1 400\r\n" +
+ "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
+ "Content-Type: text/html; charset=utf-8\r\n" +
+ "Content-Length: 10\r\n" +
+ "Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" +
+ "Vary: Accept-Encoding\r\n\r\n" +
+ "Aloha Olaa"
+
+ go func() {
+ defer close(done)
+
+ conn, _ := ln.Accept()
+ if conn != nil {
+ io.WriteString(conn, raw)
+ io.ReadAll(conn)
+ conn.Close()
+ }
+ }()
+
+ proxyURL, err := url.Parse(fullAddrURL)
+ if err != nil {
+ t.Fatalf("proxyURL: %v", err)
+ }
+
+ tr := &Transport{Proxy: ProxyURL(proxyURL)}
+
+ req, _ := NewRequest("GET", "https://golang.org/", nil)
+ res, err, panicked := doFetchCheckPanic(tr, req)
+ if panicked {
+ t.Error("panicked, expecting an error")
+ }
+ if res != nil && res.Body != nil {
+ io.Copy(io.Discard, res.Body)
+ res.Body.Close()
+ }
+
+ if err == nil || !strings.Contains(err.Error(), want) {
+ t.Errorf("got=%v want=%q", err, want)
+ }
+
+ ln.Close()
+ <-done
+}
+
+func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) {
+ defer func() {
+ if r := recover(); r != nil {
+ panicked = true
+ }
+ }()
+ res, err = tr.RoundTrip(req)
+ return
+}
+
+// Issue 22330: do not allow the response body to be read when the status code
+// forbids a response body.
+func TestNoBodyOnChunked304Response(t *testing.T) {
+ run(t, testNoBodyOnChunked304Response, []testMode{http1Mode})
+}
+func testNoBodyOnChunked304Response(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, buf, _ := w.(Hijacker).Hijack()
+ buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"))
+ buf.Flush()
+ conn.Close()
+ }))
+
+ // Our test server above is sending back bogus data after the
+ // response (the "0\r\n\r\n" part), which causes the Transport
+ // code to log spam. Disable keep-alives so we never even try
+ // to reuse the connection.
+ cst.tr.DisableKeepAlives = true
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if res.Body != NoBody {
+ t.Errorf("Unexpected body on 304 response")
+ }
+}
+
+type funcWriter func([]byte) (int, error)
+
+func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
+
+type doneContext struct {
+ context.Context
+ err error
+}
+
+func (doneContext) Done() <-chan struct{} {
+ c := make(chan struct{})
+ close(c)
+ return c
+}
+
+func (d doneContext) Err() error { return d.err }
+
+// Issue 25852: Transport should check whether Context is done early.
+func TestTransportCheckContextDoneEarly(t *testing.T) {
+ tr := &Transport{}
+ req, _ := NewRequest("GET", "http://fake.example/", nil)
+ wantErr := errors.New("some error")
+ req = req.WithContext(doneContext{context.Background(), wantErr})
+ _, err := tr.RoundTrip(req)
+ if err != wantErr {
+ t.Errorf("error = %v; want %v", err, wantErr)
+ }
+}
+
+// Issue 23399: verify that if a client request times out, the Transport's
+// conn is closed so that it's not reused.
+//
+// This is the test variant that times out before the server replies with
+// any response headers.
+func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
+ run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode})
+}
+func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) {
+ inHandler := make(chan net.Conn, 1)
+ handlerReadReturned := make(chan bool, 1)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ inHandler <- conn
+ n, err := conn.Read([]byte{0})
+ if n != 0 || err != io.EOF {
+ t.Errorf("unexpected Read result: %v, %v", n, err)
+ }
+ handlerReadReturned <- true
+ }))
+
+ const timeout = 50 * time.Millisecond
+ cst.c.Timeout = timeout
+
+ _, err := cst.c.Get(cst.ts.URL)
+ if err == nil {
+ t.Fatal("unexpected Get succeess")
+ }
+
+ select {
+ case c := <-inHandler:
+ select {
+ case <-handlerReadReturned:
+ // Success.
+ return
+ case <-time.After(5 * time.Second):
+ t.Error("Handler's conn.Read seems to be stuck in Read")
+ c.Close() // close it to unblock Handler
+ }
+ case <-time.After(timeout * 10):
+ // If we didn't get into the Handler in 50ms, that probably means
+ // the builder was just slow and the Get failed in that time
+ // but never made it to the server. That's fine. We'll usually
+ // test the part above on faster machines.
+ t.Skip("skipping test on slow builder")
+ }
+}
+
+// Issue 23399: verify that if a client request times out, the Transport's
+// conn is closed so that it's not reused.
+//
+// This is the test variant that has the server send response headers
+// first, and time out during the write of the response body.
+func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
+ run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode})
+}
+func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) {
+ inHandler := make(chan net.Conn, 1)
+ handlerResult := make(chan error, 1)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "100")
+ w.(Flusher).Flush()
+ conn, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ conn.Write([]byte("foo"))
+ inHandler <- conn
+ n, err := conn.Read([]byte{0})
+ // The error should be io.EOF or "read tcp
+ // 127.0.0.1:35827->127.0.0.1:40290: read: connection
+ // reset by peer" depending on timing. Really we just
+ // care that it returns at all. But if it returns with
+ // data, that's weird.
+ if n != 0 || err == nil {
+ handlerResult <- fmt.Errorf("unexpected Read result: %v, %v", n, err)
+ return
+ }
+ handlerResult <- nil
+ }))
+
+ // Set Timeout to something very long but non-zero to exercise
+ // the codepaths that check for it. But rather than wait for it to fire
+ // (which would make the test slow), we send on the req.Cancel channel instead,
+ // which happens to exercise the same code paths.
+ cst.c.Timeout = time.Minute // just to be non-zero, not to hit it.
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ cancel := make(chan struct{})
+ req.Cancel = cancel
+
+ res, err := cst.c.Do(req)
+ if err != nil {
+ select {
+ case <-inHandler:
+ t.Fatalf("Get error: %v", err)
+ default:
+ // Failed before entering handler. Ignore result.
+ t.Skip("skipping test on slow builder")
+ }
+ }
+
+ close(cancel)
+ got, err := io.ReadAll(res.Body)
+ if err == nil {
+ t.Fatalf("unexpected success; read %q, nil", got)
+ }
+
+ select {
+ case c := <-inHandler:
+ select {
+ case err := <-handlerResult:
+ if err != nil {
+ t.Errorf("handler: %v", err)
+ }
+ return
+ case <-time.After(5 * time.Second):
+ t.Error("Handler's conn.Read seems to be stuck in Read")
+ c.Close() // close it to unblock Handler
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout")
+ }
+}
+
+func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
+ run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode})
+}
+func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) {
+ done := make(chan struct{})
+ defer close(done)
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer conn.Close()
+ io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
+ bs := bufio.NewScanner(conn)
+ bs.Scan()
+ fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
+ <-done
+ }))
+
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ req.Header.Set("Upgrade", "foo")
+ req.Header.Set("Connection", "upgrade")
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 101 {
+ t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
+ }
+ rwc, ok := res.Body.(io.ReadWriteCloser)
+ if !ok {
+ t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
+ }
+ defer rwc.Close()
+ bs := bufio.NewScanner(rwc)
+ if !bs.Scan() {
+ t.Fatalf("expected readable input")
+ }
+ if got, want := bs.Text(), "Some buffered data"; got != want {
+ t.Errorf("read %q; want %q", got, want)
+ }
+ io.WriteString(rwc, "echo\n")
+ if !bs.Scan() {
+ t.Fatalf("expected another line")
+ }
+ if got, want := bs.Text(), "ECHO"; got != want {
+ t.Errorf("read %q; want %q", got, want)
+ }
+}
+
+func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) }
+func testTransportCONNECTBidi(t *testing.T, mode testMode) {
+ const target = "backend:443"
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method != "CONNECT" {
+ t.Errorf("unexpected method %q", r.Method)
+ w.WriteHeader(500)
+ return
+ }
+ if r.RequestURI != target {
+ t.Errorf("unexpected CONNECT target %q", r.RequestURI)
+ w.WriteHeader(500)
+ return
+ }
+ nc, brw, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer nc.Close()
+ nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
+ // Switch to a little protocol that capitalize its input lines:
+ for {
+ line, err := brw.ReadString('\n')
+ if err != nil {
+ if err != io.EOF {
+ t.Error(err)
+ }
+ return
+ }
+ io.WriteString(brw, strings.ToUpper(line))
+ brw.Flush()
+ }
+ }))
+ pr, pw := io.Pipe()
+ defer pw.Close()
+ req, err := NewRequest("CONNECT", cst.ts.URL, pr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.URL.Opaque = target
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != 200 {
+ t.Fatalf("status code = %d; want 200", res.StatusCode)
+ }
+ br := bufio.NewReader(res.Body)
+ for _, str := range []string{"foo", "bar", "baz"} {
+ fmt.Fprintf(pw, "%s\n", str)
+ got, err := br.ReadString('\n')
+ if err != nil {
+ t.Fatal(err)
+ }
+ got = strings.TrimSpace(got)
+ want := strings.ToUpper(str)
+ if got != want {
+ t.Fatalf("got %q; want %q", got, want)
+ }
+ }
+}
+
+func TestTransportRequestReplayable(t *testing.T) {
+ someBody := io.NopCloser(strings.NewReader(""))
+ tests := []struct {
+ name string
+ req *Request
+ want bool
+ }{
+ {
+ name: "GET",
+ req: &Request{Method: "GET"},
+ want: true,
+ },
+ {
+ name: "GET_http.NoBody",
+ req: &Request{Method: "GET", Body: NoBody},
+ want: true,
+ },
+ {
+ name: "GET_body",
+ req: &Request{Method: "GET", Body: someBody},
+ want: false,
+ },
+ {
+ name: "POST",
+ req: &Request{Method: "POST"},
+ want: false,
+ },
+ {
+ name: "POST_idempotency-key",
+ req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}},
+ want: true,
+ },
+ {
+ name: "POST_x-idempotency-key",
+ req: &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}},
+ want: true,
+ },
+ {
+ name: "POST_body",
+ req: &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody},
+ want: false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.req.ExportIsReplayable()
+ if got != tt.want {
+ t.Errorf("replyable = %v; want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+// testMockTCPConn is a mock TCP connection used to test that
+// ReadFrom is called when sending the request body.
+type testMockTCPConn struct {
+ *net.TCPConn
+
+ ReadFromCalled bool
+}
+
+func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
+ c.ReadFromCalled = true
+ return c.TCPConn.ReadFrom(r)
+}
+
+func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) }
+func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) {
+ nBytes := int64(1 << 10)
+ newFileFunc := func() (r io.Reader, done func(), err error) {
+ f, err := os.CreateTemp("", "net-http-newfilefunc")
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Write some bytes to the file to enable reading.
+ if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
+ return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
+ }
+ if _, err := f.Seek(0, 0); err != nil {
+ return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
+ }
+
+ done = func() {
+ f.Close()
+ os.Remove(f.Name())
+ }
+
+ return f, done, nil
+ }
+
+ newBufferFunc := func() (io.Reader, func(), error) {
+ return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
+ }
+
+ cases := []struct {
+ name string
+ readerFunc func() (io.Reader, func(), error)
+ contentLength int64
+ expectedReadFrom bool
+ }{
+ {
+ name: "file, length",
+ readerFunc: newFileFunc,
+ contentLength: nBytes,
+ expectedReadFrom: true,
+ },
+ {
+ name: "file, no length",
+ readerFunc: newFileFunc,
+ },
+ {
+ name: "file, negative length",
+ readerFunc: newFileFunc,
+ contentLength: -1,
+ },
+ {
+ name: "buffer",
+ contentLength: nBytes,
+ readerFunc: newBufferFunc,
+ },
+ {
+ name: "buffer, no length",
+ readerFunc: newBufferFunc,
+ },
+ {
+ name: "buffer, length -1",
+ contentLength: -1,
+ readerFunc: newBufferFunc,
+ },
+ }
+
+ for _, tc := range cases {
+ t.Run(tc.name, func(t *testing.T) {
+ r, cleanup, err := tc.readerFunc()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ tConn := &testMockTCPConn{}
+ trFunc := func(tr *Transport) {
+ tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+ var d net.Dialer
+ conn, err := d.DialContext(ctx, network, addr)
+ if err != nil {
+ return nil, err
+ }
+
+ tcpConn, ok := conn.(*net.TCPConn)
+ if !ok {
+ return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
+ }
+
+ tConn.TCPConn = tcpConn
+ return tConn, nil
+ }
+ }
+
+ cst := newClientServerTest(
+ t,
+ mode,
+ HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.Copy(io.Discard, r.Body)
+ r.Body.Close()
+ w.WriteHeader(200)
+ }),
+ trFunc,
+ )
+
+ req, err := NewRequest("PUT", cst.ts.URL, r)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.ContentLength = tc.contentLength
+ req.Header.Set("Content-Type", "application/octet-stream")
+ resp, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != 200 {
+ t.Fatalf("status code = %d; want 200", resp.StatusCode)
+ }
+
+ expectedReadFrom := tc.expectedReadFrom
+ if mode != http1Mode {
+ expectedReadFrom = false
+ }
+ if !tConn.ReadFromCalled && expectedReadFrom {
+ t.Fatalf("did not call ReadFrom")
+ }
+
+ if tConn.ReadFromCalled && !expectedReadFrom {
+ t.Fatalf("ReadFrom was unexpectedly invoked")
+ }
+ })
+ }
+}
+
+func TestTransportClone(t *testing.T) {
+ tr := &Transport{
+ Proxy: func(*Request) (*url.URL, error) { panic("") },
+ OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
+ return nil
+ },
+ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
+ Dial: func(network, addr string) (net.Conn, error) { panic("") },
+ DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
+ DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
+ TLSClientConfig: new(tls.Config),
+ TLSHandshakeTimeout: time.Second,
+ DisableKeepAlives: true,
+ DisableCompression: true,
+ MaxIdleConns: 1,
+ MaxIdleConnsPerHost: 1,
+ MaxConnsPerHost: 1,
+ IdleConnTimeout: time.Second,
+ ResponseHeaderTimeout: time.Second,
+ ExpectContinueTimeout: time.Second,
+ ProxyConnectHeader: Header{},
+ GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
+ MaxResponseHeaderBytes: 1,
+ ForceAttemptHTTP2: true,
+ TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
+ "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
+ },
+ ReadBufferSize: 1,
+ WriteBufferSize: 1,
+ }
+ tr2 := tr.Clone()
+ rv := reflect.ValueOf(tr2).Elem()
+ rt := rv.Type()
+ for i := 0; i < rt.NumField(); i++ {
+ sf := rt.Field(i)
+ if !token.IsExported(sf.Name) {
+ continue
+ }
+ if rv.Field(i).IsZero() {
+ t.Errorf("cloned field t2.%s is zero", sf.Name)
+ }
+ }
+
+ if _, ok := tr2.TLSNextProto["foo"]; !ok {
+ t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
+ }
+
+ // But test that a nil TLSNextProto is kept nil:
+ tr = new(Transport)
+ tr2 = tr.Clone()
+ if tr2.TLSNextProto != nil {
+ t.Errorf("Transport.TLSNextProto unexpected non-nil")
+ }
+}
+
+func TestIs408(t *testing.T) {
+ tests := []struct {
+ in string
+ want bool
+ }{
+ {"HTTP/1.0 408", true},
+ {"HTTP/1.1 408", true},
+ {"HTTP/1.8 408", true},
+ {"HTTP/2.0 408", false}, // maybe h2c would do this? but false for now.
+ {"HTTP/1.1 408 ", true},
+ {"HTTP/1.1 40", false},
+ {"http/1.0 408", false},
+ {"HTTP/1-1 408", false},
+ }
+ for _, tt := range tests {
+ if got := Export_is408Message([]byte(tt.in)); got != tt.want {
+ t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
+ }
+ }
+}
+
+func TestTransportIgnores408(t *testing.T) {
+ run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel)
+}
+func testTransportIgnores408(t *testing.T, mode testMode) {
+ // Not parallel. Relies on mutating the log package's global Output.
+ defer log.SetOutput(log.Writer())
+
+ var logout strings.Builder
+ log.SetOutput(&logout)
+
+ const target = "backend:443"
+
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ nc, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer nc.Close()
+ nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
+ nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail
+ }))
+ req, err := NewRequest("GET", cst.ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ slurp, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(slurp) != "ok" {
+ t.Fatalf("got %q; want ok", slurp)
+ }
+
+ t0 := time.Now()
+ for i := 0; i < 50; i++ {
+ time.Sleep(time.Duration(i) * 5 * time.Millisecond)
+ if cst.tr.IdleConnKeyCountForTesting() == 0 {
+ if got := logout.String(); got != "" {
+ t.Fatalf("expected no log output; got: %s", got)
+ }
+ return
+ }
+ }
+ t.Fatalf("timeout after %v waiting for Transport connections to die off", time.Since(t0))
+}
+
+func TestInvalidHeaderResponse(t *testing.T) {
+ run(t, testInvalidHeaderResponse, []testMode{http1Mode})
+}
+func testInvalidHeaderResponse(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, buf, _ := w.(Hijacker).Hijack()
+ buf.Write([]byte("HTTP/1.1 200 OK\r\n" +
+ "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
+ "Content-Type: text/html; charset=utf-8\r\n" +
+ "Content-Length: 0\r\n" +
+ "Foo : bar\r\n\r\n"))
+ buf.Flush()
+ conn.Close()
+ }))
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if v := res.Header.Get("Foo"); v != "" {
+ t.Errorf(`unexpected "Foo" header: %q`, v)
+ }
+ if v := res.Header.Get("Foo "); v != "bar" {
+ t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar")
+ }
+}
+
+type bodyCloser bool
+
+func (bc *bodyCloser) Close() error {
+ *bc = true
+ return nil
+}
+func (bc *bodyCloser) Read(b []byte) (n int, err error) {
+ return 0, io.EOF
+}
+
+// Issue 35015: ensure that Transport closes the body on any error
+// with an invalid request, as promised by Client.Do docs.
+func TestTransportClosesBodyOnInvalidRequests(t *testing.T) {
+ run(t, testTransportClosesBodyOnInvalidRequests)
+}
+func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ t.Errorf("Should not have been invoked")
+ })).ts
+
+ u, _ := url.Parse(cst.URL)
+
+ tests := []struct {
+ name string
+ req *Request
+ wantErr string
+ }{
+ {
+ name: "invalid method",
+ req: &Request{
+ Method: " ",
+ URL: u,
+ },
+ wantErr: `invalid method " "`,
+ },
+ {
+ name: "nil URL",
+ req: &Request{
+ Method: "GET",
+ },
+ wantErr: `nil Request.URL`,
+ },
+ {
+ name: "invalid header key",
+ req: &Request{
+ Method: "GET",
+ Header: Header{"💡": {"emoji"}},
+ URL: u,
+ },
+ wantErr: `invalid header field name "💡"`,
+ },
+ {
+ name: "invalid header value",
+ req: &Request{
+ Method: "POST",
+ Header: Header{"key": {"\x19"}},
+ URL: u,
+ },
+ wantErr: `invalid header field value for "key"`,
+ },
+ {
+ name: "non HTTP(s) scheme",
+ req: &Request{
+ Method: "POST",
+ URL: &url.URL{Scheme: "faux"},
+ },
+ wantErr: `unsupported protocol scheme "faux"`,
+ },
+ {
+ name: "no Host in URL",
+ req: &Request{
+ Method: "POST",
+ URL: &url.URL{Scheme: "http"},
+ },
+ wantErr: `no Host in request URL`,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var bc bodyCloser
+ req := tt.req
+ req.Body = &bc
+ _, err := cst.Client().Do(tt.req)
+ if err == nil {
+ t.Fatal("Expected an error")
+ }
+ if !bc {
+ t.Fatal("Expected body to have been closed")
+ }
+ if g, w := err.Error(), tt.wantErr; !strings.HasSuffix(g, w) {
+ t.Fatalf("Error mismatch: %q does not end with %q", g, w)
+ }
+ })
+ }
+}
+
+// breakableConn is a net.Conn wrapper with a Write method
+// that will fail when its brokenState is true.
+type breakableConn struct {
+ net.Conn
+ *brokenState
+}
+
+type brokenState struct {
+ sync.Mutex
+ broken bool
+}
+
+func (w *breakableConn) Write(b []byte) (n int, err error) {
+ w.Lock()
+ defer w.Unlock()
+ if w.broken {
+ return 0, errors.New("some write error")
+ }
+ return w.Conn.Write(b)
+}
+
+// Issue 34978: don't cache a broken HTTP/2 connection
+func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
+ run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode})
+}
+func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
+
+ var brokenState brokenState
+
+ const numReqs = 5
+ var numDials, gotConns uint32 // atomic
+
+ cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
+ atomic.AddUint32(&numDials, 1)
+ c, err := net.Dial(netw, addr)
+ if err != nil {
+ t.Errorf("unexpected Dial error: %v", err)
+ return nil, err
+ }
+ return &breakableConn{c, &brokenState}, err
+ }
+
+ for i := 1; i <= numReqs; i++ {
+ brokenState.Lock()
+ brokenState.broken = false
+ brokenState.Unlock()
+
+ // doBreak controls whether we break the TCP connection after the TLS
+ // handshake (before the HTTP/2 handshake). We test a few failures
+ // in a row followed by a final success.
+ doBreak := i != numReqs
+
+ ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
+ GotConn: func(info httptrace.GotConnInfo) {
+ t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime)
+ atomic.AddUint32(&gotConns, 1)
+ },
+ TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
+ brokenState.Lock()
+ defer brokenState.Unlock()
+ if doBreak {
+ brokenState.broken = true
+ }
+ },
+ })
+ req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = cst.c.Do(req)
+ if doBreak != (err != nil) {
+ t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
+ }
+ }
+ if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
+ t.Errorf("GotConn calls = %v; want %v", got, want)
+ }
+ if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want {
+ t.Errorf("Dials = %v; want %v", got, want)
+ }
+}
+
+// Issue 34941
+// When the client has too many concurrent requests on a single connection,
+// http.http2noCachedConnError is reported on multiple requests. There should
+// only be one decrement regardless of the number of failures.
+func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
+ run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode})
+}
+func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) {
+ CondSkipHTTP2(t)
+
+ h := HandlerFunc(func(w ResponseWriter, r *Request) {
+ _, err := w.Write([]byte("foo"))
+ if err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ })
+
+ ts := newClientServerTest(t, mode, h).ts
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ tr.MaxConnsPerHost = 1
+
+ errCh := make(chan error, 300)
+ doReq := func() {
+ resp, err := c.Get(ts.URL)
+ if err != nil {
+ errCh <- fmt.Errorf("request failed: %v", err)
+ return
+ }
+ defer resp.Body.Close()
+ _, err = io.ReadAll(resp.Body)
+ if err != nil {
+ errCh <- fmt.Errorf("read body failed: %v", err)
+ }
+ }
+
+ var wg sync.WaitGroup
+ for i := 0; i < 300; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ doReq()
+ }()
+ }
+ wg.Wait()
+ close(errCh)
+
+ for err := range errCh {
+ t.Errorf("error occurred: %v", err)
+ }
+}
+
+// Issue 36820
+// Test that we use the older backward compatible cancellation protocol
+// when a RoundTripper is registered via RegisterProtocol.
+func TestAltProtoCancellation(t *testing.T) {
+ defer afterTest(t)
+ tr := &Transport{}
+ c := &Client{
+ Transport: tr,
+ Timeout: time.Millisecond,
+ }
+ tr.RegisterProtocol("timeout", timeoutProto{})
+ _, err := c.Get("timeout://bar.com/path")
+ if err == nil {
+ t.Error("request unexpectedly succeeded")
+ } else if !strings.Contains(err.Error(), timeoutProtoErr.Error()) {
+ t.Errorf("got error %q, does not contain expected string %q", err, timeoutProtoErr)
+ }
+}
+
+var timeoutProtoErr = errors.New("canceled as expected")
+
+type timeoutProto struct{}
+
+func (timeoutProto) RoundTrip(req *Request) (*Response, error) {
+ select {
+ case <-req.Cancel:
+ return nil, timeoutProtoErr
+ case <-time.After(5 * time.Second):
+ return nil, errors.New("request was not canceled")
+ }
+}
+
+type roundTripFunc func(r *Request) (*Response, error)
+
+func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
+
+// Issue 32441: body is not reset after ErrSkipAltProtocol
+func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) }
+func testIssue32441(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
+ t.Error("body length is zero")
+ }
+ })).ts
+ c := ts.Client()
+ c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
+ // Draining body to trigger failure condition on actual request to server.
+ if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
+ t.Error("body length is zero during round trip")
+ }
+ return nil, ErrSkipAltProtocol
+ }))
+ if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
+ t.Error(err)
+ }
+}
+
+// Issue 39017. Ensure that HTTP/1 transports reject Content-Length headers
+// that contain a sign (eg. "+3"), per RFC 2616, Section 14.13.
+func TestTransportRejectsSignInContentLength(t *testing.T) {
+ run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode})
+}
+func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "+3")
+ w.Write([]byte("abc"))
+ })).ts
+
+ c := cst.Client()
+ res, err := c.Get(cst.URL)
+ if err == nil || res != nil {
+ t.Fatal("Expected a non-nil error and a nil http.Response")
+ }
+ if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) {
+ t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
+ }
+}
+
+// dumpConn is a net.Conn which writes to Writer and reads from Reader
+type dumpConn struct {
+ io.Writer
+ io.Reader
+}
+
+func (c *dumpConn) Close() error { return nil }
+func (c *dumpConn) LocalAddr() net.Addr { return nil }
+func (c *dumpConn) RemoteAddr() net.Addr { return nil }
+func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
+func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
+func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
+
+// delegateReader is a reader that delegates to another reader,
+// once it arrives on a channel.
+type delegateReader struct {
+ c chan io.Reader
+ r io.Reader // nil until received from c
+}
+
+func (r *delegateReader) Read(p []byte) (int, error) {
+ if r.r == nil {
+ var ok bool
+ if r.r, ok = <-r.c; !ok {
+ return 0, errors.New("delegate closed")
+ }
+ }
+ return r.r.Read(p)
+}
+
+func testTransportRace(req *Request) {
+ save := req.Body
+ pr, pw := io.Pipe()
+ defer pr.Close()
+ defer pw.Close()
+ dr := &delegateReader{c: make(chan io.Reader)}
+
+ t := &Transport{
+ Dial: func(net, addr string) (net.Conn, error) {
+ return &dumpConn{pw, dr}, nil
+ },
+ }
+ defer t.CloseIdleConnections()
+
+ quitReadCh := make(chan struct{})
+ // Wait for the request before replying with a dummy response:
+ go func() {
+ defer close(quitReadCh)
+
+ req, err := ReadRequest(bufio.NewReader(pr))
+ if err == nil {
+ // Ensure all the body is read; otherwise
+ // we'll get a partial dump.
+ io.Copy(io.Discard, req.Body)
+ req.Body.Close()
+ }
+ select {
+ case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
+ case quitReadCh <- struct{}{}:
+ // Ensure delegate is closed so Read doesn't block forever.
+ close(dr.c)
+ }
+ }()
+
+ t.RoundTrip(req)
+
+ // Ensure the reader returns before we reset req.Body to prevent
+ // a data race on req.Body.
+ pw.Close()
+ <-quitReadCh
+
+ req.Body = save
+}
+
+// Issue 37669
+// Test that a cancellation doesn't result in a data race due to the writeLoop
+// goroutine being left running, if the caller mutates the processed Request
+// upon completion.
+func TestErrorWriteLoopRace(t *testing.T) {
+ if testing.Short() {
+ return
+ }
+ t.Parallel()
+ for i := 0; i < 1000; i++ {
+ delay := time.Duration(mrand.Intn(5)) * time.Millisecond
+ ctx, cancel := context.WithTimeout(context.Background(), delay)
+ defer cancel()
+
+ r := bytes.NewBuffer(make([]byte, 10000))
+ req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ testTransportRace(req)
+ }
+}
+
+// Issue 41600
+// Test that a new request which uses the connection of an active request
+// cannot cause it to be canceled as well.
+func TestCancelRequestWhenSharingConnection(t *testing.T) {
+ run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode})
+}
+func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) {
+ reqc := make(chan chan struct{}, 2)
+ ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
+ ch := make(chan struct{}, 1)
+ reqc <- ch
+ <-ch
+ w.Header().Add("Content-Length", "0")
+ })).ts
+
+ client := ts.Client()
+ transport := client.Transport.(*Transport)
+ transport.MaxIdleConns = 1
+ transport.MaxConnsPerHost = 1
+
+ var wg sync.WaitGroup
+
+ wg.Add(1)
+ putidlec := make(chan chan struct{}, 1)
+ reqerrc := make(chan error, 1)
+ go func() {
+ defer wg.Done()
+ ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
+ PutIdleConn: func(error) {
+ // Signal that the idle conn has been returned to the pool,
+ // and wait for the order to proceed.
+ ch := make(chan struct{})
+ putidlec <- ch
+ close(putidlec) // panic if PutIdleConn runs twice for some reason
+ <-ch
+ },
+ })
+ req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
+ res, err := client.Do(req)
+ reqerrc <- err
+ if err == nil {
+ res.Body.Close()
+ }
+ }()
+
+ // Wait for the first request to receive a response and return the
+ // connection to the idle pool.
+ r1c := <-reqc
+ close(r1c)
+ var idlec chan struct{}
+ select {
+ case err := <-reqerrc:
+ if err != nil {
+ t.Fatalf("request 1: got err %v, want nil", err)
+ }
+ idlec = <-putidlec
+ case idlec = <-putidlec:
+ }
+
+ wg.Add(1)
+ cancelctx, cancel := context.WithCancel(context.Background())
+ go func() {
+ defer wg.Done()
+ req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil)
+ res, err := client.Do(req)
+ if err == nil {
+ res.Body.Close()
+ }
+ if !errors.Is(err, context.Canceled) {
+ t.Errorf("request 2: got err %v, want Canceled", err)
+ }
+
+ // Unblock the first request.
+ close(idlec)
+ }()
+
+ // Wait for the second request to arrive at the server, and then cancel
+ // the request context.
+ r2c := <-reqc
+ cancel()
+
+ <-idlec
+
+ close(r2c)
+ wg.Wait()
+}
+
+func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) }
+func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) {
+ ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ go io.Copy(io.Discard, req.Body)
+ panic(ErrAbortHandler)
+ })).ts
+
+ var wg sync.WaitGroup
+ for i := 0; i < 2; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for j := 0; j < 10; j++ {
+ const reqLen = 6 * 1024 * 1024
+ req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
+ req.ContentLength = reqLen
+ resp, _ := ts.Client().Transport.RoundTrip(req)
+ if resp != nil {
+ resp.Body.Close()
+ }
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+func TestRequestSanitization(t *testing.T) { run(t, testRequestSanitization) }
+func testRequestSanitization(t *testing.T, mode testMode) {
+ if mode == http2Mode {
+ // Remove this after updating x/net.
+ t.Skip("https://go.dev/issue/60374 test fails when run with HTTP/2")
+ }
+ ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ if h, ok := req.Header["X-Evil"]; ok {
+ t.Errorf("request has X-Evil header: %q", h)
+ }
+ })).ts
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req.Host = "go.dev\r\nX-Evil:evil"
+ resp, _ := ts.Client().Do(req)
+ if resp != nil {
+ resp.Body.Close()
+ }
+}