summaryrefslogtreecommitdiffstats
path: root/src/compress/lzw
diff options
context:
space:
mode:
Diffstat (limited to 'src/compress/lzw')
-rw-r--r--src/compress/lzw/reader.go269
-rw-r--r--src/compress/lzw/reader_test.go253
-rw-r--r--src/compress/lzw/writer.go269
-rw-r--r--src/compress/lzw/writer_test.go156
4 files changed, 947 insertions, 0 deletions
diff --git a/src/compress/lzw/reader.go b/src/compress/lzw/reader.go
new file mode 100644
index 0000000..f080211
--- /dev/null
+++ b/src/compress/lzw/reader.go
@@ -0,0 +1,269 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package lzw implements the Lempel-Ziv-Welch compressed data format,
+// described in T. A. Welch, ``A Technique for High-Performance Data
+// Compression'', Computer, 17(6) (June 1984), pp 8-19.
+//
+// In particular, it implements LZW as used by the GIF and PDF file
+// formats, which means variable-width codes up to 12 bits and the first
+// two non-literal codes are a clear code and an EOF code.
+//
+// The TIFF file format uses a similar but incompatible version of the LZW
+// algorithm. See the golang.org/x/image/tiff/lzw package for an
+// implementation.
+package lzw
+
+// TODO(nigeltao): check that PDF uses LZW in the same way as GIF,
+// modulo LSB/MSB packing order.
+
+import (
+ "bufio"
+ "errors"
+ "fmt"
+ "io"
+)
+
+// Order specifies the bit ordering in an LZW data stream.
+type Order int
+
+const (
+ // LSB means Least Significant Bits first, as used in the GIF file format.
+ LSB Order = iota
+ // MSB means Most Significant Bits first, as used in the TIFF and PDF
+ // file formats.
+ MSB
+)
+
+const (
+ maxWidth = 12
+ decoderInvalidCode = 0xffff
+ flushBuffer = 1 << maxWidth
+)
+
+// decoder is the state from which the readXxx method converts a byte
+// stream into a code stream.
+type decoder struct {
+ r io.ByteReader
+ bits uint32
+ nBits uint
+ width uint
+ read func(*decoder) (uint16, error) // readLSB or readMSB
+ litWidth int // width in bits of literal codes
+ err error
+
+ // The first 1<<litWidth codes are literal codes.
+ // The next two codes mean clear and EOF.
+ // Other valid codes are in the range [lo, hi] where lo := clear + 2,
+ // with the upper bound incrementing on each code seen.
+ //
+ // overflow is the code at which hi overflows the code width. It always
+ // equals 1 << width.
+ //
+ // last is the most recently seen code, or decoderInvalidCode.
+ //
+ // An invariant is that hi < overflow.
+ clear, eof, hi, overflow, last uint16
+
+ // Each code c in [lo, hi] expands to two or more bytes. For c != hi:
+ // suffix[c] is the last of these bytes.
+ // prefix[c] is the code for all but the last byte.
+ // This code can either be a literal code or another code in [lo, c).
+ // The c == hi case is a special case.
+ suffix [1 << maxWidth]uint8
+ prefix [1 << maxWidth]uint16
+
+ // output is the temporary output buffer.
+ // Literal codes are accumulated from the start of the buffer.
+ // Non-literal codes decode to a sequence of suffixes that are first
+ // written right-to-left from the end of the buffer before being copied
+ // to the start of the buffer.
+ // It is flushed when it contains >= 1<<maxWidth bytes,
+ // so that there is always room to decode an entire code.
+ output [2 * 1 << maxWidth]byte
+ o int // write index into output
+ toRead []byte // bytes to return from Read
+}
+
+// readLSB returns the next code for "Least Significant Bits first" data.
+func (d *decoder) readLSB() (uint16, error) {
+ for d.nBits < d.width {
+ x, err := d.r.ReadByte()
+ if err != nil {
+ return 0, err
+ }
+ d.bits |= uint32(x) << d.nBits
+ d.nBits += 8
+ }
+ code := uint16(d.bits & (1<<d.width - 1))
+ d.bits >>= d.width
+ d.nBits -= d.width
+ return code, nil
+}
+
+// readMSB returns the next code for "Most Significant Bits first" data.
+func (d *decoder) readMSB() (uint16, error) {
+ for d.nBits < d.width {
+ x, err := d.r.ReadByte()
+ if err != nil {
+ return 0, err
+ }
+ d.bits |= uint32(x) << (24 - d.nBits)
+ d.nBits += 8
+ }
+ code := uint16(d.bits >> (32 - d.width))
+ d.bits <<= d.width
+ d.nBits -= d.width
+ return code, nil
+}
+
+func (d *decoder) Read(b []byte) (int, error) {
+ for {
+ if len(d.toRead) > 0 {
+ n := copy(b, d.toRead)
+ d.toRead = d.toRead[n:]
+ return n, nil
+ }
+ if d.err != nil {
+ return 0, d.err
+ }
+ d.decode()
+ }
+}
+
+// decode decompresses bytes from r and leaves them in d.toRead.
+// read specifies how to decode bytes into codes.
+// litWidth is the width in bits of literal codes.
+func (d *decoder) decode() {
+ // Loop over the code stream, converting codes into decompressed bytes.
+loop:
+ for {
+ code, err := d.read(d)
+ if err != nil {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ d.err = err
+ break
+ }
+ switch {
+ case code < d.clear:
+ // We have a literal code.
+ d.output[d.o] = uint8(code)
+ d.o++
+ if d.last != decoderInvalidCode {
+ // Save what the hi code expands to.
+ d.suffix[d.hi] = uint8(code)
+ d.prefix[d.hi] = d.last
+ }
+ case code == d.clear:
+ d.width = 1 + uint(d.litWidth)
+ d.hi = d.eof
+ d.overflow = 1 << d.width
+ d.last = decoderInvalidCode
+ continue
+ case code == d.eof:
+ d.err = io.EOF
+ break loop
+ case code <= d.hi:
+ c, i := code, len(d.output)-1
+ if code == d.hi && d.last != decoderInvalidCode {
+ // code == hi is a special case which expands to the last expansion
+ // followed by the head of the last expansion. To find the head, we walk
+ // the prefix chain until we find a literal code.
+ c = d.last
+ for c >= d.clear {
+ c = d.prefix[c]
+ }
+ d.output[i] = uint8(c)
+ i--
+ c = d.last
+ }
+ // Copy the suffix chain into output and then write that to w.
+ for c >= d.clear {
+ d.output[i] = d.suffix[c]
+ i--
+ c = d.prefix[c]
+ }
+ d.output[i] = uint8(c)
+ d.o += copy(d.output[d.o:], d.output[i:])
+ if d.last != decoderInvalidCode {
+ // Save what the hi code expands to.
+ d.suffix[d.hi] = uint8(c)
+ d.prefix[d.hi] = d.last
+ }
+ default:
+ d.err = errors.New("lzw: invalid code")
+ break loop
+ }
+ d.last, d.hi = code, d.hi+1
+ if d.hi >= d.overflow {
+ if d.hi > d.overflow {
+ panic("unreachable")
+ }
+ if d.width == maxWidth {
+ d.last = decoderInvalidCode
+ // Undo the d.hi++ a few lines above, so that (1) we maintain
+ // the invariant that d.hi < d.overflow, and (2) d.hi does not
+ // eventually overflow a uint16.
+ d.hi--
+ } else {
+ d.width++
+ d.overflow = 1 << d.width
+ }
+ }
+ if d.o >= flushBuffer {
+ break
+ }
+ }
+ // Flush pending output.
+ d.toRead = d.output[:d.o]
+ d.o = 0
+}
+
+var errClosed = errors.New("lzw: reader/writer is closed")
+
+func (d *decoder) Close() error {
+ d.err = errClosed // in case any Reads come along
+ return nil
+}
+
+// NewReader creates a new io.ReadCloser.
+// Reads from the returned io.ReadCloser read and decompress data from r.
+// If r does not also implement io.ByteReader,
+// the decompressor may read more data than necessary from r.
+// It is the caller's responsibility to call Close on the ReadCloser when
+// finished reading.
+// The number of bits to use for literal codes, litWidth, must be in the
+// range [2,8] and is typically 8. It must equal the litWidth
+// used during compression.
+func NewReader(r io.Reader, order Order, litWidth int) io.ReadCloser {
+ d := new(decoder)
+ switch order {
+ case LSB:
+ d.read = (*decoder).readLSB
+ case MSB:
+ d.read = (*decoder).readMSB
+ default:
+ d.err = errors.New("lzw: unknown order")
+ return d
+ }
+ if litWidth < 2 || 8 < litWidth {
+ d.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
+ return d
+ }
+ if br, ok := r.(io.ByteReader); ok {
+ d.r = br
+ } else {
+ d.r = bufio.NewReader(r)
+ }
+ d.litWidth = litWidth
+ d.width = 1 + uint(litWidth)
+ d.clear = uint16(1) << uint(litWidth)
+ d.eof, d.hi = d.clear+1, d.clear+1
+ d.overflow = uint16(1) << d.width
+ d.last = decoderInvalidCode
+
+ return d
+}
diff --git a/src/compress/lzw/reader_test.go b/src/compress/lzw/reader_test.go
new file mode 100644
index 0000000..d1eb76d
--- /dev/null
+++ b/src/compress/lzw/reader_test.go
@@ -0,0 +1,253 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package lzw
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "math"
+ "os"
+ "runtime"
+ "strconv"
+ "strings"
+ "testing"
+)
+
+type lzwTest struct {
+ desc string
+ raw string
+ compressed string
+ err error
+}
+
+var lzwTests = []lzwTest{
+ {
+ "empty;LSB;8",
+ "",
+ "\x01\x01",
+ nil,
+ },
+ {
+ "empty;MSB;8",
+ "",
+ "\x80\x80",
+ nil,
+ },
+ {
+ "tobe;LSB;7",
+ "TOBEORNOTTOBEORTOBEORNOT",
+ "\x54\x4f\x42\x45\x4f\x52\x4e\x4f\x54\x82\x84\x86\x8b\x85\x87\x89\x81",
+ nil,
+ },
+ {
+ "tobe;LSB;8",
+ "TOBEORNOTTOBEORTOBEORNOT",
+ "\x54\x9e\x08\x29\xf2\x44\x8a\x93\x27\x54\x04\x12\x34\xb8\xb0\xe0\xc1\x84\x01\x01",
+ nil,
+ },
+ {
+ "tobe;MSB;7",
+ "TOBEORNOTTOBEORTOBEORNOT",
+ "\x54\x4f\x42\x45\x4f\x52\x4e\x4f\x54\x82\x84\x86\x8b\x85\x87\x89\x81",
+ nil,
+ },
+ {
+ "tobe;MSB;8",
+ "TOBEORNOTTOBEORTOBEORNOT",
+ "\x2a\x13\xc8\x44\x52\x79\x48\x9c\x4f\x2a\x40\xa0\x90\x68\x5c\x16\x0f\x09\x80\x80",
+ nil,
+ },
+ {
+ "tobe-truncated;LSB;8",
+ "TOBEORNOTTOBEORTOBEORNOT",
+ "\x54\x9e\x08\x29\xf2\x44\x8a\x93\x27\x54\x04",
+ io.ErrUnexpectedEOF,
+ },
+ // This example comes from https://en.wikipedia.org/wiki/Graphics_Interchange_Format.
+ {
+ "gif;LSB;8",
+ "\x28\xff\xff\xff\x28\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff",
+ "\x00\x51\xfc\x1b\x28\x70\xa0\xc1\x83\x01\x01",
+ nil,
+ },
+ // This example comes from http://compgroups.net/comp.lang.ruby/Decompressing-LZW-compression-from-PDF-file
+ {
+ "pdf;MSB;8",
+ "-----A---B",
+ "\x80\x0b\x60\x50\x22\x0c\x0c\x85\x01",
+ nil,
+ },
+}
+
+func TestReader(t *testing.T) {
+ var b bytes.Buffer
+ for _, tt := range lzwTests {
+ d := strings.Split(tt.desc, ";")
+ var order Order
+ switch d[1] {
+ case "LSB":
+ order = LSB
+ case "MSB":
+ order = MSB
+ default:
+ t.Errorf("%s: bad order %q", tt.desc, d[1])
+ }
+ litWidth, _ := strconv.Atoi(d[2])
+ rc := NewReader(strings.NewReader(tt.compressed), order, litWidth)
+ defer rc.Close()
+ b.Reset()
+ n, err := io.Copy(&b, rc)
+ s := b.String()
+ if err != nil {
+ if err != tt.err {
+ t.Errorf("%s: io.Copy: %v want %v", tt.desc, err, tt.err)
+ }
+ if err == io.ErrUnexpectedEOF {
+ // Even if the input is truncated, we should still return the
+ // partial decoded result.
+ if n == 0 || !strings.HasPrefix(tt.raw, s) {
+ t.Errorf("got %d bytes (%q), want a non-empty prefix of %q", n, s, tt.raw)
+ }
+ }
+ continue
+ }
+ if s != tt.raw {
+ t.Errorf("%s: got %d-byte %q want %d-byte %q", tt.desc, n, s, len(tt.raw), tt.raw)
+ }
+ }
+}
+
+type devZero struct{}
+
+func (devZero) Read(p []byte) (int, error) {
+ for i := range p {
+ p[i] = 0
+ }
+ return len(p), nil
+}
+
+func TestHiCodeDoesNotOverflow(t *testing.T) {
+ r := NewReader(devZero{}, LSB, 8)
+ d := r.(*decoder)
+ buf := make([]byte, 1024)
+ oldHi := uint16(0)
+ for i := 0; i < 100; i++ {
+ if _, err := io.ReadFull(r, buf); err != nil {
+ t.Fatalf("i=%d: %v", i, err)
+ }
+ // The hi code should never decrease.
+ if d.hi < oldHi {
+ t.Fatalf("i=%d: hi=%d decreased from previous value %d", i, d.hi, oldHi)
+ }
+ oldHi = d.hi
+ }
+}
+
+// TestNoLongerSavingPriorExpansions tests the decoder state when codes other
+// than clear codes continue to be seen after decoder.hi and decoder.width
+// reach their maximum values (4095 and 12), i.e. after we no longer save prior
+// expansions. In particular, it tests seeing the highest possible code, 4095.
+func TestNoLongerSavingPriorExpansions(t *testing.T) {
+ // Iterations is used to calculate how many input bits are needed to get
+ // the decoder.hi and decoder.width values up to their maximum.
+ iterations := []struct {
+ width, n int
+ }{
+ // The final term is 257, not 256, as NewReader initializes d.hi to
+ // d.clear+1 and the clear code is 256.
+ {9, 512 - 257},
+ {10, 1024 - 512},
+ {11, 2048 - 1024},
+ {12, 4096 - 2048},
+ }
+ nCodes, nBits := 0, 0
+ for _, e := range iterations {
+ nCodes += e.n
+ nBits += e.n * e.width
+ }
+ if nCodes != 3839 {
+ t.Fatalf("nCodes: got %v, want %v", nCodes, 3839)
+ }
+ if nBits != 43255 {
+ t.Fatalf("nBits: got %v, want %v", nBits, 43255)
+ }
+
+ // Construct our input of 43255 zero bits (which gets d.hi and d.width up
+ // to 4095 and 12), followed by 0xfff (4095) as 12 bits, followed by 0x101
+ // (EOF) as 12 bits.
+ //
+ // 43255 = 5406*8 + 7, and codes are read in LSB order. The final bytes are
+ // therefore:
+ //
+ // xwwwwwww xxxxxxxx yyyyyxxx zyyyyyyy
+ // 10000000 11111111 00001111 00001000
+ //
+ // or split out:
+ //
+ // .0000000 ........ ........ ........ w = 0x000
+ // 1....... 11111111 .....111 ........ x = 0xfff
+ // ........ ........ 00001... .0001000 y = 0x101
+ //
+ // The 12 'w' bits (not all are shown) form the 3839'th code, with value
+ // 0x000. Just after decoder.read returns that code, d.hi == 4095 and
+ // d.last == 0.
+ //
+ // The 12 'x' bits form the 3840'th code, with value 0xfff or 4095. Just
+ // after decoder.read returns that code, d.hi == 4095 and d.last ==
+ // decoderInvalidCode.
+ //
+ // The 12 'y' bits form the 3841'st code, with value 0x101, the EOF code.
+ //
+ // The 'z' bit is unused.
+ in := make([]byte, 5406)
+ in = append(in, 0x80, 0xff, 0x0f, 0x08)
+
+ r := NewReader(bytes.NewReader(in), LSB, 8)
+ nDecoded, err := io.Copy(io.Discard, r)
+ if err != nil {
+ t.Fatalf("Copy: %v", err)
+ }
+ // nDecoded should be 3841: 3839 literal codes and then 2 decoded bytes
+ // from 1 non-literal code. The EOF code contributes 0 decoded bytes.
+ if nDecoded != int64(nCodes+2) {
+ t.Fatalf("nDecoded: got %v, want %v", nDecoded, nCodes+2)
+ }
+}
+
+func BenchmarkDecoder(b *testing.B) {
+ buf, err := os.ReadFile("../testdata/e.txt")
+ if err != nil {
+ b.Fatal(err)
+ }
+ if len(buf) == 0 {
+ b.Fatalf("test file has no data")
+ }
+
+ for e := 4; e <= 6; e++ {
+ n := int(math.Pow10(e))
+ b.Run(fmt.Sprint("1e", e), func(b *testing.B) {
+ b.StopTimer()
+ b.SetBytes(int64(n))
+ buf0 := buf
+ compressed := new(bytes.Buffer)
+ w := NewWriter(compressed, LSB, 8)
+ for i := 0; i < n; i += len(buf0) {
+ if len(buf0) > n-i {
+ buf0 = buf0[:n-i]
+ }
+ w.Write(buf0)
+ }
+ w.Close()
+ buf1 := compressed.Bytes()
+ buf0, compressed, w = nil, nil, nil
+ runtime.GC()
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ io.Copy(io.Discard, NewReader(bytes.NewReader(buf1), LSB, 8))
+ }
+ })
+ }
+}
diff --git a/src/compress/lzw/writer.go b/src/compress/lzw/writer.go
new file mode 100644
index 0000000..6ddb335
--- /dev/null
+++ b/src/compress/lzw/writer.go
@@ -0,0 +1,269 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package lzw
+
+import (
+ "bufio"
+ "errors"
+ "fmt"
+ "io"
+)
+
+// A writer is a buffered, flushable writer.
+type writer interface {
+ io.ByteWriter
+ Flush() error
+}
+
+// An errWriteCloser is an io.WriteCloser that always returns a given error.
+type errWriteCloser struct {
+ err error
+}
+
+func (e *errWriteCloser) Write([]byte) (int, error) {
+ return 0, e.err
+}
+
+func (e *errWriteCloser) Close() error {
+ return e.err
+}
+
+const (
+ // A code is a 12 bit value, stored as a uint32 when encoding to avoid
+ // type conversions when shifting bits.
+ maxCode = 1<<12 - 1
+ invalidCode = 1<<32 - 1
+ // There are 1<<12 possible codes, which is an upper bound on the number of
+ // valid hash table entries at any given point in time. tableSize is 4x that.
+ tableSize = 4 * 1 << 12
+ tableMask = tableSize - 1
+ // A hash table entry is a uint32. Zero is an invalid entry since the
+ // lower 12 bits of a valid entry must be a non-literal code.
+ invalidEntry = 0
+)
+
+// encoder is LZW compressor.
+type encoder struct {
+ // w is the writer that compressed bytes are written to.
+ w writer
+ // order, write, bits, nBits and width are the state for
+ // converting a code stream into a byte stream.
+ order Order
+ write func(*encoder, uint32) error
+ bits uint32
+ nBits uint
+ width uint
+ // litWidth is the width in bits of literal codes.
+ litWidth uint
+ // hi is the code implied by the next code emission.
+ // overflow is the code at which hi overflows the code width.
+ hi, overflow uint32
+ // savedCode is the accumulated code at the end of the most recent Write
+ // call. It is equal to invalidCode if there was no such call.
+ savedCode uint32
+ // err is the first error encountered during writing. Closing the encoder
+ // will make any future Write calls return errClosed
+ err error
+ // table is the hash table from 20-bit keys to 12-bit values. Each table
+ // entry contains key<<12|val and collisions resolve by linear probing.
+ // The keys consist of a 12-bit code prefix and an 8-bit byte suffix.
+ // The values are a 12-bit code.
+ table [tableSize]uint32
+}
+
+// writeLSB writes the code c for "Least Significant Bits first" data.
+func (e *encoder) writeLSB(c uint32) error {
+ e.bits |= c << e.nBits
+ e.nBits += e.width
+ for e.nBits >= 8 {
+ if err := e.w.WriteByte(uint8(e.bits)); err != nil {
+ return err
+ }
+ e.bits >>= 8
+ e.nBits -= 8
+ }
+ return nil
+}
+
+// writeMSB writes the code c for "Most Significant Bits first" data.
+func (e *encoder) writeMSB(c uint32) error {
+ e.bits |= c << (32 - e.width - e.nBits)
+ e.nBits += e.width
+ for e.nBits >= 8 {
+ if err := e.w.WriteByte(uint8(e.bits >> 24)); err != nil {
+ return err
+ }
+ e.bits <<= 8
+ e.nBits -= 8
+ }
+ return nil
+}
+
+// errOutOfCodes is an internal error that means that the encoder has run out
+// of unused codes and a clear code needs to be sent next.
+var errOutOfCodes = errors.New("lzw: out of codes")
+
+// incHi increments e.hi and checks for both overflow and running out of
+// unused codes. In the latter case, incHi sends a clear code, resets the
+// encoder state and returns errOutOfCodes.
+func (e *encoder) incHi() error {
+ e.hi++
+ if e.hi == e.overflow {
+ e.width++
+ e.overflow <<= 1
+ }
+ if e.hi == maxCode {
+ clear := uint32(1) << e.litWidth
+ if err := e.write(e, clear); err != nil {
+ return err
+ }
+ e.width = e.litWidth + 1
+ e.hi = clear + 1
+ e.overflow = clear << 1
+ for i := range e.table {
+ e.table[i] = invalidEntry
+ }
+ return errOutOfCodes
+ }
+ return nil
+}
+
+// Write writes a compressed representation of p to e's underlying writer.
+func (e *encoder) Write(p []byte) (n int, err error) {
+ if e.err != nil {
+ return 0, e.err
+ }
+ if len(p) == 0 {
+ return 0, nil
+ }
+ if maxLit := uint8(1<<e.litWidth - 1); maxLit != 0xff {
+ for _, x := range p {
+ if x > maxLit {
+ e.err = errors.New("lzw: input byte too large for the litWidth")
+ return 0, e.err
+ }
+ }
+ }
+ n = len(p)
+ code := e.savedCode
+ if code == invalidCode {
+ // The first code sent is always a literal code.
+ code, p = uint32(p[0]), p[1:]
+ }
+loop:
+ for _, x := range p {
+ literal := uint32(x)
+ key := code<<8 | literal
+ // If there is a hash table hit for this key then we continue the loop
+ // and do not emit a code yet.
+ hash := (key>>12 ^ key) & tableMask
+ for h, t := hash, e.table[hash]; t != invalidEntry; {
+ if key == t>>12 {
+ code = t & maxCode
+ continue loop
+ }
+ h = (h + 1) & tableMask
+ t = e.table[h]
+ }
+ // Otherwise, write the current code, and literal becomes the start of
+ // the next emitted code.
+ if e.err = e.write(e, code); e.err != nil {
+ return 0, e.err
+ }
+ code = literal
+ // Increment e.hi, the next implied code. If we run out of codes, reset
+ // the encoder state (including clearing the hash table) and continue.
+ if err1 := e.incHi(); err1 != nil {
+ if err1 == errOutOfCodes {
+ continue
+ }
+ e.err = err1
+ return 0, e.err
+ }
+ // Otherwise, insert key -> e.hi into the map that e.table represents.
+ for {
+ if e.table[hash] == invalidEntry {
+ e.table[hash] = (key << 12) | e.hi
+ break
+ }
+ hash = (hash + 1) & tableMask
+ }
+ }
+ e.savedCode = code
+ return n, nil
+}
+
+// Close closes the encoder, flushing any pending output. It does not close or
+// flush e's underlying writer.
+func (e *encoder) Close() error {
+ if e.err != nil {
+ if e.err == errClosed {
+ return nil
+ }
+ return e.err
+ }
+ // Make any future calls to Write return errClosed.
+ e.err = errClosed
+ // Write the savedCode if valid.
+ if e.savedCode != invalidCode {
+ if err := e.write(e, e.savedCode); err != nil {
+ return err
+ }
+ if err := e.incHi(); err != nil && err != errOutOfCodes {
+ return err
+ }
+ }
+ // Write the eof code.
+ eof := uint32(1)<<e.litWidth + 1
+ if err := e.write(e, eof); err != nil {
+ return err
+ }
+ // Write the final bits.
+ if e.nBits > 0 {
+ if e.order == MSB {
+ e.bits >>= 24
+ }
+ if err := e.w.WriteByte(uint8(e.bits)); err != nil {
+ return err
+ }
+ }
+ return e.w.Flush()
+}
+
+// NewWriter creates a new io.WriteCloser.
+// Writes to the returned io.WriteCloser are compressed and written to w.
+// It is the caller's responsibility to call Close on the WriteCloser when
+// finished writing.
+// The number of bits to use for literal codes, litWidth, must be in the
+// range [2,8] and is typically 8. Input bytes must be less than 1<<litWidth.
+func NewWriter(w io.Writer, order Order, litWidth int) io.WriteCloser {
+ var write func(*encoder, uint32) error
+ switch order {
+ case LSB:
+ write = (*encoder).writeLSB
+ case MSB:
+ write = (*encoder).writeMSB
+ default:
+ return &errWriteCloser{errors.New("lzw: unknown order")}
+ }
+ if litWidth < 2 || 8 < litWidth {
+ return &errWriteCloser{fmt.Errorf("lzw: litWidth %d out of range", litWidth)}
+ }
+ bw, ok := w.(writer)
+ if !ok {
+ bw = bufio.NewWriter(w)
+ }
+ lw := uint(litWidth)
+ return &encoder{
+ w: bw,
+ order: order,
+ write: write,
+ width: 1 + lw,
+ litWidth: lw,
+ hi: 1<<lw + 1,
+ overflow: 1 << (lw + 1),
+ savedCode: invalidCode,
+ }
+}
diff --git a/src/compress/lzw/writer_test.go b/src/compress/lzw/writer_test.go
new file mode 100644
index 0000000..1a5dbca
--- /dev/null
+++ b/src/compress/lzw/writer_test.go
@@ -0,0 +1,156 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package lzw
+
+import (
+ "fmt"
+ "internal/testenv"
+ "io"
+ "math"
+ "os"
+ "runtime"
+ "testing"
+)
+
+var filenames = []string{
+ "../testdata/gettysburg.txt",
+ "../testdata/e.txt",
+ "../testdata/pi.txt",
+}
+
+// testFile tests that compressing and then decompressing the given file with
+// the given options yields equivalent bytes to the original file.
+func testFile(t *testing.T, fn string, order Order, litWidth int) {
+ // Read the file, as golden output.
+ golden, err := os.Open(fn)
+ if err != nil {
+ t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err)
+ return
+ }
+ defer golden.Close()
+
+ // Read the file again, and push it through a pipe that compresses at the write end, and decompresses at the read end.
+ raw, err := os.Open(fn)
+ if err != nil {
+ t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err)
+ return
+ }
+
+ piper, pipew := io.Pipe()
+ defer piper.Close()
+ go func() {
+ defer raw.Close()
+ defer pipew.Close()
+ lzww := NewWriter(pipew, order, litWidth)
+ defer lzww.Close()
+ var b [4096]byte
+ for {
+ n, err0 := raw.Read(b[:])
+ if err0 != nil && err0 != io.EOF {
+ t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err0)
+ return
+ }
+ _, err1 := lzww.Write(b[:n])
+ if err1 != nil {
+ t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err1)
+ return
+ }
+ if err0 == io.EOF {
+ break
+ }
+ }
+ }()
+ lzwr := NewReader(piper, order, litWidth)
+ defer lzwr.Close()
+
+ // Compare the two.
+ b0, err0 := io.ReadAll(golden)
+ b1, err1 := io.ReadAll(lzwr)
+ if err0 != nil {
+ t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err0)
+ return
+ }
+ if err1 != nil {
+ t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err1)
+ return
+ }
+ if len(b1) != len(b0) {
+ t.Errorf("%s (order=%d litWidth=%d): length mismatch %d != %d", fn, order, litWidth, len(b1), len(b0))
+ return
+ }
+ for i := 0; i < len(b0); i++ {
+ if b1[i] != b0[i] {
+ t.Errorf("%s (order=%d litWidth=%d): mismatch at %d, 0x%02x != 0x%02x\n", fn, order, litWidth, i, b1[i], b0[i])
+ return
+ }
+ }
+}
+
+func TestWriter(t *testing.T) {
+ for _, filename := range filenames {
+ for _, order := range [...]Order{LSB, MSB} {
+ // The test data "2.71828 etcetera" is ASCII text requiring at least 6 bits.
+ for litWidth := 6; litWidth <= 8; litWidth++ {
+ if filename == "../testdata/gettysburg.txt" && litWidth == 6 {
+ continue
+ }
+ testFile(t, filename, order, litWidth)
+ }
+ }
+ if testing.Short() && testenv.Builder() == "" {
+ break
+ }
+ }
+}
+
+func TestWriterReturnValues(t *testing.T) {
+ w := NewWriter(io.Discard, LSB, 8)
+ n, err := w.Write([]byte("asdf"))
+ if n != 4 || err != nil {
+ t.Errorf("got %d, %v, want 4, nil", n, err)
+ }
+}
+
+func TestSmallLitWidth(t *testing.T) {
+ w := NewWriter(io.Discard, LSB, 2)
+ if _, err := w.Write([]byte{0x03}); err != nil {
+ t.Fatalf("write a byte < 1<<2: %v", err)
+ }
+ if _, err := w.Write([]byte{0x04}); err == nil {
+ t.Fatal("write a byte >= 1<<2: got nil error, want non-nil")
+ }
+}
+
+func BenchmarkEncoder(b *testing.B) {
+ buf, err := os.ReadFile("../testdata/e.txt")
+ if err != nil {
+ b.Fatal(err)
+ }
+ if len(buf) == 0 {
+ b.Fatalf("test file has no data")
+ }
+
+ for e := 4; e <= 6; e++ {
+ n := int(math.Pow10(e))
+ buf0 := buf
+ buf1 := make([]byte, n)
+ for i := 0; i < n; i += len(buf0) {
+ if len(buf0) > n-i {
+ buf0 = buf0[:n-i]
+ }
+ copy(buf1[i:], buf0)
+ }
+ buf0 = nil
+ runtime.GC()
+ b.Run(fmt.Sprint("1e", e), func(b *testing.B) {
+ b.SetBytes(int64(n))
+ for i := 0; i < b.N; i++ {
+ w := NewWriter(io.Discard, LSB, 8)
+ w.Write(buf1)
+ w.Close()
+ }
+ })
+ }
+}