diff options
Diffstat (limited to 'src/internal/saferio')
-rw-r--r-- | src/internal/saferio/io.go | 132 | ||||
-rw-r--r-- | src/internal/saferio/io_test.go | 136 |
2 files changed, 268 insertions, 0 deletions
diff --git a/src/internal/saferio/io.go b/src/internal/saferio/io.go new file mode 100644 index 0000000..5c428e6 --- /dev/null +++ b/src/internal/saferio/io.go @@ -0,0 +1,132 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package saferio provides I/O functions that avoid allocating large +// amounts of memory unnecessarily. This is intended for packages that +// read data from an [io.Reader] where the size is part of the input +// data but the input may be corrupt, or may be provided by an +// untrustworthy attacker. +package saferio + +import ( + "io" + "unsafe" +) + +// chunk is an arbitrary limit on how much memory we are willing +// to allocate without concern. +const chunk = 10 << 20 // 10M + +// ReadData reads n bytes from the input stream, but avoids allocating +// all n bytes if n is large. This avoids crashing the program by +// allocating all n bytes in cases where n is incorrect. +// +// The error is io.EOF only if no bytes were read. +// If an io.EOF happens after reading some but not all the bytes, +// ReadData returns io.ErrUnexpectedEOF. +func ReadData(r io.Reader, n uint64) ([]byte, error) { + if int64(n) < 0 || n != uint64(int(n)) { + // n is too large to fit in int, so we can't allocate + // a buffer large enough. Treat this as a read failure. + return nil, io.ErrUnexpectedEOF + } + + if n < chunk { + buf := make([]byte, n) + _, err := io.ReadFull(r, buf) + if err != nil { + return nil, err + } + return buf, nil + } + + var buf []byte + buf1 := make([]byte, chunk) + for n > 0 { + next := n + if next > chunk { + next = chunk + } + _, err := io.ReadFull(r, buf1[:next]) + if err != nil { + if len(buf) > 0 && err == io.EOF { + err = io.ErrUnexpectedEOF + } + return nil, err + } + buf = append(buf, buf1[:next]...) + n -= next + } + return buf, nil +} + +// ReadDataAt reads n bytes from the input stream at off, but avoids +// allocating all n bytes if n is large. This avoids crashing the program +// by allocating all n bytes in cases where n is incorrect. +func ReadDataAt(r io.ReaderAt, n uint64, off int64) ([]byte, error) { + if int64(n) < 0 || n != uint64(int(n)) { + // n is too large to fit in int, so we can't allocate + // a buffer large enough. Treat this as a read failure. + return nil, io.ErrUnexpectedEOF + } + + if n < chunk { + buf := make([]byte, n) + _, err := r.ReadAt(buf, off) + if err != nil { + // io.SectionReader can return EOF for n == 0, + // but for our purposes that is a success. + if err != io.EOF || n > 0 { + return nil, err + } + } + return buf, nil + } + + var buf []byte + buf1 := make([]byte, chunk) + for n > 0 { + next := n + if next > chunk { + next = chunk + } + _, err := r.ReadAt(buf1[:next], off) + if err != nil { + return nil, err + } + buf = append(buf, buf1[:next]...) + n -= next + off += int64(next) + } + return buf, nil +} + +// SliceCapWithSize returns the capacity to use when allocating a slice. +// After the slice is allocated with the capacity, it should be +// built using append. This will avoid allocating too much memory +// if the capacity is large and incorrect. +// +// A negative result means that the value is always too big. +func SliceCapWithSize(size, c uint64) int { + if int64(c) < 0 || c != uint64(int(c)) { + return -1 + } + if size > 0 && c > (1<<64-1)/size { + return -1 + } + if c*size > chunk { + c = chunk / size + if c == 0 { + c = 1 + } + } + return int(c) +} + +// SliceCap is like SliceCapWithSize but using generics. +func SliceCap[E any](c uint64) int { + var v E + size := uint64(unsafe.Sizeof(v)) + return SliceCapWithSize(size, c) +} diff --git a/src/internal/saferio/io_test.go b/src/internal/saferio/io_test.go new file mode 100644 index 0000000..696356f --- /dev/null +++ b/src/internal/saferio/io_test.go @@ -0,0 +1,136 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package saferio + +import ( + "bytes" + "io" + "testing" +) + +func TestReadData(t *testing.T) { + const count = 100 + input := bytes.Repeat([]byte{'a'}, count) + + t.Run("small", func(t *testing.T) { + got, err := ReadData(bytes.NewReader(input), count) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, input) { + t.Errorf("got %v, want %v", got, input) + } + }) + + t.Run("large", func(t *testing.T) { + _, err := ReadData(bytes.NewReader(input), 10<<30) + if err == nil { + t.Error("large read succeeded unexpectedly") + } + }) + + t.Run("maxint", func(t *testing.T) { + _, err := ReadData(bytes.NewReader(input), 1<<62) + if err == nil { + t.Error("large read succeeded unexpectedly") + } + }) + + t.Run("small-EOF", func(t *testing.T) { + _, err := ReadData(bytes.NewReader(nil), chunk-1) + if err != io.EOF { + t.Errorf("ReadData = %v, want io.EOF", err) + } + }) + + t.Run("large-EOF", func(t *testing.T) { + _, err := ReadData(bytes.NewReader(nil), chunk+1) + if err != io.EOF { + t.Errorf("ReadData = %v, want io.EOF", err) + } + }) + + t.Run("large-UnexpectedEOF", func(t *testing.T) { + _, err := ReadData(bytes.NewReader(make([]byte, chunk)), chunk+1) + if err != io.ErrUnexpectedEOF { + t.Errorf("ReadData = %v, want io.ErrUnexpectedEOF", err) + } + }) +} + +func TestReadDataAt(t *testing.T) { + const count = 100 + input := bytes.Repeat([]byte{'a'}, count) + + t.Run("small", func(t *testing.T) { + got, err := ReadDataAt(bytes.NewReader(input), count, 0) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, input) { + t.Errorf("got %v, want %v", got, input) + } + }) + + t.Run("large", func(t *testing.T) { + _, err := ReadDataAt(bytes.NewReader(input), 10<<30, 0) + if err == nil { + t.Error("large read succeeded unexpectedly") + } + }) + + t.Run("maxint", func(t *testing.T) { + _, err := ReadDataAt(bytes.NewReader(input), 1<<62, 0) + if err == nil { + t.Error("large read succeeded unexpectedly") + } + }) + + t.Run("SectionReader", func(t *testing.T) { + // Reading 0 bytes from an io.SectionReader at the end + // of the section will return EOF, but ReadDataAt + // should succeed and return 0 bytes. + sr := io.NewSectionReader(bytes.NewReader(input), 0, 0) + got, err := ReadDataAt(sr, 0, 0) + if err != nil { + t.Fatal(err) + } + if len(got) > 0 { + t.Errorf("got %d bytes, expected 0", len(got)) + } + }) +} + +func TestSliceCap(t *testing.T) { + t.Run("small", func(t *testing.T) { + c := SliceCap[int](10) + if c != 10 { + t.Errorf("got capacity %d, want %d", c, 10) + } + }) + + t.Run("large", func(t *testing.T) { + c := SliceCap[byte](1 << 30) + if c < 0 { + t.Error("SliceCap failed unexpectedly") + } else if c == 1<<30 { + t.Errorf("got capacity %d which is too high", c) + } + }) + + t.Run("maxint", func(t *testing.T) { + c := SliceCap[byte](1 << 63) + if c >= 0 { + t.Errorf("SliceCap returned %d, expected failure", c) + } + }) + + t.Run("overflow", func(t *testing.T) { + c := SliceCap[int64](1 << 62) + if c >= 0 { + t.Errorf("SliceCap returned %d, expected failure", c) + } + }) +} |