diff options
Diffstat (limited to 'src/os/readfrom_linux_test.go')
-rw-r--r-- | src/os/readfrom_linux_test.go | 822 |
1 files changed, 822 insertions, 0 deletions
diff --git a/src/os/readfrom_linux_test.go b/src/os/readfrom_linux_test.go new file mode 100644 index 0000000..4f98be4 --- /dev/null +++ b/src/os/readfrom_linux_test.go @@ -0,0 +1,822 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package os_test + +import ( + "bytes" + "errors" + "internal/poll" + "internal/testpty" + "io" + "math/rand" + "net" + . "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "sync" + "syscall" + "testing" + "time" + + "golang.org/x/net/nettest" +) + +func TestCopyFileRange(t *testing.T) { + sizes := []int{ + 1, + 42, + 1025, + syscall.Getpagesize() + 1, + 32769, + } + t.Run("Basic", func(t *testing.T) { + for _, size := range sizes { + t.Run(strconv.Itoa(size), func(t *testing.T) { + testCopyFileRange(t, int64(size), -1) + }) + } + }) + t.Run("Limited", func(t *testing.T) { + t.Run("OneLess", func(t *testing.T) { + for _, size := range sizes { + t.Run(strconv.Itoa(size), func(t *testing.T) { + testCopyFileRange(t, int64(size), int64(size)-1) + }) + } + }) + t.Run("Half", func(t *testing.T) { + for _, size := range sizes { + t.Run(strconv.Itoa(size), func(t *testing.T) { + testCopyFileRange(t, int64(size), int64(size)/2) + }) + } + }) + t.Run("More", func(t *testing.T) { + for _, size := range sizes { + t.Run(strconv.Itoa(size), func(t *testing.T) { + testCopyFileRange(t, int64(size), int64(size)+7) + }) + } + }) + }) + t.Run("DoesntTryInAppendMode", func(t *testing.T) { + dst, src, data, hook := newCopyFileRangeTest(t, 42) + + dst2, err := OpenFile(dst.Name(), O_RDWR|O_APPEND, 0755) + if err != nil { + t.Fatal(err) + } + defer dst2.Close() + + if _, err := io.Copy(dst2, src); err != nil { + t.Fatal(err) + } + if hook.called { + t.Fatal("called poll.CopyFileRange for destination in O_APPEND mode") + } + mustSeekStart(t, dst2) + mustContainData(t, dst2, data) // through traditional means + }) + t.Run("CopyFileItself", func(t *testing.T) { + hook := hookCopyFileRange(t) + + f, err := CreateTemp("", "file-readfrom-itself-test") + if err != nil { + t.Fatalf("failed to create tmp file: %v", err) + } + t.Cleanup(func() { + f.Close() + Remove(f.Name()) + }) + + data := []byte("hello world!") + if _, err := f.Write(data); err != nil { + t.Fatalf("failed to create and feed the file: %v", err) + } + + if err := f.Sync(); err != nil { + t.Fatalf("failed to save the file: %v", err) + } + + // Rewind it. + if _, err := f.Seek(0, io.SeekStart); err != nil { + t.Fatalf("failed to rewind the file: %v", err) + } + + // Read data from the file itself. + if _, err := io.Copy(f, f); err != nil { + t.Fatalf("failed to read from the file: %v", err) + } + + if !hook.called || hook.written != 0 || hook.handled || hook.err != nil { + t.Fatalf("poll.CopyFileRange should be called and return the EINVAL error, but got hook.called=%t, hook.err=%v", hook.called, hook.err) + } + + // Rewind it. + if _, err := f.Seek(0, io.SeekStart); err != nil { + t.Fatalf("failed to rewind the file: %v", err) + } + + data2, err := io.ReadAll(f) + if err != nil { + t.Fatalf("failed to read from the file: %v", err) + } + + // It should wind up a double of the original data. + if strings.Repeat(string(data), 2) != string(data2) { + t.Fatalf("data mismatch: %s != %s", string(data), string(data2)) + } + }) + t.Run("NotRegular", func(t *testing.T) { + t.Run("BothPipes", func(t *testing.T) { + hook := hookCopyFileRange(t) + + pr1, pw1, err := Pipe() + if err != nil { + t.Fatal(err) + } + defer pr1.Close() + defer pw1.Close() + + pr2, pw2, err := Pipe() + if err != nil { + t.Fatal(err) + } + defer pr2.Close() + defer pw2.Close() + + // The pipe is empty, and PIPE_BUF is large enough + // for this, by (POSIX) definition, so there is no + // need for an additional goroutine. + data := []byte("hello") + if _, err := pw1.Write(data); err != nil { + t.Fatal(err) + } + pw1.Close() + + n, err := io.Copy(pw2, pr1) + if err != nil { + t.Fatal(err) + } + if n != int64(len(data)) { + t.Fatalf("transferred %d, want %d", n, len(data)) + } + if !hook.called { + t.Fatalf("should have called poll.CopyFileRange") + } + pw2.Close() + mustContainData(t, pr2, data) + }) + t.Run("DstPipe", func(t *testing.T) { + dst, src, data, hook := newCopyFileRangeTest(t, 255) + dst.Close() + + pr, pw, err := Pipe() + if err != nil { + t.Fatal(err) + } + defer pr.Close() + defer pw.Close() + + n, err := io.Copy(pw, src) + if err != nil { + t.Fatal(err) + } + if n != int64(len(data)) { + t.Fatalf("transferred %d, want %d", n, len(data)) + } + if !hook.called { + t.Fatalf("should have called poll.CopyFileRange") + } + pw.Close() + mustContainData(t, pr, data) + }) + t.Run("SrcPipe", func(t *testing.T) { + dst, src, data, hook := newCopyFileRangeTest(t, 255) + src.Close() + + pr, pw, err := Pipe() + if err != nil { + t.Fatal(err) + } + defer pr.Close() + defer pw.Close() + + // The pipe is empty, and PIPE_BUF is large enough + // for this, by (POSIX) definition, so there is no + // need for an additional goroutine. + if _, err := pw.Write(data); err != nil { + t.Fatal(err) + } + pw.Close() + + n, err := io.Copy(dst, pr) + if err != nil { + t.Fatal(err) + } + if n != int64(len(data)) { + t.Fatalf("transferred %d, want %d", n, len(data)) + } + if !hook.called { + t.Fatalf("should have called poll.CopyFileRange") + } + mustSeekStart(t, dst) + mustContainData(t, dst, data) + }) + }) + t.Run("Nil", func(t *testing.T) { + var nilFile *File + anyFile, err := CreateTemp("", "") + if err != nil { + t.Fatal(err) + } + defer Remove(anyFile.Name()) + defer anyFile.Close() + + if _, err := io.Copy(nilFile, nilFile); err != ErrInvalid { + t.Errorf("io.Copy(nilFile, nilFile) = %v, want %v", err, ErrInvalid) + } + if _, err := io.Copy(anyFile, nilFile); err != ErrInvalid { + t.Errorf("io.Copy(anyFile, nilFile) = %v, want %v", err, ErrInvalid) + } + if _, err := io.Copy(nilFile, anyFile); err != ErrInvalid { + t.Errorf("io.Copy(nilFile, anyFile) = %v, want %v", err, ErrInvalid) + } + + if _, err := nilFile.ReadFrom(nilFile); err != ErrInvalid { + t.Errorf("nilFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid) + } + if _, err := anyFile.ReadFrom(nilFile); err != ErrInvalid { + t.Errorf("anyFile.ReadFrom(nilFile) = %v, want %v", err, ErrInvalid) + } + if _, err := nilFile.ReadFrom(anyFile); err != ErrInvalid { + t.Errorf("nilFile.ReadFrom(anyFile) = %v, want %v", err, ErrInvalid) + } + }) +} + +func TestSpliceFile(t *testing.T) { + sizes := []int{ + 1, + 42, + 1025, + syscall.Getpagesize() + 1, + 32769, + } + t.Run("Basic-TCP", func(t *testing.T) { + for _, size := range sizes { + t.Run(strconv.Itoa(size), func(t *testing.T) { + testSpliceFile(t, "tcp", int64(size), -1) + }) + } + }) + t.Run("Basic-Unix", func(t *testing.T) { + for _, size := range sizes { + t.Run(strconv.Itoa(size), func(t *testing.T) { + testSpliceFile(t, "unix", int64(size), -1) + }) + } + }) + t.Run("TCP-To-TTY", func(t *testing.T) { + testSpliceToTTY(t, "tcp", 32768) + }) + t.Run("Unix-To-TTY", func(t *testing.T) { + testSpliceToTTY(t, "unix", 32768) + }) + t.Run("Limited", func(t *testing.T) { + t.Run("OneLess-TCP", func(t *testing.T) { + for _, size := range sizes { + t.Run(strconv.Itoa(size), func(t *testing.T) { + testSpliceFile(t, "tcp", int64(size), int64(size)-1) + }) + } + }) + t.Run("OneLess-Unix", func(t *testing.T) { + for _, size := range sizes { + t.Run(strconv.Itoa(size), func(t *testing.T) { + testSpliceFile(t, "unix", int64(size), int64(size)-1) + }) + } + }) + t.Run("Half-TCP", func(t *testing.T) { + for _, size := range sizes { + t.Run(strconv.Itoa(size), func(t *testing.T) { + testSpliceFile(t, "tcp", int64(size), int64(size)/2) + }) + } + }) + t.Run("Half-Unix", func(t *testing.T) { + for _, size := range sizes { + t.Run(strconv.Itoa(size), func(t *testing.T) { + testSpliceFile(t, "unix", int64(size), int64(size)/2) + }) + } + }) + t.Run("More-TCP", func(t *testing.T) { + for _, size := range sizes { + t.Run(strconv.Itoa(size), func(t *testing.T) { + testSpliceFile(t, "tcp", int64(size), int64(size)+1) + }) + } + }) + t.Run("More-Unix", func(t *testing.T) { + for _, size := range sizes { + t.Run(strconv.Itoa(size), func(t *testing.T) { + testSpliceFile(t, "unix", int64(size), int64(size)+1) + }) + } + }) + }) +} + +func testSpliceFile(t *testing.T, proto string, size, limit int64) { + dst, src, data, hook, cleanup := newSpliceFileTest(t, proto, size) + defer cleanup() + + // If we have a limit, wrap the reader. + var ( + r io.Reader + lr *io.LimitedReader + ) + if limit >= 0 { + lr = &io.LimitedReader{N: limit, R: src} + r = lr + if limit < int64(len(data)) { + data = data[:limit] + } + } else { + r = src + } + // Now call ReadFrom (through io.Copy), which will hopefully call poll.Splice + n, err := io.Copy(dst, r) + if err != nil { + t.Fatal(err) + } + + // We should have called poll.Splice with the right file descriptor arguments. + if n > 0 && !hook.called { + t.Fatal("expected to called poll.Splice") + } + if hook.called && hook.dstfd != int(dst.Fd()) { + t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd()) + } + sc, ok := src.(syscall.Conn) + if !ok { + t.Fatalf("server Conn is not a syscall.Conn") + } + rc, err := sc.SyscallConn() + if err != nil { + t.Fatalf("server Conn SyscallConn error: %v", err) + } + if err = rc.Control(func(fd uintptr) { + if hook.called && hook.srcfd != int(fd) { + t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, int(fd)) + } + }); err != nil { + t.Fatalf("server Conn Control error: %v", err) + } + + // Check that the offsets after the transfer make sense, that the size + // of the transfer was reported correctly, and that the destination + // file contains exactly the bytes we expect it to contain. + dstoff, err := dst.Seek(0, io.SeekCurrent) + if err != nil { + t.Fatal(err) + } + if dstoff != int64(len(data)) { + t.Errorf("dstoff = %d, want %d", dstoff, len(data)) + } + if n != int64(len(data)) { + t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data)) + } + mustSeekStart(t, dst) + mustContainData(t, dst, data) + + // If we had a limit, check that it was updated. + if lr != nil { + if want := limit - n; lr.N != want { + t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want) + } + } +} + +// Issue #59041. +func testSpliceToTTY(t *testing.T, proto string, size int64) { + var wg sync.WaitGroup + + // Call wg.Wait as the final deferred function, + // because the goroutines may block until some of + // the deferred Close calls. + defer wg.Wait() + + pty, ttyName, err := testpty.Open() + if err != nil { + t.Skipf("skipping test because pty open failed: %v", err) + } + defer pty.Close() + + // Open the tty directly, rather than via OpenFile. + // This bypasses the non-blocking support and is required + // to recreate the problem in the issue (#59041). + ttyFD, err := syscall.Open(ttyName, syscall.O_RDWR, 0) + if err != nil { + t.Skipf("skipping test becaused failed to open tty: %v", err) + } + defer syscall.Close(ttyFD) + + tty := NewFile(uintptr(ttyFD), "tty") + defer tty.Close() + + client, server := createSocketPair(t, proto) + + data := bytes.Repeat([]byte{'a'}, int(size)) + + wg.Add(1) + go func() { + defer wg.Done() + // The problem (issue #59041) occurs when writing + // a series of blocks of data. It does not occur + // when all the data is written at once. + for i := 0; i < len(data); i += 1024 { + if _, err := client.Write(data[i : i+1024]); err != nil { + // If we get here because the client was + // closed, skip the error. + if !errors.Is(err, net.ErrClosed) { + t.Errorf("error writing to socket: %v", err) + } + return + } + } + client.Close() + }() + + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 32) + for { + if _, err := pty.Read(buf); err != nil { + if err != io.EOF && !errors.Is(err, ErrClosed) { + // An error here doesn't matter for + // our test. + t.Logf("error reading from pty: %v", err) + } + return + } + } + }() + + // Close Client to wake up the writing goroutine if necessary. + defer client.Close() + + _, err = io.Copy(tty, server) + if err != nil { + t.Fatal(err) + } +} + +func testCopyFileRange(t *testing.T, size int64, limit int64) { + dst, src, data, hook := newCopyFileRangeTest(t, size) + + // If we have a limit, wrap the reader. + var ( + realsrc io.Reader + lr *io.LimitedReader + ) + if limit >= 0 { + lr = &io.LimitedReader{N: limit, R: src} + realsrc = lr + if limit < int64(len(data)) { + data = data[:limit] + } + } else { + realsrc = src + } + + // Now call ReadFrom (through io.Copy), which will hopefully call + // poll.CopyFileRange. + n, err := io.Copy(dst, realsrc) + if err != nil { + t.Fatal(err) + } + + // If we didn't have a limit, we should have called poll.CopyFileRange + // with the right file descriptor arguments. + if limit > 0 && !hook.called { + t.Fatal("never called poll.CopyFileRange") + } + if hook.called && hook.dstfd != int(dst.Fd()) { + t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd()) + } + if hook.called && hook.srcfd != int(src.Fd()) { + t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd()) + } + + // Check that the offsets after the transfer make sense, that the size + // of the transfer was reported correctly, and that the destination + // file contains exactly the bytes we expect it to contain. + dstoff, err := dst.Seek(0, io.SeekCurrent) + if err != nil { + t.Fatal(err) + } + srcoff, err := src.Seek(0, io.SeekCurrent) + if err != nil { + t.Fatal(err) + } + if dstoff != srcoff { + t.Errorf("offsets differ: dstoff = %d, srcoff = %d", dstoff, srcoff) + } + if dstoff != int64(len(data)) { + t.Errorf("dstoff = %d, want %d", dstoff, len(data)) + } + if n != int64(len(data)) { + t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data)) + } + mustSeekStart(t, dst) + mustContainData(t, dst, data) + + // If we had a limit, check that it was updated. + if lr != nil { + if want := limit - n; lr.N != want { + t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want) + } + } +} + +// newCopyFileRangeTest initializes a new test for copy_file_range. +// +// It creates source and destination files, and populates the source file +// with random data of the specified size. It also hooks package os' call +// to poll.CopyFileRange and returns the hook so it can be inspected. +func newCopyFileRangeTest(t *testing.T, size int64) (dst, src *File, data []byte, hook *copyFileRangeHook) { + t.Helper() + + hook = hookCopyFileRange(t) + tmp := t.TempDir() + + src, err := Create(filepath.Join(tmp, "src")) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { src.Close() }) + + dst, err = Create(filepath.Join(tmp, "dst")) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { dst.Close() }) + + // Populate the source file with data, then rewind it, so it can be + // consumed by copy_file_range(2). + prng := rand.New(rand.NewSource(time.Now().Unix())) + data = make([]byte, size) + prng.Read(data) + if _, err := src.Write(data); err != nil { + t.Fatal(err) + } + if _, err := src.Seek(0, io.SeekStart); err != nil { + t.Fatal(err) + } + + return dst, src, data, hook +} + +// newSpliceFileTest initializes a new test for splice. +// +// It creates source sockets and destination file, and populates the source sockets +// with random data of the specified size. It also hooks package os' call +// to poll.Splice and returns the hook so it can be inspected. +func newSpliceFileTest(t *testing.T, proto string, size int64) (*File, net.Conn, []byte, *spliceFileHook, func()) { + t.Helper() + + hook := hookSpliceFile(t) + + client, server := createSocketPair(t, proto) + + dst, err := CreateTemp(t.TempDir(), "dst-splice-file-test") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { dst.Close() }) + + randSeed := time.Now().Unix() + t.Logf("random data seed: %d\n", randSeed) + prng := rand.New(rand.NewSource(randSeed)) + data := make([]byte, size) + prng.Read(data) + + done := make(chan struct{}) + go func() { + client.Write(data) + client.Close() + close(done) + }() + + return dst, server, data, hook, func() { <-done } +} + +// mustContainData ensures that the specified file contains exactly the +// specified data. +func mustContainData(t *testing.T, f *File, data []byte) { + t.Helper() + + got := make([]byte, len(data)) + if _, err := io.ReadFull(f, got); err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, data) { + t.Fatalf("didn't get the same data back from %s", f.Name()) + } + if _, err := f.Read(make([]byte, 1)); err != io.EOF { + t.Fatalf("not at EOF") + } +} + +func mustSeekStart(t *testing.T, f *File) { + if _, err := f.Seek(0, io.SeekStart); err != nil { + t.Fatal(err) + } +} + +func hookCopyFileRange(t *testing.T) *copyFileRangeHook { + h := new(copyFileRangeHook) + h.install() + t.Cleanup(h.uninstall) + return h +} + +type copyFileRangeHook struct { + called bool + dstfd int + srcfd int + remain int64 + + written int64 + handled bool + err error + + original func(dst, src *poll.FD, remain int64) (int64, bool, error) +} + +func (h *copyFileRangeHook) install() { + h.original = *PollCopyFileRangeP + *PollCopyFileRangeP = func(dst, src *poll.FD, remain int64) (int64, bool, error) { + h.called = true + h.dstfd = dst.Sysfd + h.srcfd = src.Sysfd + h.remain = remain + h.written, h.handled, h.err = h.original(dst, src, remain) + return h.written, h.handled, h.err + } +} + +func (h *copyFileRangeHook) uninstall() { + *PollCopyFileRangeP = h.original +} + +func hookSpliceFile(t *testing.T) *spliceFileHook { + h := new(spliceFileHook) + h.install() + t.Cleanup(h.uninstall) + return h +} + +type spliceFileHook struct { + called bool + dstfd int + srcfd int + remain int64 + + written int64 + handled bool + sc string + err error + + original func(dst, src *poll.FD, remain int64) (int64, bool, string, error) +} + +func (h *spliceFileHook) install() { + h.original = *PollSpliceFile + *PollSpliceFile = func(dst, src *poll.FD, remain int64) (int64, bool, string, error) { + h.called = true + h.dstfd = dst.Sysfd + h.srcfd = src.Sysfd + h.remain = remain + h.written, h.handled, h.sc, h.err = h.original(dst, src, remain) + return h.written, h.handled, h.sc, h.err + } +} + +func (h *spliceFileHook) uninstall() { + *PollSpliceFile = h.original +} + +// On some kernels copy_file_range fails on files in /proc. +func TestProcCopy(t *testing.T) { + t.Parallel() + + const cmdlineFile = "/proc/self/cmdline" + cmdline, err := ReadFile(cmdlineFile) + if err != nil { + t.Skipf("can't read /proc file: %v", err) + } + in, err := Open(cmdlineFile) + if err != nil { + t.Fatal(err) + } + defer in.Close() + outFile := filepath.Join(t.TempDir(), "cmdline") + out, err := Create(outFile) + if err != nil { + t.Fatal(err) + } + if _, err := io.Copy(out, in); err != nil { + t.Fatal(err) + } + if err := out.Close(); err != nil { + t.Fatal(err) + } + copy, err := ReadFile(outFile) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(cmdline, copy) { + t.Errorf("copy of %q got %q want %q\n", cmdlineFile, copy, cmdline) + } +} + +func TestGetPollFDFromReader(t *testing.T) { + t.Run("tcp", func(t *testing.T) { testGetPollFromReader(t, "tcp") }) + t.Run("unix", func(t *testing.T) { testGetPollFromReader(t, "unix") }) +} + +func testGetPollFromReader(t *testing.T, proto string) { + _, server := createSocketPair(t, proto) + sc, ok := server.(syscall.Conn) + if !ok { + t.Fatalf("server Conn is not a syscall.Conn") + } + rc, err := sc.SyscallConn() + if err != nil { + t.Fatalf("server SyscallConn error: %v", err) + } + if err = rc.Control(func(fd uintptr) { + pfd := GetPollFDForTest(server) + if pfd == nil { + t.Fatalf("GetPollFDForTest didn't return poll.FD") + } + if pfd.Sysfd != int(fd) { + t.Fatalf("GetPollFDForTest returned wrong poll.FD, got: %d, want: %d", pfd.Sysfd, int(fd)) + } + if !pfd.IsStream { + t.Fatalf("expected IsStream to be true") + } + if err = pfd.Init(proto, true); err == nil { + t.Fatalf("Init should have failed with the initialized poll.FD and return EEXIST error") + } + }); err != nil { + t.Fatalf("server Control error: %v", err) + } +} + +func createSocketPair(t *testing.T, proto string) (client, server net.Conn) { + t.Helper() + if !nettest.TestableNetwork(proto) { + t.Skipf("%s does not support %q", runtime.GOOS, proto) + } + + ln, err := nettest.NewLocalListener(proto) + if err != nil { + t.Fatalf("NewLocalListener error: %v", err) + } + t.Cleanup(func() { + if ln != nil { + ln.Close() + } + if client != nil { + client.Close() + } + if server != nil { + server.Close() + } + }) + ch := make(chan struct{}) + go func() { + var err error + server, err = ln.Accept() + if err != nil { + t.Errorf("Accept new connection error: %v", err) + } + ch <- struct{}{} + }() + client, err = net.Dial(proto, ln.Addr().String()) + <-ch + if err != nil { + t.Fatalf("Dial new connection error: %v", err) + } + return client, server +} |