summaryrefslogtreecommitdiffstats
path: root/src/compress/lzw
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-16 19:23:18 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-16 19:23:18 +0000
commit43a123c1ae6613b3efeed291fa552ecd909d3acf (patch)
treefd92518b7024bc74031f78a1cf9e454b65e73665 /src/compress/lzw
parentInitial commit. (diff)
downloadgolang-1.20-43a123c1ae6613b3efeed291fa552ecd909d3acf.tar.xz
golang-1.20-43a123c1ae6613b3efeed291fa552ecd909d3acf.zip
Adding upstream version 1.20.14.upstream/1.20.14upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/compress/lzw')
-rw-r--r--src/compress/lzw/reader.go290
-rw-r--r--src/compress/lzw/reader_test.go315
-rw-r--r--src/compress/lzw/writer.go293
-rw-r--r--src/compress/lzw/writer_test.go238
4 files changed, 1136 insertions, 0 deletions
diff --git a/src/compress/lzw/reader.go b/src/compress/lzw/reader.go
new file mode 100644
index 0000000..18df970
--- /dev/null
+++ b/src/compress/lzw/reader.go
@@ -0,0 +1,290 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package lzw implements the Lempel-Ziv-Welch compressed data format,
+// described in T. A. Welch, “A Technique for High-Performance Data
+// Compression”, Computer, 17(6) (June 1984), pp 8-19.
+//
+// In particular, it implements LZW as used by the GIF and PDF file
+// formats, which means variable-width codes up to 12 bits and the first
+// two non-literal codes are a clear code and an EOF code.
+//
+// The TIFF file format uses a similar but incompatible version of the LZW
+// algorithm. See the golang.org/x/image/tiff/lzw package for an
+// implementation.
+package lzw
+
+// TODO(nigeltao): check that PDF uses LZW in the same way as GIF,
+// modulo LSB/MSB packing order.
+
+import (
+ "bufio"
+ "errors"
+ "fmt"
+ "io"
+)
+
+// Order specifies the bit ordering in an LZW data stream.
+type Order int
+
+const (
+ // LSB means Least Significant Bits first, as used in the GIF file format.
+ LSB Order = iota
+ // MSB means Most Significant Bits first, as used in the TIFF and PDF
+ // file formats.
+ MSB
+)
+
+const (
+ maxWidth = 12
+ decoderInvalidCode = 0xffff
+ flushBuffer = 1 << maxWidth
+)
+
+// Reader is an io.Reader which can be used to read compressed data in the
+// LZW format.
+type Reader struct {
+ r io.ByteReader
+ bits uint32
+ nBits uint
+ width uint
+ read func(*Reader) (uint16, error) // readLSB or readMSB
+ litWidth int // width in bits of literal codes
+ err error
+
+ // The first 1<<litWidth codes are literal codes.
+ // The next two codes mean clear and EOF.
+ // Other valid codes are in the range [lo, hi] where lo := clear + 2,
+ // with the upper bound incrementing on each code seen.
+ //
+ // overflow is the code at which hi overflows the code width. It always
+ // equals 1 << width.
+ //
+ // last is the most recently seen code, or decoderInvalidCode.
+ //
+ // An invariant is that hi < overflow.
+ clear, eof, hi, overflow, last uint16
+
+ // Each code c in [lo, hi] expands to two or more bytes. For c != hi:
+ // suffix[c] is the last of these bytes.
+ // prefix[c] is the code for all but the last byte.
+ // This code can either be a literal code or another code in [lo, c).
+ // The c == hi case is a special case.
+ suffix [1 << maxWidth]uint8
+ prefix [1 << maxWidth]uint16
+
+ // output is the temporary output buffer.
+ // Literal codes are accumulated from the start of the buffer.
+ // Non-literal codes decode to a sequence of suffixes that are first
+ // written right-to-left from the end of the buffer before being copied
+ // to the start of the buffer.
+ // It is flushed when it contains >= 1<<maxWidth bytes,
+ // so that there is always room to decode an entire code.
+ output [2 * 1 << maxWidth]byte
+ o int // write index into output
+ toRead []byte // bytes to return from Read
+}
+
+// readLSB returns the next code for "Least Significant Bits first" data.
+func (r *Reader) readLSB() (uint16, error) {
+ for r.nBits < r.width {
+ x, err := r.r.ReadByte()
+ if err != nil {
+ return 0, err
+ }
+ r.bits |= uint32(x) << r.nBits
+ r.nBits += 8
+ }
+ code := uint16(r.bits & (1<<r.width - 1))
+ r.bits >>= r.width
+ r.nBits -= r.width
+ return code, nil
+}
+
+// readMSB returns the next code for "Most Significant Bits first" data.
+func (r *Reader) readMSB() (uint16, error) {
+ for r.nBits < r.width {
+ x, err := r.r.ReadByte()
+ if err != nil {
+ return 0, err
+ }
+ r.bits |= uint32(x) << (24 - r.nBits)
+ r.nBits += 8
+ }
+ code := uint16(r.bits >> (32 - r.width))
+ r.bits <<= r.width
+ r.nBits -= r.width
+ return code, nil
+}
+
+// Read implements io.Reader, reading uncompressed bytes from its underlying Reader.
+func (r *Reader) Read(b []byte) (int, error) {
+ for {
+ if len(r.toRead) > 0 {
+ n := copy(b, r.toRead)
+ r.toRead = r.toRead[n:]
+ return n, nil
+ }
+ if r.err != nil {
+ return 0, r.err
+ }
+ r.decode()
+ }
+}
+
+// decode decompresses bytes from r and leaves them in d.toRead.
+// read specifies how to decode bytes into codes.
+// litWidth is the width in bits of literal codes.
+func (r *Reader) decode() {
+ // Loop over the code stream, converting codes into decompressed bytes.
+loop:
+ for {
+ code, err := r.read(r)
+ if err != nil {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ r.err = err
+ break
+ }
+ switch {
+ case code < r.clear:
+ // We have a literal code.
+ r.output[r.o] = uint8(code)
+ r.o++
+ if r.last != decoderInvalidCode {
+ // Save what the hi code expands to.
+ r.suffix[r.hi] = uint8(code)
+ r.prefix[r.hi] = r.last
+ }
+ case code == r.clear:
+ r.width = 1 + uint(r.litWidth)
+ r.hi = r.eof
+ r.overflow = 1 << r.width
+ r.last = decoderInvalidCode
+ continue
+ case code == r.eof:
+ r.err = io.EOF
+ break loop
+ case code <= r.hi:
+ c, i := code, len(r.output)-1
+ if code == r.hi && r.last != decoderInvalidCode {
+ // code == hi is a special case which expands to the last expansion
+ // followed by the head of the last expansion. To find the head, we walk
+ // the prefix chain until we find a literal code.
+ c = r.last
+ for c >= r.clear {
+ c = r.prefix[c]
+ }
+ r.output[i] = uint8(c)
+ i--
+ c = r.last
+ }
+ // Copy the suffix chain into output and then write that to w.
+ for c >= r.clear {
+ r.output[i] = r.suffix[c]
+ i--
+ c = r.prefix[c]
+ }
+ r.output[i] = uint8(c)
+ r.o += copy(r.output[r.o:], r.output[i:])
+ if r.last != decoderInvalidCode {
+ // Save what the hi code expands to.
+ r.suffix[r.hi] = uint8(c)
+ r.prefix[r.hi] = r.last
+ }
+ default:
+ r.err = errors.New("lzw: invalid code")
+ break loop
+ }
+ r.last, r.hi = code, r.hi+1
+ if r.hi >= r.overflow {
+ if r.hi > r.overflow {
+ panic("unreachable")
+ }
+ if r.width == maxWidth {
+ r.last = decoderInvalidCode
+ // Undo the d.hi++ a few lines above, so that (1) we maintain
+ // the invariant that d.hi < d.overflow, and (2) d.hi does not
+ // eventually overflow a uint16.
+ r.hi--
+ } else {
+ r.width++
+ r.overflow = 1 << r.width
+ }
+ }
+ if r.o >= flushBuffer {
+ break
+ }
+ }
+ // Flush pending output.
+ r.toRead = r.output[:r.o]
+ r.o = 0
+}
+
+var errClosed = errors.New("lzw: reader/writer is closed")
+
+// Close closes the Reader and returns an error for any future read operation.
+// It does not close the underlying io.Reader.
+func (r *Reader) Close() error {
+ r.err = errClosed // in case any Reads come along
+ return nil
+}
+
+// Reset clears the Reader's state and allows it to be reused again
+// as a new Reader.
+func (r *Reader) Reset(src io.Reader, order Order, litWidth int) {
+ *r = Reader{}
+ r.init(src, order, litWidth)
+}
+
+// NewReader creates a new io.ReadCloser.
+// Reads from the returned io.ReadCloser read and decompress data from r.
+// If r does not also implement io.ByteReader,
+// the decompressor may read more data than necessary from r.
+// It is the caller's responsibility to call Close on the ReadCloser when
+// finished reading.
+// The number of bits to use for literal codes, litWidth, must be in the
+// range [2,8] and is typically 8. It must equal the litWidth
+// used during compression.
+//
+// It is guaranteed that the underlying type of the returned io.ReadCloser
+// is a *Reader.
+func NewReader(r io.Reader, order Order, litWidth int) io.ReadCloser {
+ return newReader(r, order, litWidth)
+}
+
+func newReader(src io.Reader, order Order, litWidth int) *Reader {
+ r := new(Reader)
+ r.init(src, order, litWidth)
+ return r
+}
+
+func (r *Reader) init(src io.Reader, order Order, litWidth int) {
+ switch order {
+ case LSB:
+ r.read = (*Reader).readLSB
+ case MSB:
+ r.read = (*Reader).readMSB
+ default:
+ r.err = errors.New("lzw: unknown order")
+ return
+ }
+ if litWidth < 2 || 8 < litWidth {
+ r.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
+ return
+ }
+
+ br, ok := src.(io.ByteReader)
+ if !ok && src != nil {
+ br = bufio.NewReader(src)
+ }
+ r.r = br
+ r.litWidth = litWidth
+ r.width = 1 + uint(litWidth)
+ r.clear = uint16(1) << uint(litWidth)
+ r.eof, r.hi = r.clear+1, r.clear+1
+ r.overflow = uint16(1) << r.width
+ r.last = decoderInvalidCode
+}
diff --git a/src/compress/lzw/reader_test.go b/src/compress/lzw/reader_test.go
new file mode 100644
index 0000000..9a2a477
--- /dev/null
+++ b/src/compress/lzw/reader_test.go
@@ -0,0 +1,315 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package lzw
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "math"
+ "os"
+ "runtime"
+ "strconv"
+ "strings"
+ "testing"
+)
+
+type lzwTest struct {
+ desc string
+ raw string
+ compressed string
+ err error
+}
+
+var lzwTests = []lzwTest{
+ {
+ "empty;LSB;8",
+ "",
+ "\x01\x01",
+ nil,
+ },
+ {
+ "empty;MSB;8",
+ "",
+ "\x80\x80",
+ nil,
+ },
+ {
+ "tobe;LSB;7",
+ "TOBEORNOTTOBEORTOBEORNOT",
+ "\x54\x4f\x42\x45\x4f\x52\x4e\x4f\x54\x82\x84\x86\x8b\x85\x87\x89\x81",
+ nil,
+ },
+ {
+ "tobe;LSB;8",
+ "TOBEORNOTTOBEORTOBEORNOT",
+ "\x54\x9e\x08\x29\xf2\x44\x8a\x93\x27\x54\x04\x12\x34\xb8\xb0\xe0\xc1\x84\x01\x01",
+ nil,
+ },
+ {
+ "tobe;MSB;7",
+ "TOBEORNOTTOBEORTOBEORNOT",
+ "\x54\x4f\x42\x45\x4f\x52\x4e\x4f\x54\x82\x84\x86\x8b\x85\x87\x89\x81",
+ nil,
+ },
+ {
+ "tobe;MSB;8",
+ "TOBEORNOTTOBEORTOBEORNOT",
+ "\x2a\x13\xc8\x44\x52\x79\x48\x9c\x4f\x2a\x40\xa0\x90\x68\x5c\x16\x0f\x09\x80\x80",
+ nil,
+ },
+ {
+ "tobe-truncated;LSB;8",
+ "TOBEORNOTTOBEORTOBEORNOT",
+ "\x54\x9e\x08\x29\xf2\x44\x8a\x93\x27\x54\x04",
+ io.ErrUnexpectedEOF,
+ },
+ // This example comes from https://en.wikipedia.org/wiki/Graphics_Interchange_Format.
+ {
+ "gif;LSB;8",
+ "\x28\xff\xff\xff\x28\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff",
+ "\x00\x51\xfc\x1b\x28\x70\xa0\xc1\x83\x01\x01",
+ nil,
+ },
+ // This example comes from http://compgroups.net/comp.lang.ruby/Decompressing-LZW-compression-from-PDF-file
+ {
+ "pdf;MSB;8",
+ "-----A---B",
+ "\x80\x0b\x60\x50\x22\x0c\x0c\x85\x01",
+ nil,
+ },
+}
+
+func TestReader(t *testing.T) {
+ var b bytes.Buffer
+ for _, tt := range lzwTests {
+ d := strings.Split(tt.desc, ";")
+ var order Order
+ switch d[1] {
+ case "LSB":
+ order = LSB
+ case "MSB":
+ order = MSB
+ default:
+ t.Errorf("%s: bad order %q", tt.desc, d[1])
+ }
+ litWidth, _ := strconv.Atoi(d[2])
+ rc := NewReader(strings.NewReader(tt.compressed), order, litWidth)
+ defer rc.Close()
+ b.Reset()
+ n, err := io.Copy(&b, rc)
+ s := b.String()
+ if err != nil {
+ if err != tt.err {
+ t.Errorf("%s: io.Copy: %v want %v", tt.desc, err, tt.err)
+ }
+ if err == io.ErrUnexpectedEOF {
+ // Even if the input is truncated, we should still return the
+ // partial decoded result.
+ if n == 0 || !strings.HasPrefix(tt.raw, s) {
+ t.Errorf("got %d bytes (%q), want a non-empty prefix of %q", n, s, tt.raw)
+ }
+ }
+ continue
+ }
+ if s != tt.raw {
+ t.Errorf("%s: got %d-byte %q want %d-byte %q", tt.desc, n, s, len(tt.raw), tt.raw)
+ }
+ }
+}
+
+func TestReaderReset(t *testing.T) {
+ var b bytes.Buffer
+ for _, tt := range lzwTests {
+ d := strings.Split(tt.desc, ";")
+ var order Order
+ switch d[1] {
+ case "LSB":
+ order = LSB
+ case "MSB":
+ order = MSB
+ default:
+ t.Errorf("%s: bad order %q", tt.desc, d[1])
+ }
+ litWidth, _ := strconv.Atoi(d[2])
+ rc := NewReader(strings.NewReader(tt.compressed), order, litWidth)
+ defer rc.Close()
+ b.Reset()
+ n, err := io.Copy(&b, rc)
+ b1 := b.Bytes()
+ if err != nil {
+ if err != tt.err {
+ t.Errorf("%s: io.Copy: %v want %v", tt.desc, err, tt.err)
+ }
+ if err == io.ErrUnexpectedEOF {
+ // Even if the input is truncated, we should still return the
+ // partial decoded result.
+ if n == 0 || !strings.HasPrefix(tt.raw, b.String()) {
+ t.Errorf("got %d bytes (%q), want a non-empty prefix of %q", n, b.String(), tt.raw)
+ }
+ }
+ continue
+ }
+
+ b.Reset()
+ rc.(*Reader).Reset(strings.NewReader(tt.compressed), order, litWidth)
+ n, err = io.Copy(&b, rc)
+ b2 := b.Bytes()
+ if err != nil {
+ t.Errorf("%s: io.Copy: %v want %v", tt.desc, err, nil)
+ continue
+ }
+ if !bytes.Equal(b1, b2) {
+ t.Errorf("bytes read were not the same")
+ }
+ }
+}
+
+type devZero struct{}
+
+func (devZero) Read(p []byte) (int, error) {
+ for i := range p {
+ p[i] = 0
+ }
+ return len(p), nil
+}
+
+func TestHiCodeDoesNotOverflow(t *testing.T) {
+ r := NewReader(devZero{}, LSB, 8)
+ d := r.(*Reader)
+ buf := make([]byte, 1024)
+ oldHi := uint16(0)
+ for i := 0; i < 100; i++ {
+ if _, err := io.ReadFull(r, buf); err != nil {
+ t.Fatalf("i=%d: %v", i, err)
+ }
+ // The hi code should never decrease.
+ if d.hi < oldHi {
+ t.Fatalf("i=%d: hi=%d decreased from previous value %d", i, d.hi, oldHi)
+ }
+ oldHi = d.hi
+ }
+}
+
+// TestNoLongerSavingPriorExpansions tests the decoder state when codes other
+// than clear codes continue to be seen after decoder.hi and decoder.width
+// reach their maximum values (4095 and 12), i.e. after we no longer save prior
+// expansions. In particular, it tests seeing the highest possible code, 4095.
+func TestNoLongerSavingPriorExpansions(t *testing.T) {
+ // Iterations is used to calculate how many input bits are needed to get
+ // the decoder.hi and decoder.width values up to their maximum.
+ iterations := []struct {
+ width, n int
+ }{
+ // The final term is 257, not 256, as NewReader initializes d.hi to
+ // d.clear+1 and the clear code is 256.
+ {9, 512 - 257},
+ {10, 1024 - 512},
+ {11, 2048 - 1024},
+ {12, 4096 - 2048},
+ }
+ nCodes, nBits := 0, 0
+ for _, e := range iterations {
+ nCodes += e.n
+ nBits += e.n * e.width
+ }
+ if nCodes != 3839 {
+ t.Fatalf("nCodes: got %v, want %v", nCodes, 3839)
+ }
+ if nBits != 43255 {
+ t.Fatalf("nBits: got %v, want %v", nBits, 43255)
+ }
+
+ // Construct our input of 43255 zero bits (which gets d.hi and d.width up
+ // to 4095 and 12), followed by 0xfff (4095) as 12 bits, followed by 0x101
+ // (EOF) as 12 bits.
+ //
+ // 43255 = 5406*8 + 7, and codes are read in LSB order. The final bytes are
+ // therefore:
+ //
+ // xwwwwwww xxxxxxxx yyyyyxxx zyyyyyyy
+ // 10000000 11111111 00001111 00001000
+ //
+ // or split out:
+ //
+ // .0000000 ........ ........ ........ w = 0x000
+ // 1....... 11111111 .....111 ........ x = 0xfff
+ // ........ ........ 00001... .0001000 y = 0x101
+ //
+ // The 12 'w' bits (not all are shown) form the 3839'th code, with value
+ // 0x000. Just after decoder.read returns that code, d.hi == 4095 and
+ // d.last == 0.
+ //
+ // The 12 'x' bits form the 3840'th code, with value 0xfff or 4095. Just
+ // after decoder.read returns that code, d.hi == 4095 and d.last ==
+ // decoderInvalidCode.
+ //
+ // The 12 'y' bits form the 3841'st code, with value 0x101, the EOF code.
+ //
+ // The 'z' bit is unused.
+ in := make([]byte, 5406)
+ in = append(in, 0x80, 0xff, 0x0f, 0x08)
+
+ r := NewReader(bytes.NewReader(in), LSB, 8)
+ nDecoded, err := io.Copy(io.Discard, r)
+ if err != nil {
+ t.Fatalf("Copy: %v", err)
+ }
+ // nDecoded should be 3841: 3839 literal codes and then 2 decoded bytes
+ // from 1 non-literal code. The EOF code contributes 0 decoded bytes.
+ if nDecoded != int64(nCodes+2) {
+ t.Fatalf("nDecoded: got %v, want %v", nDecoded, nCodes+2)
+ }
+}
+
+func BenchmarkDecoder(b *testing.B) {
+ buf, err := os.ReadFile("../testdata/e.txt")
+ if err != nil {
+ b.Fatal(err)
+ }
+ if len(buf) == 0 {
+ b.Fatalf("test file has no data")
+ }
+
+ getInputBuf := func(buf []byte, n int) []byte {
+ compressed := new(bytes.Buffer)
+ w := NewWriter(compressed, LSB, 8)
+ for i := 0; i < n; i += len(buf) {
+ if len(buf) > n-i {
+ buf = buf[:n-i]
+ }
+ w.Write(buf)
+ }
+ w.Close()
+ return compressed.Bytes()
+ }
+
+ for e := 4; e <= 6; e++ {
+ n := int(math.Pow10(e))
+ b.Run(fmt.Sprint("1e", e), func(b *testing.B) {
+ b.StopTimer()
+ b.SetBytes(int64(n))
+ buf1 := getInputBuf(buf, n)
+ runtime.GC()
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ io.Copy(io.Discard, NewReader(bytes.NewReader(buf1), LSB, 8))
+ }
+ })
+ b.Run(fmt.Sprint("1e-Reuse", e), func(b *testing.B) {
+ b.StopTimer()
+ b.SetBytes(int64(n))
+ buf1 := getInputBuf(buf, n)
+ runtime.GC()
+ b.StartTimer()
+ r := NewReader(bytes.NewReader(buf1), LSB, 8)
+ for i := 0; i < b.N; i++ {
+ io.Copy(io.Discard, r)
+ r.Close()
+ r.(*Reader).Reset(bytes.NewReader(buf1), LSB, 8)
+ }
+ })
+ }
+}
diff --git a/src/compress/lzw/writer.go b/src/compress/lzw/writer.go
new file mode 100644
index 0000000..cf06ea8
--- /dev/null
+++ b/src/compress/lzw/writer.go
@@ -0,0 +1,293 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package lzw
+
+import (
+ "bufio"
+ "errors"
+ "fmt"
+ "io"
+)
+
+// A writer is a buffered, flushable writer.
+type writer interface {
+ io.ByteWriter
+ Flush() error
+}
+
+const (
+ // A code is a 12 bit value, stored as a uint32 when encoding to avoid
+ // type conversions when shifting bits.
+ maxCode = 1<<12 - 1
+ invalidCode = 1<<32 - 1
+ // There are 1<<12 possible codes, which is an upper bound on the number of
+ // valid hash table entries at any given point in time. tableSize is 4x that.
+ tableSize = 4 * 1 << 12
+ tableMask = tableSize - 1
+ // A hash table entry is a uint32. Zero is an invalid entry since the
+ // lower 12 bits of a valid entry must be a non-literal code.
+ invalidEntry = 0
+)
+
+// Writer is an LZW compressor. It writes the compressed form of the data
+// to an underlying writer (see NewWriter).
+type Writer struct {
+ // w is the writer that compressed bytes are written to.
+ w writer
+ // order, write, bits, nBits and width are the state for
+ // converting a code stream into a byte stream.
+ order Order
+ write func(*Writer, uint32) error
+ bits uint32
+ nBits uint
+ width uint
+ // litWidth is the width in bits of literal codes.
+ litWidth uint
+ // hi is the code implied by the next code emission.
+ // overflow is the code at which hi overflows the code width.
+ hi, overflow uint32
+ // savedCode is the accumulated code at the end of the most recent Write
+ // call. It is equal to invalidCode if there was no such call.
+ savedCode uint32
+ // err is the first error encountered during writing. Closing the writer
+ // will make any future Write calls return errClosed
+ err error
+ // table is the hash table from 20-bit keys to 12-bit values. Each table
+ // entry contains key<<12|val and collisions resolve by linear probing.
+ // The keys consist of a 12-bit code prefix and an 8-bit byte suffix.
+ // The values are a 12-bit code.
+ table [tableSize]uint32
+}
+
+// writeLSB writes the code c for "Least Significant Bits first" data.
+func (w *Writer) writeLSB(c uint32) error {
+ w.bits |= c << w.nBits
+ w.nBits += w.width
+ for w.nBits >= 8 {
+ if err := w.w.WriteByte(uint8(w.bits)); err != nil {
+ return err
+ }
+ w.bits >>= 8
+ w.nBits -= 8
+ }
+ return nil
+}
+
+// writeMSB writes the code c for "Most Significant Bits first" data.
+func (w *Writer) writeMSB(c uint32) error {
+ w.bits |= c << (32 - w.width - w.nBits)
+ w.nBits += w.width
+ for w.nBits >= 8 {
+ if err := w.w.WriteByte(uint8(w.bits >> 24)); err != nil {
+ return err
+ }
+ w.bits <<= 8
+ w.nBits -= 8
+ }
+ return nil
+}
+
+// errOutOfCodes is an internal error that means that the writer has run out
+// of unused codes and a clear code needs to be sent next.
+var errOutOfCodes = errors.New("lzw: out of codes")
+
+// incHi increments e.hi and checks for both overflow and running out of
+// unused codes. In the latter case, incHi sends a clear code, resets the
+// writer state and returns errOutOfCodes.
+func (w *Writer) incHi() error {
+ w.hi++
+ if w.hi == w.overflow {
+ w.width++
+ w.overflow <<= 1
+ }
+ if w.hi == maxCode {
+ clear := uint32(1) << w.litWidth
+ if err := w.write(w, clear); err != nil {
+ return err
+ }
+ w.width = w.litWidth + 1
+ w.hi = clear + 1
+ w.overflow = clear << 1
+ for i := range w.table {
+ w.table[i] = invalidEntry
+ }
+ return errOutOfCodes
+ }
+ return nil
+}
+
+// Write writes a compressed representation of p to w's underlying writer.
+func (w *Writer) Write(p []byte) (n int, err error) {
+ if w.err != nil {
+ return 0, w.err
+ }
+ if len(p) == 0 {
+ return 0, nil
+ }
+ if maxLit := uint8(1<<w.litWidth - 1); maxLit != 0xff {
+ for _, x := range p {
+ if x > maxLit {
+ w.err = errors.New("lzw: input byte too large for the litWidth")
+ return 0, w.err
+ }
+ }
+ }
+ n = len(p)
+ code := w.savedCode
+ if code == invalidCode {
+ // This is the first write; send a clear code.
+ // https://www.w3.org/Graphics/GIF/spec-gif89a.txt Appendix F
+ // "Variable-Length-Code LZW Compression" says that "Encoders should
+ // output a Clear code as the first code of each image data stream".
+ //
+ // LZW compression isn't only used by GIF, but it's cheap to follow
+ // that directive unconditionally.
+ clear := uint32(1) << w.litWidth
+ if err := w.write(w, clear); err != nil {
+ return 0, err
+ }
+ // After the starting clear code, the next code sent (for non-empty
+ // input) is always a literal code.
+ code, p = uint32(p[0]), p[1:]
+ }
+loop:
+ for _, x := range p {
+ literal := uint32(x)
+ key := code<<8 | literal
+ // If there is a hash table hit for this key then we continue the loop
+ // and do not emit a code yet.
+ hash := (key>>12 ^ key) & tableMask
+ for h, t := hash, w.table[hash]; t != invalidEntry; {
+ if key == t>>12 {
+ code = t & maxCode
+ continue loop
+ }
+ h = (h + 1) & tableMask
+ t = w.table[h]
+ }
+ // Otherwise, write the current code, and literal becomes the start of
+ // the next emitted code.
+ if w.err = w.write(w, code); w.err != nil {
+ return 0, w.err
+ }
+ code = literal
+ // Increment e.hi, the next implied code. If we run out of codes, reset
+ // the writer state (including clearing the hash table) and continue.
+ if err1 := w.incHi(); err1 != nil {
+ if err1 == errOutOfCodes {
+ continue
+ }
+ w.err = err1
+ return 0, w.err
+ }
+ // Otherwise, insert key -> e.hi into the map that e.table represents.
+ for {
+ if w.table[hash] == invalidEntry {
+ w.table[hash] = (key << 12) | w.hi
+ break
+ }
+ hash = (hash + 1) & tableMask
+ }
+ }
+ w.savedCode = code
+ return n, nil
+}
+
+// Close closes the Writer, flushing any pending output. It does not close
+// w's underlying writer.
+func (w *Writer) Close() error {
+ if w.err != nil {
+ if w.err == errClosed {
+ return nil
+ }
+ return w.err
+ }
+ // Make any future calls to Write return errClosed.
+ w.err = errClosed
+ // Write the savedCode if valid.
+ if w.savedCode != invalidCode {
+ if err := w.write(w, w.savedCode); err != nil {
+ return err
+ }
+ if err := w.incHi(); err != nil && err != errOutOfCodes {
+ return err
+ }
+ } else {
+ // Write the starting clear code, as w.Write did not.
+ clear := uint32(1) << w.litWidth
+ if err := w.write(w, clear); err != nil {
+ return err
+ }
+ }
+ // Write the eof code.
+ eof := uint32(1)<<w.litWidth + 1
+ if err := w.write(w, eof); err != nil {
+ return err
+ }
+ // Write the final bits.
+ if w.nBits > 0 {
+ if w.order == MSB {
+ w.bits >>= 24
+ }
+ if err := w.w.WriteByte(uint8(w.bits)); err != nil {
+ return err
+ }
+ }
+ return w.w.Flush()
+}
+
+// Reset clears the Writer's state and allows it to be reused again
+// as a new Writer.
+func (w *Writer) Reset(dst io.Writer, order Order, litWidth int) {
+ *w = Writer{}
+ w.init(dst, order, litWidth)
+}
+
+// NewWriter creates a new io.WriteCloser.
+// Writes to the returned io.WriteCloser are compressed and written to w.
+// It is the caller's responsibility to call Close on the WriteCloser when
+// finished writing.
+// The number of bits to use for literal codes, litWidth, must be in the
+// range [2,8] and is typically 8. Input bytes must be less than 1<<litWidth.
+//
+// It is guaranteed that the underlying type of the returned io.WriteCloser
+// is a *Writer.
+func NewWriter(w io.Writer, order Order, litWidth int) io.WriteCloser {
+ return newWriter(w, order, litWidth)
+}
+
+func newWriter(dst io.Writer, order Order, litWidth int) *Writer {
+ w := new(Writer)
+ w.init(dst, order, litWidth)
+ return w
+}
+
+func (w *Writer) init(dst io.Writer, order Order, litWidth int) {
+ switch order {
+ case LSB:
+ w.write = (*Writer).writeLSB
+ case MSB:
+ w.write = (*Writer).writeMSB
+ default:
+ w.err = errors.New("lzw: unknown order")
+ return
+ }
+ if litWidth < 2 || 8 < litWidth {
+ w.err = fmt.Errorf("lzw: litWidth %d out of range", litWidth)
+ return
+ }
+ bw, ok := dst.(writer)
+ if !ok && dst != nil {
+ bw = bufio.NewWriter(dst)
+ }
+ w.w = bw
+ lw := uint(litWidth)
+ w.order = order
+ w.width = 1 + lw
+ w.litWidth = lw
+ w.hi = 1<<lw + 1
+ w.overflow = 1 << (lw + 1)
+ w.savedCode = invalidCode
+}
diff --git a/src/compress/lzw/writer_test.go b/src/compress/lzw/writer_test.go
new file mode 100644
index 0000000..edf683a
--- /dev/null
+++ b/src/compress/lzw/writer_test.go
@@ -0,0 +1,238 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package lzw
+
+import (
+ "bytes"
+ "fmt"
+ "internal/testenv"
+ "io"
+ "math"
+ "os"
+ "runtime"
+ "testing"
+)
+
+var filenames = []string{
+ "../testdata/gettysburg.txt",
+ "../testdata/e.txt",
+ "../testdata/pi.txt",
+}
+
+// testFile tests that compressing and then decompressing the given file with
+// the given options yields equivalent bytes to the original file.
+func testFile(t *testing.T, fn string, order Order, litWidth int) {
+ // Read the file, as golden output.
+ golden, err := os.Open(fn)
+ if err != nil {
+ t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err)
+ return
+ }
+ defer golden.Close()
+
+ // Read the file again, and push it through a pipe that compresses at the write end, and decompresses at the read end.
+ raw, err := os.Open(fn)
+ if err != nil {
+ t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err)
+ return
+ }
+
+ piper, pipew := io.Pipe()
+ defer piper.Close()
+ go func() {
+ defer raw.Close()
+ defer pipew.Close()
+ lzww := NewWriter(pipew, order, litWidth)
+ defer lzww.Close()
+ var b [4096]byte
+ for {
+ n, err0 := raw.Read(b[:])
+ if err0 != nil && err0 != io.EOF {
+ t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err0)
+ return
+ }
+ _, err1 := lzww.Write(b[:n])
+ if err1 != nil {
+ t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err1)
+ return
+ }
+ if err0 == io.EOF {
+ break
+ }
+ }
+ }()
+ lzwr := NewReader(piper, order, litWidth)
+ defer lzwr.Close()
+
+ // Compare the two.
+ b0, err0 := io.ReadAll(golden)
+ b1, err1 := io.ReadAll(lzwr)
+ if err0 != nil {
+ t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err0)
+ return
+ }
+ if err1 != nil {
+ t.Errorf("%s (order=%d litWidth=%d): %v", fn, order, litWidth, err1)
+ return
+ }
+ if len(b1) != len(b0) {
+ t.Errorf("%s (order=%d litWidth=%d): length mismatch %d != %d", fn, order, litWidth, len(b1), len(b0))
+ return
+ }
+ for i := 0; i < len(b0); i++ {
+ if b1[i] != b0[i] {
+ t.Errorf("%s (order=%d litWidth=%d): mismatch at %d, 0x%02x != 0x%02x\n", fn, order, litWidth, i, b1[i], b0[i])
+ return
+ }
+ }
+}
+
+func TestWriter(t *testing.T) {
+ for _, filename := range filenames {
+ for _, order := range [...]Order{LSB, MSB} {
+ // The test data "2.71828 etcetera" is ASCII text requiring at least 6 bits.
+ for litWidth := 6; litWidth <= 8; litWidth++ {
+ if filename == "../testdata/gettysburg.txt" && litWidth == 6 {
+ continue
+ }
+ testFile(t, filename, order, litWidth)
+ }
+ }
+ if testing.Short() && testenv.Builder() == "" {
+ break
+ }
+ }
+}
+
+func TestWriterReset(t *testing.T) {
+ for _, order := range [...]Order{LSB, MSB} {
+ t.Run(fmt.Sprintf("Order %d", order), func(t *testing.T) {
+ for litWidth := 6; litWidth <= 8; litWidth++ {
+ t.Run(fmt.Sprintf("LitWidth %d", litWidth), func(t *testing.T) {
+ var data []byte
+ if litWidth == 6 {
+ data = []byte{1, 2, 3}
+ } else {
+ data = []byte(`lorem ipsum dolor sit amet`)
+ }
+ var buf bytes.Buffer
+ w := NewWriter(&buf, order, litWidth)
+ if _, err := w.Write(data); err != nil {
+ t.Errorf("write: %v: %v", string(data), err)
+ }
+
+ if err := w.Close(); err != nil {
+ t.Errorf("close: %v", err)
+ }
+
+ b1 := buf.Bytes()
+ buf.Reset()
+
+ w.(*Writer).Reset(&buf, order, litWidth)
+
+ if _, err := w.Write(data); err != nil {
+ t.Errorf("write: %v: %v", string(data), err)
+ }
+
+ if err := w.Close(); err != nil {
+ t.Errorf("close: %v", err)
+ }
+ b2 := buf.Bytes()
+
+ if !bytes.Equal(b1, b2) {
+ t.Errorf("bytes written were not same")
+ }
+ })
+ }
+ })
+ }
+}
+
+func TestWriterReturnValues(t *testing.T) {
+ w := NewWriter(io.Discard, LSB, 8)
+ n, err := w.Write([]byte("asdf"))
+ if n != 4 || err != nil {
+ t.Errorf("got %d, %v, want 4, nil", n, err)
+ }
+}
+
+func TestSmallLitWidth(t *testing.T) {
+ w := NewWriter(io.Discard, LSB, 2)
+ if _, err := w.Write([]byte{0x03}); err != nil {
+ t.Fatalf("write a byte < 1<<2: %v", err)
+ }
+ if _, err := w.Write([]byte{0x04}); err == nil {
+ t.Fatal("write a byte >= 1<<2: got nil error, want non-nil")
+ }
+}
+
+func TestStartsWithClearCode(t *testing.T) {
+ // A literal width of 7 bits means that the code width starts at 8 bits,
+ // which makes it easier to visually inspect the output (provided that the
+ // output is short so codes don't get longer). Each byte is a code:
+ // - ASCII bytes are literal codes,
+ // - 0x80 is the clear code,
+ // - 0x81 is the end code.
+ // - 0x82 and above are copy codes (unused in this test case).
+ for _, empty := range []bool{false, true} {
+ var buf bytes.Buffer
+ w := NewWriter(&buf, LSB, 7)
+ if !empty {
+ w.Write([]byte("Hi"))
+ }
+ w.Close()
+ got := buf.String()
+
+ want := "\x80\x81"
+ if !empty {
+ want = "\x80Hi\x81"
+ }
+
+ if got != want {
+ t.Errorf("empty=%t: got %q, want %q", empty, got, want)
+ }
+ }
+}
+
+func BenchmarkEncoder(b *testing.B) {
+ buf, err := os.ReadFile("../testdata/e.txt")
+ if err != nil {
+ b.Fatal(err)
+ }
+ if len(buf) == 0 {
+ b.Fatalf("test file has no data")
+ }
+
+ for e := 4; e <= 6; e++ {
+ n := int(math.Pow10(e))
+ buf0 := buf
+ buf1 := make([]byte, n)
+ for i := 0; i < n; i += len(buf0) {
+ if len(buf0) > n-i {
+ buf0 = buf0[:n-i]
+ }
+ copy(buf1[i:], buf0)
+ }
+ buf0 = nil
+ runtime.GC()
+ b.Run(fmt.Sprint("1e", e), func(b *testing.B) {
+ b.SetBytes(int64(n))
+ for i := 0; i < b.N; i++ {
+ w := NewWriter(io.Discard, LSB, 8)
+ w.Write(buf1)
+ w.Close()
+ }
+ })
+ b.Run(fmt.Sprint("1e-Reuse", e), func(b *testing.B) {
+ b.SetBytes(int64(n))
+ w := NewWriter(io.Discard, LSB, 8)
+ for i := 0; i < b.N; i++ {
+ w.Write(buf1)
+ w.Close()
+ w.(*Writer).Reset(io.Discard, LSB, 8)
+ }
+ })
+ }
+}