summaryrefslogtreecommitdiffstats
path: root/src/internal/saferio
diff options
context:
space:
mode:
Diffstat (limited to 'src/internal/saferio')
-rw-r--r--src/internal/saferio/io.go132
-rw-r--r--src/internal/saferio/io_test.go136
2 files changed, 268 insertions, 0 deletions
diff --git a/src/internal/saferio/io.go b/src/internal/saferio/io.go
new file mode 100644
index 0000000..5c428e6
--- /dev/null
+++ b/src/internal/saferio/io.go
@@ -0,0 +1,132 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package saferio provides I/O functions that avoid allocating large
+// amounts of memory unnecessarily. This is intended for packages that
+// read data from an [io.Reader] where the size is part of the input
+// data but the input may be corrupt, or may be provided by an
+// untrustworthy attacker.
+package saferio
+
+import (
+ "io"
+ "unsafe"
+)
+
+// chunk is an arbitrary limit on how much memory we are willing
+// to allocate without concern.
+const chunk = 10 << 20 // 10M
+
+// ReadData reads n bytes from the input stream, but avoids allocating
+// all n bytes if n is large. This avoids crashing the program by
+// allocating all n bytes in cases where n is incorrect.
+//
+// The error is io.EOF only if no bytes were read.
+// If an io.EOF happens after reading some but not all the bytes,
+// ReadData returns io.ErrUnexpectedEOF.
+func ReadData(r io.Reader, n uint64) ([]byte, error) {
+ if int64(n) < 0 || n != uint64(int(n)) {
+ // n is too large to fit in int, so we can't allocate
+ // a buffer large enough. Treat this as a read failure.
+ return nil, io.ErrUnexpectedEOF
+ }
+
+ if n < chunk {
+ buf := make([]byte, n)
+ _, err := io.ReadFull(r, buf)
+ if err != nil {
+ return nil, err
+ }
+ return buf, nil
+ }
+
+ var buf []byte
+ buf1 := make([]byte, chunk)
+ for n > 0 {
+ next := n
+ if next > chunk {
+ next = chunk
+ }
+ _, err := io.ReadFull(r, buf1[:next])
+ if err != nil {
+ if len(buf) > 0 && err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return nil, err
+ }
+ buf = append(buf, buf1[:next]...)
+ n -= next
+ }
+ return buf, nil
+}
+
+// ReadDataAt reads n bytes from the input stream at off, but avoids
+// allocating all n bytes if n is large. This avoids crashing the program
+// by allocating all n bytes in cases where n is incorrect.
+func ReadDataAt(r io.ReaderAt, n uint64, off int64) ([]byte, error) {
+ if int64(n) < 0 || n != uint64(int(n)) {
+ // n is too large to fit in int, so we can't allocate
+ // a buffer large enough. Treat this as a read failure.
+ return nil, io.ErrUnexpectedEOF
+ }
+
+ if n < chunk {
+ buf := make([]byte, n)
+ _, err := r.ReadAt(buf, off)
+ if err != nil {
+ // io.SectionReader can return EOF for n == 0,
+ // but for our purposes that is a success.
+ if err != io.EOF || n > 0 {
+ return nil, err
+ }
+ }
+ return buf, nil
+ }
+
+ var buf []byte
+ buf1 := make([]byte, chunk)
+ for n > 0 {
+ next := n
+ if next > chunk {
+ next = chunk
+ }
+ _, err := r.ReadAt(buf1[:next], off)
+ if err != nil {
+ return nil, err
+ }
+ buf = append(buf, buf1[:next]...)
+ n -= next
+ off += int64(next)
+ }
+ return buf, nil
+}
+
+// SliceCapWithSize returns the capacity to use when allocating a slice.
+// After the slice is allocated with the capacity, it should be
+// built using append. This will avoid allocating too much memory
+// if the capacity is large and incorrect.
+//
+// A negative result means that the value is always too big.
+func SliceCapWithSize(size, c uint64) int {
+ if int64(c) < 0 || c != uint64(int(c)) {
+ return -1
+ }
+ if size > 0 && c > (1<<64-1)/size {
+ return -1
+ }
+ if c*size > chunk {
+ c = chunk / size
+ if c == 0 {
+ c = 1
+ }
+ }
+ return int(c)
+}
+
+// SliceCap is like SliceCapWithSize but using generics.
+func SliceCap[E any](c uint64) int {
+ var v E
+ size := uint64(unsafe.Sizeof(v))
+ return SliceCapWithSize(size, c)
+}
diff --git a/src/internal/saferio/io_test.go b/src/internal/saferio/io_test.go
new file mode 100644
index 0000000..696356f
--- /dev/null
+++ b/src/internal/saferio/io_test.go
@@ -0,0 +1,136 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package saferio
+
+import (
+ "bytes"
+ "io"
+ "testing"
+)
+
+func TestReadData(t *testing.T) {
+ const count = 100
+ input := bytes.Repeat([]byte{'a'}, count)
+
+ t.Run("small", func(t *testing.T) {
+ got, err := ReadData(bytes.NewReader(input), count)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(got, input) {
+ t.Errorf("got %v, want %v", got, input)
+ }
+ })
+
+ t.Run("large", func(t *testing.T) {
+ _, err := ReadData(bytes.NewReader(input), 10<<30)
+ if err == nil {
+ t.Error("large read succeeded unexpectedly")
+ }
+ })
+
+ t.Run("maxint", func(t *testing.T) {
+ _, err := ReadData(bytes.NewReader(input), 1<<62)
+ if err == nil {
+ t.Error("large read succeeded unexpectedly")
+ }
+ })
+
+ t.Run("small-EOF", func(t *testing.T) {
+ _, err := ReadData(bytes.NewReader(nil), chunk-1)
+ if err != io.EOF {
+ t.Errorf("ReadData = %v, want io.EOF", err)
+ }
+ })
+
+ t.Run("large-EOF", func(t *testing.T) {
+ _, err := ReadData(bytes.NewReader(nil), chunk+1)
+ if err != io.EOF {
+ t.Errorf("ReadData = %v, want io.EOF", err)
+ }
+ })
+
+ t.Run("large-UnexpectedEOF", func(t *testing.T) {
+ _, err := ReadData(bytes.NewReader(make([]byte, chunk)), chunk+1)
+ if err != io.ErrUnexpectedEOF {
+ t.Errorf("ReadData = %v, want io.ErrUnexpectedEOF", err)
+ }
+ })
+}
+
+func TestReadDataAt(t *testing.T) {
+ const count = 100
+ input := bytes.Repeat([]byte{'a'}, count)
+
+ t.Run("small", func(t *testing.T) {
+ got, err := ReadDataAt(bytes.NewReader(input), count, 0)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(got, input) {
+ t.Errorf("got %v, want %v", got, input)
+ }
+ })
+
+ t.Run("large", func(t *testing.T) {
+ _, err := ReadDataAt(bytes.NewReader(input), 10<<30, 0)
+ if err == nil {
+ t.Error("large read succeeded unexpectedly")
+ }
+ })
+
+ t.Run("maxint", func(t *testing.T) {
+ _, err := ReadDataAt(bytes.NewReader(input), 1<<62, 0)
+ if err == nil {
+ t.Error("large read succeeded unexpectedly")
+ }
+ })
+
+ t.Run("SectionReader", func(t *testing.T) {
+ // Reading 0 bytes from an io.SectionReader at the end
+ // of the section will return EOF, but ReadDataAt
+ // should succeed and return 0 bytes.
+ sr := io.NewSectionReader(bytes.NewReader(input), 0, 0)
+ got, err := ReadDataAt(sr, 0, 0)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(got) > 0 {
+ t.Errorf("got %d bytes, expected 0", len(got))
+ }
+ })
+}
+
+func TestSliceCap(t *testing.T) {
+ t.Run("small", func(t *testing.T) {
+ c := SliceCap[int](10)
+ if c != 10 {
+ t.Errorf("got capacity %d, want %d", c, 10)
+ }
+ })
+
+ t.Run("large", func(t *testing.T) {
+ c := SliceCap[byte](1 << 30)
+ if c < 0 {
+ t.Error("SliceCap failed unexpectedly")
+ } else if c == 1<<30 {
+ t.Errorf("got capacity %d which is too high", c)
+ }
+ })
+
+ t.Run("maxint", func(t *testing.T) {
+ c := SliceCap[byte](1 << 63)
+ if c >= 0 {
+ t.Errorf("SliceCap returned %d, expected failure", c)
+ }
+ })
+
+ t.Run("overflow", func(t *testing.T) {
+ c := SliceCap[int64](1 << 62)
+ if c >= 0 {
+ t.Errorf("SliceCap returned %d, expected failure", c)
+ }
+ })
+}