diff options
Diffstat (limited to 'src/net/http/httputil/reverseproxy_test.go')
-rw-r--r-- | src/net/http/httputil/reverseproxy_test.go | 1863 |
1 files changed, 1863 insertions, 0 deletions
diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go new file mode 100644 index 0000000..dd3330b --- /dev/null +++ b/src/net/http/httputil/reverseproxy_test.go @@ -0,0 +1,1863 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Reverse proxy tests. + +package httputil + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "net/http/httptrace" + "net/http/internal/ascii" + "net/textproto" + "net/url" + "os" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +const fakeHopHeader = "X-Fake-Hop-Header-For-Test" + +func init() { + inOurTests = true + hopHeaders = append(hopHeaders, fakeHopHeader) +} + +func TestReverseProxy(t *testing.T) { + const backendResponse = "I am the backend" + const backendStatus = 404 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" && r.FormValue("mode") == "hangup" { + c, _, _ := w.(http.Hijacker).Hijack() + c.Close() + return + } + if len(r.TransferEncoding) > 0 { + t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding) + } + if r.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + if c := r.Header.Get("Connection"); c != "" { + t.Errorf("handler got Connection header value %q", c) + } + if c := r.Header.Get("Te"); c != "trailers" { + t.Errorf("handler got Te header value %q; want 'trailers'", c) + } + if c := r.Header.Get("Upgrade"); c != "" { + t.Errorf("handler got Upgrade header value %q", c) + } + if c := r.Header.Get("Proxy-Connection"); c != "" { + t.Errorf("handler got Proxy-Connection header value %q", c) + } + if g, e := r.Host, "some-name"; g != e { + t.Errorf("backend got Host header %q, want %q", g, e) + } + w.Header().Set("Trailers", "not a special header field name") + w.Header().Set("Trailer", "X-Trailer") + w.Header().Set("X-Foo", "bar") + w.Header().Set("Upgrade", "foo") + w.Header().Set(fakeHopHeader, "foo") + w.Header().Add("X-Multi-Value", "foo") + w.Header().Add("X-Multi-Value", "bar") + http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"}) + w.WriteHeader(backendStatus) + w.Write([]byte(backendResponse)) + w.Header().Set("X-Trailer", "trailer_value") + w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Header.Set("Connection", "close, TE") + getReq.Header.Add("Te", "foo") + getReq.Header.Add("Te", "bar, trailers") + getReq.Header.Set("Proxy-Connection", "should be deleted") + getReq.Header.Set("Upgrade", "foo") + getReq.Close = true + res, err := frontendClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + if g, e := res.Header.Get("X-Foo"), "bar"; g != e { + t.Errorf("got X-Foo %q; expected %q", g, e) + } + if c := res.Header.Get(fakeHopHeader); c != "" { + t.Errorf("got %s header value %q", fakeHopHeader, c) + } + if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e { + t.Errorf("header Trailers = %q; want %q", g, e) + } + if g, e := len(res.Header["X-Multi-Value"]), 2; g != e { + t.Errorf("got %d X-Multi-Value header values; expected %d", g, e) + } + if g, e := len(res.Header["Set-Cookie"]), 1; g != e { + t.Fatalf("got %d SetCookies, want %d", g, e) + } + if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) { + t.Errorf("before reading body, Trailer = %#v; want %#v", g, e) + } + if cookie := res.Cookies()[0]; cookie.Name != "flavor" { + t.Errorf("unexpected cookie %q", cookie.Name) + } + bodyBytes, _ := io.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } + if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e { + t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e) + } + if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e { + t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e) + } + + // Test that a backend failing to be reached or one which doesn't return + // a response results in a StatusBadGateway. + getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil) + getReq.Close = true + res, err = frontendClient.Do(getReq) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != http.StatusBadGateway { + t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status) + } + +} + +// Issue 16875: remove any proxied headers mentioned in the "Connection" +// header value. +func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { + const fakeConnectionToken = "X-Fake-Connection-Token" + const backendResponse = "I am the backend" + + // someConnHeader is some arbitrary header to be declared as a hop-by-hop header + // in the Request's Connection header. + const someConnHeader = "X-Some-Conn-Header" + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if c := r.Header.Get("Connection"); c != "" { + t.Errorf("handler got header %q = %q; want empty", "Connection", c) + } + if c := r.Header.Get(fakeConnectionToken); c != "" { + t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) + } + if c := r.Header.Get(someConnHeader); c != "" { + t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) + } + w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken) + w.Header().Add("Connection", someConnHeader) + w.Header().Set(someConnHeader, "should be deleted") + w.Header().Set(fakeConnectionToken, "should be deleted") + io.WriteString(w, backendResponse) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proxyHandler.ServeHTTP(w, r) + if c := r.Header.Get(someConnHeader); c != "should be deleted" { + t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") + } + if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" { + t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted") + } + c := r.Header["Connection"] + var cf []string + for _, f := range c { + for _, sf := range strings.Split(f, ",") { + if sf = strings.TrimSpace(sf); sf != "" { + cf = append(cf, sf) + } + } + } + sort.Strings(cf) + expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken} + sort.Strings(expectedValues) + if !reflect.DeepEqual(cf, expectedValues) { + t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues) + } + })) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken) + getReq.Header.Add("Connection", someConnHeader) + getReq.Header.Set(someConnHeader, "should be deleted") + getReq.Header.Set(fakeConnectionToken, "should be deleted") + res, err := frontend.Client().Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("reading body: %v", err) + } + if got, want := string(bodyBytes), backendResponse; got != want { + t.Errorf("got body %q; want %q", got, want) + } + if c := res.Header.Get("Connection"); c != "" { + t.Errorf("handler got header %q = %q; want empty", "Connection", c) + } + if c := res.Header.Get(someConnHeader); c != "" { + t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) + } + if c := res.Header.Get(fakeConnectionToken); c != "" { + t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) + } +} + +func TestReverseProxyStripEmptyConnection(t *testing.T) { + // See Issue 46313. + const backendResponse = "I am the backend" + + // someConnHeader is some arbitrary header to be declared as a hop-by-hop header + // in the Request's Connection header. + const someConnHeader = "X-Some-Conn-Header" + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if c := r.Header.Values("Connection"); len(c) != 0 { + t.Errorf("handler got header %q = %v; want empty", "Connection", c) + } + if c := r.Header.Get(someConnHeader); c != "" { + t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) + } + w.Header().Add("Connection", "") + w.Header().Add("Connection", someConnHeader) + w.Header().Set(someConnHeader, "should be deleted") + io.WriteString(w, backendResponse) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proxyHandler.ServeHTTP(w, r) + if c := r.Header.Get(someConnHeader); c != "should be deleted" { + t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") + } + })) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Header.Add("Connection", "") + getReq.Header.Add("Connection", someConnHeader) + getReq.Header.Set(someConnHeader, "should be deleted") + res, err := frontend.Client().Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("reading body: %v", err) + } + if got, want := string(bodyBytes), backendResponse; got != want { + t.Errorf("got body %q; want %q", got, want) + } + if c := res.Header.Get("Connection"); c != "" { + t.Errorf("handler got header %q = %q; want empty", "Connection", c) + } + if c := res.Header.Get(someConnHeader); c != "" { + t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) + } +} + +func TestXForwardedFor(t *testing.T) { + const prevForwardedFor = "client ip" + const backendResponse = "I am the backend" + const backendStatus = 404 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) { + t.Errorf("X-Forwarded-For didn't contain prior data") + } + w.WriteHeader(backendStatus) + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Header.Set("Connection", "close") + getReq.Header.Set("X-Forwarded-For", prevForwardedFor) + getReq.Close = true + res, err := frontend.Client().Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + bodyBytes, _ := io.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} + +// Issue 38079: don't append to X-Forwarded-For if it's present but nil +func TestXForwardedFor_Omit(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if v := r.Header.Get("X-Forwarded-For"); v != "" { + t.Errorf("got X-Forwarded-For header: %q", v) + } + w.Write([]byte("hi")) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + oldDirector := proxyHandler.Director + proxyHandler.Director = func(r *http.Request) { + r.Header["X-Forwarded-For"] = nil + oldDirector(r) + } + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Close = true + res, err := frontend.Client().Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + res.Body.Close() +} + +func TestReverseProxyRewriteStripsForwarded(t *testing.T) { + headers := []string{ + "Forwarded", + "X-Forwarded-For", + "X-Forwarded-Host", + "X-Forwarded-Proto", + } + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for _, h := range headers { + if v := r.Header.Get(h); v != "" { + t.Errorf("got %v header: %q", h, v) + } + } + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.SetURL(backendURL) + }, + } + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Close = true + for _, h := range headers { + getReq.Header.Set(h, "x") + } + res, err := frontend.Client().Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + res.Body.Close() +} + +var proxyQueryTests = []struct { + baseSuffix string // suffix to add to backend URL + reqSuffix string // suffix to add to frontend's request URL + want string // what backend should see for final request URL (without ?) +}{ + {"", "", ""}, + {"?sta=tic", "?us=er", "sta=tic&us=er"}, + {"", "?us=er", "us=er"}, + {"?sta=tic", "", "sta=tic"}, +} + +func TestReverseProxyQuery(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Got-Query", r.URL.RawQuery) + w.Write([]byte("hi")) + })) + defer backend.Close() + + for i, tt := range proxyQueryTests { + backendURL, err := url.Parse(backend.URL + tt.baseSuffix) + if err != nil { + t.Fatal(err) + } + frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL)) + req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil) + req.Close = true + res, err := frontend.Client().Do(req) + if err != nil { + t.Fatalf("%d. Get: %v", i, err) + } + if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e { + t.Errorf("%d. got query %q; expected %q", i, g, e) + } + res.Body.Close() + frontend.Close() + } +} + +func TestReverseProxyFlushInterval(t *testing.T) { + const expected = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(expected)) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.FlushInterval = time.Microsecond + + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := frontend.Client().Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected { + t.Errorf("got body %q; expected %q", bodyBytes, expected) + } +} + +type mockFlusher struct { + http.ResponseWriter + flushed bool +} + +func (m *mockFlusher) Flush() { + m.flushed = true +} + +type wrappedRW struct { + http.ResponseWriter +} + +func (w *wrappedRW) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} + +func TestReverseProxyResponseControllerFlushInterval(t *testing.T) { + const expected = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(expected)) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + mf := &mockFlusher{} + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.FlushInterval = -1 // flush immediately + proxyWithMiddleware := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mf.ResponseWriter = w + w = &wrappedRW{mf} + proxyHandler.ServeHTTP(w, r) + }) + + frontend := httptest.NewServer(proxyWithMiddleware) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := frontend.Client().Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected { + t.Errorf("got body %q; expected %q", bodyBytes, expected) + } + if !mf.flushed { + t.Errorf("response writer was not flushed") + } +} + +func TestReverseProxyFlushIntervalHeaders(t *testing.T) { + const expected = "hi" + stopCh := make(chan struct{}) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("MyHeader", expected) + w.WriteHeader(200) + w.(http.Flusher).Flush() + <-stopCh + })) + defer backend.Close() + defer close(stopCh) + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.FlushInterval = time.Microsecond + + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + + ctx, cancel := context.WithTimeout(req.Context(), 10*time.Second) + defer cancel() + req = req.WithContext(ctx) + + res, err := frontend.Client().Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + + if res.Header.Get("MyHeader") != expected { + t.Errorf("got header %q; expected %q", res.Header.Get("MyHeader"), expected) + } +} + +func TestReverseProxyCancellation(t *testing.T) { + const backendResponse = "I am the backend" + + reqInFlight := make(chan struct{}) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + close(reqInFlight) // cause the client to cancel its request + + select { + case <-time.After(10 * time.Second): + // Note: this should only happen in broken implementations, and the + // closenotify case should be instantaneous. + t.Error("Handler never saw CloseNotify") + return + case <-w.(http.CloseNotifier).CloseNotify(): + } + + w.WriteHeader(http.StatusOK) + w.Write([]byte(backendResponse)) + })) + + defer backend.Close() + + backend.Config.ErrorLog = log.New(io.Discard, "", 0) + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := NewSingleHostReverseProxy(backendURL) + + // Discards errors of the form: + // http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) + + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + go func() { + <-reqInFlight + frontendClient.Transport.(*http.Transport).CancelRequest(getReq) + }() + res, err := frontendClient.Do(getReq) + if res != nil { + t.Errorf("got response %v; want nil", res.Status) + } + if err == nil { + // This should be an error like: + // Get "http://127.0.0.1:58079": read tcp 127.0.0.1:58079: + // use of closed network connection + t.Error("Server.Client().Do() returned nil error; want non-nil error") + } +} + +func req(t *testing.T, v string) *http.Request { + req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v))) + if err != nil { + t.Fatal(err) + } + return req +} + +// Issue 12344 +func TestNilBody(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi")) + })) + defer backend.Close() + + frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + backURL, _ := url.Parse(backend.URL) + rp := NewSingleHostReverseProxy(backURL) + r := req(t, "GET / HTTP/1.0\r\n\r\n") + r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working + rp.ServeHTTP(w, r) + })) + defer frontend.Close() + + res, err := http.Get(frontend.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + slurp, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(slurp) != "hi" { + t.Errorf("Got %q; want %q", slurp, "hi") + } +} + +// Issue 15524 +func TestUserAgentHeader(t *testing.T) { + var gotUA string + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotUA = r.Header.Get("User-Agent") + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + proxyHandler := new(ReverseProxy) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + proxyHandler.Director = func(req *http.Request) { + req.URL = backendURL + } + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + for _, sentUA := range []string{"explicit UA", ""} { + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Header.Set("User-Agent", sentUA) + getReq.Close = true + res, err := frontendClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + res.Body.Close() + if got, want := gotUA, sentUA; got != want { + t.Errorf("got forwarded User-Agent %q, want %q", got, want) + } + } +} + +type bufferPool struct { + get func() []byte + put func([]byte) +} + +func (bp bufferPool) Get() []byte { return bp.get() } +func (bp bufferPool) Put(v []byte) { bp.put(v) } + +func TestReverseProxyGetPutBuffer(t *testing.T) { + const msg = "hi" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, msg) + })) + defer backend.Close() + + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + var ( + mu sync.Mutex + log []string + ) + addLog := func(event string) { + mu.Lock() + defer mu.Unlock() + log = append(log, event) + } + rp := NewSingleHostReverseProxy(backendURL) + const size = 1234 + rp.BufferPool = bufferPool{ + get: func() []byte { + addLog("getBuf") + return make([]byte, size) + }, + put: func(p []byte) { + addLog("putBuf-" + strconv.Itoa(len(p))) + }, + } + frontend := httptest.NewServer(rp) + defer frontend.Close() + + req, _ := http.NewRequest("GET", frontend.URL, nil) + req.Close = true + res, err := frontend.Client().Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + slurp, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatalf("reading body: %v", err) + } + if string(slurp) != msg { + t.Errorf("msg = %q; want %q", slurp, msg) + } + wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)} + mu.Lock() + defer mu.Unlock() + if !reflect.DeepEqual(log, wantLog) { + t.Errorf("Log events = %q; want %q", log, wantLog) + } +} + +func TestReverseProxy_Post(t *testing.T) { + const backendResponse = "I am the backend" + const backendStatus = 200 + var requestBody = bytes.Repeat([]byte("a"), 1<<20) + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + slurp, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("Backend body read = %v", err) + } + if len(slurp) != len(requestBody) { + t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody)) + } + if !bytes.Equal(slurp, requestBody) { + t.Error("Backend read wrong request body.") // 1MB; omitting details + } + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody)) + res, err := frontend.Client().Do(postReq) + if err != nil { + t.Fatalf("Do: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + bodyBytes, _ := io.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} + +type RoundTripperFunc func(*http.Request) (*http.Response, error) + +func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +// Issue 16036: send a Request with a nil Body when possible +func TestReverseProxy_NilBody(t *testing.T) { + backendURL, _ := url.Parse("http://fake.tld/") + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.Body != nil { + t.Error("Body != nil; want a nil Body") + } + return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") + }) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + res, err := frontend.Client().Get(frontend.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != 502 { + t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status) + } +} + +// Issue 33142: always allocate the request headers +func TestReverseProxy_AllocatedHeader(t *testing.T) { + proxyHandler := new(ReverseProxy) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + proxyHandler.Director = func(*http.Request) {} // noop + proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.Header == nil { + t.Error("Header == nil; want a non-nil Header") + } + return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") + }) + + proxyHandler.ServeHTTP(httptest.NewRecorder(), &http.Request{ + Method: "GET", + URL: &url.URL{Scheme: "http", Host: "fake.tld", Path: "/"}, + Proto: "HTTP/1.0", + ProtoMajor: 1, + }) +} + +// Issue 14237. Test ModifyResponse and that an error from it +// causes the proxy to return StatusBadGateway, or StatusOK otherwise. +func TestReverseProxyModifyResponse(t *testing.T) { + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod")) + })) + defer backendServer.Close() + + rpURL, _ := url.Parse(backendServer.URL) + rproxy := NewSingleHostReverseProxy(rpURL) + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + rproxy.ModifyResponse = func(resp *http.Response) error { + if resp.Header.Get("X-Hit-Mod") != "true" { + return fmt.Errorf("tried to by-pass proxy") + } + return nil + } + + frontendProxy := httptest.NewServer(rproxy) + defer frontendProxy.Close() + + tests := []struct { + url string + wantCode int + }{ + {frontendProxy.URL + "/mod", http.StatusOK}, + {frontendProxy.URL + "/schedule", http.StatusBadGateway}, + } + + for i, tt := range tests { + resp, err := http.Get(tt.url) + if err != nil { + t.Fatalf("failed to reach proxy: %v", err) + } + if g, e := resp.StatusCode, tt.wantCode; g != e { + t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e) + } + resp.Body.Close() + } +} + +type failingRoundTripper struct{} + +func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { + return nil, errors.New("some error") +} + +type staticResponseRoundTripper struct{ res *http.Response } + +func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { + return rt.res, nil +} + +func TestReverseProxyErrorHandler(t *testing.T) { + tests := []struct { + name string + wantCode int + errorHandler func(http.ResponseWriter, *http.Request, error) + transport http.RoundTripper // defaults to failingRoundTripper + modifyResponse func(*http.Response) error + }{ + { + name: "default", + wantCode: http.StatusBadGateway, + }, + { + name: "errorhandler", + wantCode: http.StatusTeapot, + errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, + }, + { + name: "modifyresponse_noerr", + transport: staticResponseRoundTripper{ + &http.Response{StatusCode: 345, Body: http.NoBody}, + }, + modifyResponse: func(res *http.Response) error { + res.StatusCode++ + return nil + }, + errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, + wantCode: 346, + }, + { + name: "modifyresponse_err", + transport: staticResponseRoundTripper{ + &http.Response{StatusCode: 345, Body: http.NoBody}, + }, + modifyResponse: func(res *http.Response) error { + res.StatusCode++ + return errors.New("some error to trigger errorHandler") + }, + errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, + wantCode: http.StatusTeapot, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + target := &url.URL{ + Scheme: "http", + Host: "dummy.tld", + Path: "/", + } + rproxy := NewSingleHostReverseProxy(target) + rproxy.Transport = tt.transport + rproxy.ModifyResponse = tt.modifyResponse + if rproxy.Transport == nil { + rproxy.Transport = failingRoundTripper{} + } + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + if tt.errorHandler != nil { + rproxy.ErrorHandler = tt.errorHandler + } + frontendProxy := httptest.NewServer(rproxy) + defer frontendProxy.Close() + + resp, err := http.Get(frontendProxy.URL + "/test") + if err != nil { + t.Fatalf("failed to reach proxy: %v", err) + } + if g, e := resp.StatusCode, tt.wantCode; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + resp.Body.Close() + }) + } +} + +// Issue 16659: log errors from short read +func TestReverseProxy_CopyBuffer(t *testing.T) { + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + out := "this call was relayed by the reverse proxy" + // Coerce a wrong content length to induce io.UnexpectedEOF + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) + fmt.Fprintln(w, out) + })) + defer backendServer.Close() + + rpURL, err := url.Parse(backendServer.URL) + if err != nil { + t.Fatal(err) + } + + var proxyLog bytes.Buffer + rproxy := NewSingleHostReverseProxy(rpURL) + rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile) + donec := make(chan bool, 1) + frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { donec <- true }() + rproxy.ServeHTTP(w, r) + })) + defer frontendProxy.Close() + + if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil { + t.Fatalf("want non-nil error") + } + // The race detector complains about the proxyLog usage in logf in copyBuffer + // and our usage below with proxyLog.Bytes() so we're explicitly using a + // channel to ensure that the ReverseProxy's ServeHTTP is done before we + // continue after Get. + <-donec + + expected := []string{ + "EOF", + "read", + } + for _, phrase := range expected { + if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) { + t.Errorf("expected log to contain phrase %q", phrase) + } + } +} + +type staticTransport struct { + res *http.Response +} + +func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) { + return t.res, nil +} + +func BenchmarkServeHTTP(b *testing.B) { + res := &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("")), + } + proxy := &ReverseProxy{ + Director: func(*http.Request) {}, + Transport: &staticTransport{res}, + } + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + proxy.ServeHTTP(w, r) + } +} + +func TestServeHTTPDeepCopy(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hello Gopher!")) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + type result struct { + before, after string + } + + resultChan := make(chan result, 1) + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + before := r.URL.String() + proxyHandler.ServeHTTP(w, r) + after := r.URL.String() + resultChan <- result{before: before, after: after} + })) + defer frontend.Close() + + want := result{before: "/", after: "/"} + + res, err := frontend.Client().Get(frontend.URL) + if err != nil { + t.Fatalf("Do: %v", err) + } + res.Body.Close() + + got := <-resultChan + if got != want { + t.Errorf("got = %+v; want = %+v", got, want) + } +} + +// Issue 18327: verify we always do a deep copy of the Request.Header map +// before any mutations. +func TestClonesRequestHeaders(t *testing.T) { + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + req, _ := http.NewRequest("GET", "http://foo.tld/", nil) + req.RemoteAddr = "1.2.3.4:56789" + rp := &ReverseProxy{ + Director: func(req *http.Request) { + req.Header.Set("From-Director", "1") + }, + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if v := req.Header.Get("From-Director"); v != "1" { + t.Errorf("From-Directory value = %q; want 1", v) + } + return nil, io.EOF + }), + } + rp.ServeHTTP(httptest.NewRecorder(), req) + + for _, h := range []string{ + "From-Director", + "X-Forwarded-For", + } { + if req.Header.Get(h) != "" { + t.Errorf("%v header mutation modified caller's request", h) + } + } +} + +type roundTripperFunc func(req *http.Request) (*http.Response, error) + +func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +func TestModifyResponseClosesBody(t *testing.T) { + req, _ := http.NewRequest("GET", "http://foo.tld/", nil) + req.RemoteAddr = "1.2.3.4:56789" + closeCheck := new(checkCloser) + logBuf := new(strings.Builder) + outErr := errors.New("ModifyResponse error") + rp := &ReverseProxy{ + Director: func(req *http.Request) {}, + Transport: &staticTransport{&http.Response{ + StatusCode: 200, + Body: closeCheck, + }}, + ErrorLog: log.New(logBuf, "", 0), + ModifyResponse: func(*http.Response) error { + return outErr + }, + } + rec := httptest.NewRecorder() + rp.ServeHTTP(rec, req) + res := rec.Result() + if g, e := res.StatusCode, http.StatusBadGateway; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + if !closeCheck.closed { + t.Errorf("body should have been closed") + } + if g, e := logBuf.String(), outErr.Error(); !strings.Contains(g, e) { + t.Errorf("ErrorLog %q does not contain %q", g, e) + } +} + +type checkCloser struct { + closed bool +} + +func (cc *checkCloser) Close() error { + cc.closed = true + return nil +} + +func (cc *checkCloser) Read(b []byte) (int, error) { + return len(b), nil +} + +// Issue 23643: panic on body copy error +func TestReverseProxy_PanicBodyError(t *testing.T) { + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + out := "this call was relayed by the reverse proxy" + // Coerce a wrong content length to induce io.ErrUnexpectedEOF + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) + fmt.Fprintln(w, out) + })) + defer backendServer.Close() + + rpURL, err := url.Parse(backendServer.URL) + if err != nil { + t.Fatal(err) + } + + rproxy := NewSingleHostReverseProxy(rpURL) + + // Ensure that the handler panics when the body read encounters an + // io.ErrUnexpectedEOF + defer func() { + err := recover() + if err == nil { + t.Fatal("handler should have panicked") + } + if err != http.ErrAbortHandler { + t.Fatal("expected ErrAbortHandler, got", err) + } + }() + req, _ := http.NewRequest("GET", "http://foo.tld/", nil) + rproxy.ServeHTTP(httptest.NewRecorder(), req) +} + +// Issue #46866: panic without closing incoming request body causes a panic +func TestReverseProxy_PanicClosesIncomingBody(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + out := "this call was relayed by the reverse proxy" + // Coerce a wrong content length to induce io.ErrUnexpectedEOF + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) + fmt.Fprintln(w, out) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + var wg sync.WaitGroup + for i := 0; i < 2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 10; j++ { + const reqLen = 6 * 1024 * 1024 + req, _ := http.NewRequest("POST", frontend.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen}) + req.ContentLength = reqLen + resp, _ := frontendClient.Transport.RoundTrip(req) + if resp != nil { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + } + }() + } + wg.Wait() +} + +func TestSelectFlushInterval(t *testing.T) { + tests := []struct { + name string + p *ReverseProxy + res *http.Response + want time.Duration + }{ + { + name: "default", + res: &http.Response{}, + p: &ReverseProxy{FlushInterval: 123}, + want: 123, + }, + { + name: "server-sent events overrides non-zero", + res: &http.Response{ + Header: http.Header{ + "Content-Type": {"text/event-stream"}, + }, + }, + p: &ReverseProxy{FlushInterval: 123}, + want: -1, + }, + { + name: "server-sent events overrides zero", + res: &http.Response{ + Header: http.Header{ + "Content-Type": {"text/event-stream"}, + }, + }, + p: &ReverseProxy{FlushInterval: 0}, + want: -1, + }, + { + name: "server-sent events with media-type parameters overrides non-zero", + res: &http.Response{ + Header: http.Header{ + "Content-Type": {"text/event-stream;charset=utf-8"}, + }, + }, + p: &ReverseProxy{FlushInterval: 123}, + want: -1, + }, + { + name: "server-sent events with media-type parameters overrides zero", + res: &http.Response{ + Header: http.Header{ + "Content-Type": {"text/event-stream;charset=utf-8"}, + }, + }, + p: &ReverseProxy{FlushInterval: 0}, + want: -1, + }, + { + name: "Content-Length: -1, overrides non-zero", + res: &http.Response{ + ContentLength: -1, + }, + p: &ReverseProxy{FlushInterval: 123}, + want: -1, + }, + { + name: "Content-Length: -1, overrides zero", + res: &http.Response{ + ContentLength: -1, + }, + p: &ReverseProxy{FlushInterval: 0}, + want: -1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.p.flushInterval(tt.res) + if got != tt.want { + t.Errorf("flushLatency = %v; want %v", got, tt.want) + } + }) + } +} + +func TestReverseProxyWebSocket(t *testing.T) { + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if upgradeType(r.Header) != "websocket" { + t.Error("unexpected backend request") + http.Error(w, "unexpected request", 400) + return + } + c, _, err := w.(http.Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + defer c.Close() + io.WriteString(c, "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n") + bs := bufio.NewScanner(c) + if !bs.Scan() { + t.Errorf("backend failed to read line from client: %v", bs.Err()) + return + } + fmt.Fprintf(c, "backend got %q\n", bs.Text()) + })) + defer backendServer.Close() + + backURL, _ := url.Parse(backendServer.URL) + rproxy := NewSingleHostReverseProxy(backURL) + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + rproxy.ModifyResponse = func(res *http.Response) error { + res.Header.Add("X-Modified", "true") + return nil + } + + handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("X-Header", "X-Value") + rproxy.ServeHTTP(rw, req) + if got, want := rw.Header().Get("X-Modified"), "true"; got != want { + t.Errorf("response writer X-Modified header = %q; want %q", got, want) + } + }) + + frontendProxy := httptest.NewServer(handler) + defer frontendProxy.Close() + + req, _ := http.NewRequest("GET", frontendProxy.URL, nil) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + + c := frontendProxy.Client() + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 101 { + t.Fatalf("status = %v; want 101", res.Status) + } + + got := res.Header.Get("X-Header") + want := "X-Value" + if got != want { + t.Errorf("Header(XHeader) = %q; want %q", got, want) + } + + if !ascii.EqualFold(upgradeType(res.Header), "websocket") { + t.Fatalf("not websocket upgrade; got %#v", res.Header) + } + rwc, ok := res.Body.(io.ReadWriteCloser) + if !ok { + t.Fatalf("response body is of type %T; does not implement ReadWriteCloser", res.Body) + } + defer rwc.Close() + + if got, want := res.Header.Get("X-Modified"), "true"; got != want { + t.Errorf("response X-Modified header = %q; want %q", got, want) + } + + io.WriteString(rwc, "Hello\n") + bs := bufio.NewScanner(rwc) + if !bs.Scan() { + t.Fatalf("Scan: %v", bs.Err()) + } + got = bs.Text() + want = `backend got "Hello"` + if got != want { + t.Errorf("got %#q, want %#q", got, want) + } +} + +func TestReverseProxyWebSocketCancellation(t *testing.T) { + n := 5 + triggerCancelCh := make(chan bool, n) + nthResponse := func(i int) string { + return fmt.Sprintf("backend response #%d\n", i) + } + terminalMsg := "final message" + + cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if g, ws := upgradeType(r.Header), "websocket"; g != ws { + t.Errorf("Unexpected upgrade type %q, want %q", g, ws) + http.Error(w, "Unexpected request", 400) + return + } + conn, bufrw, err := w.(http.Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + defer conn.Close() + + upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n" + if _, err := io.WriteString(conn, upgradeMsg); err != nil { + t.Error(err) + return + } + if _, _, err := bufrw.ReadLine(); err != nil { + t.Errorf("Failed to read line from client: %v", err) + return + } + + for i := 0; i < n; i++ { + if _, err := bufrw.WriteString(nthResponse(i)); err != nil { + select { + case <-triggerCancelCh: + default: + t.Errorf("Writing response #%d failed: %v", i, err) + } + return + } + bufrw.Flush() + time.Sleep(time.Second) + } + if _, err := bufrw.WriteString(terminalMsg); err != nil { + select { + case <-triggerCancelCh: + default: + t.Errorf("Failed to write terminal message: %v", err) + } + } + bufrw.Flush() + })) + defer cst.Close() + + backendURL, _ := url.Parse(cst.URL) + rproxy := NewSingleHostReverseProxy(backendURL) + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + rproxy.ModifyResponse = func(res *http.Response) error { + res.Header.Add("X-Modified", "true") + return nil + } + + handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("X-Header", "X-Value") + ctx, cancel := context.WithCancel(req.Context()) + go func() { + <-triggerCancelCh + cancel() + }() + rproxy.ServeHTTP(rw, req.WithContext(ctx)) + }) + + frontendProxy := httptest.NewServer(handler) + defer frontendProxy.Close() + + req, _ := http.NewRequest("GET", frontendProxy.URL, nil) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + + res, err := frontendProxy.Client().Do(req) + if err != nil { + t.Fatalf("Dialing to frontend proxy: %v", err) + } + defer res.Body.Close() + if g, w := res.StatusCode, 101; g != w { + t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w) + } + + if g, w := res.Header.Get("X-Header"), "X-Value"; g != w { + t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w) + } + + if g, w := upgradeType(res.Header), "websocket"; !ascii.EqualFold(g, w) { + t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w) + } + + rwc, ok := res.Body.(io.ReadWriteCloser) + if !ok { + t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body) + } + + if got, want := res.Header.Get("X-Modified"), "true"; got != want { + t.Errorf("response X-Modified header = %q; want %q", got, want) + } + + if _, err := io.WriteString(rwc, "Hello\n"); err != nil { + t.Fatalf("Failed to write first message: %v", err) + } + + // Read loop. + + br := bufio.NewReader(rwc) + for { + line, err := br.ReadString('\n') + switch { + case line == terminalMsg: // this case before "err == io.EOF" + t.Fatalf("The websocket request was not canceled, unfortunately!") + + case err == io.EOF: + return + + case err != nil: + t.Fatalf("Unexpected error: %v", err) + + case line == nthResponse(0): // We've gotten the first response back + // Let's trigger a cancel. + close(triggerCancelCh) + } + } +} + +func TestUnannouncedTrailer(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + res, err := frontendClient.Get(frontend.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + + io.ReadAll(res.Body) + + if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w { + t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w) + } + +} + +func TestSetURL(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(r.Host)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.SetURL(backendURL) + }, + } + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + res, err := frontendClient.Get(frontend.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("Reading body: %v", err) + } + + if got, want := string(body), backendURL.Host; got != want { + t.Errorf("backend got Host %q, want %q", got, want) + } +} + +func TestSingleJoinSlash(t *testing.T) { + tests := []struct { + slasha string + slashb string + expected string + }{ + {"https://www.google.com/", "/favicon.ico", "https://www.google.com/favicon.ico"}, + {"https://www.google.com", "/favicon.ico", "https://www.google.com/favicon.ico"}, + {"https://www.google.com", "favicon.ico", "https://www.google.com/favicon.ico"}, + {"https://www.google.com", "", "https://www.google.com/"}, + {"", "favicon.ico", "/favicon.ico"}, + } + for _, tt := range tests { + if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected { + t.Errorf("singleJoiningSlash(%q,%q) want %q got %q", + tt.slasha, + tt.slashb, + tt.expected, + got) + } + } +} + +func TestJoinURLPath(t *testing.T) { + tests := []struct { + a *url.URL + b *url.URL + wantPath string + wantRaw string + }{ + {&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""}, + {&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"}, + {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, + {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, + {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"}, + {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"}, + } + + for _, tt := range tests { + p, rp := joinURLPath(tt.a, tt.b) + if p != tt.wantPath || rp != tt.wantRaw { + t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)", + tt.a.Path, tt.a.RawPath, + tt.b.Path, tt.b.RawPath, + tt.wantPath, tt.wantRaw, + p, rp) + } + } +} + +func TestReverseProxyRewriteReplacesOut(t *testing.T) { + const content = "response_content" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(content)) + })) + defer backend.Close() + proxyHandler := &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.Out, _ = http.NewRequest("GET", backend.URL, nil) + }, + } + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + res, err := frontend.Client().Get(frontend.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + body, _ := io.ReadAll(res.Body) + if got, want := string(body), content; got != want { + t.Errorf("got response %q, want %q", got, want) + } +} + +func Test1xxResponses(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := w.Header() + h.Add("Link", "</style.css>; rel=preload; as=style") + h.Add("Link", "</script.js>; rel=preload; as=script") + w.WriteHeader(http.StatusEarlyHints) + + h.Add("Link", "</foo.js>; rel=preload; as=script") + w.WriteHeader(http.StatusProcessing) + + w.Write([]byte("Hello")) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + frontendClient := frontend.Client() + + checkLinkHeaders := func(t *testing.T, expected, got []string) { + t.Helper() + + if len(expected) != len(got) { + t.Errorf("Expected %d link headers; got %d", len(expected), len(got)) + } + + for i := range expected { + if i >= len(got) { + t.Errorf("Expected %q link header; got nothing", expected[i]) + + continue + } + + if expected[i] != got[i] { + t.Errorf("Expected %q link header; got %q", expected[i], got[i]) + } + } + } + + var respCounter uint8 + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + switch code { + case http.StatusEarlyHints: + checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"]) + case http.StatusProcessing: + checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"]) + default: + t.Error("Unexpected 1xx response") + } + + respCounter++ + + return nil + }, + } + req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", frontend.URL, nil) + + res, err := frontendClient.Do(req) + if err != nil { + t.Fatalf("Get: %v", err) + } + + defer res.Body.Close() + + if respCounter != 2 { + t.Errorf("Expected 2 1xx responses; got %d", respCounter) + } + checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"]) + + body, _ := io.ReadAll(res.Body) + if string(body) != "Hello" { + t.Errorf("Read body %q; want Hello", body) + } +} + +const ( + testWantsCleanQuery = true + testWantsRawQuery = false +) + +func TestReverseProxyQueryParameterSmugglingDirectorDoesNotParseForm(t *testing.T) { + testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy { + proxyHandler := NewSingleHostReverseProxy(u) + oldDirector := proxyHandler.Director + proxyHandler.Director = func(r *http.Request) { + oldDirector(r) + } + return proxyHandler + }) +} + +func TestReverseProxyQueryParameterSmugglingDirectorParsesForm(t *testing.T) { + testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy { + proxyHandler := NewSingleHostReverseProxy(u) + oldDirector := proxyHandler.Director + proxyHandler.Director = func(r *http.Request) { + // Parsing the form causes ReverseProxy to remove unparsable + // query parameters before forwarding. + r.FormValue("a") + oldDirector(r) + } + return proxyHandler + }) +} + +func TestReverseProxyQueryParameterSmugglingRewrite(t *testing.T) { + testReverseProxyQueryParameterSmuggling(t, testWantsCleanQuery, func(u *url.URL) *ReverseProxy { + return &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.SetURL(u) + }, + } + }) +} + +func TestReverseProxyQueryParameterSmugglingRewritePreservesRawQuery(t *testing.T) { + testReverseProxyQueryParameterSmuggling(t, testWantsRawQuery, func(u *url.URL) *ReverseProxy { + return &ReverseProxy{ + Rewrite: func(r *ProxyRequest) { + r.SetURL(u) + r.Out.URL.RawQuery = r.In.URL.RawQuery + }, + } + }) +} + +func testReverseProxyQueryParameterSmuggling(t *testing.T, wantCleanQuery bool, newProxy func(*url.URL) *ReverseProxy) { + const content = "response_content" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(r.URL.RawQuery)) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := newProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + // Don't spam output with logs of queries containing semicolons. + backend.Config.ErrorLog = log.New(io.Discard, "", 0) + frontend.Config.ErrorLog = log.New(io.Discard, "", 0) + + for _, test := range []struct { + rawQuery string + cleanQuery string + }{{ + rawQuery: "a=1&a=2;b=3", + cleanQuery: "a=1", + }, { + rawQuery: "a=1&a=%zz&b=3", + cleanQuery: "a=1&b=3", + }} { + res, err := frontend.Client().Get(frontend.URL + "?" + test.rawQuery) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + body, _ := io.ReadAll(res.Body) + wantQuery := test.rawQuery + if wantCleanQuery { + wantQuery = test.cleanQuery + } + if got, want := string(body), wantQuery; got != want { + t.Errorf("proxy forwarded raw query %q as %q, want %q", test.rawQuery, got, want) + } + } +} |