diff options
Diffstat (limited to 'src/encoding/base64/base64_test.go')
-rw-r--r-- | src/encoding/base64/base64_test.go | 530 |
1 files changed, 530 insertions, 0 deletions
diff --git a/src/encoding/base64/base64_test.go b/src/encoding/base64/base64_test.go new file mode 100644 index 0000000..57256a3 --- /dev/null +++ b/src/encoding/base64/base64_test.go @@ -0,0 +1,530 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package base64 + +import ( + "bytes" + "errors" + "fmt" + "io" + "reflect" + "runtime/debug" + "strings" + "testing" + "time" +) + +type testpair struct { + decoded, encoded string +} + +var pairs = []testpair{ + // RFC 3548 examples + {"\x14\xfb\x9c\x03\xd9\x7e", "FPucA9l+"}, + {"\x14\xfb\x9c\x03\xd9", "FPucA9k="}, + {"\x14\xfb\x9c\x03", "FPucAw=="}, + + // RFC 4648 examples + {"", ""}, + {"f", "Zg=="}, + {"fo", "Zm8="}, + {"foo", "Zm9v"}, + {"foob", "Zm9vYg=="}, + {"fooba", "Zm9vYmE="}, + {"foobar", "Zm9vYmFy"}, + + // Wikipedia examples + {"sure.", "c3VyZS4="}, + {"sure", "c3VyZQ=="}, + {"sur", "c3Vy"}, + {"su", "c3U="}, + {"leasure.", "bGVhc3VyZS4="}, + {"easure.", "ZWFzdXJlLg=="}, + {"asure.", "YXN1cmUu"}, + {"sure.", "c3VyZS4="}, +} + +// Do nothing to a reference base64 string (leave in standard format) +func stdRef(ref string) string { + return ref +} + +// Convert a reference string to URL-encoding +func urlRef(ref string) string { + ref = strings.ReplaceAll(ref, "+", "-") + ref = strings.ReplaceAll(ref, "/", "_") + return ref +} + +// Convert a reference string to raw, unpadded format +func rawRef(ref string) string { + return strings.TrimRight(ref, "=") +} + +// Both URL and unpadding conversions +func rawURLRef(ref string) string { + return rawRef(urlRef(ref)) +} + +// A nonstandard encoding with a funny padding character, for testing +var funnyEncoding = NewEncoding(encodeStd).WithPadding(rune('@')) + +func funnyRef(ref string) string { + return strings.ReplaceAll(ref, "=", "@") +} + +type encodingTest struct { + enc *Encoding // Encoding to test + conv func(string) string // Reference string converter +} + +var encodingTests = []encodingTest{ + {StdEncoding, stdRef}, + {URLEncoding, urlRef}, + {RawStdEncoding, rawRef}, + {RawURLEncoding, rawURLRef}, + {funnyEncoding, funnyRef}, + {StdEncoding.Strict(), stdRef}, + {URLEncoding.Strict(), urlRef}, + {RawStdEncoding.Strict(), rawRef}, + {RawURLEncoding.Strict(), rawURLRef}, + {funnyEncoding.Strict(), funnyRef}, +} + +var bigtest = testpair{ + "Twas brillig, and the slithy toves", + "VHdhcyBicmlsbGlnLCBhbmQgdGhlIHNsaXRoeSB0b3Zlcw==", +} + +func testEqual(t *testing.T, msg string, args ...any) bool { + t.Helper() + if args[len(args)-2] != args[len(args)-1] { + t.Errorf(msg, args...) + return false + } + return true +} + +func TestEncode(t *testing.T) { + for _, p := range pairs { + for _, tt := range encodingTests { + got := tt.enc.EncodeToString([]byte(p.decoded)) + testEqual(t, "Encode(%q) = %q, want %q", p.decoded, + got, tt.conv(p.encoded)) + } + } +} + +func TestEncoder(t *testing.T) { + for _, p := range pairs { + bb := &bytes.Buffer{} + encoder := NewEncoder(StdEncoding, bb) + encoder.Write([]byte(p.decoded)) + encoder.Close() + testEqual(t, "Encode(%q) = %q, want %q", p.decoded, bb.String(), p.encoded) + } +} + +func TestEncoderBuffering(t *testing.T) { + input := []byte(bigtest.decoded) + for bs := 1; bs <= 12; bs++ { + bb := &bytes.Buffer{} + encoder := NewEncoder(StdEncoding, bb) + for pos := 0; pos < len(input); pos += bs { + end := pos + bs + if end > len(input) { + end = len(input) + } + n, err := encoder.Write(input[pos:end]) + testEqual(t, "Write(%q) gave error %v, want %v", input[pos:end], err, error(nil)) + testEqual(t, "Write(%q) gave length %v, want %v", input[pos:end], n, end-pos) + } + err := encoder.Close() + testEqual(t, "Close gave error %v, want %v", err, error(nil)) + testEqual(t, "Encoding/%d of %q = %q, want %q", bs, bigtest.decoded, bb.String(), bigtest.encoded) + } +} + +func TestDecode(t *testing.T) { + for _, p := range pairs { + for _, tt := range encodingTests { + encoded := tt.conv(p.encoded) + dbuf := make([]byte, tt.enc.DecodedLen(len(encoded))) + count, err := tt.enc.Decode(dbuf, []byte(encoded)) + testEqual(t, "Decode(%q) = error %v, want %v", encoded, err, error(nil)) + testEqual(t, "Decode(%q) = length %v, want %v", encoded, count, len(p.decoded)) + testEqual(t, "Decode(%q) = %q, want %q", encoded, string(dbuf[0:count]), p.decoded) + + dbuf, err = tt.enc.DecodeString(encoded) + testEqual(t, "DecodeString(%q) = error %v, want %v", encoded, err, error(nil)) + testEqual(t, "DecodeString(%q) = %q, want %q", encoded, string(dbuf), p.decoded) + } + } +} + +func TestDecoder(t *testing.T) { + for _, p := range pairs { + decoder := NewDecoder(StdEncoding, strings.NewReader(p.encoded)) + dbuf := make([]byte, StdEncoding.DecodedLen(len(p.encoded))) + count, err := decoder.Read(dbuf) + if err != nil && err != io.EOF { + t.Fatal("Read failed", err) + } + testEqual(t, "Read from %q = length %v, want %v", p.encoded, count, len(p.decoded)) + testEqual(t, "Decoding of %q = %q, want %q", p.encoded, string(dbuf[0:count]), p.decoded) + if err != io.EOF { + _, err = decoder.Read(dbuf) + } + testEqual(t, "Read from %q = %v, want %v", p.encoded, err, io.EOF) + } +} + +func TestDecoderBuffering(t *testing.T) { + for bs := 1; bs <= 12; bs++ { + decoder := NewDecoder(StdEncoding, strings.NewReader(bigtest.encoded)) + buf := make([]byte, len(bigtest.decoded)+12) + var total int + var n int + var err error + for total = 0; total < len(bigtest.decoded) && err == nil; { + n, err = decoder.Read(buf[total : total+bs]) + total += n + } + if err != nil && err != io.EOF { + t.Errorf("Read from %q at pos %d = %d, unexpected error %v", bigtest.encoded, total, n, err) + } + testEqual(t, "Decoding/%d of %q = %q, want %q", bs, bigtest.encoded, string(buf[0:total]), bigtest.decoded) + } +} + +func TestDecodeCorrupt(t *testing.T) { + testCases := []struct { + input string + offset int // -1 means no corruption. + }{ + {"", -1}, + {"\n", -1}, + {"AAA=\n", -1}, + {"AAAA\n", -1}, + {"!!!!", 0}, + {"====", 0}, + {"x===", 1}, + {"=AAA", 0}, + {"A=AA", 1}, + {"AA=A", 2}, + {"AA==A", 4}, + {"AAA=AAAA", 4}, + {"AAAAA", 4}, + {"AAAAAA", 4}, + {"A=", 1}, + {"A==", 1}, + {"AA=", 3}, + {"AA==", -1}, + {"AAA=", -1}, + {"AAAA", -1}, + {"AAAAAA=", 7}, + {"YWJjZA=====", 8}, + {"A!\n", 1}, + {"A=\n", 1}, + } + for _, tc := range testCases { + dbuf := make([]byte, StdEncoding.DecodedLen(len(tc.input))) + _, err := StdEncoding.Decode(dbuf, []byte(tc.input)) + if tc.offset == -1 { + if err != nil { + t.Error("Decoder wrongly detected corruption in", tc.input) + } + continue + } + switch err := err.(type) { + case CorruptInputError: + testEqual(t, "Corruption in %q at offset %v, want %v", tc.input, int(err), tc.offset) + default: + t.Error("Decoder failed to detect corruption in", tc) + } + } +} + +func TestDecodeBounds(t *testing.T) { + var buf [32]byte + s := StdEncoding.EncodeToString(buf[:]) + defer func() { + if err := recover(); err != nil { + t.Fatalf("Decode panicked unexpectedly: %v\n%s", err, debug.Stack()) + } + }() + n, err := StdEncoding.Decode(buf[:], []byte(s)) + if n != len(buf) || err != nil { + t.Fatalf("StdEncoding.Decode = %d, %v, want %d, nil", n, err, len(buf)) + } +} + +func TestEncodedLen(t *testing.T) { + for _, tt := range []struct { + enc *Encoding + n int + want int + }{ + {RawStdEncoding, 0, 0}, + {RawStdEncoding, 1, 2}, + {RawStdEncoding, 2, 3}, + {RawStdEncoding, 3, 4}, + {RawStdEncoding, 7, 10}, + {StdEncoding, 0, 0}, + {StdEncoding, 1, 4}, + {StdEncoding, 2, 4}, + {StdEncoding, 3, 4}, + {StdEncoding, 4, 8}, + {StdEncoding, 7, 12}, + } { + if got := tt.enc.EncodedLen(tt.n); got != tt.want { + t.Errorf("EncodedLen(%d): got %d, want %d", tt.n, got, tt.want) + } + } +} + +func TestDecodedLen(t *testing.T) { + for _, tt := range []struct { + enc *Encoding + n int + want int + }{ + {RawStdEncoding, 0, 0}, + {RawStdEncoding, 2, 1}, + {RawStdEncoding, 3, 2}, + {RawStdEncoding, 4, 3}, + {RawStdEncoding, 10, 7}, + {StdEncoding, 0, 0}, + {StdEncoding, 4, 3}, + {StdEncoding, 8, 6}, + } { + if got := tt.enc.DecodedLen(tt.n); got != tt.want { + t.Errorf("DecodedLen(%d): got %d, want %d", tt.n, got, tt.want) + } + } +} + +func TestBig(t *testing.T) { + n := 3*1000 + 1 + raw := make([]byte, n) + const alpha = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + for i := 0; i < n; i++ { + raw[i] = alpha[i%len(alpha)] + } + encoded := new(bytes.Buffer) + w := NewEncoder(StdEncoding, encoded) + nn, err := w.Write(raw) + if nn != n || err != nil { + t.Fatalf("Encoder.Write(raw) = %d, %v want %d, nil", nn, err, n) + } + err = w.Close() + if err != nil { + t.Fatalf("Encoder.Close() = %v want nil", err) + } + decoded, err := io.ReadAll(NewDecoder(StdEncoding, encoded)) + if err != nil { + t.Fatalf("io.ReadAll(NewDecoder(...)): %v", err) + } + + if !bytes.Equal(raw, decoded) { + var i int + for i = 0; i < len(decoded) && i < len(raw); i++ { + if decoded[i] != raw[i] { + break + } + } + t.Errorf("Decode(Encode(%d-byte string)) failed at offset %d", n, i) + } +} + +func TestNewLineCharacters(t *testing.T) { + // Each of these should decode to the string "sure", without errors. + const expected = "sure" + examples := []string{ + "c3VyZQ==", + "c3VyZQ==\r", + "c3VyZQ==\n", + "c3VyZQ==\r\n", + "c3VyZ\r\nQ==", + "c3V\ryZ\nQ==", + "c3V\nyZ\rQ==", + "c3VyZ\nQ==", + "c3VyZQ\n==", + "c3VyZQ=\n=", + "c3VyZQ=\r\n\r\n=", + } + for _, e := range examples { + buf, err := StdEncoding.DecodeString(e) + if err != nil { + t.Errorf("Decode(%q) failed: %v", e, err) + continue + } + if s := string(buf); s != expected { + t.Errorf("Decode(%q) = %q, want %q", e, s, expected) + } + } +} + +type nextRead struct { + n int // bytes to return + err error // error to return +} + +// faultInjectReader returns data from source, rate-limited +// and with the errors as written to nextc. +type faultInjectReader struct { + source string + nextc <-chan nextRead +} + +func (r *faultInjectReader) Read(p []byte) (int, error) { + nr := <-r.nextc + if len(p) > nr.n { + p = p[:nr.n] + } + n := copy(p, r.source) + r.source = r.source[n:] + return n, nr.err +} + +// tests that we don't ignore errors from our underlying reader +func TestDecoderIssue3577(t *testing.T) { + next := make(chan nextRead, 10) + wantErr := errors.New("my error") + next <- nextRead{5, nil} + next <- nextRead{10, wantErr} + next <- nextRead{0, wantErr} + d := NewDecoder(StdEncoding, &faultInjectReader{ + source: "VHdhcyBicmlsbGlnLCBhbmQgdGhlIHNsaXRoeSB0b3Zlcw==", // twas brillig... + nextc: next, + }) + errc := make(chan error, 1) + go func() { + _, err := io.ReadAll(d) + errc <- err + }() + select { + case err := <-errc: + if err != wantErr { + t.Errorf("got error %v; want %v", err, wantErr) + } + case <-time.After(5 * time.Second): + t.Errorf("timeout; Decoder blocked without returning an error") + } +} + +func TestDecoderIssue4779(t *testing.T) { + encoded := `CP/EAT8AAAEF +AQEBAQEBAAAAAAAAAAMAAQIEBQYHCAkKCwEAAQUBAQEBAQEAAAAAAAAAAQACAwQFBgcICQoLEAAB +BAEDAgQCBQcGCAUDDDMBAAIRAwQhEjEFQVFhEyJxgTIGFJGhsUIjJBVSwWIzNHKC0UMHJZJT8OHx +Y3M1FqKygyZEk1RkRcKjdDYX0lXiZfKzhMPTdePzRieUpIW0lcTU5PSltcXV5fVWZnaGlqa2xtbm +9jdHV2d3h5ent8fX5/cRAAICAQIEBAMEBQYHBwYFNQEAAhEDITESBEFRYXEiEwUygZEUobFCI8FS +0fAzJGLhcoKSQ1MVY3M08SUGFqKygwcmNcLSRJNUoxdkRVU2dGXi8rOEw9N14/NGlKSFtJXE1OT0 +pbXF1eX1VmZ2hpamtsbW5vYnN0dXZ3eHl6e3x//aAAwDAQACEQMRAD8A9VSSSSUpJJJJSkkkJ+Tj +1kiy1jCJJDnAcCTykpKkuQ6p/jN6FgmxlNduXawwAzaGH+V6jn/R/wCt71zdn+N/qL3kVYFNYB4N +ji6PDVjWpKp9TSXnvTf8bFNjg3qOEa2n6VlLpj/rT/pf567DpX1i6L1hs9Py67X8mqdtg/rUWbbf ++gkp0kkkklKSSSSUpJJJJT//0PVUkkklKVLq3WMDpGI7KzrNjADtYNXvI/Mqr/Pd/q9W3vaxjnvM +NaCXE9gNSvGPrf8AWS3qmba5jjsJhoB0DAf0NDf6sevf+/lf8Hj0JJATfWT6/dV6oXU1uOLQeKKn +EQP+Hubtfe/+R7Mf/g7f5xcocp++Z11JMCJPgFBxOg7/AOuqDx8I/ikpkXkmSdU8mJIJA/O8EMAy +j+mSARB/17pKVXYWHXjsj7yIex0PadzXMO1zT5KHoNA3HT8ietoGhgjsfA+CSnvvqh/jJtqsrwOv +2b6NGNzXfTYexzJ+nU7/ALkf4P8Awv6P9KvTQQ4AgyDqCF85Pho3CTB7eHwXoH+LT65uZbX9X+o2 +bqbPb06551Y4 +` + encodedShort := strings.ReplaceAll(encoded, "\n", "") + + dec := NewDecoder(StdEncoding, strings.NewReader(encoded)) + res1, err := io.ReadAll(dec) + if err != nil { + t.Errorf("ReadAll failed: %v", err) + } + + dec = NewDecoder(StdEncoding, strings.NewReader(encodedShort)) + var res2 []byte + res2, err = io.ReadAll(dec) + if err != nil { + t.Errorf("ReadAll failed: %v", err) + } + + if !bytes.Equal(res1, res2) { + t.Error("Decoded results not equal") + } +} + +func TestDecoderIssue7733(t *testing.T) { + s, err := StdEncoding.DecodeString("YWJjZA=====") + want := CorruptInputError(8) + if !reflect.DeepEqual(want, err) { + t.Errorf("Error = %v; want CorruptInputError(8)", err) + } + if string(s) != "abcd" { + t.Errorf("DecodeString = %q; want abcd", s) + } +} + +func TestDecoderIssue15656(t *testing.T) { + _, err := StdEncoding.Strict().DecodeString("WvLTlMrX9NpYDQlEIFlnDB==") + want := CorruptInputError(22) + if !reflect.DeepEqual(want, err) { + t.Errorf("Error = %v; want CorruptInputError(22)", err) + } + _, err = StdEncoding.Strict().DecodeString("WvLTlMrX9NpYDQlEIFlnDA==") + if err != nil { + t.Errorf("Error = %v; want nil", err) + } + _, err = StdEncoding.DecodeString("WvLTlMrX9NpYDQlEIFlnDB==") + if err != nil { + t.Errorf("Error = %v; want nil", err) + } +} + +func BenchmarkEncodeToString(b *testing.B) { + data := make([]byte, 8192) + b.SetBytes(int64(len(data))) + for i := 0; i < b.N; i++ { + StdEncoding.EncodeToString(data) + } +} + +func BenchmarkDecodeString(b *testing.B) { + sizes := []int{2, 4, 8, 64, 8192} + benchFunc := func(b *testing.B, benchSize int) { + data := StdEncoding.EncodeToString(make([]byte, benchSize)) + b.SetBytes(int64(len(data))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + StdEncoding.DecodeString(data) + } + } + for _, size := range sizes { + b.Run(fmt.Sprintf("%d", size), func(b *testing.B) { + benchFunc(b, size) + }) + } +} + +func TestDecoderRaw(t *testing.T) { + source := "AAAAAA" + want := []byte{0, 0, 0, 0} + + // Direct. + dec1, err := RawURLEncoding.DecodeString(source) + if err != nil || !bytes.Equal(dec1, want) { + t.Errorf("RawURLEncoding.DecodeString(%q) = %x, %v, want %x, nil", source, dec1, err, want) + } + + // Through reader. Used to fail. + r := NewDecoder(RawURLEncoding, bytes.NewReader([]byte(source))) + dec2, err := io.ReadAll(io.LimitReader(r, 100)) + if err != nil || !bytes.Equal(dec2, want) { + t.Errorf("reading NewDecoder(RawURLEncoding, %q) = %x, %v, want %x, nil", source, dec2, err, want) + } + + // Should work with padding. + r = NewDecoder(URLEncoding, bytes.NewReader([]byte(source+"=="))) + dec3, err := io.ReadAll(r) + if err != nil || !bytes.Equal(dec3, want) { + t.Errorf("reading NewDecoder(URLEncoding, %q) = %x, %v, want %x, nil", source+"==", dec3, err, want) + } +} |