summaryrefslogtreecommitdiffstats
path: root/src/net/http/httputil/reverseproxy_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/net/http/httputil/reverseproxy_test.go')
-rw-r--r--src/net/http/httputil/reverseproxy_test.go1863
1 files changed, 1863 insertions, 0 deletions
diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go
new file mode 100644
index 0000000..dd3330b
--- /dev/null
+++ b/src/net/http/httputil/reverseproxy_test.go
@@ -0,0 +1,1863 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Reverse proxy tests.
+
+package httputil
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "net/http/httptest"
+ "net/http/httptrace"
+ "net/http/internal/ascii"
+ "net/textproto"
+ "net/url"
+ "os"
+ "reflect"
+ "sort"
+ "strconv"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
+
+func init() {
+ inOurTests = true
+ hopHeaders = append(hopHeaders, fakeHopHeader)
+}
+
+func TestReverseProxy(t *testing.T) {
+ const backendResponse = "I am the backend"
+ const backendStatus = 404
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method == "GET" && r.FormValue("mode") == "hangup" {
+ c, _, _ := w.(http.Hijacker).Hijack()
+ c.Close()
+ return
+ }
+ if len(r.TransferEncoding) > 0 {
+ t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
+ }
+ if r.Header.Get("X-Forwarded-For") == "" {
+ t.Errorf("didn't get X-Forwarded-For header")
+ }
+ if c := r.Header.Get("Connection"); c != "" {
+ t.Errorf("handler got Connection header value %q", c)
+ }
+ if c := r.Header.Get("Te"); c != "trailers" {
+ t.Errorf("handler got Te header value %q; want 'trailers'", c)
+ }
+ if c := r.Header.Get("Upgrade"); c != "" {
+ t.Errorf("handler got Upgrade header value %q", c)
+ }
+ if c := r.Header.Get("Proxy-Connection"); c != "" {
+ t.Errorf("handler got Proxy-Connection header value %q", c)
+ }
+ if g, e := r.Host, "some-name"; g != e {
+ t.Errorf("backend got Host header %q, want %q", g, e)
+ }
+ w.Header().Set("Trailers", "not a special header field name")
+ w.Header().Set("Trailer", "X-Trailer")
+ w.Header().Set("X-Foo", "bar")
+ w.Header().Set("Upgrade", "foo")
+ w.Header().Set(fakeHopHeader, "foo")
+ w.Header().Add("X-Multi-Value", "foo")
+ w.Header().Add("X-Multi-Value", "bar")
+ http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
+ w.WriteHeader(backendStatus)
+ w.Write([]byte(backendResponse))
+ w.Header().Set("X-Trailer", "trailer_value")
+ w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+ frontendClient := frontend.Client()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Host = "some-name"
+ getReq.Header.Set("Connection", "close, TE")
+ getReq.Header.Add("Te", "foo")
+ getReq.Header.Add("Te", "bar, trailers")
+ getReq.Header.Set("Proxy-Connection", "should be deleted")
+ getReq.Header.Set("Upgrade", "foo")
+ getReq.Close = true
+ res, err := frontendClient.Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ if g, e := res.StatusCode, backendStatus; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
+ t.Errorf("got X-Foo %q; expected %q", g, e)
+ }
+ if c := res.Header.Get(fakeHopHeader); c != "" {
+ t.Errorf("got %s header value %q", fakeHopHeader, c)
+ }
+ if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e {
+ t.Errorf("header Trailers = %q; want %q", g, e)
+ }
+ if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
+ t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
+ }
+ if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
+ t.Fatalf("got %d SetCookies, want %d", g, e)
+ }
+ if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
+ t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
+ }
+ if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
+ t.Errorf("unexpected cookie %q", cookie.Name)
+ }
+ bodyBytes, _ := io.ReadAll(res.Body)
+ if g, e := string(bodyBytes), backendResponse; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+ if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
+ t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
+ }
+ if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e {
+ t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e)
+ }
+
+ // Test that a backend failing to be reached or one which doesn't return
+ // a response results in a StatusBadGateway.
+ getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil)
+ getReq.Close = true
+ res, err = frontendClient.Do(getReq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if res.StatusCode != http.StatusBadGateway {
+ t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status)
+ }
+
+}
+
+// Issue 16875: remove any proxied headers mentioned in the "Connection"
+// header value.
+func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
+ const fakeConnectionToken = "X-Fake-Connection-Token"
+ const backendResponse = "I am the backend"
+
+ // someConnHeader is some arbitrary header to be declared as a hop-by-hop header
+ // in the Request's Connection header.
+ const someConnHeader = "X-Some-Conn-Header"
+
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if c := r.Header.Get("Connection"); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", "Connection", c)
+ }
+ if c := r.Header.Get(fakeConnectionToken); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
+ }
+ if c := r.Header.Get(someConnHeader); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
+ }
+ w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken)
+ w.Header().Add("Connection", someConnHeader)
+ w.Header().Set(someConnHeader, "should be deleted")
+ w.Header().Set(fakeConnectionToken, "should be deleted")
+ io.WriteString(w, backendResponse)
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ proxyHandler.ServeHTTP(w, r)
+ if c := r.Header.Get(someConnHeader); c != "should be deleted" {
+ t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
+ }
+ if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" {
+ t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted")
+ }
+ c := r.Header["Connection"]
+ var cf []string
+ for _, f := range c {
+ for _, sf := range strings.Split(f, ",") {
+ if sf = strings.TrimSpace(sf); sf != "" {
+ cf = append(cf, sf)
+ }
+ }
+ }
+ sort.Strings(cf)
+ expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken}
+ sort.Strings(expectedValues)
+ if !reflect.DeepEqual(cf, expectedValues) {
+ t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues)
+ }
+ }))
+ defer frontend.Close()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken)
+ getReq.Header.Add("Connection", someConnHeader)
+ getReq.Header.Set(someConnHeader, "should be deleted")
+ getReq.Header.Set(fakeConnectionToken, "should be deleted")
+ res, err := frontend.Client().Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ bodyBytes, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("reading body: %v", err)
+ }
+ if got, want := string(bodyBytes), backendResponse; got != want {
+ t.Errorf("got body %q; want %q", got, want)
+ }
+ if c := res.Header.Get("Connection"); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", "Connection", c)
+ }
+ if c := res.Header.Get(someConnHeader); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
+ }
+ if c := res.Header.Get(fakeConnectionToken); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
+ }
+}
+
+func TestReverseProxyStripEmptyConnection(t *testing.T) {
+ // See Issue 46313.
+ const backendResponse = "I am the backend"
+
+ // someConnHeader is some arbitrary header to be declared as a hop-by-hop header
+ // in the Request's Connection header.
+ const someConnHeader = "X-Some-Conn-Header"
+
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if c := r.Header.Values("Connection"); len(c) != 0 {
+ t.Errorf("handler got header %q = %v; want empty", "Connection", c)
+ }
+ if c := r.Header.Get(someConnHeader); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
+ }
+ w.Header().Add("Connection", "")
+ w.Header().Add("Connection", someConnHeader)
+ w.Header().Set(someConnHeader, "should be deleted")
+ io.WriteString(w, backendResponse)
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ proxyHandler.ServeHTTP(w, r)
+ if c := r.Header.Get(someConnHeader); c != "should be deleted" {
+ t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
+ }
+ }))
+ defer frontend.Close()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Header.Add("Connection", "")
+ getReq.Header.Add("Connection", someConnHeader)
+ getReq.Header.Set(someConnHeader, "should be deleted")
+ res, err := frontend.Client().Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ bodyBytes, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("reading body: %v", err)
+ }
+ if got, want := string(bodyBytes), backendResponse; got != want {
+ t.Errorf("got body %q; want %q", got, want)
+ }
+ if c := res.Header.Get("Connection"); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", "Connection", c)
+ }
+ if c := res.Header.Get(someConnHeader); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
+ }
+}
+
+func TestXForwardedFor(t *testing.T) {
+ const prevForwardedFor = "client ip"
+ const backendResponse = "I am the backend"
+ const backendStatus = 404
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Header.Get("X-Forwarded-For") == "" {
+ t.Errorf("didn't get X-Forwarded-For header")
+ }
+ if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
+ t.Errorf("X-Forwarded-For didn't contain prior data")
+ }
+ w.WriteHeader(backendStatus)
+ w.Write([]byte(backendResponse))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Header.Set("Connection", "close")
+ getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
+ getReq.Close = true
+ res, err := frontend.Client().Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ if g, e := res.StatusCode, backendStatus; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ bodyBytes, _ := io.ReadAll(res.Body)
+ if g, e := string(bodyBytes), backendResponse; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+}
+
+// Issue 38079: don't append to X-Forwarded-For if it's present but nil
+func TestXForwardedFor_Omit(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if v := r.Header.Get("X-Forwarded-For"); v != "" {
+ t.Errorf("got X-Forwarded-For header: %q", v)
+ }
+ w.Write([]byte("hi"))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ oldDirector := proxyHandler.Director
+ proxyHandler.Director = func(r *http.Request) {
+ r.Header["X-Forwarded-For"] = nil
+ oldDirector(r)
+ }
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Host = "some-name"
+ getReq.Close = true
+ res, err := frontend.Client().Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ res.Body.Close()
+}
+
+func TestReverseProxyRewriteStripsForwarded(t *testing.T) {
+ headers := []string{
+ "Forwarded",
+ "X-Forwarded-For",
+ "X-Forwarded-Host",
+ "X-Forwarded-Proto",
+ }
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ for _, h := range headers {
+ if v := r.Header.Get(h); v != "" {
+ t.Errorf("got %v header: %q", h, v)
+ }
+ }
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := &ReverseProxy{
+ Rewrite: func(r *ProxyRequest) {
+ r.SetURL(backendURL)
+ },
+ }
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Host = "some-name"
+ getReq.Close = true
+ for _, h := range headers {
+ getReq.Header.Set(h, "x")
+ }
+ res, err := frontend.Client().Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ res.Body.Close()
+}
+
+var proxyQueryTests = []struct {
+ baseSuffix string // suffix to add to backend URL
+ reqSuffix string // suffix to add to frontend's request URL
+ want string // what backend should see for final request URL (without ?)
+}{
+ {"", "", ""},
+ {"?sta=tic", "?us=er", "sta=tic&us=er"},
+ {"", "?us=er", "us=er"},
+ {"?sta=tic", "", "sta=tic"},
+}
+
+func TestReverseProxyQuery(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("X-Got-Query", r.URL.RawQuery)
+ w.Write([]byte("hi"))
+ }))
+ defer backend.Close()
+
+ for i, tt := range proxyQueryTests {
+ backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
+ if err != nil {
+ t.Fatal(err)
+ }
+ frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
+ req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
+ req.Close = true
+ res, err := frontend.Client().Do(req)
+ if err != nil {
+ t.Fatalf("%d. Get: %v", i, err)
+ }
+ if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
+ t.Errorf("%d. got query %q; expected %q", i, g, e)
+ }
+ res.Body.Close()
+ frontend.Close()
+ }
+}
+
+func TestReverseProxyFlushInterval(t *testing.T) {
+ const expected = "hi"
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(expected))
+ }))
+ defer backend.Close()
+
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.FlushInterval = time.Microsecond
+
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ req, _ := http.NewRequest("GET", frontend.URL, nil)
+ req.Close = true
+ res, err := frontend.Client().Do(req)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
+ t.Errorf("got body %q; expected %q", bodyBytes, expected)
+ }
+}
+
+type mockFlusher struct {
+ http.ResponseWriter
+ flushed bool
+}
+
+func (m *mockFlusher) Flush() {
+ m.flushed = true
+}
+
+type wrappedRW struct {
+ http.ResponseWriter
+}
+
+func (w *wrappedRW) Unwrap() http.ResponseWriter {
+ return w.ResponseWriter
+}
+
+func TestReverseProxyResponseControllerFlushInterval(t *testing.T) {
+ const expected = "hi"
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(expected))
+ }))
+ defer backend.Close()
+
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ mf := &mockFlusher{}
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.FlushInterval = -1 // flush immediately
+ proxyWithMiddleware := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ mf.ResponseWriter = w
+ w = &wrappedRW{mf}
+ proxyHandler.ServeHTTP(w, r)
+ })
+
+ frontend := httptest.NewServer(proxyWithMiddleware)
+ defer frontend.Close()
+
+ req, _ := http.NewRequest("GET", frontend.URL, nil)
+ req.Close = true
+ res, err := frontend.Client().Do(req)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected {
+ t.Errorf("got body %q; expected %q", bodyBytes, expected)
+ }
+ if !mf.flushed {
+ t.Errorf("response writer was not flushed")
+ }
+}
+
+func TestReverseProxyFlushIntervalHeaders(t *testing.T) {
+ const expected = "hi"
+ stopCh := make(chan struct{})
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Add("MyHeader", expected)
+ w.WriteHeader(200)
+ w.(http.Flusher).Flush()
+ <-stopCh
+ }))
+ defer backend.Close()
+ defer close(stopCh)
+
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.FlushInterval = time.Microsecond
+
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ req, _ := http.NewRequest("GET", frontend.URL, nil)
+ req.Close = true
+
+ ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second)
+ defer cancel()
+ req = req.WithContext(ctx)
+
+ res, err := frontend.Client().Do(req)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+
+ if res.Header.Get("MyHeader") != expected {
+ t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected)
+ }
+}
+
+func TestReverseProxyCancellation(t *testing.T) {
+ const backendResponse = "I am the backend"
+
+ reqInFlight := make(chan struct{})
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ close(reqInFlight) // cause the client to cancel its request
+
+ select {
+ case <-time.After(10 * time.Second):
+ // Note: this should only happen in broken implementations, and the
+ // closenotify case should be instantaneous.
+ t.Error("Handler never saw CloseNotify")
+ return
+ case <-w.(http.CloseNotifier).CloseNotify():
+ }
+
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(backendResponse))
+ }))
+
+ defer backend.Close()
+
+ backend.Config.ErrorLog = log.New(io.Discard, "", 0)
+
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+
+ // Discards errors of the form:
+ // http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0)
+
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+ frontendClient := frontend.Client()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ go func() {
+ <-reqInFlight
+ frontendClient.Transport.(*http.Transport).CancelRequest(getReq)
+ }()
+ res, err := frontendClient.Do(getReq)
+ if res != nil {
+ t.Errorf("got response %v; want nil", res.Status)
+ }
+ if err == nil {
+ // This should be an error like:
+ // Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079:
+ // use of closed network connection
+ t.Error("Server.Client().Do() returned nil error; want non-nil error")
+ }
+}
+
+func req(t *testing.T, v string) *http.Request {
+ req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
+ if err != nil {
+ t.Fatal(err)
+ }
+ return req
+}
+
+// Issue 12344
+func TestNilBody(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hi"))
+ }))
+ defer backend.Close()
+
+ frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ backURL, _ := url.Parse(backend.URL)
+ rp := NewSingleHostReverseProxy(backURL)
+ r := req(t, "GET / HTTP/1.0\r\n\r\n")
+ r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working
+ rp.ServeHTTP(w, r)
+ }))
+ defer frontend.Close()
+
+ res, err := http.Get(frontend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ slurp, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(slurp) != "hi" {
+ t.Errorf("Got %q; want %q", slurp, "hi")
+ }
+}
+
+// Issue 15524
+func TestUserAgentHeader(t *testing.T) {
+ var gotUA string
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gotUA = r.Header.Get("User-Agent")
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ proxyHandler := new(ReverseProxy)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ proxyHandler.Director = func(req *http.Request) {
+ req.URL = backendURL
+ }
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+ frontendClient := frontend.Client()
+
+ for _, sentUA := range []string{"explicit UA", ""} {
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Header.Set("User-Agent", sentUA)
+ getReq.Close = true
+ res, err := frontendClient.Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ res.Body.Close()
+ if got, want := gotUA, sentUA; got != want {
+ t.Errorf("got forwarded User-Agent %q, want %q", got, want)
+ }
+ }
+}
+
+type bufferPool struct {
+ get func() []byte
+ put func([]byte)
+}
+
+func (bp bufferPool) Get() []byte { return bp.get() }
+func (bp bufferPool) Put(v []byte) { bp.put(v) }
+
+func TestReverseProxyGetPutBuffer(t *testing.T) {
+ const msg = "hi"
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, msg)
+ }))
+ defer backend.Close()
+
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var (
+ mu sync.Mutex
+ log []string
+ )
+ addLog := func(event string) {
+ mu.Lock()
+ defer mu.Unlock()
+ log = append(log, event)
+ }
+ rp := NewSingleHostReverseProxy(backendURL)
+ const size = 1234
+ rp.BufferPool = bufferPool{
+ get: func() []byte {
+ addLog("getBuf")
+ return make([]byte, size)
+ },
+ put: func(p []byte) {
+ addLog("putBuf-" + strconv.Itoa(len(p)))
+ },
+ }
+ frontend := httptest.NewServer(rp)
+ defer frontend.Close()
+
+ req, _ := http.NewRequest("GET", frontend.URL, nil)
+ req.Close = true
+ res, err := frontend.Client().Do(req)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ slurp, err := io.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Fatalf("reading body: %v", err)
+ }
+ if string(slurp) != msg {
+ t.Errorf("msg = %q; want %q", slurp, msg)
+ }
+ wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
+ mu.Lock()
+ defer mu.Unlock()
+ if !reflect.DeepEqual(log, wantLog) {
+ t.Errorf("Log events = %q; want %q", log, wantLog)
+ }
+}
+
+func TestReverseProxy_Post(t *testing.T) {
+ const backendResponse = "I am the backend"
+ const backendStatus = 200
+ var requestBody = bytes.Repeat([]byte("a"), 1<<20)
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ slurp, err := io.ReadAll(r.Body)
+ if err != nil {
+ t.Errorf("Backend body read = %v", err)
+ }
+ if len(slurp) != len(requestBody) {
+ t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
+ }
+ if !bytes.Equal(slurp, requestBody) {
+ t.Error("Backend read wrong request body.") // 1MB; omitting details
+ }
+ w.Write([]byte(backendResponse))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
+ res, err := frontend.Client().Do(postReq)
+ if err != nil {
+ t.Fatalf("Do: %v", err)
+ }
+ if g, e := res.StatusCode, backendStatus; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ bodyBytes, _ := io.ReadAll(res.Body)
+ if g, e := string(bodyBytes), backendResponse; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+}
+
+type RoundTripperFunc func(*http.Request) (*http.Response, error)
+
+func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
+ return fn(req)
+}
+
+// Issue 16036: send a Request with a nil Body when possible
+func TestReverseProxy_NilBody(t *testing.T) {
+ backendURL, _ := url.Parse("http://fake.tld/")
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
+ if req.Body != nil {
+ t.Error("Body != nil; want a nil Body")
+ }
+ return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
+ })
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ res, err := frontend.Client().Get(frontend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != 502 {
+ t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status)
+ }
+}
+
+// Issue 33142: always allocate the request headers
+func TestReverseProxy_AllocatedHeader(t *testing.T) {
+ proxyHandler := new(ReverseProxy)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ proxyHandler.Director = func(*http.Request) {} // noop
+ proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
+ if req.Header == nil {
+ t.Error("Header == nil; want a non-nil Header")
+ }
+ return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
+ })
+
+ proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{
+ Method: "GET",
+ URL: &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"},
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ })
+}
+
+// Issue 14237. Test ModifyResponse and that an error from it
+// causes the proxy to return StatusBadGateway, or StatusOK otherwise.
+func TestReverseProxyModifyResponse(t *testing.T) {
+ backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod"))
+ }))
+ defer backendServer.Close()
+
+ rpURL, _ := url.Parse(backendServer.URL)
+ rproxy := NewSingleHostReverseProxy(rpURL)
+ rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ rproxy.ModifyResponse = func(resp *http.Response) error {
+ if resp.Header.Get("X-Hit-Mod") != "true" {
+ return fmt.Errorf("tried to by-pass proxy")
+ }
+ return nil
+ }
+
+ frontendProxy := httptest.NewServer(rproxy)
+ defer frontendProxy.Close()
+
+ tests := []struct {
+ url string
+ wantCode int
+ }{
+ {frontendProxy.URL + "/mod", http.StatusOK},
+ {frontendProxy.URL + "/schedule", http.StatusBadGateway},
+ }
+
+ for i, tt := range tests {
+ resp, err := http.Get(tt.url)
+ if err != nil {
+ t.Fatalf("failed to reach proxy: %v", err)
+ }
+ if g, e := resp.StatusCode, tt.wantCode; g != e {
+ t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e)
+ }
+ resp.Body.Close()
+ }
+}
+
+type failingRoundTripper struct{}
+
+func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
+ return nil, errors.New("some error")
+}
+
+type staticResponseRoundTripper struct{ res *http.Response }
+
+func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
+ return rt.res, nil
+}
+
+func TestReverseProxyErrorHandler(t *testing.T) {
+ tests := []struct {
+ name string
+ wantCode int
+ errorHandler func(http.ResponseWriter, *http.Request, error)
+ transport http.RoundTripper // defaults to failingRoundTripper
+ modifyResponse func(*http.Response) error
+ }{
+ {
+ name: "default",
+ wantCode: http.StatusBadGateway,
+ },
+ {
+ name: "errorhandler",
+ wantCode: http.StatusTeapot,
+ errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
+ },
+ {
+ name: "modifyresponse_noerr",
+ transport: staticResponseRoundTripper{
+ &http.Response{StatusCode: 345, Body: http.NoBody},
+ },
+ modifyResponse: func(res *http.Response) error {
+ res.StatusCode++
+ return nil
+ },
+ errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
+ wantCode: 346,
+ },
+ {
+ name: "modifyresponse_err",
+ transport: staticResponseRoundTripper{
+ &http.Response{StatusCode: 345, Body: http.NoBody},
+ },
+ modifyResponse: func(res *http.Response) error {
+ res.StatusCode++
+ return errors.New("some error to trigger errorHandler")
+ },
+ errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
+ wantCode: http.StatusTeapot,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ target := &url.URL{
+ Scheme: "http",
+ Host: "dummy.tld",
+ Path: "/",
+ }
+ rproxy := NewSingleHostReverseProxy(target)
+ rproxy.Transport = tt.transport
+ rproxy.ModifyResponse = tt.modifyResponse
+ if rproxy.Transport == nil {
+ rproxy.Transport = failingRoundTripper{}
+ }
+ rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ if tt.errorHandler != nil {
+ rproxy.ErrorHandler = tt.errorHandler
+ }
+ frontendProxy := httptest.NewServer(rproxy)
+ defer frontendProxy.Close()
+
+ resp, err := http.Get(frontendProxy.URL + "/test")
+ if err != nil {
+ t.Fatalf("failed to reach proxy: %v", err)
+ }
+ if g, e := resp.StatusCode, tt.wantCode; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ resp.Body.Close()
+ })
+ }
+}
+
+// Issue 16659: log errors from short read
+func TestReverseProxy_CopyBuffer(t *testing.T) {
+ backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ out := "this call was relayed by the reverse proxy"
+ // Coerce a wrong content length to induce io.UnexpectedEOF
+ w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
+ fmt.Fprintln(w, out)
+ }))
+ defer backendServer.Close()
+
+ rpURL, err := url.Parse(backendServer.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var proxyLog bytes.Buffer
+ rproxy := NewSingleHostReverseProxy(rpURL)
+ rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
+ donec := make(chan bool, 1)
+ frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ defer func() { donec <- true }()
+ rproxy.ServeHTTP(w, r)
+ }))
+ defer frontendProxy.Close()
+
+ if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil {
+ t.Fatalf("want non-nil error")
+ }
+ // The race detector complains about the proxyLog usage in logf in copyBuffer
+ // and our usage below with proxyLog.Bytes() so we're explicitly using a
+ // channel to ensure that the ReverseProxy's ServeHTTP is done before we
+ // continue after Get.
+ <-donec
+
+ expected := []string{
+ "EOF",
+ "read",
+ }
+ for _, phrase := range expected {
+ if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) {
+ t.Errorf("expected log to contain phrase %q", phrase)
+ }
+ }
+}
+
+type staticTransport struct {
+ res *http.Response
+}
+
+func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
+ return t.res, nil
+}
+
+func BenchmarkServeHTTP(b *testing.B) {
+ res := &http.Response{
+ StatusCode: 200,
+ Body: io.NopCloser(strings.NewReader("")),
+ }
+ proxy := &ReverseProxy{
+ Director: func(*http.Request) {},
+ Transport: &staticTransport{res},
+ }
+
+ w := httptest.NewRecorder()
+ r := httptest.NewRequest("GET", "/", nil)
+
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ proxy.ServeHTTP(w, r)
+ }
+}
+
+func TestServeHTTPDeepCopy(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("Hello Gopher!"))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ type result struct {
+ before, after string
+ }
+
+ resultChan := make(chan result, 1)
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ before := r.URL.String()
+ proxyHandler.ServeHTTP(w, r)
+ after := r.URL.String()
+ resultChan <- result{before: before, after: after}
+ }))
+ defer frontend.Close()
+
+ want := result{before: "/", after: "/"}
+
+ res, err := frontend.Client().Get(frontend.URL)
+ if err != nil {
+ t.Fatalf("Do: %v", err)
+ }
+ res.Body.Close()
+
+ got := <-resultChan
+ if got != want {
+ t.Errorf("got = %+v; want = %+v", got, want)
+ }
+}
+
+// Issue 18327: verify we always do a deep copy of the Request.Header map
+// before any mutations.
+func TestClonesRequestHeaders(t *testing.T) {
+ log.SetOutput(io.Discard)
+ defer log.SetOutput(os.Stderr)
+ req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
+ req.RemoteAddr = "1.2.3.4:56789"
+ rp := &ReverseProxy{
+ Director: func(req *http.Request) {
+ req.Header.Set("From-Director", "1")
+ },
+ Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
+ if v := req.Header.Get("From-Director"); v != "1" {
+ t.Errorf("From-Directory value = %q; want 1", v)
+ }
+ return nil, io.EOF
+ }),
+ }
+ rp.ServeHTTP(httptest.NewRecorder(), req)
+
+ for _, h := range []string{
+ "From-Director",
+ "X-Forwarded-For",
+ } {
+ if req.Header.Get(h) != "" {
+ t.Errorf("%v header mutation modified caller's request", h)
+ }
+ }
+}
+
+type roundTripperFunc func(req *http.Request) (*http.Response, error)
+
+func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
+ return fn(req)
+}
+
+func TestModifyResponseClosesBody(t *testing.T) {
+ req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
+ req.RemoteAddr = "1.2.3.4:56789"
+ closeCheck := new(checkCloser)
+ logBuf := new(strings.Builder)
+ outErr := errors.New("ModifyResponse error")
+ rp := &ReverseProxy{
+ Director: func(req *http.Request) {},
+ Transport: &staticTransport{&http.Response{
+ StatusCode: 200,
+ Body: closeCheck,
+ }},
+ ErrorLog: log.New(logBuf, "", 0),
+ ModifyResponse: func(*http.Response) error {
+ return outErr
+ },
+ }
+ rec := httptest.NewRecorder()
+ rp.ServeHTTP(rec, req)
+ res := rec.Result()
+ if g, e := res.StatusCode, http.StatusBadGateway; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ if !closeCheck.closed {
+ t.Errorf("body should have been closed")
+ }
+ if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) {
+ t.Errorf("ErrorLog %q does not contain %q", g, e)
+ }
+}
+
+type checkCloser struct {
+ closed bool
+}
+
+func (cc *checkCloser) Close() error {
+ cc.closed = true
+ return nil
+}
+
+func (cc *checkCloser) Read(b []byte) (int, error) {
+ return len(b), nil
+}
+
+// Issue 23643: panic on body copy error
+func TestReverseProxy_PanicBodyError(t *testing.T) {
+ log.SetOutput(io.Discard)
+ defer log.SetOutput(os.Stderr)
+ backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ out := "this call was relayed by the reverse proxy"
+ // Coerce a wrong content length to induce io.ErrUnexpectedEOF
+ w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
+ fmt.Fprintln(w, out)
+ }))
+ defer backendServer.Close()
+
+ rpURL, err := url.Parse(backendServer.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ rproxy := NewSingleHostReverseProxy(rpURL)
+
+ // Ensure that the handler panics when the body read encounters an
+ // io.ErrUnexpectedEOF
+ defer func() {
+ err := recover()
+ if err == nil {
+ t.Fatal("handler should have panicked")
+ }
+ if err != http.ErrAbortHandler {
+ t.Fatal("expected ErrAbortHandler, got", err)
+ }
+ }()
+ req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
+ rproxy.ServeHTTP(httptest.NewRecorder(), req)
+}
+
+// Issue #46866: panic without closing incoming request body causes a panic
+func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ out := "this call was relayed by the reverse proxy"
+ // Coerce a wrong content length to induce io.ErrUnexpectedEOF
+ w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
+ fmt.Fprintln(w, out)
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+ frontendClient := frontend.Client()
+
+ var wg sync.WaitGroup
+ for i := 0; i < 2; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for j := 0; j < 10; j++ {
+ const reqLen = 6 * 1024 * 1024
+ req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
+ req.ContentLength = reqLen
+ resp, _ := frontendClient.Transport.RoundTrip(req)
+ if resp != nil {
+ io.Copy(io.Discard, resp.Body)
+ resp.Body.Close()
+ }
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+func TestSelectFlushInterval(t *testing.T) {
+ tests := []struct {
+ name string
+ p *ReverseProxy
+ res *http.Response
+ want time.Duration
+ }{
+ {
+ name: "default",
+ res: &http.Response{},
+ p: &ReverseProxy{FlushInterval: 123},
+ want: 123,
+ },
+ {
+ name: "server-sent events overrides non-zero",
+ res: &http.Response{
+ Header: http.Header{
+ "Content-Type": {"text/event-stream"},
+ },
+ },
+ p: &ReverseProxy{FlushInterval: 123},
+ want: -1,
+ },
+ {
+ name: "server-sent events overrides zero",
+ res: &http.Response{
+ Header: http.Header{
+ "Content-Type": {"text/event-stream"},
+ },
+ },
+ p: &ReverseProxy{FlushInterval: 0},
+ want: -1,
+ },
+ {
+ name: "server-sent events with media-type parameters overrides non-zero",
+ res: &http.Response{
+ Header: http.Header{
+ "Content-Type": {"text/event-stream;charset=utf-8"},
+ },
+ },
+ p: &ReverseProxy{FlushInterval: 123},
+ want: -1,
+ },
+ {
+ name: "server-sent events with media-type parameters overrides zero",
+ res: &http.Response{
+ Header: http.Header{
+ "Content-Type": {"text/event-stream;charset=utf-8"},
+ },
+ },
+ p: &ReverseProxy{FlushInterval: 0},
+ want: -1,
+ },
+ {
+ name: "Content-Length: -1, overrides non-zero",
+ res: &http.Response{
+ ContentLength: -1,
+ },
+ p: &ReverseProxy{FlushInterval: 123},
+ want: -1,
+ },
+ {
+ name: "Content-Length: -1, overrides zero",
+ res: &http.Response{
+ ContentLength: -1,
+ },
+ p: &ReverseProxy{FlushInterval: 0},
+ want: -1,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := tt.p.flushInterval(tt.res)
+ if got != tt.want {
+ t.Errorf("flushLatency = %v; want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestReverseProxyWebSocket(t *testing.T) {
+ backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if upgradeType(r.Header) != "websocket" {
+ t.Error("unexpected backend request")
+ http.Error(w, "unexpected request", 400)
+ return
+ }
+ c, _, err := w.(http.Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer c.Close()
+ io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n")
+ bs := bufio.NewScanner(c)
+ if !bs.Scan() {
+ t.Errorf("backend failed to read line from client: %v", bs.Err())
+ return
+ }
+ fmt.Fprintf(c, "backend got %q\n", bs.Text())
+ }))
+ defer backendServer.Close()
+
+ backURL, _ := url.Parse(backendServer.URL)
+ rproxy := NewSingleHostReverseProxy(backURL)
+ rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ rproxy.ModifyResponse = func(res *http.Response) error {
+ res.Header.Add("X-Modified", "true")
+ return nil
+ }
+
+ handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+ rw.Header().Set("X-Header", "X-Value")
+ rproxy.ServeHTTP(rw, req)
+ if got, want := rw.Header().Get("X-Modified"), "true"; got != want {
+ t.Errorf("response writer X-Modified header = %q; want %q", got, want)
+ }
+ })
+
+ frontendProxy := httptest.NewServer(handler)
+ defer frontendProxy.Close()
+
+ req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
+ req.Header.Set("Connection", "Upgrade")
+ req.Header.Set("Upgrade", "websocket")
+
+ c := frontendProxy.Client()
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 101 {
+ t.Fatalf("status = %v; want 101", res.Status)
+ }
+
+ got := res.Header.Get("X-Header")
+ want := "X-Value"
+ if got != want {
+ t.Errorf("Header(XHeader) = %q; want %q", got, want)
+ }
+
+ if !ascii.EqualFold(upgradeType(res.Header), "websocket") {
+ t.Fatalf("not websocket upgrade; got %#v", res.Header)
+ }
+ rwc, ok := res.Body.(io.ReadWriteCloser)
+ if !ok {
+ t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body)
+ }
+ defer rwc.Close()
+
+ if got, want := res.Header.Get("X-Modified"), "true"; got != want {
+ t.Errorf("response X-Modified header = %q; want %q", got, want)
+ }
+
+ io.WriteString(rwc, "Hello\n")
+ bs := bufio.NewScanner(rwc)
+ if !bs.Scan() {
+ t.Fatalf("Scan: %v", bs.Err())
+ }
+ got = bs.Text()
+ want = `backend got "Hello"`
+ if got != want {
+ t.Errorf("got %#q, want %#q", got, want)
+ }
+}
+
+func TestReverseProxyWebSocketCancellation(t *testing.T) {
+ n := 5
+ triggerCancelCh := make(chan bool, n)
+ nthResponse := func(i int) string {
+ return fmt.Sprintf("backend response #%d\n", i)
+ }
+ terminalMsg := "final message"
+
+ cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if g, ws := upgradeType(r.Header), "websocket"; g != ws {
+ t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
+ http.Error(w, "Unexpected request", 400)
+ return
+ }
+ conn, bufrw, err := w.(http.Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer conn.Close()
+
+ upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
+ if _, err := io.WriteString(conn, upgradeMsg); err != nil {
+ t.Error(err)
+ return
+ }
+ if _, _, err := bufrw.ReadLine(); err != nil {
+ t.Errorf("Failed to read line from client: %v", err)
+ return
+ }
+
+ for i := 0; i < n; i++ {
+ if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
+ select {
+ case <-triggerCancelCh:
+ default:
+ t.Errorf("Writing response #%d failed: %v", i, err)
+ }
+ return
+ }
+ bufrw.Flush()
+ time.Sleep(time.Second)
+ }
+ if _, err := bufrw.WriteString(terminalMsg); err != nil {
+ select {
+ case <-triggerCancelCh:
+ default:
+ t.Errorf("Failed to write terminal message: %v", err)
+ }
+ }
+ bufrw.Flush()
+ }))
+ defer cst.Close()
+
+ backendURL, _ := url.Parse(cst.URL)
+ rproxy := NewSingleHostReverseProxy(backendURL)
+ rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ rproxy.ModifyResponse = func(res *http.Response) error {
+ res.Header.Add("X-Modified", "true")
+ return nil
+ }
+
+ handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
+ rw.Header().Set("X-Header", "X-Value")
+ ctx, cancel := context.WithCancel(req.Context())
+ go func() {
+ <-triggerCancelCh
+ cancel()
+ }()
+ rproxy.ServeHTTP(rw, req.WithContext(ctx))
+ })
+
+ frontendProxy := httptest.NewServer(handler)
+ defer frontendProxy.Close()
+
+ req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
+ req.Header.Set("Connection", "Upgrade")
+ req.Header.Set("Upgrade", "websocket")
+
+ res, err := frontendProxy.Client().Do(req)
+ if err != nil {
+ t.Fatalf("Dialing to frontend proxy: %v", err)
+ }
+ defer res.Body.Close()
+ if g, w := res.StatusCode, 101; g != w {
+ t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
+ }
+
+ if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
+ t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w)
+ }
+
+ if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) {
+ t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w)
+ }
+
+ rwc, ok := res.Body.(io.ReadWriteCloser)
+ if !ok {
+ t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
+ }
+
+ if got, want := res.Header.Get("X-Modified"), "true"; got != want {
+ t.Errorf("response X-Modified header = %q; want %q", got, want)
+ }
+
+ if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
+ t.Fatalf("Failed to write first message: %v", err)
+ }
+
+ // Read loop.
+
+ br := bufio.NewReader(rwc)
+ for {
+ line, err := br.ReadString('\n')
+ switch {
+ case line == terminalMsg: // this case before "err == io.EOF"
+ t.Fatalf("The websocket request was not canceled, unfortunately!")
+
+ case err == io.EOF:
+ return
+
+ case err != nil:
+ t.Fatalf("Unexpected error: %v", err)
+
+ case line == nthResponse(0): // We've gotten the first response back
+ // Let's trigger a cancel.
+ close(triggerCancelCh)
+ }
+ }
+}
+
+func TestUnannouncedTrailer(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ w.(http.Flusher).Flush()
+ w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value")
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+ frontendClient := frontend.Client()
+
+ res, err := frontendClient.Get(frontend.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+
+ io.ReadAll(res.Body)
+
+ if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w {
+ t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w)
+ }
+
+}
+
+func TestSetURL(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(r.Host))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := &ReverseProxy{
+ Rewrite: func(r *ProxyRequest) {
+ r.SetURL(backendURL)
+ },
+ }
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+ frontendClient := frontend.Client()
+
+ res, err := frontendClient.Get(frontend.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("Reading body: %v", err)
+ }
+
+ if got, want := string(body), backendURL.Host; got != want {
+ t.Errorf("backend got Host %q, want %q", got, want)
+ }
+}
+
+func TestSingleJoinSlash(t *testing.T) {
+ tests := []struct {
+ slasha string
+ slashb string
+ expected string
+ }{
+ {"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"},
+ {"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"},
+ {"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"},
+ {"https://www.google.com", "", "https://www.google.com/"},
+ {"", "favicon.ico", "/favicon.ico"},
+ }
+ for _, tt := range tests {
+ if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected {
+ t.Errorf("singleJoiningSlash(%q,%q) want %q got %q",
+ tt.slasha,
+ tt.slashb,
+ tt.expected,
+ got)
+ }
+ }
+}
+
+func TestJoinURLPath(t *testing.T) {
+ tests := []struct {
+ a *url.URL
+ b *url.URL
+ wantPath string
+ wantRaw string
+ }{
+ {&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""},
+ {&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"},
+ {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
+ {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"},
+ {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"},
+ {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"},
+ }
+
+ for _, tt := range tests {
+ p, rp := joinURLPath(tt.a, tt.b)
+ if p != tt.wantPath || rp != tt.wantRaw {
+ t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)",
+ tt.a.Path, tt.a.RawPath,
+ tt.b.Path, tt.b.RawPath,
+ tt.wantPath, tt.wantRaw,
+ p, rp)
+ }
+ }
+}
+
+func TestReverseProxyRewriteReplacesOut(t *testing.T) {
+ const content = "response_content"
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(content))
+ }))
+ defer backend.Close()
+ proxyHandler := &ReverseProxy{
+ Rewrite: func(r *ProxyRequest) {
+ r.Out, _ = http.NewRequest("GET", backend.URL, nil)
+ },
+ }
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ res, err := frontend.Client().Get(frontend.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ body, _ := io.ReadAll(res.Body)
+ if got, want := string(body), content; got != want {
+ t.Errorf("got response %q, want %q", got, want)
+ }
+}
+
+func Test1xxResponses(t *testing.T) {
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ h := w.Header()
+ h.Add("Link", "</style.css>; rel=preload; as=style")
+ h.Add("Link", "</script.js>; rel=preload; as=script")
+ w.WriteHeader(http.StatusEarlyHints)
+
+ h.Add("Link", "</foo.js>; rel=preload; as=script")
+ w.WriteHeader(http.StatusProcessing)
+
+ w.Write([]byte("Hello"))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+ frontendClient := frontend.Client()
+
+ checkLinkHeaders := func(t *testing.T, expected, got []string) {
+ t.Helper()
+
+ if len(expected) != len(got) {
+ t.Errorf("Expected %d link headers; got %d", len(expected), len(got))
+ }
+
+ for i := range expected {
+ if i >= len(got) {
+ t.Errorf("Expected %q link header; got nothing", expected[i])
+
+ continue
+ }
+
+ if expected[i] != got[i] {
+ t.Errorf("Expected %q link header; got %q", expected[i], got[i])
+ }
+ }
+ }
+
+ var respCounter uint8
+ trace := &httptrace.ClientTrace{
+ Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+ switch code {
+ case http.StatusEarlyHints:
+ checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
+ case http.StatusProcessing:
+ checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
+ default:
+ t.Error("Unexpected 1xx response")
+ }
+
+ respCounter++
+
+ return nil
+ },
+ }
+ req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", frontend.URL, nil)
+
+ res, err := frontendClient.Do(req)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+
+ defer res.Body.Close()
+
+ if respCounter != 2 {
+ t.Errorf("Expected 2 1xx responses; got %d", respCounter)
+ }
+ checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
+
+ body, _ := io.ReadAll(res.Body)
+ if string(body) != "Hello" {
+ t.Errorf("Read body %q; want Hello", body)
+ }
+}
+
+const (
+ testWantsCleanQuery = true
+ testWantsRawQuery = false
+)
+
+func TestReverseProxyQueryParameterSmugglingDirectorDoesNotParseForm(t *testing.T) {
+ testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy {
+ proxyHandler := NewSingleHostReverseProxy(u)
+ oldDirector := proxyHandler.Director
+ proxyHandler.Director = func(r *http.Request) {
+ oldDirector(r)
+ }
+ return proxyHandler
+ })
+}
+
+func TestReverseProxyQueryParameterSmugglingDirectorParsesForm(t *testing.T) {
+ testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy {
+ proxyHandler := NewSingleHostReverseProxy(u)
+ oldDirector := proxyHandler.Director
+ proxyHandler.Director = func(r *http.Request) {
+ // Parsing the form causes ReverseProxy to remove unparsable
+ // query parameters before forwarding.
+ r.FormValue("a")
+ oldDirector(r)
+ }
+ return proxyHandler
+ })
+}
+
+func TestReverseProxyQueryParameterSmugglingRewrite(t *testing.T) {
+ testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy {
+ return &ReverseProxy{
+ Rewrite: func(r *ProxyRequest) {
+ r.SetURL(u)
+ },
+ }
+ })
+}
+
+func TestReverseProxyQueryParameterSmugglingRewritePreservesRawQuery(t *testing.T) {
+ testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy {
+ return &ReverseProxy{
+ Rewrite: func(r *ProxyRequest) {
+ r.SetURL(u)
+ r.Out.URL.RawQuery = r.In.URL.RawQuery
+ },
+ }
+ })
+}
+
+func testReverseProxyQueryParameterSmuggling(t *testing.T, wantCleanQuery bool, newProxy func(*url.URL) *ReverseProxy) {
+ const content = "response_content"
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte(r.URL.RawQuery))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := newProxy(backendURL)
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ // Don't spam output with logs of queries containing semicolons.
+ backend.Config.ErrorLog = log.New(io.Discard, "", 0)
+ frontend.Config.ErrorLog = log.New(io.Discard, "", 0)
+
+ for _, test := range []struct {
+ rawQuery string
+ cleanQuery string
+ }{{
+ rawQuery: "a=1&a=2;b=3",
+ cleanQuery: "a=1",
+ }, {
+ rawQuery: "a=1&a=%zz&b=3",
+ cleanQuery: "a=1&b=3",
+ }} {
+ res, err := frontend.Client().Get(frontend.URL + "?" + test.rawQuery)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ body, _ := io.ReadAll(res.Body)
+ wantQuery := test.rawQuery
+ if wantCleanQuery {
+ wantQuery = test.cleanQuery
+ }
+ if got, want := string(body), wantQuery; got != want {
+ t.Errorf("proxy forwarded raw query %q as %q, want %q", test.rawQuery, got, want)
+ }
+ }
+}