summaryrefslogtreecommitdiffstats
path: root/src/testing/iotest
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-16 19:19:13 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-16 19:19:13 +0000
commitccd992355df7192993c666236047820244914598 (patch)
treef00fea65147227b7743083c6148396f74cd66935 /src/testing/iotest
parentInitial commit. (diff)
downloadgolang-1.21-ccd992355df7192993c666236047820244914598.tar.xz
golang-1.21-ccd992355df7192993c666236047820244914598.zip
Adding upstream version 1.21.8.upstream/1.21.8
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/testing/iotest')
-rw-r--r--src/testing/iotest/example_test.go22
-rw-r--r--src/testing/iotest/logger.go54
-rw-r--r--src/testing/iotest/logger_test.go153
-rw-r--r--src/testing/iotest/reader.go268
-rw-r--r--src/testing/iotest/reader_test.go261
-rw-r--r--src/testing/iotest/writer.go35
-rw-r--r--src/testing/iotest/writer_test.go39
7 files changed, 832 insertions, 0 deletions
diff --git a/src/testing/iotest/example_test.go b/src/testing/iotest/example_test.go
new file mode 100644
index 0000000..10f6bd3
--- /dev/null
+++ b/src/testing/iotest/example_test.go
@@ -0,0 +1,22 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package iotest_test
+
+import (
+ "errors"
+ "fmt"
+ "testing/iotest"
+)
+
+func ExampleErrReader() {
+ // A reader that always returns a custom error.
+ r := iotest.ErrReader(errors.New("custom error"))
+ n, err := r.Read(nil)
+ fmt.Printf("n: %d\nerr: %q\n", n, err)
+
+ // Output:
+ // n: 0
+ // err: "custom error"
+}
diff --git a/src/testing/iotest/logger.go b/src/testing/iotest/logger.go
new file mode 100644
index 0000000..99548dc
--- /dev/null
+++ b/src/testing/iotest/logger.go
@@ -0,0 +1,54 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package iotest
+
+import (
+ "io"
+ "log"
+)
+
+type writeLogger struct {
+ prefix string
+ w io.Writer
+}
+
+func (l *writeLogger) Write(p []byte) (n int, err error) {
+ n, err = l.w.Write(p)
+ if err != nil {
+ log.Printf("%s %x: %v", l.prefix, p[0:n], err)
+ } else {
+ log.Printf("%s %x", l.prefix, p[0:n])
+ }
+ return
+}
+
+// NewWriteLogger returns a writer that behaves like w except
+// that it logs (using log.Printf) each write to standard error,
+// printing the prefix and the hexadecimal data written.
+func NewWriteLogger(prefix string, w io.Writer) io.Writer {
+ return &writeLogger{prefix, w}
+}
+
+type readLogger struct {
+ prefix string
+ r io.Reader
+}
+
+func (l *readLogger) Read(p []byte) (n int, err error) {
+ n, err = l.r.Read(p)
+ if err != nil {
+ log.Printf("%s %x: %v", l.prefix, p[0:n], err)
+ } else {
+ log.Printf("%s %x", l.prefix, p[0:n])
+ }
+ return
+}
+
+// NewReadLogger returns a reader that behaves like r except
+// that it logs (using log.Printf) each read to standard error,
+// printing the prefix and the hexadecimal data read.
+func NewReadLogger(prefix string, r io.Reader) io.Reader {
+ return &readLogger{prefix, r}
+}
diff --git a/src/testing/iotest/logger_test.go b/src/testing/iotest/logger_test.go
new file mode 100644
index 0000000..7a7d0aa
--- /dev/null
+++ b/src/testing/iotest/logger_test.go
@@ -0,0 +1,153 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package iotest
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "log"
+ "strings"
+ "testing"
+)
+
+type errWriter struct {
+ err error
+}
+
+func (w errWriter) Write([]byte) (int, error) {
+ return 0, w.err
+}
+
+func TestWriteLogger(t *testing.T) {
+ olw := log.Writer()
+ olf := log.Flags()
+ olp := log.Prefix()
+
+ // Revert the original log settings before we exit.
+ defer func() {
+ log.SetFlags(olf)
+ log.SetPrefix(olp)
+ log.SetOutput(olw)
+ }()
+
+ lOut := new(strings.Builder)
+ log.SetPrefix("lw: ")
+ log.SetOutput(lOut)
+ log.SetFlags(0)
+
+ lw := new(strings.Builder)
+ wl := NewWriteLogger("write:", lw)
+ if _, err := wl.Write([]byte("Hello, World!")); err != nil {
+ t.Fatalf("Unexpectedly failed to write: %v", err)
+ }
+
+ if g, w := lw.String(), "Hello, World!"; g != w {
+ t.Errorf("WriteLogger mismatch\n\tgot: %q\n\twant: %q", g, w)
+ }
+ wantLogWithHex := fmt.Sprintf("lw: write: %x\n", "Hello, World!")
+ if g, w := lOut.String(), wantLogWithHex; g != w {
+ t.Errorf("WriteLogger mismatch\n\tgot: %q\n\twant: %q", g, w)
+ }
+}
+
+func TestWriteLogger_errorOnWrite(t *testing.T) {
+ olw := log.Writer()
+ olf := log.Flags()
+ olp := log.Prefix()
+
+ // Revert the original log settings before we exit.
+ defer func() {
+ log.SetFlags(olf)
+ log.SetPrefix(olp)
+ log.SetOutput(olw)
+ }()
+
+ lOut := new(strings.Builder)
+ log.SetPrefix("lw: ")
+ log.SetOutput(lOut)
+ log.SetFlags(0)
+
+ lw := errWriter{err: errors.New("Write Error!")}
+ wl := NewWriteLogger("write:", lw)
+ if _, err := wl.Write([]byte("Hello, World!")); err == nil {
+ t.Fatalf("Unexpectedly succeeded to write: %v", err)
+ }
+
+ wantLogWithHex := fmt.Sprintf("lw: write: %x: %v\n", "", "Write Error!")
+ if g, w := lOut.String(), wantLogWithHex; g != w {
+ t.Errorf("WriteLogger mismatch\n\tgot: %q\n\twant: %q", g, w)
+ }
+}
+
+func TestReadLogger(t *testing.T) {
+ olw := log.Writer()
+ olf := log.Flags()
+ olp := log.Prefix()
+
+ // Revert the original log settings before we exit.
+ defer func() {
+ log.SetFlags(olf)
+ log.SetPrefix(olp)
+ log.SetOutput(olw)
+ }()
+
+ lOut := new(strings.Builder)
+ log.SetPrefix("lr: ")
+ log.SetOutput(lOut)
+ log.SetFlags(0)
+
+ data := []byte("Hello, World!")
+ p := make([]byte, len(data))
+ lr := bytes.NewReader(data)
+ rl := NewReadLogger("read:", lr)
+
+ n, err := rl.Read(p)
+ if err != nil {
+ t.Fatalf("Unexpectedly failed to read: %v", err)
+ }
+
+ if g, w := p[:n], data; !bytes.Equal(g, w) {
+ t.Errorf("ReadLogger mismatch\n\tgot: %q\n\twant: %q", g, w)
+ }
+
+ wantLogWithHex := fmt.Sprintf("lr: read: %x\n", "Hello, World!")
+ if g, w := lOut.String(), wantLogWithHex; g != w {
+ t.Errorf("ReadLogger mismatch\n\tgot: %q\n\twant: %q", g, w)
+ }
+}
+
+func TestReadLogger_errorOnRead(t *testing.T) {
+ olw := log.Writer()
+ olf := log.Flags()
+ olp := log.Prefix()
+
+ // Revert the original log settings before we exit.
+ defer func() {
+ log.SetFlags(olf)
+ log.SetPrefix(olp)
+ log.SetOutput(olw)
+ }()
+
+ lOut := new(strings.Builder)
+ log.SetPrefix("lr: ")
+ log.SetOutput(lOut)
+ log.SetFlags(0)
+
+ data := []byte("Hello, World!")
+ p := make([]byte, len(data))
+
+ lr := ErrReader(errors.New("io failure"))
+ rl := NewReadLogger("read", lr)
+ n, err := rl.Read(p)
+ if err == nil {
+ t.Fatalf("Unexpectedly succeeded to read: %v", err)
+ }
+
+ wantLogWithHex := fmt.Sprintf("lr: read %x: io failure\n", p[:n])
+ if g, w := lOut.String(), wantLogWithHex; g != w {
+ t.Errorf("ReadLogger mismatch\n\tgot: %q\n\twant: %q", g, w)
+ }
+}
diff --git a/src/testing/iotest/reader.go b/src/testing/iotest/reader.go
new file mode 100644
index 0000000..770d87f
--- /dev/null
+++ b/src/testing/iotest/reader.go
@@ -0,0 +1,268 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package iotest implements Readers and Writers useful mainly for testing.
+package iotest
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "io"
+)
+
+// OneByteReader returns a Reader that implements
+// each non-empty Read by reading one byte from r.
+func OneByteReader(r io.Reader) io.Reader { return &oneByteReader{r} }
+
+type oneByteReader struct {
+ r io.Reader
+}
+
+func (r *oneByteReader) Read(p []byte) (int, error) {
+ if len(p) == 0 {
+ return 0, nil
+ }
+ return r.r.Read(p[0:1])
+}
+
+// HalfReader returns a Reader that implements Read
+// by reading half as many requested bytes from r.
+func HalfReader(r io.Reader) io.Reader { return &halfReader{r} }
+
+type halfReader struct {
+ r io.Reader
+}
+
+func (r *halfReader) Read(p []byte) (int, error) {
+ return r.r.Read(p[0 : (len(p)+1)/2])
+}
+
+// DataErrReader changes the way errors are handled by a Reader. Normally, a
+// Reader returns an error (typically EOF) from the first Read call after the
+// last piece of data is read. DataErrReader wraps a Reader and changes its
+// behavior so the final error is returned along with the final data, instead
+// of in the first call after the final data.
+func DataErrReader(r io.Reader) io.Reader { return &dataErrReader{r, nil, make([]byte, 1024)} }
+
+type dataErrReader struct {
+ r io.Reader
+ unread []byte
+ data []byte
+}
+
+func (r *dataErrReader) Read(p []byte) (n int, err error) {
+ // loop because first call needs two reads:
+ // one to get data and a second to look for an error.
+ for {
+ if len(r.unread) == 0 {
+ n1, err1 := r.r.Read(r.data)
+ r.unread = r.data[0:n1]
+ err = err1
+ }
+ if n > 0 || err != nil {
+ break
+ }
+ n = copy(p, r.unread)
+ r.unread = r.unread[n:]
+ }
+ return
+}
+
+// ErrTimeout is a fake timeout error.
+var ErrTimeout = errors.New("timeout")
+
+// TimeoutReader returns ErrTimeout on the second read
+// with no data. Subsequent calls to read succeed.
+func TimeoutReader(r io.Reader) io.Reader { return &timeoutReader{r, 0} }
+
+type timeoutReader struct {
+ r io.Reader
+ count int
+}
+
+func (r *timeoutReader) Read(p []byte) (int, error) {
+ r.count++
+ if r.count == 2 {
+ return 0, ErrTimeout
+ }
+ return r.r.Read(p)
+}
+
+// ErrReader returns an io.Reader that returns 0, err from all Read calls.
+func ErrReader(err error) io.Reader {
+ return &errReader{err: err}
+}
+
+type errReader struct {
+ err error
+}
+
+func (r *errReader) Read(p []byte) (int, error) {
+ return 0, r.err
+}
+
+type smallByteReader struct {
+ r io.Reader
+ off int
+ n int
+}
+
+func (r *smallByteReader) Read(p []byte) (int, error) {
+ if len(p) == 0 {
+ return 0, nil
+ }
+ r.n = r.n%3 + 1
+ n := r.n
+ if n > len(p) {
+ n = len(p)
+ }
+ n, err := r.r.Read(p[0:n])
+ if err != nil && err != io.EOF {
+ err = fmt.Errorf("Read(%d bytes at offset %d): %v", n, r.off, err)
+ }
+ r.off += n
+ return n, err
+}
+
+// TestReader tests that reading from r returns the expected file content.
+// It does reads of different sizes, until EOF.
+// If r implements io.ReaderAt or io.Seeker, TestReader also checks
+// that those operations behave as they should.
+//
+// If TestReader finds any misbehaviors, it returns an error reporting them.
+// The error text may span multiple lines.
+func TestReader(r io.Reader, content []byte) error {
+ if len(content) > 0 {
+ n, err := r.Read(nil)
+ if n != 0 || err != nil {
+ return fmt.Errorf("Read(0) = %d, %v, want 0, nil", n, err)
+ }
+ }
+
+ data, err := io.ReadAll(&smallByteReader{r: r})
+ if err != nil {
+ return err
+ }
+ if !bytes.Equal(data, content) {
+ return fmt.Errorf("ReadAll(small amounts) = %q\n\twant %q", data, content)
+ }
+ n, err := r.Read(make([]byte, 10))
+ if n != 0 || err != io.EOF {
+ return fmt.Errorf("Read(10) at EOF = %v, %v, want 0, EOF", n, err)
+ }
+
+ if r, ok := r.(io.ReadSeeker); ok {
+ // Seek(0, 1) should report the current file position (EOF).
+ if off, err := r.Seek(0, 1); off != int64(len(content)) || err != nil {
+ return fmt.Errorf("Seek(0, 1) from EOF = %d, %v, want %d, nil", off, err, len(content))
+ }
+
+ // Seek backward partway through file, in two steps.
+ // If middle == 0, len(content) == 0, can't use the -1 and +1 seeks.
+ middle := len(content) - len(content)/3
+ if middle > 0 {
+ if off, err := r.Seek(-1, 1); off != int64(len(content)-1) || err != nil {
+ return fmt.Errorf("Seek(-1, 1) from EOF = %d, %v, want %d, nil", -off, err, len(content)-1)
+ }
+ if off, err := r.Seek(int64(-len(content)/3), 1); off != int64(middle-1) || err != nil {
+ return fmt.Errorf("Seek(%d, 1) from %d = %d, %v, want %d, nil", -len(content)/3, len(content)-1, off, err, middle-1)
+ }
+ if off, err := r.Seek(+1, 1); off != int64(middle) || err != nil {
+ return fmt.Errorf("Seek(+1, 1) from %d = %d, %v, want %d, nil", middle-1, off, err, middle)
+ }
+ }
+
+ // Seek(0, 1) should report the current file position (middle).
+ if off, err := r.Seek(0, 1); off != int64(middle) || err != nil {
+ return fmt.Errorf("Seek(0, 1) from %d = %d, %v, want %d, nil", middle, off, err, middle)
+ }
+
+ // Reading forward should return the last part of the file.
+ data, err := io.ReadAll(&smallByteReader{r: r})
+ if err != nil {
+ return fmt.Errorf("ReadAll from offset %d: %v", middle, err)
+ }
+ if !bytes.Equal(data, content[middle:]) {
+ return fmt.Errorf("ReadAll from offset %d = %q\n\twant %q", middle, data, content[middle:])
+ }
+
+ // Seek relative to end of file, but start elsewhere.
+ if off, err := r.Seek(int64(middle/2), 0); off != int64(middle/2) || err != nil {
+ return fmt.Errorf("Seek(%d, 0) from EOF = %d, %v, want %d, nil", middle/2, off, err, middle/2)
+ }
+ if off, err := r.Seek(int64(-len(content)/3), 2); off != int64(middle) || err != nil {
+ return fmt.Errorf("Seek(%d, 2) from %d = %d, %v, want %d, nil", -len(content)/3, middle/2, off, err, middle)
+ }
+
+ // Reading forward should return the last part of the file (again).
+ data, err = io.ReadAll(&smallByteReader{r: r})
+ if err != nil {
+ return fmt.Errorf("ReadAll from offset %d: %v", middle, err)
+ }
+ if !bytes.Equal(data, content[middle:]) {
+ return fmt.Errorf("ReadAll from offset %d = %q\n\twant %q", middle, data, content[middle:])
+ }
+
+ // Absolute seek & read forward.
+ if off, err := r.Seek(int64(middle/2), 0); off != int64(middle/2) || err != nil {
+ return fmt.Errorf("Seek(%d, 0) from EOF = %d, %v, want %d, nil", middle/2, off, err, middle/2)
+ }
+ data, err = io.ReadAll(r)
+ if err != nil {
+ return fmt.Errorf("ReadAll from offset %d: %v", middle/2, err)
+ }
+ if !bytes.Equal(data, content[middle/2:]) {
+ return fmt.Errorf("ReadAll from offset %d = %q\n\twant %q", middle/2, data, content[middle/2:])
+ }
+ }
+
+ if r, ok := r.(io.ReaderAt); ok {
+ data := make([]byte, len(content), len(content)+1)
+ for i := range data {
+ data[i] = 0xfe
+ }
+ n, err := r.ReadAt(data, 0)
+ if n != len(data) || err != nil && err != io.EOF {
+ return fmt.Errorf("ReadAt(%d, 0) = %v, %v, want %d, nil or EOF", len(data), n, err, len(data))
+ }
+ if !bytes.Equal(data, content) {
+ return fmt.Errorf("ReadAt(%d, 0) = %q\n\twant %q", len(data), data, content)
+ }
+
+ n, err = r.ReadAt(data[:1], int64(len(data)))
+ if n != 0 || err != io.EOF {
+ return fmt.Errorf("ReadAt(1, %d) = %v, %v, want 0, EOF", len(data), n, err)
+ }
+
+ for i := range data {
+ data[i] = 0xfe
+ }
+ n, err = r.ReadAt(data[:cap(data)], 0)
+ if n != len(data) || err != io.EOF {
+ return fmt.Errorf("ReadAt(%d, 0) = %v, %v, want %d, EOF", cap(data), n, err, len(data))
+ }
+ if !bytes.Equal(data, content) {
+ return fmt.Errorf("ReadAt(%d, 0) = %q\n\twant %q", len(data), data, content)
+ }
+
+ for i := range data {
+ data[i] = 0xfe
+ }
+ for i := range data {
+ n, err = r.ReadAt(data[i:i+1], int64(i))
+ if n != 1 || err != nil && (i != len(data)-1 || err != io.EOF) {
+ want := "nil"
+ if i == len(data)-1 {
+ want = "nil or EOF"
+ }
+ return fmt.Errorf("ReadAt(1, %d) = %v, %v, want 1, %s", i, n, err, want)
+ }
+ if data[i] != content[i] {
+ return fmt.Errorf("ReadAt(1, %d) = %q want %q", i, data[i:i+1], content[i:i+1])
+ }
+ }
+ }
+ return nil
+}
diff --git a/src/testing/iotest/reader_test.go b/src/testing/iotest/reader_test.go
new file mode 100644
index 0000000..1d22237
--- /dev/null
+++ b/src/testing/iotest/reader_test.go
@@ -0,0 +1,261 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package iotest
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "strings"
+ "testing"
+)
+
+func TestOneByteReader_nonEmptyReader(t *testing.T) {
+ msg := "Hello, World!"
+ buf := new(bytes.Buffer)
+ buf.WriteString(msg)
+
+ obr := OneByteReader(buf)
+ var b []byte
+ n, err := obr.Read(b)
+ if err != nil || n != 0 {
+ t.Errorf("Empty buffer read returned n=%d err=%v", n, err)
+ }
+
+ b = make([]byte, 3)
+ // Read from obr until EOF.
+ got := new(strings.Builder)
+ for i := 0; ; i++ {
+ n, err = obr.Read(b)
+ if err != nil {
+ break
+ }
+ if g, w := n, 1; g != w {
+ t.Errorf("Iteration #%d read %d bytes, want %d", i, g, w)
+ }
+ got.Write(b[:n])
+ }
+ if g, w := err, io.EOF; g != w {
+ t.Errorf("Unexpected error after reading all bytes\n\tGot: %v\n\tWant: %v", g, w)
+ }
+ if g, w := got.String(), "Hello, World!"; g != w {
+ t.Errorf("Read mismatch\n\tGot: %q\n\tWant: %q", g, w)
+ }
+}
+
+func TestOneByteReader_emptyReader(t *testing.T) {
+ r := new(bytes.Buffer)
+
+ obr := OneByteReader(r)
+ var b []byte
+ if n, err := obr.Read(b); err != nil || n != 0 {
+ t.Errorf("Empty buffer read returned n=%d err=%v", n, err)
+ }
+
+ b = make([]byte, 5)
+ n, err := obr.Read(b)
+ if g, w := err, io.EOF; g != w {
+ t.Errorf("Error mismatch\n\tGot: %v\n\tWant: %v", g, w)
+ }
+ if g, w := n, 0; g != w {
+ t.Errorf("Unexpectedly read %d bytes, wanted %d", g, w)
+ }
+}
+
+func TestHalfReader_nonEmptyReader(t *testing.T) {
+ msg := "Hello, World!"
+ buf := new(bytes.Buffer)
+ buf.WriteString(msg)
+ // empty read buffer
+ hr := HalfReader(buf)
+ var b []byte
+ n, err := hr.Read(b)
+ if err != nil || n != 0 {
+ t.Errorf("Empty buffer read returned n=%d err=%v", n, err)
+ }
+ // non empty read buffer
+ b = make([]byte, 2)
+ got := new(strings.Builder)
+ for i := 0; ; i++ {
+ n, err = hr.Read(b)
+ if err != nil {
+ break
+ }
+ if g, w := n, 1; g != w {
+ t.Errorf("Iteration #%d read %d bytes, want %d", i, g, w)
+ }
+ got.Write(b[:n])
+ }
+ if g, w := err, io.EOF; g != w {
+ t.Errorf("Unexpected error after reading all bytes\n\tGot: %v\n\tWant: %v", g, w)
+ }
+ if g, w := got.String(), "Hello, World!"; g != w {
+ t.Errorf("Read mismatch\n\tGot: %q\n\tWant: %q", g, w)
+ }
+}
+
+func TestHalfReader_emptyReader(t *testing.T) {
+ r := new(bytes.Buffer)
+
+ hr := HalfReader(r)
+ var b []byte
+ if n, err := hr.Read(b); err != nil || n != 0 {
+ t.Errorf("Empty buffer read returned n=%d err=%v", n, err)
+ }
+
+ b = make([]byte, 5)
+ n, err := hr.Read(b)
+ if g, w := err, io.EOF; g != w {
+ t.Errorf("Error mismatch\n\tGot: %v\n\tWant: %v", g, w)
+ }
+ if g, w := n, 0; g != w {
+ t.Errorf("Unexpectedly read %d bytes, wanted %d", g, w)
+ }
+}
+
+func TestTimeOutReader_nonEmptyReader(t *testing.T) {
+ msg := "Hello, World!"
+ buf := new(bytes.Buffer)
+ buf.WriteString(msg)
+ // empty read buffer
+ tor := TimeoutReader(buf)
+ var b []byte
+ n, err := tor.Read(b)
+ if err != nil || n != 0 {
+ t.Errorf("Empty buffer read returned n=%d err=%v", n, err)
+ }
+ // Second call should timeout
+ n, err = tor.Read(b)
+ if g, w := err, ErrTimeout; g != w {
+ t.Errorf("Error mismatch\n\tGot: %v\n\tWant: %v", g, w)
+ }
+ if g, w := n, 0; g != w {
+ t.Errorf("Unexpectedly read %d bytes, wanted %d", g, w)
+ }
+ // non empty read buffer
+ tor2 := TimeoutReader(buf)
+ b = make([]byte, 3)
+ if n, err := tor2.Read(b); err != nil || n == 0 {
+ t.Errorf("Empty buffer read returned n=%d err=%v", n, err)
+ }
+ // Second call should timeout
+ n, err = tor2.Read(b)
+ if g, w := err, ErrTimeout; g != w {
+ t.Errorf("Error mismatch\n\tGot: %v\n\tWant: %v", g, w)
+ }
+ if g, w := n, 0; g != w {
+ t.Errorf("Unexpectedly read %d bytes, wanted %d", g, w)
+ }
+}
+
+func TestTimeOutReader_emptyReader(t *testing.T) {
+ r := new(bytes.Buffer)
+ // empty read buffer
+ tor := TimeoutReader(r)
+ var b []byte
+ if n, err := tor.Read(b); err != nil || n != 0 {
+ t.Errorf("Empty buffer read returned n=%d err=%v", n, err)
+ }
+ // Second call should timeout
+ n, err := tor.Read(b)
+ if g, w := err, ErrTimeout; g != w {
+ t.Errorf("Error mismatch\n\tGot: %v\n\tWant: %v", g, w)
+ }
+ if g, w := n, 0; g != w {
+ t.Errorf("Unexpectedly read %d bytes, wanted %d", g, w)
+ }
+ // non empty read buffer
+ tor2 := TimeoutReader(r)
+ b = make([]byte, 5)
+ if n, err := tor2.Read(b); err != io.EOF || n != 0 {
+ t.Errorf("Empty buffer read returned n=%d err=%v", n, err)
+ }
+ // Second call should timeout
+ n, err = tor2.Read(b)
+ if g, w := err, ErrTimeout; g != w {
+ t.Errorf("Error mismatch\n\tGot: %v\n\tWant: %v", g, w)
+ }
+ if g, w := n, 0; g != w {
+ t.Errorf("Unexpectedly read %d bytes, wanted %d", g, w)
+ }
+}
+
+func TestDataErrReader_nonEmptyReader(t *testing.T) {
+ msg := "Hello, World!"
+ buf := new(bytes.Buffer)
+ buf.WriteString(msg)
+
+ der := DataErrReader(buf)
+
+ b := make([]byte, 3)
+ got := new(strings.Builder)
+ var n int
+ var err error
+ for {
+ n, err = der.Read(b)
+ got.Write(b[:n])
+ if err != nil {
+ break
+ }
+ }
+ if err != io.EOF || n == 0 {
+ t.Errorf("Last Read returned n=%d err=%v", n, err)
+ }
+ if g, w := got.String(), "Hello, World!"; g != w {
+ t.Errorf("Read mismatch\n\tGot: %q\n\tWant: %q", g, w)
+ }
+}
+
+func TestDataErrReader_emptyReader(t *testing.T) {
+ r := new(bytes.Buffer)
+
+ der := DataErrReader(r)
+ var b []byte
+ if n, err := der.Read(b); err != io.EOF || n != 0 {
+ t.Errorf("Empty buffer read returned n=%d err=%v", n, err)
+ }
+
+ b = make([]byte, 5)
+ n, err := der.Read(b)
+ if g, w := err, io.EOF; g != w {
+ t.Errorf("Error mismatch\n\tGot: %v\n\tWant: %v", g, w)
+ }
+ if g, w := n, 0; g != w {
+ t.Errorf("Unexpectedly read %d bytes, wanted %d", g, w)
+ }
+}
+
+func TestErrReader(t *testing.T) {
+ cases := []struct {
+ name string
+ err error
+ }{
+ {"nil error", nil},
+ {"non-nil error", errors.New("io failure")},
+ {"io.EOF", io.EOF},
+ }
+
+ for _, tt := range cases {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ n, err := ErrReader(tt.err).Read(nil)
+ if err != tt.err {
+ t.Fatalf("Error mismatch\nGot: %v\nWant: %v", err, tt.err)
+ }
+ if n != 0 {
+ t.Fatalf("Byte count mismatch: got %d want 0", n)
+ }
+ })
+ }
+}
+
+func TestStringsReader(t *testing.T) {
+ const msg = "Now is the time for all good gophers."
+
+ r := strings.NewReader(msg)
+ if err := TestReader(r, []byte(msg)); err != nil {
+ t.Fatal(err)
+ }
+}
diff --git a/src/testing/iotest/writer.go b/src/testing/iotest/writer.go
new file mode 100644
index 0000000..af61ab8
--- /dev/null
+++ b/src/testing/iotest/writer.go
@@ -0,0 +1,35 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package iotest
+
+import "io"
+
+// TruncateWriter returns a Writer that writes to w
+// but stops silently after n bytes.
+func TruncateWriter(w io.Writer, n int64) io.Writer {
+ return &truncateWriter{w, n}
+}
+
+type truncateWriter struct {
+ w io.Writer
+ n int64
+}
+
+func (t *truncateWriter) Write(p []byte) (n int, err error) {
+ if t.n <= 0 {
+ return len(p), nil
+ }
+ // real write
+ n = len(p)
+ if int64(n) > t.n {
+ n = int(t.n)
+ }
+ n, err = t.w.Write(p[0:n])
+ t.n -= int64(n)
+ if err == nil {
+ n = len(p)
+ }
+ return
+}
diff --git a/src/testing/iotest/writer_test.go b/src/testing/iotest/writer_test.go
new file mode 100644
index 0000000..2762513
--- /dev/null
+++ b/src/testing/iotest/writer_test.go
@@ -0,0 +1,39 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package iotest
+
+import (
+ "strings"
+ "testing"
+)
+
+var truncateWriterTests = []struct {
+ in string
+ want string
+ trunc int64
+ n int
+}{
+ {"hello", "", -1, 5},
+ {"world", "", 0, 5},
+ {"abcde", "abc", 3, 5},
+ {"edcba", "edcba", 7, 5},
+}
+
+func TestTruncateWriter(t *testing.T) {
+ for _, tt := range truncateWriterTests {
+ buf := new(strings.Builder)
+ tw := TruncateWriter(buf, tt.trunc)
+ n, err := tw.Write([]byte(tt.in))
+ if err != nil {
+ t.Errorf("Unexpected error %v for\n\t%+v", err, tt)
+ }
+ if g, w := buf.String(), tt.want; g != w {
+ t.Errorf("got %q, expected %q", g, w)
+ }
+ if g, w := n, tt.n; g != w {
+ t.Errorf("read %d bytes, but expected to have read %d bytes for\n\t%+v", g, w, tt)
+ }
+ }
+}