diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-16 19:25:22 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-16 19:25:22 +0000 |
commit | f6ad4dcef54c5ce997a4bad5a6d86de229015700 (patch) | |
tree | 7cfa4e31ace5c2bd95c72b154d15af494b2bcbef /src/io | |
parent | Initial commit. (diff) | |
download | golang-1.22-f6ad4dcef54c5ce997a4bad5a6d86de229015700.tar.xz golang-1.22-f6ad4dcef54c5ce997a4bad5a6d86de229015700.zip |
Adding upstream version 1.22.1.upstream/1.22.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
31 files changed, 5088 insertions, 0 deletions
diff --git a/src/io/example_test.go b/src/io/example_test.go new file mode 100644 index 0000000..818020e --- /dev/null +++ b/src/io/example_test.go @@ -0,0 +1,284 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package io_test + +import ( + "fmt" + "io" + "log" + "os" + "strings" +) + +func ExampleCopy() { + r := strings.NewReader("some io.Reader stream to be read\n") + + if _, err := io.Copy(os.Stdout, r); err != nil { + log.Fatal(err) + } + + // Output: + // some io.Reader stream to be read +} + +func ExampleCopyBuffer() { + r1 := strings.NewReader("first reader\n") + r2 := strings.NewReader("second reader\n") + buf := make([]byte, 8) + + // buf is used here... + if _, err := io.CopyBuffer(os.Stdout, r1, buf); err != nil { + log.Fatal(err) + } + + // ... reused here also. No need to allocate an extra buffer. + if _, err := io.CopyBuffer(os.Stdout, r2, buf); err != nil { + log.Fatal(err) + } + + // Output: + // first reader + // second reader +} + +func ExampleCopyN() { + r := strings.NewReader("some io.Reader stream to be read") + + if _, err := io.CopyN(os.Stdout, r, 4); err != nil { + log.Fatal(err) + } + + // Output: + // some +} + +func ExampleReadAtLeast() { + r := strings.NewReader("some io.Reader stream to be read\n") + + buf := make([]byte, 14) + if _, err := io.ReadAtLeast(r, buf, 4); err != nil { + log.Fatal(err) + } + fmt.Printf("%s\n", buf) + + // buffer smaller than minimal read size. + shortBuf := make([]byte, 3) + if _, err := io.ReadAtLeast(r, shortBuf, 4); err != nil { + fmt.Println("error:", err) + } + + // minimal read size bigger than io.Reader stream + longBuf := make([]byte, 64) + if _, err := io.ReadAtLeast(r, longBuf, 64); err != nil { + fmt.Println("error:", err) + } + + // Output: + // some io.Reader + // error: short buffer + // error: unexpected EOF +} + +func ExampleReadFull() { + r := strings.NewReader("some io.Reader stream to be read\n") + + buf := make([]byte, 4) + if _, err := io.ReadFull(r, buf); err != nil { + log.Fatal(err) + } + fmt.Printf("%s\n", buf) + + // minimal read size bigger than io.Reader stream + longBuf := make([]byte, 64) + if _, err := io.ReadFull(r, longBuf); err != nil { + fmt.Println("error:", err) + } + + // Output: + // some + // error: unexpected EOF +} + +func ExampleWriteString() { + if _, err := io.WriteString(os.Stdout, "Hello World"); err != nil { + log.Fatal(err) + } + + // Output: Hello World +} + +func ExampleLimitReader() { + r := strings.NewReader("some io.Reader stream to be read\n") + lr := io.LimitReader(r, 4) + + if _, err := io.Copy(os.Stdout, lr); err != nil { + log.Fatal(err) + } + + // Output: + // some +} + +func ExampleMultiReader() { + r1 := strings.NewReader("first reader ") + r2 := strings.NewReader("second reader ") + r3 := strings.NewReader("third reader\n") + r := io.MultiReader(r1, r2, r3) + + if _, err := io.Copy(os.Stdout, r); err != nil { + log.Fatal(err) + } + + // Output: + // first reader second reader third reader +} + +func ExampleTeeReader() { + var r io.Reader = strings.NewReader("some io.Reader stream to be read\n") + + r = io.TeeReader(r, os.Stdout) + + // Everything read from r will be copied to stdout. + if _, err := io.ReadAll(r); err != nil { + log.Fatal(err) + } + + // Output: + // some io.Reader stream to be read +} + +func ExampleSectionReader() { + r := strings.NewReader("some io.Reader stream to be read\n") + s := io.NewSectionReader(r, 5, 17) + + if _, err := io.Copy(os.Stdout, s); err != nil { + log.Fatal(err) + } + + // Output: + // io.Reader stream +} + +func ExampleSectionReader_Read() { + r := strings.NewReader("some io.Reader stream to be read\n") + s := io.NewSectionReader(r, 5, 17) + + buf := make([]byte, 9) + if _, err := s.Read(buf); err != nil { + log.Fatal(err) + } + + fmt.Printf("%s\n", buf) + + // Output: + // io.Reader +} + +func ExampleSectionReader_ReadAt() { + r := strings.NewReader("some io.Reader stream to be read\n") + s := io.NewSectionReader(r, 5, 17) + + buf := make([]byte, 6) + if _, err := s.ReadAt(buf, 10); err != nil { + log.Fatal(err) + } + + fmt.Printf("%s\n", buf) + + // Output: + // stream +} + +func ExampleSectionReader_Seek() { + r := strings.NewReader("some io.Reader stream to be read\n") + s := io.NewSectionReader(r, 5, 17) + + if _, err := s.Seek(10, io.SeekStart); err != nil { + log.Fatal(err) + } + + if _, err := io.Copy(os.Stdout, s); err != nil { + log.Fatal(err) + } + + // Output: + // stream +} + +func ExampleSectionReader_Size() { + r := strings.NewReader("some io.Reader stream to be read\n") + s := io.NewSectionReader(r, 5, 17) + + fmt.Println(s.Size()) + + // Output: + // 17 +} + +func ExampleSeeker_Seek() { + r := strings.NewReader("some io.Reader stream to be read\n") + + r.Seek(5, io.SeekStart) // move to the 5th char from the start + if _, err := io.Copy(os.Stdout, r); err != nil { + log.Fatal(err) + } + + r.Seek(-5, io.SeekEnd) + if _, err := io.Copy(os.Stdout, r); err != nil { + log.Fatal(err) + } + + // Output: + // io.Reader stream to be read + // read +} + +func ExampleMultiWriter() { + r := strings.NewReader("some io.Reader stream to be read\n") + + var buf1, buf2 strings.Builder + w := io.MultiWriter(&buf1, &buf2) + + if _, err := io.Copy(w, r); err != nil { + log.Fatal(err) + } + + fmt.Print(buf1.String()) + fmt.Print(buf2.String()) + + // Output: + // some io.Reader stream to be read + // some io.Reader stream to be read +} + +func ExamplePipe() { + r, w := io.Pipe() + + go func() { + fmt.Fprint(w, "some io.Reader stream to be read\n") + w.Close() + }() + + if _, err := io.Copy(os.Stdout, r); err != nil { + log.Fatal(err) + } + + // Output: + // some io.Reader stream to be read +} + +func ExampleReadAll() { + r := strings.NewReader("Go is a general-purpose language designed with systems programming in mind.") + + b, err := io.ReadAll(r) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%s", b) + + // Output: + // Go is a general-purpose language designed with systems programming in mind. +} diff --git a/src/io/export_test.go b/src/io/export_test.go new file mode 100644 index 0000000..06853f9 --- /dev/null +++ b/src/io/export_test.go @@ -0,0 +1,10 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package io + +// exported for test +var ErrInvalidWrite = errInvalidWrite +var ErrWhence = errWhence +var ErrOffset = errOffset diff --git a/src/io/fs/example_test.go b/src/io/fs/example_test.go new file mode 100644 index 0000000..c902703 --- /dev/null +++ b/src/io/fs/example_test.go @@ -0,0 +1,25 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fs_test + +import ( + "fmt" + "io/fs" + "log" + "os" +) + +func ExampleWalkDir() { + root := "/usr/local/go/bin" + fileSystem := os.DirFS(root) + + fs.WalkDir(fileSystem, ".", func(path string, d fs.DirEntry, err error) error { + if err != nil { + log.Fatal(err) + } + fmt.Println(path) + return nil + }) +} diff --git a/src/io/fs/format.go b/src/io/fs/format.go new file mode 100644 index 0000000..60b40df --- /dev/null +++ b/src/io/fs/format.go @@ -0,0 +1,76 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fs + +import ( + "time" +) + +// FormatFileInfo returns a formatted version of info for human readability. +// Implementations of [FileInfo] can call this from a String method. +// The output for a file named "hello.go", 100 bytes, mode 0o644, created +// January 1, 1970 at noon is +// +// -rw-r--r-- 100 1970-01-01 12:00:00 hello.go +func FormatFileInfo(info FileInfo) string { + name := info.Name() + b := make([]byte, 0, 40+len(name)) + b = append(b, info.Mode().String()...) + b = append(b, ' ') + + size := info.Size() + var usize uint64 + if size >= 0 { + usize = uint64(size) + } else { + b = append(b, '-') + usize = uint64(-size) + } + var buf [20]byte + i := len(buf) - 1 + for usize >= 10 { + q := usize / 10 + buf[i] = byte('0' + usize - q*10) + i-- + usize = q + } + buf[i] = byte('0' + usize) + b = append(b, buf[i:]...) + b = append(b, ' ') + + b = append(b, info.ModTime().Format(time.DateTime)...) + b = append(b, ' ') + + b = append(b, name...) + if info.IsDir() { + b = append(b, '/') + } + + return string(b) +} + +// FormatDirEntry returns a formatted version of dir for human readability. +// Implementations of [DirEntry] can call this from a String method. +// The outputs for a directory named subdir and a file named hello.go are: +// +// d subdir/ +// - hello.go +func FormatDirEntry(dir DirEntry) string { + name := dir.Name() + b := make([]byte, 0, 5+len(name)) + + // The Type method does not return any permission bits, + // so strip them from the string. + mode := dir.Type().String() + mode = mode[:len(mode)-9] + + b = append(b, mode...) + b = append(b, ' ') + b = append(b, name...) + if dir.IsDir() { + b = append(b, '/') + } + return string(b) +} diff --git a/src/io/fs/format_test.go b/src/io/fs/format_test.go new file mode 100644 index 0000000..a5f5066 --- /dev/null +++ b/src/io/fs/format_test.go @@ -0,0 +1,123 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fs_test + +import ( + . "io/fs" + "testing" + "time" +) + +// formatTest implements FileInfo to test FormatFileInfo, +// and implements DirEntry to test FormatDirEntry. +type formatTest struct { + name string + size int64 + mode FileMode + modTime time.Time + isDir bool +} + +func (fs *formatTest) Name() string { + return fs.name +} + +func (fs *formatTest) Size() int64 { + return fs.size +} + +func (fs *formatTest) Mode() FileMode { + return fs.mode +} + +func (fs *formatTest) ModTime() time.Time { + return fs.modTime +} + +func (fs *formatTest) IsDir() bool { + return fs.isDir +} + +func (fs *formatTest) Sys() any { + return nil +} + +func (fs *formatTest) Type() FileMode { + return fs.mode.Type() +} + +func (fs *formatTest) Info() (FileInfo, error) { + return fs, nil +} + +var formatTests = []struct { + input formatTest + wantFileInfo string + wantDirEntry string +}{ + { + formatTest{ + name: "hello.go", + size: 100, + mode: 0o644, + modTime: time.Date(1970, time.January, 1, 12, 0, 0, 0, time.UTC), + isDir: false, + }, + "-rw-r--r-- 100 1970-01-01 12:00:00 hello.go", + "- hello.go", + }, + { + formatTest{ + name: "home/gopher", + size: 0, + mode: ModeDir | 0o755, + modTime: time.Date(1970, time.January, 1, 12, 0, 0, 0, time.UTC), + isDir: true, + }, + "drwxr-xr-x 0 1970-01-01 12:00:00 home/gopher/", + "d home/gopher/", + }, + { + formatTest{ + name: "big", + size: 0x7fffffffffffffff, + mode: ModeIrregular | 0o644, + modTime: time.Date(1970, time.January, 1, 12, 0, 0, 0, time.UTC), + isDir: false, + }, + "?rw-r--r-- 9223372036854775807 1970-01-01 12:00:00 big", + "? big", + }, + { + formatTest{ + name: "small", + size: -0x8000000000000000, + mode: ModeSocket | ModeSetuid | 0o644, + modTime: time.Date(1970, time.January, 1, 12, 0, 0, 0, time.UTC), + isDir: false, + }, + "Surw-r--r-- -9223372036854775808 1970-01-01 12:00:00 small", + "S small", + }, +} + +func TestFormatFileInfo(t *testing.T) { + for i, test := range formatTests { + got := FormatFileInfo(&test.input) + if got != test.wantFileInfo { + t.Errorf("%d: FormatFileInfo(%#v) = %q, want %q", i, test.input, got, test.wantFileInfo) + } + } +} + +func TestFormatDirEntry(t *testing.T) { + for i, test := range formatTests { + got := FormatDirEntry(&test.input) + if got != test.wantDirEntry { + t.Errorf("%d: FormatDirEntry(%#v) = %q, want %q", i, test.input, got, test.wantDirEntry) + } + } + +} diff --git a/src/io/fs/fs.go b/src/io/fs/fs.go new file mode 100644 index 0000000..6891d75 --- /dev/null +++ b/src/io/fs/fs.go @@ -0,0 +1,264 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package fs defines basic interfaces to a file system. +// A file system can be provided by the host operating system +// but also by other packages. +// +// See the [testing/fstest] package for support with testing +// implementations of file systems. +package fs + +import ( + "internal/oserror" + "time" + "unicode/utf8" +) + +// An FS provides access to a hierarchical file system. +// +// The FS interface is the minimum implementation required of the file system. +// A file system may implement additional interfaces, +// such as [ReadFileFS], to provide additional or optimized functionality. +// +// [testing/fstest.TestFS] may be used to test implementations of an FS for +// correctness. +type FS interface { + // Open opens the named file. + // + // When Open returns an error, it should be of type *PathError + // with the Op field set to "open", the Path field set to name, + // and the Err field describing the problem. + // + // Open should reject attempts to open names that do not satisfy + // ValidPath(name), returning a *PathError with Err set to + // ErrInvalid or ErrNotExist. + Open(name string) (File, error) +} + +// ValidPath reports whether the given path name +// is valid for use in a call to Open. +// +// Path names passed to open are UTF-8-encoded, +// unrooted, slash-separated sequences of path elements, like “x/y/z”. +// Path names must not contain an element that is “.” or “..” or the empty string, +// except for the special case that the root directory is named “.”. +// Paths must not start or end with a slash: “/x” and “x/” are invalid. +// +// Note that paths are slash-separated on all systems, even Windows. +// Paths containing other characters such as backslash and colon +// are accepted as valid, but those characters must never be +// interpreted by an [FS] implementation as path element separators. +func ValidPath(name string) bool { + if !utf8.ValidString(name) { + return false + } + + if name == "." { + // special case + return true + } + + // Iterate over elements in name, checking each. + for { + i := 0 + for i < len(name) && name[i] != '/' { + i++ + } + elem := name[:i] + if elem == "" || elem == "." || elem == ".." { + return false + } + if i == len(name) { + return true // reached clean ending + } + name = name[i+1:] + } +} + +// A File provides access to a single file. +// The File interface is the minimum implementation required of the file. +// Directory files should also implement [ReadDirFile]. +// A file may implement [io.ReaderAt] or [io.Seeker] as optimizations. +type File interface { + Stat() (FileInfo, error) + Read([]byte) (int, error) + Close() error +} + +// A DirEntry is an entry read from a directory +// (using the [ReadDir] function or a [ReadDirFile]'s ReadDir method). +type DirEntry interface { + // Name returns the name of the file (or subdirectory) described by the entry. + // This name is only the final element of the path (the base name), not the entire path. + // For example, Name would return "hello.go" not "home/gopher/hello.go". + Name() string + + // IsDir reports whether the entry describes a directory. + IsDir() bool + + // Type returns the type bits for the entry. + // The type bits are a subset of the usual FileMode bits, those returned by the FileMode.Type method. + Type() FileMode + + // Info returns the FileInfo for the file or subdirectory described by the entry. + // The returned FileInfo may be from the time of the original directory read + // or from the time of the call to Info. If the file has been removed or renamed + // since the directory read, Info may return an error satisfying errors.Is(err, ErrNotExist). + // If the entry denotes a symbolic link, Info reports the information about the link itself, + // not the link's target. + Info() (FileInfo, error) +} + +// A ReadDirFile is a directory file whose entries can be read with the ReadDir method. +// Every directory file should implement this interface. +// (It is permissible for any file to implement this interface, +// but if so ReadDir should return an error for non-directories.) +type ReadDirFile interface { + File + + // ReadDir reads the contents of the directory and returns + // a slice of up to n DirEntry values in directory order. + // Subsequent calls on the same file will yield further DirEntry values. + // + // If n > 0, ReadDir returns at most n DirEntry structures. + // In this case, if ReadDir returns an empty slice, it will return + // a non-nil error explaining why. + // At the end of a directory, the error is io.EOF. + // (ReadDir must return io.EOF itself, not an error wrapping io.EOF.) + // + // If n <= 0, ReadDir returns all the DirEntry values from the directory + // in a single slice. In this case, if ReadDir succeeds (reads all the way + // to the end of the directory), it returns the slice and a nil error. + // If it encounters an error before the end of the directory, + // ReadDir returns the DirEntry list read until that point and a non-nil error. + ReadDir(n int) ([]DirEntry, error) +} + +// Generic file system errors. +// Errors returned by file systems can be tested against these errors +// using [errors.Is]. +var ( + ErrInvalid = errInvalid() // "invalid argument" + ErrPermission = errPermission() // "permission denied" + ErrExist = errExist() // "file already exists" + ErrNotExist = errNotExist() // "file does not exist" + ErrClosed = errClosed() // "file already closed" +) + +func errInvalid() error { return oserror.ErrInvalid } +func errPermission() error { return oserror.ErrPermission } +func errExist() error { return oserror.ErrExist } +func errNotExist() error { return oserror.ErrNotExist } +func errClosed() error { return oserror.ErrClosed } + +// A FileInfo describes a file and is returned by [Stat]. +type FileInfo interface { + Name() string // base name of the file + Size() int64 // length in bytes for regular files; system-dependent for others + Mode() FileMode // file mode bits + ModTime() time.Time // modification time + IsDir() bool // abbreviation for Mode().IsDir() + Sys() any // underlying data source (can return nil) +} + +// A FileMode represents a file's mode and permission bits. +// The bits have the same definition on all systems, so that +// information about files can be moved from one system +// to another portably. Not all bits apply to all systems. +// The only required bit is [ModeDir] for directories. +type FileMode uint32 + +// The defined file mode bits are the most significant bits of the [FileMode]. +// The nine least-significant bits are the standard Unix rwxrwxrwx permissions. +// The values of these bits should be considered part of the public API and +// may be used in wire protocols or disk representations: they must not be +// changed, although new bits might be added. +const ( + // The single letters are the abbreviations + // used by the String method's formatting. + ModeDir FileMode = 1 << (32 - 1 - iota) // d: is a directory + ModeAppend // a: append-only + ModeExclusive // l: exclusive use + ModeTemporary // T: temporary file; Plan 9 only + ModeSymlink // L: symbolic link + ModeDevice // D: device file + ModeNamedPipe // p: named pipe (FIFO) + ModeSocket // S: Unix domain socket + ModeSetuid // u: setuid + ModeSetgid // g: setgid + ModeCharDevice // c: Unix character device, when ModeDevice is set + ModeSticky // t: sticky + ModeIrregular // ?: non-regular file; nothing else is known about this file + + // Mask for the type bits. For regular files, none will be set. + ModeType = ModeDir | ModeSymlink | ModeNamedPipe | ModeSocket | ModeDevice | ModeCharDevice | ModeIrregular + + ModePerm FileMode = 0777 // Unix permission bits +) + +func (m FileMode) String() string { + const str = "dalTLDpSugct?" + var buf [32]byte // Mode is uint32. + w := 0 + for i, c := range str { + if m&(1<<uint(32-1-i)) != 0 { + buf[w] = byte(c) + w++ + } + } + if w == 0 { + buf[w] = '-' + w++ + } + const rwx = "rwxrwxrwx" + for i, c := range rwx { + if m&(1<<uint(9-1-i)) != 0 { + buf[w] = byte(c) + } else { + buf[w] = '-' + } + w++ + } + return string(buf[:w]) +} + +// IsDir reports whether m describes a directory. +// That is, it tests for the [ModeDir] bit being set in m. +func (m FileMode) IsDir() bool { + return m&ModeDir != 0 +} + +// IsRegular reports whether m describes a regular file. +// That is, it tests that no mode type bits are set. +func (m FileMode) IsRegular() bool { + return m&ModeType == 0 +} + +// Perm returns the Unix permission bits in m (m & [ModePerm]). +func (m FileMode) Perm() FileMode { + return m & ModePerm +} + +// Type returns type bits in m (m & [ModeType]). +func (m FileMode) Type() FileMode { + return m & ModeType +} + +// PathError records an error and the operation and file path that caused it. +type PathError struct { + Op string + Path string + Err error +} + +func (e *PathError) Error() string { return e.Op + " " + e.Path + ": " + e.Err.Error() } + +func (e *PathError) Unwrap() error { return e.Err } + +// Timeout reports whether this error represents a timeout. +func (e *PathError) Timeout() bool { + t, ok := e.Err.(interface{ Timeout() bool }) + return ok && t.Timeout() +} diff --git a/src/io/fs/fs_test.go b/src/io/fs/fs_test.go new file mode 100644 index 0000000..aae1a76 --- /dev/null +++ b/src/io/fs/fs_test.go @@ -0,0 +1,49 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fs_test + +import ( + . "io/fs" + "testing" +) + +var isValidPathTests = []struct { + name string + ok bool +}{ + {".", true}, + {"x", true}, + {"x/y", true}, + + {"", false}, + {"..", false}, + {"/", false}, + {"x/", false}, + {"/x", false}, + {"x/y/", false}, + {"/x/y", false}, + {"./", false}, + {"./x", false}, + {"x/.", false}, + {"x/./y", false}, + {"../", false}, + {"../x", false}, + {"x/..", false}, + {"x/../y", false}, + {"x//y", false}, + {`x\`, true}, + {`x\y`, true}, + {`x:y`, true}, + {`\x`, true}, +} + +func TestValidPath(t *testing.T) { + for _, tt := range isValidPathTests { + ok := ValidPath(tt.name) + if ok != tt.ok { + t.Errorf("ValidPath(%q) = %v, want %v", tt.name, ok, tt.ok) + } + } +} diff --git a/src/io/fs/glob.go b/src/io/fs/glob.go new file mode 100644 index 0000000..db17156 --- /dev/null +++ b/src/io/fs/glob.go @@ -0,0 +1,129 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fs + +import ( + "path" +) + +// A GlobFS is a file system with a Glob method. +type GlobFS interface { + FS + + // Glob returns the names of all files matching pattern, + // providing an implementation of the top-level + // Glob function. + Glob(pattern string) ([]string, error) +} + +// Glob returns the names of all files matching pattern or nil +// if there is no matching file. The syntax of patterns is the same +// as in [path.Match]. The pattern may describe hierarchical names such as +// usr/*/bin/ed. +// +// Glob ignores file system errors such as I/O errors reading directories. +// The only possible returned error is [path.ErrBadPattern], reporting that +// the pattern is malformed. +// +// If fs implements [GlobFS], Glob calls fs.Glob. +// Otherwise, Glob uses [ReadDir] to traverse the directory tree +// and look for matches for the pattern. +func Glob(fsys FS, pattern string) (matches []string, err error) { + return globWithLimit(fsys, pattern, 0) +} + +func globWithLimit(fsys FS, pattern string, depth int) (matches []string, err error) { + // This limit is added to prevent stack exhaustion issues. See + // CVE-2022-30630. + const pathSeparatorsLimit = 10000 + if depth > pathSeparatorsLimit { + return nil, path.ErrBadPattern + } + if fsys, ok := fsys.(GlobFS); ok { + return fsys.Glob(pattern) + } + + // Check pattern is well-formed. + if _, err := path.Match(pattern, ""); err != nil { + return nil, err + } + if !hasMeta(pattern) { + if _, err = Stat(fsys, pattern); err != nil { + return nil, nil + } + return []string{pattern}, nil + } + + dir, file := path.Split(pattern) + dir = cleanGlobPath(dir) + + if !hasMeta(dir) { + return glob(fsys, dir, file, nil) + } + + // Prevent infinite recursion. See issue 15879. + if dir == pattern { + return nil, path.ErrBadPattern + } + + var m []string + m, err = globWithLimit(fsys, dir, depth+1) + if err != nil { + return nil, err + } + for _, d := range m { + matches, err = glob(fsys, d, file, matches) + if err != nil { + return + } + } + return +} + +// cleanGlobPath prepares path for glob matching. +func cleanGlobPath(path string) string { + switch path { + case "": + return "." + default: + return path[0 : len(path)-1] // chop off trailing separator + } +} + +// glob searches for files matching pattern in the directory dir +// and appends them to matches, returning the updated slice. +// If the directory cannot be opened, glob returns the existing matches. +// New matches are added in lexicographical order. +func glob(fs FS, dir, pattern string, matches []string) (m []string, e error) { + m = matches + infos, err := ReadDir(fs, dir) + if err != nil { + return // ignore I/O error + } + + for _, info := range infos { + n := info.Name() + matched, err := path.Match(pattern, n) + if err != nil { + return m, err + } + if matched { + m = append(m, path.Join(dir, n)) + } + } + return +} + +// hasMeta reports whether path contains any of the magic characters +// recognized by path.Match. +func hasMeta(path string) bool { + for i := 0; i < len(path); i++ { + switch path[i] { + case '*', '?', '[', '\\': + return true + } + } + return false +} diff --git a/src/io/fs/glob_test.go b/src/io/fs/glob_test.go new file mode 100644 index 0000000..d052eab --- /dev/null +++ b/src/io/fs/glob_test.go @@ -0,0 +1,97 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fs_test + +import ( + . "io/fs" + "os" + "path" + "strings" + "testing" +) + +var globTests = []struct { + fs FS + pattern, result string +}{ + {os.DirFS("."), "glob.go", "glob.go"}, + {os.DirFS("."), "gl?b.go", "glob.go"}, + {os.DirFS("."), `gl\ob.go`, "glob.go"}, + {os.DirFS("."), "*", "glob.go"}, + {os.DirFS(".."), "*/glob.go", "fs/glob.go"}, +} + +func TestGlob(t *testing.T) { + for _, tt := range globTests { + matches, err := Glob(tt.fs, tt.pattern) + if err != nil { + t.Errorf("Glob error for %q: %s", tt.pattern, err) + continue + } + if !contains(matches, tt.result) { + t.Errorf("Glob(%#q) = %#v want %v", tt.pattern, matches, tt.result) + } + } + for _, pattern := range []string{"no_match", "../*/no_match", `\*`} { + matches, err := Glob(os.DirFS("."), pattern) + if err != nil { + t.Errorf("Glob error for %q: %s", pattern, err) + continue + } + if len(matches) != 0 { + t.Errorf("Glob(%#q) = %#v want []", pattern, matches) + } + } +} + +func TestGlobError(t *testing.T) { + bad := []string{`[]`, `nonexist/[]`} + for _, pattern := range bad { + _, err := Glob(os.DirFS("."), pattern) + if err != path.ErrBadPattern { + t.Errorf("Glob(fs, %#q) returned err=%v, want path.ErrBadPattern", pattern, err) + } + } +} + +func TestCVE202230630(t *testing.T) { + // Prior to CVE-2022-30630, a stack exhaustion would occur given a large + // number of separators. There is now a limit of 10,000. + _, err := Glob(os.DirFS("."), "/*"+strings.Repeat("/", 10001)) + if err != path.ErrBadPattern { + t.Fatalf("Glob returned err=%v, want %v", err, path.ErrBadPattern) + } +} + +// contains reports whether vector contains the string s. +func contains(vector []string, s string) bool { + for _, elem := range vector { + if elem == s { + return true + } + } + return false +} + +type globOnly struct{ GlobFS } + +func (globOnly) Open(name string) (File, error) { return nil, ErrNotExist } + +func TestGlobMethod(t *testing.T) { + check := func(desc string, names []string, err error) { + t.Helper() + if err != nil || len(names) != 1 || names[0] != "hello.txt" { + t.Errorf("Glob(%s) = %v, %v, want %v, nil", desc, names, err, []string{"hello.txt"}) + } + } + + // Test that ReadDir uses the method when present. + names, err := Glob(globOnly{testFsys}, "*.txt") + check("readDirOnly", names, err) + + // Test that ReadDir uses Open when the method is not present. + names, err = Glob(openOnly{testFsys}, "*.txt") + check("openOnly", names, err) +} diff --git a/src/io/fs/readdir.go b/src/io/fs/readdir.go new file mode 100644 index 0000000..22ced48 --- /dev/null +++ b/src/io/fs/readdir.go @@ -0,0 +1,81 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fs + +import ( + "errors" + "sort" +) + +// ReadDirFS is the interface implemented by a file system +// that provides an optimized implementation of [ReadDir]. +type ReadDirFS interface { + FS + + // ReadDir reads the named directory + // and returns a list of directory entries sorted by filename. + ReadDir(name string) ([]DirEntry, error) +} + +// ReadDir reads the named directory +// and returns a list of directory entries sorted by filename. +// +// If fs implements [ReadDirFS], ReadDir calls fs.ReadDir. +// Otherwise ReadDir calls fs.Open and uses ReadDir and Close +// on the returned file. +func ReadDir(fsys FS, name string) ([]DirEntry, error) { + if fsys, ok := fsys.(ReadDirFS); ok { + return fsys.ReadDir(name) + } + + file, err := fsys.Open(name) + if err != nil { + return nil, err + } + defer file.Close() + + dir, ok := file.(ReadDirFile) + if !ok { + return nil, &PathError{Op: "readdir", Path: name, Err: errors.New("not implemented")} + } + + list, err := dir.ReadDir(-1) + sort.Slice(list, func(i, j int) bool { return list[i].Name() < list[j].Name() }) + return list, err +} + +// dirInfo is a DirEntry based on a FileInfo. +type dirInfo struct { + fileInfo FileInfo +} + +func (di dirInfo) IsDir() bool { + return di.fileInfo.IsDir() +} + +func (di dirInfo) Type() FileMode { + return di.fileInfo.Mode().Type() +} + +func (di dirInfo) Info() (FileInfo, error) { + return di.fileInfo, nil +} + +func (di dirInfo) Name() string { + return di.fileInfo.Name() +} + +func (di dirInfo) String() string { + return FormatDirEntry(di) +} + +// FileInfoToDirEntry returns a [DirEntry] that returns information from info. +// If info is nil, FileInfoToDirEntry returns nil. +func FileInfoToDirEntry(info FileInfo) DirEntry { + if info == nil { + return nil + } + return dirInfo{fileInfo: info} +} diff --git a/src/io/fs/readdir_test.go b/src/io/fs/readdir_test.go new file mode 100644 index 0000000..4c409ae --- /dev/null +++ b/src/io/fs/readdir_test.go @@ -0,0 +1,111 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fs_test + +import ( + "errors" + . "io/fs" + "os" + "testing" + "testing/fstest" + "time" +) + +type readDirOnly struct{ ReadDirFS } + +func (readDirOnly) Open(name string) (File, error) { return nil, ErrNotExist } + +func TestReadDir(t *testing.T) { + check := func(desc string, dirs []DirEntry, err error) { + t.Helper() + if err != nil || len(dirs) != 2 || dirs[0].Name() != "hello.txt" || dirs[1].Name() != "sub" { + var names []string + for _, d := range dirs { + names = append(names, d.Name()) + } + t.Errorf("ReadDir(%s) = %v, %v, want %v, nil", desc, names, err, []string{"hello.txt", "sub"}) + } + } + + // Test that ReadDir uses the method when present. + dirs, err := ReadDir(readDirOnly{testFsys}, ".") + check("readDirOnly", dirs, err) + + // Test that ReadDir uses Open when the method is not present. + dirs, err = ReadDir(openOnly{testFsys}, ".") + check("openOnly", dirs, err) + + // Test that ReadDir on Sub of . works (sub_test checks non-trivial subs). + sub, err := Sub(testFsys, ".") + if err != nil { + t.Fatal(err) + } + dirs, err = ReadDir(sub, ".") + check("sub(.)", dirs, err) +} + +func TestFileInfoToDirEntry(t *testing.T) { + testFs := fstest.MapFS{ + "notadir.txt": { + Data: []byte("hello, world"), + Mode: 0, + ModTime: time.Now(), + Sys: &sysValue, + }, + "adir": { + Data: nil, + Mode: os.ModeDir, + ModTime: time.Now(), + Sys: &sysValue, + }, + } + + tests := []struct { + path string + wantMode FileMode + wantDir bool + }{ + {path: "notadir.txt", wantMode: 0, wantDir: false}, + {path: "adir", wantMode: os.ModeDir, wantDir: true}, + } + + for _, test := range tests { + test := test + t.Run(test.path, func(t *testing.T) { + fi, err := Stat(testFs, test.path) + if err != nil { + t.Fatal(err) + } + + dirEntry := FileInfoToDirEntry(fi) + if g, w := dirEntry.Type(), test.wantMode; g != w { + t.Errorf("FileMode mismatch: got=%v, want=%v", g, w) + } + if g, w := dirEntry.Name(), test.path; g != w { + t.Errorf("Name mismatch: got=%v, want=%v", g, w) + } + if g, w := dirEntry.IsDir(), test.wantDir; g != w { + t.Errorf("IsDir mismatch: got=%v, want=%v", g, w) + } + }) + } +} + +func errorPath(err error) string { + var perr *PathError + if !errors.As(err, &perr) { + return "" + } + return perr.Path +} + +func TestReadDirPath(t *testing.T) { + fsys := os.DirFS(t.TempDir()) + _, err1 := ReadDir(fsys, "non-existent") + _, err2 := ReadDir(struct{ FS }{fsys}, "non-existent") + if s1, s2 := errorPath(err1), errorPath(err2); s1 != s2 { + t.Fatalf("s1: %s != s2: %s", s1, s2) + } +} diff --git a/src/io/fs/readfile.go b/src/io/fs/readfile.go new file mode 100644 index 0000000..41ca5bf --- /dev/null +++ b/src/io/fs/readfile.go @@ -0,0 +1,66 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fs + +import "io" + +// ReadFileFS is the interface implemented by a file system +// that provides an optimized implementation of [ReadFile]. +type ReadFileFS interface { + FS + + // ReadFile reads the named file and returns its contents. + // A successful call returns a nil error, not io.EOF. + // (Because ReadFile reads the whole file, the expected EOF + // from the final Read is not treated as an error to be reported.) + // + // The caller is permitted to modify the returned byte slice. + // This method should return a copy of the underlying data. + ReadFile(name string) ([]byte, error) +} + +// ReadFile reads the named file from the file system fs and returns its contents. +// A successful call returns a nil error, not [io.EOF]. +// (Because ReadFile reads the whole file, the expected EOF +// from the final Read is not treated as an error to be reported.) +// +// If fs implements [ReadFileFS], ReadFile calls fs.ReadFile. +// Otherwise ReadFile calls fs.Open and uses Read and Close +// on the returned [File]. +func ReadFile(fsys FS, name string) ([]byte, error) { + if fsys, ok := fsys.(ReadFileFS); ok { + return fsys.ReadFile(name) + } + + file, err := fsys.Open(name) + if err != nil { + return nil, err + } + defer file.Close() + + var size int + if info, err := file.Stat(); err == nil { + size64 := info.Size() + if int64(int(size64)) == size64 { + size = int(size64) + } + } + + data := make([]byte, 0, size+1) + for { + if len(data) >= cap(data) { + d := append(data[:cap(data)], 0) + data = d[:len(data)] + } + n, err := file.Read(data[len(data):cap(data)]) + data = data[:len(data)+n] + if err != nil { + if err == io.EOF { + err = nil + } + return data, err + } + } +} diff --git a/src/io/fs/readfile_test.go b/src/io/fs/readfile_test.go new file mode 100644 index 0000000..3c521f6 --- /dev/null +++ b/src/io/fs/readfile_test.go @@ -0,0 +1,69 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fs_test + +import ( + . "io/fs" + "os" + "testing" + "testing/fstest" + "time" +) + +var testFsys = fstest.MapFS{ + "hello.txt": { + Data: []byte("hello, world"), + Mode: 0456, + ModTime: time.Now(), + Sys: &sysValue, + }, + "sub/goodbye.txt": { + Data: []byte("goodbye, world"), + Mode: 0456, + ModTime: time.Now(), + Sys: &sysValue, + }, +} + +var sysValue int + +type readFileOnly struct{ ReadFileFS } + +func (readFileOnly) Open(name string) (File, error) { return nil, ErrNotExist } + +type openOnly struct{ FS } + +func TestReadFile(t *testing.T) { + // Test that ReadFile uses the method when present. + data, err := ReadFile(readFileOnly{testFsys}, "hello.txt") + if string(data) != "hello, world" || err != nil { + t.Fatalf(`ReadFile(readFileOnly, "hello.txt") = %q, %v, want %q, nil`, data, err, "hello, world") + } + + // Test that ReadFile uses Open when the method is not present. + data, err = ReadFile(openOnly{testFsys}, "hello.txt") + if string(data) != "hello, world" || err != nil { + t.Fatalf(`ReadFile(openOnly, "hello.txt") = %q, %v, want %q, nil`, data, err, "hello, world") + } + + // Test that ReadFile on Sub of . works (sub_test checks non-trivial subs). + sub, err := Sub(testFsys, ".") + if err != nil { + t.Fatal(err) + } + data, err = ReadFile(sub, "hello.txt") + if string(data) != "hello, world" || err != nil { + t.Fatalf(`ReadFile(sub(.), "hello.txt") = %q, %v, want %q, nil`, data, err, "hello, world") + } +} + +func TestReadFilePath(t *testing.T) { + fsys := os.DirFS(t.TempDir()) + _, err1 := ReadFile(fsys, "non-existent") + _, err2 := ReadFile(struct{ FS }{fsys}, "non-existent") + if s1, s2 := errorPath(err1), errorPath(err2); s1 != s2 { + t.Fatalf("s1: %s != s2: %s", s1, s2) + } +} diff --git a/src/io/fs/stat.go b/src/io/fs/stat.go new file mode 100644 index 0000000..bbb91c2 --- /dev/null +++ b/src/io/fs/stat.go @@ -0,0 +1,31 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fs + +// A StatFS is a file system with a Stat method. +type StatFS interface { + FS + + // Stat returns a FileInfo describing the file. + // If there is an error, it should be of type *PathError. + Stat(name string) (FileInfo, error) +} + +// Stat returns a [FileInfo] describing the named file from the file system. +// +// If fs implements [StatFS], Stat calls fs.Stat. +// Otherwise, Stat opens the [File] to stat it. +func Stat(fsys FS, name string) (FileInfo, error) { + if fsys, ok := fsys.(StatFS); ok { + return fsys.Stat(name) + } + + file, err := fsys.Open(name) + if err != nil { + return nil, err + } + defer file.Close() + return file.Stat() +} diff --git a/src/io/fs/stat_test.go b/src/io/fs/stat_test.go new file mode 100644 index 0000000..e312b6f --- /dev/null +++ b/src/io/fs/stat_test.go @@ -0,0 +1,36 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fs_test + +import ( + "fmt" + . "io/fs" + "testing" +) + +type statOnly struct{ StatFS } + +func (statOnly) Open(name string) (File, error) { return nil, ErrNotExist } + +func TestStat(t *testing.T) { + check := func(desc string, info FileInfo, err error) { + t.Helper() + if err != nil || info == nil || info.Mode() != 0456 { + infoStr := "<nil>" + if info != nil { + infoStr = fmt.Sprintf("FileInfo(Mode: %#o)", info.Mode()) + } + t.Fatalf("Stat(%s) = %v, %v, want Mode:0456, nil", desc, infoStr, err) + } + } + + // Test that Stat uses the method when present. + info, err := Stat(statOnly{testFsys}, "hello.txt") + check("statOnly", info, err) + + // Test that Stat uses Open when the method is not present. + info, err = Stat(openOnly{testFsys}, "hello.txt") + check("openOnly", info, err) +} diff --git a/src/io/fs/sub.go b/src/io/fs/sub.go new file mode 100644 index 0000000..9999e63 --- /dev/null +++ b/src/io/fs/sub.go @@ -0,0 +1,138 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fs + +import ( + "errors" + "path" +) + +// A SubFS is a file system with a Sub method. +type SubFS interface { + FS + + // Sub returns an FS corresponding to the subtree rooted at dir. + Sub(dir string) (FS, error) +} + +// Sub returns an [FS] corresponding to the subtree rooted at fsys's dir. +// +// If dir is ".", Sub returns fsys unchanged. +// Otherwise, if fs implements [SubFS], Sub returns fsys.Sub(dir). +// Otherwise, Sub returns a new [FS] implementation sub that, +// in effect, implements sub.Open(name) as fsys.Open(path.Join(dir, name)). +// The implementation also translates calls to ReadDir, ReadFile, and Glob appropriately. +// +// Note that Sub(os.DirFS("/"), "prefix") is equivalent to os.DirFS("/prefix") +// and that neither of them guarantees to avoid operating system +// accesses outside "/prefix", because the implementation of [os.DirFS] +// does not check for symbolic links inside "/prefix" that point to +// other directories. That is, [os.DirFS] is not a general substitute for a +// chroot-style security mechanism, and Sub does not change that fact. +func Sub(fsys FS, dir string) (FS, error) { + if !ValidPath(dir) { + return nil, &PathError{Op: "sub", Path: dir, Err: errors.New("invalid name")} + } + if dir == "." { + return fsys, nil + } + if fsys, ok := fsys.(SubFS); ok { + return fsys.Sub(dir) + } + return &subFS{fsys, dir}, nil +} + +type subFS struct { + fsys FS + dir string +} + +// fullName maps name to the fully-qualified name dir/name. +func (f *subFS) fullName(op string, name string) (string, error) { + if !ValidPath(name) { + return "", &PathError{Op: op, Path: name, Err: errors.New("invalid name")} + } + return path.Join(f.dir, name), nil +} + +// shorten maps name, which should start with f.dir, back to the suffix after f.dir. +func (f *subFS) shorten(name string) (rel string, ok bool) { + if name == f.dir { + return ".", true + } + if len(name) >= len(f.dir)+2 && name[len(f.dir)] == '/' && name[:len(f.dir)] == f.dir { + return name[len(f.dir)+1:], true + } + return "", false +} + +// fixErr shortens any reported names in PathErrors by stripping f.dir. +func (f *subFS) fixErr(err error) error { + if e, ok := err.(*PathError); ok { + if short, ok := f.shorten(e.Path); ok { + e.Path = short + } + } + return err +} + +func (f *subFS) Open(name string) (File, error) { + full, err := f.fullName("open", name) + if err != nil { + return nil, err + } + file, err := f.fsys.Open(full) + return file, f.fixErr(err) +} + +func (f *subFS) ReadDir(name string) ([]DirEntry, error) { + full, err := f.fullName("read", name) + if err != nil { + return nil, err + } + dir, err := ReadDir(f.fsys, full) + return dir, f.fixErr(err) +} + +func (f *subFS) ReadFile(name string) ([]byte, error) { + full, err := f.fullName("read", name) + if err != nil { + return nil, err + } + data, err := ReadFile(f.fsys, full) + return data, f.fixErr(err) +} + +func (f *subFS) Glob(pattern string) ([]string, error) { + // Check pattern is well-formed. + if _, err := path.Match(pattern, ""); err != nil { + return nil, err + } + if pattern == "." { + return []string{"."}, nil + } + + full := f.dir + "/" + pattern + list, err := Glob(f.fsys, full) + for i, name := range list { + name, ok := f.shorten(name) + if !ok { + return nil, errors.New("invalid result from inner fsys Glob: " + name + " not in " + f.dir) // can't use fmt in this package + } + list[i] = name + } + return list, f.fixErr(err) +} + +func (f *subFS) Sub(dir string) (FS, error) { + if dir == "." { + return f, nil + } + full, err := f.fullName("sub", dir) + if err != nil { + return nil, err + } + return &subFS{f.fsys, full}, nil +} diff --git a/src/io/fs/sub_test.go b/src/io/fs/sub_test.go new file mode 100644 index 0000000..451b0ef --- /dev/null +++ b/src/io/fs/sub_test.go @@ -0,0 +1,57 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fs_test + +import ( + . "io/fs" + "testing" +) + +type subOnly struct{ SubFS } + +func (subOnly) Open(name string) (File, error) { return nil, ErrNotExist } + +func TestSub(t *testing.T) { + check := func(desc string, sub FS, err error) { + t.Helper() + if err != nil { + t.Errorf("Sub(sub): %v", err) + return + } + data, err := ReadFile(sub, "goodbye.txt") + if string(data) != "goodbye, world" || err != nil { + t.Errorf(`ReadFile(%s, "goodbye.txt" = %q, %v, want %q, nil`, desc, string(data), err, "goodbye, world") + } + + dirs, err := ReadDir(sub, ".") + if err != nil || len(dirs) != 1 || dirs[0].Name() != "goodbye.txt" { + var names []string + for _, d := range dirs { + names = append(names, d.Name()) + } + t.Errorf(`ReadDir(%s, ".") = %v, %v, want %v, nil`, desc, names, err, []string{"goodbye.txt"}) + } + } + + // Test that Sub uses the method when present. + sub, err := Sub(subOnly{testFsys}, "sub") + check("subOnly", sub, err) + + // Test that Sub uses Open when the method is not present. + sub, err = Sub(openOnly{testFsys}, "sub") + check("openOnly", sub, err) + + _, err = sub.Open("nonexist") + if err == nil { + t.Fatal("Open(nonexist): succeeded") + } + pe, ok := err.(*PathError) + if !ok { + t.Fatalf("Open(nonexist): error is %T, want *PathError", err) + } + if pe.Path != "nonexist" { + t.Fatalf("Open(nonexist): err.Path = %q, want %q", pe.Path, "nonexist") + } +} diff --git a/src/io/fs/walk.go b/src/io/fs/walk.go new file mode 100644 index 0000000..2e8a8db --- /dev/null +++ b/src/io/fs/walk.go @@ -0,0 +1,128 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fs + +import ( + "errors" + "path" +) + +// SkipDir is used as a return value from [WalkDirFunc] to indicate that +// the directory named in the call is to be skipped. It is not returned +// as an error by any function. +var SkipDir = errors.New("skip this directory") + +// SkipAll is used as a return value from [WalkDirFunc] to indicate that +// all remaining files and directories are to be skipped. It is not returned +// as an error by any function. +var SkipAll = errors.New("skip everything and stop the walk") + +// WalkDirFunc is the type of the function called by [WalkDir] to visit +// each file or directory. +// +// The path argument contains the argument to [WalkDir] as a prefix. +// That is, if WalkDir is called with root argument "dir" and finds a file +// named "a" in that directory, the walk function will be called with +// argument "dir/a". +// +// The d argument is the [DirEntry] for the named path. +// +// The error result returned by the function controls how [WalkDir] +// continues. If the function returns the special value [SkipDir], WalkDir +// skips the current directory (path if d.IsDir() is true, otherwise +// path's parent directory). If the function returns the special value +// [SkipAll], WalkDir skips all remaining files and directories. Otherwise, +// if the function returns a non-nil error, WalkDir stops entirely and +// returns that error. +// +// The err argument reports an error related to path, signaling that +// [WalkDir] will not walk into that directory. The function can decide how +// to handle that error; as described earlier, returning the error will +// cause WalkDir to stop walking the entire tree. +// +// [WalkDir] calls the function with a non-nil err argument in two cases. +// +// First, if the initial [Stat] on the root directory fails, WalkDir +// calls the function with path set to root, d set to nil, and err set to +// the error from [fs.Stat]. +// +// Second, if a directory's ReadDir method (see [ReadDirFile]) fails, WalkDir calls the +// function with path set to the directory's path, d set to an +// [DirEntry] describing the directory, and err set to the error from +// ReadDir. In this second case, the function is called twice with the +// path of the directory: the first call is before the directory read is +// attempted and has err set to nil, giving the function a chance to +// return [SkipDir] or [SkipAll] and avoid the ReadDir entirely. The second call +// is after a failed ReadDir and reports the error from ReadDir. +// (If ReadDir succeeds, there is no second call.) +// +// The differences between WalkDirFunc compared to [path/filepath.WalkFunc] are: +// +// - The second argument has type [DirEntry] instead of [FileInfo]. +// - The function is called before reading a directory, to allow [SkipDir] +// or [SkipAll] to bypass the directory read entirely or skip all remaining +// files and directories respectively. +// - If a directory read fails, the function is called a second time +// for that directory to report the error. +type WalkDirFunc func(path string, d DirEntry, err error) error + +// walkDir recursively descends path, calling walkDirFn. +func walkDir(fsys FS, name string, d DirEntry, walkDirFn WalkDirFunc) error { + if err := walkDirFn(name, d, nil); err != nil || !d.IsDir() { + if err == SkipDir && d.IsDir() { + // Successfully skipped directory. + err = nil + } + return err + } + + dirs, err := ReadDir(fsys, name) + if err != nil { + // Second call, to report ReadDir error. + err = walkDirFn(name, d, err) + if err != nil { + if err == SkipDir && d.IsDir() { + err = nil + } + return err + } + } + + for _, d1 := range dirs { + name1 := path.Join(name, d1.Name()) + if err := walkDir(fsys, name1, d1, walkDirFn); err != nil { + if err == SkipDir { + break + } + return err + } + } + return nil +} + +// WalkDir walks the file tree rooted at root, calling fn for each file or +// directory in the tree, including root. +// +// All errors that arise visiting files and directories are filtered by fn: +// see the [fs.WalkDirFunc] documentation for details. +// +// The files are walked in lexical order, which makes the output deterministic +// but requires WalkDir to read an entire directory into memory before proceeding +// to walk that directory. +// +// WalkDir does not follow symbolic links found in directories, +// but if root itself is a symbolic link, its target will be walked. +func WalkDir(fsys FS, root string, fn WalkDirFunc) error { + info, err := Stat(fsys, root) + if err != nil { + err = fn(root, nil, err) + } else { + err = walkDir(fsys, root, FileInfoToDirEntry(info), fn) + } + if err == SkipDir || err == SkipAll { + return nil + } + return err +} diff --git a/src/io/fs/walk_test.go b/src/io/fs/walk_test.go new file mode 100644 index 0000000..40f4e1a --- /dev/null +++ b/src/io/fs/walk_test.go @@ -0,0 +1,151 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fs_test + +import ( + . "io/fs" + "os" + pathpkg "path" + "path/filepath" + "reflect" + "testing" + "testing/fstest" +) + +type Node struct { + name string + entries []*Node // nil if the entry is a file + mark int +} + +var tree = &Node{ + "testdata", + []*Node{ + {"a", nil, 0}, + {"b", []*Node{}, 0}, + {"c", nil, 0}, + { + "d", + []*Node{ + {"x", nil, 0}, + {"y", []*Node{}, 0}, + { + "z", + []*Node{ + {"u", nil, 0}, + {"v", nil, 0}, + }, + 0, + }, + }, + 0, + }, + }, + 0, +} + +func walkTree(n *Node, path string, f func(path string, n *Node)) { + f(path, n) + for _, e := range n.entries { + walkTree(e, pathpkg.Join(path, e.name), f) + } +} + +func makeTree() FS { + fsys := fstest.MapFS{} + walkTree(tree, tree.name, func(path string, n *Node) { + if n.entries == nil { + fsys[path] = &fstest.MapFile{} + } else { + fsys[path] = &fstest.MapFile{Mode: ModeDir} + } + }) + return fsys +} + +// Assumes that each node name is unique. Good enough for a test. +// If clear is true, any incoming error is cleared before return. The errors +// are always accumulated, though. +func mark(entry DirEntry, err error, errors *[]error, clear bool) error { + name := entry.Name() + walkTree(tree, tree.name, func(path string, n *Node) { + if n.name == name { + n.mark++ + } + }) + if err != nil { + *errors = append(*errors, err) + if clear { + return nil + } + return err + } + return nil +} + +func TestWalkDir(t *testing.T) { + tmpDir := t.TempDir() + + origDir, err := os.Getwd() + if err != nil { + t.Fatal("finding working dir:", err) + } + if err = os.Chdir(tmpDir); err != nil { + t.Fatal("entering temp dir:", err) + } + defer os.Chdir(origDir) + + fsys := makeTree() + errors := make([]error, 0, 10) + clear := true + markFn := func(path string, entry DirEntry, err error) error { + return mark(entry, err, &errors, clear) + } + // Expect no errors. + err = WalkDir(fsys, ".", markFn) + if err != nil { + t.Fatalf("no error expected, found: %s", err) + } + if len(errors) != 0 { + t.Fatalf("unexpected errors: %s", errors) + } + walkTree(tree, tree.name, func(path string, n *Node) { + if n.mark != 1 { + t.Errorf("node %s mark = %d; expected 1", path, n.mark) + } + n.mark = 0 + }) +} + +func TestIssue51617(t *testing.T) { + dir := t.TempDir() + for _, sub := range []string{"a", filepath.Join("a", "bad"), filepath.Join("a", "next")} { + if err := os.Mkdir(filepath.Join(dir, sub), 0755); err != nil { + t.Fatal(err) + } + } + bad := filepath.Join(dir, "a", "bad") + if err := os.Chmod(bad, 0); err != nil { + t.Fatal(err) + } + defer os.Chmod(bad, 0700) // avoid errors on cleanup + var saw []string + err := WalkDir(os.DirFS(dir), ".", func(path string, d DirEntry, err error) error { + if err != nil { + return filepath.SkipDir + } + if d.IsDir() { + saw = append(saw, path) + } + return nil + }) + if err != nil { + t.Fatal(err) + } + want := []string{".", "a", "a/bad", "a/next"} + if !reflect.DeepEqual(saw, want) { + t.Errorf("got directories %v, want %v", saw, want) + } +} diff --git a/src/io/io.go b/src/io/io.go new file mode 100644 index 0000000..7f16e18 --- /dev/null +++ b/src/io/io.go @@ -0,0 +1,726 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package io provides basic interfaces to I/O primitives. +// Its primary job is to wrap existing implementations of such primitives, +// such as those in package os, into shared public interfaces that +// abstract the functionality, plus some other related primitives. +// +// Because these interfaces and primitives wrap lower-level operations with +// various implementations, unless otherwise informed clients should not +// assume they are safe for parallel execution. +package io + +import ( + "errors" + "sync" +) + +// Seek whence values. +const ( + SeekStart = 0 // seek relative to the origin of the file + SeekCurrent = 1 // seek relative to the current offset + SeekEnd = 2 // seek relative to the end +) + +// ErrShortWrite means that a write accepted fewer bytes than requested +// but failed to return an explicit error. +var ErrShortWrite = errors.New("short write") + +// errInvalidWrite means that a write returned an impossible count. +var errInvalidWrite = errors.New("invalid write result") + +// ErrShortBuffer means that a read required a longer buffer than was provided. +var ErrShortBuffer = errors.New("short buffer") + +// EOF is the error returned by Read when no more input is available. +// (Read must return EOF itself, not an error wrapping EOF, +// because callers will test for EOF using ==.) +// Functions should return EOF only to signal a graceful end of input. +// If the EOF occurs unexpectedly in a structured data stream, +// the appropriate error is either [ErrUnexpectedEOF] or some other error +// giving more detail. +var EOF = errors.New("EOF") + +// ErrUnexpectedEOF means that EOF was encountered in the +// middle of reading a fixed-size block or data structure. +var ErrUnexpectedEOF = errors.New("unexpected EOF") + +// ErrNoProgress is returned by some clients of a [Reader] when +// many calls to Read have failed to return any data or error, +// usually the sign of a broken [Reader] implementation. +var ErrNoProgress = errors.New("multiple Read calls return no data or error") + +// Reader is the interface that wraps the basic Read method. +// +// Read reads up to len(p) bytes into p. It returns the number of bytes +// read (0 <= n <= len(p)) and any error encountered. Even if Read +// returns n < len(p), it may use all of p as scratch space during the call. +// If some data is available but not len(p) bytes, Read conventionally +// returns what is available instead of waiting for more. +// +// When Read encounters an error or end-of-file condition after +// successfully reading n > 0 bytes, it returns the number of +// bytes read. It may return the (non-nil) error from the same call +// or return the error (and n == 0) from a subsequent call. +// An instance of this general case is that a Reader returning +// a non-zero number of bytes at the end of the input stream may +// return either err == EOF or err == nil. The next Read should +// return 0, EOF. +// +// Callers should always process the n > 0 bytes returned before +// considering the error err. Doing so correctly handles I/O errors +// that happen after reading some bytes and also both of the +// allowed EOF behaviors. +// +// If len(p) == 0, Read should always return n == 0. It may return a +// non-nil error if some error condition is known, such as EOF. +// +// Implementations of Read are discouraged from returning a +// zero byte count with a nil error, except when len(p) == 0. +// Callers should treat a return of 0 and nil as indicating that +// nothing happened; in particular it does not indicate EOF. +// +// Implementations must not retain p. +type Reader interface { + Read(p []byte) (n int, err error) +} + +// Writer is the interface that wraps the basic Write method. +// +// Write writes len(p) bytes from p to the underlying data stream. +// It returns the number of bytes written from p (0 <= n <= len(p)) +// and any error encountered that caused the write to stop early. +// Write must return a non-nil error if it returns n < len(p). +// Write must not modify the slice data, even temporarily. +// +// Implementations must not retain p. +type Writer interface { + Write(p []byte) (n int, err error) +} + +// Closer is the interface that wraps the basic Close method. +// +// The behavior of Close after the first call is undefined. +// Specific implementations may document their own behavior. +type Closer interface { + Close() error +} + +// Seeker is the interface that wraps the basic Seek method. +// +// Seek sets the offset for the next Read or Write to offset, +// interpreted according to whence: +// [SeekStart] means relative to the start of the file, +// [SeekCurrent] means relative to the current offset, and +// [SeekEnd] means relative to the end +// (for example, offset = -2 specifies the penultimate byte of the file). +// Seek returns the new offset relative to the start of the +// file or an error, if any. +// +// Seeking to an offset before the start of the file is an error. +// Seeking to any positive offset may be allowed, but if the new offset exceeds +// the size of the underlying object the behavior of subsequent I/O operations +// is implementation-dependent. +type Seeker interface { + Seek(offset int64, whence int) (int64, error) +} + +// ReadWriter is the interface that groups the basic Read and Write methods. +type ReadWriter interface { + Reader + Writer +} + +// ReadCloser is the interface that groups the basic Read and Close methods. +type ReadCloser interface { + Reader + Closer +} + +// WriteCloser is the interface that groups the basic Write and Close methods. +type WriteCloser interface { + Writer + Closer +} + +// ReadWriteCloser is the interface that groups the basic Read, Write and Close methods. +type ReadWriteCloser interface { + Reader + Writer + Closer +} + +// ReadSeeker is the interface that groups the basic Read and Seek methods. +type ReadSeeker interface { + Reader + Seeker +} + +// ReadSeekCloser is the interface that groups the basic Read, Seek and Close +// methods. +type ReadSeekCloser interface { + Reader + Seeker + Closer +} + +// WriteSeeker is the interface that groups the basic Write and Seek methods. +type WriteSeeker interface { + Writer + Seeker +} + +// ReadWriteSeeker is the interface that groups the basic Read, Write and Seek methods. +type ReadWriteSeeker interface { + Reader + Writer + Seeker +} + +// ReaderFrom is the interface that wraps the ReadFrom method. +// +// ReadFrom reads data from r until EOF or error. +// The return value n is the number of bytes read. +// Any error except EOF encountered during the read is also returned. +// +// The [Copy] function uses [ReaderFrom] if available. +type ReaderFrom interface { + ReadFrom(r Reader) (n int64, err error) +} + +// WriterTo is the interface that wraps the WriteTo method. +// +// WriteTo writes data to w until there's no more data to write or +// when an error occurs. The return value n is the number of bytes +// written. Any error encountered during the write is also returned. +// +// The Copy function uses WriterTo if available. +type WriterTo interface { + WriteTo(w Writer) (n int64, err error) +} + +// ReaderAt is the interface that wraps the basic ReadAt method. +// +// ReadAt reads len(p) bytes into p starting at offset off in the +// underlying input source. It returns the number of bytes +// read (0 <= n <= len(p)) and any error encountered. +// +// When ReadAt returns n < len(p), it returns a non-nil error +// explaining why more bytes were not returned. In this respect, +// ReadAt is stricter than Read. +// +// Even if ReadAt returns n < len(p), it may use all of p as scratch +// space during the call. If some data is available but not len(p) bytes, +// ReadAt blocks until either all the data is available or an error occurs. +// In this respect ReadAt is different from Read. +// +// If the n = len(p) bytes returned by ReadAt are at the end of the +// input source, ReadAt may return either err == EOF or err == nil. +// +// If ReadAt is reading from an input source with a seek offset, +// ReadAt should not affect nor be affected by the underlying +// seek offset. +// +// Clients of ReadAt can execute parallel ReadAt calls on the +// same input source. +// +// Implementations must not retain p. +type ReaderAt interface { + ReadAt(p []byte, off int64) (n int, err error) +} + +// WriterAt is the interface that wraps the basic WriteAt method. +// +// WriteAt writes len(p) bytes from p to the underlying data stream +// at offset off. It returns the number of bytes written from p (0 <= n <= len(p)) +// and any error encountered that caused the write to stop early. +// WriteAt must return a non-nil error if it returns n < len(p). +// +// If WriteAt is writing to a destination with a seek offset, +// WriteAt should not affect nor be affected by the underlying +// seek offset. +// +// Clients of WriteAt can execute parallel WriteAt calls on the same +// destination if the ranges do not overlap. +// +// Implementations must not retain p. +type WriterAt interface { + WriteAt(p []byte, off int64) (n int, err error) +} + +// ByteReader is the interface that wraps the ReadByte method. +// +// ReadByte reads and returns the next byte from the input or +// any error encountered. If ReadByte returns an error, no input +// byte was consumed, and the returned byte value is undefined. +// +// ReadByte provides an efficient interface for byte-at-time +// processing. A [Reader] that does not implement ByteReader +// can be wrapped using bufio.NewReader to add this method. +type ByteReader interface { + ReadByte() (byte, error) +} + +// ByteScanner is the interface that adds the UnreadByte method to the +// basic ReadByte method. +// +// UnreadByte causes the next call to ReadByte to return the last byte read. +// If the last operation was not a successful call to ReadByte, UnreadByte may +// return an error, unread the last byte read (or the byte prior to the +// last-unread byte), or (in implementations that support the [Seeker] interface) +// seek to one byte before the current offset. +type ByteScanner interface { + ByteReader + UnreadByte() error +} + +// ByteWriter is the interface that wraps the WriteByte method. +type ByteWriter interface { + WriteByte(c byte) error +} + +// RuneReader is the interface that wraps the ReadRune method. +// +// ReadRune reads a single encoded Unicode character +// and returns the rune and its size in bytes. If no character is +// available, err will be set. +type RuneReader interface { + ReadRune() (r rune, size int, err error) +} + +// RuneScanner is the interface that adds the UnreadRune method to the +// basic ReadRune method. +// +// UnreadRune causes the next call to ReadRune to return the last rune read. +// If the last operation was not a successful call to ReadRune, UnreadRune may +// return an error, unread the last rune read (or the rune prior to the +// last-unread rune), or (in implementations that support the [Seeker] interface) +// seek to the start of the rune before the current offset. +type RuneScanner interface { + RuneReader + UnreadRune() error +} + +// StringWriter is the interface that wraps the WriteString method. +type StringWriter interface { + WriteString(s string) (n int, err error) +} + +// WriteString writes the contents of the string s to w, which accepts a slice of bytes. +// If w implements [StringWriter], [StringWriter.WriteString] is invoked directly. +// Otherwise, [Writer.Write] is called exactly once. +func WriteString(w Writer, s string) (n int, err error) { + if sw, ok := w.(StringWriter); ok { + return sw.WriteString(s) + } + return w.Write([]byte(s)) +} + +// ReadAtLeast reads from r into buf until it has read at least min bytes. +// It returns the number of bytes copied and an error if fewer bytes were read. +// The error is EOF only if no bytes were read. +// If an EOF happens after reading fewer than min bytes, +// ReadAtLeast returns [ErrUnexpectedEOF]. +// If min is greater than the length of buf, ReadAtLeast returns [ErrShortBuffer]. +// On return, n >= min if and only if err == nil. +// If r returns an error having read at least min bytes, the error is dropped. +func ReadAtLeast(r Reader, buf []byte, min int) (n int, err error) { + if len(buf) < min { + return 0, ErrShortBuffer + } + for n < min && err == nil { + var nn int + nn, err = r.Read(buf[n:]) + n += nn + } + if n >= min { + err = nil + } else if n > 0 && err == EOF { + err = ErrUnexpectedEOF + } + return +} + +// ReadFull reads exactly len(buf) bytes from r into buf. +// It returns the number of bytes copied and an error if fewer bytes were read. +// The error is EOF only if no bytes were read. +// If an EOF happens after reading some but not all the bytes, +// ReadFull returns [ErrUnexpectedEOF]. +// On return, n == len(buf) if and only if err == nil. +// If r returns an error having read at least len(buf) bytes, the error is dropped. +func ReadFull(r Reader, buf []byte) (n int, err error) { + return ReadAtLeast(r, buf, len(buf)) +} + +// CopyN copies n bytes (or until an error) from src to dst. +// It returns the number of bytes copied and the earliest +// error encountered while copying. +// On return, written == n if and only if err == nil. +// +// If dst implements [ReaderFrom], the copy is implemented using it. +func CopyN(dst Writer, src Reader, n int64) (written int64, err error) { + written, err = Copy(dst, LimitReader(src, n)) + if written == n { + return n, nil + } + if written < n && err == nil { + // src stopped early; must have been EOF. + err = EOF + } + return +} + +// Copy copies from src to dst until either EOF is reached +// on src or an error occurs. It returns the number of bytes +// copied and the first error encountered while copying, if any. +// +// A successful Copy returns err == nil, not err == EOF. +// Because Copy is defined to read from src until EOF, it does +// not treat an EOF from Read as an error to be reported. +// +// If src implements [WriterTo], +// the copy is implemented by calling src.WriteTo(dst). +// Otherwise, if dst implements [ReaderFrom], +// the copy is implemented by calling dst.ReadFrom(src). +func Copy(dst Writer, src Reader) (written int64, err error) { + return copyBuffer(dst, src, nil) +} + +// CopyBuffer is identical to Copy except that it stages through the +// provided buffer (if one is required) rather than allocating a +// temporary one. If buf is nil, one is allocated; otherwise if it has +// zero length, CopyBuffer panics. +// +// If either src implements [WriterTo] or dst implements [ReaderFrom], +// buf will not be used to perform the copy. +func CopyBuffer(dst Writer, src Reader, buf []byte) (written int64, err error) { + if buf != nil && len(buf) == 0 { + panic("empty buffer in CopyBuffer") + } + return copyBuffer(dst, src, buf) +} + +// copyBuffer is the actual implementation of Copy and CopyBuffer. +// if buf is nil, one is allocated. +func copyBuffer(dst Writer, src Reader, buf []byte) (written int64, err error) { + // If the reader has a WriteTo method, use it to do the copy. + // Avoids an allocation and a copy. + if wt, ok := src.(WriterTo); ok { + return wt.WriteTo(dst) + } + // Similarly, if the writer has a ReadFrom method, use it to do the copy. + if rt, ok := dst.(ReaderFrom); ok { + return rt.ReadFrom(src) + } + if buf == nil { + size := 32 * 1024 + if l, ok := src.(*LimitedReader); ok && int64(size) > l.N { + if l.N < 1 { + size = 1 + } else { + size = int(l.N) + } + } + buf = make([]byte, size) + } + for { + nr, er := src.Read(buf) + if nr > 0 { + nw, ew := dst.Write(buf[0:nr]) + if nw < 0 || nr < nw { + nw = 0 + if ew == nil { + ew = errInvalidWrite + } + } + written += int64(nw) + if ew != nil { + err = ew + break + } + if nr != nw { + err = ErrShortWrite + break + } + } + if er != nil { + if er != EOF { + err = er + } + break + } + } + return written, err +} + +// LimitReader returns a Reader that reads from r +// but stops with EOF after n bytes. +// The underlying implementation is a *LimitedReader. +func LimitReader(r Reader, n int64) Reader { return &LimitedReader{r, n} } + +// A LimitedReader reads from R but limits the amount of +// data returned to just N bytes. Each call to Read +// updates N to reflect the new amount remaining. +// Read returns EOF when N <= 0 or when the underlying R returns EOF. +type LimitedReader struct { + R Reader // underlying reader + N int64 // max bytes remaining +} + +func (l *LimitedReader) Read(p []byte) (n int, err error) { + if l.N <= 0 { + return 0, EOF + } + if int64(len(p)) > l.N { + p = p[0:l.N] + } + n, err = l.R.Read(p) + l.N -= int64(n) + return +} + +// NewSectionReader returns a [SectionReader] that reads from r +// starting at offset off and stops with EOF after n bytes. +func NewSectionReader(r ReaderAt, off int64, n int64) *SectionReader { + var remaining int64 + const maxint64 = 1<<63 - 1 + if off <= maxint64-n { + remaining = n + off + } else { + // Overflow, with no way to return error. + // Assume we can read up to an offset of 1<<63 - 1. + remaining = maxint64 + } + return &SectionReader{r, off, off, remaining, n} +} + +// SectionReader implements Read, Seek, and ReadAt on a section +// of an underlying [ReaderAt]. +type SectionReader struct { + r ReaderAt // constant after creation + base int64 // constant after creation + off int64 + limit int64 // constant after creation + n int64 // constant after creation +} + +func (s *SectionReader) Read(p []byte) (n int, err error) { + if s.off >= s.limit { + return 0, EOF + } + if max := s.limit - s.off; int64(len(p)) > max { + p = p[0:max] + } + n, err = s.r.ReadAt(p, s.off) + s.off += int64(n) + return +} + +var errWhence = errors.New("Seek: invalid whence") +var errOffset = errors.New("Seek: invalid offset") + +func (s *SectionReader) Seek(offset int64, whence int) (int64, error) { + switch whence { + default: + return 0, errWhence + case SeekStart: + offset += s.base + case SeekCurrent: + offset += s.off + case SeekEnd: + offset += s.limit + } + if offset < s.base { + return 0, errOffset + } + s.off = offset + return offset - s.base, nil +} + +func (s *SectionReader) ReadAt(p []byte, off int64) (n int, err error) { + if off < 0 || off >= s.Size() { + return 0, EOF + } + off += s.base + if max := s.limit - off; int64(len(p)) > max { + p = p[0:max] + n, err = s.r.ReadAt(p, off) + if err == nil { + err = EOF + } + return n, err + } + return s.r.ReadAt(p, off) +} + +// Size returns the size of the section in bytes. +func (s *SectionReader) Size() int64 { return s.limit - s.base } + +// Outer returns the underlying [ReaderAt] and offsets for the section. +// +// The returned values are the same that were passed to [NewSectionReader] +// when the [SectionReader] was created. +func (s *SectionReader) Outer() (r ReaderAt, off int64, n int64) { + return s.r, s.base, s.n +} + +// An OffsetWriter maps writes at offset base to offset base+off in the underlying writer. +type OffsetWriter struct { + w WriterAt + base int64 // the original offset + off int64 // the current offset +} + +// NewOffsetWriter returns an [OffsetWriter] that writes to w +// starting at offset off. +func NewOffsetWriter(w WriterAt, off int64) *OffsetWriter { + return &OffsetWriter{w, off, off} +} + +func (o *OffsetWriter) Write(p []byte) (n int, err error) { + n, err = o.w.WriteAt(p, o.off) + o.off += int64(n) + return +} + +func (o *OffsetWriter) WriteAt(p []byte, off int64) (n int, err error) { + if off < 0 { + return 0, errOffset + } + + off += o.base + return o.w.WriteAt(p, off) +} + +func (o *OffsetWriter) Seek(offset int64, whence int) (int64, error) { + switch whence { + default: + return 0, errWhence + case SeekStart: + offset += o.base + case SeekCurrent: + offset += o.off + } + if offset < o.base { + return 0, errOffset + } + o.off = offset + return offset - o.base, nil +} + +// TeeReader returns a [Reader] that writes to w what it reads from r. +// All reads from r performed through it are matched with +// corresponding writes to w. There is no internal buffering - +// the write must complete before the read completes. +// Any error encountered while writing is reported as a read error. +func TeeReader(r Reader, w Writer) Reader { + return &teeReader{r, w} +} + +type teeReader struct { + r Reader + w Writer +} + +func (t *teeReader) Read(p []byte) (n int, err error) { + n, err = t.r.Read(p) + if n > 0 { + if n, err := t.w.Write(p[:n]); err != nil { + return n, err + } + } + return +} + +// Discard is a [Writer] on which all Write calls succeed +// without doing anything. +var Discard Writer = discard{} + +type discard struct{} + +// discard implements ReaderFrom as an optimization so Copy to +// io.Discard can avoid doing unnecessary work. +var _ ReaderFrom = discard{} + +func (discard) Write(p []byte) (int, error) { + return len(p), nil +} + +func (discard) WriteString(s string) (int, error) { + return len(s), nil +} + +var blackHolePool = sync.Pool{ + New: func() any { + b := make([]byte, 8192) + return &b + }, +} + +func (discard) ReadFrom(r Reader) (n int64, err error) { + bufp := blackHolePool.Get().(*[]byte) + readSize := 0 + for { + readSize, err = r.Read(*bufp) + n += int64(readSize) + if err != nil { + blackHolePool.Put(bufp) + if err == EOF { + return n, nil + } + return + } + } +} + +// NopCloser returns a [ReadCloser] with a no-op Close method wrapping +// the provided [Reader] r. +// If r implements [WriterTo], the returned [ReadCloser] will implement [WriterTo] +// by forwarding calls to r. +func NopCloser(r Reader) ReadCloser { + if _, ok := r.(WriterTo); ok { + return nopCloserWriterTo{r} + } + return nopCloser{r} +} + +type nopCloser struct { + Reader +} + +func (nopCloser) Close() error { return nil } + +type nopCloserWriterTo struct { + Reader +} + +func (nopCloserWriterTo) Close() error { return nil } + +func (c nopCloserWriterTo) WriteTo(w Writer) (n int64, err error) { + return c.Reader.(WriterTo).WriteTo(w) +} + +// ReadAll reads from r until an error or EOF and returns the data it read. +// A successful call returns err == nil, not err == EOF. Because ReadAll is +// defined to read from src until EOF, it does not treat an EOF from Read +// as an error to be reported. +func ReadAll(r Reader) ([]byte, error) { + b := make([]byte, 0, 512) + for { + n, err := r.Read(b[len(b):cap(b)]) + b = b[:len(b)+n] + if err != nil { + if err == EOF { + err = nil + } + return b, err + } + + if len(b) == cap(b) { + // Add more capacity (let append pick how much). + b = append(b, 0)[:len(b)] + } + } +} diff --git a/src/io/io_test.go b/src/io/io_test.go new file mode 100644 index 0000000..9491ffa --- /dev/null +++ b/src/io/io_test.go @@ -0,0 +1,697 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package io_test + +import ( + "bytes" + "errors" + "fmt" + . "io" + "os" + "strings" + "sync" + "sync/atomic" + "testing" +) + +// A version of bytes.Buffer without ReadFrom and WriteTo +type Buffer struct { + bytes.Buffer + ReaderFrom // conflicts with and hides bytes.Buffer's ReaderFrom. + WriterTo // conflicts with and hides bytes.Buffer's WriterTo. +} + +// Simple tests, primarily to verify the ReadFrom and WriteTo callouts inside Copy, CopyBuffer and CopyN. + +func TestCopy(t *testing.T) { + rb := new(Buffer) + wb := new(Buffer) + rb.WriteString("hello, world.") + Copy(wb, rb) + if wb.String() != "hello, world." { + t.Errorf("Copy did not work properly") + } +} + +func TestCopyNegative(t *testing.T) { + rb := new(Buffer) + wb := new(Buffer) + rb.WriteString("hello") + Copy(wb, &LimitedReader{R: rb, N: -1}) + if wb.String() != "" { + t.Errorf("Copy on LimitedReader with N<0 copied data") + } + + CopyN(wb, rb, -1) + if wb.String() != "" { + t.Errorf("CopyN with N<0 copied data") + } +} + +func TestCopyBuffer(t *testing.T) { + rb := new(Buffer) + wb := new(Buffer) + rb.WriteString("hello, world.") + CopyBuffer(wb, rb, make([]byte, 1)) // Tiny buffer to keep it honest. + if wb.String() != "hello, world." { + t.Errorf("CopyBuffer did not work properly") + } +} + +func TestCopyBufferNil(t *testing.T) { + rb := new(Buffer) + wb := new(Buffer) + rb.WriteString("hello, world.") + CopyBuffer(wb, rb, nil) // Should allocate a buffer. + if wb.String() != "hello, world." { + t.Errorf("CopyBuffer did not work properly") + } +} + +func TestCopyReadFrom(t *testing.T) { + rb := new(Buffer) + wb := new(bytes.Buffer) // implements ReadFrom. + rb.WriteString("hello, world.") + Copy(wb, rb) + if wb.String() != "hello, world." { + t.Errorf("Copy did not work properly") + } +} + +func TestCopyWriteTo(t *testing.T) { + rb := new(bytes.Buffer) // implements WriteTo. + wb := new(Buffer) + rb.WriteString("hello, world.") + Copy(wb, rb) + if wb.String() != "hello, world." { + t.Errorf("Copy did not work properly") + } +} + +// Version of bytes.Buffer that checks whether WriteTo was called or not +type writeToChecker struct { + bytes.Buffer + writeToCalled bool +} + +func (wt *writeToChecker) WriteTo(w Writer) (int64, error) { + wt.writeToCalled = true + return wt.Buffer.WriteTo(w) +} + +// It's preferable to choose WriterTo over ReaderFrom, since a WriterTo can issue one large write, +// while the ReaderFrom must read until EOF, potentially allocating when running out of buffer. +// Make sure that we choose WriterTo when both are implemented. +func TestCopyPriority(t *testing.T) { + rb := new(writeToChecker) + wb := new(bytes.Buffer) + rb.WriteString("hello, world.") + Copy(wb, rb) + if wb.String() != "hello, world." { + t.Errorf("Copy did not work properly") + } else if !rb.writeToCalled { + t.Errorf("WriteTo was not prioritized over ReadFrom") + } +} + +type zeroErrReader struct { + err error +} + +func (r zeroErrReader) Read(p []byte) (int, error) { + return copy(p, []byte{0}), r.err +} + +type errWriter struct { + err error +} + +func (w errWriter) Write([]byte) (int, error) { + return 0, w.err +} + +// In case a Read results in an error with non-zero bytes read, and +// the subsequent Write also results in an error, the error from Write +// is returned, as it is the one that prevented progressing further. +func TestCopyReadErrWriteErr(t *testing.T) { + er, ew := errors.New("readError"), errors.New("writeError") + r, w := zeroErrReader{err: er}, errWriter{err: ew} + n, err := Copy(w, r) + if n != 0 || err != ew { + t.Errorf("Copy(zeroErrReader, errWriter) = %d, %v; want 0, writeError", n, err) + } +} + +func TestCopyN(t *testing.T) { + rb := new(Buffer) + wb := new(Buffer) + rb.WriteString("hello, world.") + CopyN(wb, rb, 5) + if wb.String() != "hello" { + t.Errorf("CopyN did not work properly") + } +} + +func TestCopyNReadFrom(t *testing.T) { + rb := new(Buffer) + wb := new(bytes.Buffer) // implements ReadFrom. + rb.WriteString("hello") + CopyN(wb, rb, 5) + if wb.String() != "hello" { + t.Errorf("CopyN did not work properly") + } +} + +func TestCopyNWriteTo(t *testing.T) { + rb := new(bytes.Buffer) // implements WriteTo. + wb := new(Buffer) + rb.WriteString("hello, world.") + CopyN(wb, rb, 5) + if wb.String() != "hello" { + t.Errorf("CopyN did not work properly") + } +} + +func BenchmarkCopyNSmall(b *testing.B) { + bs := bytes.Repeat([]byte{0}, 512+1) + rd := bytes.NewReader(bs) + buf := new(Buffer) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + CopyN(buf, rd, 512) + rd.Reset(bs) + } +} + +func BenchmarkCopyNLarge(b *testing.B) { + bs := bytes.Repeat([]byte{0}, (32*1024)+1) + rd := bytes.NewReader(bs) + buf := new(Buffer) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + CopyN(buf, rd, 32*1024) + rd.Reset(bs) + } +} + +type noReadFrom struct { + w Writer +} + +func (w *noReadFrom) Write(p []byte) (n int, err error) { + return w.w.Write(p) +} + +type wantedAndErrReader struct{} + +func (wantedAndErrReader) Read(p []byte) (int, error) { + return len(p), errors.New("wantedAndErrReader error") +} + +func TestCopyNEOF(t *testing.T) { + // Test that EOF behavior is the same regardless of whether + // argument to CopyN has ReadFrom. + + b := new(bytes.Buffer) + + n, err := CopyN(&noReadFrom{b}, strings.NewReader("foo"), 3) + if n != 3 || err != nil { + t.Errorf("CopyN(noReadFrom, foo, 3) = %d, %v; want 3, nil", n, err) + } + + n, err = CopyN(&noReadFrom{b}, strings.NewReader("foo"), 4) + if n != 3 || err != EOF { + t.Errorf("CopyN(noReadFrom, foo, 4) = %d, %v; want 3, EOF", n, err) + } + + n, err = CopyN(b, strings.NewReader("foo"), 3) // b has read from + if n != 3 || err != nil { + t.Errorf("CopyN(bytes.Buffer, foo, 3) = %d, %v; want 3, nil", n, err) + } + + n, err = CopyN(b, strings.NewReader("foo"), 4) // b has read from + if n != 3 || err != EOF { + t.Errorf("CopyN(bytes.Buffer, foo, 4) = %d, %v; want 3, EOF", n, err) + } + + n, err = CopyN(b, wantedAndErrReader{}, 5) + if n != 5 || err != nil { + t.Errorf("CopyN(bytes.Buffer, wantedAndErrReader, 5) = %d, %v; want 5, nil", n, err) + } + + n, err = CopyN(&noReadFrom{b}, wantedAndErrReader{}, 5) + if n != 5 || err != nil { + t.Errorf("CopyN(noReadFrom, wantedAndErrReader, 5) = %d, %v; want 5, nil", n, err) + } +} + +func TestReadAtLeast(t *testing.T) { + var rb bytes.Buffer + testReadAtLeast(t, &rb) +} + +// A version of bytes.Buffer that returns n > 0, err on Read +// when the input is exhausted. +type dataAndErrorBuffer struct { + err error + bytes.Buffer +} + +func (r *dataAndErrorBuffer) Read(p []byte) (n int, err error) { + n, err = r.Buffer.Read(p) + if n > 0 && r.Buffer.Len() == 0 && err == nil { + err = r.err + } + return +} + +func TestReadAtLeastWithDataAndEOF(t *testing.T) { + var rb dataAndErrorBuffer + rb.err = EOF + testReadAtLeast(t, &rb) +} + +func TestReadAtLeastWithDataAndError(t *testing.T) { + var rb dataAndErrorBuffer + rb.err = fmt.Errorf("fake error") + testReadAtLeast(t, &rb) +} + +func testReadAtLeast(t *testing.T, rb ReadWriter) { + rb.Write([]byte("0123")) + buf := make([]byte, 2) + n, err := ReadAtLeast(rb, buf, 2) + if err != nil { + t.Error(err) + } + if n != 2 { + t.Errorf("expected to have read 2 bytes, got %v", n) + } + n, err = ReadAtLeast(rb, buf, 4) + if err != ErrShortBuffer { + t.Errorf("expected ErrShortBuffer got %v", err) + } + if n != 0 { + t.Errorf("expected to have read 0 bytes, got %v", n) + } + n, err = ReadAtLeast(rb, buf, 1) + if err != nil { + t.Error(err) + } + if n != 2 { + t.Errorf("expected to have read 2 bytes, got %v", n) + } + n, err = ReadAtLeast(rb, buf, 2) + if err != EOF { + t.Errorf("expected EOF, got %v", err) + } + if n != 0 { + t.Errorf("expected to have read 0 bytes, got %v", n) + } + rb.Write([]byte("4")) + n, err = ReadAtLeast(rb, buf, 2) + want := ErrUnexpectedEOF + if rb, ok := rb.(*dataAndErrorBuffer); ok && rb.err != EOF { + want = rb.err + } + if err != want { + t.Errorf("expected %v, got %v", want, err) + } + if n != 1 { + t.Errorf("expected to have read 1 bytes, got %v", n) + } +} + +func TestTeeReader(t *testing.T) { + src := []byte("hello, world") + dst := make([]byte, len(src)) + rb := bytes.NewBuffer(src) + wb := new(bytes.Buffer) + r := TeeReader(rb, wb) + if n, err := ReadFull(r, dst); err != nil || n != len(src) { + t.Fatalf("ReadFull(r, dst) = %d, %v; want %d, nil", n, err, len(src)) + } + if !bytes.Equal(dst, src) { + t.Errorf("bytes read = %q want %q", dst, src) + } + if !bytes.Equal(wb.Bytes(), src) { + t.Errorf("bytes written = %q want %q", wb.Bytes(), src) + } + if n, err := r.Read(dst); n != 0 || err != EOF { + t.Errorf("r.Read at EOF = %d, %v want 0, EOF", n, err) + } + rb = bytes.NewBuffer(src) + pr, pw := Pipe() + pr.Close() + r = TeeReader(rb, pw) + if n, err := ReadFull(r, dst); n != 0 || err != ErrClosedPipe { + t.Errorf("closed tee: ReadFull(r, dst) = %d, %v; want 0, EPIPE", n, err) + } +} + +func TestSectionReader_ReadAt(t *testing.T) { + dat := "a long sample data, 1234567890" + tests := []struct { + data string + off int + n int + bufLen int + at int + exp string + err error + }{ + {data: "", off: 0, n: 10, bufLen: 2, at: 0, exp: "", err: EOF}, + {data: dat, off: 0, n: len(dat), bufLen: 0, at: 0, exp: "", err: nil}, + {data: dat, off: len(dat), n: 1, bufLen: 1, at: 0, exp: "", err: EOF}, + {data: dat, off: 0, n: len(dat) + 2, bufLen: len(dat), at: 0, exp: dat, err: nil}, + {data: dat, off: 0, n: len(dat), bufLen: len(dat) / 2, at: 0, exp: dat[:len(dat)/2], err: nil}, + {data: dat, off: 0, n: len(dat), bufLen: len(dat), at: 0, exp: dat, err: nil}, + {data: dat, off: 0, n: len(dat), bufLen: len(dat) / 2, at: 2, exp: dat[2 : 2+len(dat)/2], err: nil}, + {data: dat, off: 3, n: len(dat), bufLen: len(dat) / 2, at: 2, exp: dat[5 : 5+len(dat)/2], err: nil}, + {data: dat, off: 3, n: len(dat) / 2, bufLen: len(dat)/2 - 2, at: 2, exp: dat[5 : 5+len(dat)/2-2], err: nil}, + {data: dat, off: 3, n: len(dat) / 2, bufLen: len(dat)/2 + 2, at: 2, exp: dat[5 : 5+len(dat)/2-2], err: EOF}, + {data: dat, off: 0, n: 0, bufLen: 0, at: -1, exp: "", err: EOF}, + {data: dat, off: 0, n: 0, bufLen: 0, at: 1, exp: "", err: EOF}, + } + for i, tt := range tests { + r := strings.NewReader(tt.data) + s := NewSectionReader(r, int64(tt.off), int64(tt.n)) + buf := make([]byte, tt.bufLen) + if n, err := s.ReadAt(buf, int64(tt.at)); n != len(tt.exp) || string(buf[:n]) != tt.exp || err != tt.err { + t.Fatalf("%d: ReadAt(%d) = %q, %v; expected %q, %v", i, tt.at, buf[:n], err, tt.exp, tt.err) + } + if _r, off, n := s.Outer(); _r != r || off != int64(tt.off) || n != int64(tt.n) { + t.Fatalf("%d: Outer() = %v, %d, %d; expected %v, %d, %d", i, _r, off, n, r, tt.off, tt.n) + } + } +} + +func TestSectionReader_Seek(t *testing.T) { + // Verifies that NewSectionReader's Seeker behaves like bytes.NewReader (which is like strings.NewReader) + br := bytes.NewReader([]byte("foo")) + sr := NewSectionReader(br, 0, int64(len("foo"))) + + for _, whence := range []int{SeekStart, SeekCurrent, SeekEnd} { + for offset := int64(-3); offset <= 4; offset++ { + brOff, brErr := br.Seek(offset, whence) + srOff, srErr := sr.Seek(offset, whence) + if (brErr != nil) != (srErr != nil) || brOff != srOff { + t.Errorf("For whence %d, offset %d: bytes.Reader.Seek = (%v, %v) != SectionReader.Seek = (%v, %v)", + whence, offset, brOff, brErr, srErr, srOff) + } + } + } + + // And verify we can just seek past the end and get an EOF + got, err := sr.Seek(100, SeekStart) + if err != nil || got != 100 { + t.Errorf("Seek = %v, %v; want 100, nil", got, err) + } + + n, err := sr.Read(make([]byte, 10)) + if n != 0 || err != EOF { + t.Errorf("Read = %v, %v; want 0, EOF", n, err) + } +} + +func TestSectionReader_Size(t *testing.T) { + tests := []struct { + data string + want int64 + }{ + {"a long sample data, 1234567890", 30}, + {"", 0}, + } + + for _, tt := range tests { + r := strings.NewReader(tt.data) + sr := NewSectionReader(r, 0, int64(len(tt.data))) + if got := sr.Size(); got != tt.want { + t.Errorf("Size = %v; want %v", got, tt.want) + } + } +} + +func TestSectionReader_Max(t *testing.T) { + r := strings.NewReader("abcdef") + const maxint64 = 1<<63 - 1 + sr := NewSectionReader(r, 3, maxint64) + n, err := sr.Read(make([]byte, 3)) + if n != 3 || err != nil { + t.Errorf("Read = %v %v, want 3, nil", n, err) + } + n, err = sr.Read(make([]byte, 3)) + if n != 0 || err != EOF { + t.Errorf("Read = %v, %v, want 0, EOF", n, err) + } + if _r, off, n := sr.Outer(); _r != r || off != 3 || n != maxint64 { + t.Fatalf("Outer = %v, %d, %d; expected %v, %d, %d", _r, off, n, r, 3, int64(maxint64)) + } +} + +// largeWriter returns an invalid count that is larger than the number +// of bytes provided (issue 39978). +type largeWriter struct { + err error +} + +func (w largeWriter) Write(p []byte) (int, error) { + return len(p) + 1, w.err +} + +func TestCopyLargeWriter(t *testing.T) { + want := ErrInvalidWrite + rb := new(Buffer) + wb := largeWriter{} + rb.WriteString("hello, world.") + if _, err := Copy(wb, rb); err != want { + t.Errorf("Copy error: got %v, want %v", err, want) + } + + want = errors.New("largeWriterError") + rb = new(Buffer) + wb = largeWriter{err: want} + rb.WriteString("hello, world.") + if _, err := Copy(wb, rb); err != want { + t.Errorf("Copy error: got %v, want %v", err, want) + } +} + +func TestNopCloserWriterToForwarding(t *testing.T) { + for _, tc := range [...]struct { + Name string + r Reader + }{ + {"not a WriterTo", Reader(nil)}, + {"a WriterTo", struct { + Reader + WriterTo + }{}}, + } { + nc := NopCloser(tc.r) + + _, expected := tc.r.(WriterTo) + _, got := nc.(WriterTo) + if expected != got { + t.Errorf("NopCloser incorrectly forwards WriterTo for %s, got %t want %t", tc.Name, got, expected) + } + } +} + +func TestOffsetWriter_Seek(t *testing.T) { + tmpfilename := "TestOffsetWriter_Seek" + tmpfile, err := os.CreateTemp(t.TempDir(), tmpfilename) + if err != nil || tmpfile == nil { + t.Fatalf("CreateTemp(%s) failed: %v", tmpfilename, err) + } + defer tmpfile.Close() + w := NewOffsetWriter(tmpfile, 0) + + // Should throw error errWhence if whence is not valid + t.Run("errWhence", func(t *testing.T) { + for _, whence := range []int{-3, -2, -1, 3, 4, 5} { + var offset int64 = 0 + gotOff, gotErr := w.Seek(offset, whence) + if gotOff != 0 || gotErr != ErrWhence { + t.Errorf("For whence %d, offset %d, OffsetWriter.Seek got: (%d, %v), want: (%d, %v)", + whence, offset, gotOff, gotErr, 0, ErrWhence) + } + } + }) + + // Should throw error errOffset if offset is negative + t.Run("errOffset", func(t *testing.T) { + for _, whence := range []int{SeekStart, SeekCurrent} { + for offset := int64(-3); offset < 0; offset++ { + gotOff, gotErr := w.Seek(offset, whence) + if gotOff != 0 || gotErr != ErrOffset { + t.Errorf("For whence %d, offset %d, OffsetWriter.Seek got: (%d, %v), want: (%d, %v)", + whence, offset, gotOff, gotErr, 0, ErrOffset) + } + } + } + }) + + // Normal tests + t.Run("normal", func(t *testing.T) { + tests := []struct { + offset int64 + whence int + returnOff int64 + }{ + // keep in order + {whence: SeekStart, offset: 1, returnOff: 1}, + {whence: SeekStart, offset: 2, returnOff: 2}, + {whence: SeekStart, offset: 3, returnOff: 3}, + {whence: SeekCurrent, offset: 1, returnOff: 4}, + {whence: SeekCurrent, offset: 2, returnOff: 6}, + {whence: SeekCurrent, offset: 3, returnOff: 9}, + } + for idx, tt := range tests { + gotOff, gotErr := w.Seek(tt.offset, tt.whence) + if gotOff != tt.returnOff || gotErr != nil { + t.Errorf("%d:: For whence %d, offset %d, OffsetWriter.Seek got: (%d, %v), want: (%d, <nil>)", + idx+1, tt.whence, tt.offset, gotOff, gotErr, tt.returnOff) + } + } + }) +} + +func TestOffsetWriter_WriteAt(t *testing.T) { + const content = "0123456789ABCDEF" + contentSize := int64(len(content)) + tmpdir, err := os.MkdirTemp(t.TempDir(), "TestOffsetWriter_WriteAt") + if err != nil { + t.Fatal(err) + } + + work := func(off, at int64) { + position := fmt.Sprintf("off_%d_at_%d", off, at) + tmpfile, err := os.CreateTemp(tmpdir, position) + if err != nil || tmpfile == nil { + t.Fatalf("CreateTemp(%s) failed: %v", position, err) + } + defer tmpfile.Close() + + var writeN int64 + var wg sync.WaitGroup + // Concurrent writes, one byte at a time + for step, value := range []byte(content) { + wg.Add(1) + go func(wg *sync.WaitGroup, tmpfile *os.File, value byte, off, at int64, step int) { + defer wg.Done() + + w := NewOffsetWriter(tmpfile, off) + n, e := w.WriteAt([]byte{value}, at+int64(step)) + if e != nil { + t.Errorf("WriteAt failed. off: %d, at: %d, step: %d\n error: %v", off, at, step, e) + } + atomic.AddInt64(&writeN, int64(n)) + }(&wg, tmpfile, value, off, at, step) + } + wg.Wait() + + // Read one more byte to reach EOF + buf := make([]byte, contentSize+1) + readN, err := tmpfile.ReadAt(buf, off+at) + if err != EOF { + t.Fatalf("ReadAt failed: %v", err) + } + readContent := string(buf[:contentSize]) + if writeN != int64(readN) || writeN != contentSize || readContent != content { + t.Fatalf("%s:: WriteAt(%s, %d) error. \ngot n: %v, content: %s \nexpected n: %v, content: %v", + position, content, at, readN, readContent, contentSize, content) + } + } + for off := int64(0); off < 2; off++ { + for at := int64(0); at < 2; at++ { + work(off, at) + } + } +} + +func TestWriteAt_PositionPriorToBase(t *testing.T) { + tmpdir := t.TempDir() + tmpfilename := "TestOffsetWriter_WriteAt" + tmpfile, err := os.CreateTemp(tmpdir, tmpfilename) + if err != nil { + t.Fatalf("CreateTemp(%s) failed: %v", tmpfilename, err) + } + defer tmpfile.Close() + + // start writing position in OffsetWriter + offset := int64(10) + // position we want to write to the tmpfile + at := int64(-1) + w := NewOffsetWriter(tmpfile, offset) + _, e := w.WriteAt([]byte("hello"), at) + if e == nil { + t.Errorf("error expected to be not nil") + } +} + +func TestOffsetWriter_Write(t *testing.T) { + const content = "0123456789ABCDEF" + contentSize := len(content) + tmpdir := t.TempDir() + + makeOffsetWriter := func(name string) (*OffsetWriter, *os.File) { + tmpfilename := "TestOffsetWriter_Write_" + name + tmpfile, err := os.CreateTemp(tmpdir, tmpfilename) + if err != nil || tmpfile == nil { + t.Fatalf("CreateTemp(%s) failed: %v", tmpfilename, err) + } + return NewOffsetWriter(tmpfile, 0), tmpfile + } + checkContent := func(name string, f *os.File) { + // Read one more byte to reach EOF + buf := make([]byte, contentSize+1) + readN, err := f.ReadAt(buf, 0) + if err != EOF { + t.Fatalf("ReadAt failed, err: %v", err) + } + readContent := string(buf[:contentSize]) + if readN != contentSize || readContent != content { + t.Fatalf("%s error. \ngot n: %v, content: %s \nexpected n: %v, content: %v", + name, readN, readContent, contentSize, content) + } + } + + var name string + name = "Write" + t.Run(name, func(t *testing.T) { + // Write directly (off: 0, at: 0) + // Write content to file + w, f := makeOffsetWriter(name) + defer f.Close() + for _, value := range []byte(content) { + n, err := w.Write([]byte{value}) + if err != nil { + t.Fatalf("Write failed, n: %d, err: %v", n, err) + } + } + checkContent(name, f) + + // Copy -> Write + // Copy file f to file f2 + name = "Copy" + w2, f2 := makeOffsetWriter(name) + defer f2.Close() + Copy(w2, f) + checkContent(name, f2) + }) + + // Copy -> WriteTo -> Write + // Note: strings.Reader implements the io.WriterTo interface. + name = "Write_Of_Copy_WriteTo" + t.Run(name, func(t *testing.T) { + w, f := makeOffsetWriter(name) + defer f.Close() + Copy(w, strings.NewReader(content)) + checkContent(name, f) + }) +} diff --git a/src/io/ioutil/example_test.go b/src/io/ioutil/example_test.go new file mode 100644 index 0000000..78b0730 --- /dev/null +++ b/src/io/ioutil/example_test.go @@ -0,0 +1,132 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ioutil_test + +import ( + "fmt" + "io/ioutil" + "log" + "os" + "path/filepath" + "strings" +) + +func ExampleReadAll() { + r := strings.NewReader("Go is a general-purpose language designed with systems programming in mind.") + + b, err := ioutil.ReadAll(r) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%s", b) + + // Output: + // Go is a general-purpose language designed with systems programming in mind. +} + +func ExampleReadDir() { + files, err := ioutil.ReadDir(".") + if err != nil { + log.Fatal(err) + } + + for _, file := range files { + fmt.Println(file.Name()) + } +} + +func ExampleTempDir() { + content := []byte("temporary file's content") + dir, err := ioutil.TempDir("", "example") + if err != nil { + log.Fatal(err) + } + + defer os.RemoveAll(dir) // clean up + + tmpfn := filepath.Join(dir, "tmpfile") + if err := ioutil.WriteFile(tmpfn, content, 0666); err != nil { + log.Fatal(err) + } +} + +func ExampleTempDir_suffix() { + parentDir := os.TempDir() + logsDir, err := ioutil.TempDir(parentDir, "*-logs") + if err != nil { + log.Fatal(err) + } + defer os.RemoveAll(logsDir) // clean up + + // Logs can be cleaned out earlier if needed by searching + // for all directories whose suffix ends in *-logs. + globPattern := filepath.Join(parentDir, "*-logs") + matches, err := filepath.Glob(globPattern) + if err != nil { + log.Fatalf("Failed to match %q: %v", globPattern, err) + } + + for _, match := range matches { + if err := os.RemoveAll(match); err != nil { + log.Printf("Failed to remove %q: %v", match, err) + } + } +} + +func ExampleTempFile() { + content := []byte("temporary file's content") + tmpfile, err := ioutil.TempFile("", "example") + if err != nil { + log.Fatal(err) + } + + defer os.Remove(tmpfile.Name()) // clean up + + if _, err := tmpfile.Write(content); err != nil { + log.Fatal(err) + } + if err := tmpfile.Close(); err != nil { + log.Fatal(err) + } +} + +func ExampleTempFile_suffix() { + content := []byte("temporary file's content") + tmpfile, err := ioutil.TempFile("", "example.*.txt") + if err != nil { + log.Fatal(err) + } + + defer os.Remove(tmpfile.Name()) // clean up + + if _, err := tmpfile.Write(content); err != nil { + tmpfile.Close() + log.Fatal(err) + } + if err := tmpfile.Close(); err != nil { + log.Fatal(err) + } +} + +func ExampleReadFile() { + content, err := ioutil.ReadFile("testdata/hello") + if err != nil { + log.Fatal(err) + } + + fmt.Printf("File contents: %s", content) + + // Output: + // File contents: Hello, Gophers! +} + +func ExampleWriteFile() { + message := []byte("Hello, Gophers!") + err := ioutil.WriteFile("hello", message, 0644) + if err != nil { + log.Fatal(err) + } +} diff --git a/src/io/ioutil/ioutil.go b/src/io/ioutil/ioutil.go new file mode 100644 index 0000000..67768e5 --- /dev/null +++ b/src/io/ioutil/ioutil.go @@ -0,0 +1,95 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package ioutil implements some I/O utility functions. +// +// Deprecated: As of Go 1.16, the same functionality is now provided +// by package [io] or package [os], and those implementations +// should be preferred in new code. +// See the specific function documentation for details. +package ioutil + +import ( + "io" + "io/fs" + "os" + "sort" +) + +// ReadAll reads from r until an error or EOF and returns the data it read. +// A successful call returns err == nil, not err == EOF. Because ReadAll is +// defined to read from src until EOF, it does not treat an EOF from Read +// as an error to be reported. +// +// Deprecated: As of Go 1.16, this function simply calls [io.ReadAll]. +func ReadAll(r io.Reader) ([]byte, error) { + return io.ReadAll(r) +} + +// ReadFile reads the file named by filename and returns the contents. +// A successful call returns err == nil, not err == EOF. Because ReadFile +// reads the whole file, it does not treat an EOF from Read as an error +// to be reported. +// +// Deprecated: As of Go 1.16, this function simply calls [os.ReadFile]. +func ReadFile(filename string) ([]byte, error) { + return os.ReadFile(filename) +} + +// WriteFile writes data to a file named by filename. +// If the file does not exist, WriteFile creates it with permissions perm +// (before umask); otherwise WriteFile truncates it before writing, without changing permissions. +// +// Deprecated: As of Go 1.16, this function simply calls [os.WriteFile]. +func WriteFile(filename string, data []byte, perm fs.FileMode) error { + return os.WriteFile(filename, data, perm) +} + +// ReadDir reads the directory named by dirname and returns +// a list of fs.FileInfo for the directory's contents, +// sorted by filename. If an error occurs reading the directory, +// ReadDir returns no directory entries along with the error. +// +// Deprecated: As of Go 1.16, [os.ReadDir] is a more efficient and correct choice: +// it returns a list of [fs.DirEntry] instead of [fs.FileInfo], +// and it returns partial results in the case of an error +// midway through reading a directory. +// +// If you must continue obtaining a list of [fs.FileInfo], you still can: +// +// entries, err := os.ReadDir(dirname) +// if err != nil { ... } +// infos := make([]fs.FileInfo, 0, len(entries)) +// for _, entry := range entries { +// info, err := entry.Info() +// if err != nil { ... } +// infos = append(infos, info) +// } +func ReadDir(dirname string) ([]fs.FileInfo, error) { + f, err := os.Open(dirname) + if err != nil { + return nil, err + } + list, err := f.Readdir(-1) + f.Close() + if err != nil { + return nil, err + } + sort.Slice(list, func(i, j int) bool { return list[i].Name() < list[j].Name() }) + return list, nil +} + +// NopCloser returns a ReadCloser with a no-op Close method wrapping +// the provided Reader r. +// +// Deprecated: As of Go 1.16, this function simply calls [io.NopCloser]. +func NopCloser(r io.Reader) io.ReadCloser { + return io.NopCloser(r) +} + +// Discard is an io.Writer on which all Write calls succeed +// without doing anything. +// +// Deprecated: As of Go 1.16, this value is simply [io.Discard]. +var Discard io.Writer = io.Discard diff --git a/src/io/ioutil/ioutil_test.go b/src/io/ioutil/ioutil_test.go new file mode 100644 index 0000000..6bff8c6 --- /dev/null +++ b/src/io/ioutil/ioutil_test.go @@ -0,0 +1,134 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ioutil_test + +import ( + "bytes" + . "io/ioutil" + "os" + "path/filepath" + "runtime" + "testing" +) + +func checkSize(t *testing.T, path string, size int64) { + dir, err := os.Stat(path) + if err != nil { + t.Fatalf("Stat %q (looking for size %d): %s", path, size, err) + } + if dir.Size() != size { + t.Errorf("Stat %q: size %d want %d", path, dir.Size(), size) + } +} + +func TestReadFile(t *testing.T) { + filename := "rumpelstilzchen" + contents, err := ReadFile(filename) + if err == nil { + t.Fatalf("ReadFile %s: error expected, none found", filename) + } + + filename = "ioutil_test.go" + contents, err = ReadFile(filename) + if err != nil { + t.Fatalf("ReadFile %s: %v", filename, err) + } + + checkSize(t, filename, int64(len(contents))) +} + +func TestWriteFile(t *testing.T) { + f, err := TempFile("", "ioutil-test") + if err != nil { + t.Fatal(err) + } + filename := f.Name() + data := "Programming today is a race between software engineers striving to " + + "build bigger and better idiot-proof programs, and the Universe trying " + + "to produce bigger and better idiots. So far, the Universe is winning." + + if err := WriteFile(filename, []byte(data), 0644); err != nil { + t.Fatalf("WriteFile %s: %v", filename, err) + } + + contents, err := ReadFile(filename) + if err != nil { + t.Fatalf("ReadFile %s: %v", filename, err) + } + + if string(contents) != data { + t.Fatalf("contents = %q\nexpected = %q", string(contents), data) + } + + // cleanup + f.Close() + os.Remove(filename) // ignore error +} + +func TestReadOnlyWriteFile(t *testing.T) { + if os.Getuid() == 0 { + t.Skipf("Root can write to read-only files anyway, so skip the read-only test.") + } + if runtime.GOOS == "wasip1" { + t.Skip("file permissions are not supported by wasip1") + } + + // We don't want to use TempFile directly, since that opens a file for us as 0600. + tempDir, err := TempDir("", t.Name()) + if err != nil { + t.Fatalf("TempDir %s: %v", t.Name(), err) + } + defer os.RemoveAll(tempDir) + filename := filepath.Join(tempDir, "blurp.txt") + + shmorp := []byte("shmorp") + florp := []byte("florp") + err = WriteFile(filename, shmorp, 0444) + if err != nil { + t.Fatalf("WriteFile %s: %v", filename, err) + } + err = WriteFile(filename, florp, 0444) + if err == nil { + t.Fatalf("Expected an error when writing to read-only file %s", filename) + } + got, err := ReadFile(filename) + if err != nil { + t.Fatalf("ReadFile %s: %v", filename, err) + } + if !bytes.Equal(got, shmorp) { + t.Fatalf("want %s, got %s", shmorp, got) + } +} + +func TestReadDir(t *testing.T) { + dirname := "rumpelstilzchen" + _, err := ReadDir(dirname) + if err == nil { + t.Fatalf("ReadDir %s: error expected, none found", dirname) + } + + dirname = ".." + list, err := ReadDir(dirname) + if err != nil { + t.Fatalf("ReadDir %s: %v", dirname, err) + } + + foundFile := false + foundSubDir := false + for _, dir := range list { + switch { + case !dir.IsDir() && dir.Name() == "io_test.go": + foundFile = true + case dir.IsDir() && dir.Name() == "ioutil": + foundSubDir = true + } + } + if !foundFile { + t.Fatalf("ReadDir %s: io_test.go file not found", dirname) + } + if !foundSubDir { + t.Fatalf("ReadDir %s: ioutil directory not found", dirname) + } +} diff --git a/src/io/ioutil/tempfile.go b/src/io/ioutil/tempfile.go new file mode 100644 index 0000000..47b2e40 --- /dev/null +++ b/src/io/ioutil/tempfile.go @@ -0,0 +1,41 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ioutil + +import ( + "os" +) + +// TempFile creates a new temporary file in the directory dir, +// opens the file for reading and writing, and returns the resulting *[os.File]. +// The filename is generated by taking pattern and adding a random +// string to the end. If pattern includes a "*", the random string +// replaces the last "*". +// If dir is the empty string, TempFile uses the default directory +// for temporary files (see [os.TempDir]). +// Multiple programs calling TempFile simultaneously +// will not choose the same file. The caller can use f.Name() +// to find the pathname of the file. It is the caller's responsibility +// to remove the file when no longer needed. +// +// Deprecated: As of Go 1.17, this function simply calls [os.CreateTemp]. +func TempFile(dir, pattern string) (f *os.File, err error) { + return os.CreateTemp(dir, pattern) +} + +// TempDir creates a new temporary directory in the directory dir. +// The directory name is generated by taking pattern and applying a +// random string to the end. If pattern includes a "*", the random string +// replaces the last "*". TempDir returns the name of the new directory. +// If dir is the empty string, TempDir uses the +// default directory for temporary files (see [os.TempDir]). +// Multiple programs calling TempDir simultaneously +// will not choose the same directory. It is the caller's responsibility +// to remove the directory when no longer needed. +// +// Deprecated: As of Go 1.17, this function simply calls [os.MkdirTemp]. +func TempDir(dir, pattern string) (name string, err error) { + return os.MkdirTemp(dir, pattern) +} diff --git a/src/io/ioutil/tempfile_test.go b/src/io/ioutil/tempfile_test.go new file mode 100644 index 0000000..818fcda --- /dev/null +++ b/src/io/ioutil/tempfile_test.go @@ -0,0 +1,196 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ioutil_test + +import ( + "io/fs" + . "io/ioutil" + "os" + "path/filepath" + "regexp" + "strings" + "testing" +) + +func TestTempFile(t *testing.T) { + dir, err := TempDir("", "TestTempFile_BadDir") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + nonexistentDir := filepath.Join(dir, "_not_exists_") + f, err := TempFile(nonexistentDir, "foo") + if f != nil || err == nil { + t.Errorf("TempFile(%q, `foo`) = %v, %v", nonexistentDir, f, err) + } +} + +func TestTempFile_pattern(t *testing.T) { + tests := []struct{ pattern, prefix, suffix string }{ + {"ioutil_test", "ioutil_test", ""}, + {"ioutil_test*", "ioutil_test", ""}, + {"ioutil_test*xyz", "ioutil_test", "xyz"}, + } + for _, test := range tests { + f, err := TempFile("", test.pattern) + if err != nil { + t.Errorf("TempFile(..., %q) error: %v", test.pattern, err) + continue + } + defer os.Remove(f.Name()) + base := filepath.Base(f.Name()) + f.Close() + if !(strings.HasPrefix(base, test.prefix) && strings.HasSuffix(base, test.suffix)) { + t.Errorf("TempFile pattern %q created bad name %q; want prefix %q & suffix %q", + test.pattern, base, test.prefix, test.suffix) + } + } +} + +// This string is from os.errPatternHasSeparator. +const patternHasSeparator = "pattern contains path separator" + +func TestTempFile_BadPattern(t *testing.T) { + tmpDir, err := TempDir("", t.Name()) + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + const sep = string(os.PathSeparator) + tests := []struct { + pattern string + wantErr bool + }{ + {"ioutil*test", false}, + {"ioutil_test*foo", false}, + {"ioutil_test" + sep + "foo", true}, + {"ioutil_test*" + sep + "foo", true}, + {"ioutil_test" + sep + "*foo", true}, + {sep + "ioutil_test" + sep + "*foo", true}, + {"ioutil_test*foo" + sep, true}, + } + for _, tt := range tests { + t.Run(tt.pattern, func(t *testing.T) { + tmpfile, err := TempFile(tmpDir, tt.pattern) + defer func() { + if tmpfile != nil { + tmpfile.Close() + } + }() + if tt.wantErr { + if err == nil { + t.Errorf("Expected an error for pattern %q", tt.pattern) + } else if !strings.Contains(err.Error(), patternHasSeparator) { + t.Errorf("Error mismatch: got %#v, want %q for pattern %q", err, patternHasSeparator, tt.pattern) + } + } else if err != nil { + t.Errorf("Unexpected error %v for pattern %q", err, tt.pattern) + } + }) + } +} + +func TestTempDir(t *testing.T) { + name, err := TempDir("/_not_exists_", "foo") + if name != "" || err == nil { + t.Errorf("TempDir(`/_not_exists_`, `foo`) = %v, %v", name, err) + } + + tests := []struct { + pattern string + wantPrefix, wantSuffix string + }{ + {"ioutil_test", "ioutil_test", ""}, + {"ioutil_test*", "ioutil_test", ""}, + {"ioutil_test*xyz", "ioutil_test", "xyz"}, + } + + dir := os.TempDir() + + runTestTempDir := func(t *testing.T, pattern, wantRePat string) { + name, err := TempDir(dir, pattern) + if name == "" || err != nil { + t.Fatalf("TempDir(dir, `ioutil_test`) = %v, %v", name, err) + } + defer os.Remove(name) + + re := regexp.MustCompile(wantRePat) + if !re.MatchString(name) { + t.Errorf("TempDir(%q, %q) created bad name\n\t%q\ndid not match pattern\n\t%q", dir, pattern, name, wantRePat) + } + } + + for _, tt := range tests { + t.Run(tt.pattern, func(t *testing.T) { + wantRePat := "^" + regexp.QuoteMeta(filepath.Join(dir, tt.wantPrefix)) + "[0-9]+" + regexp.QuoteMeta(tt.wantSuffix) + "$" + runTestTempDir(t, tt.pattern, wantRePat) + }) + } + + // Separately testing "*xyz" (which has no prefix). That is when constructing the + // pattern to assert on, as in the previous loop, using filepath.Join for an empty + // prefix filepath.Join(dir, ""), produces the pattern: + // ^<DIR>[0-9]+xyz$ + // yet we just want to match + // "^<DIR>/[0-9]+xyz" + t.Run("*xyz", func(t *testing.T) { + wantRePat := "^" + regexp.QuoteMeta(filepath.Join(dir)) + regexp.QuoteMeta(string(filepath.Separator)) + "[0-9]+xyz$" + runTestTempDir(t, "*xyz", wantRePat) + }) +} + +// test that we return a nice error message if the dir argument to TempDir doesn't +// exist (or that it's empty and os.TempDir doesn't exist) +func TestTempDir_BadDir(t *testing.T) { + dir, err := TempDir("", "TestTempDir_BadDir") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + badDir := filepath.Join(dir, "not-exist") + _, err = TempDir(badDir, "foo") + if pe, ok := err.(*fs.PathError); !ok || !os.IsNotExist(err) || pe.Path != badDir { + t.Errorf("TempDir error = %#v; want PathError for path %q satisfying os.IsNotExist", err, badDir) + } +} + +func TestTempDir_BadPattern(t *testing.T) { + tmpDir, err := TempDir("", t.Name()) + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + + const sep = string(os.PathSeparator) + tests := []struct { + pattern string + wantErr bool + }{ + {"ioutil*test", false}, + {"ioutil_test*foo", false}, + {"ioutil_test" + sep + "foo", true}, + {"ioutil_test*" + sep + "foo", true}, + {"ioutil_test" + sep + "*foo", true}, + {sep + "ioutil_test" + sep + "*foo", true}, + {"ioutil_test*foo" + sep, true}, + } + for _, tt := range tests { + t.Run(tt.pattern, func(t *testing.T) { + _, err := TempDir(tmpDir, tt.pattern) + if tt.wantErr { + if err == nil { + t.Errorf("Expected an error for pattern %q", tt.pattern) + } else if !strings.Contains(err.Error(), patternHasSeparator) { + t.Errorf("Error mismatch: got %#v, want %q for pattern %q", err, patternHasSeparator, tt.pattern) + } + } else if err != nil { + t.Errorf("Unexpected error %v for pattern %q", err, tt.pattern) + } + }) + } +} diff --git a/src/io/ioutil/testdata/hello b/src/io/ioutil/testdata/hello new file mode 100644 index 0000000..e47c092 --- /dev/null +++ b/src/io/ioutil/testdata/hello @@ -0,0 +1 @@ +Hello, Gophers! diff --git a/src/io/multi.go b/src/io/multi.go new file mode 100644 index 0000000..07a9aff --- /dev/null +++ b/src/io/multi.go @@ -0,0 +1,137 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package io + +type eofReader struct{} + +func (eofReader) Read([]byte) (int, error) { + return 0, EOF +} + +type multiReader struct { + readers []Reader +} + +func (mr *multiReader) Read(p []byte) (n int, err error) { + for len(mr.readers) > 0 { + // Optimization to flatten nested multiReaders (Issue 13558). + if len(mr.readers) == 1 { + if r, ok := mr.readers[0].(*multiReader); ok { + mr.readers = r.readers + continue + } + } + n, err = mr.readers[0].Read(p) + if err == EOF { + // Use eofReader instead of nil to avoid nil panic + // after performing flatten (Issue 18232). + mr.readers[0] = eofReader{} // permit earlier GC + mr.readers = mr.readers[1:] + } + if n > 0 || err != EOF { + if err == EOF && len(mr.readers) > 0 { + // Don't return EOF yet. More readers remain. + err = nil + } + return + } + } + return 0, EOF +} + +func (mr *multiReader) WriteTo(w Writer) (sum int64, err error) { + return mr.writeToWithBuffer(w, make([]byte, 1024*32)) +} + +func (mr *multiReader) writeToWithBuffer(w Writer, buf []byte) (sum int64, err error) { + for i, r := range mr.readers { + var n int64 + if subMr, ok := r.(*multiReader); ok { // reuse buffer with nested multiReaders + n, err = subMr.writeToWithBuffer(w, buf) + } else { + n, err = copyBuffer(w, r, buf) + } + sum += n + if err != nil { + mr.readers = mr.readers[i:] // permit resume / retry after error + return sum, err + } + mr.readers[i] = nil // permit early GC + } + mr.readers = nil + return sum, nil +} + +var _ WriterTo = (*multiReader)(nil) + +// MultiReader returns a Reader that's the logical concatenation of +// the provided input readers. They're read sequentially. Once all +// inputs have returned EOF, Read will return EOF. If any of the readers +// return a non-nil, non-EOF error, Read will return that error. +func MultiReader(readers ...Reader) Reader { + r := make([]Reader, len(readers)) + copy(r, readers) + return &multiReader{r} +} + +type multiWriter struct { + writers []Writer +} + +func (t *multiWriter) Write(p []byte) (n int, err error) { + for _, w := range t.writers { + n, err = w.Write(p) + if err != nil { + return + } + if n != len(p) { + err = ErrShortWrite + return + } + } + return len(p), nil +} + +var _ StringWriter = (*multiWriter)(nil) + +func (t *multiWriter) WriteString(s string) (n int, err error) { + var p []byte // lazily initialized if/when needed + for _, w := range t.writers { + if sw, ok := w.(StringWriter); ok { + n, err = sw.WriteString(s) + } else { + if p == nil { + p = []byte(s) + } + n, err = w.Write(p) + } + if err != nil { + return + } + if n != len(s) { + err = ErrShortWrite + return + } + } + return len(s), nil +} + +// MultiWriter creates a writer that duplicates its writes to all the +// provided writers, similar to the Unix tee(1) command. +// +// Each write is written to each listed writer, one at a time. +// If a listed writer returns an error, that overall write operation +// stops and returns the error; it does not continue down the list. +func MultiWriter(writers ...Writer) Writer { + allWriters := make([]Writer, 0, len(writers)) + for _, w := range writers { + if mw, ok := w.(*multiWriter); ok { + allWriters = append(allWriters, mw.writers...) + } else { + allWriters = append(allWriters, w) + } + } + return &multiWriter{allWriters} +} diff --git a/src/io/multi_test.go b/src/io/multi_test.go new file mode 100644 index 0000000..7a24a8a --- /dev/null +++ b/src/io/multi_test.go @@ -0,0 +1,379 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package io_test + +import ( + "bytes" + "crypto/sha1" + "errors" + "fmt" + . "io" + "runtime" + "strings" + "testing" + "time" +) + +func TestMultiReader(t *testing.T) { + var mr Reader + var buf []byte + nread := 0 + withFooBar := func(tests func()) { + r1 := strings.NewReader("foo ") + r2 := strings.NewReader("") + r3 := strings.NewReader("bar") + mr = MultiReader(r1, r2, r3) + buf = make([]byte, 20) + tests() + } + expectRead := func(size int, expected string, eerr error) { + nread++ + n, gerr := mr.Read(buf[0:size]) + if n != len(expected) { + t.Errorf("#%d, expected %d bytes; got %d", + nread, len(expected), n) + } + got := string(buf[0:n]) + if got != expected { + t.Errorf("#%d, expected %q; got %q", + nread, expected, got) + } + if gerr != eerr { + t.Errorf("#%d, expected error %v; got %v", + nread, eerr, gerr) + } + buf = buf[n:] + } + withFooBar(func() { + expectRead(2, "fo", nil) + expectRead(5, "o ", nil) + expectRead(5, "bar", nil) + expectRead(5, "", EOF) + }) + withFooBar(func() { + expectRead(4, "foo ", nil) + expectRead(1, "b", nil) + expectRead(3, "ar", nil) + expectRead(1, "", EOF) + }) + withFooBar(func() { + expectRead(5, "foo ", nil) + }) +} + +func TestMultiReaderAsWriterTo(t *testing.T) { + mr := MultiReader( + strings.NewReader("foo "), + MultiReader( // Tickle the buffer reusing codepath + strings.NewReader(""), + strings.NewReader("bar"), + ), + ) + mrAsWriterTo, ok := mr.(WriterTo) + if !ok { + t.Fatalf("expected cast to WriterTo to succeed") + } + sink := &strings.Builder{} + n, err := mrAsWriterTo.WriteTo(sink) + if err != nil { + t.Fatalf("expected no error; got %v", err) + } + if n != 7 { + t.Errorf("expected read 7 bytes; got %d", n) + } + if result := sink.String(); result != "foo bar" { + t.Errorf(`expected "foo bar"; got %q`, result) + } +} + +func TestMultiWriter(t *testing.T) { + sink := new(bytes.Buffer) + // Hide bytes.Buffer's WriteString method: + testMultiWriter(t, struct { + Writer + fmt.Stringer + }{sink, sink}) +} + +func TestMultiWriter_String(t *testing.T) { + testMultiWriter(t, new(bytes.Buffer)) +} + +// Test that a multiWriter.WriteString calls results in at most 1 allocation, +// even if multiple targets don't support WriteString. +func TestMultiWriter_WriteStringSingleAlloc(t *testing.T) { + var sink1, sink2 bytes.Buffer + type simpleWriter struct { // hide bytes.Buffer's WriteString + Writer + } + mw := MultiWriter(simpleWriter{&sink1}, simpleWriter{&sink2}) + allocs := int(testing.AllocsPerRun(1000, func() { + WriteString(mw, "foo") + })) + if allocs != 1 { + t.Errorf("num allocations = %d; want 1", allocs) + } +} + +type writeStringChecker struct{ called bool } + +func (c *writeStringChecker) WriteString(s string) (n int, err error) { + c.called = true + return len(s), nil +} + +func (c *writeStringChecker) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func TestMultiWriter_StringCheckCall(t *testing.T) { + var c writeStringChecker + mw := MultiWriter(&c) + WriteString(mw, "foo") + if !c.called { + t.Error("did not see WriteString call to writeStringChecker") + } +} + +func testMultiWriter(t *testing.T, sink interface { + Writer + fmt.Stringer +}) { + sha1 := sha1.New() + mw := MultiWriter(sha1, sink) + + sourceString := "My input text." + source := strings.NewReader(sourceString) + written, err := Copy(mw, source) + + if written != int64(len(sourceString)) { + t.Errorf("short write of %d, not %d", written, len(sourceString)) + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + sha1hex := fmt.Sprintf("%x", sha1.Sum(nil)) + if sha1hex != "01cb303fa8c30a64123067c5aa6284ba7ec2d31b" { + t.Error("incorrect sha1 value") + } + + if sink.String() != sourceString { + t.Errorf("expected %q; got %q", sourceString, sink.String()) + } +} + +// writerFunc is a Writer implemented by the underlying func. +type writerFunc func(p []byte) (int, error) + +func (f writerFunc) Write(p []byte) (int, error) { + return f(p) +} + +// Test that MultiWriter properly flattens chained multiWriters. +func TestMultiWriterSingleChainFlatten(t *testing.T) { + pc := make([]uintptr, 1000) // 1000 should fit the full stack + n := runtime.Callers(0, pc) + var myDepth = callDepth(pc[:n]) + var writeDepth int // will contain the depth from which writerFunc.Writer was called + var w Writer = MultiWriter(writerFunc(func(p []byte) (int, error) { + n := runtime.Callers(1, pc) + writeDepth += callDepth(pc[:n]) + return 0, nil + })) + + mw := w + // chain a bunch of multiWriters + for i := 0; i < 100; i++ { + mw = MultiWriter(w) + } + + mw = MultiWriter(w, mw, w, mw) + mw.Write(nil) // don't care about errors, just want to check the call-depth for Write + + if writeDepth != 4*(myDepth+2) { // 2 should be multiWriter.Write and writerFunc.Write + t.Errorf("multiWriter did not flatten chained multiWriters: expected writeDepth %d, got %d", + 4*(myDepth+2), writeDepth) + } +} + +func TestMultiWriterError(t *testing.T) { + f1 := writerFunc(func(p []byte) (int, error) { + return len(p) / 2, ErrShortWrite + }) + f2 := writerFunc(func(p []byte) (int, error) { + t.Errorf("MultiWriter called f2.Write") + return len(p), nil + }) + w := MultiWriter(f1, f2) + n, err := w.Write(make([]byte, 100)) + if n != 50 || err != ErrShortWrite { + t.Errorf("Write = %d, %v, want 50, ErrShortWrite", n, err) + } +} + +// Test that MultiReader copies the input slice and is insulated from future modification. +func TestMultiReaderCopy(t *testing.T) { + slice := []Reader{strings.NewReader("hello world")} + r := MultiReader(slice...) + slice[0] = nil + data, err := ReadAll(r) + if err != nil || string(data) != "hello world" { + t.Errorf("ReadAll() = %q, %v, want %q, nil", data, err, "hello world") + } +} + +// Test that MultiWriter copies the input slice and is insulated from future modification. +func TestMultiWriterCopy(t *testing.T) { + var buf strings.Builder + slice := []Writer{&buf} + w := MultiWriter(slice...) + slice[0] = nil + n, err := w.Write([]byte("hello world")) + if err != nil || n != 11 { + t.Errorf("Write(`hello world`) = %d, %v, want 11, nil", n, err) + } + if buf.String() != "hello world" { + t.Errorf("buf.String() = %q, want %q", buf.String(), "hello world") + } +} + +// readerFunc is a Reader implemented by the underlying func. +type readerFunc func(p []byte) (int, error) + +func (f readerFunc) Read(p []byte) (int, error) { + return f(p) +} + +// callDepth returns the logical call depth for the given PCs. +func callDepth(callers []uintptr) (depth int) { + frames := runtime.CallersFrames(callers) + more := true + for more { + _, more = frames.Next() + depth++ + } + return +} + +// Test that MultiReader properly flattens chained multiReaders when Read is called +func TestMultiReaderFlatten(t *testing.T) { + pc := make([]uintptr, 1000) // 1000 should fit the full stack + n := runtime.Callers(0, pc) + var myDepth = callDepth(pc[:n]) + var readDepth int // will contain the depth from which fakeReader.Read was called + var r Reader = MultiReader(readerFunc(func(p []byte) (int, error) { + n := runtime.Callers(1, pc) + readDepth = callDepth(pc[:n]) + return 0, errors.New("irrelevant") + })) + + // chain a bunch of multiReaders + for i := 0; i < 100; i++ { + r = MultiReader(r) + } + + r.Read(nil) // don't care about errors, just want to check the call-depth for Read + + if readDepth != myDepth+2 { // 2 should be multiReader.Read and fakeReader.Read + t.Errorf("multiReader did not flatten chained multiReaders: expected readDepth %d, got %d", + myDepth+2, readDepth) + } +} + +// byteAndEOFReader is a Reader which reads one byte (the underlying +// byte) and EOF at once in its Read call. +type byteAndEOFReader byte + +func (b byteAndEOFReader) Read(p []byte) (n int, err error) { + if len(p) == 0 { + // Read(0 bytes) is useless. We expect no such useless + // calls in this test. + panic("unexpected call") + } + p[0] = byte(b) + return 1, EOF +} + +// This used to yield bytes forever; issue 16795. +func TestMultiReaderSingleByteWithEOF(t *testing.T) { + got, err := ReadAll(LimitReader(MultiReader(byteAndEOFReader('a'), byteAndEOFReader('b')), 10)) + if err != nil { + t.Fatal(err) + } + const want = "ab" + if string(got) != want { + t.Errorf("got %q; want %q", got, want) + } +} + +// Test that a reader returning (n, EOF) at the end of a MultiReader +// chain continues to return EOF on its final read, rather than +// yielding a (0, EOF). +func TestMultiReaderFinalEOF(t *testing.T) { + r := MultiReader(bytes.NewReader(nil), byteAndEOFReader('a')) + buf := make([]byte, 2) + n, err := r.Read(buf) + if n != 1 || err != EOF { + t.Errorf("got %v, %v; want 1, EOF", n, err) + } +} + +func TestMultiReaderFreesExhaustedReaders(t *testing.T) { + var mr Reader + closed := make(chan struct{}) + // The closure ensures that we don't have a live reference to buf1 + // on our stack after MultiReader is inlined (Issue 18819). This + // is a work around for a limitation in liveness analysis. + func() { + buf1 := bytes.NewReader([]byte("foo")) + buf2 := bytes.NewReader([]byte("bar")) + mr = MultiReader(buf1, buf2) + runtime.SetFinalizer(buf1, func(*bytes.Reader) { + close(closed) + }) + }() + + buf := make([]byte, 4) + if n, err := ReadFull(mr, buf); err != nil || string(buf) != "foob" { + t.Fatalf(`ReadFull = %d (%q), %v; want 3, "foo", nil`, n, buf[:n], err) + } + + runtime.GC() + select { + case <-closed: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for collection of buf1") + } + + if n, err := ReadFull(mr, buf[:2]); err != nil || string(buf[:2]) != "ar" { + t.Fatalf(`ReadFull = %d (%q), %v; want 2, "ar", nil`, n, buf[:n], err) + } +} + +func TestInterleavedMultiReader(t *testing.T) { + r1 := strings.NewReader("123") + r2 := strings.NewReader("45678") + + mr1 := MultiReader(r1, r2) + mr2 := MultiReader(mr1) + + buf := make([]byte, 4) + + // Have mr2 use mr1's []Readers. + // Consume r1 (and clear it for GC to handle) and consume part of r2. + n, err := ReadFull(mr2, buf) + if got := string(buf[:n]); got != "1234" || err != nil { + t.Errorf(`ReadFull(mr2) = (%q, %v), want ("1234", nil)`, got, err) + } + + // Consume the rest of r2 via mr1. + // This should not panic even though mr2 cleared r1. + n, err = ReadFull(mr1, buf) + if got := string(buf[:n]); got != "5678" || err != nil { + t.Errorf(`ReadFull(mr1) = (%q, %v), want ("5678", nil)`, got, err) + } +} diff --git a/src/io/pipe.go b/src/io/pipe.go new file mode 100644 index 0000000..f34cf25 --- /dev/null +++ b/src/io/pipe.go @@ -0,0 +1,202 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Pipe adapter to connect code expecting an io.Reader +// with code expecting an io.Writer. + +package io + +import ( + "errors" + "sync" +) + +// onceError is an object that will only store an error once. +type onceError struct { + sync.Mutex // guards following + err error +} + +func (a *onceError) Store(err error) { + a.Lock() + defer a.Unlock() + if a.err != nil { + return + } + a.err = err +} +func (a *onceError) Load() error { + a.Lock() + defer a.Unlock() + return a.err +} + +// ErrClosedPipe is the error used for read or write operations on a closed pipe. +var ErrClosedPipe = errors.New("io: read/write on closed pipe") + +// A pipe is the shared pipe structure underlying PipeReader and PipeWriter. +type pipe struct { + wrMu sync.Mutex // Serializes Write operations + wrCh chan []byte + rdCh chan int + + once sync.Once // Protects closing done + done chan struct{} + rerr onceError + werr onceError +} + +func (p *pipe) read(b []byte) (n int, err error) { + select { + case <-p.done: + return 0, p.readCloseError() + default: + } + + select { + case bw := <-p.wrCh: + nr := copy(b, bw) + p.rdCh <- nr + return nr, nil + case <-p.done: + return 0, p.readCloseError() + } +} + +func (p *pipe) closeRead(err error) error { + if err == nil { + err = ErrClosedPipe + } + p.rerr.Store(err) + p.once.Do(func() { close(p.done) }) + return nil +} + +func (p *pipe) write(b []byte) (n int, err error) { + select { + case <-p.done: + return 0, p.writeCloseError() + default: + p.wrMu.Lock() + defer p.wrMu.Unlock() + } + + for once := true; once || len(b) > 0; once = false { + select { + case p.wrCh <- b: + nw := <-p.rdCh + b = b[nw:] + n += nw + case <-p.done: + return n, p.writeCloseError() + } + } + return n, nil +} + +func (p *pipe) closeWrite(err error) error { + if err == nil { + err = EOF + } + p.werr.Store(err) + p.once.Do(func() { close(p.done) }) + return nil +} + +// readCloseError is considered internal to the pipe type. +func (p *pipe) readCloseError() error { + rerr := p.rerr.Load() + if werr := p.werr.Load(); rerr == nil && werr != nil { + return werr + } + return ErrClosedPipe +} + +// writeCloseError is considered internal to the pipe type. +func (p *pipe) writeCloseError() error { + werr := p.werr.Load() + if rerr := p.rerr.Load(); werr == nil && rerr != nil { + return rerr + } + return ErrClosedPipe +} + +// A PipeReader is the read half of a pipe. +type PipeReader struct{ pipe } + +// Read implements the standard Read interface: +// it reads data from the pipe, blocking until a writer +// arrives or the write end is closed. +// If the write end is closed with an error, that error is +// returned as err; otherwise err is EOF. +func (r *PipeReader) Read(data []byte) (n int, err error) { + return r.pipe.read(data) +} + +// Close closes the reader; subsequent writes to the +// write half of the pipe will return the error [ErrClosedPipe]. +func (r *PipeReader) Close() error { + return r.CloseWithError(nil) +} + +// CloseWithError closes the reader; subsequent writes +// to the write half of the pipe will return the error err. +// +// CloseWithError never overwrites the previous error if it exists +// and always returns nil. +func (r *PipeReader) CloseWithError(err error) error { + return r.pipe.closeRead(err) +} + +// A PipeWriter is the write half of a pipe. +type PipeWriter struct{ r PipeReader } + +// Write implements the standard Write interface: +// it writes data to the pipe, blocking until one or more readers +// have consumed all the data or the read end is closed. +// If the read end is closed with an error, that err is +// returned as err; otherwise err is [ErrClosedPipe]. +func (w *PipeWriter) Write(data []byte) (n int, err error) { + return w.r.pipe.write(data) +} + +// Close closes the writer; subsequent reads from the +// read half of the pipe will return no bytes and EOF. +func (w *PipeWriter) Close() error { + return w.CloseWithError(nil) +} + +// CloseWithError closes the writer; subsequent reads from the +// read half of the pipe will return no bytes and the error err, +// or EOF if err is nil. +// +// CloseWithError never overwrites the previous error if it exists +// and always returns nil. +func (w *PipeWriter) CloseWithError(err error) error { + return w.r.pipe.closeWrite(err) +} + +// Pipe creates a synchronous in-memory pipe. +// It can be used to connect code expecting an [io.Reader] +// with code expecting an [io.Writer]. +// +// Reads and Writes on the pipe are matched one to one +// except when multiple Reads are needed to consume a single Write. +// That is, each Write to the [PipeWriter] blocks until it has satisfied +// one or more Reads from the [PipeReader] that fully consume +// the written data. +// The data is copied directly from the Write to the corresponding +// Read (or Reads); there is no internal buffering. +// +// It is safe to call Read and Write in parallel with each other or with Close. +// Parallel calls to Read and parallel calls to Write are also safe: +// the individual calls will be gated sequentially. +func Pipe() (*PipeReader, *PipeWriter) { + pw := &PipeWriter{r: PipeReader{pipe: pipe{ + wrCh: make(chan []byte), + rdCh: make(chan int), + done: make(chan struct{}), + }}} + return &pw.r, pw +} diff --git a/src/io/pipe_test.go b/src/io/pipe_test.go new file mode 100644 index 0000000..8973360 --- /dev/null +++ b/src/io/pipe_test.go @@ -0,0 +1,423 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package io_test + +import ( + "bytes" + "fmt" + . "io" + "sort" + "strings" + "testing" + "time" +) + +func checkWrite(t *testing.T, w Writer, data []byte, c chan int) { + n, err := w.Write(data) + if err != nil { + t.Errorf("write: %v", err) + } + if n != len(data) { + t.Errorf("short write: %d != %d", n, len(data)) + } + c <- 0 +} + +// Test a single read/write pair. +func TestPipe1(t *testing.T) { + c := make(chan int) + r, w := Pipe() + var buf = make([]byte, 64) + go checkWrite(t, w, []byte("hello, world"), c) + n, err := r.Read(buf) + if err != nil { + t.Errorf("read: %v", err) + } else if n != 12 || string(buf[0:12]) != "hello, world" { + t.Errorf("bad read: got %q", buf[0:n]) + } + <-c + r.Close() + w.Close() +} + +func reader(t *testing.T, r Reader, c chan int) { + var buf = make([]byte, 64) + for { + n, err := r.Read(buf) + if err == EOF { + c <- 0 + break + } + if err != nil { + t.Errorf("read: %v", err) + } + c <- n + } +} + +// Test a sequence of read/write pairs. +func TestPipe2(t *testing.T) { + c := make(chan int) + r, w := Pipe() + go reader(t, r, c) + var buf = make([]byte, 64) + for i := 0; i < 5; i++ { + p := buf[0 : 5+i*10] + n, err := w.Write(p) + if n != len(p) { + t.Errorf("wrote %d, got %d", len(p), n) + } + if err != nil { + t.Errorf("write: %v", err) + } + nn := <-c + if nn != n { + t.Errorf("wrote %d, read got %d", n, nn) + } + } + w.Close() + nn := <-c + if nn != 0 { + t.Errorf("final read got %d", nn) + } +} + +type pipeReturn struct { + n int + err error +} + +// Test a large write that requires multiple reads to satisfy. +func writer(w WriteCloser, buf []byte, c chan pipeReturn) { + n, err := w.Write(buf) + w.Close() + c <- pipeReturn{n, err} +} + +func TestPipe3(t *testing.T) { + c := make(chan pipeReturn) + r, w := Pipe() + var wdat = make([]byte, 128) + for i := 0; i < len(wdat); i++ { + wdat[i] = byte(i) + } + go writer(w, wdat, c) + var rdat = make([]byte, 1024) + tot := 0 + for n := 1; n <= 256; n *= 2 { + nn, err := r.Read(rdat[tot : tot+n]) + if err != nil && err != EOF { + t.Fatalf("read: %v", err) + } + + // only final two reads should be short - 1 byte, then 0 + expect := n + if n == 128 { + expect = 1 + } else if n == 256 { + expect = 0 + if err != EOF { + t.Fatalf("read at end: %v", err) + } + } + if nn != expect { + t.Fatalf("read %d, expected %d, got %d", n, expect, nn) + } + tot += nn + } + pr := <-c + if pr.n != 128 || pr.err != nil { + t.Fatalf("write 128: %d, %v", pr.n, pr.err) + } + if tot != 128 { + t.Fatalf("total read %d != 128", tot) + } + for i := 0; i < 128; i++ { + if rdat[i] != byte(i) { + t.Fatalf("rdat[%d] = %d", i, rdat[i]) + } + } +} + +// Test read after/before writer close. + +type closer interface { + CloseWithError(error) error + Close() error +} + +type pipeTest struct { + async bool + err error + closeWithError bool +} + +func (p pipeTest) String() string { + return fmt.Sprintf("async=%v err=%v closeWithError=%v", p.async, p.err, p.closeWithError) +} + +var pipeTests = []pipeTest{ + {true, nil, false}, + {true, nil, true}, + {true, ErrShortWrite, true}, + {false, nil, false}, + {false, nil, true}, + {false, ErrShortWrite, true}, +} + +func delayClose(t *testing.T, cl closer, ch chan int, tt pipeTest) { + time.Sleep(1 * time.Millisecond) + var err error + if tt.closeWithError { + err = cl.CloseWithError(tt.err) + } else { + err = cl.Close() + } + if err != nil { + t.Errorf("delayClose: %v", err) + } + ch <- 0 +} + +func TestPipeReadClose(t *testing.T) { + for _, tt := range pipeTests { + c := make(chan int, 1) + r, w := Pipe() + if tt.async { + go delayClose(t, w, c, tt) + } else { + delayClose(t, w, c, tt) + } + var buf = make([]byte, 64) + n, err := r.Read(buf) + <-c + want := tt.err + if want == nil { + want = EOF + } + if err != want { + t.Errorf("read from closed pipe: %v want %v", err, want) + } + if n != 0 { + t.Errorf("read on closed pipe returned %d", n) + } + if err = r.Close(); err != nil { + t.Errorf("r.Close: %v", err) + } + } +} + +// Test close on Read side during Read. +func TestPipeReadClose2(t *testing.T) { + c := make(chan int, 1) + r, _ := Pipe() + go delayClose(t, r, c, pipeTest{}) + n, err := r.Read(make([]byte, 64)) + <-c + if n != 0 || err != ErrClosedPipe { + t.Errorf("read from closed pipe: %v, %v want %v, %v", n, err, 0, ErrClosedPipe) + } +} + +// Test write after/before reader close. + +func TestPipeWriteClose(t *testing.T) { + for _, tt := range pipeTests { + c := make(chan int, 1) + r, w := Pipe() + if tt.async { + go delayClose(t, r, c, tt) + } else { + delayClose(t, r, c, tt) + } + n, err := WriteString(w, "hello, world") + <-c + expect := tt.err + if expect == nil { + expect = ErrClosedPipe + } + if err != expect { + t.Errorf("write on closed pipe: %v want %v", err, expect) + } + if n != 0 { + t.Errorf("write on closed pipe returned %d", n) + } + if err = w.Close(); err != nil { + t.Errorf("w.Close: %v", err) + } + } +} + +// Test close on Write side during Write. +func TestPipeWriteClose2(t *testing.T) { + c := make(chan int, 1) + _, w := Pipe() + go delayClose(t, w, c, pipeTest{}) + n, err := w.Write(make([]byte, 64)) + <-c + if n != 0 || err != ErrClosedPipe { + t.Errorf("write to closed pipe: %v, %v want %v, %v", n, err, 0, ErrClosedPipe) + } +} + +func TestWriteEmpty(t *testing.T) { + r, w := Pipe() + go func() { + w.Write([]byte{}) + w.Close() + }() + var b [2]byte + ReadFull(r, b[0:2]) + r.Close() +} + +func TestWriteNil(t *testing.T) { + r, w := Pipe() + go func() { + w.Write(nil) + w.Close() + }() + var b [2]byte + ReadFull(r, b[0:2]) + r.Close() +} + +func TestWriteAfterWriterClose(t *testing.T) { + r, w := Pipe() + + done := make(chan bool) + var writeErr error + go func() { + _, err := w.Write([]byte("hello")) + if err != nil { + t.Errorf("got error: %q; expected none", err) + } + w.Close() + _, writeErr = w.Write([]byte("world")) + done <- true + }() + + buf := make([]byte, 100) + var result string + n, err := ReadFull(r, buf) + if err != nil && err != ErrUnexpectedEOF { + t.Fatalf("got: %q; want: %q", err, ErrUnexpectedEOF) + } + result = string(buf[0:n]) + <-done + + if result != "hello" { + t.Errorf("got: %q; want: %q", result, "hello") + } + if writeErr != ErrClosedPipe { + t.Errorf("got: %q; want: %q", writeErr, ErrClosedPipe) + } +} + +func TestPipeCloseError(t *testing.T) { + type testError1 struct{ error } + type testError2 struct{ error } + + r, w := Pipe() + r.CloseWithError(testError1{}) + if _, err := w.Write(nil); err != (testError1{}) { + t.Errorf("Write error: got %T, want testError1", err) + } + r.CloseWithError(testError2{}) + if _, err := w.Write(nil); err != (testError1{}) { + t.Errorf("Write error: got %T, want testError1", err) + } + + r, w = Pipe() + w.CloseWithError(testError1{}) + if _, err := r.Read(nil); err != (testError1{}) { + t.Errorf("Read error: got %T, want testError1", err) + } + w.CloseWithError(testError2{}) + if _, err := r.Read(nil); err != (testError1{}) { + t.Errorf("Read error: got %T, want testError1", err) + } +} + +func TestPipeConcurrent(t *testing.T) { + const ( + input = "0123456789abcdef" + count = 8 + readSize = 2 + ) + + t.Run("Write", func(t *testing.T) { + r, w := Pipe() + + for i := 0; i < count; i++ { + go func() { + time.Sleep(time.Millisecond) // Increase probability of race + if n, err := w.Write([]byte(input)); n != len(input) || err != nil { + t.Errorf("Write() = (%d, %v); want (%d, nil)", n, err, len(input)) + } + }() + } + + buf := make([]byte, count*len(input)) + for i := 0; i < len(buf); i += readSize { + if n, err := r.Read(buf[i : i+readSize]); n != readSize || err != nil { + t.Errorf("Read() = (%d, %v); want (%d, nil)", n, err, readSize) + } + } + + // Since each Write is fully gated, if multiple Read calls were needed, + // the contents of Write should still appear together in the output. + got := string(buf) + want := strings.Repeat(input, count) + if got != want { + t.Errorf("got: %q; want: %q", got, want) + } + }) + + t.Run("Read", func(t *testing.T) { + r, w := Pipe() + + c := make(chan []byte, count*len(input)/readSize) + for i := 0; i < cap(c); i++ { + go func() { + time.Sleep(time.Millisecond) // Increase probability of race + buf := make([]byte, readSize) + if n, err := r.Read(buf); n != readSize || err != nil { + t.Errorf("Read() = (%d, %v); want (%d, nil)", n, err, readSize) + } + c <- buf + }() + } + + for i := 0; i < count; i++ { + if n, err := w.Write([]byte(input)); n != len(input) || err != nil { + t.Errorf("Write() = (%d, %v); want (%d, nil)", n, err, len(input)) + } + } + + // Since each read is independent, the only guarantee about the output + // is that it is a permutation of the input in readSized groups. + got := make([]byte, 0, count*len(input)) + for i := 0; i < cap(c); i++ { + got = append(got, (<-c)...) + } + got = sortBytesInGroups(got, readSize) + want := bytes.Repeat([]byte(input), count) + want = sortBytesInGroups(want, readSize) + if string(got) != string(want) { + t.Errorf("got: %q; want: %q", got, want) + } + }) +} + +func sortBytesInGroups(b []byte, n int) []byte { + var groups [][]byte + for len(b) > 0 { + groups = append(groups, b[:n]) + b = b[n:] + } + sort.Slice(groups, func(i, j int) bool { return bytes.Compare(groups[i], groups[j]) < 0 }) + return bytes.Join(groups, nil) +} |