diff options
Diffstat (limited to 'src/testing/iotest')
-rw-r--r-- | src/testing/iotest/example_test.go | 22 | ||||
-rw-r--r-- | src/testing/iotest/logger.go | 54 | ||||
-rw-r--r-- | src/testing/iotest/logger_test.go | 153 | ||||
-rw-r--r-- | src/testing/iotest/reader.go | 268 | ||||
-rw-r--r-- | src/testing/iotest/reader_test.go | 261 | ||||
-rw-r--r-- | src/testing/iotest/writer.go | 35 | ||||
-rw-r--r-- | src/testing/iotest/writer_test.go | 39 |
7 files changed, 832 insertions, 0 deletions
diff --git a/src/testing/iotest/example_test.go b/src/testing/iotest/example_test.go new file mode 100644 index 0000000..10f6bd3 --- /dev/null +++ b/src/testing/iotest/example_test.go @@ -0,0 +1,22 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package iotest_test + +import ( + "errors" + "fmt" + "testing/iotest" +) + +func ExampleErrReader() { + // A reader that always returns a custom error. + r := iotest.ErrReader(errors.New("custom error")) + n, err := r.Read(nil) + fmt.Printf("n: %d\nerr: %q\n", n, err) + + // Output: + // n: 0 + // err: "custom error" +} diff --git a/src/testing/iotest/logger.go b/src/testing/iotest/logger.go new file mode 100644 index 0000000..99548dc --- /dev/null +++ b/src/testing/iotest/logger.go @@ -0,0 +1,54 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package iotest + +import ( + "io" + "log" +) + +type writeLogger struct { + prefix string + w io.Writer +} + +func (l *writeLogger) Write(p []byte) (n int, err error) { + n, err = l.w.Write(p) + if err != nil { + log.Printf("%s %x: %v", l.prefix, p[0:n], err) + } else { + log.Printf("%s %x", l.prefix, p[0:n]) + } + return +} + +// NewWriteLogger returns a writer that behaves like w except +// that it logs (using log.Printf) each write to standard error, +// printing the prefix and the hexadecimal data written. +func NewWriteLogger(prefix string, w io.Writer) io.Writer { + return &writeLogger{prefix, w} +} + +type readLogger struct { + prefix string + r io.Reader +} + +func (l *readLogger) Read(p []byte) (n int, err error) { + n, err = l.r.Read(p) + if err != nil { + log.Printf("%s %x: %v", l.prefix, p[0:n], err) + } else { + log.Printf("%s %x", l.prefix, p[0:n]) + } + return +} + +// NewReadLogger returns a reader that behaves like r except +// that it logs (using log.Printf) each read to standard error, +// printing the prefix and the hexadecimal data read. +func NewReadLogger(prefix string, r io.Reader) io.Reader { + return &readLogger{prefix, r} +} diff --git a/src/testing/iotest/logger_test.go b/src/testing/iotest/logger_test.go new file mode 100644 index 0000000..7a7d0aa --- /dev/null +++ b/src/testing/iotest/logger_test.go @@ -0,0 +1,153 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package iotest + +import ( + "bytes" + "errors" + "fmt" + "log" + "strings" + "testing" +) + +type errWriter struct { + err error +} + +func (w errWriter) Write([]byte) (int, error) { + return 0, w.err +} + +func TestWriteLogger(t *testing.T) { + olw := log.Writer() + olf := log.Flags() + olp := log.Prefix() + + // Revert the original log settings before we exit. + defer func() { + log.SetFlags(olf) + log.SetPrefix(olp) + log.SetOutput(olw) + }() + + lOut := new(strings.Builder) + log.SetPrefix("lw: ") + log.SetOutput(lOut) + log.SetFlags(0) + + lw := new(strings.Builder) + wl := NewWriteLogger("write:", lw) + if _, err := wl.Write([]byte("Hello, World!")); err != nil { + t.Fatalf("Unexpectedly failed to write: %v", err) + } + + if g, w := lw.String(), "Hello, World!"; g != w { + t.Errorf("WriteLogger mismatch\n\tgot: %q\n\twant: %q", g, w) + } + wantLogWithHex := fmt.Sprintf("lw: write: %x\n", "Hello, World!") + if g, w := lOut.String(), wantLogWithHex; g != w { + t.Errorf("WriteLogger mismatch\n\tgot: %q\n\twant: %q", g, w) + } +} + +func TestWriteLogger_errorOnWrite(t *testing.T) { + olw := log.Writer() + olf := log.Flags() + olp := log.Prefix() + + // Revert the original log settings before we exit. + defer func() { + log.SetFlags(olf) + log.SetPrefix(olp) + log.SetOutput(olw) + }() + + lOut := new(strings.Builder) + log.SetPrefix("lw: ") + log.SetOutput(lOut) + log.SetFlags(0) + + lw := errWriter{err: errors.New("Write Error!")} + wl := NewWriteLogger("write:", lw) + if _, err := wl.Write([]byte("Hello, World!")); err == nil { + t.Fatalf("Unexpectedly succeeded to write: %v", err) + } + + wantLogWithHex := fmt.Sprintf("lw: write: %x: %v\n", "", "Write Error!") + if g, w := lOut.String(), wantLogWithHex; g != w { + t.Errorf("WriteLogger mismatch\n\tgot: %q\n\twant: %q", g, w) + } +} + +func TestReadLogger(t *testing.T) { + olw := log.Writer() + olf := log.Flags() + olp := log.Prefix() + + // Revert the original log settings before we exit. + defer func() { + log.SetFlags(olf) + log.SetPrefix(olp) + log.SetOutput(olw) + }() + + lOut := new(strings.Builder) + log.SetPrefix("lr: ") + log.SetOutput(lOut) + log.SetFlags(0) + + data := []byte("Hello, World!") + p := make([]byte, len(data)) + lr := bytes.NewReader(data) + rl := NewReadLogger("read:", lr) + + n, err := rl.Read(p) + if err != nil { + t.Fatalf("Unexpectedly failed to read: %v", err) + } + + if g, w := p[:n], data; !bytes.Equal(g, w) { + t.Errorf("ReadLogger mismatch\n\tgot: %q\n\twant: %q", g, w) + } + + wantLogWithHex := fmt.Sprintf("lr: read: %x\n", "Hello, World!") + if g, w := lOut.String(), wantLogWithHex; g != w { + t.Errorf("ReadLogger mismatch\n\tgot: %q\n\twant: %q", g, w) + } +} + +func TestReadLogger_errorOnRead(t *testing.T) { + olw := log.Writer() + olf := log.Flags() + olp := log.Prefix() + + // Revert the original log settings before we exit. + defer func() { + log.SetFlags(olf) + log.SetPrefix(olp) + log.SetOutput(olw) + }() + + lOut := new(strings.Builder) + log.SetPrefix("lr: ") + log.SetOutput(lOut) + log.SetFlags(0) + + data := []byte("Hello, World!") + p := make([]byte, len(data)) + + lr := ErrReader(errors.New("io failure")) + rl := NewReadLogger("read", lr) + n, err := rl.Read(p) + if err == nil { + t.Fatalf("Unexpectedly succeeded to read: %v", err) + } + + wantLogWithHex := fmt.Sprintf("lr: read %x: io failure\n", p[:n]) + if g, w := lOut.String(), wantLogWithHex; g != w { + t.Errorf("ReadLogger mismatch\n\tgot: %q\n\twant: %q", g, w) + } +} diff --git a/src/testing/iotest/reader.go b/src/testing/iotest/reader.go new file mode 100644 index 0000000..770d87f --- /dev/null +++ b/src/testing/iotest/reader.go @@ -0,0 +1,268 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package iotest implements Readers and Writers useful mainly for testing. +package iotest + +import ( + "bytes" + "errors" + "fmt" + "io" +) + +// OneByteReader returns a Reader that implements +// each non-empty Read by reading one byte from r. +func OneByteReader(r io.Reader) io.Reader { return &oneByteReader{r} } + +type oneByteReader struct { + r io.Reader +} + +func (r *oneByteReader) Read(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + return r.r.Read(p[0:1]) +} + +// HalfReader returns a Reader that implements Read +// by reading half as many requested bytes from r. +func HalfReader(r io.Reader) io.Reader { return &halfReader{r} } + +type halfReader struct { + r io.Reader +} + +func (r *halfReader) Read(p []byte) (int, error) { + return r.r.Read(p[0 : (len(p)+1)/2]) +} + +// DataErrReader changes the way errors are handled by a Reader. Normally, a +// Reader returns an error (typically EOF) from the first Read call after the +// last piece of data is read. DataErrReader wraps a Reader and changes its +// behavior so the final error is returned along with the final data, instead +// of in the first call after the final data. +func DataErrReader(r io.Reader) io.Reader { return &dataErrReader{r, nil, make([]byte, 1024)} } + +type dataErrReader struct { + r io.Reader + unread []byte + data []byte +} + +func (r *dataErrReader) Read(p []byte) (n int, err error) { + // loop because first call needs two reads: + // one to get data and a second to look for an error. + for { + if len(r.unread) == 0 { + n1, err1 := r.r.Read(r.data) + r.unread = r.data[0:n1] + err = err1 + } + if n > 0 || err != nil { + break + } + n = copy(p, r.unread) + r.unread = r.unread[n:] + } + return +} + +// ErrTimeout is a fake timeout error. +var ErrTimeout = errors.New("timeout") + +// TimeoutReader returns ErrTimeout on the second read +// with no data. Subsequent calls to read succeed. +func TimeoutReader(r io.Reader) io.Reader { return &timeoutReader{r, 0} } + +type timeoutReader struct { + r io.Reader + count int +} + +func (r *timeoutReader) Read(p []byte) (int, error) { + r.count++ + if r.count == 2 { + return 0, ErrTimeout + } + return r.r.Read(p) +} + +// ErrReader returns an io.Reader that returns 0, err from all Read calls. +func ErrReader(err error) io.Reader { + return &errReader{err: err} +} + +type errReader struct { + err error +} + +func (r *errReader) Read(p []byte) (int, error) { + return 0, r.err +} + +type smallByteReader struct { + r io.Reader + off int + n int +} + +func (r *smallByteReader) Read(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + r.n = r.n%3 + 1 + n := r.n + if n > len(p) { + n = len(p) + } + n, err := r.r.Read(p[0:n]) + if err != nil && err != io.EOF { + err = fmt.Errorf("Read(%d bytes at offset %d): %v", n, r.off, err) + } + r.off += n + return n, err +} + +// TestReader tests that reading from r returns the expected file content. +// It does reads of different sizes, until EOF. +// If r implements io.ReaderAt or io.Seeker, TestReader also checks +// that those operations behave as they should. +// +// If TestReader finds any misbehaviors, it returns an error reporting them. +// The error text may span multiple lines. +func TestReader(r io.Reader, content []byte) error { + if len(content) > 0 { + n, err := r.Read(nil) + if n != 0 || err != nil { + return fmt.Errorf("Read(0) = %d, %v, want 0, nil", n, err) + } + } + + data, err := io.ReadAll(&smallByteReader{r: r}) + if err != nil { + return err + } + if !bytes.Equal(data, content) { + return fmt.Errorf("ReadAll(small amounts) = %q\n\twant %q", data, content) + } + n, err := r.Read(make([]byte, 10)) + if n != 0 || err != io.EOF { + return fmt.Errorf("Read(10) at EOF = %v, %v, want 0, EOF", n, err) + } + + if r, ok := r.(io.ReadSeeker); ok { + // Seek(0, 1) should report the current file position (EOF). + if off, err := r.Seek(0, 1); off != int64(len(content)) || err != nil { + return fmt.Errorf("Seek(0, 1) from EOF = %d, %v, want %d, nil", off, err, len(content)) + } + + // Seek backward partway through file, in two steps. + // If middle == 0, len(content) == 0, can't use the -1 and +1 seeks. + middle := len(content) - len(content)/3 + if middle > 0 { + if off, err := r.Seek(-1, 1); off != int64(len(content)-1) || err != nil { + return fmt.Errorf("Seek(-1, 1) from EOF = %d, %v, want %d, nil", -off, err, len(content)-1) + } + if off, err := r.Seek(int64(-len(content)/3), 1); off != int64(middle-1) || err != nil { + return fmt.Errorf("Seek(%d, 1) from %d = %d, %v, want %d, nil", -len(content)/3, len(content)-1, off, err, middle-1) + } + if off, err := r.Seek(+1, 1); off != int64(middle) || err != nil { + return fmt.Errorf("Seek(+1, 1) from %d = %d, %v, want %d, nil", middle-1, off, err, middle) + } + } + + // Seek(0, 1) should report the current file position (middle). + if off, err := r.Seek(0, 1); off != int64(middle) || err != nil { + return fmt.Errorf("Seek(0, 1) from %d = %d, %v, want %d, nil", middle, off, err, middle) + } + + // Reading forward should return the last part of the file. + data, err := io.ReadAll(&smallByteReader{r: r}) + if err != nil { + return fmt.Errorf("ReadAll from offset %d: %v", middle, err) + } + if !bytes.Equal(data, content[middle:]) { + return fmt.Errorf("ReadAll from offset %d = %q\n\twant %q", middle, data, content[middle:]) + } + + // Seek relative to end of file, but start elsewhere. + if off, err := r.Seek(int64(middle/2), 0); off != int64(middle/2) || err != nil { + return fmt.Errorf("Seek(%d, 0) from EOF = %d, %v, want %d, nil", middle/2, off, err, middle/2) + } + if off, err := r.Seek(int64(-len(content)/3), 2); off != int64(middle) || err != nil { + return fmt.Errorf("Seek(%d, 2) from %d = %d, %v, want %d, nil", -len(content)/3, middle/2, off, err, middle) + } + + // Reading forward should return the last part of the file (again). + data, err = io.ReadAll(&smallByteReader{r: r}) + if err != nil { + return fmt.Errorf("ReadAll from offset %d: %v", middle, err) + } + if !bytes.Equal(data, content[middle:]) { + return fmt.Errorf("ReadAll from offset %d = %q\n\twant %q", middle, data, content[middle:]) + } + + // Absolute seek & read forward. + if off, err := r.Seek(int64(middle/2), 0); off != int64(middle/2) || err != nil { + return fmt.Errorf("Seek(%d, 0) from EOF = %d, %v, want %d, nil", middle/2, off, err, middle/2) + } + data, err = io.ReadAll(r) + if err != nil { + return fmt.Errorf("ReadAll from offset %d: %v", middle/2, err) + } + if !bytes.Equal(data, content[middle/2:]) { + return fmt.Errorf("ReadAll from offset %d = %q\n\twant %q", middle/2, data, content[middle/2:]) + } + } + + if r, ok := r.(io.ReaderAt); ok { + data := make([]byte, len(content), len(content)+1) + for i := range data { + data[i] = 0xfe + } + n, err := r.ReadAt(data, 0) + if n != len(data) || err != nil && err != io.EOF { + return fmt.Errorf("ReadAt(%d, 0) = %v, %v, want %d, nil or EOF", len(data), n, err, len(data)) + } + if !bytes.Equal(data, content) { + return fmt.Errorf("ReadAt(%d, 0) = %q\n\twant %q", len(data), data, content) + } + + n, err = r.ReadAt(data[:1], int64(len(data))) + if n != 0 || err != io.EOF { + return fmt.Errorf("ReadAt(1, %d) = %v, %v, want 0, EOF", len(data), n, err) + } + + for i := range data { + data[i] = 0xfe + } + n, err = r.ReadAt(data[:cap(data)], 0) + if n != len(data) || err != io.EOF { + return fmt.Errorf("ReadAt(%d, 0) = %v, %v, want %d, EOF", cap(data), n, err, len(data)) + } + if !bytes.Equal(data, content) { + return fmt.Errorf("ReadAt(%d, 0) = %q\n\twant %q", len(data), data, content) + } + + for i := range data { + data[i] = 0xfe + } + for i := range data { + n, err = r.ReadAt(data[i:i+1], int64(i)) + if n != 1 || err != nil && (i != len(data)-1 || err != io.EOF) { + want := "nil" + if i == len(data)-1 { + want = "nil or EOF" + } + return fmt.Errorf("ReadAt(1, %d) = %v, %v, want 1, %s", i, n, err, want) + } + if data[i] != content[i] { + return fmt.Errorf("ReadAt(1, %d) = %q want %q", i, data[i:i+1], content[i:i+1]) + } + } + } + return nil +} diff --git a/src/testing/iotest/reader_test.go b/src/testing/iotest/reader_test.go new file mode 100644 index 0000000..1d22237 --- /dev/null +++ b/src/testing/iotest/reader_test.go @@ -0,0 +1,261 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package iotest + +import ( + "bytes" + "errors" + "io" + "strings" + "testing" +) + +func TestOneByteReader_nonEmptyReader(t *testing.T) { + msg := "Hello, World!" + buf := new(bytes.Buffer) + buf.WriteString(msg) + + obr := OneByteReader(buf) + var b []byte + n, err := obr.Read(b) + if err != nil || n != 0 { + t.Errorf("Empty buffer read returned n=%d err=%v", n, err) + } + + b = make([]byte, 3) + // Read from obr until EOF. + got := new(strings.Builder) + for i := 0; ; i++ { + n, err = obr.Read(b) + if err != nil { + break + } + if g, w := n, 1; g != w { + t.Errorf("Iteration #%d read %d bytes, want %d", i, g, w) + } + got.Write(b[:n]) + } + if g, w := err, io.EOF; g != w { + t.Errorf("Unexpected error after reading all bytes\n\tGot: %v\n\tWant: %v", g, w) + } + if g, w := got.String(), "Hello, World!"; g != w { + t.Errorf("Read mismatch\n\tGot: %q\n\tWant: %q", g, w) + } +} + +func TestOneByteReader_emptyReader(t *testing.T) { + r := new(bytes.Buffer) + + obr := OneByteReader(r) + var b []byte + if n, err := obr.Read(b); err != nil || n != 0 { + t.Errorf("Empty buffer read returned n=%d err=%v", n, err) + } + + b = make([]byte, 5) + n, err := obr.Read(b) + if g, w := err, io.EOF; g != w { + t.Errorf("Error mismatch\n\tGot: %v\n\tWant: %v", g, w) + } + if g, w := n, 0; g != w { + t.Errorf("Unexpectedly read %d bytes, wanted %d", g, w) + } +} + +func TestHalfReader_nonEmptyReader(t *testing.T) { + msg := "Hello, World!" + buf := new(bytes.Buffer) + buf.WriteString(msg) + // empty read buffer + hr := HalfReader(buf) + var b []byte + n, err := hr.Read(b) + if err != nil || n != 0 { + t.Errorf("Empty buffer read returned n=%d err=%v", n, err) + } + // non empty read buffer + b = make([]byte, 2) + got := new(strings.Builder) + for i := 0; ; i++ { + n, err = hr.Read(b) + if err != nil { + break + } + if g, w := n, 1; g != w { + t.Errorf("Iteration #%d read %d bytes, want %d", i, g, w) + } + got.Write(b[:n]) + } + if g, w := err, io.EOF; g != w { + t.Errorf("Unexpected error after reading all bytes\n\tGot: %v\n\tWant: %v", g, w) + } + if g, w := got.String(), "Hello, World!"; g != w { + t.Errorf("Read mismatch\n\tGot: %q\n\tWant: %q", g, w) + } +} + +func TestHalfReader_emptyReader(t *testing.T) { + r := new(bytes.Buffer) + + hr := HalfReader(r) + var b []byte + if n, err := hr.Read(b); err != nil || n != 0 { + t.Errorf("Empty buffer read returned n=%d err=%v", n, err) + } + + b = make([]byte, 5) + n, err := hr.Read(b) + if g, w := err, io.EOF; g != w { + t.Errorf("Error mismatch\n\tGot: %v\n\tWant: %v", g, w) + } + if g, w := n, 0; g != w { + t.Errorf("Unexpectedly read %d bytes, wanted %d", g, w) + } +} + +func TestTimeOutReader_nonEmptyReader(t *testing.T) { + msg := "Hello, World!" + buf := new(bytes.Buffer) + buf.WriteString(msg) + // empty read buffer + tor := TimeoutReader(buf) + var b []byte + n, err := tor.Read(b) + if err != nil || n != 0 { + t.Errorf("Empty buffer read returned n=%d err=%v", n, err) + } + // Second call should timeout + n, err = tor.Read(b) + if g, w := err, ErrTimeout; g != w { + t.Errorf("Error mismatch\n\tGot: %v\n\tWant: %v", g, w) + } + if g, w := n, 0; g != w { + t.Errorf("Unexpectedly read %d bytes, wanted %d", g, w) + } + // non empty read buffer + tor2 := TimeoutReader(buf) + b = make([]byte, 3) + if n, err := tor2.Read(b); err != nil || n == 0 { + t.Errorf("Empty buffer read returned n=%d err=%v", n, err) + } + // Second call should timeout + n, err = tor2.Read(b) + if g, w := err, ErrTimeout; g != w { + t.Errorf("Error mismatch\n\tGot: %v\n\tWant: %v", g, w) + } + if g, w := n, 0; g != w { + t.Errorf("Unexpectedly read %d bytes, wanted %d", g, w) + } +} + +func TestTimeOutReader_emptyReader(t *testing.T) { + r := new(bytes.Buffer) + // empty read buffer + tor := TimeoutReader(r) + var b []byte + if n, err := tor.Read(b); err != nil || n != 0 { + t.Errorf("Empty buffer read returned n=%d err=%v", n, err) + } + // Second call should timeout + n, err := tor.Read(b) + if g, w := err, ErrTimeout; g != w { + t.Errorf("Error mismatch\n\tGot: %v\n\tWant: %v", g, w) + } + if g, w := n, 0; g != w { + t.Errorf("Unexpectedly read %d bytes, wanted %d", g, w) + } + // non empty read buffer + tor2 := TimeoutReader(r) + b = make([]byte, 5) + if n, err := tor2.Read(b); err != io.EOF || n != 0 { + t.Errorf("Empty buffer read returned n=%d err=%v", n, err) + } + // Second call should timeout + n, err = tor2.Read(b) + if g, w := err, ErrTimeout; g != w { + t.Errorf("Error mismatch\n\tGot: %v\n\tWant: %v", g, w) + } + if g, w := n, 0; g != w { + t.Errorf("Unexpectedly read %d bytes, wanted %d", g, w) + } +} + +func TestDataErrReader_nonEmptyReader(t *testing.T) { + msg := "Hello, World!" + buf := new(bytes.Buffer) + buf.WriteString(msg) + + der := DataErrReader(buf) + + b := make([]byte, 3) + got := new(strings.Builder) + var n int + var err error + for { + n, err = der.Read(b) + got.Write(b[:n]) + if err != nil { + break + } + } + if err != io.EOF || n == 0 { + t.Errorf("Last Read returned n=%d err=%v", n, err) + } + if g, w := got.String(), "Hello, World!"; g != w { + t.Errorf("Read mismatch\n\tGot: %q\n\tWant: %q", g, w) + } +} + +func TestDataErrReader_emptyReader(t *testing.T) { + r := new(bytes.Buffer) + + der := DataErrReader(r) + var b []byte + if n, err := der.Read(b); err != io.EOF || n != 0 { + t.Errorf("Empty buffer read returned n=%d err=%v", n, err) + } + + b = make([]byte, 5) + n, err := der.Read(b) + if g, w := err, io.EOF; g != w { + t.Errorf("Error mismatch\n\tGot: %v\n\tWant: %v", g, w) + } + if g, w := n, 0; g != w { + t.Errorf("Unexpectedly read %d bytes, wanted %d", g, w) + } +} + +func TestErrReader(t *testing.T) { + cases := []struct { + name string + err error + }{ + {"nil error", nil}, + {"non-nil error", errors.New("io failure")}, + {"io.EOF", io.EOF}, + } + + for _, tt := range cases { + tt := tt + t.Run(tt.name, func(t *testing.T) { + n, err := ErrReader(tt.err).Read(nil) + if err != tt.err { + t.Fatalf("Error mismatch\nGot: %v\nWant: %v", err, tt.err) + } + if n != 0 { + t.Fatalf("Byte count mismatch: got %d want 0", n) + } + }) + } +} + +func TestStringsReader(t *testing.T) { + const msg = "Now is the time for all good gophers." + + r := strings.NewReader(msg) + if err := TestReader(r, []byte(msg)); err != nil { + t.Fatal(err) + } +} diff --git a/src/testing/iotest/writer.go b/src/testing/iotest/writer.go new file mode 100644 index 0000000..af61ab8 --- /dev/null +++ b/src/testing/iotest/writer.go @@ -0,0 +1,35 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package iotest + +import "io" + +// TruncateWriter returns a Writer that writes to w +// but stops silently after n bytes. +func TruncateWriter(w io.Writer, n int64) io.Writer { + return &truncateWriter{w, n} +} + +type truncateWriter struct { + w io.Writer + n int64 +} + +func (t *truncateWriter) Write(p []byte) (n int, err error) { + if t.n <= 0 { + return len(p), nil + } + // real write + n = len(p) + if int64(n) > t.n { + n = int(t.n) + } + n, err = t.w.Write(p[0:n]) + t.n -= int64(n) + if err == nil { + n = len(p) + } + return +} diff --git a/src/testing/iotest/writer_test.go b/src/testing/iotest/writer_test.go new file mode 100644 index 0000000..2762513 --- /dev/null +++ b/src/testing/iotest/writer_test.go @@ -0,0 +1,39 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package iotest + +import ( + "strings" + "testing" +) + +var truncateWriterTests = []struct { + in string + want string + trunc int64 + n int +}{ + {"hello", "", -1, 5}, + {"world", "", 0, 5}, + {"abcde", "abc", 3, 5}, + {"edcba", "edcba", 7, 5}, +} + +func TestTruncateWriter(t *testing.T) { + for _, tt := range truncateWriterTests { + buf := new(strings.Builder) + tw := TruncateWriter(buf, tt.trunc) + n, err := tw.Write([]byte(tt.in)) + if err != nil { + t.Errorf("Unexpected error %v for\n\t%+v", err, tt) + } + if g, w := buf.String(), tt.want; g != w { + t.Errorf("got %q, expected %q", g, w) + } + if g, w := n, tt.n; g != w { + t.Errorf("read %d bytes, but expected to have read %d bytes for\n\t%+v", g, w, tt) + } + } +} |