diff options
Diffstat (limited to 'src/net/http/httputil')
-rw-r--r-- | src/net/http/httputil/dump.go | 337 | ||||
-rw-r--r-- | src/net/http/httputil/dump_test.go | 532 | ||||
-rw-r--r-- | src/net/http/httputil/example_test.go | 128 | ||||
-rw-r--r-- | src/net/http/httputil/httputil.go | 41 | ||||
-rw-r--r-- | src/net/http/httputil/persist.go | 431 | ||||
-rw-r--r-- | src/net/http/httputil/reverseproxy.go | 831 | ||||
-rw-r--r-- | src/net/http/httputil/reverseproxy_test.go | 1863 |
7 files changed, 4163 insertions, 0 deletions
diff --git a/src/net/http/httputil/dump.go b/src/net/http/httputil/dump.go new file mode 100644 index 0000000..2edb9bc --- /dev/null +++ b/src/net/http/httputil/dump.go @@ -0,0 +1,337 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httputil + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + "time" +) + +// drainBody reads all of b to memory and then returns two equivalent +// ReadClosers yielding the same bytes. +// +// It returns an error if the initial slurp of all bytes fails. It does not attempt +// to make the returned ReadClosers have identical error-matching behavior. +func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) { + if b == nil || b == http.NoBody { + // No copying needed. Preserve the magic sentinel meaning of NoBody. + return http.NoBody, http.NoBody, nil + } + var buf bytes.Buffer + if _, err = buf.ReadFrom(b); err != nil { + return nil, b, err + } + if err = b.Close(); err != nil { + return nil, b, err + } + return io.NopCloser(&buf), io.NopCloser(bytes.NewReader(buf.Bytes())), nil +} + +// dumpConn is a net.Conn which writes to Writer and reads from Reader +type dumpConn struct { + io.Writer + io.Reader +} + +func (c *dumpConn) Close() error { return nil } +func (c *dumpConn) LocalAddr() net.Addr { return nil } +func (c *dumpConn) RemoteAddr() net.Addr { return nil } +func (c *dumpConn) SetDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil } + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + +// outgoingLength is a copy of the unexported +// (*http.Request).outgoingLength method. +func outgoingLength(req *http.Request) int64 { + if req.Body == nil || req.Body == http.NoBody { + return 0 + } + if req.ContentLength != 0 { + return req.ContentLength + } + return -1 +} + +// DumpRequestOut is like [DumpRequest] but for outgoing client requests. It +// includes any headers that the standard [http.Transport] adds, such as +// User-Agent. +func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { + save := req.Body + dummyBody := false + if !body { + contentLength := outgoingLength(req) + if contentLength != 0 { + req.Body = io.NopCloser(io.LimitReader(neverEnding('x'), contentLength)) + dummyBody = true + } + } else { + var err error + save, req.Body, err = drainBody(req.Body) + if err != nil { + return nil, err + } + } + + // Since we're using the actual Transport code to write the request, + // switch to http so the Transport doesn't try to do an SSL + // negotiation with our dumpConn and its bytes.Buffer & pipe. + // The wire format for https and http are the same, anyway. + reqSend := req + if req.URL.Scheme == "https" { + reqSend = new(http.Request) + *reqSend = *req + reqSend.URL = new(url.URL) + *reqSend.URL = *req.URL + reqSend.URL.Scheme = "http" + } + + // Use the actual Transport code to record what we would send + // on the wire, but not using TCP. Use a Transport with a + // custom dialer that returns a fake net.Conn that waits + // for the full input (and recording it), and then responds + // with a dummy response. + var buf bytes.Buffer // records the output + pr, pw := io.Pipe() + defer pr.Close() + defer pw.Close() + dr := &delegateReader{c: make(chan io.Reader)} + + t := &http.Transport{ + Dial: func(net, addr string) (net.Conn, error) { + return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil + }, + } + defer t.CloseIdleConnections() + + // We need this channel to ensure that the reader + // goroutine exits if t.RoundTrip returns an error. + // See golang.org/issue/32571. + quitReadCh := make(chan struct{}) + // Wait for the request before replying with a dummy response: + go func() { + req, err := http.ReadRequest(bufio.NewReader(pr)) + if err == nil { + // Ensure all the body is read; otherwise + // we'll get a partial dump. + io.Copy(io.Discard, req.Body) + req.Body.Close() + } + select { + case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"): + case <-quitReadCh: + // Ensure delegateReader.Read doesn't block forever if we get an error. + close(dr.c) + } + }() + + _, err := t.RoundTrip(reqSend) + + req.Body = save + if err != nil { + pw.Close() + dr.err = err + close(quitReadCh) + return nil, err + } + dump := buf.Bytes() + + // If we used a dummy body above, remove it now. + // TODO: if the req.ContentLength is large, we allocate memory + // unnecessarily just to slice it off here. But this is just + // a debug function, so this is acceptable for now. We could + // discard the body earlier if this matters. + if dummyBody { + if i := bytes.Index(dump, []byte("\r\n\r\n")); i >= 0 { + dump = dump[:i+4] + } + } + return dump, nil +} + +// delegateReader is a reader that delegates to another reader, +// once it arrives on a channel. +type delegateReader struct { + c chan io.Reader + err error // only used if r is nil and c is closed. + r io.Reader // nil until received from c +} + +func (r *delegateReader) Read(p []byte) (int, error) { + if r.r == nil { + var ok bool + if r.r, ok = <-r.c; !ok { + return 0, r.err + } + } + return r.r.Read(p) +} + +// Return value if nonempty, def otherwise. +func valueOrDefault(value, def string) string { + if value != "" { + return value + } + return def +} + +var reqWriteExcludeHeaderDump = map[string]bool{ + "Host": true, // not in Header map anyway + "Transfer-Encoding": true, + "Trailer": true, +} + +// DumpRequest returns the given request in its HTTP/1.x wire +// representation. It should only be used by servers to debug client +// requests. The returned representation is an approximation only; +// some details of the initial request are lost while parsing it into +// an [http.Request]. In particular, the order and case of header field +// names are lost. The order of values in multi-valued headers is kept +// intact. HTTP/2 requests are dumped in HTTP/1.x form, not in their +// original binary representations. +// +// If body is true, DumpRequest also returns the body. To do so, it +// consumes req.Body and then replaces it with a new [io.ReadCloser] +// that yields the same bytes. If DumpRequest returns an error, +// the state of req is undefined. +// +// The documentation for [http.Request.Write] details which fields +// of req are included in the dump. +func DumpRequest(req *http.Request, body bool) ([]byte, error) { + var err error + save := req.Body + if !body || req.Body == nil { + req.Body = nil + } else { + save, req.Body, err = drainBody(req.Body) + if err != nil { + return nil, err + } + } + + var b bytes.Buffer + + // By default, print out the unmodified req.RequestURI, which + // is always set for incoming server requests. But because we + // previously used req.URL.RequestURI and the docs weren't + // always so clear about when to use DumpRequest vs + // DumpRequestOut, fall back to the old way if the caller + // provides a non-server Request. + reqURI := req.RequestURI + if reqURI == "" { + reqURI = req.URL.RequestURI() + } + + fmt.Fprintf(&b, "%s %s HTTP/%d.%d\r\n", valueOrDefault(req.Method, "GET"), + reqURI, req.ProtoMajor, req.ProtoMinor) + + absRequestURI := strings.HasPrefix(req.RequestURI, "http://") || strings.HasPrefix(req.RequestURI, "https://") + if !absRequestURI { + host := req.Host + if host == "" && req.URL != nil { + host = req.URL.Host + } + if host != "" { + fmt.Fprintf(&b, "Host: %s\r\n", host) + } + } + + chunked := len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked" + if len(req.TransferEncoding) > 0 { + fmt.Fprintf(&b, "Transfer-Encoding: %s\r\n", strings.Join(req.TransferEncoding, ",")) + } + + err = req.Header.WriteSubset(&b, reqWriteExcludeHeaderDump) + if err != nil { + return nil, err + } + + io.WriteString(&b, "\r\n") + + if req.Body != nil { + var dest io.Writer = &b + if chunked { + dest = NewChunkedWriter(dest) + } + _, err = io.Copy(dest, req.Body) + if chunked { + dest.(io.Closer).Close() + io.WriteString(&b, "\r\n") + } + } + + req.Body = save + if err != nil { + return nil, err + } + return b.Bytes(), nil +} + +// errNoBody is a sentinel error value used by failureToReadBody so we +// can detect that the lack of body was intentional. +var errNoBody = errors.New("sentinel error value") + +// failureToReadBody is an io.ReadCloser that just returns errNoBody on +// Read. It's swapped in when we don't actually want to consume +// the body, but need a non-nil one, and want to distinguish the +// error from reading the dummy body. +type failureToReadBody struct{} + +func (failureToReadBody) Read([]byte) (int, error) { return 0, errNoBody } +func (failureToReadBody) Close() error { return nil } + +// emptyBody is an instance of empty reader. +var emptyBody = io.NopCloser(strings.NewReader("")) + +// DumpResponse is like DumpRequest but dumps a response. +func DumpResponse(resp *http.Response, body bool) ([]byte, error) { + var b bytes.Buffer + var err error + save := resp.Body + savecl := resp.ContentLength + + if !body { + // For content length of zero. Make sure the body is an empty + // reader, instead of returning error through failureToReadBody{}. + if resp.ContentLength == 0 { + resp.Body = emptyBody + } else { + resp.Body = failureToReadBody{} + } + } else if resp.Body == nil { + resp.Body = emptyBody + } else { + save, resp.Body, err = drainBody(resp.Body) + if err != nil { + return nil, err + } + } + err = resp.Write(&b) + if err == errNoBody { + err = nil + } + resp.Body = save + resp.ContentLength = savecl + if err != nil { + return nil, err + } + return b.Bytes(), nil +} diff --git a/src/net/http/httputil/dump_test.go b/src/net/http/httputil/dump_test.go new file mode 100644 index 0000000..c20c054 --- /dev/null +++ b/src/net/http/httputil/dump_test.go @@ -0,0 +1,532 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httputil + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "math/rand" + "net/http" + "net/url" + "runtime" + "runtime/pprof" + "strings" + "testing" + "time" +) + +type eofReader struct{} + +func (n eofReader) Close() error { return nil } + +func (n eofReader) Read([]byte) (int, error) { return 0, io.EOF } + +type dumpTest struct { + // Either Req or GetReq can be set/nil but not both. + Req *http.Request + GetReq func() *http.Request + + Body any // optional []byte or func() io.ReadCloser to populate Req.Body + + WantDump string + WantDumpOut string + MustError bool // if true, the test is expected to throw an error + NoBody bool // if true, set DumpRequest{,Out} body to false +} + +var dumpTests = []dumpTest{ + // HTTP/1.1 => chunked coding; body; empty trailer + { + Req: &http.Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/search", + }, + ProtoMajor: 1, + ProtoMinor: 1, + TransferEncoding: []string{"chunked"}, + }, + + Body: []byte("abcdef"), + + WantDump: "GET /search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + chunk("abcdef") + chunk(""), + }, + + // Verify that DumpRequest preserves the HTTP version number, doesn't add a Host, + // and doesn't add a User-Agent. + { + Req: &http.Request{ + Method: "GET", + URL: mustParseURL("/foo"), + ProtoMajor: 1, + ProtoMinor: 0, + Header: http.Header{ + "X-Foo": []string{"X-Bar"}, + }, + }, + + WantDump: "GET /foo HTTP/1.0\r\n" + + "X-Foo: X-Bar\r\n\r\n", + }, + + { + Req: mustNewRequest("GET", "http://example.com/foo", nil), + + WantDumpOut: "GET /foo HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go-http-client/1.1\r\n" + + "Accept-Encoding: gzip\r\n\r\n", + }, + + // Test that an https URL doesn't try to do an SSL negotiation + // with a bytes.Buffer and hang with all goroutines not + // runnable. + { + Req: mustNewRequest("GET", "https://example.com/foo", nil), + WantDumpOut: "GET /foo HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go-http-client/1.1\r\n" + + "Accept-Encoding: gzip\r\n\r\n", + }, + + // Request with Body, but Dump requested without it. + { + Req: &http.Request{ + Method: "POST", + URL: &url.URL{ + Scheme: "http", + Host: "post.tld", + Path: "/", + }, + ContentLength: 6, + ProtoMajor: 1, + ProtoMinor: 1, + }, + + Body: []byte("abcdef"), + + WantDumpOut: "POST / HTTP/1.1\r\n" + + "Host: post.tld\r\n" + + "User-Agent: Go-http-client/1.1\r\n" + + "Content-Length: 6\r\n" + + "Accept-Encoding: gzip\r\n\r\n", + + NoBody: true, + }, + + // Request with Body > 8196 (default buffer size) + { + Req: &http.Request{ + Method: "POST", + URL: &url.URL{ + Scheme: "http", + Host: "post.tld", + Path: "/", + }, + Header: http.Header{ + "Content-Length": []string{"8193"}, + }, + + ContentLength: 8193, + ProtoMajor: 1, + ProtoMinor: 1, + }, + + Body: bytes.Repeat([]byte("a"), 8193), + + WantDumpOut: "POST / HTTP/1.1\r\n" + + "Host: post.tld\r\n" + + "User-Agent: Go-http-client/1.1\r\n" + + "Content-Length: 8193\r\n" + + "Accept-Encoding: gzip\r\n\r\n" + + strings.Repeat("a", 8193), + WantDump: "POST / HTTP/1.1\r\n" + + "Host: post.tld\r\n" + + "Content-Length: 8193\r\n\r\n" + + strings.Repeat("a", 8193), + }, + + { + GetReq: func() *http.Request { + return mustReadRequest("GET http://foo.com/ HTTP/1.1\r\n" + + "User-Agent: blah\r\n\r\n") + }, + NoBody: true, + WantDump: "GET http://foo.com/ HTTP/1.1\r\n" + + "User-Agent: blah\r\n\r\n", + }, + + // Issue #7215. DumpRequest should return the "Content-Length" when set + { + GetReq: func() *http.Request { + return mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" + + "Host: passport.myhost.com\r\n" + + "Content-Length: 3\r\n" + + "\r\nkey1=name1&key2=name2") + }, + WantDump: "POST /v2/api/?login HTTP/1.1\r\n" + + "Host: passport.myhost.com\r\n" + + "Content-Length: 3\r\n" + + "\r\nkey", + }, + // Issue #7215. DumpRequest should return the "Content-Length" in ReadRequest + { + GetReq: func() *http.Request { + return mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" + + "Host: passport.myhost.com\r\n" + + "Content-Length: 0\r\n" + + "\r\nkey1=name1&key2=name2") + }, + WantDump: "POST /v2/api/?login HTTP/1.1\r\n" + + "Host: passport.myhost.com\r\n" + + "Content-Length: 0\r\n\r\n", + }, + + // Issue #7215. DumpRequest should not return the "Content-Length" if unset + { + GetReq: func() *http.Request { + return mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" + + "Host: passport.myhost.com\r\n" + + "\r\nkey1=name1&key2=name2") + }, + WantDump: "POST /v2/api/?login HTTP/1.1\r\n" + + "Host: passport.myhost.com\r\n\r\n", + }, + + // Issue 18506: make drainBody recognize NoBody. Otherwise + // this was turning into a chunked request. + { + Req: mustNewRequest("POST", "http://example.com/foo", http.NoBody), + WantDumpOut: "POST /foo HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go-http-client/1.1\r\n" + + "Content-Length: 0\r\n" + + "Accept-Encoding: gzip\r\n\r\n", + }, + + // Issue 34504: a non-nil Body without ContentLength set should be chunked + { + Req: &http.Request{ + Method: "PUT", + URL: &url.URL{ + Scheme: "http", + Host: "post.tld", + Path: "/test", + }, + ContentLength: 0, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Body: &eofReader{}, + }, + NoBody: true, + WantDumpOut: "PUT /test HTTP/1.1\r\n" + + "Host: post.tld\r\n" + + "User-Agent: Go-http-client/1.1\r\n" + + "Transfer-Encoding: chunked\r\n" + + "Accept-Encoding: gzip\r\n\r\n", + }, + + // Issue 54616: request with Connection header doesn't result in duplicate header. + { + GetReq: func() *http.Request { + return mustReadRequest("GET / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Connection: close\r\n\r\n") + }, + NoBody: true, + WantDump: "GET / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Connection: close\r\n\r\n", + }, +} + +func TestDumpRequest(t *testing.T) { + // Make a copy of dumpTests and add 10 new cases with an empty URL + // to test that no goroutines are leaked. See golang.org/issue/32571. + // 10 seems to be a decent number which always triggers the failure. + dumpTests := dumpTests[:] + for i := 0; i < 10; i++ { + dumpTests = append(dumpTests, dumpTest{ + Req: mustNewRequest("GET", "", nil), + MustError: true, + }) + } + numg0 := runtime.NumGoroutine() + for i, tt := range dumpTests { + if tt.Req != nil && tt.GetReq != nil || tt.Req == nil && tt.GetReq == nil { + t.Errorf("#%d: either .Req(%p) or .GetReq(%p) can be set/nil but not both", i, tt.Req, tt.GetReq) + continue + } + + freshReq := func(ti dumpTest) *http.Request { + req := ti.Req + if req == nil { + req = ti.GetReq() + } + + if req.Header == nil { + req.Header = make(http.Header) + } + + if ti.Body == nil { + return req + } + switch b := ti.Body.(type) { + case []byte: + req.Body = io.NopCloser(bytes.NewReader(b)) + case func() io.ReadCloser: + req.Body = b() + default: + t.Fatalf("Test %d: unsupported Body of %T", i, ti.Body) + } + return req + } + + if tt.WantDump != "" { + req := freshReq(tt) + dump, err := DumpRequest(req, !tt.NoBody) + if err != nil { + t.Errorf("DumpRequest #%d: %s\nWantDump:\n%s", i, err, tt.WantDump) + continue + } + if string(dump) != tt.WantDump { + t.Errorf("DumpRequest %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantDump, string(dump)) + continue + } + } + + if tt.MustError { + req := freshReq(tt) + _, err := DumpRequestOut(req, !tt.NoBody) + if err == nil { + t.Errorf("DumpRequestOut #%d: expected an error, got nil", i) + } + continue + } + + if tt.WantDumpOut != "" { + req := freshReq(tt) + dump, err := DumpRequestOut(req, !tt.NoBody) + if err != nil { + t.Errorf("DumpRequestOut #%d: %s", i, err) + continue + } + if string(dump) != tt.WantDumpOut { + t.Errorf("DumpRequestOut %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantDumpOut, string(dump)) + continue + } + } + } + + // Validate we haven't leaked any goroutines. + var dg int + dl := deadline(t, 5*time.Second, time.Second) + for time.Now().Before(dl) { + if dg = runtime.NumGoroutine() - numg0; dg <= 4 { + // No unexpected goroutines. + return + } + + // Allow goroutines to schedule and die off. + runtime.Gosched() + } + + buf := make([]byte, 4096) + buf = buf[:runtime.Stack(buf, true)] + t.Errorf("Unexpectedly large number of new goroutines: %d new: %s", dg, buf) +} + +// deadline returns the time which is needed before t.Deadline() +// if one is configured and it is s greater than needed in the future, +// otherwise defaultDelay from the current time. +func deadline(t *testing.T, defaultDelay, needed time.Duration) time.Time { + if dl, ok := t.Deadline(); ok { + if dl = dl.Add(-needed); dl.After(time.Now()) { + // Allow an arbitrarily long delay. + return dl + } + } + + // No deadline configured or its closer than needed from now + // so just use the default. + return time.Now().Add(defaultDelay) +} + +func chunk(s string) string { + return fmt.Sprintf("%x\r\n%s\r\n", len(s), s) +} + +func mustParseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(fmt.Sprintf("Error parsing URL %q: %v", s, err)) + } + return u +} + +func mustNewRequest(method, url string, body io.Reader) *http.Request { + req, err := http.NewRequest(method, url, body) + if err != nil { + panic(fmt.Sprintf("NewRequest(%q, %q, %p) err = %v", method, url, body, err)) + } + return req +} + +func mustReadRequest(s string) *http.Request { + req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(s))) + if err != nil { + panic(err) + } + return req +} + +var dumpResTests = []struct { + res *http.Response + body bool + want string +}{ + { + res: &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 50, + Header: http.Header{ + "Foo": []string{"Bar"}, + }, + Body: io.NopCloser(strings.NewReader("foo")), // shouldn't be used + }, + body: false, // to verify we see 50, not empty or 3. + want: `HTTP/1.1 200 OK +Content-Length: 50 +Foo: Bar`, + }, + + { + res: &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 3, + Body: io.NopCloser(strings.NewReader("foo")), + }, + body: true, + want: `HTTP/1.1 200 OK +Content-Length: 3 + +foo`, + }, + + { + res: &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: -1, + Body: io.NopCloser(strings.NewReader("foo")), + TransferEncoding: []string{"chunked"}, + }, + body: true, + want: `HTTP/1.1 200 OK +Transfer-Encoding: chunked + +3 +foo +0`, + }, + { + res: &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 0, + Header: http.Header{ + // To verify if headers are not filtered out. + "Foo1": []string{"Bar1"}, + "Foo2": []string{"Bar2"}, + }, + Body: nil, + }, + body: false, // to verify we see 0, not empty. + want: `HTTP/1.1 200 OK +Foo1: Bar1 +Foo2: Bar2 +Content-Length: 0`, + }, +} + +func TestDumpResponse(t *testing.T) { + for i, tt := range dumpResTests { + gotb, err := DumpResponse(tt.res, tt.body) + if err != nil { + t.Errorf("%d. DumpResponse = %v", i, err) + continue + } + got := string(gotb) + got = strings.TrimSpace(got) + got = strings.ReplaceAll(got, "\r", "") + + if got != tt.want { + t.Errorf("%d.\nDumpResponse got:\n%s\n\nWant:\n%s\n", i, got, tt.want) + } + } +} + +// Issue 38352: Check for deadlock on canceled requests. +func TestDumpRequestOutIssue38352(t *testing.T) { + if testing.Short() { + return + } + t.Parallel() + + timeout := 10 * time.Second + if deadline, ok := t.Deadline(); ok { + timeout = time.Until(deadline) + timeout -= time.Second * 2 // Leave 2 seconds to report failures. + } + for i := 0; i < 1000; i++ { + delay := time.Duration(rand.Intn(5)) * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), delay) + defer cancel() + + r := bytes.NewBuffer(make([]byte, 10000)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://example.com", r) + if err != nil { + t.Fatal(err) + } + + out := make(chan error) + go func() { + _, err = DumpRequestOut(req, true) + out <- err + }() + + select { + case <-out: + case <-time.After(timeout): + b := &strings.Builder{} + fmt.Fprintf(b, "deadlock detected on iteration %d after %s with delay: %v\n", i, timeout, delay) + pprof.Lookup("goroutine").WriteTo(b, 1) + t.Fatal(b.String()) + } + } +} diff --git a/src/net/http/httputil/example_test.go b/src/net/http/httputil/example_test.go new file mode 100644 index 0000000..6c107f8 --- /dev/null +++ b/src/net/http/httputil/example_test.go @@ -0,0 +1,128 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httputil_test + +import ( + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "strings" +) + +func ExampleDumpRequest() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dump, err := httputil.DumpRequest(r, true) + if err != nil { + http.Error(w, fmt.Sprint(err), http.StatusInternalServerError) + return + } + + fmt.Fprintf(w, "%q", dump) + })) + defer ts.Close() + + const body = "Go is a general-purpose language designed with systems programming in mind." + req, err := http.NewRequest("POST", ts.URL, strings.NewReader(body)) + if err != nil { + log.Fatal(err) + } + req.Host = "www.example.org" + resp, err := http.DefaultClient.Do(req) + if err != nil { + log.Fatal(err) + } + defer resp.Body.Close() + + b, err := io.ReadAll(resp.Body) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%s", b) + + // Output: + // "POST / HTTP/1.1\r\nHost: www.example.org\r\nAccept-Encoding: gzip\r\nContent-Length: 75\r\nUser-Agent: Go-http-client/1.1\r\n\r\nGo is a general-purpose language designed with systems programming in mind." +} + +func ExampleDumpRequestOut() { + const body = "Go is a general-purpose language designed with systems programming in mind." + req, err := http.NewRequest("PUT", "http://www.example.org", strings.NewReader(body)) + if err != nil { + log.Fatal(err) + } + + dump, err := httputil.DumpRequestOut(req, true) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%q", dump) + + // Output: + // "PUT / HTTP/1.1\r\nHost: www.example.org\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 75\r\nAccept-Encoding: gzip\r\n\r\nGo is a general-purpose language designed with systems programming in mind." +} + +func ExampleDumpResponse() { + const body = "Go is a general-purpose language designed with systems programming in mind." + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Date", "Wed, 19 Jul 1972 19:00:00 GMT") + fmt.Fprintln(w, body) + })) + defer ts.Close() + + resp, err := http.Get(ts.URL) + if err != nil { + log.Fatal(err) + } + defer resp.Body.Close() + + dump, err := httputil.DumpResponse(resp, true) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%q", dump) + + // Output: + // "HTTP/1.1 200 OK\r\nContent-Length: 76\r\nContent-Type: text/plain; charset=utf-8\r\nDate: Wed, 19 Jul 1972 19:00:00 GMT\r\n\r\nGo is a general-purpose language designed with systems programming in mind.\n" +} + +func ExampleReverseProxy() { + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "this call was relayed by the reverse proxy") + })) + defer backendServer.Close() + + rpURL, err := url.Parse(backendServer.URL) + if err != nil { + log.Fatal(err) + } + frontendProxy := httptest.NewServer(&httputil.ReverseProxy{ + Rewrite: func(r *httputil.ProxyRequest) { + r.SetXForwarded() + r.SetURL(rpURL) + }, + }) + defer frontendProxy.Close() + + resp, err := http.Get(frontendProxy.URL) + if err != nil { + log.Fatal(err) + } + + b, err := io.ReadAll(resp.Body) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%s", b) + + // Output: + // this call was relayed by the reverse proxy +} diff --git a/src/net/http/httputil/httputil.go b/src/net/http/httputil/httputil.go new file mode 100644 index 0000000..431930e --- /dev/null +++ b/src/net/http/httputil/httputil.go @@ -0,0 +1,41 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package httputil provides HTTP utility functions, complementing the +// more common ones in the net/http package. +package httputil + +import ( + "io" + "net/http/internal" +) + +// NewChunkedReader returns a new chunkedReader that translates the data read from r +// out of HTTP "chunked" format before returning it. +// The chunkedReader returns [io.EOF] when the final 0-length chunk is read. +// +// NewChunkedReader is not needed by normal applications. The http package +// automatically decodes chunking when reading response bodies. +func NewChunkedReader(r io.Reader) io.Reader { + return internal.NewChunkedReader(r) +} + +// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP +// "chunked" format before writing them to w. Closing the returned chunkedWriter +// sends the final 0-length chunk that marks the end of the stream but does +// not send the final CRLF that appears after trailers; trailers and the last +// CRLF must be written separately. +// +// NewChunkedWriter is not needed by normal applications. The http +// package adds chunking automatically if handlers don't set a +// Content-Length header. Using NewChunkedWriter inside a handler +// would result in double chunking or chunking with a Content-Length +// length, both of which are wrong. +func NewChunkedWriter(w io.Writer) io.WriteCloser { + return internal.NewChunkedWriter(w) +} + +// ErrLineTooLong is returned when reading malformed chunked data +// with lines that are too long. +var ErrLineTooLong = internal.ErrLineTooLong diff --git a/src/net/http/httputil/persist.go b/src/net/http/httputil/persist.go new file mode 100644 index 0000000..0cbe3eb --- /dev/null +++ b/src/net/http/httputil/persist.go @@ -0,0 +1,431 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httputil + +import ( + "bufio" + "errors" + "io" + "net" + "net/http" + "net/textproto" + "sync" +) + +var ( + // Deprecated: No longer used. + ErrPersistEOF = &http.ProtocolError{ErrorString: "persistent connection closed"} + + // Deprecated: No longer used. + ErrClosed = &http.ProtocolError{ErrorString: "connection closed by user"} + + // Deprecated: No longer used. + ErrPipeline = &http.ProtocolError{ErrorString: "pipeline error"} +) + +// This is an API usage error - the local side is closed. +// ErrPersistEOF (above) reports that the remote side is closed. +var errClosed = errors.New("i/o operation on closed connection") + +// ServerConn is an artifact of Go's early HTTP implementation. +// It is low-level, old, and unused by Go's current HTTP stack. +// We should have deleted it before Go 1. +// +// Deprecated: Use the Server in package [net/http] instead. +type ServerConn struct { + mu sync.Mutex // read-write protects the following fields + c net.Conn + r *bufio.Reader + re, we error // read/write errors + lastbody io.ReadCloser + nread, nwritten int + pipereq map[*http.Request]uint + + pipe textproto.Pipeline +} + +// NewServerConn is an artifact of Go's early HTTP implementation. +// It is low-level, old, and unused by Go's current HTTP stack. +// We should have deleted it before Go 1. +// +// Deprecated: Use the Server in package [net/http] instead. +func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn { + if r == nil { + r = bufio.NewReader(c) + } + return &ServerConn{c: c, r: r, pipereq: make(map[*http.Request]uint)} +} + +// Hijack detaches the [ServerConn] and returns the underlying connection as well +// as the read-side bufio which may have some left over data. Hijack may be +// called before Read has signaled the end of the keep-alive logic. The user +// should not call Hijack while [ServerConn.Read] or [ServerConn.Write] is in progress. +func (sc *ServerConn) Hijack() (net.Conn, *bufio.Reader) { + sc.mu.Lock() + defer sc.mu.Unlock() + c := sc.c + r := sc.r + sc.c = nil + sc.r = nil + return c, r +} + +// Close calls [ServerConn.Hijack] and then also closes the underlying connection. +func (sc *ServerConn) Close() error { + c, _ := sc.Hijack() + if c != nil { + return c.Close() + } + return nil +} + +// Read returns the next request on the wire. An [ErrPersistEOF] is returned if +// it is gracefully determined that there are no more requests (e.g. after the +// first request on an HTTP/1.0 connection, or after a Connection:close on a +// HTTP/1.1 connection). +func (sc *ServerConn) Read() (*http.Request, error) { + var req *http.Request + var err error + + // Ensure ordered execution of Reads and Writes + id := sc.pipe.Next() + sc.pipe.StartRequest(id) + defer func() { + sc.pipe.EndRequest(id) + if req == nil { + sc.pipe.StartResponse(id) + sc.pipe.EndResponse(id) + } else { + // Remember the pipeline id of this request + sc.mu.Lock() + sc.pipereq[req] = id + sc.mu.Unlock() + } + }() + + sc.mu.Lock() + if sc.we != nil { // no point receiving if write-side broken or closed + defer sc.mu.Unlock() + return nil, sc.we + } + if sc.re != nil { + defer sc.mu.Unlock() + return nil, sc.re + } + if sc.r == nil { // connection closed by user in the meantime + defer sc.mu.Unlock() + return nil, errClosed + } + r := sc.r + lastbody := sc.lastbody + sc.lastbody = nil + sc.mu.Unlock() + + // Make sure body is fully consumed, even if user does not call body.Close + if lastbody != nil { + // body.Close is assumed to be idempotent and multiple calls to + // it should return the error that its first invocation + // returned. + err = lastbody.Close() + if err != nil { + sc.mu.Lock() + defer sc.mu.Unlock() + sc.re = err + return nil, err + } + } + + req, err = http.ReadRequest(r) + sc.mu.Lock() + defer sc.mu.Unlock() + if err != nil { + if err == io.ErrUnexpectedEOF { + // A close from the opposing client is treated as a + // graceful close, even if there was some unparse-able + // data before the close. + sc.re = ErrPersistEOF + return nil, sc.re + } else { + sc.re = err + return req, err + } + } + sc.lastbody = req.Body + sc.nread++ + if req.Close { + sc.re = ErrPersistEOF + return req, sc.re + } + return req, err +} + +// Pending returns the number of unanswered requests +// that have been received on the connection. +func (sc *ServerConn) Pending() int { + sc.mu.Lock() + defer sc.mu.Unlock() + return sc.nread - sc.nwritten +} + +// Write writes resp in response to req. To close the connection gracefully, set the +// Response.Close field to true. Write should be considered operational until +// it returns an error, regardless of any errors returned on the [ServerConn.Read] side. +func (sc *ServerConn) Write(req *http.Request, resp *http.Response) error { + + // Retrieve the pipeline ID of this request/response pair + sc.mu.Lock() + id, ok := sc.pipereq[req] + delete(sc.pipereq, req) + if !ok { + sc.mu.Unlock() + return ErrPipeline + } + sc.mu.Unlock() + + // Ensure pipeline order + sc.pipe.StartResponse(id) + defer sc.pipe.EndResponse(id) + + sc.mu.Lock() + if sc.we != nil { + defer sc.mu.Unlock() + return sc.we + } + if sc.c == nil { // connection closed by user in the meantime + defer sc.mu.Unlock() + return ErrClosed + } + c := sc.c + if sc.nread <= sc.nwritten { + defer sc.mu.Unlock() + return errors.New("persist server pipe count") + } + if resp.Close { + // After signaling a keep-alive close, any pipelined unread + // requests will be lost. It is up to the user to drain them + // before signaling. + sc.re = ErrPersistEOF + } + sc.mu.Unlock() + + err := resp.Write(c) + sc.mu.Lock() + defer sc.mu.Unlock() + if err != nil { + sc.we = err + return err + } + sc.nwritten++ + + return nil +} + +// ClientConn is an artifact of Go's early HTTP implementation. +// It is low-level, old, and unused by Go's current HTTP stack. +// We should have deleted it before Go 1. +// +// Deprecated: Use Client or Transport in package [net/http] instead. +type ClientConn struct { + mu sync.Mutex // read-write protects the following fields + c net.Conn + r *bufio.Reader + re, we error // read/write errors + lastbody io.ReadCloser + nread, nwritten int + pipereq map[*http.Request]uint + + pipe textproto.Pipeline + writeReq func(*http.Request, io.Writer) error +} + +// NewClientConn is an artifact of Go's early HTTP implementation. +// It is low-level, old, and unused by Go's current HTTP stack. +// We should have deleted it before Go 1. +// +// Deprecated: Use the Client or Transport in package [net/http] instead. +func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn { + if r == nil { + r = bufio.NewReader(c) + } + return &ClientConn{ + c: c, + r: r, + pipereq: make(map[*http.Request]uint), + writeReq: (*http.Request).Write, + } +} + +// NewProxyClientConn is an artifact of Go's early HTTP implementation. +// It is low-level, old, and unused by Go's current HTTP stack. +// We should have deleted it before Go 1. +// +// Deprecated: Use the Client or Transport in package [net/http] instead. +func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn { + cc := NewClientConn(c, r) + cc.writeReq = (*http.Request).WriteProxy + return cc +} + +// Hijack detaches the [ClientConn] and returns the underlying connection as well +// as the read-side bufio which may have some left over data. Hijack may be +// called before the user or Read have signaled the end of the keep-alive +// logic. The user should not call Hijack while [ClientConn.Read] or ClientConn.Write is in progress. +func (cc *ClientConn) Hijack() (c net.Conn, r *bufio.Reader) { + cc.mu.Lock() + defer cc.mu.Unlock() + c = cc.c + r = cc.r + cc.c = nil + cc.r = nil + return +} + +// Close calls [ClientConn.Hijack] and then also closes the underlying connection. +func (cc *ClientConn) Close() error { + c, _ := cc.Hijack() + if c != nil { + return c.Close() + } + return nil +} + +// Write writes a request. An [ErrPersistEOF] error is returned if the connection +// has been closed in an HTTP keep-alive sense. If req.Close equals true, the +// keep-alive connection is logically closed after this request and the opposing +// server is informed. An ErrUnexpectedEOF indicates the remote closed the +// underlying TCP connection, which is usually considered as graceful close. +func (cc *ClientConn) Write(req *http.Request) error { + var err error + + // Ensure ordered execution of Writes + id := cc.pipe.Next() + cc.pipe.StartRequest(id) + defer func() { + cc.pipe.EndRequest(id) + if err != nil { + cc.pipe.StartResponse(id) + cc.pipe.EndResponse(id) + } else { + // Remember the pipeline id of this request + cc.mu.Lock() + cc.pipereq[req] = id + cc.mu.Unlock() + } + }() + + cc.mu.Lock() + if cc.re != nil { // no point sending if read-side closed or broken + defer cc.mu.Unlock() + return cc.re + } + if cc.we != nil { + defer cc.mu.Unlock() + return cc.we + } + if cc.c == nil { // connection closed by user in the meantime + defer cc.mu.Unlock() + return errClosed + } + c := cc.c + if req.Close { + // We write the EOF to the write-side error, because there + // still might be some pipelined reads + cc.we = ErrPersistEOF + } + cc.mu.Unlock() + + err = cc.writeReq(req, c) + cc.mu.Lock() + defer cc.mu.Unlock() + if err != nil { + cc.we = err + return err + } + cc.nwritten++ + + return nil +} + +// Pending returns the number of unanswered requests +// that have been sent on the connection. +func (cc *ClientConn) Pending() int { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.nwritten - cc.nread +} + +// Read reads the next response from the wire. A valid response might be +// returned together with an [ErrPersistEOF], which means that the remote +// requested that this be the last request serviced. Read can be called +// concurrently with [ClientConn.Write], but not with another Read. +func (cc *ClientConn) Read(req *http.Request) (resp *http.Response, err error) { + // Retrieve the pipeline ID of this request/response pair + cc.mu.Lock() + id, ok := cc.pipereq[req] + delete(cc.pipereq, req) + if !ok { + cc.mu.Unlock() + return nil, ErrPipeline + } + cc.mu.Unlock() + + // Ensure pipeline order + cc.pipe.StartResponse(id) + defer cc.pipe.EndResponse(id) + + cc.mu.Lock() + if cc.re != nil { + defer cc.mu.Unlock() + return nil, cc.re + } + if cc.r == nil { // connection closed by user in the meantime + defer cc.mu.Unlock() + return nil, errClosed + } + r := cc.r + lastbody := cc.lastbody + cc.lastbody = nil + cc.mu.Unlock() + + // Make sure body is fully consumed, even if user does not call body.Close + if lastbody != nil { + // body.Close is assumed to be idempotent and multiple calls to + // it should return the error that its first invocation + // returned. + err = lastbody.Close() + if err != nil { + cc.mu.Lock() + defer cc.mu.Unlock() + cc.re = err + return nil, err + } + } + + resp, err = http.ReadResponse(r, req) + cc.mu.Lock() + defer cc.mu.Unlock() + if err != nil { + cc.re = err + return resp, err + } + cc.lastbody = resp.Body + + cc.nread++ + + if resp.Close { + cc.re = ErrPersistEOF // don't send any more requests + return resp, cc.re + } + return resp, err +} + +// Do is convenience method that writes a request and reads a response. +func (cc *ClientConn) Do(req *http.Request) (*http.Response, error) { + err := cc.Write(req) + if err != nil { + return nil, err + } + return cc.Read(req) +} diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go new file mode 100644 index 0000000..5c70f0d --- /dev/null +++ b/src/net/http/httputil/reverseproxy.go @@ -0,0 +1,831 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP reverse proxy handler + +package httputil + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "mime" + "net" + "net/http" + "net/http/httptrace" + "net/http/internal/ascii" + "net/textproto" + "net/url" + "strings" + "sync" + "time" + + "golang.org/x/net/http/httpguts" +) + +// A ProxyRequest contains a request to be rewritten by a [ReverseProxy]. +type ProxyRequest struct { + // In is the request received by the proxy. + // The Rewrite function must not modify In. + In *http.Request + + // Out is the request which will be sent by the proxy. + // The Rewrite function may modify or replace this request. + // Hop-by-hop headers are removed from this request + // before Rewrite is called. + Out *http.Request +} + +// SetURL routes the outbound request to the scheme, host, and base path +// provided in target. If the target's path is "/base" and the incoming +// request was for "/dir", the target request will be for "/base/dir". +// +// SetURL rewrites the outbound Host header to match the target's host. +// To preserve the inbound request's Host header (the default behavior +// of [NewSingleHostReverseProxy]): +// +// rewriteFunc := func(r *httputil.ProxyRequest) { +// r.SetURL(url) +// r.Out.Host = r.In.Host +// } +func (r *ProxyRequest) SetURL(target *url.URL) { + rewriteRequestURL(r.Out, target) + r.Out.Host = "" +} + +// SetXForwarded sets the X-Forwarded-For, X-Forwarded-Host, and +// X-Forwarded-Proto headers of the outbound request. +// +// - The X-Forwarded-For header is set to the client IP address. +// - The X-Forwarded-Host header is set to the host name requested +// by the client. +// - The X-Forwarded-Proto header is set to "http" or "https", depending +// on whether the inbound request was made on a TLS-enabled connection. +// +// If the outbound request contains an existing X-Forwarded-For header, +// SetXForwarded appends the client IP address to it. To append to the +// inbound request's X-Forwarded-For header (the default behavior of +// [ReverseProxy] when using a Director function), copy the header +// from the inbound request before calling SetXForwarded: +// +// rewriteFunc := func(r *httputil.ProxyRequest) { +// r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"] +// r.SetXForwarded() +// } +func (r *ProxyRequest) SetXForwarded() { + clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr) + if err == nil { + prior := r.Out.Header["X-Forwarded-For"] + if len(prior) > 0 { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + r.Out.Header.Set("X-Forwarded-For", clientIP) + } else { + r.Out.Header.Del("X-Forwarded-For") + } + r.Out.Header.Set("X-Forwarded-Host", r.In.Host) + if r.In.TLS == nil { + r.Out.Header.Set("X-Forwarded-Proto", "http") + } else { + r.Out.Header.Set("X-Forwarded-Proto", "https") + } +} + +// ReverseProxy is an HTTP Handler that takes an incoming request and +// sends it to another server, proxying the response back to the +// client. +// +// 1xx responses are forwarded to the client if the underlying +// transport supports ClientTrace.Got1xxResponse. +type ReverseProxy struct { + // Rewrite must be a function which modifies + // the request into a new request to be sent + // using Transport. Its response is then copied + // back to the original client unmodified. + // Rewrite must not access the provided ProxyRequest + // or its contents after returning. + // + // The Forwarded, X-Forwarded, X-Forwarded-Host, + // and X-Forwarded-Proto headers are removed from the + // outbound request before Rewrite is called. See also + // the ProxyRequest.SetXForwarded method. + // + // Unparsable query parameters are removed from the + // outbound request before Rewrite is called. + // The Rewrite function may copy the inbound URL's + // RawQuery to the outbound URL to preserve the original + // parameter string. Note that this can lead to security + // issues if the proxy's interpretation of query parameters + // does not match that of the downstream server. + // + // At most one of Rewrite or Director may be set. + Rewrite func(*ProxyRequest) + + // Director is a function which modifies + // the request into a new request to be sent + // using Transport. Its response is then copied + // back to the original client unmodified. + // Director must not access the provided Request + // after returning. + // + // By default, the X-Forwarded-For header is set to the + // value of the client IP address. If an X-Forwarded-For + // header already exists, the client IP is appended to the + // existing values. As a special case, if the header + // exists in the Request.Header map but has a nil value + // (such as when set by the Director func), the X-Forwarded-For + // header is not modified. + // + // To prevent IP spoofing, be sure to delete any pre-existing + // X-Forwarded-For header coming from the client or + // an untrusted proxy. + // + // Hop-by-hop headers are removed from the request after + // Director returns, which can remove headers added by + // Director. Use a Rewrite function instead to ensure + // modifications to the request are preserved. + // + // Unparsable query parameters are removed from the outbound + // request if Request.Form is set after Director returns. + // + // At most one of Rewrite or Director may be set. + Director func(*http.Request) + + // The transport used to perform proxy requests. + // If nil, http.DefaultTransport is used. + Transport http.RoundTripper + + // FlushInterval specifies the flush interval + // to flush to the client while copying the + // response body. + // If zero, no periodic flushing is done. + // A negative value means to flush immediately + // after each write to the client. + // The FlushInterval is ignored when ReverseProxy + // recognizes a response as a streaming response, or + // if its ContentLength is -1; for such responses, writes + // are flushed to the client immediately. + FlushInterval time.Duration + + // ErrorLog specifies an optional logger for errors + // that occur when attempting to proxy the request. + // If nil, logging is done via the log package's standard logger. + ErrorLog *log.Logger + + // BufferPool optionally specifies a buffer pool to + // get byte slices for use by io.CopyBuffer when + // copying HTTP response bodies. + BufferPool BufferPool + + // ModifyResponse is an optional function that modifies the + // Response from the backend. It is called if the backend + // returns a response at all, with any HTTP status code. + // If the backend is unreachable, the optional ErrorHandler is + // called without any call to ModifyResponse. + // + // If ModifyResponse returns an error, ErrorHandler is called + // with its error value. If ErrorHandler is nil, its default + // implementation is used. + ModifyResponse func(*http.Response) error + + // ErrorHandler is an optional function that handles errors + // reaching the backend or errors from ModifyResponse. + // + // If nil, the default is to log the provided error and return + // a 502 Status Bad Gateway response. + ErrorHandler func(http.ResponseWriter, *http.Request, error) +} + +// A BufferPool is an interface for getting and returning temporary +// byte slices for use by [io.CopyBuffer]. +type BufferPool interface { + Get() []byte + Put([]byte) +} + +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +func joinURLPath(a, b *url.URL) (path, rawpath string) { + if a.RawPath == "" && b.RawPath == "" { + return singleJoiningSlash(a.Path, b.Path), "" + } + // Same as singleJoiningSlash, but uses EscapedPath to determine + // whether a slash should be added + apath := a.EscapedPath() + bpath := b.EscapedPath() + + aslash := strings.HasSuffix(apath, "/") + bslash := strings.HasPrefix(bpath, "/") + + switch { + case aslash && bslash: + return a.Path + b.Path[1:], apath + bpath[1:] + case !aslash && !bslash: + return a.Path + "/" + b.Path, apath + "/" + bpath + } + return a.Path + b.Path, apath + bpath +} + +// NewSingleHostReverseProxy returns a new [ReverseProxy] that routes +// URLs to the scheme, host, and base path provided in target. If the +// target's path is "/base" and the incoming request was for "/dir", +// the target request will be for /base/dir. +// +// NewSingleHostReverseProxy does not rewrite the Host header. +// +// To customize the ReverseProxy behavior beyond what +// NewSingleHostReverseProxy provides, use ReverseProxy directly +// with a Rewrite function. The ProxyRequest SetURL method +// may be used to route the outbound request. (Note that SetURL, +// unlike NewSingleHostReverseProxy, rewrites the Host header +// of the outbound request by default.) +// +// proxy := &ReverseProxy{ +// Rewrite: func(r *ProxyRequest) { +// r.SetURL(target) +// r.Out.Host = r.In.Host // if desired +// }, +// } +func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { + director := func(req *http.Request) { + rewriteRequestURL(req, target) + } + return &ReverseProxy{Director: director} +} + +func rewriteRequestURL(req *http.Request, target *url.URL) { + targetQuery := target.RawQuery + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL) + if targetQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = targetQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery + } +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +// Hop-by-hop headers. These are removed when sent to the backend. +// As of RFC 7230, hop-by-hop headers are required to appear in the +// Connection header field. These are the headers defined by the +// obsoleted RFC 2616 (section 13.5.1) and are used for backward +// compatibility. +var hopHeaders = []string{ + "Connection", + "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 + "Transfer-Encoding", + "Upgrade", +} + +func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) { + p.logf("http: proxy error: %v", err) + rw.WriteHeader(http.StatusBadGateway) +} + +func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) { + if p.ErrorHandler != nil { + return p.ErrorHandler + } + return p.defaultErrorHandler +} + +// modifyResponse conditionally runs the optional ModifyResponse hook +// and reports whether the request should proceed. +func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool { + if p.ModifyResponse == nil { + return true + } + if err := p.ModifyResponse(res); err != nil { + res.Body.Close() + p.getErrorHandler()(rw, req, err) + return false + } + return true +} + +func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + transport := p.Transport + if transport == nil { + transport = http.DefaultTransport + } + + ctx := req.Context() + if ctx.Done() != nil { + // CloseNotifier predates context.Context, and has been + // entirely superseded by it. If the request contains + // a Context that carries a cancellation signal, don't + // bother spinning up a goroutine to watch the CloseNotify + // channel (if any). + // + // If the request Context has a nil Done channel (which + // means it is either context.Background, or a custom + // Context implementation with no cancellation signal), + // then consult the CloseNotifier if available. + } else if cn, ok := rw.(http.CloseNotifier); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + defer cancel() + notifyChan := cn.CloseNotify() + go func() { + select { + case <-notifyChan: + cancel() + case <-ctx.Done(): + } + }() + } + + outreq := req.Clone(ctx) + if req.ContentLength == 0 { + outreq.Body = nil // Issue 16036: nil Body for http.Transport retries + } + if outreq.Body != nil { + // Reading from the request body after returning from a handler is not + // allowed, and the RoundTrip goroutine that reads the Body can outlive + // this handler. This can lead to a crash if the handler panics (see + // Issue 46866). Although calling Close doesn't guarantee there isn't + // any Read in flight after the handle returns, in practice it's safe to + // read after closing it. + defer outreq.Body.Close() + } + if outreq.Header == nil { + outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate + } + + if (p.Director != nil) == (p.Rewrite != nil) { + p.getErrorHandler()(rw, req, errors.New("ReverseProxy must have exactly one of Director or Rewrite set")) + return + } + + if p.Director != nil { + p.Director(outreq) + if outreq.Form != nil { + outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery) + } + } + outreq.Close = false + + reqUpType := upgradeType(outreq.Header) + if !ascii.IsPrint(reqUpType) { + p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType)) + return + } + removeHopByHopHeaders(outreq.Header) + + // Issue 21096: tell backend applications that care about trailer support + // that we support trailers. (We do, but we don't go out of our way to + // advertise that unless the incoming client request thought it was worth + // mentioning.) Note that we look at req.Header, not outreq.Header, since + // the latter has passed through removeHopByHopHeaders. + if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") { + outreq.Header.Set("Te", "trailers") + } + + // After stripping all the hop-by-hop connection headers above, add back any + // necessary for protocol upgrades, such as for websockets. + if reqUpType != "" { + outreq.Header.Set("Connection", "Upgrade") + outreq.Header.Set("Upgrade", reqUpType) + } + + if p.Rewrite != nil { + // Strip client-provided forwarding headers. + // The Rewrite func may use SetXForwarded to set new values + // for these or copy the previous values from the inbound request. + outreq.Header.Del("Forwarded") + outreq.Header.Del("X-Forwarded-For") + outreq.Header.Del("X-Forwarded-Host") + outreq.Header.Del("X-Forwarded-Proto") + + // Remove unparsable query parameters from the outbound request. + outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery) + + pr := &ProxyRequest{ + In: req, + Out: outreq, + } + p.Rewrite(pr) + outreq = pr.Out + } else { + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + prior, ok := outreq.Header["X-Forwarded-For"] + omit := ok && prior == nil // Issue 38079: nil now means don't populate the header + if len(prior) > 0 { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + if !omit { + outreq.Header.Set("X-Forwarded-For", clientIP) + } + } + } + + if _, ok := outreq.Header["User-Agent"]; !ok { + // If the outbound request doesn't have a User-Agent header set, + // don't send the default Go HTTP client User-Agent. + outreq.Header.Set("User-Agent", "") + } + + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + h := rw.Header() + copyHeader(h, http.Header(header)) + rw.WriteHeader(code) + + // Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses + clear(h) + return nil + }, + } + outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) + + res, err := transport.RoundTrip(outreq) + if err != nil { + p.getErrorHandler()(rw, outreq, err) + return + } + + // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) + if res.StatusCode == http.StatusSwitchingProtocols { + if !p.modifyResponse(rw, res, outreq) { + return + } + p.handleUpgradeResponse(rw, outreq, res) + return + } + + removeHopByHopHeaders(res.Header) + + if !p.modifyResponse(rw, res, outreq) { + return + } + + copyHeader(rw.Header(), res.Header) + + // The "Trailer" header isn't included in the Transport's response, + // at least for *http.Transport. Build it up from Trailer. + announcedTrailers := len(res.Trailer) + if announcedTrailers > 0 { + trailerKeys := make([]string, 0, len(res.Trailer)) + for k := range res.Trailer { + trailerKeys = append(trailerKeys, k) + } + rw.Header().Add("Trailer", strings.Join(trailerKeys, ", ")) + } + + rw.WriteHeader(res.StatusCode) + + err = p.copyResponse(rw, res.Body, p.flushInterval(res)) + if err != nil { + defer res.Body.Close() + // Since we're streaming the response, if we run into an error all we can do + // is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler + // on read error while copying body. + if !shouldPanicOnCopyError(req) { + p.logf("suppressing panic for copyResponse error in test; copy error: %v", err) + return + } + panic(http.ErrAbortHandler) + } + res.Body.Close() // close now, instead of defer, to populate res.Trailer + + if len(res.Trailer) > 0 { + // Force chunking if we saw a response trailer. + // This prevents net/http from calculating the length for short + // bodies and adding a Content-Length. + http.NewResponseController(rw).Flush() + } + + if len(res.Trailer) == announcedTrailers { + copyHeader(rw.Header(), res.Trailer) + return + } + + for k, vv := range res.Trailer { + k = http.TrailerPrefix + k + for _, v := range vv { + rw.Header().Add(k, v) + } + } +} + +var inOurTests bool // whether we're in our own tests + +// shouldPanicOnCopyError reports whether the reverse proxy should +// panic with http.ErrAbortHandler. This is the right thing to do by +// default, but Go 1.10 and earlier did not, so existing unit tests +// weren't expecting panics. Only panic in our own tests, or when +// running under the HTTP server. +func shouldPanicOnCopyError(req *http.Request) bool { + if inOurTests { + // Our tests know to handle this panic. + return true + } + if req.Context().Value(http.ServerContextKey) != nil { + // We seem to be running under an HTTP server, so + // it'll recover the panic. + return true + } + // Otherwise act like Go 1.10 and earlier to not break + // existing tests. + return false +} + +// removeHopByHopHeaders removes hop-by-hop headers. +func removeHopByHopHeaders(h http.Header) { + // RFC 7230, section 6.1: Remove headers listed in the "Connection" header. + for _, f := range h["Connection"] { + for _, sf := range strings.Split(f, ",") { + if sf = textproto.TrimString(sf); sf != "" { + h.Del(sf) + } + } + } + // RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers. + // This behavior is superseded by the RFC 7230 Connection header, but + // preserve it for backwards compatibility. + for _, f := range hopHeaders { + h.Del(f) + } +} + +// flushInterval returns the p.FlushInterval value, conditionally +// overriding its value for a specific request/response. +func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration { + resCT := res.Header.Get("Content-Type") + + // For Server-Sent Events responses, flush immediately. + // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream + if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" { + return -1 // negative means immediately + } + + // We might have the case of streaming for which Content-Length might be unset. + if res.ContentLength == -1 { + return -1 + } + + return p.FlushInterval +} + +func (p *ReverseProxy) copyResponse(dst http.ResponseWriter, src io.Reader, flushInterval time.Duration) error { + var w io.Writer = dst + + if flushInterval != 0 { + mlw := &maxLatencyWriter{ + dst: dst, + flush: http.NewResponseController(dst).Flush, + latency: flushInterval, + } + defer mlw.stop() + + // set up initial timer so headers get flushed even if body writes are delayed + mlw.flushPending = true + mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) + + w = mlw + } + + var buf []byte + if p.BufferPool != nil { + buf = p.BufferPool.Get() + defer p.BufferPool.Put(buf) + } + _, err := p.copyBuffer(w, src, buf) + return err +} + +// copyBuffer returns any write errors or non-EOF read errors, and the amount +// of bytes written. +func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { + if len(buf) == 0 { + buf = make([]byte, 32*1024) + } + var written int64 + for { + nr, rerr := src.Read(buf) + if rerr != nil && rerr != io.EOF && rerr != context.Canceled { + p.logf("httputil: ReverseProxy read error during body copy: %v", rerr) + } + if nr > 0 { + nw, werr := dst.Write(buf[:nr]) + if nw > 0 { + written += int64(nw) + } + if werr != nil { + return written, werr + } + if nr != nw { + return written, io.ErrShortWrite + } + } + if rerr != nil { + if rerr == io.EOF { + rerr = nil + } + return written, rerr + } + } +} + +func (p *ReverseProxy) logf(format string, args ...any) { + if p.ErrorLog != nil { + p.ErrorLog.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + +type maxLatencyWriter struct { + dst io.Writer + flush func() error + latency time.Duration // non-zero; negative means to flush immediately + + mu sync.Mutex // protects t, flushPending, and dst.Flush + t *time.Timer + flushPending bool +} + +func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { + m.mu.Lock() + defer m.mu.Unlock() + n, err = m.dst.Write(p) + if m.latency < 0 { + m.flush() + return + } + if m.flushPending { + return + } + if m.t == nil { + m.t = time.AfterFunc(m.latency, m.delayedFlush) + } else { + m.t.Reset(m.latency) + } + m.flushPending = true + return +} + +func (m *maxLatencyWriter) delayedFlush() { + m.mu.Lock() + defer m.mu.Unlock() + if !m.flushPending { // if stop was called but AfterFunc already started this goroutine + return + } + m.flush() + m.flushPending = false +} + +func (m *maxLatencyWriter) stop() { + m.mu.Lock() + defer m.mu.Unlock() + m.flushPending = false + if m.t != nil { + m.t.Stop() + } +} + +func upgradeType(h http.Header) string { + if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") { + return "" + } + return h.Get("Upgrade") +} + +func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { + reqUpType := upgradeType(req.Header) + resUpType := upgradeType(res.Header) + if !ascii.IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller. + p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType)) + } + if !ascii.EqualFold(reqUpType, resUpType) { + p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) + return + } + + backConn, ok := res.Body.(io.ReadWriteCloser) + if !ok { + p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) + return + } + + rc := http.NewResponseController(rw) + conn, brw, hijackErr := rc.Hijack() + if errors.Is(hijackErr, http.ErrNotSupported) { + p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) + return + } + + backConnCloseCh := make(chan bool) + go func() { + // Ensure that the cancellation of a request closes the backend. + // See issue https://golang.org/issue/35559. + select { + case <-req.Context().Done(): + case <-backConnCloseCh: + } + backConn.Close() + }() + defer close(backConnCloseCh) + + if hijackErr != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", hijackErr)) + return + } + defer conn.Close() + + copyHeader(rw.Header(), res.Header) + + res.Header = rw.Header() + res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above + if err := res.Write(brw); err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) + return + } + if err := brw.Flush(); err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) + return + } + errc := make(chan error, 1) + spc := switchProtocolCopier{user: conn, backend: backConn} + go spc.copyToBackend(errc) + go spc.copyFromBackend(errc) + <-errc +} + +// switchProtocolCopier exists so goroutines proxying data back and +// forth have nice names in stacks. +type switchProtocolCopier struct { + user, backend io.ReadWriter +} + +func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { + _, err := io.Copy(c.user, c.backend) + errc <- err +} + +func (c switchProtocolCopier) copyToBackend(errc chan<- error) { + _, err := io.Copy(c.backend, c.user) + errc <- err +} + +func cleanQueryParams(s string) string { + reencode := func(s string) string { + v, _ := url.ParseQuery(s) + return v.Encode() + } + for i := 0; i < len(s); { + switch s[i] { + case ';': + return reencode(s) + case '%': + if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) { + return reencode(s) + } + i += 3 + default: + i++ + } + } + return s +} + +func ishex(c byte) bool { + switch { + case '0' <= c && c <= '9': + return true + case 'a' <= c && c <= 'f': + return true + case 'A' <= c && c <= 'F': + return true + } + return false +} diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go new file mode 100644 index 0000000..dd3330b --- /dev/null +++ b/src/net/http/httputil/reverseproxy_test.go @@ -0,0 +1,1863 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Reverse proxy tests. + +package httputil + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "net/http/httptrace" + "net/http/internal/ascii" + "net/textproto" + "net/url" + "os" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +const fakeHopHeader = "X-Fake-Hop-Header-For-Test" + +func init() { + inOurTests = true + hopHeaders = append(hopHeaders, fakeHopHeader) +} + +func TestReverseProxy(t *testing.T) { + const backendResponse = "I am the backend" + const backendStatus = 404 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" && r.FormValue("mode") == "hangup" { + c, _, _ := w.(http.Hijacker).Hijack() + c.Close() + return + } + if len(r.TransferEncoding) > 0 { + t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding) + } + if r.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + if c := r.Header.Get("Connection"); c != "" { + t.Errorf("handler got Connection header value %q", c) + } + if c := r.Header.Get("Te"); c != "trailers" { + t.Errorf("handler got Te header value %q; want 'trailers'", c) + } + if c := r.Header.Get("Upgrade"); c != "" { + t.Errorf("handler got Upgrade header value %q", c) + } + if c := r.Header.Get("Proxy-Connection"); c != "" { + t.Errorf("handler got Proxy-Connection header value %q", c) + } + if g, e := r.Host, "some-name"; g != e { + t.Errorf("backend got Host header %q, want %q", g, e) + } + w.Header().Set("Trailers", "not a special header field name") + w.Header().Set("Trailer", "X-Trailer") + w.Header().Set("X-Foo", "bar") + w.Header().Set("Upgrade", "foo") + w.Header().Set(fakeHopHeader, "foo") + w.Header().Add("X-Multi-Value", "foo") + w.Header().Add("X-Multi-Value", "bar") + http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"}) + w.WriteHeader(backendStatus) + w.Write([]byte(backendResponse)) + w.Header().Set("X-Trailer", "trailer_value") + w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Header.Set("Connection", "close, TE") + getReq.Header.Add("Te", "foo") + getReq.Header.Add("Te", "bar, trailers") + getReq.Header.Set("Proxy-Connection", "should be deleted") + getReq.Header.Set("Upgrade", "foo") + getReq.Close = true + res, err := frontendClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + if g, e := res.Header.Get("X-Foo"), "bar"; g != e { + t.Errorf("got X-Foo %q; expected %q", g, e) + } + if c := res.Header.Get(fakeHopHeader); c != "" { + t.Errorf("got %s header value %q", fakeHopHeader, c) + } + if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e { + t.Errorf("header Trailers = %q; want %q", g, e) + } + if g, e := len(res.Header["X-Multi-Value"]), 2; g != e { + t.Errorf("got %d X-Multi-Value header values; expected %d", g, e) + } + if g, e := len(res.Header["Set-Cookie"]), 1; g != e { + t.Fatalf("got %d SetCookies, want %d", g, e) + } + if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) { + t.Errorf("before reading body, Trailer = %#v; want %#v", g, e) + } + if cookie := res.Cookies()[0]; cookie.Name != "flavor" { + t.Errorf("unexpected cookie %q", cookie.Name) + } + bodyBytes, _ := io.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } + if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e { + t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e) + } + if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e { + t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e) + } + + // Test that a backend failing to be reached or one which doesn't return + // a response results in a StatusBadGateway. + getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil) + getReq.Close = true + res, err = frontendClient.Do(getReq) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != http.StatusBadGateway { + t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status) + } + +} + +// Issue 16875: remove any proxied headers mentioned in the "Connection" +// header value. +func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { + const fakeConnectionToken = "X-Fake-Connection-Token" + const backendResponse = "I am the backend" + + // someConnHeader is some arbitrary header to be declared as a hop-by-hop header + // in the Request's Connection header. + const someConnHeader = "X-Some-Conn-Header" + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if c := r.Header.Get("Connection"); c != "" { + t.Errorf("handler got header %q = %q; want empty", "Connection", c) + } + if c := r.Header.Get(fakeConnectionToken); c != "" { + t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) + } + if c := r.Header.Get(someConnHeader); c != "" { + t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) + } + w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken) + w.Header().Add("Connection", someConnHeader) + w.Header().Set(someConnHeader, "should be deleted") + w.Header().Set(fakeConnectionToken, "should be deleted") + io.WriteString(w, backendResponse) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proxyHandler.ServeHTTP(w, r) + if c := r.Header.Get(someConnHeader); c != "should be deleted" { + t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") + } + if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" { + t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted") + } + c := r.Header["Connection"] + var cf []string + for _, f := range c { + for _, sf := range strings.Split(f, ",") { + if sf = strings.TrimSpace(sf); sf != "" { + cf = append(cf, sf) + } + } + } + sort.Strings(cf) + expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken} + sort.Strings(expectedValues) + if !reflect.DeepEqual(cf, expectedValues) { + t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues) + } + })) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken) + getReq.Header.Add("Connection", someConnHeader) + getReq.Header.Set(someConnHeader, "should be deleted") + getReq.Header.Set(fakeConnectionToken, "should be deleted") + res, err := frontend.Client().Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("reading body: %v", err) + } + if got, want := string(bodyBytes), backendResponse; got != want { + t.Errorf("got body %q; want %q", got, want) + } + if c := res.Header.Get("Connection"); c != "" { + t.Errorf("handler got header %q = %q; want empty", "Connection", c) + } + if c := res.Header.Get(someConnHeader); c != "" { + t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) + } + if c := res.Header.Get(fakeConnectionToken); c != "" { + t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) + } +} + +func TestReverseProxyStripEmptyConnection(t *testing.T) { + // See Issue 46313. + const backendResponse = "I am the backend" + + // someConnHeader is some arbitrary header to be declared as a hop-by-hop header + // in the Request's Connection header. + const someConnHeader = "X-Some-Conn-Header" + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if c := r.Header.Values("Connection"); len(c) != 0 { + t.Errorf("handler got header %q = %v; want empty", "Connection", c) + } + if c := r.Header.Get(someConnHeader); c != "" { + t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) + } + w.Header().Add("Connection", "") + w.Header().Add("Connection", someConnHeader) + w.Header().Set(someConnHeader, "should be deleted") + io.WriteString(w, backendResponse) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proxyHandler.ServeHTTP(w, r) + if c := r.Header.Get(someConnHeader); c != "should be deleted" { + t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") + } + })) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Header.Add("Connection", "") + getReq.Header.Add("Connection", someConnHeader) + getReq.Header.Set(someConnHeader, "should be deleted") + res, err := frontend.Client().Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("reading body: %v", err) + } + if got, want := string(bodyBytes), backendResponse; got != want { + t.Errorf("got body %q; want %q", got, want) + } + if c := res.Header.Get("Connection"); c != "" { + t.Errorf("handler got header %q = %q; want empty", "Connection", c) + } + if c := res.Header.Get(someConnHeader); c != "" { + t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) + } +} + +func TestXForwardedFor(t *testing.T) { + const prevForwardedFor = "client ip" + const backendResponse = "I am the backend" + const backendStatus = 404 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) { + t.Errorf("X-Forwarded-For didn't contain prior data") + } + w.WriteHeader(backendStatus) + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Header.Set("Connection", "close") + getReq.Header.Set("X-Forwarded-For", prevForwardedFor) + getReq.Close = true + res, err := frontend.Client().Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + bodyBytes, _ := io.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} + +// Issue 38079: don't append to X-Forwarded-For if it's present but nil +func TestXForwardedFor_Omit(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if v := r.Header.Get("X-Forwarded-For"); v != "" { + t.Errorf("got X-Forwarded-For header: %q", v) + } + w.Write([]byte("hi")) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + oldDirector := proxyHandler.Director + proxyHandler.Director = func(r *http.Request) { + r.Header["X-Forwarded-For"] = nil + oldDirector(r) + } + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Close = true + res, err := frontend.Client().Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + res.Body.Close() +} + +func TestReverseProxyRewriteStripsForwarded(t *testing.T) { + headers := []string{ + "Forwarded", + "X-Forwarded-For", + "X-Forwarded-Host", + "X-Forwarded-Proto", + } + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for _, h := range headers { + if v := r.Header.Get(h); v != "" { + t.Errorf("got %v header: %q", h, v) + } + } + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.SetURL(backendURL) + }, + } + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Close = true + for _, h := range headers { + getReq.Header.Set(h, "x") + } + res, err := frontend.Client().Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + res.Body.Close() +} + +var proxyQueryTests = []struct { + baseSuffix string // suffix to add to backend URL + reqSuffix string // suffix to add to frontend's request URL + want string // what backend should see for final request URL (without ?) +}{ + {"", "", ""}, + {"?sta=tic", "?us=er", "sta=tic&us=er"}, + {"", "?us=er", "us=er"}, + {"?sta=tic", "", "sta=tic"}, +} + +func TestReverseProxyQuery(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Got-Query", r.URL.RawQuery) + w.Write([]byte("hi")) + })) + defer backend.Close() + + for i, tt := range proxyQueryTests { + backendURL, err := url.Parse(backend.URL + tt.baseSuffix) + if err != nil { + t.Fatal(err) + } + frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL)) + req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil) + req.Close = true + res, err := frontend.Client().Do(req) + if err != nil { + t.Fatalf("%d. Get: %v", i, err) + } + if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e { + t.Errorf("%d. got query %q; expected %q", i, g, e) + } + res.Body.Close() + frontend.Close() + } +} + +func TestReverseProxyFlushInterval(t *testing.T) { + const expected = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(expected)) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.FlushInterval = time.Microsecond + + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := frontend.Client().Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected { + t.Errorf("got body %q; expected %q", bodyBytes, expected) + } +} + +type mockFlusher struct { + http.ResponseWriter + flushed bool +} + +func (m *mockFlusher) Flush() { + m.flushed = true +} + +type wrappedRW struct { + http.ResponseWriter +} + +func (w *wrappedRW) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} + +func TestReverseProxyResponseControllerFlushInterval(t *testing.T) { + const expected = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(expected)) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + mf := &mockFlusher{} + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.FlushInterval = -1 // flush immediately + proxyWithMiddleware := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mf.ResponseWriter = w + w = &wrappedRW{mf} + proxyHandler.ServeHTTP(w, r) + }) + + frontend := httptest.NewServer(proxyWithMiddleware) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := frontend.Client().Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected { + t.Errorf("got body %q; expected %q", bodyBytes, expected) + } + if !mf.flushed { + t.Errorf("response writer was not flushed") + } +} + +func TestReverseProxyFlushIntervalHeaders(t *testing.T) { + const expected = "hi" + stopCh := make(chan struct{}) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("MyHeader", expected) + w.WriteHeader(200) + w.(http.Flusher).Flush() + <-stopCh + })) + defer backend.Close() + defer close(stopCh) + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.FlushInterval = time.Microsecond + + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + + ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second) + defer cancel() + req = req.WithContext(ctx) + + res, err := frontend.Client().Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + + if res.Header.Get("MyHeader") != expected { + t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected) + } +} + +func TestReverseProxyCancellation(t *testing.T) { + const backendResponse = "I am the backend" + + reqInFlight := make(chan struct{}) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + close(reqInFlight) // cause the client to cancel its request + + select { + case <-time.After(10 * time.Second): + // Note: this should only happen in broken implementations, and the + // closenotify case should be instantaneous. + t.Error("Handler never saw CloseNotify") + return + case <-w.(http.CloseNotifier).CloseNotify(): + } + + w.WriteHeader(http.StatusOK) + w.Write([]byte(backendResponse)) + })) + + defer backend.Close() + + backend.Config.ErrorLog = log.New(io.Discard, "", 0) + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := NewSingleHostReverseProxy(backendURL) + + // Discards errors of the form: + // http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) + + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + go func() { + <-reqInFlight + frontendClient.Transport.(*http.Transport).CancelRequest(getReq) + }() + res, err := frontendClient.Do(getReq) + if res != nil { + t.Errorf("got response %v; want nil", res.Status) + } + if err == nil { + // This should be an error like: + // Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079: + // use of closed network connection + t.Error("Server.Client().Do() returned nil error; want non-nil error") + } +} + +func req(t *testing.T, v string) *http.Request { + req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v))) + if err != nil { + t.Fatal(err) + } + return req +} + +// Issue 12344 +func TestNilBody(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi")) + })) + defer backend.Close() + + frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + backURL, _ := url.Parse(backend.URL) + rp := NewSingleHostReverseProxy(backURL) + r := req(t, "GET / HTTP/1.0\r\n\r\n") + r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working + rp.ServeHTTP(w, r) + })) + defer frontend.Close() + + res, err := http.Get(frontend.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + slurp, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(slurp) != "hi" { + t.Errorf("Got %q; want %q", slurp, "hi") + } +} + +// Issue 15524 +func TestUserAgentHeader(t *testing.T) { + var gotUA string + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotUA = r.Header.Get("User-Agent") + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := new(ReverseProxy) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + proxyHandler.Director = func(req *http.Request) { + req.URL = backendURL + } + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + for _, sentUA := range []string{"explicit UA", ""} { + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Header.Set("User-Agent", sentUA) + getReq.Close = true + res, err := frontendClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + res.Body.Close() + if got, want := gotUA, sentUA; got != want { + t.Errorf("got forwarded User-Agent %q, want %q", got, want) + } + } +} + +type bufferPool struct { + get func() []byte + put func([]byte) +} + +func (bp bufferPool) Get() []byte { return bp.get() } +func (bp bufferPool) Put(v []byte) { bp.put(v) } + +func TestReverseProxyGetPutBuffer(t *testing.T) { + const msg = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, msg) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + var ( + mu sync.Mutex + log []string + ) + addLog := func(event string) { + mu.Lock() + defer mu.Unlock() + log = append(log, event) + } + rp := NewSingleHostReverseProxy(backendURL) + const size = 1234 + rp.BufferPool = bufferPool{ + get: func() []byte { + addLog("getBuf") + return make([]byte, size) + }, + put: func(p []byte) { + addLog("putBuf-" + strconv.Itoa(len(p))) + }, + } + frontend := httptest.NewServer(rp) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := frontend.Client().Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + slurp, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatalf("reading body: %v", err) + } + if string(slurp) != msg { + t.Errorf("msg = %q; want %q", slurp, msg) + } + wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)} + mu.Lock() + defer mu.Unlock() + if !reflect.DeepEqual(log, wantLog) { + t.Errorf("Log events = %q; want %q", log, wantLog) + } +} + +func TestReverseProxy_Post(t *testing.T) { + const backendResponse = "I am the backend" + const backendStatus = 200 + var requestBody = bytes.Repeat([]byte("a"), 1<<20) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + slurp, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("Backend body read = %v", err) + } + if len(slurp) != len(requestBody) { + t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody)) + } + if !bytes.Equal(slurp, requestBody) { + t.Error("Backend read wrong request body.") // 1MB; omitting details + } + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody)) + res, err := frontend.Client().Do(postReq) + if err != nil { + t.Fatalf("Do: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + bodyBytes, _ := io.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} + +type RoundTripperFunc func(*http.Request) (*http.Response, error) + +func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +// Issue 16036: send a Request with a nil Body when possible +func TestReverseProxy_NilBody(t *testing.T) { + backendURL, _ := url.Parse("http://fake.tld/") + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.Body != nil { + t.Error("Body != nil; want a nil Body") + } + return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") + }) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + res, err := frontend.Client().Get(frontend.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != 502 { + t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status) + } +} + +// Issue 33142: always allocate the request headers +func TestReverseProxy_AllocatedHeader(t *testing.T) { + proxyHandler := new(ReverseProxy) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + proxyHandler.Director = func(*http.Request) {} // noop + proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.Header == nil { + t.Error("Header == nil; want a non-nil Header") + } + return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") + }) + + proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{ + Method: "GET", + URL: &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"}, + Proto: "HTTP/1.0", + ProtoMajor: 1, + }) +} + +// Issue 14237. Test ModifyResponse and that an error from it +// causes the proxy to return StatusBadGateway, or StatusOK otherwise. +func TestReverseProxyModifyResponse(t *testing.T) { + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod")) + })) + defer backendServer.Close() + + rpURL, _ := url.Parse(backendServer.URL) + rproxy := NewSingleHostReverseProxy(rpURL) + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + rproxy.ModifyResponse = func(resp *http.Response) error { + if resp.Header.Get("X-Hit-Mod") != "true" { + return fmt.Errorf("tried to by-pass proxy") + } + return nil + } + + frontendProxy := httptest.NewServer(rproxy) + defer frontendProxy.Close() + + tests := []struct { + url string + wantCode int + }{ + {frontendProxy.URL + "/mod", http.StatusOK}, + {frontendProxy.URL + "/schedule", http.StatusBadGateway}, + } + + for i, tt := range tests { + resp, err := http.Get(tt.url) + if err != nil { + t.Fatalf("failed to reach proxy: %v", err) + } + if g, e := resp.StatusCode, tt.wantCode; g != e { + t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e) + } + resp.Body.Close() + } +} + +type failingRoundTripper struct{} + +func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { + return nil, errors.New("some error") +} + +type staticResponseRoundTripper struct{ res *http.Response } + +func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { + return rt.res, nil +} + +func TestReverseProxyErrorHandler(t *testing.T) { + tests := []struct { + name string + wantCode int + errorHandler func(http.ResponseWriter, *http.Request, error) + transport http.RoundTripper // defaults to failingRoundTripper + modifyResponse func(*http.Response) error + }{ + { + name: "default", + wantCode: http.StatusBadGateway, + }, + { + name: "errorhandler", + wantCode: http.StatusTeapot, + errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, + }, + { + name: "modifyresponse_noerr", + transport: staticResponseRoundTripper{ + &http.Response{StatusCode: 345, Body: http.NoBody}, + }, + modifyResponse: func(res *http.Response) error { + res.StatusCode++ + return nil + }, + errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, + wantCode: 346, + }, + { + name: "modifyresponse_err", + transport: staticResponseRoundTripper{ + &http.Response{StatusCode: 345, Body: http.NoBody}, + }, + modifyResponse: func(res *http.Response) error { + res.StatusCode++ + return errors.New("some error to trigger errorHandler") + }, + errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, + wantCode: http.StatusTeapot, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + target := &url.URL{ + Scheme: "http", + Host: "dummy.tld", + Path: "/", + } + rproxy := NewSingleHostReverseProxy(target) + rproxy.Transport = tt.transport + rproxy.ModifyResponse = tt.modifyResponse + if rproxy.Transport == nil { + rproxy.Transport = failingRoundTripper{} + } + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + if tt.errorHandler != nil { + rproxy.ErrorHandler = tt.errorHandler + } + frontendProxy := httptest.NewServer(rproxy) + defer frontendProxy.Close() + + resp, err := http.Get(frontendProxy.URL + "/test") + if err != nil { + t.Fatalf("failed to reach proxy: %v", err) + } + if g, e := resp.StatusCode, tt.wantCode; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + resp.Body.Close() + }) + } +} + +// Issue 16659: log errors from short read +func TestReverseProxy_CopyBuffer(t *testing.T) { + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + out := "this call was relayed by the reverse proxy" + // Coerce a wrong content length to induce io.UnexpectedEOF + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) + fmt.Fprintln(w, out) + })) + defer backendServer.Close() + + rpURL, err := url.Parse(backendServer.URL) + if err != nil { + t.Fatal(err) + } + + var proxyLog bytes.Buffer + rproxy := NewSingleHostReverseProxy(rpURL) + rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile) + donec := make(chan bool, 1) + frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { donec <- true }() + rproxy.ServeHTTP(w, r) + })) + defer frontendProxy.Close() + + if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil { + t.Fatalf("want non-nil error") + } + // The race detector complains about the proxyLog usage in logf in copyBuffer + // and our usage below with proxyLog.Bytes() so we're explicitly using a + // channel to ensure that the ReverseProxy's ServeHTTP is done before we + // continue after Get. + <-donec + + expected := []string{ + "EOF", + "read", + } + for _, phrase := range expected { + if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) { + t.Errorf("expected log to contain phrase %q", phrase) + } + } +} + +type staticTransport struct { + res *http.Response +} + +func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) { + return t.res, nil +} + +func BenchmarkServeHTTP(b *testing.B) { + res := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("")), + } + proxy := &ReverseProxy{ + Director: func(*http.Request) {}, + Transport: &staticTransport{res}, + } + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + proxy.ServeHTTP(w, r) + } +} + +func TestServeHTTPDeepCopy(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hello Gopher!")) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + type result struct { + before, after string + } + + resultChan := make(chan result, 1) + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + before := r.URL.String() + proxyHandler.ServeHTTP(w, r) + after := r.URL.String() + resultChan <- result{before: before, after: after} + })) + defer frontend.Close() + + want := result{before: "/", after: "/"} + + res, err := frontend.Client().Get(frontend.URL) + if err != nil { + t.Fatalf("Do: %v", err) + } + res.Body.Close() + + got := <-resultChan + if got != want { + t.Errorf("got = %+v; want = %+v", got, want) + } +} + +// Issue 18327: verify we always do a deep copy of the Request.Header map +// before any mutations. +func TestClonesRequestHeaders(t *testing.T) { + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + req, _ := http.NewRequest("GET", "http://foo.tld/", nil) + req.RemoteAddr = "1.2.3.4:56789" + rp := &ReverseProxy{ + Director: func(req *http.Request) { + req.Header.Set("From-Director", "1") + }, + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if v := req.Header.Get("From-Director"); v != "1" { + t.Errorf("From-Directory value = %q; want 1", v) + } + return nil, io.EOF + }), + } + rp.ServeHTTP(httptest.NewRecorder(), req) + + for _, h := range []string{ + "From-Director", + "X-Forwarded-For", + } { + if req.Header.Get(h) != "" { + t.Errorf("%v header mutation modified caller's request", h) + } + } +} + +type roundTripperFunc func(req *http.Request) (*http.Response, error) + +func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +func TestModifyResponseClosesBody(t *testing.T) { + req, _ := http.NewRequest("GET", "http://foo.tld/", nil) + req.RemoteAddr = "1.2.3.4:56789" + closeCheck := new(checkCloser) + logBuf := new(strings.Builder) + outErr := errors.New("ModifyResponse error") + rp := &ReverseProxy{ + Director: func(req *http.Request) {}, + Transport: &staticTransport{&http.Response{ + StatusCode: 200, + Body: closeCheck, + }}, + ErrorLog: log.New(logBuf, "", 0), + ModifyResponse: func(*http.Response) error { + return outErr + }, + } + rec := httptest.NewRecorder() + rp.ServeHTTP(rec, req) + res := rec.Result() + if g, e := res.StatusCode, http.StatusBadGateway; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + if !closeCheck.closed { + t.Errorf("body should have been closed") + } + if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) { + t.Errorf("ErrorLog %q does not contain %q", g, e) + } +} + +type checkCloser struct { + closed bool +} + +func (cc *checkCloser) Close() error { + cc.closed = true + return nil +} + +func (cc *checkCloser) Read(b []byte) (int, error) { + return len(b), nil +} + +// Issue 23643: panic on body copy error +func TestReverseProxy_PanicBodyError(t *testing.T) { + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + out := "this call was relayed by the reverse proxy" + // Coerce a wrong content length to induce io.ErrUnexpectedEOF + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) + fmt.Fprintln(w, out) + })) + defer backendServer.Close() + + rpURL, err := url.Parse(backendServer.URL) + if err != nil { + t.Fatal(err) + } + + rproxy := NewSingleHostReverseProxy(rpURL) + + // Ensure that the handler panics when the body read encounters an + // io.ErrUnexpectedEOF + defer func() { + err := recover() + if err == nil { + t.Fatal("handler should have panicked") + } + if err != http.ErrAbortHandler { + t.Fatal("expected ErrAbortHandler, got", err) + } + }() + req, _ := http.NewRequest("GET", "http://foo.tld/", nil) + rproxy.ServeHTTP(httptest.NewRecorder(), req) +} + +// Issue #46866: panic without closing incoming request body causes a panic +func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + out := "this call was relayed by the reverse proxy" + // Coerce a wrong content length to induce io.ErrUnexpectedEOF + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) + fmt.Fprintln(w, out) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + var wg sync.WaitGroup + for i := 0; i < 2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 10; j++ { + const reqLen = 6 * 1024 * 1024 + req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen}) + req.ContentLength = reqLen + resp, _ := frontendClient.Transport.RoundTrip(req) + if resp != nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + } + }() + } + wg.Wait() +} + +func TestSelectFlushInterval(t *testing.T) { + tests := []struct { + name string + p *ReverseProxy + res *http.Response + want time.Duration + }{ + { + name: "default", + res: &http.Response{}, + p: &ReverseProxy{FlushInterval: 123}, + want: 123, + }, + { + name: "server-sent events overrides non-zero", + res: &http.Response{ + Header: http.Header{ + "Content-Type": {"text/event-stream"}, + }, + }, + p: &ReverseProxy{FlushInterval: 123}, + want: -1, + }, + { + name: "server-sent events overrides zero", + res: &http.Response{ + Header: http.Header{ + "Content-Type": {"text/event-stream"}, + }, + }, + p: &ReverseProxy{FlushInterval: 0}, + want: -1, + }, + { + name: "server-sent events with media-type parameters overrides non-zero", + res: &http.Response{ + Header: http.Header{ + "Content-Type": {"text/event-stream;charset=utf-8"}, + }, + }, + p: &ReverseProxy{FlushInterval: 123}, + want: -1, + }, + { + name: "server-sent events with media-type parameters overrides zero", + res: &http.Response{ + Header: http.Header{ + "Content-Type": {"text/event-stream;charset=utf-8"}, + }, + }, + p: &ReverseProxy{FlushInterval: 0}, + want: -1, + }, + { + name: "Content-Length: -1, overrides non-zero", + res: &http.Response{ + ContentLength: -1, + }, + p: &ReverseProxy{FlushInterval: 123}, + want: -1, + }, + { + name: "Content-Length: -1, overrides zero", + res: &http.Response{ + ContentLength: -1, + }, + p: &ReverseProxy{FlushInterval: 0}, + want: -1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.p.flushInterval(tt.res) + if got != tt.want { + t.Errorf("flushLatency = %v; want %v", got, tt.want) + } + }) + } +} + +func TestReverseProxyWebSocket(t *testing.T) { + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if upgradeType(r.Header) != "websocket" { + t.Error("unexpected backend request") + http.Error(w, "unexpected request", 400) + return + } + c, _, err := w.(http.Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + defer c.Close() + io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n") + bs := bufio.NewScanner(c) + if !bs.Scan() { + t.Errorf("backend failed to read line from client: %v", bs.Err()) + return + } + fmt.Fprintf(c, "backend got %q\n", bs.Text()) + })) + defer backendServer.Close() + + backURL, _ := url.Parse(backendServer.URL) + rproxy := NewSingleHostReverseProxy(backURL) + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + rproxy.ModifyResponse = func(res *http.Response) error { + res.Header.Add("X-Modified", "true") + return nil + } + + handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("X-Header", "X-Value") + rproxy.ServeHTTP(rw, req) + if got, want := rw.Header().Get("X-Modified"), "true"; got != want { + t.Errorf("response writer X-Modified header = %q; want %q", got, want) + } + }) + + frontendProxy := httptest.NewServer(handler) + defer frontendProxy.Close() + + req, _ := http.NewRequest("GET", frontendProxy.URL, nil) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + + c := frontendProxy.Client() + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 101 { + t.Fatalf("status = %v; want 101", res.Status) + } + + got := res.Header.Get("X-Header") + want := "X-Value" + if got != want { + t.Errorf("Header(XHeader) = %q; want %q", got, want) + } + + if !ascii.EqualFold(upgradeType(res.Header), "websocket") { + t.Fatalf("not websocket upgrade; got %#v", res.Header) + } + rwc, ok := res.Body.(io.ReadWriteCloser) + if !ok { + t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body) + } + defer rwc.Close() + + if got, want := res.Header.Get("X-Modified"), "true"; got != want { + t.Errorf("response X-Modified header = %q; want %q", got, want) + } + + io.WriteString(rwc, "Hello\n") + bs := bufio.NewScanner(rwc) + if !bs.Scan() { + t.Fatalf("Scan: %v", bs.Err()) + } + got = bs.Text() + want = `backend got "Hello"` + if got != want { + t.Errorf("got %#q, want %#q", got, want) + } +} + +func TestReverseProxyWebSocketCancellation(t *testing.T) { + n := 5 + triggerCancelCh := make(chan bool, n) + nthResponse := func(i int) string { + return fmt.Sprintf("backend response #%d\n", i) + } + terminalMsg := "final message" + + cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if g, ws := upgradeType(r.Header), "websocket"; g != ws { + t.Errorf("Unexpected upgrade type %q, want %q", g, ws) + http.Error(w, "Unexpected request", 400) + return + } + conn, bufrw, err := w.(http.Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + defer conn.Close() + + upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n" + if _, err := io.WriteString(conn, upgradeMsg); err != nil { + t.Error(err) + return + } + if _, _, err := bufrw.ReadLine(); err != nil { + t.Errorf("Failed to read line from client: %v", err) + return + } + + for i := 0; i < n; i++ { + if _, err := bufrw.WriteString(nthResponse(i)); err != nil { + select { + case <-triggerCancelCh: + default: + t.Errorf("Writing response #%d failed: %v", i, err) + } + return + } + bufrw.Flush() + time.Sleep(time.Second) + } + if _, err := bufrw.WriteString(terminalMsg); err != nil { + select { + case <-triggerCancelCh: + default: + t.Errorf("Failed to write terminal message: %v", err) + } + } + bufrw.Flush() + })) + defer cst.Close() + + backendURL, _ := url.Parse(cst.URL) + rproxy := NewSingleHostReverseProxy(backendURL) + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + rproxy.ModifyResponse = func(res *http.Response) error { + res.Header.Add("X-Modified", "true") + return nil + } + + handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("X-Header", "X-Value") + ctx, cancel := context.WithCancel(req.Context()) + go func() { + <-triggerCancelCh + cancel() + }() + rproxy.ServeHTTP(rw, req.WithContext(ctx)) + }) + + frontendProxy := httptest.NewServer(handler) + defer frontendProxy.Close() + + req, _ := http.NewRequest("GET", frontendProxy.URL, nil) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + + res, err := frontendProxy.Client().Do(req) + if err != nil { + t.Fatalf("Dialing to frontend proxy: %v", err) + } + defer res.Body.Close() + if g, w := res.StatusCode, 101; g != w { + t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w) + } + + if g, w := res.Header.Get("X-Header"), "X-Value"; g != w { + t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w) + } + + if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) { + t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w) + } + + rwc, ok := res.Body.(io.ReadWriteCloser) + if !ok { + t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body) + } + + if got, want := res.Header.Get("X-Modified"), "true"; got != want { + t.Errorf("response X-Modified header = %q; want %q", got, want) + } + + if _, err := io.WriteString(rwc, "Hello\n"); err != nil { + t.Fatalf("Failed to write first message: %v", err) + } + + // Read loop. + + br := bufio.NewReader(rwc) + for { + line, err := br.ReadString('\n') + switch { + case line == terminalMsg: // this case before "err == io.EOF" + t.Fatalf("The websocket request was not canceled, unfortunately!") + + case err == io.EOF: + return + + case err != nil: + t.Fatalf("Unexpected error: %v", err) + + case line == nthResponse(0): // We've gotten the first response back + // Let's trigger a cancel. + close(triggerCancelCh) + } + } +} + +func TestUnannouncedTrailer(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + res, err := frontendClient.Get(frontend.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + + io.ReadAll(res.Body) + + if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w { + t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w) + } + +} + +func TestSetURL(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(r.Host)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.SetURL(backendURL) + }, + } + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + res, err := frontendClient.Get(frontend.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("Reading body: %v", err) + } + + if got, want := string(body), backendURL.Host; got != want { + t.Errorf("backend got Host %q, want %q", got, want) + } +} + +func TestSingleJoinSlash(t *testing.T) { + tests := []struct { + slasha string + slashb string + expected string + }{ + {"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"}, + {"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"}, + {"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"}, + {"https://www.google.com", "", "https://www.google.com/"}, + {"", "favicon.ico", "/favicon.ico"}, + } + for _, tt := range tests { + if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected { + t.Errorf("singleJoiningSlash(%q,%q) want %q got %q", + tt.slasha, + tt.slashb, + tt.expected, + got) + } + } +} + +func TestJoinURLPath(t *testing.T) { + tests := []struct { + a *url.URL + b *url.URL + wantPath string + wantRaw string + }{ + {&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""}, + {&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"}, + {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, + {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, + {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"}, + {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"}, + } + + for _, tt := range tests { + p, rp := joinURLPath(tt.a, tt.b) + if p != tt.wantPath || rp != tt.wantRaw { + t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)", + tt.a.Path, tt.a.RawPath, + tt.b.Path, tt.b.RawPath, + tt.wantPath, tt.wantRaw, + p, rp) + } + } +} + +func TestReverseProxyRewriteReplacesOut(t *testing.T) { + const content = "response_content" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(content)) + })) + defer backend.Close() + proxyHandler := &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.Out, _ = http.NewRequest("GET", backend.URL, nil) + }, + } + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + res, err := frontend.Client().Get(frontend.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + body, _ := io.ReadAll(res.Body) + if got, want := string(body), content; got != want { + t.Errorf("got response %q, want %q", got, want) + } +} + +func Test1xxResponses(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := w.Header() + h.Add("Link", "</style.css>; rel=preload; as=style") + h.Add("Link", "</script.js>; rel=preload; as=script") + w.WriteHeader(http.StatusEarlyHints) + + h.Add("Link", "</foo.js>; rel=preload; as=script") + w.WriteHeader(http.StatusProcessing) + + w.Write([]byte("Hello")) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + checkLinkHeaders := func(t *testing.T, expected, got []string) { + t.Helper() + + if len(expected) != len(got) { + t.Errorf("Expected %d link headers; got %d", len(expected), len(got)) + } + + for i := range expected { + if i >= len(got) { + t.Errorf("Expected %q link header; got nothing", expected[i]) + + continue + } + + if expected[i] != got[i] { + t.Errorf("Expected %q link header; got %q", expected[i], got[i]) + } + } + } + + var respCounter uint8 + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + switch code { + case http.StatusEarlyHints: + checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"]) + case http.StatusProcessing: + checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"]) + default: + t.Error("Unexpected 1xx response") + } + + respCounter++ + + return nil + }, + } + req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", frontend.URL, nil) + + res, err := frontendClient.Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + + defer res.Body.Close() + + if respCounter != 2 { + t.Errorf("Expected 2 1xx responses; got %d", respCounter) + } + checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"]) + + body, _ := io.ReadAll(res.Body) + if string(body) != "Hello" { + t.Errorf("Read body %q; want Hello", body) + } +} + +const ( + testWantsCleanQuery = true + testWantsRawQuery = false +) + +func TestReverseProxyQueryParameterSmugglingDirectorDoesNotParseForm(t *testing.T) { + testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy { + proxyHandler := NewSingleHostReverseProxy(u) + oldDirector := proxyHandler.Director + proxyHandler.Director = func(r *http.Request) { + oldDirector(r) + } + return proxyHandler + }) +} + +func TestReverseProxyQueryParameterSmugglingDirectorParsesForm(t *testing.T) { + testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy { + proxyHandler := NewSingleHostReverseProxy(u) + oldDirector := proxyHandler.Director + proxyHandler.Director = func(r *http.Request) { + // Parsing the form causes ReverseProxy to remove unparsable + // query parameters before forwarding. + r.FormValue("a") + oldDirector(r) + } + return proxyHandler + }) +} + +func TestReverseProxyQueryParameterSmugglingRewrite(t *testing.T) { + testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy { + return &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.SetURL(u) + }, + } + }) +} + +func TestReverseProxyQueryParameterSmugglingRewritePreservesRawQuery(t *testing.T) { + testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy { + return &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.SetURL(u) + r.Out.URL.RawQuery = r.In.URL.RawQuery + }, + } + }) +} + +func testReverseProxyQueryParameterSmuggling(t *testing.T, wantCleanQuery bool, newProxy func(*url.URL) *ReverseProxy) { + const content = "response_content" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(r.URL.RawQuery)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := newProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + // Don't spam output with logs of queries containing semicolons. + backend.Config.ErrorLog = log.New(io.Discard, "", 0) + frontend.Config.ErrorLog = log.New(io.Discard, "", 0) + + for _, test := range []struct { + rawQuery string + cleanQuery string + }{{ + rawQuery: "a=1&a=2;b=3", + cleanQuery: "a=1", + }, { + rawQuery: "a=1&a=%zz&b=3", + cleanQuery: "a=1&b=3", + }} { + res, err := frontend.Client().Get(frontend.URL + "?" + test.rawQuery) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + body, _ := io.ReadAll(res.Body) + wantQuery := test.rawQuery + if wantCleanQuery { + wantQuery = test.cleanQuery + } + if got, want := string(body), wantQuery; got != want { + t.Errorf("proxy forwarded raw query %q as %q, want %q", test.rawQuery, got, want) + } + } +} |