summaryrefslogtreecommitdiffstats
path: root/src/net/http/httputil
diff options
context:
space:
mode:
Diffstat (limited to 'src/net/http/httputil')
-rw-r--r--src/net/http/httputil/dump.go340
-rw-r--r--src/net/http/httputil/dump_test.go519
-rw-r--r--src/net/http/httputil/example_test.go123
-rw-r--r--src/net/http/httputil/httputil.go41
-rw-r--r--src/net/http/httputil/persist.go431
-rw-r--r--src/net/http/httputil/reverseproxy.go677
-rw-r--r--src/net/http/httputil/reverseproxy_test.go1613
7 files changed, 3744 insertions, 0 deletions
diff --git a/src/net/http/httputil/dump.go b/src/net/http/httputil/dump.go
new file mode 100644
index 0000000..d7baecd
--- /dev/null
+++ b/src/net/http/httputil/dump.go
@@ -0,0 +1,340 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httputil
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+)
+
+// drainBody reads all of b to memory and then returns two equivalent
+// ReadClosers yielding the same bytes.
+//
+// It returns an error if the initial slurp of all bytes fails. It does not attempt
+// to make the returned ReadClosers have identical error-matching behavior.
+func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) {
+ if b == nil || b == http.NoBody {
+ // No copying needed. Preserve the magic sentinel meaning of NoBody.
+ return http.NoBody, http.NoBody, nil
+ }
+ var buf bytes.Buffer
+ if _, err = buf.ReadFrom(b); err != nil {
+ return nil, b, err
+ }
+ if err = b.Close(); err != nil {
+ return nil, b, err
+ }
+ return io.NopCloser(&buf), io.NopCloser(bytes.NewReader(buf.Bytes())), nil
+}
+
+// dumpConn is a net.Conn which writes to Writer and reads from Reader
+type dumpConn struct {
+ io.Writer
+ io.Reader
+}
+
+func (c *dumpConn) Close() error { return nil }
+func (c *dumpConn) LocalAddr() net.Addr { return nil }
+func (c *dumpConn) RemoteAddr() net.Addr { return nil }
+func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
+func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
+func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
+
+type neverEnding byte
+
+func (b neverEnding) Read(p []byte) (n int, err error) {
+ for i := range p {
+ p[i] = byte(b)
+ }
+ return len(p), nil
+}
+
+// outGoingLength is a copy of the unexported
+// (*http.Request).outgoingLength method.
+func outgoingLength(req *http.Request) int64 {
+ if req.Body == nil || req.Body == http.NoBody {
+ return 0
+ }
+ if req.ContentLength != 0 {
+ return req.ContentLength
+ }
+ return -1
+}
+
+// DumpRequestOut is like DumpRequest but for outgoing client requests. It
+// includes any headers that the standard http.Transport adds, such as
+// User-Agent.
+func DumpRequestOut(req *http.Request, body bool) ([]byte, error) {
+ save := req.Body
+ dummyBody := false
+ if !body {
+ contentLength := outgoingLength(req)
+ if contentLength != 0 {
+ req.Body = io.NopCloser(io.LimitReader(neverEnding('x'), contentLength))
+ dummyBody = true
+ }
+ } else {
+ var err error
+ save, req.Body, err = drainBody(req.Body)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // Since we're using the actual Transport code to write the request,
+ // switch to http so the Transport doesn't try to do an SSL
+ // negotiation with our dumpConn and its bytes.Buffer & pipe.
+ // The wire format for https and http are the same, anyway.
+ reqSend := req
+ if req.URL.Scheme == "https" {
+ reqSend = new(http.Request)
+ *reqSend = *req
+ reqSend.URL = new(url.URL)
+ *reqSend.URL = *req.URL
+ reqSend.URL.Scheme = "http"
+ }
+
+ // Use the actual Transport code to record what we would send
+ // on the wire, but not using TCP. Use a Transport with a
+ // custom dialer that returns a fake net.Conn that waits
+ // for the full input (and recording it), and then responds
+ // with a dummy response.
+ var buf bytes.Buffer // records the output
+ pr, pw := io.Pipe()
+ defer pr.Close()
+ defer pw.Close()
+ dr := &delegateReader{c: make(chan io.Reader)}
+
+ t := &http.Transport{
+ Dial: func(net, addr string) (net.Conn, error) {
+ return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil
+ },
+ }
+ defer t.CloseIdleConnections()
+
+ // We need this channel to ensure that the reader
+ // goroutine exits if t.RoundTrip returns an error.
+ // See golang.org/issue/32571.
+ quitReadCh := make(chan struct{})
+ // Wait for the request before replying with a dummy response:
+ go func() {
+ req, err := http.ReadRequest(bufio.NewReader(pr))
+ if err == nil {
+ // Ensure all the body is read; otherwise
+ // we'll get a partial dump.
+ io.Copy(io.Discard, req.Body)
+ req.Body.Close()
+ }
+ select {
+ case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
+ case <-quitReadCh:
+ // Ensure delegateReader.Read doesn't block forever if we get an error.
+ close(dr.c)
+ }
+ }()
+
+ _, err := t.RoundTrip(reqSend)
+
+ req.Body = save
+ if err != nil {
+ pw.Close()
+ dr.err = err
+ close(quitReadCh)
+ return nil, err
+ }
+ dump := buf.Bytes()
+
+ // If we used a dummy body above, remove it now.
+ // TODO: if the req.ContentLength is large, we allocate memory
+ // unnecessarily just to slice it off here. But this is just
+ // a debug function, so this is acceptable for now. We could
+ // discard the body earlier if this matters.
+ if dummyBody {
+ if i := bytes.Index(dump, []byte("\r\n\r\n")); i >= 0 {
+ dump = dump[:i+4]
+ }
+ }
+ return dump, nil
+}
+
+// delegateReader is a reader that delegates to another reader,
+// once it arrives on a channel.
+type delegateReader struct {
+ c chan io.Reader
+ err error // only used if r is nil and c is closed.
+ r io.Reader // nil until received from c
+}
+
+func (r *delegateReader) Read(p []byte) (int, error) {
+ if r.r == nil {
+ var ok bool
+ if r.r, ok = <-r.c; !ok {
+ return 0, r.err
+ }
+ }
+ return r.r.Read(p)
+}
+
+// Return value if nonempty, def otherwise.
+func valueOrDefault(value, def string) string {
+ if value != "" {
+ return value
+ }
+ return def
+}
+
+var reqWriteExcludeHeaderDump = map[string]bool{
+ "Host": true, // not in Header map anyway
+ "Transfer-Encoding": true,
+ "Trailer": true,
+}
+
+// DumpRequest returns the given request in its HTTP/1.x wire
+// representation. It should only be used by servers to debug client
+// requests. The returned representation is an approximation only;
+// some details of the initial request are lost while parsing it into
+// an http.Request. In particular, the order and case of header field
+// names are lost. The order of values in multi-valued headers is kept
+// intact. HTTP/2 requests are dumped in HTTP/1.x form, not in their
+// original binary representations.
+//
+// If body is true, DumpRequest also returns the body. To do so, it
+// consumes req.Body and then replaces it with a new io.ReadCloser
+// that yields the same bytes. If DumpRequest returns an error,
+// the state of req is undefined.
+//
+// The documentation for http.Request.Write details which fields
+// of req are included in the dump.
+func DumpRequest(req *http.Request, body bool) ([]byte, error) {
+ var err error
+ save := req.Body
+ if !body || req.Body == nil {
+ req.Body = nil
+ } else {
+ save, req.Body, err = drainBody(req.Body)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ var b bytes.Buffer
+
+ // By default, print out the unmodified req.RequestURI, which
+ // is always set for incoming server requests. But because we
+ // previously used req.URL.RequestURI and the docs weren't
+ // always so clear about when to use DumpRequest vs
+ // DumpRequestOut, fall back to the old way if the caller
+ // provides a non-server Request.
+ reqURI := req.RequestURI
+ if reqURI == "" {
+ reqURI = req.URL.RequestURI()
+ }
+
+ fmt.Fprintf(&b, "%s %s HTTP/%d.%d\r\n", valueOrDefault(req.Method, "GET"),
+ reqURI, req.ProtoMajor, req.ProtoMinor)
+
+ absRequestURI := strings.HasPrefix(req.RequestURI, "http://") || strings.HasPrefix(req.RequestURI, "https://")
+ if !absRequestURI {
+ host := req.Host
+ if host == "" && req.URL != nil {
+ host = req.URL.Host
+ }
+ if host != "" {
+ fmt.Fprintf(&b, "Host: %s\r\n", host)
+ }
+ }
+
+ chunked := len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked"
+ if len(req.TransferEncoding) > 0 {
+ fmt.Fprintf(&b, "Transfer-Encoding: %s\r\n", strings.Join(req.TransferEncoding, ","))
+ }
+ if req.Close {
+ fmt.Fprintf(&b, "Connection: close\r\n")
+ }
+
+ err = req.Header.WriteSubset(&b, reqWriteExcludeHeaderDump)
+ if err != nil {
+ return nil, err
+ }
+
+ io.WriteString(&b, "\r\n")
+
+ if req.Body != nil {
+ var dest io.Writer = &b
+ if chunked {
+ dest = NewChunkedWriter(dest)
+ }
+ _, err = io.Copy(dest, req.Body)
+ if chunked {
+ dest.(io.Closer).Close()
+ io.WriteString(&b, "\r\n")
+ }
+ }
+
+ req.Body = save
+ if err != nil {
+ return nil, err
+ }
+ return b.Bytes(), nil
+}
+
+// errNoBody is a sentinel error value used by failureToReadBody so we
+// can detect that the lack of body was intentional.
+var errNoBody = errors.New("sentinel error value")
+
+// failureToReadBody is an io.ReadCloser that just returns errNoBody on
+// Read. It's swapped in when we don't actually want to consume
+// the body, but need a non-nil one, and want to distinguish the
+// error from reading the dummy body.
+type failureToReadBody struct{}
+
+func (failureToReadBody) Read([]byte) (int, error) { return 0, errNoBody }
+func (failureToReadBody) Close() error { return nil }
+
+// emptyBody is an instance of empty reader.
+var emptyBody = io.NopCloser(strings.NewReader(""))
+
+// DumpResponse is like DumpRequest but dumps a response.
+func DumpResponse(resp *http.Response, body bool) ([]byte, error) {
+ var b bytes.Buffer
+ var err error
+ save := resp.Body
+ savecl := resp.ContentLength
+
+ if !body {
+ // For content length of zero. Make sure the body is an empty
+ // reader, instead of returning error through failureToReadBody{}.
+ if resp.ContentLength == 0 {
+ resp.Body = emptyBody
+ } else {
+ resp.Body = failureToReadBody{}
+ }
+ } else if resp.Body == nil {
+ resp.Body = emptyBody
+ } else {
+ save, resp.Body, err = drainBody(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+ }
+ err = resp.Write(&b)
+ if err == errNoBody {
+ err = nil
+ }
+ resp.Body = save
+ resp.ContentLength = savecl
+ if err != nil {
+ return nil, err
+ }
+ return b.Bytes(), nil
+}
diff --git a/src/net/http/httputil/dump_test.go b/src/net/http/httputil/dump_test.go
new file mode 100644
index 0000000..5df2ee8
--- /dev/null
+++ b/src/net/http/httputil/dump_test.go
@@ -0,0 +1,519 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httputil
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "math/rand"
+ "net/http"
+ "net/url"
+ "runtime"
+ "runtime/pprof"
+ "strings"
+ "testing"
+ "time"
+)
+
+type eofReader struct{}
+
+func (n eofReader) Close() error { return nil }
+
+func (n eofReader) Read([]byte) (int, error) { return 0, io.EOF }
+
+type dumpTest struct {
+ // Either Req or GetReq can be set/nil but not both.
+ Req *http.Request
+ GetReq func() *http.Request
+
+ Body any // optional []byte or func() io.ReadCloser to populate Req.Body
+
+ WantDump string
+ WantDumpOut string
+ MustError bool // if true, the test is expected to throw an error
+ NoBody bool // if true, set DumpRequest{,Out} body to false
+}
+
+var dumpTests = []dumpTest{
+ // HTTP/1.1 => chunked coding; body; empty trailer
+ {
+ Req: &http.Request{
+ Method: "GET",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "/search",
+ },
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ TransferEncoding: []string{"chunked"},
+ },
+
+ Body: []byte("abcdef"),
+
+ WantDump: "GET /search HTTP/1.1\r\n" +
+ "Host: www.google.com\r\n" +
+ "Transfer-Encoding: chunked\r\n\r\n" +
+ chunk("abcdef") + chunk(""),
+ },
+
+ // Verify that DumpRequest preserves the HTTP version number, doesn't add a Host,
+ // and doesn't add a User-Agent.
+ {
+ Req: &http.Request{
+ Method: "GET",
+ URL: mustParseURL("/foo"),
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Header: http.Header{
+ "X-Foo": []string{"X-Bar"},
+ },
+ },
+
+ WantDump: "GET /foo HTTP/1.0\r\n" +
+ "X-Foo: X-Bar\r\n\r\n",
+ },
+
+ {
+ Req: mustNewRequest("GET", "http://example.com/foo", nil),
+
+ WantDumpOut: "GET /foo HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Accept-Encoding: gzip\r\n\r\n",
+ },
+
+ // Test that an https URL doesn't try to do an SSL negotiation
+ // with a bytes.Buffer and hang with all goroutines not
+ // runnable.
+ {
+ Req: mustNewRequest("GET", "https://example.com/foo", nil),
+ WantDumpOut: "GET /foo HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Accept-Encoding: gzip\r\n\r\n",
+ },
+
+ // Request with Body, but Dump requested without it.
+ {
+ Req: &http.Request{
+ Method: "POST",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "post.tld",
+ Path: "/",
+ },
+ ContentLength: 6,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ },
+
+ Body: []byte("abcdef"),
+
+ WantDumpOut: "POST / HTTP/1.1\r\n" +
+ "Host: post.tld\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Content-Length: 6\r\n" +
+ "Accept-Encoding: gzip\r\n\r\n",
+
+ NoBody: true,
+ },
+
+ // Request with Body > 8196 (default buffer size)
+ {
+ Req: &http.Request{
+ Method: "POST",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "post.tld",
+ Path: "/",
+ },
+ Header: http.Header{
+ "Content-Length": []string{"8193"},
+ },
+
+ ContentLength: 8193,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ },
+
+ Body: bytes.Repeat([]byte("a"), 8193),
+
+ WantDumpOut: "POST / HTTP/1.1\r\n" +
+ "Host: post.tld\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Content-Length: 8193\r\n" +
+ "Accept-Encoding: gzip\r\n\r\n" +
+ strings.Repeat("a", 8193),
+ WantDump: "POST / HTTP/1.1\r\n" +
+ "Host: post.tld\r\n" +
+ "Content-Length: 8193\r\n\r\n" +
+ strings.Repeat("a", 8193),
+ },
+
+ {
+ GetReq: func() *http.Request {
+ return mustReadRequest("GET http://foo.com/ HTTP/1.1\r\n" +
+ "User-Agent: blah\r\n\r\n")
+ },
+ NoBody: true,
+ WantDump: "GET http://foo.com/ HTTP/1.1\r\n" +
+ "User-Agent: blah\r\n\r\n",
+ },
+
+ // Issue #7215. DumpRequest should return the "Content-Length" when set
+ {
+ GetReq: func() *http.Request {
+ return mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" +
+ "Host: passport.myhost.com\r\n" +
+ "Content-Length: 3\r\n" +
+ "\r\nkey1=name1&key2=name2")
+ },
+ WantDump: "POST /v2/api/?login HTTP/1.1\r\n" +
+ "Host: passport.myhost.com\r\n" +
+ "Content-Length: 3\r\n" +
+ "\r\nkey",
+ },
+ // Issue #7215. DumpRequest should return the "Content-Length" in ReadRequest
+ {
+ GetReq: func() *http.Request {
+ return mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" +
+ "Host: passport.myhost.com\r\n" +
+ "Content-Length: 0\r\n" +
+ "\r\nkey1=name1&key2=name2")
+ },
+ WantDump: "POST /v2/api/?login HTTP/1.1\r\n" +
+ "Host: passport.myhost.com\r\n" +
+ "Content-Length: 0\r\n\r\n",
+ },
+
+ // Issue #7215. DumpRequest should not return the "Content-Length" if unset
+ {
+ GetReq: func() *http.Request {
+ return mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" +
+ "Host: passport.myhost.com\r\n" +
+ "\r\nkey1=name1&key2=name2")
+ },
+ WantDump: "POST /v2/api/?login HTTP/1.1\r\n" +
+ "Host: passport.myhost.com\r\n\r\n",
+ },
+
+ // Issue 18506: make drainBody recognize NoBody. Otherwise
+ // this was turning into a chunked request.
+ {
+ Req: mustNewRequest("POST", "http://example.com/foo", http.NoBody),
+ WantDumpOut: "POST /foo HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Content-Length: 0\r\n" +
+ "Accept-Encoding: gzip\r\n\r\n",
+ },
+
+ // Issue 34504: a non-nil Body without ContentLength set should be chunked
+ {
+ Req: &http.Request{
+ Method: "PUT",
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "post.tld",
+ Path: "/test",
+ },
+ ContentLength: 0,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ Body: &eofReader{},
+ },
+ NoBody: true,
+ WantDumpOut: "PUT /test HTTP/1.1\r\n" +
+ "Host: post.tld\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Transfer-Encoding: chunked\r\n" +
+ "Accept-Encoding: gzip\r\n\r\n",
+ },
+}
+
+func TestDumpRequest(t *testing.T) {
+ // Make a copy of dumpTests and add 10 new cases with an empty URL
+ // to test that no goroutines are leaked. See golang.org/issue/32571.
+ // 10 seems to be a decent number which always triggers the failure.
+ dumpTests := dumpTests[:]
+ for i := 0; i < 10; i++ {
+ dumpTests = append(dumpTests, dumpTest{
+ Req: mustNewRequest("GET", "", nil),
+ MustError: true,
+ })
+ }
+ numg0 := runtime.NumGoroutine()
+ for i, tt := range dumpTests {
+ if tt.Req != nil && tt.GetReq != nil || tt.Req == nil && tt.GetReq == nil {
+ t.Errorf("#%d: either .Req(%p) or .GetReq(%p) can be set/nil but not both", i, tt.Req, tt.GetReq)
+ continue
+ }
+
+ freshReq := func(ti dumpTest) *http.Request {
+ req := ti.Req
+ if req == nil {
+ req = ti.GetReq()
+ }
+
+ if req.Header == nil {
+ req.Header = make(http.Header)
+ }
+
+ if ti.Body == nil {
+ return req
+ }
+ switch b := ti.Body.(type) {
+ case []byte:
+ req.Body = io.NopCloser(bytes.NewReader(b))
+ case func() io.ReadCloser:
+ req.Body = b()
+ default:
+ t.Fatalf("Test %d: unsupported Body of %T", i, ti.Body)
+ }
+ return req
+ }
+
+ if tt.WantDump != "" {
+ req := freshReq(tt)
+ dump, err := DumpRequest(req, !tt.NoBody)
+ if err != nil {
+ t.Errorf("DumpRequest #%d: %s\nWantDump:\n%s", i, err, tt.WantDump)
+ continue
+ }
+ if string(dump) != tt.WantDump {
+ t.Errorf("DumpRequest %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantDump, string(dump))
+ continue
+ }
+ }
+
+ if tt.MustError {
+ req := freshReq(tt)
+ _, err := DumpRequestOut(req, !tt.NoBody)
+ if err == nil {
+ t.Errorf("DumpRequestOut #%d: expected an error, got nil", i)
+ }
+ continue
+ }
+
+ if tt.WantDumpOut != "" {
+ req := freshReq(tt)
+ dump, err := DumpRequestOut(req, !tt.NoBody)
+ if err != nil {
+ t.Errorf("DumpRequestOut #%d: %s", i, err)
+ continue
+ }
+ if string(dump) != tt.WantDumpOut {
+ t.Errorf("DumpRequestOut %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantDumpOut, string(dump))
+ continue
+ }
+ }
+ }
+
+ // Validate we haven't leaked any goroutines.
+ var dg int
+ dl := deadline(t, 5*time.Second, time.Second)
+ for time.Now().Before(dl) {
+ if dg = runtime.NumGoroutine() - numg0; dg <= 4 {
+ // No unexpected goroutines.
+ return
+ }
+
+ // Allow goroutines to schedule and die off.
+ runtime.Gosched()
+ }
+
+ buf := make([]byte, 4096)
+ buf = buf[:runtime.Stack(buf, true)]
+ t.Errorf("Unexpectedly large number of new goroutines: %d new: %s", dg, buf)
+}
+
+// deadline returns the time which is needed before t.Deadline()
+// if one is configured and it is s greater than needed in the future,
+// otherwise defaultDelay from the current time.
+func deadline(t *testing.T, defaultDelay, needed time.Duration) time.Time {
+ if dl, ok := t.Deadline(); ok {
+ if dl = dl.Add(-needed); dl.After(time.Now()) {
+ // Allow an arbitrarily long delay.
+ return dl
+ }
+ }
+
+ // No deadline configured or its closer than needed from now
+ // so just use the default.
+ return time.Now().Add(defaultDelay)
+}
+
+func chunk(s string) string {
+ return fmt.Sprintf("%x\r\n%s\r\n", len(s), s)
+}
+
+func mustParseURL(s string) *url.URL {
+ u, err := url.Parse(s)
+ if err != nil {
+ panic(fmt.Sprintf("Error parsing URL %q: %v", s, err))
+ }
+ return u
+}
+
+func mustNewRequest(method, url string, body io.Reader) *http.Request {
+ req, err := http.NewRequest(method, url, body)
+ if err != nil {
+ panic(fmt.Sprintf("NewRequest(%q, %q, %p) err = %v", method, url, body, err))
+ }
+ return req
+}
+
+func mustReadRequest(s string) *http.Request {
+ req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(s)))
+ if err != nil {
+ panic(err)
+ }
+ return req
+}
+
+var dumpResTests = []struct {
+ res *http.Response
+ body bool
+ want string
+}{
+ {
+ res: &http.Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 50,
+ Header: http.Header{
+ "Foo": []string{"Bar"},
+ },
+ Body: io.NopCloser(strings.NewReader("foo")), // shouldn't be used
+ },
+ body: false, // to verify we see 50, not empty or 3.
+ want: `HTTP/1.1 200 OK
+Content-Length: 50
+Foo: Bar`,
+ },
+
+ {
+ res: &http.Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 3,
+ Body: io.NopCloser(strings.NewReader("foo")),
+ },
+ body: true,
+ want: `HTTP/1.1 200 OK
+Content-Length: 3
+
+foo`,
+ },
+
+ {
+ res: &http.Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: -1,
+ Body: io.NopCloser(strings.NewReader("foo")),
+ TransferEncoding: []string{"chunked"},
+ },
+ body: true,
+ want: `HTTP/1.1 200 OK
+Transfer-Encoding: chunked
+
+3
+foo
+0`,
+ },
+ {
+ res: &http.Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.1",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 0,
+ Header: http.Header{
+ // To verify if headers are not filtered out.
+ "Foo1": []string{"Bar1"},
+ "Foo2": []string{"Bar2"},
+ },
+ Body: nil,
+ },
+ body: false, // to verify we see 0, not empty.
+ want: `HTTP/1.1 200 OK
+Foo1: Bar1
+Foo2: Bar2
+Content-Length: 0`,
+ },
+}
+
+func TestDumpResponse(t *testing.T) {
+ for i, tt := range dumpResTests {
+ gotb, err := DumpResponse(tt.res, tt.body)
+ if err != nil {
+ t.Errorf("%d. DumpResponse = %v", i, err)
+ continue
+ }
+ got := string(gotb)
+ got = strings.TrimSpace(got)
+ got = strings.ReplaceAll(got, "\r", "")
+
+ if got != tt.want {
+ t.Errorf("%d.\nDumpResponse got:\n%s\n\nWant:\n%s\n", i, got, tt.want)
+ }
+ }
+}
+
+// Issue 38352: Check for deadlock on canceled requests.
+func TestDumpRequestOutIssue38352(t *testing.T) {
+ if testing.Short() {
+ return
+ }
+ t.Parallel()
+
+ timeout := 10 * time.Second
+ if deadline, ok := t.Deadline(); ok {
+ timeout = time.Until(deadline)
+ timeout -= time.Second * 2 // Leave 2 seconds to report failures.
+ }
+ for i := 0; i < 1000; i++ {
+ delay := time.Duration(rand.Intn(5)) * time.Millisecond
+ ctx, cancel := context.WithTimeout(context.Background(), delay)
+ defer cancel()
+
+ r := bytes.NewBuffer(make([]byte, 10000))
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://example.com", r)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ out := make(chan error)
+ go func() {
+ _, err = DumpRequestOut(req, true)
+ out <- err
+ }()
+
+ select {
+ case <-out:
+ case <-time.After(timeout):
+ b := &bytes.Buffer{}
+ fmt.Fprintf(b, "deadlock detected on iteration %d after %s with delay: %v\n", i, timeout, delay)
+ pprof.Lookup("goroutine").WriteTo(b, 1)
+ t.Fatal(b.String())
+ }
+ }
+}
diff --git a/src/net/http/httputil/example_test.go b/src/net/http/httputil/example_test.go
new file mode 100644
index 0000000..b77a243
--- /dev/null
+++ b/src/net/http/httputil/example_test.go
@@ -0,0 +1,123 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httputil_test
+
+import (
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "net/http/httptest"
+ "net/http/httputil"
+ "net/url"
+ "strings"
+)
+
+func ExampleDumpRequest() {
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ dump, err := httputil.DumpRequest(r, true)
+ if err != nil {
+ http.Error(w, fmt.Sprint(err), http.StatusInternalServerError)
+ return
+ }
+
+ fmt.Fprintf(w, "%q", dump)
+ }))
+ defer ts.Close()
+
+ const body = "Go is a general-purpose language designed with systems programming in mind."
+ req, err := http.NewRequest("POST", ts.URL, strings.NewReader(body))
+ if err != nil {
+ log.Fatal(err)
+ }
+ req.Host = "www.example.org"
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer resp.Body.Close()
+
+ b, err := io.ReadAll(resp.Body)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Printf("%s", b)
+
+ // Output:
+ // "POST / HTTP/1.1\r\nHost: www.example.org\r\nAccept-Encoding: gzip\r\nContent-Length: 75\r\nUser-Agent: Go-http-client/1.1\r\n\r\nGo is a general-purpose language designed with systems programming in mind."
+}
+
+func ExampleDumpRequestOut() {
+ const body = "Go is a general-purpose language designed with systems programming in mind."
+ req, err := http.NewRequest("PUT", "http://www.example.org", strings.NewReader(body))
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ dump, err := httputil.DumpRequestOut(req, true)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Printf("%q", dump)
+
+ // Output:
+ // "PUT / HTTP/1.1\r\nHost: www.example.org\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 75\r\nAccept-Encoding: gzip\r\n\r\nGo is a general-purpose language designed with systems programming in mind."
+}
+
+func ExampleDumpResponse() {
+ const body = "Go is a general-purpose language designed with systems programming in mind."
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Date", "Wed, 19 Jul 1972 19:00:00 GMT")
+ fmt.Fprintln(w, body)
+ }))
+ defer ts.Close()
+
+ resp, err := http.Get(ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer resp.Body.Close()
+
+ dump, err := httputil.DumpResponse(resp, true)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Printf("%q", dump)
+
+ // Output:
+ // "HTTP/1.1 200 OK\r\nContent-Length: 76\r\nContent-Type: text/plain; charset=utf-8\r\nDate: Wed, 19 Jul 1972 19:00:00 GMT\r\n\r\nGo is a general-purpose language designed with systems programming in mind.\n"
+}
+
+func ExampleReverseProxy() {
+ backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintln(w, "this call was relayed by the reverse proxy")
+ }))
+ defer backendServer.Close()
+
+ rpURL, err := url.Parse(backendServer.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+ frontendProxy := httptest.NewServer(httputil.NewSingleHostReverseProxy(rpURL))
+ defer frontendProxy.Close()
+
+ resp, err := http.Get(frontendProxy.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ b, err := io.ReadAll(resp.Body)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Printf("%s", b)
+
+ // Output:
+ // this call was relayed by the reverse proxy
+}
diff --git a/src/net/http/httputil/httputil.go b/src/net/http/httputil/httputil.go
new file mode 100644
index 0000000..09ea74d
--- /dev/null
+++ b/src/net/http/httputil/httputil.go
@@ -0,0 +1,41 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package httputil provides HTTP utility functions, complementing the
+// more common ones in the net/http package.
+package httputil
+
+import (
+ "io"
+ "net/http/internal"
+)
+
+// NewChunkedReader returns a new chunkedReader that translates the data read from r
+// out of HTTP "chunked" format before returning it.
+// The chunkedReader returns io.EOF when the final 0-length chunk is read.
+//
+// NewChunkedReader is not needed by normal applications. The http package
+// automatically decodes chunking when reading response bodies.
+func NewChunkedReader(r io.Reader) io.Reader {
+ return internal.NewChunkedReader(r)
+}
+
+// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP
+// "chunked" format before writing them to w. Closing the returned chunkedWriter
+// sends the final 0-length chunk that marks the end of the stream but does
+// not send the final CRLF that appears after trailers; trailers and the last
+// CRLF must be written separately.
+//
+// NewChunkedWriter is not needed by normal applications. The http
+// package adds chunking automatically if handlers don't set a
+// Content-Length header. Using NewChunkedWriter inside a handler
+// would result in double chunking or chunking with a Content-Length
+// length, both of which are wrong.
+func NewChunkedWriter(w io.Writer) io.WriteCloser {
+ return internal.NewChunkedWriter(w)
+}
+
+// ErrLineTooLong is returned when reading malformed chunked data
+// with lines that are too long.
+var ErrLineTooLong = internal.ErrLineTooLong
diff --git a/src/net/http/httputil/persist.go b/src/net/http/httputil/persist.go
new file mode 100644
index 0000000..84b116d
--- /dev/null
+++ b/src/net/http/httputil/persist.go
@@ -0,0 +1,431 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httputil
+
+import (
+ "bufio"
+ "errors"
+ "io"
+ "net"
+ "net/http"
+ "net/textproto"
+ "sync"
+)
+
+var (
+ // Deprecated: No longer used.
+ ErrPersistEOF = &http.ProtocolError{ErrorString: "persistent connection closed"}
+
+ // Deprecated: No longer used.
+ ErrClosed = &http.ProtocolError{ErrorString: "connection closed by user"}
+
+ // Deprecated: No longer used.
+ ErrPipeline = &http.ProtocolError{ErrorString: "pipeline error"}
+)
+
+// This is an API usage error - the local side is closed.
+// ErrPersistEOF (above) reports that the remote side is closed.
+var errClosed = errors.New("i/o operation on closed connection")
+
+// ServerConn is an artifact of Go's early HTTP implementation.
+// It is low-level, old, and unused by Go's current HTTP stack.
+// We should have deleted it before Go 1.
+//
+// Deprecated: Use the Server in package net/http instead.
+type ServerConn struct {
+ mu sync.Mutex // read-write protects the following fields
+ c net.Conn
+ r *bufio.Reader
+ re, we error // read/write errors
+ lastbody io.ReadCloser
+ nread, nwritten int
+ pipereq map[*http.Request]uint
+
+ pipe textproto.Pipeline
+}
+
+// NewServerConn is an artifact of Go's early HTTP implementation.
+// It is low-level, old, and unused by Go's current HTTP stack.
+// We should have deleted it before Go 1.
+//
+// Deprecated: Use the Server in package net/http instead.
+func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn {
+ if r == nil {
+ r = bufio.NewReader(c)
+ }
+ return &ServerConn{c: c, r: r, pipereq: make(map[*http.Request]uint)}
+}
+
+// Hijack detaches the ServerConn and returns the underlying connection as well
+// as the read-side bufio which may have some left over data. Hijack may be
+// called before Read has signaled the end of the keep-alive logic. The user
+// should not call Hijack while Read or Write is in progress.
+func (sc *ServerConn) Hijack() (net.Conn, *bufio.Reader) {
+ sc.mu.Lock()
+ defer sc.mu.Unlock()
+ c := sc.c
+ r := sc.r
+ sc.c = nil
+ sc.r = nil
+ return c, r
+}
+
+// Close calls Hijack and then also closes the underlying connection.
+func (sc *ServerConn) Close() error {
+ c, _ := sc.Hijack()
+ if c != nil {
+ return c.Close()
+ }
+ return nil
+}
+
+// Read returns the next request on the wire. An ErrPersistEOF is returned if
+// it is gracefully determined that there are no more requests (e.g. after the
+// first request on an HTTP/1.0 connection, or after a Connection:close on a
+// HTTP/1.1 connection).
+func (sc *ServerConn) Read() (*http.Request, error) {
+ var req *http.Request
+ var err error
+
+ // Ensure ordered execution of Reads and Writes
+ id := sc.pipe.Next()
+ sc.pipe.StartRequest(id)
+ defer func() {
+ sc.pipe.EndRequest(id)
+ if req == nil {
+ sc.pipe.StartResponse(id)
+ sc.pipe.EndResponse(id)
+ } else {
+ // Remember the pipeline id of this request
+ sc.mu.Lock()
+ sc.pipereq[req] = id
+ sc.mu.Unlock()
+ }
+ }()
+
+ sc.mu.Lock()
+ if sc.we != nil { // no point receiving if write-side broken or closed
+ defer sc.mu.Unlock()
+ return nil, sc.we
+ }
+ if sc.re != nil {
+ defer sc.mu.Unlock()
+ return nil, sc.re
+ }
+ if sc.r == nil { // connection closed by user in the meantime
+ defer sc.mu.Unlock()
+ return nil, errClosed
+ }
+ r := sc.r
+ lastbody := sc.lastbody
+ sc.lastbody = nil
+ sc.mu.Unlock()
+
+ // Make sure body is fully consumed, even if user does not call body.Close
+ if lastbody != nil {
+ // body.Close is assumed to be idempotent and multiple calls to
+ // it should return the error that its first invocation
+ // returned.
+ err = lastbody.Close()
+ if err != nil {
+ sc.mu.Lock()
+ defer sc.mu.Unlock()
+ sc.re = err
+ return nil, err
+ }
+ }
+
+ req, err = http.ReadRequest(r)
+ sc.mu.Lock()
+ defer sc.mu.Unlock()
+ if err != nil {
+ if err == io.ErrUnexpectedEOF {
+ // A close from the opposing client is treated as a
+ // graceful close, even if there was some unparse-able
+ // data before the close.
+ sc.re = ErrPersistEOF
+ return nil, sc.re
+ } else {
+ sc.re = err
+ return req, err
+ }
+ }
+ sc.lastbody = req.Body
+ sc.nread++
+ if req.Close {
+ sc.re = ErrPersistEOF
+ return req, sc.re
+ }
+ return req, err
+}
+
+// Pending returns the number of unanswered requests
+// that have been received on the connection.
+func (sc *ServerConn) Pending() int {
+ sc.mu.Lock()
+ defer sc.mu.Unlock()
+ return sc.nread - sc.nwritten
+}
+
+// Write writes resp in response to req. To close the connection gracefully, set the
+// Response.Close field to true. Write should be considered operational until
+// it returns an error, regardless of any errors returned on the Read side.
+func (sc *ServerConn) Write(req *http.Request, resp *http.Response) error {
+
+ // Retrieve the pipeline ID of this request/response pair
+ sc.mu.Lock()
+ id, ok := sc.pipereq[req]
+ delete(sc.pipereq, req)
+ if !ok {
+ sc.mu.Unlock()
+ return ErrPipeline
+ }
+ sc.mu.Unlock()
+
+ // Ensure pipeline order
+ sc.pipe.StartResponse(id)
+ defer sc.pipe.EndResponse(id)
+
+ sc.mu.Lock()
+ if sc.we != nil {
+ defer sc.mu.Unlock()
+ return sc.we
+ }
+ if sc.c == nil { // connection closed by user in the meantime
+ defer sc.mu.Unlock()
+ return ErrClosed
+ }
+ c := sc.c
+ if sc.nread <= sc.nwritten {
+ defer sc.mu.Unlock()
+ return errors.New("persist server pipe count")
+ }
+ if resp.Close {
+ // After signaling a keep-alive close, any pipelined unread
+ // requests will be lost. It is up to the user to drain them
+ // before signaling.
+ sc.re = ErrPersistEOF
+ }
+ sc.mu.Unlock()
+
+ err := resp.Write(c)
+ sc.mu.Lock()
+ defer sc.mu.Unlock()
+ if err != nil {
+ sc.we = err
+ return err
+ }
+ sc.nwritten++
+
+ return nil
+}
+
+// ClientConn is an artifact of Go's early HTTP implementation.
+// It is low-level, old, and unused by Go's current HTTP stack.
+// We should have deleted it before Go 1.
+//
+// Deprecated: Use Client or Transport in package net/http instead.
+type ClientConn struct {
+ mu sync.Mutex // read-write protects the following fields
+ c net.Conn
+ r *bufio.Reader
+ re, we error // read/write errors
+ lastbody io.ReadCloser
+ nread, nwritten int
+ pipereq map[*http.Request]uint
+
+ pipe textproto.Pipeline
+ writeReq func(*http.Request, io.Writer) error
+}
+
+// NewClientConn is an artifact of Go's early HTTP implementation.
+// It is low-level, old, and unused by Go's current HTTP stack.
+// We should have deleted it before Go 1.
+//
+// Deprecated: Use the Client or Transport in package net/http instead.
+func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
+ if r == nil {
+ r = bufio.NewReader(c)
+ }
+ return &ClientConn{
+ c: c,
+ r: r,
+ pipereq: make(map[*http.Request]uint),
+ writeReq: (*http.Request).Write,
+ }
+}
+
+// NewProxyClientConn is an artifact of Go's early HTTP implementation.
+// It is low-level, old, and unused by Go's current HTTP stack.
+// We should have deleted it before Go 1.
+//
+// Deprecated: Use the Client or Transport in package net/http instead.
+func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
+ cc := NewClientConn(c, r)
+ cc.writeReq = (*http.Request).WriteProxy
+ return cc
+}
+
+// Hijack detaches the ClientConn and returns the underlying connection as well
+// as the read-side bufio which may have some left over data. Hijack may be
+// called before the user or Read have signaled the end of the keep-alive
+// logic. The user should not call Hijack while Read or Write is in progress.
+func (cc *ClientConn) Hijack() (c net.Conn, r *bufio.Reader) {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ c = cc.c
+ r = cc.r
+ cc.c = nil
+ cc.r = nil
+ return
+}
+
+// Close calls Hijack and then also closes the underlying connection.
+func (cc *ClientConn) Close() error {
+ c, _ := cc.Hijack()
+ if c != nil {
+ return c.Close()
+ }
+ return nil
+}
+
+// Write writes a request. An ErrPersistEOF error is returned if the connection
+// has been closed in an HTTP keep-alive sense. If req.Close equals true, the
+// keep-alive connection is logically closed after this request and the opposing
+// server is informed. An ErrUnexpectedEOF indicates the remote closed the
+// underlying TCP connection, which is usually considered as graceful close.
+func (cc *ClientConn) Write(req *http.Request) error {
+ var err error
+
+ // Ensure ordered execution of Writes
+ id := cc.pipe.Next()
+ cc.pipe.StartRequest(id)
+ defer func() {
+ cc.pipe.EndRequest(id)
+ if err != nil {
+ cc.pipe.StartResponse(id)
+ cc.pipe.EndResponse(id)
+ } else {
+ // Remember the pipeline id of this request
+ cc.mu.Lock()
+ cc.pipereq[req] = id
+ cc.mu.Unlock()
+ }
+ }()
+
+ cc.mu.Lock()
+ if cc.re != nil { // no point sending if read-side closed or broken
+ defer cc.mu.Unlock()
+ return cc.re
+ }
+ if cc.we != nil {
+ defer cc.mu.Unlock()
+ return cc.we
+ }
+ if cc.c == nil { // connection closed by user in the meantime
+ defer cc.mu.Unlock()
+ return errClosed
+ }
+ c := cc.c
+ if req.Close {
+ // We write the EOF to the write-side error, because there
+ // still might be some pipelined reads
+ cc.we = ErrPersistEOF
+ }
+ cc.mu.Unlock()
+
+ err = cc.writeReq(req, c)
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ if err != nil {
+ cc.we = err
+ return err
+ }
+ cc.nwritten++
+
+ return nil
+}
+
+// Pending returns the number of unanswered requests
+// that have been sent on the connection.
+func (cc *ClientConn) Pending() int {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ return cc.nwritten - cc.nread
+}
+
+// Read reads the next response from the wire. A valid response might be
+// returned together with an ErrPersistEOF, which means that the remote
+// requested that this be the last request serviced. Read can be called
+// concurrently with Write, but not with another Read.
+func (cc *ClientConn) Read(req *http.Request) (resp *http.Response, err error) {
+ // Retrieve the pipeline ID of this request/response pair
+ cc.mu.Lock()
+ id, ok := cc.pipereq[req]
+ delete(cc.pipereq, req)
+ if !ok {
+ cc.mu.Unlock()
+ return nil, ErrPipeline
+ }
+ cc.mu.Unlock()
+
+ // Ensure pipeline order
+ cc.pipe.StartResponse(id)
+ defer cc.pipe.EndResponse(id)
+
+ cc.mu.Lock()
+ if cc.re != nil {
+ defer cc.mu.Unlock()
+ return nil, cc.re
+ }
+ if cc.r == nil { // connection closed by user in the meantime
+ defer cc.mu.Unlock()
+ return nil, errClosed
+ }
+ r := cc.r
+ lastbody := cc.lastbody
+ cc.lastbody = nil
+ cc.mu.Unlock()
+
+ // Make sure body is fully consumed, even if user does not call body.Close
+ if lastbody != nil {
+ // body.Close is assumed to be idempotent and multiple calls to
+ // it should return the error that its first invocation
+ // returned.
+ err = lastbody.Close()
+ if err != nil {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ cc.re = err
+ return nil, err
+ }
+ }
+
+ resp, err = http.ReadResponse(r, req)
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ if err != nil {
+ cc.re = err
+ return resp, err
+ }
+ cc.lastbody = resp.Body
+
+ cc.nread++
+
+ if resp.Close {
+ cc.re = ErrPersistEOF // don't send any more requests
+ return resp, cc.re
+ }
+ return resp, err
+}
+
+// Do is convenience method that writes a request and reads a response.
+func (cc *ClientConn) Do(req *http.Request) (*http.Response, error) {
+ err := cc.Write(req)
+ if err != nil {
+ return nil, err
+ }
+ return cc.Read(req)
+}
diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go
new file mode 100644
index 0000000..cf39222
--- /dev/null
+++ b/src/net/http/httputil/reverseproxy.go
@@ -0,0 +1,677 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// HTTP reverse proxy handler
+
+package httputil
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "log"
+ "mime"
+ "net"
+ "net/http"
+ "net/http/internal/ascii"
+ "net/textproto"
+ "net/url"
+ "strings"
+ "sync"
+ "time"
+
+ "golang.org/x/net/http/httpguts"
+)
+
+// ReverseProxy is an HTTP Handler that takes an incoming request and
+// sends it to another server, proxying the response back to the
+// client.
+//
+// ReverseProxy by default sets the client IP as the value of the
+// X-Forwarded-For header.
+//
+// If an X-Forwarded-For header already exists, the client IP is
+// appended to the existing values. As a special case, if the header
+// exists in the Request.Header map but has a nil value (such as when
+// set by the Director func), the X-Forwarded-For header is
+// not modified.
+//
+// To prevent IP spoofing, be sure to delete any pre-existing
+// X-Forwarded-For header coming from the client or
+// an untrusted proxy.
+type ReverseProxy struct {
+ // Director must be a function which modifies
+ // the request into a new request to be sent
+ // using Transport. Its response is then copied
+ // back to the original client unmodified.
+ // Director must not access the provided Request
+ // after returning.
+ Director func(*http.Request)
+
+ // The transport used to perform proxy requests.
+ // If nil, http.DefaultTransport is used.
+ Transport http.RoundTripper
+
+ // FlushInterval specifies the flush interval
+ // to flush to the client while copying the
+ // response body.
+ // If zero, no periodic flushing is done.
+ // A negative value means to flush immediately
+ // after each write to the client.
+ // The FlushInterval is ignored when ReverseProxy
+ // recognizes a response as a streaming response, or
+ // if its ContentLength is -1; for such responses, writes
+ // are flushed to the client immediately.
+ FlushInterval time.Duration
+
+ // ErrorLog specifies an optional logger for errors
+ // that occur when attempting to proxy the request.
+ // If nil, logging is done via the log package's standard logger.
+ ErrorLog *log.Logger
+
+ // BufferPool optionally specifies a buffer pool to
+ // get byte slices for use by io.CopyBuffer when
+ // copying HTTP response bodies.
+ BufferPool BufferPool
+
+ // ModifyResponse is an optional function that modifies the
+ // Response from the backend. It is called if the backend
+ // returns a response at all, with any HTTP status code.
+ // If the backend is unreachable, the optional ErrorHandler is
+ // called without any call to ModifyResponse.
+ //
+ // If ModifyResponse returns an error, ErrorHandler is called
+ // with its error value. If ErrorHandler is nil, its default
+ // implementation is used.
+ ModifyResponse func(*http.Response) error
+
+ // ErrorHandler is an optional function that handles errors
+ // reaching the backend or errors from ModifyResponse.
+ //
+ // If nil, the default is to log the provided error and return
+ // a 502 Status Bad Gateway response.
+ ErrorHandler func(http.ResponseWriter, *http.Request, error)
+}
+
+// A BufferPool is an interface for getting and returning temporary
+// byte slices for use by io.CopyBuffer.
+type BufferPool interface {
+ Get() []byte
+ Put([]byte)
+}
+
+func singleJoiningSlash(a, b string) string {
+ aslash := strings.HasSuffix(a, "/")
+ bslash := strings.HasPrefix(b, "/")
+ switch {
+ case aslash && bslash:
+ return a + b[1:]
+ case !aslash && !bslash:
+ return a + "/" + b
+ }
+ return a + b
+}
+
+func joinURLPath(a, b *url.URL) (path, rawpath string) {
+ if a.RawPath == "" && b.RawPath == "" {
+ return singleJoiningSlash(a.Path, b.Path), ""
+ }
+ // Same as singleJoiningSlash, but uses EscapedPath to determine
+ // whether a slash should be added
+ apath := a.EscapedPath()
+ bpath := b.EscapedPath()
+
+ aslash := strings.HasSuffix(apath, "/")
+ bslash := strings.HasPrefix(bpath, "/")
+
+ switch {
+ case aslash && bslash:
+ return a.Path + b.Path[1:], apath + bpath[1:]
+ case !aslash && !bslash:
+ return a.Path + "/" + b.Path, apath + "/" + bpath
+ }
+ return a.Path + b.Path, apath + bpath
+}
+
+// NewSingleHostReverseProxy returns a new ReverseProxy that routes
+// URLs to the scheme, host, and base path provided in target. If the
+// target's path is "/base" and the incoming request was for "/dir",
+// the target request will be for /base/dir.
+// NewSingleHostReverseProxy does not rewrite the Host header.
+// To rewrite Host headers, use ReverseProxy directly with a custom
+// Director policy.
+func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
+ targetQuery := target.RawQuery
+ director := func(req *http.Request) {
+ req.URL.Scheme = target.Scheme
+ req.URL.Host = target.Host
+ req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
+ if targetQuery == "" || req.URL.RawQuery == "" {
+ req.URL.RawQuery = targetQuery + req.URL.RawQuery
+ } else {
+ req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
+ }
+ if _, ok := req.Header["User-Agent"]; !ok {
+ // explicitly disable User-Agent so it's not set to default value
+ req.Header.Set("User-Agent", "")
+ }
+ }
+ return &ReverseProxy{Director: director}
+}
+
+func copyHeader(dst, src http.Header) {
+ for k, vv := range src {
+ for _, v := range vv {
+ dst.Add(k, v)
+ }
+ }
+}
+
+// Hop-by-hop headers. These are removed when sent to the backend.
+// As of RFC 7230, hop-by-hop headers are required to appear in the
+// Connection header field. These are the headers defined by the
+// obsoleted RFC 2616 (section 13.5.1) and are used for backward
+// compatibility.
+var hopHeaders = []string{
+ "Connection",
+ "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
+ "Keep-Alive",
+ "Proxy-Authenticate",
+ "Proxy-Authorization",
+ "Te", // canonicalized version of "TE"
+ "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
+ "Transfer-Encoding",
+ "Upgrade",
+}
+
+func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
+ p.logf("http: proxy error: %v", err)
+ rw.WriteHeader(http.StatusBadGateway)
+}
+
+func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
+ if p.ErrorHandler != nil {
+ return p.ErrorHandler
+ }
+ return p.defaultErrorHandler
+}
+
+// modifyResponse conditionally runs the optional ModifyResponse hook
+// and reports whether the request should proceed.
+func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
+ if p.ModifyResponse == nil {
+ return true
+ }
+ if err := p.ModifyResponse(res); err != nil {
+ res.Body.Close()
+ p.getErrorHandler()(rw, req, err)
+ return false
+ }
+ return true
+}
+
+func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
+ transport := p.Transport
+ if transport == nil {
+ transport = http.DefaultTransport
+ }
+
+ ctx := req.Context()
+ if ctx.Done() != nil {
+ // CloseNotifier predates context.Context, and has been
+ // entirely superseded by it. If the request contains
+ // a Context that carries a cancellation signal, don't
+ // bother spinning up a goroutine to watch the CloseNotify
+ // channel (if any).
+ //
+ // If the request Context has a nil Done channel (which
+ // means it is either context.Background, or a custom
+ // Context implementation with no cancellation signal),
+ // then consult the CloseNotifier if available.
+ } else if cn, ok := rw.(http.CloseNotifier); ok {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithCancel(ctx)
+ defer cancel()
+ notifyChan := cn.CloseNotify()
+ go func() {
+ select {
+ case <-notifyChan:
+ cancel()
+ case <-ctx.Done():
+ }
+ }()
+ }
+
+ outreq := req.Clone(ctx)
+ if req.ContentLength == 0 {
+ outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
+ }
+ if outreq.Body != nil {
+ // Reading from the request body after returning from a handler is not
+ // allowed, and the RoundTrip goroutine that reads the Body can outlive
+ // this handler. This can lead to a crash if the handler panics (see
+ // Issue 46866). Although calling Close doesn't guarantee there isn't
+ // any Read in flight after the handle returns, in practice it's safe to
+ // read after closing it.
+ defer outreq.Body.Close()
+ }
+ if outreq.Header == nil {
+ outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
+ }
+
+ p.Director(outreq)
+ if outreq.Form != nil {
+ outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery)
+ }
+ outreq.Close = false
+
+ reqUpType := upgradeType(outreq.Header)
+ if !ascii.IsPrint(reqUpType) {
+ p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
+ return
+ }
+ removeConnectionHeaders(outreq.Header)
+
+ // Remove hop-by-hop headers to the backend. Especially
+ // important is "Connection" because we want a persistent
+ // connection, regardless of what the client sent to us.
+ for _, h := range hopHeaders {
+ outreq.Header.Del(h)
+ }
+
+ // Issue 21096: tell backend applications that care about trailer support
+ // that we support trailers. (We do, but we don't go out of our way to
+ // advertise that unless the incoming client request thought it was worth
+ // mentioning.) Note that we look at req.Header, not outreq.Header, since
+ // the latter has passed through removeConnectionHeaders.
+ if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
+ outreq.Header.Set("Te", "trailers")
+ }
+
+ // After stripping all the hop-by-hop connection headers above, add back any
+ // necessary for protocol upgrades, such as for websockets.
+ if reqUpType != "" {
+ outreq.Header.Set("Connection", "Upgrade")
+ outreq.Header.Set("Upgrade", reqUpType)
+ }
+
+ if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
+ // If we aren't the first proxy retain prior
+ // X-Forwarded-For information as a comma+space
+ // separated list and fold multiple headers into one.
+ prior, ok := outreq.Header["X-Forwarded-For"]
+ omit := ok && prior == nil // Issue 38079: nil now means don't populate the header
+ if len(prior) > 0 {
+ clientIP = strings.Join(prior, ", ") + ", " + clientIP
+ }
+ if !omit {
+ outreq.Header.Set("X-Forwarded-For", clientIP)
+ }
+ }
+
+ res, err := transport.RoundTrip(outreq)
+ if err != nil {
+ p.getErrorHandler()(rw, outreq, err)
+ return
+ }
+
+ // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
+ if res.StatusCode == http.StatusSwitchingProtocols {
+ if !p.modifyResponse(rw, res, outreq) {
+ return
+ }
+ p.handleUpgradeResponse(rw, outreq, res)
+ return
+ }
+
+ removeConnectionHeaders(res.Header)
+
+ for _, h := range hopHeaders {
+ res.Header.Del(h)
+ }
+
+ if !p.modifyResponse(rw, res, outreq) {
+ return
+ }
+
+ copyHeader(rw.Header(), res.Header)
+
+ // The "Trailer" header isn't included in the Transport's response,
+ // at least for *http.Transport. Build it up from Trailer.
+ announcedTrailers := len(res.Trailer)
+ if announcedTrailers > 0 {
+ trailerKeys := make([]string, 0, len(res.Trailer))
+ for k := range res.Trailer {
+ trailerKeys = append(trailerKeys, k)
+ }
+ rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
+ }
+
+ rw.WriteHeader(res.StatusCode)
+
+ err = p.copyResponse(rw, res.Body, p.flushInterval(res))
+ if err != nil {
+ defer res.Body.Close()
+ // Since we're streaming the response, if we run into an error all we can do
+ // is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler
+ // on read error while copying body.
+ if !shouldPanicOnCopyError(req) {
+ p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
+ return
+ }
+ panic(http.ErrAbortHandler)
+ }
+ res.Body.Close() // close now, instead of defer, to populate res.Trailer
+
+ if len(res.Trailer) > 0 {
+ // Force chunking if we saw a response trailer.
+ // This prevents net/http from calculating the length for short
+ // bodies and adding a Content-Length.
+ if fl, ok := rw.(http.Flusher); ok {
+ fl.Flush()
+ }
+ }
+
+ if len(res.Trailer) == announcedTrailers {
+ copyHeader(rw.Header(), res.Trailer)
+ return
+ }
+
+ for k, vv := range res.Trailer {
+ k = http.TrailerPrefix + k
+ for _, v := range vv {
+ rw.Header().Add(k, v)
+ }
+ }
+}
+
+var inOurTests bool // whether we're in our own tests
+
+// shouldPanicOnCopyError reports whether the reverse proxy should
+// panic with http.ErrAbortHandler. This is the right thing to do by
+// default, but Go 1.10 and earlier did not, so existing unit tests
+// weren't expecting panics. Only panic in our own tests, or when
+// running under the HTTP server.
+func shouldPanicOnCopyError(req *http.Request) bool {
+ if inOurTests {
+ // Our tests know to handle this panic.
+ return true
+ }
+ if req.Context().Value(http.ServerContextKey) != nil {
+ // We seem to be running under an HTTP server, so
+ // it'll recover the panic.
+ return true
+ }
+ // Otherwise act like Go 1.10 and earlier to not break
+ // existing tests.
+ return false
+}
+
+// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h.
+// See RFC 7230, section 6.1
+func removeConnectionHeaders(h http.Header) {
+ for _, f := range h["Connection"] {
+ for _, sf := range strings.Split(f, ",") {
+ if sf = textproto.TrimString(sf); sf != "" {
+ h.Del(sf)
+ }
+ }
+ }
+}
+
+// flushInterval returns the p.FlushInterval value, conditionally
+// overriding its value for a specific request/response.
+func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration {
+ resCT := res.Header.Get("Content-Type")
+
+ // For Server-Sent Events responses, flush immediately.
+ // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream
+ if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" {
+ return -1 // negative means immediately
+ }
+
+ // We might have the case of streaming for which Content-Length might be unset.
+ if res.ContentLength == -1 {
+ return -1
+ }
+
+ return p.FlushInterval
+}
+
+func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
+ if flushInterval != 0 {
+ if wf, ok := dst.(writeFlusher); ok {
+ mlw := &maxLatencyWriter{
+ dst: wf,
+ latency: flushInterval,
+ }
+ defer mlw.stop()
+
+ // set up initial timer so headers get flushed even if body writes are delayed
+ mlw.flushPending = true
+ mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
+
+ dst = mlw
+ }
+ }
+
+ var buf []byte
+ if p.BufferPool != nil {
+ buf = p.BufferPool.Get()
+ defer p.BufferPool.Put(buf)
+ }
+ _, err := p.copyBuffer(dst, src, buf)
+ return err
+}
+
+// copyBuffer returns any write errors or non-EOF read errors, and the amount
+// of bytes written.
+func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
+ if len(buf) == 0 {
+ buf = make([]byte, 32*1024)
+ }
+ var written int64
+ for {
+ nr, rerr := src.Read(buf)
+ if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
+ p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
+ }
+ if nr > 0 {
+ nw, werr := dst.Write(buf[:nr])
+ if nw > 0 {
+ written += int64(nw)
+ }
+ if werr != nil {
+ return written, werr
+ }
+ if nr != nw {
+ return written, io.ErrShortWrite
+ }
+ }
+ if rerr != nil {
+ if rerr == io.EOF {
+ rerr = nil
+ }
+ return written, rerr
+ }
+ }
+}
+
+func (p *ReverseProxy) logf(format string, args ...any) {
+ if p.ErrorLog != nil {
+ p.ErrorLog.Printf(format, args...)
+ } else {
+ log.Printf(format, args...)
+ }
+}
+
+type writeFlusher interface {
+ io.Writer
+ http.Flusher
+}
+
+type maxLatencyWriter struct {
+ dst writeFlusher
+ latency time.Duration // non-zero; negative means to flush immediately
+
+ mu sync.Mutex // protects t, flushPending, and dst.Flush
+ t *time.Timer
+ flushPending bool
+}
+
+func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ n, err = m.dst.Write(p)
+ if m.latency < 0 {
+ m.dst.Flush()
+ return
+ }
+ if m.flushPending {
+ return
+ }
+ if m.t == nil {
+ m.t = time.AfterFunc(m.latency, m.delayedFlush)
+ } else {
+ m.t.Reset(m.latency)
+ }
+ m.flushPending = true
+ return
+}
+
+func (m *maxLatencyWriter) delayedFlush() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if !m.flushPending { // if stop was called but AfterFunc already started this goroutine
+ return
+ }
+ m.dst.Flush()
+ m.flushPending = false
+}
+
+func (m *maxLatencyWriter) stop() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.flushPending = false
+ if m.t != nil {
+ m.t.Stop()
+ }
+}
+
+func upgradeType(h http.Header) string {
+ if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
+ return ""
+ }
+ return h.Get("Upgrade")
+}
+
+func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
+ reqUpType := upgradeType(req.Header)
+ resUpType := upgradeType(res.Header)
+ if !ascii.IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller.
+ p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType))
+ }
+ if !ascii.EqualFold(reqUpType, resUpType) {
+ p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
+ return
+ }
+
+ hj, ok := rw.(http.Hijacker)
+ if !ok {
+ p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
+ return
+ }
+ backConn, ok := res.Body.(io.ReadWriteCloser)
+ if !ok {
+ p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
+ return
+ }
+
+ backConnCloseCh := make(chan bool)
+ go func() {
+ // Ensure that the cancellation of a request closes the backend.
+ // See issue https://golang.org/issue/35559.
+ select {
+ case <-req.Context().Done():
+ case <-backConnCloseCh:
+ }
+ backConn.Close()
+ }()
+
+ defer close(backConnCloseCh)
+
+ conn, brw, err := hj.Hijack()
+ if err != nil {
+ p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
+ return
+ }
+ defer conn.Close()
+
+ copyHeader(rw.Header(), res.Header)
+
+ res.Header = rw.Header()
+ res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
+ if err := res.Write(brw); err != nil {
+ p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
+ return
+ }
+ if err := brw.Flush(); err != nil {
+ p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
+ return
+ }
+ errc := make(chan error, 1)
+ spc := switchProtocolCopier{user: conn, backend: backConn}
+ go spc.copyToBackend(errc)
+ go spc.copyFromBackend(errc)
+ <-errc
+}
+
+// switchProtocolCopier exists so goroutines proxying data back and
+// forth have nice names in stacks.
+type switchProtocolCopier struct {
+ user, backend io.ReadWriter
+}
+
+func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
+ _, err := io.Copy(c.user, c.backend)
+ errc <- err
+}
+
+func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
+ _, err := io.Copy(c.backend, c.user)
+ errc <- err
+}
+
+func cleanQueryParams(s string) string {
+ reencode := func(s string) string {
+ v, _ := url.ParseQuery(s)
+ return v.Encode()
+ }
+ for i := 0; i < len(s); {
+ switch s[i] {
+ case ';':
+ return reencode(s)
+ case '%':
+ if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) {
+ return reencode(s)
+ }
+ i += 3
+ default:
+ i++
+ }
+ }
+ return s
+}
+
+func ishex(c byte) bool {
+ switch {
+ case '0' <= c && c <= '9':
+ return true
+ case 'a' <= c && c <= 'f':
+ return true
+ case 'A' <= c && c <= 'F':
+ return true
+ }
+ return false
+}
diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go
new file mode 100644
index 0000000..33b1ade
--- /dev/null
+++ b/src/net/http/httputil/reverseproxy_test.go
@@ -0,0 +1,1613 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Reverse proxy tests.
+
+package httputil
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "net/http/httptest"
+ "net/http/internal/ascii"
+ "net/url"
+ "os"
+ "reflect"
+ "sort"
+ "strconv"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
+
+func init() {
+ inOurTests = true
+ hopHeaders = append(hopHeaders, fakeHopHeader)
+}
+
+func TestReverseProxy(t *testing.T) {
+ const backendResponse = "I am the backend"
+ const backendStatus = 404
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method == "GET" && r.FormValue("mode") == "hangup" {
+ c, _, _ := w.(http.Hijacker).Hijack()
+ c.Close()
+ return
+ }
+ if len(r.TransferEncoding) > 0 {
+ t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
+ }
+ if r.Header.Get("X-Forwarded-For") == "" {
+ t.Errorf("didn't get X-Forwarded-For header")
+ }
+ if c := r.Header.Get("Connection"); c != "" {
+ t.Errorf("handler got Connection header value %q", c)
+ }
+ if c := r.Header.Get("Te"); c != "trailers" {
+ t.Errorf("handler got Te header value %q; want 'trailers'", c)
+ }
+ if c := r.Header.Get("Upgrade"); c != "" {
+ t.Errorf("handler got Upgrade header value %q", c)
+ }
+ if c := r.Header.Get("Proxy-Connection"); c != "" {
+ t.Errorf("handler got Proxy-Connection header value %q", c)
+ }
+ if g, e := r.Host, "some-name"; g != e {
+ t.Errorf("backend got Host header %q, want %q", g, e)
+ }
+ w.Header().Set("Trailers", "not a special header field name")
+ w.Header().Set("Trailer", "X-Trailer")
+ w.Header().Set("X-Foo", "bar")
+ w.Header().Set("Upgrade", "foo")
+ w.Header().Set(fakeHopHeader, "foo")
+ w.Header().Add("X-Multi-Value", "foo")
+ w.Header().Add("X-Multi-Value", "bar")
+ http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
+ w.WriteHeader(backendStatus)
+ w.Write([]byte(backendResponse))
+ w.Header().Set("X-Trailer", "trailer_value")
+ w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+ frontendClient := frontend.Client()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Host = "some-name"
+ getReq.Header.Set("Connection", "close, TE")
+ getReq.Header.Add("Te", "foo")
+ getReq.Header.Add("Te", "bar, trailers")
+ getReq.Header.Set("Proxy-Connection", "should be deleted")
+ getReq.Header.Set("Upgrade", "foo")
+ getReq.Close = true
+ res, err := frontendClient.Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ if g, e := res.StatusCode, backendStatus; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
+ t.Errorf("got X-Foo %q; expected %q", g, e)
+ }
+ if c := res.Header.Get(fakeHopHeader); c != "" {
+ t.Errorf("got %s header value %q", fakeHopHeader, c)
+ }
+ if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e {
+ t.Errorf("header Trailers = %q; want %q", g, e)
+ }
+ if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
+ t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
+ }
+ if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
+ t.Fatalf("got %d SetCookies, want %d", g, e)
+ }
+ if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
+ t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
+ }
+ if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
+ t.Errorf("unexpected cookie %q", cookie.Name)
+ }
+ bodyBytes, _ := io.ReadAll(res.Body)
+ if g, e := string(bodyBytes), backendResponse; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+ if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
+ t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
+ }
+ if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e {
+ t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e)
+ }
+
+ // Test that a backend failing to be reached or one which doesn't return
+ // a response results in a StatusBadGateway.
+ getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
+ getReq.Close = true
+ res, err = frontendClient.Do(getReq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if res.StatusCode != http.StatusBadGateway {
+ t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status)
+ }
+
+}
+
+// Issue 16875: remove any proxied headers mentioned in the "Connection"
+// header value.
+func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
+ const fakeConnectionToken = "X-Fake-Connection-Token"
+ const backendResponse = "I am the backend"
+
+ // someConnHeader is some arbitrary header to be declared as a hop-by-hop header
+ // in the Request's Connection header.
+ const someConnHeader = "X-Some-Conn-Header"
+
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if c := r.Header.Get("Connection"); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", "Connection", c)
+ }
+ if c := r.Header.Get(fakeConnectionToken); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
+ }
+ if c := r.Header.Get(someConnHeader); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
+ }
+ w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken)
+ w.Header().Add("Connection", someConnHeader)
+ w.Header().Set(someConnHeader, "should be deleted")
+ w.Header().Set(fakeConnectionToken, "should be deleted")
+ io.WriteString(w, backendResponse)
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ proxyHandler.ServeHTTP(w, r)
+ if c := r.Header.Get(someConnHeader); c != "should be deleted" {
+ t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
+ }
+ if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" {
+ t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted")
+ }
+ c := r.Header["Connection"]
+ var cf []string
+ for _, f := range c {
+ for _, sf := range strings.Split(f, ",") {
+ if sf = strings.TrimSpace(sf); sf != "" {
+ cf = append(cf, sf)
+ }
+ }
+ }
+ sort.Strings(cf)
+ expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken}
+ sort.Strings(expectedValues)
+ if !reflect.DeepEqual(cf, expectedValues) {
+ t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues)
+ }
+ }))
+ defer frontend.Close()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken)
+ getReq.Header.Add("Connection", someConnHeader)
+ getReq.Header.Set(someConnHeader, "should be deleted")
+ getReq.Header.Set(fakeConnectionToken, "should be deleted")
+ res, err := frontend.Client().Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ bodyBytes, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("reading body: %v", err)
+ }
+ if got, want := string(bodyBytes), backendResponse; got != want {
+ t.Errorf("got body %q; want %q", got, want)
+ }
+ if c := res.Header.Get("Connection"); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", "Connection", c)
+ }
+ if c := res.Header.Get(someConnHeader); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
+ }
+ if c := res.Header.Get(fakeConnectionToken); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
+ }
+}
+
+func TestReverseProxyStripEmptyConnection(t *testing.T) {
+ // See Issue 46313.
+ const backendResponse = "I am the backend"
+
+ // someConnHeader is some arbitrary header to be declared as a hop-by-hop header
+ // in the Request's Connection header.
+ const someConnHeader = "X-Some-Conn-Header"
+
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if c := r.Header.Values("Connection"); len(c) != 0 {
+ t.Errorf("handler got header %q = %v; want empty", "Connection", c)
+ }
+ if c := r.Header.Get(someConnHeader); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
+ }
+ w.Header().Add("Connection", "")
+ w.Header().Add("Connection", someConnHeader)
+ w.Header().Set(someConnHeader, "should be deleted")
+ io.WriteString(w, backendResponse)
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ proxyHandler.ServeHTTP(w, r)
+ if c := r.Header.Get(someConnHeader); c != "should be deleted" {
+ t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
+ }
+ }))
+ defer frontend.Close()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Header.Add("Connection", "")
+ getReq.Header.Add("Connection", someConnHeader)
+ getReq.Header.Set(someConnHeader, "should be deleted")
+ res, err := frontend.Client().Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ bodyBytes, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("reading body: %v", err)
+ }
+ if got, want := string(bodyBytes), backendResponse; got != want {
+ t.Errorf("got body %q; want %q", got, want)
+ }
+ if c := res.Header.Get("Connection"); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", "Connection", c)
+ }
+ if c := res.Header.Get(someConnHeader); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
+ }
+}
+
+func TestXForwardedFor(t *testing.T) {
+ const prevForwardedFor = "client ip"
+ const backendResponse = "I am the backend"
+ const backendStatus = 404
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Header.Get("X-Forwarded-For") == "" {
+ t.Errorf("didn't get X-Forwarded-For header")
+ }
+ if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
+ t.Errorf("X-Forwarded-For didn't contain prior data")
+ }
+ w.WriteHeader(backendStatus)
+ w.Write([]byte(backendResponse))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Host = "some-name"
+ getReq.Header.Set("Connection", "close")
+ getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
+ getReq.Close = true
+ res, err := frontend.Client().Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ if g, e := res.StatusCode, backendStatus; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ bodyBytes, _ := io.ReadAll(res.Body)
+ if g, e := string(bodyBytes), backendResponse; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+}
+
+// Issue 38079: don't append to X-Forwarded-For if it's present but nil
+func TestXForwardedFor_Omit(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if v := r.Header.Get("X-Forwarded-For"); v != "" {
+ t.Errorf("got X-Forwarded-For header: %q", v)
+ }
+ w.Write([]byte("hi"))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ oldDirector := proxyHandler.Director
+ proxyHandler.Director = func(r *http.Request) {
+ r.Header["X-Forwarded-For"] = nil
+ oldDirector(r)
+ }
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Host = "some-name"
+ getReq.Close = true
+ res, err := frontend.Client().Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ res.Body.Close()
+}
+
+var proxyQueryTests = []struct {
+ baseSuffix string // suffix to add to backend URL
+ reqSuffix string // suffix to add to frontend's request URL
+ want string // what backend should see for final request URL (without ?)
+}{
+ {"", "", ""},
+ {"?sta=tic", "?us=er", "sta=tic&us=er"},
+ {"", "?us=er", "us=er"},
+ {"?sta=tic", "", "sta=tic"},
+}
+
+func TestReverseProxyQuery(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("X-Got-Query", r.URL.RawQuery)
+ w.Write([]byte("hi"))
+ }))
+ defer backend.Close()
+
+ for i, tt := range proxyQueryTests {
+ backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
+ if err != nil {
+ t.Fatal(err)
+ }
+ frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
+ req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
+ req.Close = true
+ res, err := frontend.Client().Do(req)
+ if err != nil {
+ t.Fatalf("%d. Get: %v", i, err)
+ }
+ if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
+ t.Errorf("%d. got query %q; expected %q", i, g, e)
+ }
+ res.Body.Close()
+ frontend.Close()
+ }
+}
+
+func TestReverseProxyFlushInterval(t *testing.T) {
+ const expected = "hi"
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(expected))
+ }))
+ defer backend.Close()
+
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.FlushInterval = time.Microsecond
+
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ req, _ := http.NewRequest("GET", frontend.URL, nil)
+ req.Close = true
+ res, err := frontend.Client().Do(req)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
+ t.Errorf("got body %q; expected %q", bodyBytes, expected)
+ }
+}
+
+func TestReverseProxyFlushIntervalHeaders(t *testing.T) {
+ const expected = "hi"
+ stopCh := make(chan struct{})
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Add("MyHeader", expected)
+ w.WriteHeader(200)
+ w.(http.Flusher).Flush()
+ <-stopCh
+ }))
+ defer backend.Close()
+ defer close(stopCh)
+
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.FlushInterval = time.Microsecond
+
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ req, _ := http.NewRequest("GET", frontend.URL, nil)
+ req.Close = true
+
+ ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
+ defer cancel()
+ req = req.WithContext(ctx)
+
+ res, err := frontend.Client().Do(req)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+
+ if res.Header.Get("MyHeader") != expected {
+ t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected)
+ }
+}
+
+func TestReverseProxyCancellation(t *testing.T) {
+ const backendResponse = "I am the backend"
+
+ reqInFlight := make(chan struct{})
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ close(reqInFlight) // cause the client to cancel its request
+
+ select {
+ case <-time.After(10 * time.Second):
+ // Note: this should only happen in broken implementations, and the
+ // closenotify case should be instantaneous.
+ t.Error("Handler never saw CloseNotify")
+ return
+ case <-w.(http.CloseNotifier).CloseNotify():
+ }
+
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(backendResponse))
+ }))
+
+ defer backend.Close()
+
+ backend.Config.ErrorLog = log.New(io.Discard, "", 0)
+
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+
+ // Discards errors of the form:
+ // http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
+
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+ frontendClient := frontend.Client()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ go func() {
+ <-reqInFlight
+ frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
+ }()
+ res, err := frontendClient.Do(getReq)
+ if res != nil {
+ t.Errorf("got response %v; want nil", res.Status)
+ }
+ if err == nil {
+ // This should be an error like:
+ // Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079:
+ // use of closed network connection
+ t.Error("Server.Client().Do() returned nil error; want non-nil error")
+ }
+}
+
+func req(t *testing.T, v string) *http.Request {
+ req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
+ if err != nil {
+ t.Fatal(err)
+ }
+ return req
+}
+
+// Issue 12344
+func TestNilBody(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hi"))
+ }))
+ defer backend.Close()
+
+ frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ backURL, _ := url.Parse(backend.URL)
+ rp := NewSingleHostReverseProxy(backURL)
+ r := req(t, "GET / HTTP/1.0\r\n\r\n")
+ r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working
+ rp.ServeHTTP(w, r)
+ }))
+ defer frontend.Close()
+
+ res, err := http.Get(frontend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ slurp, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(slurp) != "hi" {
+ t.Errorf("Got %q; want %q", slurp, "hi")
+ }
+}
+
+// Issue 15524
+func TestUserAgentHeader(t *testing.T) {
+ const explicitUA = "explicit UA"
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/noua" {
+ if c := r.Header.Get("User-Agent"); c != "" {
+ t.Errorf("handler got non-empty User-Agent header %q", c)
+ }
+ return
+ }
+ if c := r.Header.Get("User-Agent"); c != explicitUA {
+ t.Errorf("handler got unexpected User-Agent header %q", c)
+ }
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+ frontendClient := frontend.Client()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Header.Set("User-Agent", explicitUA)
+ getReq.Close = true
+ res, err := frontendClient.Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ res.Body.Close()
+
+ getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil)
+ getReq.Header.Set("User-Agent", "")
+ getReq.Close = true
+ res, err = frontendClient.Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ res.Body.Close()
+}
+
+type bufferPool struct {
+ get func() []byte
+ put func([]byte)
+}
+
+func (bp bufferPool) Get() []byte { return bp.get() }
+func (bp bufferPool) Put(v []byte) { bp.put(v) }
+
+func TestReverseProxyGetPutBuffer(t *testing.T) {
+ const msg = "hi"
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, msg)
+ }))
+ defer backend.Close()
+
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var (
+ mu sync.Mutex
+ log []string
+ )
+ addLog := func(event string) {
+ mu.Lock()
+ defer mu.Unlock()
+ log = append(log, event)
+ }
+ rp := NewSingleHostReverseProxy(backendURL)
+ const size = 1234
+ rp.BufferPool = bufferPool{
+ get: func() []byte {
+ addLog("getBuf")
+ return make([]byte, size)
+ },
+ put: func(p []byte) {
+ addLog("putBuf-" + strconv.Itoa(len(p)))
+ },
+ }
+ frontend := httptest.NewServer(rp)
+ defer frontend.Close()
+
+ req, _ := http.NewRequest("GET", frontend.URL, nil)
+ req.Close = true
+ res, err := frontend.Client().Do(req)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ slurp, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Fatalf("reading body: %v", err)
+ }
+ if string(slurp) != msg {
+ t.Errorf("msg = %q; want %q", slurp, msg)
+ }
+ wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
+ mu.Lock()
+ defer mu.Unlock()
+ if !reflect.DeepEqual(log, wantLog) {
+ t.Errorf("Log events = %q; want %q", log, wantLog)
+ }
+}
+
+func TestReverseProxy_Post(t *testing.T) {
+ const backendResponse = "I am the backend"
+ const backendStatus = 200
+ var requestBody = bytes.Repeat([]byte("a"), 1<<20)
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ slurp, err := io.ReadAll(r.Body)
+ if err != nil {
+ t.Errorf("Backend body read = %v", err)
+ }
+ if len(slurp) != len(requestBody) {
+ t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
+ }
+ if !bytes.Equal(slurp, requestBody) {
+ t.Error("Backend read wrong request body.") // 1MB; omitting details
+ }
+ w.Write([]byte(backendResponse))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
+ res, err := frontend.Client().Do(postReq)
+ if err != nil {
+ t.Fatalf("Do: %v", err)
+ }
+ if g, e := res.StatusCode, backendStatus; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ bodyBytes, _ := io.ReadAll(res.Body)
+ if g, e := string(bodyBytes), backendResponse; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+}
+
+type RoundTripperFunc func(*http.Request) (*http.Response, error)
+
+func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
+ return fn(req)
+}
+
+// Issue 16036: send a Request with a nil Body when possible
+func TestReverseProxy_NilBody(t *testing.T) {
+ backendURL, _ := url.Parse("http://fake.tld/")
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
+ if req.Body != nil {
+ t.Error("Body != nil; want a nil Body")
+ }
+ return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
+ })
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ res, err := frontend.Client().Get(frontend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != 502 {
+ t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status)
+ }
+}
+
+// Issue 33142: always allocate the request headers
+func TestReverseProxy_AllocatedHeader(t *testing.T) {
+ proxyHandler := new(ReverseProxy)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ proxyHandler.Director = func(*http.Request) {} // noop
+ proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
+ if req.Header == nil {
+ t.Error("Header == nil; want a non-nil Header")
+ }
+ return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
+ })
+
+ proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{
+ Method: "GET",
+ URL: &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"},
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ })
+}
+
+// Issue 14237. Test ModifyResponse and that an error from it
+// causes the proxy to return StatusBadGateway, or StatusOK otherwise.
+func TestReverseProxyModifyResponse(t *testing.T) {
+ backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod"))
+ }))
+ defer backendServer.Close()
+
+ rpURL, _ := url.Parse(backendServer.URL)
+ rproxy := NewSingleHostReverseProxy(rpURL)
+ rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ rproxy.ModifyResponse = func(resp *http.Response) error {
+ if resp.Header.Get("X-Hit-Mod") != "true" {
+ return fmt.Errorf("tried to by-pass proxy")
+ }
+ return nil
+ }
+
+ frontendProxy := httptest.NewServer(rproxy)
+ defer frontendProxy.Close()
+
+ tests := []struct {
+ url string
+ wantCode int
+ }{
+ {frontendProxy.URL + "/mod", http.StatusOK},
+ {frontendProxy.URL + "/schedule", http.StatusBadGateway},
+ }
+
+ for i, tt := range tests {
+ resp, err := http.Get(tt.url)
+ if err != nil {
+ t.Fatalf("failed to reach proxy: %v", err)
+ }
+ if g, e := resp.StatusCode, tt.wantCode; g != e {
+ t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e)
+ }
+ resp.Body.Close()
+ }
+}
+
+type failingRoundTripper struct{}
+
+func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
+ return nil, errors.New("some error")
+}
+
+type staticResponseRoundTripper struct{ res *http.Response }
+
+func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
+ return rt.res, nil
+}
+
+func TestReverseProxyErrorHandler(t *testing.T) {
+ tests := []struct {
+ name string
+ wantCode int
+ errorHandler func(http.ResponseWriter, *http.Request, error)
+ transport http.RoundTripper // defaults to failingRoundTripper
+ modifyResponse func(*http.Response) error
+ }{
+ {
+ name: "default",
+ wantCode: http.StatusBadGateway,
+ },
+ {
+ name: "errorhandler",
+ wantCode: http.StatusTeapot,
+ errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
+ },
+ {
+ name: "modifyresponse_noerr",
+ transport: staticResponseRoundTripper{
+ &http.Response{StatusCode: 345, Body: http.NoBody},
+ },
+ modifyResponse: func(res *http.Response) error {
+ res.StatusCode++
+ return nil
+ },
+ errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
+ wantCode: 346,
+ },
+ {
+ name: "modifyresponse_err",
+ transport: staticResponseRoundTripper{
+ &http.Response{StatusCode: 345, Body: http.NoBody},
+ },
+ modifyResponse: func(res *http.Response) error {
+ res.StatusCode++
+ return errors.New("some error to trigger errorHandler")
+ },
+ errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
+ wantCode: http.StatusTeapot,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ target := &url.URL{
+ Scheme: "http",
+ Host: "dummy.tld",
+ Path: "/",
+ }
+ rproxy := NewSingleHostReverseProxy(target)
+ rproxy.Transport = tt.transport
+ rproxy.ModifyResponse = tt.modifyResponse
+ if rproxy.Transport == nil {
+ rproxy.Transport = failingRoundTripper{}
+ }
+ rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ if tt.errorHandler != nil {
+ rproxy.ErrorHandler = tt.errorHandler
+ }
+ frontendProxy := httptest.NewServer(rproxy)
+ defer frontendProxy.Close()
+
+ resp, err := http.Get(frontendProxy.URL + "/test")
+ if err != nil {
+ t.Fatalf("failed to reach proxy: %v", err)
+ }
+ if g, e := resp.StatusCode, tt.wantCode; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ resp.Body.Close()
+ })
+ }
+}
+
+// Issue 16659: log errors from short read
+func TestReverseProxy_CopyBuffer(t *testing.T) {
+ backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ out := "this call was relayed by the reverse proxy"
+ // Coerce a wrong content length to induce io.UnexpectedEOF
+ w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
+ fmt.Fprintln(w, out)
+ }))
+ defer backendServer.Close()
+
+ rpURL, err := url.Parse(backendServer.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var proxyLog bytes.Buffer
+ rproxy := NewSingleHostReverseProxy(rpURL)
+ rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
+ donec := make(chan bool, 1)
+ frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ defer func() { donec <- true }()
+ rproxy.ServeHTTP(w, r)
+ }))
+ defer frontendProxy.Close()
+
+ if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil {
+ t.Fatalf("want non-nil error")
+ }
+ // The race detector complains about the proxyLog usage in logf in copyBuffer
+ // and our usage below with proxyLog.Bytes() so we're explicitly using a
+ // channel to ensure that the ReverseProxy's ServeHTTP is done before we
+ // continue after Get.
+ <-donec
+
+ expected := []string{
+ "EOF",
+ "read",
+ }
+ for _, phrase := range expected {
+ if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) {
+ t.Errorf("expected log to contain phrase %q", phrase)
+ }
+ }
+}
+
+type staticTransport struct {
+ res *http.Response
+}
+
+func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
+ return t.res, nil
+}
+
+func BenchmarkServeHTTP(b *testing.B) {
+ res := &http.Response{
+ StatusCode: 200,
+ Body: io.NopCloser(strings.NewReader("")),
+ }
+ proxy := &ReverseProxy{
+ Director: func(*http.Request) {},
+ Transport: &staticTransport{res},
+ }
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest("GET", "/", nil)
+
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ proxy.ServeHTTP(w, r)
+ }
+}
+
+func TestServeHTTPDeepCopy(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("Hello Gopher!"))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ type result struct {
+ before, after string
+ }
+
+ resultChan := make(chan result, 1)
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ before := r.URL.String()
+ proxyHandler.ServeHTTP(w, r)
+ after := r.URL.String()
+ resultChan <- result{before: before, after: after}
+ }))
+ defer frontend.Close()
+
+ want := result{before: "/", after: "/"}
+
+ res, err := frontend.Client().Get(frontend.URL)
+ if err != nil {
+ t.Fatalf("Do: %v", err)
+ }
+ res.Body.Close()
+
+ got := <-resultChan
+ if got != want {
+ t.Errorf("got = %+v; want = %+v", got, want)
+ }
+}
+
+// Issue 18327: verify we always do a deep copy of the Request.Header map
+// before any mutations.
+func TestClonesRequestHeaders(t *testing.T) {
+ log.SetOutput(io.Discard)
+ defer log.SetOutput(os.Stderr)
+ req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
+ req.RemoteAddr = "1.2.3.4:56789"
+ rp := &ReverseProxy{
+ Director: func(req *http.Request) {
+ req.Header.Set("From-Director", "1")
+ },
+ Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
+ if v := req.Header.Get("From-Director"); v != "1" {
+ t.Errorf("From-Directory value = %q; want 1", v)
+ }
+ return nil, io.EOF
+ }),
+ }
+ rp.ServeHTTP(httptest.NewRecorder(), req)
+
+ if req.Header.Get("From-Director") == "1" {
+ t.Error("Director header mutation modified caller's request")
+ }
+ if req.Header.Get("X-Forwarded-For") != "" {
+ t.Error("X-Forward-For header mutation modified caller's request")
+ }
+
+}
+
+type roundTripperFunc func(req *http.Request) (*http.Response, error)
+
+func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
+ return fn(req)
+}
+
+func TestModifyResponseClosesBody(t *testing.T) {
+ req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
+ req.RemoteAddr = "1.2.3.4:56789"
+ closeCheck := new(checkCloser)
+ logBuf := new(bytes.Buffer)
+ outErr := errors.New("ModifyResponse error")
+ rp := &ReverseProxy{
+ Director: func(req *http.Request) {},
+ Transport: &staticTransport{&http.Response{
+ StatusCode: 200,
+ Body: closeCheck,
+ }},
+ ErrorLog: log.New(logBuf, "", 0),
+ ModifyResponse: func(*http.Response) error {
+ return outErr
+ },
+ }
+ rec := httptest.NewRecorder()
+ rp.ServeHTTP(rec, req)
+ res := rec.Result()
+ if g, e := res.StatusCode, http.StatusBadGateway; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ if !closeCheck.closed {
+ t.Errorf("body should have been closed")
+ }
+ if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
+ t.Errorf("ErrorLog %q does not contain %q", g, e)
+ }
+}
+
+type checkCloser struct {
+ closed bool
+}
+
+func (cc *checkCloser) Close() error {
+ cc.closed = true
+ return nil
+}
+
+func (cc *checkCloser) Read(b []byte) (int, error) {
+ return len(b), nil
+}
+
+// Issue 23643: panic on body copy error
+func TestReverseProxy_PanicBodyError(t *testing.T) {
+ log.SetOutput(io.Discard)
+ defer log.SetOutput(os.Stderr)
+ backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ out := "this call was relayed by the reverse proxy"
+ // Coerce a wrong content length to induce io.ErrUnexpectedEOF
+ w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
+ fmt.Fprintln(w, out)
+ }))
+ defer backendServer.Close()
+
+ rpURL, err := url.Parse(backendServer.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ rproxy := NewSingleHostReverseProxy(rpURL)
+
+ // Ensure that the handler panics when the body read encounters an
+ // io.ErrUnexpectedEOF
+ defer func() {
+ err := recover()
+ if err == nil {
+ t.Fatal("handler should have panicked")
+ }
+ if err != http.ErrAbortHandler {
+ t.Fatal("expected ErrAbortHandler, got", err)
+ }
+ }()
+ req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
+ rproxy.ServeHTTP(httptest.NewRecorder(), req)
+}
+
+// Issue #46866: panic without closing incoming request body causes a panic
+func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ out := "this call was relayed by the reverse proxy"
+ // Coerce a wrong content length to induce io.ErrUnexpectedEOF
+ w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
+ fmt.Fprintln(w, out)
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+ frontendClient := frontend.Client()
+
+ var wg sync.WaitGroup
+ for i := 0; i < 2; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for j := 0; j < 10; j++ {
+ const reqLen = 6 * 1024 * 1024
+ req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
+ req.ContentLength = reqLen
+ resp, _ := frontendClient.Transport.RoundTrip(req)
+ if resp != nil {
+ io.Copy(io.Discard, resp.Body)
+ resp.Body.Close()
+ }
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+func TestSelectFlushInterval(t *testing.T) {
+ tests := []struct {
+ name string
+ p *ReverseProxy
+ res *http.Response
+ want time.Duration
+ }{
+ {
+ name: "default",
+ res: &http.Response{},
+ p: &ReverseProxy{FlushInterval: 123},
+ want: 123,
+ },
+ {
+ name: "server-sent events overrides non-zero",
+ res: &http.Response{
+ Header: http.Header{
+ "Content-Type": {"text/event-stream"},
+ },
+ },
+ p: &ReverseProxy{FlushInterval: 123},
+ want: -1,
+ },
+ {
+ name: "server-sent events overrides zero",
+ res: &http.Response{
+ Header: http.Header{
+ "Content-Type": {"text/event-stream"},
+ },
+ },
+ p: &ReverseProxy{FlushInterval: 0},
+ want: -1,
+ },
+ {
+ name: "server-sent events with media-type parameters overrides non-zero",
+ res: &http.Response{
+ Header: http.Header{
+ "Content-Type": {"text/event-stream;charset=utf-8"},
+ },
+ },
+ p: &ReverseProxy{FlushInterval: 123},
+ want: -1,
+ },
+ {
+ name: "server-sent events with media-type parameters overrides zero",
+ res: &http.Response{
+ Header: http.Header{
+ "Content-Type": {"text/event-stream;charset=utf-8"},
+ },
+ },
+ p: &ReverseProxy{FlushInterval: 0},
+ want: -1,
+ },
+ {
+ name: "Content-Length: -1, overrides non-zero",
+ res: &http.Response{
+ ContentLength: -1,
+ },
+ p: &ReverseProxy{FlushInterval: 123},
+ want: -1,
+ },
+ {
+ name: "Content-Length: -1, overrides zero",
+ res: &http.Response{
+ ContentLength: -1,
+ },
+ p: &ReverseProxy{FlushInterval: 0},
+ want: -1,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.p.flushInterval(tt.res)
+ if got != tt.want {
+ t.Errorf("flushLatency = %v; want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestReverseProxyWebSocket(t *testing.T) {
+ backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if upgradeType(r.Header) != "websocket" {
+ t.Error("unexpected backend request")
+ http.Error(w, "unexpected request", 400)
+ return
+ }
+ c, _, err := w.(http.Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer c.Close()
+ io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
+ bs := bufio.NewScanner(c)
+ if !bs.Scan() {
+ t.Errorf("backend failed to read line from client: %v", bs.Err())
+ return
+ }
+ fmt.Fprintf(c, "backend got %q\n", bs.Text())
+ }))
+ defer backendServer.Close()
+
+ backURL, _ := url.Parse(backendServer.URL)
+ rproxy := NewSingleHostReverseProxy(backURL)
+ rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ rproxy.ModifyResponse = func(res *http.Response) error {
+ res.Header.Add("X-Modified", "true")
+ return nil
+ }
+
+ handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+ rw.Header().Set("X-Header", "X-Value")
+ rproxy.ServeHTTP(rw, req)
+ if got, want := rw.Header().Get("X-Modified"), "true"; got != want {
+ t.Errorf("response writer X-Modified header = %q; want %q", got, want)
+ }
+ })
+
+ frontendProxy := httptest.NewServer(handler)
+ defer frontendProxy.Close()
+
+ req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
+ req.Header.Set("Connection", "Upgrade")
+ req.Header.Set("Upgrade", "websocket")
+
+ c := frontendProxy.Client()
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 101 {
+ t.Fatalf("status = %v; want 101", res.Status)
+ }
+
+ got := res.Header.Get("X-Header")
+ want := "X-Value"
+ if got != want {
+ t.Errorf("Header(XHeader) = %q; want %q", got, want)
+ }
+
+ if !ascii.EqualFold(upgradeType(res.Header), "websocket") {
+ t.Fatalf("not websocket upgrade; got %#v", res.Header)
+ }
+ rwc, ok := res.Body.(io.ReadWriteCloser)
+ if !ok {
+ t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
+ }
+ defer rwc.Close()
+
+ if got, want := res.Header.Get("X-Modified"), "true"; got != want {
+ t.Errorf("response X-Modified header = %q; want %q", got, want)
+ }
+
+ io.WriteString(rwc, "Hello\n")
+ bs := bufio.NewScanner(rwc)
+ if !bs.Scan() {
+ t.Fatalf("Scan: %v", bs.Err())
+ }
+ got = bs.Text()
+ want = `backend got "Hello"`
+ if got != want {
+ t.Errorf("got %#q, want %#q", got, want)
+ }
+}
+
+func TestReverseProxyWebSocketCancellation(t *testing.T) {
+ n := 5
+ triggerCancelCh := make(chan bool, n)
+ nthResponse := func(i int) string {
+ return fmt.Sprintf("backend response #%d\n", i)
+ }
+ terminalMsg := "final message"
+
+ cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if g, ws := upgradeType(r.Header), "websocket"; g != ws {
+ t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
+ http.Error(w, "Unexpected request", 400)
+ return
+ }
+ conn, bufrw, err := w.(http.Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer conn.Close()
+
+ upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
+ if _, err := io.WriteString(conn, upgradeMsg); err != nil {
+ t.Error(err)
+ return
+ }
+ if _, _, err := bufrw.ReadLine(); err != nil {
+ t.Errorf("Failed to read line from client: %v", err)
+ return
+ }
+
+ for i := 0; i < n; i++ {
+ if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
+ select {
+ case <-triggerCancelCh:
+ default:
+ t.Errorf("Writing response #%d failed: %v", i, err)
+ }
+ return
+ }
+ bufrw.Flush()
+ time.Sleep(time.Second)
+ }
+ if _, err := bufrw.WriteString(terminalMsg); err != nil {
+ select {
+ case <-triggerCancelCh:
+ default:
+ t.Errorf("Failed to write terminal message: %v", err)
+ }
+ }
+ bufrw.Flush()
+ }))
+ defer cst.Close()
+
+ backendURL, _ := url.Parse(cst.URL)
+ rproxy := NewSingleHostReverseProxy(backendURL)
+ rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ rproxy.ModifyResponse = func(res *http.Response) error {
+ res.Header.Add("X-Modified", "true")
+ return nil
+ }
+
+ handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+ rw.Header().Set("X-Header", "X-Value")
+ ctx, cancel := context.WithCancel(req.Context())
+ go func() {
+ <-triggerCancelCh
+ cancel()
+ }()
+ rproxy.ServeHTTP(rw, req.WithContext(ctx))
+ })
+
+ frontendProxy := httptest.NewServer(handler)
+ defer frontendProxy.Close()
+
+ req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
+ req.Header.Set("Connection", "Upgrade")
+ req.Header.Set("Upgrade", "websocket")
+
+ res, err := frontendProxy.Client().Do(req)
+ if err != nil {
+ t.Fatalf("Dialing to frontend proxy: %v", err)
+ }
+ defer res.Body.Close()
+ if g, w := res.StatusCode, 101; g != w {
+ t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
+ }
+
+ if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
+ t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w)
+ }
+
+ if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) {
+ t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w)
+ }
+
+ rwc, ok := res.Body.(io.ReadWriteCloser)
+ if !ok {
+ t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
+ }
+
+ if got, want := res.Header.Get("X-Modified"), "true"; got != want {
+ t.Errorf("response X-Modified header = %q; want %q", got, want)
+ }
+
+ if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
+ t.Fatalf("Failed to write first message: %v", err)
+ }
+
+ // Read loop.
+
+ br := bufio.NewReader(rwc)
+ for {
+ line, err := br.ReadString('\n')
+ switch {
+ case line == terminalMsg: // this case before "err == io.EOF"
+ t.Fatalf("The websocket request was not canceled, unfortunately!")
+
+ case err == io.EOF:
+ return
+
+ case err != nil:
+ t.Fatalf("Unexpected error: %v", err)
+
+ case line == nthResponse(0): // We've gotten the first response back
+ // Let's trigger a cancel.
+ close(triggerCancelCh)
+ }
+ }
+}
+
+func TestUnannouncedTrailer(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ w.(http.Flusher).Flush()
+ w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+ frontendClient := frontend.Client()
+
+ res, err := frontendClient.Get(frontend.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+
+ io.ReadAll(res.Body)
+
+ if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w {
+ t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w)
+ }
+
+}
+
+func TestSingleJoinSlash(t *testing.T) {
+ tests := []struct {
+ slasha string
+ slashb string
+ expected string
+ }{
+ {"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"},
+ {"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"},
+ {"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"},
+ {"https://www.google.com", "", "https://www.google.com/"},
+ {"", "favicon.ico", "/favicon.ico"},
+ }
+ for _, tt := range tests {
+ if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected {
+ t.Errorf("singleJoiningSlash(%q,%q) want %q got %q",
+ tt.slasha,
+ tt.slashb,
+ tt.expected,
+ got)
+ }
+ }
+}
+
+func TestJoinURLPath(t *testing.T) {
+ tests := []struct {
+ a *url.URL
+ b *url.URL
+ wantPath string
+ wantRaw string
+ }{
+ {&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""},
+ {&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"},
+ {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
+ {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
+ {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"},
+ {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"},
+ }
+
+ for _, tt := range tests {
+ p, rp := joinURLPath(tt.a, tt.b)
+ if p != tt.wantPath || rp != tt.wantRaw {
+ t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)",
+ tt.a.Path, tt.a.RawPath,
+ tt.b.Path, tt.b.RawPath,
+ tt.wantPath, tt.wantRaw,
+ p, rp)
+ }
+ }
+}
+
+const (
+ testWantsCleanQuery = true
+ testWantsRawQuery = false
+)
+
+func TestReverseProxyQueryParameterSmugglingDirectorDoesNotParseForm(t *testing.T) {
+ testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy {
+ proxyHandler := NewSingleHostReverseProxy(u)
+ oldDirector := proxyHandler.Director
+ proxyHandler.Director = func(r *http.Request) {
+ oldDirector(r)
+ }
+ return proxyHandler
+ })
+}
+
+func TestReverseProxyQueryParameterSmugglingDirectorParsesForm(t *testing.T) {
+ testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy {
+ proxyHandler := NewSingleHostReverseProxy(u)
+ oldDirector := proxyHandler.Director
+ proxyHandler.Director = func(r *http.Request) {
+ // Parsing the form causes ReverseProxy to remove unparsable
+ // query parameters before forwarding.
+ r.FormValue("a")
+ oldDirector(r)
+ }
+ return proxyHandler
+ })
+}
+
+func testReverseProxyQueryParameterSmuggling(t *testing.T, wantCleanQuery bool, newProxy func(*url.URL) *ReverseProxy) {
+ const content = "response_content"
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(r.URL.RawQuery))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := newProxy(backendURL)
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ // Don't spam output with logs of queries containing semicolons.
+ backend.Config.ErrorLog = log.New(io.Discard, "", 0)
+ frontend.Config.ErrorLog = log.New(io.Discard, "", 0)
+
+ for _, test := range []struct {
+ rawQuery string
+ cleanQuery string
+ }{{
+ rawQuery: "a=1&a=2;b=3",
+ cleanQuery: "a=1",
+ }, {
+ rawQuery: "a=1&a=%zz&b=3",
+ cleanQuery: "a=1&b=3",
+ }} {
+ res, err := frontend.Client().Get(frontend.URL + "?" + test.rawQuery)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ body, _ := io.ReadAll(res.Body)
+ wantQuery := test.rawQuery
+ if wantCleanQuery {
+ wantQuery = test.cleanQuery
+ }
+ if got, want := string(body), wantQuery; got != want {
+ t.Errorf("proxy forwarded raw query %q as %q, want %q", test.rawQuery, got, want)
+ }
+ }
+}