summaryrefslogtreecommitdiffstats
path: root/pkg/v1/remote/transport/bearer_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/v1/remote/transport/bearer_test.go')
-rw-r--r--pkg/v1/remote/transport/bearer_test.go561
1 files changed, 561 insertions, 0 deletions
diff --git a/pkg/v1/remote/transport/bearer_test.go b/pkg/v1/remote/transport/bearer_test.go
new file mode 100644
index 0000000..a03b1f9
--- /dev/null
+++ b/pkg/v1/remote/transport/bearer_test.go
@@ -0,0 +1,561 @@
+// Copyright 2018 Google LLC All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/google/go-containerregistry/pkg/authn"
+ "github.com/google/go-containerregistry/pkg/name"
+)
+
+func TestBearerRefresh(t *testing.T) {
+ expectedToken := "Sup3rDup3rS3cr3tz"
+ expectedScope := "this-is-your-scope"
+ expectedService := "my-service.io"
+
+ cases := []struct {
+ tokenKey string
+ wantErr bool
+ }{{
+ tokenKey: "token",
+ wantErr: false,
+ }, {
+ tokenKey: "access_token",
+ wantErr: false,
+ }, {
+ tokenKey: "tolkien",
+ wantErr: true,
+ }}
+
+ for _, tc := range cases {
+ t.Run(tc.tokenKey, func(t *testing.T) {
+ server := httptest.NewServer(
+ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ hdr := r.Header.Get("Authorization")
+ if !strings.HasPrefix(hdr, "Basic ") {
+ t.Errorf("Header.Get(Authorization); got %v, want Basic prefix", hdr)
+ }
+ if got, want := r.FormValue("scope"), expectedScope; got != want {
+ t.Errorf("FormValue(scope); got %v, want %v", got, want)
+ }
+ if got, want := r.FormValue("service"), expectedService; got != want {
+ t.Errorf("FormValue(service); got %v, want %v", got, want)
+ }
+ w.Write([]byte(fmt.Sprintf(`{%q: %q}`, tc.tokenKey, expectedToken)))
+ }))
+ defer server.Close()
+
+ basic := &authn.Basic{Username: "foo", Password: "bar"}
+ registry, err := name.NewRegistry(expectedService, name.WeakValidation)
+ if err != nil {
+ t.Errorf("Unexpected error during NewRegistry: %v", err)
+ }
+
+ bt := &bearerTransport{
+ inner: http.DefaultTransport,
+ basic: basic,
+ registry: registry,
+ realm: server.URL,
+ scopes: []string{expectedScope},
+ service: expectedService,
+ scheme: "http",
+ }
+
+ if err := bt.refresh(context.Background()); (err != nil) != tc.wantErr {
+ t.Errorf("refresh() = %v", err)
+ }
+ })
+ }
+}
+
+func TestBearerTransport(t *testing.T) {
+ expectedToken := "sdkjhfskjdhfkjshdf"
+
+ blobServer := httptest.NewServer(
+ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // We don't expect the blobServer to receive bearer tokens.
+ if got := r.Header.Get("Authorization"); got != "" {
+ t.Errorf("Header.Get(Authorization); got %v, want empty string", got)
+ }
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer blobServer.Close()
+
+ server := httptest.NewServer(
+ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if got, want := r.Header.Get("Authorization"), "Bearer "+expectedToken; got != want {
+ t.Errorf("Header.Get(Authorization); got %v, want %v", got, want)
+ }
+ if r.URL.Path == "/v2/auth" {
+ http.Redirect(w, r, "/redirect", http.StatusMovedPermanently)
+ return
+ }
+ if strings.Contains(r.URL.Path, "blobs") {
+ http.Redirect(w, r, blobServer.URL, http.StatusFound)
+ return
+ }
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer server.Close()
+
+ u, err := url.Parse(server.URL)
+ if err != nil {
+ t.Errorf("Unexpected error during url.Parse: %v", err)
+ }
+ registry, err := name.NewRegistry(u.Host, name.WeakValidation)
+ if err != nil {
+ t.Errorf("Unexpected error during NewRegistry: %v", err)
+ }
+
+ client := http.Client{Transport: &bearerTransport{
+ inner: &http.Transport{},
+ bearer: authn.AuthConfig{RegistryToken: expectedToken},
+ registry: registry,
+ scheme: "http",
+ }}
+
+ _, err = client.Get(fmt.Sprintf("http://%s/v2/auth", u.Host))
+ if err != nil {
+ t.Errorf("Unexpected error during Get: %v", err)
+ }
+
+ _, err = client.Get(fmt.Sprintf("http://%s/v2/foo/bar/blobs/blah", u.Host))
+ if err != nil {
+ t.Errorf("Unexpected error during Get: %v", err)
+ }
+}
+
+func TestBearerTransportTokenRefresh(t *testing.T) {
+ initialToken := "foo"
+ refreshedToken := "bar"
+
+ server := httptest.NewServer(
+ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ hdr := r.Header.Get("Authorization")
+ if hdr == "Bearer "+refreshedToken {
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+ if strings.HasPrefix(hdr, "Basic ") {
+ w.Write([]byte(fmt.Sprintf(`{"token": %q}`, refreshedToken)))
+ }
+
+ w.Header().Set("WWW-Authenticate", "scope=foo")
+ w.WriteHeader(http.StatusUnauthorized)
+ }))
+ defer server.Close()
+
+ u, err := url.Parse(server.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ registry, err := name.NewRegistry(u.Host, name.WeakValidation)
+ if err != nil {
+ t.Fatalf("Unexpected error during NewRegistry: %v", err)
+ }
+
+ // Pass Username/Password
+ transport := &bearerTransport{
+ inner: http.DefaultTransport,
+ bearer: authn.AuthConfig{RegistryToken: initialToken},
+ basic: &authn.Basic{Username: "foo", Password: "bar"},
+ registry: registry,
+ realm: server.URL,
+ scheme: "http",
+ }
+ client := http.Client{Transport: transport}
+
+ res, err := client.Get(fmt.Sprintf("http://%s/v2/foo/bar/blobs/blah", u.Host))
+ if err != nil {
+ t.Errorf("Unexpected error during client.Get: %v", err)
+ return
+ }
+ if res.StatusCode != http.StatusOK {
+ t.Errorf("client.Get final StatusCode got %v, want: %v", res.StatusCode, http.StatusOK)
+ }
+ if got, want := transport.bearer.RegistryToken, refreshedToken; got != want {
+ t.Errorf("Expected Bearer token to be refreshed, got %v, want %v", got, want)
+ }
+
+ // Pass RegistryToken directly
+ transport.bearer = authn.AuthConfig{RegistryToken: initialToken}
+ transport.basic = &authn.Bearer{Token: refreshedToken}
+ client = http.Client{Transport: transport}
+
+ res, err = client.Get(fmt.Sprintf("http://%s/v2/foo/bar/blobs/blah", u.Host))
+ if err != nil {
+ t.Errorf("Unexpected error during client.Get: %v", err)
+ return
+ }
+ if res.StatusCode != http.StatusOK {
+ t.Errorf("client.Get final StatusCode got %v, want: %v", res.StatusCode, http.StatusOK)
+ }
+ if got, want := transport.bearer.RegistryToken, refreshedToken; got != want {
+ t.Errorf("Expected Bearer token to be refreshed, got %v, want %v", got, want)
+ }
+}
+
+func TestBearerTransportOauthRefresh(t *testing.T) {
+ initialToken := "foo"
+ accessToken := "bar"
+ refreshToken := "baz"
+
+ server := httptest.NewServer(
+ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method == http.MethodPost {
+ if err := r.ParseForm(); err != nil {
+ t.Fatal(err)
+ }
+ if it := r.FormValue("refresh_token"); it != initialToken {
+ t.Errorf("want %s got %s", initialToken, it)
+ }
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(fmt.Sprintf(`{"access_token": %q, "refresh_token": %q}`, accessToken, refreshToken)))
+ return
+ }
+
+ hdr := r.Header.Get("Authorization")
+ if hdr == "Bearer "+accessToken {
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+
+ w.Header().Set("WWW-Authenticate", "scope=foo")
+ w.WriteHeader(http.StatusUnauthorized)
+ }))
+ defer server.Close()
+
+ u, err := url.Parse(server.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ registry, err := name.NewRegistry(u.Host, name.WeakValidation)
+ if err != nil {
+ t.Errorf("Unexpected error during NewRegistry: %v", err)
+ }
+
+ transport := &bearerTransport{
+ inner: http.DefaultTransport,
+ basic: authn.FromConfig(authn.AuthConfig{IdentityToken: initialToken}),
+ registry: registry,
+ realm: server.URL,
+ scheme: "http",
+ scopes: []string{"myscope"},
+ service: u.Host,
+ }
+ client := http.Client{Transport: transport}
+
+ res, err := client.Get(fmt.Sprintf("http://%s/v2/foo/bar/blobs/blah", u.Host))
+ if err != nil {
+ t.Fatalf("Unexpected error during client.Get: %v", err)
+ }
+ if res.StatusCode != http.StatusOK {
+ t.Errorf("client.Get final StatusCode got %v, want: %v", res.StatusCode, http.StatusOK)
+ }
+ if want, got := transport.bearer.RegistryToken, accessToken; want != got {
+ t.Errorf("Expected Bearer token to be refreshed, got %v, want %v", got, want)
+ }
+ basicAuthConfig, err := transport.basic.Authorization()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := basicAuthConfig.IdentityToken, refreshToken; got != want {
+ t.Errorf("Expected Basic IdentityToken to be refreshed, got %v, want %v", got, want)
+ }
+}
+
+func TestBearerTransportOauth404Fallback(t *testing.T) {
+ basicAuth := "basic_auth"
+ identityToken := "identity_token"
+ accessToken := "access_token"
+
+ server := httptest.NewServer(
+ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method == http.MethodPost {
+ w.WriteHeader(http.StatusNotFound)
+ }
+
+ hdr := r.Header.Get("Authorization")
+ if hdr == "Basic "+basicAuth {
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(fmt.Sprintf(`{"access_token": %q}`, accessToken)))
+ }
+ if hdr == "Bearer "+accessToken {
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+
+ w.Header().Set("WWW-Authenticate", "scope=foo")
+ w.WriteHeader(http.StatusUnauthorized)
+ }))
+ defer server.Close()
+
+ u, err := url.Parse(server.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ registry, err := name.NewRegistry(u.Host, name.WeakValidation)
+ if err != nil {
+ t.Errorf("Unexpected error during NewRegistry: %v", err)
+ }
+
+ transport := &bearerTransport{
+ inner: http.DefaultTransport,
+ basic: authn.FromConfig(authn.AuthConfig{
+ IdentityToken: identityToken,
+ Auth: basicAuth,
+ }),
+ registry: registry,
+ realm: server.URL,
+ scheme: "http",
+ scopes: []string{"myscope"},
+ service: u.Host,
+ }
+ client := http.Client{Transport: transport}
+
+ res, err := client.Get(fmt.Sprintf("http://%s/v2/foo/bar/blobs/blah", u.Host))
+ if err != nil {
+ t.Fatalf("Unexpected error during client.Get: %v", err)
+ }
+ if res.StatusCode != http.StatusOK {
+ t.Errorf("client.Get final StatusCode got %v, want: %v", res.StatusCode, http.StatusOK)
+ }
+ if got, want := transport.bearer.RegistryToken, accessToken; got != want {
+ t.Errorf("Expected Bearer token to be refreshed, got %v, want %v", got, want)
+ }
+}
+
+type recorder struct {
+ reqs []*http.Request
+ resp *http.Response
+ err error
+}
+
+func newRecorder(resp *http.Response, err error) *recorder {
+ return &recorder{
+ reqs: []*http.Request{},
+ resp: resp,
+ err: err,
+ }
+}
+
+func (r *recorder) RoundTrip(in *http.Request) (*http.Response, error) {
+ r.reqs = append(r.reqs, in)
+ return r.resp, r.err
+}
+
+func TestSchemeOverride(t *testing.T) {
+ // Record the requests we get in the inner transport.
+ cannedResponse := http.Response{
+ Status: http.StatusText(http.StatusOK),
+ StatusCode: http.StatusOK,
+ }
+ rec := newRecorder(&cannedResponse, nil)
+ registry, err := name.NewRegistry("example.com")
+ if err != nil {
+ t.Fatalf("Unexpected error during NewRegistry: %v", err)
+ }
+ st := &schemeTransport{
+ inner: rec,
+ registry: registry,
+ scheme: "http",
+ }
+
+ // We should see the scheme be overridden to "http" for the registry, but the
+ // scheme for the token server should be unchanged.
+ tests := []struct {
+ url string
+ wantScheme string
+ }{{
+ url: "https://example.com",
+ wantScheme: "http",
+ }, {
+ url: "https://token.example.com",
+ wantScheme: "https",
+ }}
+
+ for i, tt := range tests {
+ req, err := http.NewRequest("GET", tt.url, nil)
+ if err != nil {
+ t.Fatalf("Unexpected error during NewRequest: %v", err)
+ }
+
+ if _, err := st.RoundTrip(req); err != nil {
+ t.Fatalf("Unexpected error during RoundTrip: %v", err)
+ }
+
+ if got, want := rec.reqs[i].URL.Scheme, tt.wantScheme; got != want {
+ t.Errorf("Wrong scheme: wanted %v, got %v", want, got)
+ }
+ }
+}
+
+func TestCanonicalAddressResolution(t *testing.T) {
+ registry, err := name.NewRegistry("does-not-matter", name.WeakValidation)
+ if err != nil {
+ t.Errorf("Unexpected error during NewRegistry: %v", err)
+ }
+
+ tests := []struct {
+ registry name.Registry
+ scheme string
+ address string
+ want string
+ }{{
+ registry: registry,
+ scheme: "http",
+ address: "registry.example.com",
+ want: "registry.example.com:80",
+ }, {
+ registry: registry,
+ scheme: "http",
+ address: "registry.example.com:12345",
+ want: "registry.example.com:12345",
+ }, {
+ registry: registry,
+ scheme: "https",
+ address: "registry.example.com",
+ want: "registry.example.com:443",
+ }, {
+ registry: registry,
+ scheme: "https",
+ address: "registry.example.com:12345",
+ want: "registry.example.com:12345",
+ }, {
+ registry: registry,
+ scheme: "http",
+ address: "registry.example.com:",
+ want: "registry.example.com:80",
+ }, {
+ registry: registry,
+ scheme: "https",
+ address: "registry.example.com:",
+ want: "registry.example.com:443",
+ }, {
+ registry: registry,
+ scheme: "http",
+ address: "2001:db8::1",
+ want: "[2001:db8::1]:80",
+ }, {
+ registry: registry,
+ scheme: "https",
+ address: "2001:db8::1",
+ want: "[2001:db8::1]:443",
+ }, {
+ registry: registry,
+ scheme: "http",
+ address: "[2001:db8::1]:12345",
+ want: "[2001:db8::1]:12345",
+ }, {
+ registry: registry,
+ scheme: "https",
+ address: "[2001:db8::1]:12345",
+ want: "[2001:db8::1]:12345",
+ }, {
+ registry: registry,
+ scheme: "http",
+ address: "[2001:db8::1]:",
+ want: "[2001:db8::1]:80",
+ }, {
+ registry: registry,
+ scheme: "https",
+ address: "[2001:db8::1]:",
+ want: "[2001:db8::1]:443",
+ }, {
+ registry: registry,
+ scheme: "https",
+ address: "something:is::wrong]:",
+ want: "something:is::wrong]:",
+ }}
+
+ for _, tt := range tests {
+ got := canonicalAddress(tt.address, tt.scheme)
+ if got != tt.want {
+ t.Errorf("Wrong canonical host: wanted %v got %v", tt.want, got)
+ }
+ }
+}
+
+func TestInsufficientScope(t *testing.T) {
+ wrong := "the-wrong-scope"
+ right := "the-right-scope"
+ realm := ""
+ expectedService := "my-service.io"
+ passed := false
+
+ server := httptest.NewServer(
+ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ query := r.URL.Query()
+
+ scopes := query["scope"]
+ switch {
+ case len(scopes) == 0:
+ if !passed {
+ w.Header().Set("WWW-Authenticate", fmt.Sprintf("Bearer realm=%q,scope=%q", realm, right))
+ w.WriteHeader(http.StatusUnauthorized)
+ }
+ case len(scopes) == 1:
+ w.Write([]byte(`{"token": "arbitrary-token"}`))
+ default:
+ passed = true
+ w.Write([]byte(`{"token": "arbitrary-token-2"}`))
+ }
+ }))
+ defer server.Close()
+
+ basic := &authn.Basic{Username: "foo", Password: "bar"}
+ u, err := url.Parse(server.URL)
+ if err != nil {
+ t.Error("Unexpected error during url.Parse: ", err)
+ }
+ realm = u.Host
+
+ registry, err := name.NewRegistry(expectedService, name.WeakValidation)
+ if err != nil {
+ t.Error("Unexpected error during NewRegistry: ", err)
+ }
+
+ bt := &bearerTransport{
+ inner: http.DefaultTransport,
+ basic: basic,
+ registry: registry,
+ realm: server.URL,
+ scopes: []string{wrong},
+ service: expectedService,
+ scheme: "http",
+ }
+
+ client := http.Client{Transport: bt}
+
+ res, err := client.Get(fmt.Sprintf("http://%s/v2/foo/bar/blobs/blah", u.Host))
+ if err != nil {
+ t.Error("Unexpected error during client.Get: ", err)
+ return
+ }
+ if res.StatusCode != http.StatusOK {
+ t.Errorf("client.Get final StatusCode got %v, want: %v", res.StatusCode, http.StatusOK)
+ }
+
+ if !passed {
+ t.Error("didn't refresh insufficient scope")
+ }
+}