summaryrefslogtreecommitdiffstats
path: root/src/internal/zstd
diff options
context:
space:
mode:
Diffstat (limited to 'src/internal/zstd')
-rw-r--r--src/internal/zstd/bits.go130
-rw-r--r--src/internal/zstd/block.go436
-rw-r--r--src/internal/zstd/fse.go437
-rw-r--r--src/internal/zstd/fse_test.go89
-rw-r--r--src/internal/zstd/fuzz_test.go140
-rw-r--r--src/internal/zstd/huff.go204
-rw-r--r--src/internal/zstd/literals.go330
-rw-r--r--src/internal/zstd/xxhash.go148
-rw-r--r--src/internal/zstd/xxhash_test.go105
-rw-r--r--src/internal/zstd/zstd.go508
-rw-r--r--src/internal/zstd/zstd_test.go249
11 files changed, 2776 insertions, 0 deletions
diff --git a/src/internal/zstd/bits.go b/src/internal/zstd/bits.go
new file mode 100644
index 0000000..c9a2f70
--- /dev/null
+++ b/src/internal/zstd/bits.go
@@ -0,0 +1,130 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zstd
+
+import (
+ "math/bits"
+)
+
+// block is the data for a single compressed block.
+// The data starts immediately after the 3 byte block header,
+// and is Block_Size bytes long.
+type block []byte
+
+// bitReader reads a bit stream going forward.
+type bitReader struct {
+ r *Reader // for error reporting
+ data block // the bits to read
+ off uint32 // current offset into data
+ bits uint32 // bits ready to be returned
+ cnt uint32 // number of valid bits in the bits field
+}
+
+// makeBitReader makes a bit reader starting at off.
+func (r *Reader) makeBitReader(data block, off int) bitReader {
+ return bitReader{
+ r: r,
+ data: data,
+ off: uint32(off),
+ }
+}
+
+// moreBits is called to read more bits.
+// This ensures that at least 16 bits are available.
+func (br *bitReader) moreBits() error {
+ for br.cnt < 16 {
+ if br.off >= uint32(len(br.data)) {
+ return br.r.makeEOFError(int(br.off))
+ }
+ c := br.data[br.off]
+ br.off++
+ br.bits |= uint32(c) << br.cnt
+ br.cnt += 8
+ }
+ return nil
+}
+
+// val is called to fetch a value of b bits.
+func (br *bitReader) val(b uint8) uint32 {
+ r := br.bits & ((1 << b) - 1)
+ br.bits >>= b
+ br.cnt -= uint32(b)
+ return r
+}
+
+// backup steps back to the last byte we used.
+func (br *bitReader) backup() {
+ for br.cnt >= 8 {
+ br.off--
+ br.cnt -= 8
+ }
+}
+
+// makeError returns an error at the current offset wrapping a string.
+func (br *bitReader) makeError(msg string) error {
+ return br.r.makeError(int(br.off), msg)
+}
+
+// reverseBitReader reads a bit stream in reverse.
+type reverseBitReader struct {
+ r *Reader // for error reporting
+ data block // the bits to read
+ off uint32 // current offset into data
+ start uint32 // start in data; we read backward to start
+ bits uint32 // bits ready to be returned
+ cnt uint32 // number of valid bits in bits field
+}
+
+// makeReverseBitReader makes a reverseBitReader reading backward
+// from off to start. The bitstream starts with a 1 bit in the last
+// byte, at off.
+func (r *Reader) makeReverseBitReader(data block, off, start int) (reverseBitReader, error) {
+ streamStart := data[off]
+ if streamStart == 0 {
+ return reverseBitReader{}, r.makeError(off, "zero byte at reverse bit stream start")
+ }
+ rbr := reverseBitReader{
+ r: r,
+ data: data,
+ off: uint32(off),
+ start: uint32(start),
+ bits: uint32(streamStart),
+ cnt: uint32(7 - bits.LeadingZeros8(streamStart)),
+ }
+ return rbr, nil
+}
+
+// val is called to fetch a value of b bits.
+func (rbr *reverseBitReader) val(b uint8) (uint32, error) {
+ if !rbr.fetch(b) {
+ return 0, rbr.r.makeEOFError(int(rbr.off))
+ }
+
+ rbr.cnt -= uint32(b)
+ v := (rbr.bits >> rbr.cnt) & ((1 << b) - 1)
+ return v, nil
+}
+
+// fetch is called to ensure that at least b bits are available.
+// It reports false if this can't be done,
+// in which case only rbr.cnt bits are available.
+func (rbr *reverseBitReader) fetch(b uint8) bool {
+ for rbr.cnt < uint32(b) {
+ if rbr.off <= rbr.start {
+ return false
+ }
+ rbr.off--
+ c := rbr.data[rbr.off]
+ rbr.bits <<= 8
+ rbr.bits |= uint32(c)
+ rbr.cnt += 8
+ }
+ return true
+}
+
+// makeError returns an error at the current offset wrapping a string.
+func (rbr *reverseBitReader) makeError(msg string) error {
+ return rbr.r.makeError(int(rbr.off), msg)
+}
diff --git a/src/internal/zstd/block.go b/src/internal/zstd/block.go
new file mode 100644
index 0000000..bd3040c
--- /dev/null
+++ b/src/internal/zstd/block.go
@@ -0,0 +1,436 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zstd
+
+import (
+ "io"
+)
+
+// debug can be set in the source to print debug info using println.
+const debug = false
+
+// compressedBlock decompresses a compressed block, storing the decompressed
+// data in r.buffer. The blockSize argument is the compressed size.
+// RFC 3.1.1.3.
+func (r *Reader) compressedBlock(blockSize int) error {
+ if len(r.compressedBuf) >= blockSize {
+ r.compressedBuf = r.compressedBuf[:blockSize]
+ } else {
+ // We know that blockSize <= 128K,
+ // so this won't allocate an enormous amount.
+ need := blockSize - len(r.compressedBuf)
+ r.compressedBuf = append(r.compressedBuf, make([]byte, need)...)
+ }
+
+ if _, err := io.ReadFull(r.r, r.compressedBuf); err != nil {
+ return r.wrapNonEOFError(0, err)
+ }
+
+ data := block(r.compressedBuf)
+ off := 0
+ r.buffer = r.buffer[:0]
+
+ litoff, litbuf, err := r.readLiterals(data, off, r.literals[:0])
+ if err != nil {
+ return err
+ }
+ r.literals = litbuf
+
+ off = litoff
+
+ seqCount, off, err := r.initSeqs(data, off)
+ if err != nil {
+ return err
+ }
+
+ if seqCount == 0 {
+ // No sequences, just literals.
+ if off < len(data) {
+ return r.makeError(off, "extraneous data after no sequences")
+ }
+ if len(litbuf) == 0 {
+ return r.makeError(off, "no sequences and no literals")
+ }
+ r.buffer = append(r.buffer, litbuf...)
+ return nil
+ }
+
+ return r.execSeqs(data, off, litbuf, seqCount)
+}
+
+// seqCode is the kind of sequence codes we have to handle.
+type seqCode int
+
+const (
+ seqLiteral seqCode = iota
+ seqOffset
+ seqMatch
+)
+
+// seqCodeInfoData is the information needed to set up seqTables and
+// seqTableBits for a particular kind of sequence code.
+type seqCodeInfoData struct {
+ predefTable []fseBaselineEntry // predefined FSE
+ predefTableBits int // number of bits in predefTable
+ maxSym int // max symbol value in FSE
+ maxBits int // max bits for FSE
+
+ // toBaseline converts from an FSE table to an FSE baseline table.
+ toBaseline func(*Reader, int, []fseEntry, []fseBaselineEntry) error
+}
+
+// seqCodeInfo is the seqCodeInfoData for each kind of sequence code.
+var seqCodeInfo = [3]seqCodeInfoData{
+ seqLiteral: {
+ predefTable: predefinedLiteralTable[:],
+ predefTableBits: 6,
+ maxSym: 35,
+ maxBits: 9,
+ toBaseline: (*Reader).makeLiteralBaselineFSE,
+ },
+ seqOffset: {
+ predefTable: predefinedOffsetTable[:],
+ predefTableBits: 5,
+ maxSym: 31,
+ maxBits: 8,
+ toBaseline: (*Reader).makeOffsetBaselineFSE,
+ },
+ seqMatch: {
+ predefTable: predefinedMatchTable[:],
+ predefTableBits: 6,
+ maxSym: 52,
+ maxBits: 9,
+ toBaseline: (*Reader).makeMatchBaselineFSE,
+ },
+}
+
+// initSeqs reads the Sequences_Section_Header and sets up the FSE
+// tables used to read the sequence codes. It returns the number of
+// sequences and the new offset. RFC 3.1.1.3.2.1.
+func (r *Reader) initSeqs(data block, off int) (int, int, error) {
+ if off >= len(data) {
+ return 0, 0, r.makeEOFError(off)
+ }
+
+ seqHdr := data[off]
+ off++
+ if seqHdr == 0 {
+ return 0, off, nil
+ }
+
+ var seqCount int
+ if seqHdr < 128 {
+ seqCount = int(seqHdr)
+ } else if seqHdr < 255 {
+ if off >= len(data) {
+ return 0, 0, r.makeEOFError(off)
+ }
+ seqCount = ((int(seqHdr) - 128) << 8) + int(data[off])
+ off++
+ } else {
+ if off+1 >= len(data) {
+ return 0, 0, r.makeEOFError(off)
+ }
+ seqCount = int(data[off]) + (int(data[off+1]) << 8) + 0x7f00
+ off += 2
+ }
+
+ // Read the Symbol_Compression_Modes byte.
+
+ if off >= len(data) {
+ return 0, 0, r.makeEOFError(off)
+ }
+ symMode := data[off]
+ if symMode&3 != 0 {
+ return 0, 0, r.makeError(off, "invalid symbol compression mode")
+ }
+ off++
+
+ // Set up the FSE tables used to decode the sequence codes.
+
+ var err error
+ off, err = r.setSeqTable(data, off, seqLiteral, (symMode>>6)&3)
+ if err != nil {
+ return 0, 0, err
+ }
+
+ off, err = r.setSeqTable(data, off, seqOffset, (symMode>>4)&3)
+ if err != nil {
+ return 0, 0, err
+ }
+
+ off, err = r.setSeqTable(data, off, seqMatch, (symMode>>2)&3)
+ if err != nil {
+ return 0, 0, err
+ }
+
+ return seqCount, off, nil
+}
+
+// setSeqTable uses the Compression_Mode in mode to set up r.seqTables and
+// r.seqTableBits for kind. We store these in the Reader because one of
+// the modes simply reuses the value from the last block in the frame.
+func (r *Reader) setSeqTable(data block, off int, kind seqCode, mode byte) (int, error) {
+ info := &seqCodeInfo[kind]
+ switch mode {
+ case 0:
+ // Predefined_Mode
+ r.seqTables[kind] = info.predefTable
+ r.seqTableBits[kind] = uint8(info.predefTableBits)
+ return off, nil
+
+ case 1:
+ // RLE_Mode
+ if off >= len(data) {
+ return 0, r.makeEOFError(off)
+ }
+ rle := data[off]
+ off++
+
+ // Build a simple baseline table that always returns rle.
+
+ entry := []fseEntry{
+ {
+ sym: rle,
+ bits: 0,
+ base: 0,
+ },
+ }
+ if cap(r.seqTableBuffers[kind]) == 0 {
+ r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits)
+ }
+ r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1]
+ if err := info.toBaseline(r, off, entry, r.seqTableBuffers[kind]); err != nil {
+ return 0, err
+ }
+
+ r.seqTables[kind] = r.seqTableBuffers[kind]
+ r.seqTableBits[kind] = 0
+ return off, nil
+
+ case 2:
+ // FSE_Compressed_Mode
+ if cap(r.fseScratch) < 1<<info.maxBits {
+ r.fseScratch = make([]fseEntry, 1<<info.maxBits)
+ }
+ r.fseScratch = r.fseScratch[:1<<info.maxBits]
+
+ tableBits, roff, err := r.readFSE(data, off, info.maxSym, info.maxBits, r.fseScratch)
+ if err != nil {
+ return 0, err
+ }
+ r.fseScratch = r.fseScratch[:1<<tableBits]
+
+ if cap(r.seqTableBuffers[kind]) == 0 {
+ r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits)
+ }
+ r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1<<tableBits]
+
+ if err := info.toBaseline(r, roff, r.fseScratch, r.seqTableBuffers[kind]); err != nil {
+ return 0, err
+ }
+
+ r.seqTables[kind] = r.seqTableBuffers[kind]
+ r.seqTableBits[kind] = uint8(tableBits)
+ return roff, nil
+
+ case 3:
+ // Repeat_Mode
+ if len(r.seqTables[kind]) == 0 {
+ return 0, r.makeError(off, "missing repeat sequence FSE table")
+ }
+ return off, nil
+ }
+ panic("unreachable")
+}
+
+// execSeqs reads and executes the sequences. RFC 3.1.1.3.2.1.2.
+func (r *Reader) execSeqs(data block, off int, litbuf []byte, seqCount int) error {
+ // Set up the initial states for the sequence code readers.
+
+ rbr, err := r.makeReverseBitReader(data, len(data)-1, off)
+ if err != nil {
+ return err
+ }
+
+ literalState, err := rbr.val(r.seqTableBits[seqLiteral])
+ if err != nil {
+ return err
+ }
+
+ offsetState, err := rbr.val(r.seqTableBits[seqOffset])
+ if err != nil {
+ return err
+ }
+
+ matchState, err := rbr.val(r.seqTableBits[seqMatch])
+ if err != nil {
+ return err
+ }
+
+ // Read and perform all the sequences. RFC 3.1.1.4.
+
+ seq := 0
+ for seq < seqCount {
+ if len(r.buffer)+len(litbuf) > 128<<10 {
+ return rbr.makeError("uncompressed size too big")
+ }
+
+ ptoffset := &r.seqTables[seqOffset][offsetState]
+ ptmatch := &r.seqTables[seqMatch][matchState]
+ ptliteral := &r.seqTables[seqLiteral][literalState]
+
+ add, err := rbr.val(ptoffset.basebits)
+ if err != nil {
+ return err
+ }
+ offset := ptoffset.baseline + add
+
+ add, err = rbr.val(ptmatch.basebits)
+ if err != nil {
+ return err
+ }
+ match := ptmatch.baseline + add
+
+ add, err = rbr.val(ptliteral.basebits)
+ if err != nil {
+ return err
+ }
+ literal := ptliteral.baseline + add
+
+ // Handle repeat offsets. RFC 3.1.1.5.
+ // See the comment in makeOffsetBaselineFSE.
+ if ptoffset.basebits > 1 {
+ r.repeatedOffset3 = r.repeatedOffset2
+ r.repeatedOffset2 = r.repeatedOffset1
+ r.repeatedOffset1 = offset
+ } else {
+ if literal == 0 {
+ offset++
+ }
+ switch offset {
+ case 1:
+ offset = r.repeatedOffset1
+ case 2:
+ offset = r.repeatedOffset2
+ r.repeatedOffset2 = r.repeatedOffset1
+ r.repeatedOffset1 = offset
+ case 3:
+ offset = r.repeatedOffset3
+ r.repeatedOffset3 = r.repeatedOffset2
+ r.repeatedOffset2 = r.repeatedOffset1
+ r.repeatedOffset1 = offset
+ case 4:
+ offset = r.repeatedOffset1 - 1
+ r.repeatedOffset3 = r.repeatedOffset2
+ r.repeatedOffset2 = r.repeatedOffset1
+ r.repeatedOffset1 = offset
+ }
+ }
+
+ seq++
+ if seq < seqCount {
+ // Update the states.
+ add, err = rbr.val(ptliteral.bits)
+ if err != nil {
+ return err
+ }
+ literalState = uint32(ptliteral.base) + add
+
+ add, err = rbr.val(ptmatch.bits)
+ if err != nil {
+ return err
+ }
+ matchState = uint32(ptmatch.base) + add
+
+ add, err = rbr.val(ptoffset.bits)
+ if err != nil {
+ return err
+ }
+ offsetState = uint32(ptoffset.base) + add
+ }
+
+ // The next sequence is now in literal, offset, match.
+
+ if debug {
+ println("literal", literal, "offset", offset, "match", match)
+ }
+
+ // Copy literal bytes from litbuf.
+ if literal > uint32(len(litbuf)) {
+ return rbr.makeError("literal byte overflow")
+ }
+ if literal > 0 {
+ r.buffer = append(r.buffer, litbuf[:literal]...)
+ litbuf = litbuf[literal:]
+ }
+
+ if match > 0 {
+ if err := r.copyFromWindow(&rbr, offset, match); err != nil {
+ return err
+ }
+ }
+ }
+
+ if len(litbuf) > 0 {
+ r.buffer = append(r.buffer, litbuf...)
+ }
+
+ if rbr.cnt != 0 {
+ return r.makeError(off, "extraneous data after sequences")
+ }
+
+ return nil
+}
+
+// Copy match bytes from the decoded output, or the window, at offset.
+func (r *Reader) copyFromWindow(rbr *reverseBitReader, offset, match uint32) error {
+ if offset == 0 {
+ return rbr.makeError("invalid zero offset")
+ }
+
+ lenBlock := uint32(len(r.buffer))
+ if lenBlock < offset {
+ lenWindow := uint32(len(r.window))
+ windowOffset := offset - lenBlock
+ if windowOffset > lenWindow {
+ return rbr.makeError("offset past window")
+ }
+ from := lenWindow - windowOffset
+ if from+match <= lenWindow {
+ r.buffer = append(r.buffer, r.window[from:from+match]...)
+ return nil
+ }
+ r.buffer = append(r.buffer, r.window[from:]...)
+ copied := lenWindow - from
+ offset -= copied
+ match -= copied
+
+ if offset == 0 && match > 0 {
+ return rbr.makeError("invalid offset")
+ }
+ }
+
+ from := lenBlock - offset
+ if offset >= match {
+ r.buffer = append(r.buffer, r.buffer[from:from+match]...)
+ return nil
+ }
+
+ // We are being asked to copy data that we are adding to the
+ // buffer in the same copy.
+ for match > 0 {
+ var copy uint32
+ if offset >= match {
+ copy = match
+ } else {
+ copy = offset
+ }
+ r.buffer = append(r.buffer, r.buffer[from:from+copy]...)
+ match -= copy
+ from += copy
+ }
+ return nil
+}
diff --git a/src/internal/zstd/fse.go b/src/internal/zstd/fse.go
new file mode 100644
index 0000000..ea661d4
--- /dev/null
+++ b/src/internal/zstd/fse.go
@@ -0,0 +1,437 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zstd
+
+import (
+ "math/bits"
+)
+
+// fseEntry is one entry in an FSE table.
+type fseEntry struct {
+ sym uint8 // value that this entry records
+ bits uint8 // number of bits to read to determine next state
+ base uint16 // add those bits to this state to get the next state
+}
+
+// readFSE reads an FSE table from data starting at off.
+// maxSym is the maximum symbol value.
+// maxBits is the maximum number of bits permitted for symbols in the table.
+// The FSE is written into table, which must be at least 1<<maxBits in size.
+// This returns the number of bits in the FSE table and the new offset.
+// RFC 4.1.1.
+func (r *Reader) readFSE(data block, off, maxSym, maxBits int, table []fseEntry) (tableBits, roff int, err error) {
+ br := r.makeBitReader(data, off)
+ if err := br.moreBits(); err != nil {
+ return 0, 0, err
+ }
+
+ accuracyLog := int(br.val(4)) + 5
+ if accuracyLog > maxBits {
+ return 0, 0, br.makeError("FSE accuracy log too large")
+ }
+
+ // The number of remaining probabilities, plus 1.
+ // This determines the number of bits to be read for the next value.
+ remaining := (1 << accuracyLog) + 1
+
+ // The current difference between small and large values,
+ // which depends on the number of remaining values.
+ // Small values use 1 less bit.
+ threshold := 1 << accuracyLog
+
+ // The number of bits needed to compute threshold.
+ bitsNeeded := accuracyLog + 1
+
+ // The next character value.
+ sym := 0
+
+ // Whether the last count was 0.
+ prev0 := false
+
+ var norm [256]int16
+
+ for remaining > 1 && sym <= maxSym {
+ if err := br.moreBits(); err != nil {
+ return 0, 0, err
+ }
+
+ if prev0 {
+ // Previous count was 0, so there is a 2-bit
+ // repeat flag. If the 2-bit flag is 0b11,
+ // it adds 3 and then there is another repeat flag.
+ zsym := sym
+ for (br.bits & 0xfff) == 0xfff {
+ zsym += 3 * 6
+ br.bits >>= 12
+ br.cnt -= 12
+ if err := br.moreBits(); err != nil {
+ return 0, 0, err
+ }
+ }
+ for (br.bits & 3) == 3 {
+ zsym += 3
+ br.bits >>= 2
+ br.cnt -= 2
+ if err := br.moreBits(); err != nil {
+ return 0, 0, err
+ }
+ }
+
+ // We have at least 14 bits here,
+ // no need to call moreBits
+
+ zsym += int(br.val(2))
+
+ if zsym > maxSym {
+ return 0, 0, br.makeError("FSE symbol index overflow")
+ }
+
+ for ; sym < zsym; sym++ {
+ norm[uint8(sym)] = 0
+ }
+
+ prev0 = false
+ continue
+ }
+
+ max := (2*threshold - 1) - remaining
+ var count int
+ if int(br.bits&uint32(threshold-1)) < max {
+ // A small value.
+ count = int(br.bits & uint32((threshold - 1)))
+ br.bits >>= bitsNeeded - 1
+ br.cnt -= uint32(bitsNeeded - 1)
+ } else {
+ // A large value.
+ count = int(br.bits & uint32((2*threshold - 1)))
+ if count >= threshold {
+ count -= max
+ }
+ br.bits >>= bitsNeeded
+ br.cnt -= uint32(bitsNeeded)
+ }
+
+ count--
+ if count >= 0 {
+ remaining -= count
+ } else {
+ remaining--
+ }
+ if sym >= 256 {
+ return 0, 0, br.makeError("FSE sym overflow")
+ }
+ norm[uint8(sym)] = int16(count)
+ sym++
+
+ prev0 = count == 0
+
+ for remaining < threshold {
+ bitsNeeded--
+ threshold >>= 1
+ }
+ }
+
+ if remaining != 1 {
+ return 0, 0, br.makeError("too many symbols in FSE table")
+ }
+
+ for ; sym <= maxSym; sym++ {
+ norm[uint8(sym)] = 0
+ }
+
+ br.backup()
+
+ if err := r.buildFSE(off, norm[:maxSym+1], table, accuracyLog); err != nil {
+ return 0, 0, err
+ }
+
+ return accuracyLog, int(br.off), nil
+}
+
+// buildFSE builds an FSE decoding table from a list of probabilities.
+// The probabilities are in norm. next is scratch space. The number of bits
+// in the table is tableBits.
+func (r *Reader) buildFSE(off int, norm []int16, table []fseEntry, tableBits int) error {
+ tableSize := 1 << tableBits
+ highThreshold := tableSize - 1
+
+ var next [256]uint16
+
+ for i, n := range norm {
+ if n >= 0 {
+ next[uint8(i)] = uint16(n)
+ } else {
+ table[highThreshold].sym = uint8(i)
+ highThreshold--
+ next[uint8(i)] = 1
+ }
+ }
+
+ pos := 0
+ step := (tableSize >> 1) + (tableSize >> 3) + 3
+ mask := tableSize - 1
+ for i, n := range norm {
+ for j := 0; j < int(n); j++ {
+ table[pos].sym = uint8(i)
+ pos = (pos + step) & mask
+ for pos > highThreshold {
+ pos = (pos + step) & mask
+ }
+ }
+ }
+ if pos != 0 {
+ return r.makeError(off, "FSE count error")
+ }
+
+ for i := 0; i < tableSize; i++ {
+ sym := table[i].sym
+ nextState := next[sym]
+ next[sym]++
+
+ if nextState == 0 {
+ return r.makeError(off, "FSE state error")
+ }
+
+ highBit := 15 - bits.LeadingZeros16(nextState)
+
+ bits := tableBits - highBit
+ table[i].bits = uint8(bits)
+ table[i].base = (nextState << bits) - uint16(tableSize)
+ }
+
+ return nil
+}
+
+// fseBaselineEntry is an entry in an FSE baseline table.
+// We use these for literal/match/length values.
+// Those require mapping the symbol to a baseline value,
+// and then reading zero or more bits and adding the value to the baseline.
+// Rather than looking thees up in separate tables,
+// we convert the FSE table to an FSE baseline table.
+type fseBaselineEntry struct {
+ baseline uint32 // baseline for value that this entry represents
+ basebits uint8 // number of bits to read to add to baseline
+ bits uint8 // number of bits to read to determine next state
+ base uint16 // add the bits to this base to get the next state
+}
+
+// Given a literal length code, we need to read a number of bits and
+// add that to a baseline. For states 0 to 15 the baseline is the
+// state and the number of bits is zero. RFC 3.1.1.3.2.1.1.
+
+const literalLengthOffset = 16
+
+var literalLengthBase = []uint32{
+ 16 | (1 << 24),
+ 18 | (1 << 24),
+ 20 | (1 << 24),
+ 22 | (1 << 24),
+ 24 | (2 << 24),
+ 28 | (2 << 24),
+ 32 | (3 << 24),
+ 40 | (3 << 24),
+ 48 | (4 << 24),
+ 64 | (6 << 24),
+ 128 | (7 << 24),
+ 256 | (8 << 24),
+ 512 | (9 << 24),
+ 1024 | (10 << 24),
+ 2048 | (11 << 24),
+ 4096 | (12 << 24),
+ 8192 | (13 << 24),
+ 16384 | (14 << 24),
+ 32768 | (15 << 24),
+ 65536 | (16 << 24),
+}
+
+// makeLiteralBaselineFSE converts the literal length fseTable to baselineTable.
+func (r *Reader) makeLiteralBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error {
+ for i, e := range fseTable {
+ be := fseBaselineEntry{
+ bits: e.bits,
+ base: e.base,
+ }
+ if e.sym < literalLengthOffset {
+ be.baseline = uint32(e.sym)
+ be.basebits = 0
+ } else {
+ if e.sym > 35 {
+ return r.makeError(off, "FSE baseline symbol overflow")
+ }
+ idx := e.sym - literalLengthOffset
+ basebits := literalLengthBase[idx]
+ be.baseline = basebits & 0xffffff
+ be.basebits = uint8(basebits >> 24)
+ }
+ baselineTable[i] = be
+ }
+ return nil
+}
+
+// makeOffsetBaselineFSE converts the offset length fseTable to baselineTable.
+func (r *Reader) makeOffsetBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error {
+ for i, e := range fseTable {
+ be := fseBaselineEntry{
+ bits: e.bits,
+ base: e.base,
+ }
+ if e.sym > 31 {
+ return r.makeError(off, "FSE offset symbol overflow")
+ }
+
+ // The simple way to write this is
+ // be.baseline = 1 << e.sym
+ // be.basebits = e.sym
+ // That would give us an offset value that corresponds to
+ // the one described in the RFC. However, for offsets > 3
+ // we have to subtract 3. And for offset values 1, 2, 3
+ // we use a repeated offset.
+ //
+ // The baseline is always a power of 2, and is never 0,
+ // so for those low values we will see one entry that is
+ // baseline 1, basebits 0, and one entry that is baseline 2,
+ // basebits 1. All other entries will have baseline >= 4
+ // basebits >= 2.
+ //
+ // So we can check for RFC offset <= 3 by checking for
+ // basebits <= 1. That means that we can subtract 3 here
+ // and not worry about doing it in the hot loop.
+
+ be.baseline = 1 << e.sym
+ if e.sym >= 2 {
+ be.baseline -= 3
+ }
+ be.basebits = e.sym
+ baselineTable[i] = be
+ }
+ return nil
+}
+
+// Given a match length code, we need to read a number of bits and add
+// that to a baseline. For states 0 to 31 the baseline is state+3 and
+// the number of bits is zero. RFC 3.1.1.3.2.1.1.
+
+const matchLengthOffset = 32
+
+var matchLengthBase = []uint32{
+ 35 | (1 << 24),
+ 37 | (1 << 24),
+ 39 | (1 << 24),
+ 41 | (1 << 24),
+ 43 | (2 << 24),
+ 47 | (2 << 24),
+ 51 | (3 << 24),
+ 59 | (3 << 24),
+ 67 | (4 << 24),
+ 83 | (4 << 24),
+ 99 | (5 << 24),
+ 131 | (7 << 24),
+ 259 | (8 << 24),
+ 515 | (9 << 24),
+ 1027 | (10 << 24),
+ 2051 | (11 << 24),
+ 4099 | (12 << 24),
+ 8195 | (13 << 24),
+ 16387 | (14 << 24),
+ 32771 | (15 << 24),
+ 65539 | (16 << 24),
+}
+
+// makeMatchBaselineFSE converts the match length fseTable to baselineTable.
+func (r *Reader) makeMatchBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error {
+ for i, e := range fseTable {
+ be := fseBaselineEntry{
+ bits: e.bits,
+ base: e.base,
+ }
+ if e.sym < matchLengthOffset {
+ be.baseline = uint32(e.sym) + 3
+ be.basebits = 0
+ } else {
+ if e.sym > 52 {
+ return r.makeError(off, "FSE baseline symbol overflow")
+ }
+ idx := e.sym - matchLengthOffset
+ basebits := matchLengthBase[idx]
+ be.baseline = basebits & 0xffffff
+ be.basebits = uint8(basebits >> 24)
+ }
+ baselineTable[i] = be
+ }
+ return nil
+}
+
+// predefinedLiteralTable is the predefined table to use for literal lengths.
+// Generated from table in RFC 3.1.1.3.2.2.1.
+// Checked by TestPredefinedTables.
+var predefinedLiteralTable = [...]fseBaselineEntry{
+ {0, 0, 4, 0}, {0, 0, 4, 16}, {1, 0, 5, 32},
+ {3, 0, 5, 0}, {4, 0, 5, 0}, {6, 0, 5, 0},
+ {7, 0, 5, 0}, {9, 0, 5, 0}, {10, 0, 5, 0},
+ {12, 0, 5, 0}, {14, 0, 6, 0}, {16, 1, 5, 0},
+ {20, 1, 5, 0}, {22, 1, 5, 0}, {28, 2, 5, 0},
+ {32, 3, 5, 0}, {48, 4, 5, 0}, {64, 6, 5, 32},
+ {128, 7, 5, 0}, {256, 8, 6, 0}, {1024, 10, 6, 0},
+ {4096, 12, 6, 0}, {0, 0, 4, 32}, {1, 0, 4, 0},
+ {2, 0, 5, 0}, {4, 0, 5, 32}, {5, 0, 5, 0},
+ {7, 0, 5, 32}, {8, 0, 5, 0}, {10, 0, 5, 32},
+ {11, 0, 5, 0}, {13, 0, 6, 0}, {16, 1, 5, 32},
+ {18, 1, 5, 0}, {22, 1, 5, 32}, {24, 2, 5, 0},
+ {32, 3, 5, 32}, {40, 3, 5, 0}, {64, 6, 4, 0},
+ {64, 6, 4, 16}, {128, 7, 5, 32}, {512, 9, 6, 0},
+ {2048, 11, 6, 0}, {0, 0, 4, 48}, {1, 0, 4, 16},
+ {2, 0, 5, 32}, {3, 0, 5, 32}, {5, 0, 5, 32},
+ {6, 0, 5, 32}, {8, 0, 5, 32}, {9, 0, 5, 32},
+ {11, 0, 5, 32}, {12, 0, 5, 32}, {15, 0, 6, 0},
+ {18, 1, 5, 32}, {20, 1, 5, 32}, {24, 2, 5, 32},
+ {28, 2, 5, 32}, {40, 3, 5, 32}, {48, 4, 5, 32},
+ {65536, 16, 6, 0}, {32768, 15, 6, 0}, {16384, 14, 6, 0},
+ {8192, 13, 6, 0},
+}
+
+// predefinedOffsetTable is the predefined table to use for offsets.
+// Generated from table in RFC 3.1.1.3.2.2.3.
+// Checked by TestPredefinedTables.
+var predefinedOffsetTable = [...]fseBaselineEntry{
+ {1, 0, 5, 0}, {61, 6, 4, 0}, {509, 9, 5, 0},
+ {32765, 15, 5, 0}, {2097149, 21, 5, 0}, {5, 3, 5, 0},
+ {125, 7, 4, 0}, {4093, 12, 5, 0}, {262141, 18, 5, 0},
+ {8388605, 23, 5, 0}, {29, 5, 5, 0}, {253, 8, 4, 0},
+ {16381, 14, 5, 0}, {1048573, 20, 5, 0}, {1, 2, 5, 0},
+ {125, 7, 4, 16}, {2045, 11, 5, 0}, {131069, 17, 5, 0},
+ {4194301, 22, 5, 0}, {13, 4, 5, 0}, {253, 8, 4, 16},
+ {8189, 13, 5, 0}, {524285, 19, 5, 0}, {2, 1, 5, 0},
+ {61, 6, 4, 16}, {1021, 10, 5, 0}, {65533, 16, 5, 0},
+ {268435453, 28, 5, 0}, {134217725, 27, 5, 0}, {67108861, 26, 5, 0},
+ {33554429, 25, 5, 0}, {16777213, 24, 5, 0},
+}
+
+// predefinedMatchTable is the predefined table to use for match lengths.
+// Generated from table in RFC 3.1.1.3.2.2.2.
+// Checked by TestPredefinedTables.
+var predefinedMatchTable = [...]fseBaselineEntry{
+ {3, 0, 6, 0}, {4, 0, 4, 0}, {5, 0, 5, 32},
+ {6, 0, 5, 0}, {8, 0, 5, 0}, {9, 0, 5, 0},
+ {11, 0, 5, 0}, {13, 0, 6, 0}, {16, 0, 6, 0},
+ {19, 0, 6, 0}, {22, 0, 6, 0}, {25, 0, 6, 0},
+ {28, 0, 6, 0}, {31, 0, 6, 0}, {34, 0, 6, 0},
+ {37, 1, 6, 0}, {41, 1, 6, 0}, {47, 2, 6, 0},
+ {59, 3, 6, 0}, {83, 4, 6, 0}, {131, 7, 6, 0},
+ {515, 9, 6, 0}, {4, 0, 4, 16}, {5, 0, 4, 0},
+ {6, 0, 5, 32}, {7, 0, 5, 0}, {9, 0, 5, 32},
+ {10, 0, 5, 0}, {12, 0, 6, 0}, {15, 0, 6, 0},
+ {18, 0, 6, 0}, {21, 0, 6, 0}, {24, 0, 6, 0},
+ {27, 0, 6, 0}, {30, 0, 6, 0}, {33, 0, 6, 0},
+ {35, 1, 6, 0}, {39, 1, 6, 0}, {43, 2, 6, 0},
+ {51, 3, 6, 0}, {67, 4, 6, 0}, {99, 5, 6, 0},
+ {259, 8, 6, 0}, {4, 0, 4, 32}, {4, 0, 4, 48},
+ {5, 0, 4, 16}, {7, 0, 5, 32}, {8, 0, 5, 32},
+ {10, 0, 5, 32}, {11, 0, 5, 32}, {14, 0, 6, 0},
+ {17, 0, 6, 0}, {20, 0, 6, 0}, {23, 0, 6, 0},
+ {26, 0, 6, 0}, {29, 0, 6, 0}, {32, 0, 6, 0},
+ {65539, 16, 6, 0}, {32771, 15, 6, 0}, {16387, 14, 6, 0},
+ {8195, 13, 6, 0}, {4099, 12, 6, 0}, {2051, 11, 6, 0},
+ {1027, 10, 6, 0},
+}
diff --git a/src/internal/zstd/fse_test.go b/src/internal/zstd/fse_test.go
new file mode 100644
index 0000000..6f106b6
--- /dev/null
+++ b/src/internal/zstd/fse_test.go
@@ -0,0 +1,89 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zstd
+
+import (
+ "slices"
+ "testing"
+)
+
+// literalPredefinedDistribution is the predefined distribution table
+// for literal lengths. RFC 3.1.1.3.2.2.1.
+var literalPredefinedDistribution = []int16{
+ 4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1,
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1,
+ -1, -1, -1, -1,
+}
+
+// offsetPredefinedDistribution is the predefined distribution table
+// for offsets. RFC 3.1.1.3.2.2.3.
+var offsetPredefinedDistribution = []int16{
+ 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1,
+}
+
+// matchPredefinedDistribution is the predefined distribution table
+// for match lengths. RFC 3.1.1.3.2.2.2.
+var matchPredefinedDistribution = []int16{
+ 1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1,
+ -1, -1, -1, -1, -1,
+}
+
+// TestPredefinedTables verifies that we can generate the predefined
+// literal/offset/match tables from the input data in RFC 8878.
+// This serves as a test of the predefined tables, and also of buildFSE
+// and the functions that make baseline FSE tables.
+func TestPredefinedTables(t *testing.T) {
+ tests := []struct {
+ name string
+ distribution []int16
+ tableBits int
+ toBaseline func(*Reader, int, []fseEntry, []fseBaselineEntry) error
+ predef []fseBaselineEntry
+ }{
+ {
+ name: "literal",
+ distribution: literalPredefinedDistribution,
+ tableBits: 6,
+ toBaseline: (*Reader).makeLiteralBaselineFSE,
+ predef: predefinedLiteralTable[:],
+ },
+ {
+ name: "offset",
+ distribution: offsetPredefinedDistribution,
+ tableBits: 5,
+ toBaseline: (*Reader).makeOffsetBaselineFSE,
+ predef: predefinedOffsetTable[:],
+ },
+ {
+ name: "match",
+ distribution: matchPredefinedDistribution,
+ tableBits: 6,
+ toBaseline: (*Reader).makeMatchBaselineFSE,
+ predef: predefinedMatchTable[:],
+ },
+ }
+ for _, test := range tests {
+ test := test
+ t.Run(test.name, func(t *testing.T) {
+ var r Reader
+ table := make([]fseEntry, 1<<test.tableBits)
+ if err := r.buildFSE(0, test.distribution, table, test.tableBits); err != nil {
+ t.Fatal(err)
+ }
+
+ baselineTable := make([]fseBaselineEntry, len(table))
+ if err := test.toBaseline(&r, 0, table, baselineTable); err != nil {
+ t.Fatal(err)
+ }
+
+ if !slices.Equal(baselineTable, test.predef) {
+ t.Errorf("got %v, want %v", baselineTable, test.predef)
+ }
+ })
+ }
+}
diff --git a/src/internal/zstd/fuzz_test.go b/src/internal/zstd/fuzz_test.go
new file mode 100644
index 0000000..bb6f0a9
--- /dev/null
+++ b/src/internal/zstd/fuzz_test.go
@@ -0,0 +1,140 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zstd
+
+import (
+ "bytes"
+ "io"
+ "os"
+ "os/exec"
+ "testing"
+)
+
+// badStrings is some inputs that FuzzReader failed on earlier.
+var badStrings = []string{
+ "(\xb5/\xfdd00,\x05\x00\xc4\x0400000000000000000000000000000000000000000000000000000000000000000000000000000 \xa07100000000000000000000000000000000000000000000000000000000000000000000000000aM\x8a2y0B\b",
+ "(\xb5/\xfd00$\x05\x0020 00X70000a70000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
+ "(\xb5/\xfd00$\x05\x0020 00B00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
+ "(\xb5/\xfd00}\x00\x0020\x00\x9000000000000",
+ "(\xb5/\xfd00}\x00\x00&0\x02\x830!000000000",
+ "(\xb5/\xfd\x1002000$\x05\x0010\xcc0\xa8100000000100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
+ "(\xb5/\xfd\x1002000$\x05\x0000\xcc0\xa8100d\x0000001000000000000000000000000000000000000000000000000000000000000000000000000\x000000000000000000000000000000000000000000000000000000000000000000000000000000",
+ "(\xb5/\xfd001\x00\x0000000000000000000",
+}
+
+// This is a simple fuzzer to see if the decompressor panics.
+func FuzzReader(f *testing.F) {
+ for _, test := range tests {
+ f.Add([]byte(test.compressed))
+ }
+ for _, s := range badStrings {
+ f.Add([]byte(s))
+ }
+ f.Fuzz(func(t *testing.T, b []byte) {
+ r := NewReader(bytes.NewReader(b))
+ io.Copy(io.Discard, r)
+ })
+}
+
+// Fuzz test to verify that what we decompress is what we compress.
+// This isn't a great fuzz test because the fuzzer can't efficiently
+// explore the space of decompressor behavior, since it can't see
+// what the compressor is doing. But it's better than nothing.
+func FuzzDecompressor(f *testing.F) {
+ if _, err := os.Stat("/usr/bin/zstd"); err != nil {
+ f.Skip("skipping because /usr/bin/zstd does not exist")
+ }
+
+ for _, test := range tests {
+ f.Add([]byte(test.uncompressed))
+ }
+
+ // Add some larger data, as that has more interesting compression.
+ f.Add(bytes.Repeat([]byte("abcdefghijklmnop"), 256))
+ var buf bytes.Buffer
+ for i := 0; i < 256; i++ {
+ buf.WriteByte(byte(i))
+ }
+ f.Add(bytes.Repeat(buf.Bytes(), 64))
+ f.Add(bigData(f))
+
+ f.Fuzz(func(t *testing.T, b []byte) {
+ cmd := exec.Command("/usr/bin/zstd", "-z")
+ cmd.Stdin = bytes.NewReader(b)
+ var compressed bytes.Buffer
+ cmd.Stdout = &compressed
+ cmd.Stderr = os.Stderr
+ if err := cmd.Run(); err != nil {
+ t.Errorf("running zstd failed: %v", err)
+ }
+
+ r := NewReader(bytes.NewReader(compressed.Bytes()))
+ got, err := io.ReadAll(r)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(got, b) {
+ showDiffs(t, got, b)
+ }
+ })
+}
+
+// Fuzz test to check that if we can decompress some data,
+// so can zstd, and that we get the same result.
+func FuzzReverse(f *testing.F) {
+ if _, err := os.Stat("/usr/bin/zstd"); err != nil {
+ f.Skip("skipping because /usr/bin/zstd does not exist")
+ }
+
+ for _, test := range tests {
+ f.Add([]byte(test.compressed))
+ }
+
+ // Set a hook to reject some cases where we don't match zstd.
+ fuzzing = true
+ defer func() { fuzzing = false }()
+
+ f.Fuzz(func(t *testing.T, b []byte) {
+ r := NewReader(bytes.NewReader(b))
+ goExp, goErr := io.ReadAll(r)
+
+ cmd := exec.Command("/usr/bin/zstd", "-d")
+ cmd.Stdin = bytes.NewReader(b)
+ var uncompressed bytes.Buffer
+ cmd.Stdout = &uncompressed
+ cmd.Stderr = os.Stderr
+ zstdErr := cmd.Run()
+ zstdExp := uncompressed.Bytes()
+
+ if goErr == nil && zstdErr == nil {
+ if !bytes.Equal(zstdExp, goExp) {
+ showDiffs(t, zstdExp, goExp)
+ }
+ } else {
+ // Ideally we should check that this package and
+ // the zstd program both fail or both succeed,
+ // and that if they both fail one byte sequence
+ // is an exact prefix of the other.
+ // Actually trying this proved to be frustrating,
+ // as the zstd program appears to accept invalid
+ // byte sequences using rules that are difficult
+ // to determine.
+ // So we just check the prefix.
+
+ c := len(goExp)
+ if c > len(zstdExp) {
+ c = len(zstdExp)
+ }
+ goExp = goExp[:c]
+ zstdExp = zstdExp[:c]
+ if !bytes.Equal(goExp, zstdExp) {
+ t.Error("byte mismatch after error")
+ t.Logf("Go error: %v\n", goErr)
+ t.Logf("zstd error: %v\n", zstdErr)
+ showDiffs(t, zstdExp, goExp)
+ }
+ }
+ })
+}
diff --git a/src/internal/zstd/huff.go b/src/internal/zstd/huff.go
new file mode 100644
index 0000000..452e24b
--- /dev/null
+++ b/src/internal/zstd/huff.go
@@ -0,0 +1,204 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zstd
+
+import (
+ "io"
+ "math/bits"
+)
+
+// maxHuffmanBits is the largest possible Huffman table bits.
+const maxHuffmanBits = 11
+
+// readHuff reads Huffman table from data starting at off into table.
+// Each entry in a Huffman table is a pair of bytes.
+// The high byte is the encoded value. The low byte is the number
+// of bits used to encode that value. We index into the table
+// with a value of size tableBits. A value that requires fewer bits
+// appear in the table multiple times.
+// This returns the number of bits in the Huffman table and the new offset.
+// RFC 4.2.1.
+func (r *Reader) readHuff(data block, off int, table []uint16) (tableBits, roff int, err error) {
+ if off >= len(data) {
+ return 0, 0, r.makeEOFError(off)
+ }
+
+ hdr := data[off]
+ off++
+
+ var weights [256]uint8
+ var count int
+ if hdr < 128 {
+ // The table is compressed using an FSE. RFC 4.2.1.2.
+ if len(r.fseScratch) < 1<<6 {
+ r.fseScratch = make([]fseEntry, 1<<6)
+ }
+ fseBits, noff, err := r.readFSE(data, off, 255, 6, r.fseScratch)
+ if err != nil {
+ return 0, 0, err
+ }
+ fseTable := r.fseScratch
+
+ if off+int(hdr) > len(data) {
+ return 0, 0, r.makeEOFError(off)
+ }
+
+ rbr, err := r.makeReverseBitReader(data, off+int(hdr)-1, noff)
+ if err != nil {
+ return 0, 0, err
+ }
+
+ state1, err := rbr.val(uint8(fseBits))
+ if err != nil {
+ return 0, 0, err
+ }
+
+ state2, err := rbr.val(uint8(fseBits))
+ if err != nil {
+ return 0, 0, err
+ }
+
+ // There are two independent FSE streams, tracked by
+ // state1 and state2. We decode them alternately.
+
+ for {
+ pt := &fseTable[state1]
+ if !rbr.fetch(pt.bits) {
+ if count >= 254 {
+ return 0, 0, rbr.makeError("Huffman count overflow")
+ }
+ weights[count] = pt.sym
+ weights[count+1] = fseTable[state2].sym
+ count += 2
+ break
+ }
+
+ v, err := rbr.val(pt.bits)
+ if err != nil {
+ return 0, 0, err
+ }
+ state1 = uint32(pt.base) + v
+
+ if count >= 255 {
+ return 0, 0, rbr.makeError("Huffman count overflow")
+ }
+
+ weights[count] = pt.sym
+ count++
+
+ pt = &fseTable[state2]
+
+ if !rbr.fetch(pt.bits) {
+ if count >= 254 {
+ return 0, 0, rbr.makeError("Huffman count overflow")
+ }
+ weights[count] = pt.sym
+ weights[count+1] = fseTable[state1].sym
+ count += 2
+ break
+ }
+
+ v, err = rbr.val(pt.bits)
+ if err != nil {
+ return 0, 0, err
+ }
+ state2 = uint32(pt.base) + v
+
+ if count >= 255 {
+ return 0, 0, rbr.makeError("Huffman count overflow")
+ }
+
+ weights[count] = pt.sym
+ count++
+ }
+
+ off += int(hdr)
+ } else {
+ // The table is not compressed. Each weight is 4 bits.
+
+ count = int(hdr) - 127
+ if off+((count+1)/2) >= len(data) {
+ return 0, 0, io.ErrUnexpectedEOF
+ }
+ for i := 0; i < count; i += 2 {
+ b := data[off]
+ off++
+ weights[i] = b >> 4
+ weights[i+1] = b & 0xf
+ }
+ }
+
+ // RFC 4.2.1.3.
+
+ var weightMark [13]uint32
+ weightMask := uint32(0)
+ for _, w := range weights[:count] {
+ if w > 12 {
+ return 0, 0, r.makeError(off, "Huffman weight overflow")
+ }
+ weightMark[w]++
+ if w > 0 {
+ weightMask += 1 << (w - 1)
+ }
+ }
+ if weightMask == 0 {
+ return 0, 0, r.makeError(off, "bad Huffman weights")
+ }
+
+ tableBits = 32 - bits.LeadingZeros32(weightMask)
+ if tableBits > maxHuffmanBits {
+ return 0, 0, r.makeError(off, "bad Huffman weights")
+ }
+
+ if len(table) < 1<<tableBits {
+ return 0, 0, r.makeError(off, "Huffman table too small")
+ }
+
+ // Work out the last weight value, which is omitted because
+ // the weights must sum to a power of two.
+ left := (uint32(1) << tableBits) - weightMask
+ if left == 0 {
+ return 0, 0, r.makeError(off, "bad Huffman weights")
+ }
+ highBit := 31 - bits.LeadingZeros32(left)
+ if uint32(1)<<highBit != left {
+ return 0, 0, r.makeError(off, "bad Huffman weights")
+ }
+ if count >= 256 {
+ return 0, 0, r.makeError(off, "Huffman weight overflow")
+ }
+ weights[count] = uint8(highBit + 1)
+ count++
+ weightMark[highBit+1]++
+
+ if weightMark[1] < 2 || weightMark[1]&1 != 0 {
+ return 0, 0, r.makeError(off, "bad Huffman weights")
+ }
+
+ // Change weightMark from a count of weights to the index of
+ // the first symbol for that weight. We shift the indexes to
+ // also store how many we have seen so far,
+ next := uint32(0)
+ for i := 0; i < tableBits; i++ {
+ cur := next
+ next += weightMark[i+1] << i
+ weightMark[i+1] = cur
+ }
+
+ for i, w := range weights[:count] {
+ if w == 0 {
+ continue
+ }
+ length := uint32(1) << (w - 1)
+ tval := uint16(i)<<8 | (uint16(tableBits) + 1 - uint16(w))
+ start := weightMark[w]
+ for j := uint32(0); j < length; j++ {
+ table[start+j] = tval
+ }
+ weightMark[w] += length
+ }
+
+ return tableBits, off, nil
+}
diff --git a/src/internal/zstd/literals.go b/src/internal/zstd/literals.go
new file mode 100644
index 0000000..b46d668
--- /dev/null
+++ b/src/internal/zstd/literals.go
@@ -0,0 +1,330 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zstd
+
+import (
+ "encoding/binary"
+)
+
+// readLiterals reads and decompresses the literals from data at off.
+// The literals are appended to outbuf, which is returned.
+// Also returns the new input offset. RFC 3.1.1.3.1.
+func (r *Reader) readLiterals(data block, off int, outbuf []byte) (int, []byte, error) {
+ if off >= len(data) {
+ return 0, nil, r.makeEOFError(off)
+ }
+
+ // Literals section header. RFC 3.1.1.3.1.1.
+ hdr := data[off]
+ off++
+
+ if (hdr&3) == 0 || (hdr&3) == 1 {
+ return r.readRawRLELiterals(data, off, hdr, outbuf)
+ } else {
+ return r.readHuffLiterals(data, off, hdr, outbuf)
+ }
+}
+
+// readRawRLELiterals reads and decompresses a Raw_Literals_Block or
+// a RLE_Literals_Block. RFC 3.1.1.3.1.1.
+func (r *Reader) readRawRLELiterals(data block, off int, hdr byte, outbuf []byte) (int, []byte, error) {
+ raw := (hdr & 3) == 0
+
+ var regeneratedSize int
+ switch (hdr >> 2) & 3 {
+ case 0, 2:
+ regeneratedSize = int(hdr >> 3)
+ case 1:
+ if off >= len(data) {
+ return 0, nil, r.makeEOFError(off)
+ }
+ regeneratedSize = int(hdr>>4) + (int(data[off]) << 4)
+ off++
+ case 3:
+ if off+1 >= len(data) {
+ return 0, nil, r.makeEOFError(off)
+ }
+ regeneratedSize = int(hdr>>4) + (int(data[off]) << 4) + (int(data[off+1]) << 12)
+ off += 2
+ }
+
+ // We are going to use the entire literal block in the output.
+ // The maximum size of one decompressed block is 128K,
+ // so we can't have more literals than that.
+ if regeneratedSize > 128<<10 {
+ return 0, nil, r.makeError(off, "literal size too large")
+ }
+
+ if raw {
+ // RFC 3.1.1.3.1.2.
+ if off+regeneratedSize > len(data) {
+ return 0, nil, r.makeError(off, "raw literal size too large")
+ }
+ outbuf = append(outbuf, data[off:off+regeneratedSize]...)
+ off += regeneratedSize
+ } else {
+ // RFC 3.1.1.3.1.3.
+ if off >= len(data) {
+ return 0, nil, r.makeError(off, "RLE literal missing")
+ }
+ rle := data[off]
+ off++
+ for i := 0; i < regeneratedSize; i++ {
+ outbuf = append(outbuf, rle)
+ }
+ }
+
+ return off, outbuf, nil
+}
+
+// readHuffLiterals reads and decompresses a Compressed_Literals_Block or
+// a Treeless_Literals_Block. RFC 3.1.1.3.1.4.
+func (r *Reader) readHuffLiterals(data block, off int, hdr byte, outbuf []byte) (int, []byte, error) {
+ var (
+ regeneratedSize int
+ compressedSize int
+ streams int
+ )
+ switch (hdr >> 2) & 3 {
+ case 0, 1:
+ if off+1 >= len(data) {
+ return 0, nil, r.makeEOFError(off)
+ }
+ regeneratedSize = (int(hdr) >> 4) | ((int(data[off]) & 0x3f) << 4)
+ compressedSize = (int(data[off]) >> 6) | (int(data[off+1]) << 2)
+ off += 2
+ if ((hdr >> 2) & 3) == 0 {
+ streams = 1
+ } else {
+ streams = 4
+ }
+ case 2:
+ if off+2 >= len(data) {
+ return 0, nil, r.makeEOFError(off)
+ }
+ regeneratedSize = (int(hdr) >> 4) | (int(data[off]) << 4) | ((int(data[off+1]) & 3) << 12)
+ compressedSize = (int(data[off+1]) >> 2) | (int(data[off+2]) << 6)
+ off += 3
+ streams = 4
+ case 3:
+ if off+3 >= len(data) {
+ return 0, nil, r.makeEOFError(off)
+ }
+ regeneratedSize = (int(hdr) >> 4) | (int(data[off]) << 4) | ((int(data[off+1]) & 0x3f) << 12)
+ compressedSize = (int(data[off+1]) >> 6) | (int(data[off+2]) << 2) | (int(data[off+3]) << 10)
+ off += 4
+ streams = 4
+ }
+
+ // We are going to use the entire literal block in the output.
+ // The maximum size of one decompressed block is 128K,
+ // so we can't have more literals than that.
+ if regeneratedSize > 128<<10 {
+ return 0, nil, r.makeError(off, "literal size too large")
+ }
+
+ roff := off + compressedSize
+ if roff > len(data) || roff < 0 {
+ return 0, nil, r.makeEOFError(off)
+ }
+
+ totalStreamsSize := compressedSize
+ if (hdr & 3) == 2 {
+ // Compressed_Literals_Block.
+ // Read new huffman tree.
+
+ if len(r.huffmanTable) < 1<<maxHuffmanBits {
+ r.huffmanTable = make([]uint16, 1<<maxHuffmanBits)
+ }
+
+ huffmanTableBits, hoff, err := r.readHuff(data, off, r.huffmanTable)
+ if err != nil {
+ return 0, nil, err
+ }
+ r.huffmanTableBits = huffmanTableBits
+
+ if totalStreamsSize < hoff-off {
+ return 0, nil, r.makeError(off, "Huffman table too big")
+ }
+ totalStreamsSize -= hoff - off
+ off = hoff
+ } else {
+ // Treeless_Literals_Block
+ // Reuse previous Huffman tree.
+ if r.huffmanTableBits == 0 {
+ return 0, nil, r.makeError(off, "missing literals Huffman tree")
+ }
+ }
+
+ // Decompress compressedSize bytes of data at off using the
+ // Huffman tree.
+
+ var err error
+ if streams == 1 {
+ outbuf, err = r.readLiteralsOneStream(data, off, totalStreamsSize, regeneratedSize, outbuf)
+ } else {
+ outbuf, err = r.readLiteralsFourStreams(data, off, totalStreamsSize, regeneratedSize, outbuf)
+ }
+
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return roff, outbuf, nil
+}
+
+// readLiteralsOneStream reads a single stream of compressed literals.
+func (r *Reader) readLiteralsOneStream(data block, off, compressedSize, regeneratedSize int, outbuf []byte) ([]byte, error) {
+ // We let the reverse bit reader read earlier bytes,
+ // because the Huffman table ignores bits that it doesn't need.
+ rbr, err := r.makeReverseBitReader(data, off+compressedSize-1, off-2)
+ if err != nil {
+ return nil, err
+ }
+
+ huffTable := r.huffmanTable
+ huffBits := uint32(r.huffmanTableBits)
+ huffMask := (uint32(1) << huffBits) - 1
+
+ for i := 0; i < regeneratedSize; i++ {
+ if !rbr.fetch(uint8(huffBits)) {
+ return nil, rbr.makeError("literals Huffman stream out of bits")
+ }
+
+ var t uint16
+ idx := (rbr.bits >> (rbr.cnt - huffBits)) & huffMask
+ t = huffTable[idx]
+ outbuf = append(outbuf, byte(t>>8))
+ rbr.cnt -= uint32(t & 0xff)
+ }
+
+ return outbuf, nil
+}
+
+// readLiteralsFourStreams reads four interleaved streams of
+// compressed literals.
+func (r *Reader) readLiteralsFourStreams(data block, off, totalStreamsSize, regeneratedSize int, outbuf []byte) ([]byte, error) {
+ // Read the jump table to find out where the streams are.
+ // RFC 3.1.1.3.1.6.
+ if off+5 >= len(data) {
+ return nil, r.makeEOFError(off)
+ }
+ if totalStreamsSize < 6 {
+ return nil, r.makeError(off, "total streams size too small for jump table")
+ }
+
+ streamSize1 := binary.LittleEndian.Uint16(data[off:])
+ streamSize2 := binary.LittleEndian.Uint16(data[off+2:])
+ streamSize3 := binary.LittleEndian.Uint16(data[off+4:])
+ off += 6
+
+ tot := uint64(streamSize1) + uint64(streamSize2) + uint64(streamSize3)
+ if tot > uint64(totalStreamsSize)-6 {
+ return nil, r.makeEOFError(off)
+ }
+ streamSize4 := uint32(totalStreamsSize) - 6 - uint32(tot)
+
+ off--
+ off1 := off + int(streamSize1)
+ start1 := off + 1
+
+ off2 := off1 + int(streamSize2)
+ start2 := off1 + 1
+
+ off3 := off2 + int(streamSize3)
+ start3 := off2 + 1
+
+ off4 := off3 + int(streamSize4)
+ start4 := off3 + 1
+
+ // We let the reverse bit readers read earlier bytes,
+ // because the Huffman tables ignore bits that they don't need.
+
+ rbr1, err := r.makeReverseBitReader(data, off1, start1-2)
+ if err != nil {
+ return nil, err
+ }
+
+ rbr2, err := r.makeReverseBitReader(data, off2, start2-2)
+ if err != nil {
+ return nil, err
+ }
+
+ rbr3, err := r.makeReverseBitReader(data, off3, start3-2)
+ if err != nil {
+ return nil, err
+ }
+
+ rbr4, err := r.makeReverseBitReader(data, off4, start4-2)
+ if err != nil {
+ return nil, err
+ }
+
+ regeneratedStreamSize := (regeneratedSize + 3) / 4
+
+ out1 := len(outbuf)
+ out2 := out1 + regeneratedStreamSize
+ out3 := out2 + regeneratedStreamSize
+ out4 := out3 + regeneratedStreamSize
+
+ regeneratedStreamSize4 := regeneratedSize - regeneratedStreamSize*3
+
+ outbuf = append(outbuf, make([]byte, regeneratedSize)...)
+
+ huffTable := r.huffmanTable
+ huffBits := uint32(r.huffmanTableBits)
+ huffMask := (uint32(1) << huffBits) - 1
+
+ for i := 0; i < regeneratedStreamSize; i++ {
+ use4 := i < regeneratedStreamSize4
+
+ fetchHuff := func(rbr *reverseBitReader) (uint16, error) {
+ if !rbr.fetch(uint8(huffBits)) {
+ return 0, rbr.makeError("literals Huffman stream out of bits")
+ }
+ idx := (rbr.bits >> (rbr.cnt - huffBits)) & huffMask
+ return huffTable[idx], nil
+ }
+
+ t1, err := fetchHuff(&rbr1)
+ if err != nil {
+ return nil, err
+ }
+
+ t2, err := fetchHuff(&rbr2)
+ if err != nil {
+ return nil, err
+ }
+
+ t3, err := fetchHuff(&rbr3)
+ if err != nil {
+ return nil, err
+ }
+
+ if use4 {
+ t4, err := fetchHuff(&rbr4)
+ if err != nil {
+ return nil, err
+ }
+ outbuf[out4] = byte(t4 >> 8)
+ out4++
+ rbr4.cnt -= uint32(t4 & 0xff)
+ }
+
+ outbuf[out1] = byte(t1 >> 8)
+ out1++
+ rbr1.cnt -= uint32(t1 & 0xff)
+
+ outbuf[out2] = byte(t2 >> 8)
+ out2++
+ rbr2.cnt -= uint32(t2 & 0xff)
+
+ outbuf[out3] = byte(t3 >> 8)
+ out3++
+ rbr3.cnt -= uint32(t3 & 0xff)
+ }
+
+ return outbuf, nil
+}
diff --git a/src/internal/zstd/xxhash.go b/src/internal/zstd/xxhash.go
new file mode 100644
index 0000000..4d579ee
--- /dev/null
+++ b/src/internal/zstd/xxhash.go
@@ -0,0 +1,148 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zstd
+
+import (
+ "encoding/binary"
+ "math/bits"
+)
+
+const (
+ xxhPrime64c1 = 0x9e3779b185ebca87
+ xxhPrime64c2 = 0xc2b2ae3d27d4eb4f
+ xxhPrime64c3 = 0x165667b19e3779f9
+ xxhPrime64c4 = 0x85ebca77c2b2ae63
+ xxhPrime64c5 = 0x27d4eb2f165667c5
+)
+
+// xxhash64 is the state of a xxHash-64 checksum.
+type xxhash64 struct {
+ len uint64 // total length hashed
+ v [4]uint64 // accumulators
+ buf [32]byte // buffer
+ cnt int // number of bytes in buffer
+}
+
+// reset discards the current state and prepares to compute a new hash.
+// We assume a seed of 0 since that is what zstd uses.
+func (xh *xxhash64) reset() {
+ xh.len = 0
+
+ // Separate addition for awkward constant overflow.
+ xh.v[0] = xxhPrime64c1
+ xh.v[0] += xxhPrime64c2
+
+ xh.v[1] = xxhPrime64c2
+ xh.v[2] = 0
+
+ // Separate negation for awkward constant overflow.
+ xh.v[3] = xxhPrime64c1
+ xh.v[3] = -xh.v[3]
+
+ for i := range xh.buf {
+ xh.buf[i] = 0
+ }
+ xh.cnt = 0
+}
+
+// update adds a buffer to the has.
+func (xh *xxhash64) update(b []byte) {
+ xh.len += uint64(len(b))
+
+ if xh.cnt+len(b) < len(xh.buf) {
+ copy(xh.buf[xh.cnt:], b)
+ xh.cnt += len(b)
+ return
+ }
+
+ if xh.cnt > 0 {
+ n := copy(xh.buf[xh.cnt:], b)
+ b = b[n:]
+ xh.v[0] = xh.round(xh.v[0], binary.LittleEndian.Uint64(xh.buf[:]))
+ xh.v[1] = xh.round(xh.v[1], binary.LittleEndian.Uint64(xh.buf[8:]))
+ xh.v[2] = xh.round(xh.v[2], binary.LittleEndian.Uint64(xh.buf[16:]))
+ xh.v[3] = xh.round(xh.v[3], binary.LittleEndian.Uint64(xh.buf[24:]))
+ xh.cnt = 0
+ }
+
+ for len(b) >= 32 {
+ xh.v[0] = xh.round(xh.v[0], binary.LittleEndian.Uint64(b))
+ xh.v[1] = xh.round(xh.v[1], binary.LittleEndian.Uint64(b[8:]))
+ xh.v[2] = xh.round(xh.v[2], binary.LittleEndian.Uint64(b[16:]))
+ xh.v[3] = xh.round(xh.v[3], binary.LittleEndian.Uint64(b[24:]))
+ b = b[32:]
+ }
+
+ if len(b) > 0 {
+ copy(xh.buf[:], b)
+ xh.cnt = len(b)
+ }
+}
+
+// digest returns the final hash value.
+func (xh *xxhash64) digest() uint64 {
+ var h64 uint64
+ if xh.len < 32 {
+ h64 = xh.v[2] + xxhPrime64c5
+ } else {
+ h64 = bits.RotateLeft64(xh.v[0], 1) +
+ bits.RotateLeft64(xh.v[1], 7) +
+ bits.RotateLeft64(xh.v[2], 12) +
+ bits.RotateLeft64(xh.v[3], 18)
+ h64 = xh.mergeRound(h64, xh.v[0])
+ h64 = xh.mergeRound(h64, xh.v[1])
+ h64 = xh.mergeRound(h64, xh.v[2])
+ h64 = xh.mergeRound(h64, xh.v[3])
+ }
+
+ h64 += xh.len
+
+ len := xh.len
+ len &= 31
+ buf := xh.buf[:]
+ for len >= 8 {
+ k1 := xh.round(0, binary.LittleEndian.Uint64(buf))
+ buf = buf[8:]
+ h64 ^= k1
+ h64 = bits.RotateLeft64(h64, 27)*xxhPrime64c1 + xxhPrime64c4
+ len -= 8
+ }
+ if len >= 4 {
+ h64 ^= uint64(binary.LittleEndian.Uint32(buf)) * xxhPrime64c1
+ buf = buf[4:]
+ h64 = bits.RotateLeft64(h64, 23)*xxhPrime64c2 + xxhPrime64c3
+ len -= 4
+ }
+ for len > 0 {
+ h64 ^= uint64(buf[0]) * xxhPrime64c5
+ buf = buf[1:]
+ h64 = bits.RotateLeft64(h64, 11) * xxhPrime64c1
+ len--
+ }
+
+ h64 ^= h64 >> 33
+ h64 *= xxhPrime64c2
+ h64 ^= h64 >> 29
+ h64 *= xxhPrime64c3
+ h64 ^= h64 >> 32
+
+ return h64
+}
+
+// round updates a value.
+func (xh *xxhash64) round(v, n uint64) uint64 {
+ v += n * xxhPrime64c2
+ v = bits.RotateLeft64(v, 31)
+ v *= xxhPrime64c1
+ return v
+}
+
+// mergeRound updates a value in the final round.
+func (xh *xxhash64) mergeRound(v, n uint64) uint64 {
+ n = xh.round(0, n)
+ v ^= n
+ v = v*xxhPrime64c1 + xxhPrime64c4
+ return v
+}
diff --git a/src/internal/zstd/xxhash_test.go b/src/internal/zstd/xxhash_test.go
new file mode 100644
index 0000000..646cee8
--- /dev/null
+++ b/src/internal/zstd/xxhash_test.go
@@ -0,0 +1,105 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zstd
+
+import (
+ "bytes"
+ "os"
+ "os/exec"
+ "strconv"
+ "testing"
+)
+
+var xxHashTests = []struct {
+ data string
+ hash uint64
+}{
+ {
+ "hello, world",
+ 0xb33a384e6d1b1242,
+ },
+ {
+ "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789$",
+ 0x1032d841e824f998,
+ },
+}
+
+func TestXXHash(t *testing.T) {
+ var xh xxhash64
+ for i, test := range xxHashTests {
+ xh.reset()
+ xh.update([]byte(test.data))
+ if got := xh.digest(); got != test.hash {
+ t.Errorf("#%d: got %#x want %#x", i, got, test.hash)
+ }
+ }
+}
+
+func TestLargeXXHash(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping expensive test in short mode")
+ }
+
+ data := bigData(t)
+ var xh xxhash64
+ xh.reset()
+ i := 0
+ for i < len(data) {
+ // Write varying amounts to test buffering.
+ c := i%4094 + 1
+ if i+c > len(data) {
+ c = len(data) - i
+ }
+ xh.update(data[i : i+c])
+ i += c
+ }
+
+ got := xh.digest()
+ want := uint64(0xf0dd39fd7e063f82)
+ if got != want {
+ t.Errorf("got %#x want %#x", got, want)
+ }
+}
+
+func FuzzXXHash(f *testing.F) {
+ if _, err := os.Stat("/usr/bin/xxhsum"); err != nil {
+ f.Skip("skipping because /usr/bin/xxhsum does not exist")
+ }
+
+ for _, test := range xxHashTests {
+ f.Add([]byte(test.data))
+ }
+ f.Add(bytes.Repeat([]byte("abcdefghijklmnop"), 256))
+ var buf bytes.Buffer
+ for i := 0; i < 256; i++ {
+ buf.WriteByte(byte(i))
+ }
+ f.Add(bytes.Repeat(buf.Bytes(), 64))
+ f.Add(bigData(f))
+
+ f.Fuzz(func(t *testing.T, b []byte) {
+ cmd := exec.Command("/usr/bin/xxhsum", "-H64")
+ cmd.Stdin = bytes.NewReader(b)
+ var hhsumHash bytes.Buffer
+ cmd.Stdout = &hhsumHash
+ if err := cmd.Run(); err != nil {
+ t.Fatalf("running hhsum failed: %v", err)
+ }
+ hhHashBytes := bytes.Fields(bytes.TrimSpace(hhsumHash.Bytes()))[0]
+ hhHash, err := strconv.ParseUint(string(hhHashBytes), 16, 64)
+ if err != nil {
+ t.Fatalf("could not parse hash %q: %v", hhHashBytes, err)
+ }
+
+ var xh xxhash64
+ xh.reset()
+ xh.update(b)
+ goHash := xh.digest()
+
+ if goHash != hhHash {
+ t.Errorf("Go hash %#x != xxhsum hash %#x", goHash, hhHash)
+ }
+ })
+}
diff --git a/src/internal/zstd/zstd.go b/src/internal/zstd/zstd.go
new file mode 100644
index 0000000..a860789
--- /dev/null
+++ b/src/internal/zstd/zstd.go
@@ -0,0 +1,508 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package zstd provides a decompressor for zstd streams,
+// described in RFC 8878. It does not support dictionaries.
+package zstd
+
+import (
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+)
+
+// fuzzing is a fuzzer hook set to true when fuzzing.
+// This is used to reject cases where we don't match zstd.
+var fuzzing = false
+
+// Reader implements [io.Reader] to read a zstd compressed stream.
+type Reader struct {
+ // The underlying Reader.
+ r io.Reader
+
+ // Whether we have read the frame header.
+ // This is of interest when buffer is empty.
+ // If true we expect to see a new block.
+ sawFrameHeader bool
+
+ // Whether the current frame expects a checksum.
+ hasChecksum bool
+
+ // Whether we have read at least one frame.
+ readOneFrame bool
+
+ // True if the frame size is not known.
+ frameSizeUnknown bool
+
+ // The number of uncompressed bytes remaining in the current frame.
+ // If frameSizeUnknown is true, this is not valid.
+ remainingFrameSize uint64
+
+ // The number of bytes read from r up to the start of the current
+ // block, for error reporting.
+ blockOffset int64
+
+ // Buffered decompressed data.
+ buffer []byte
+ // Current read offset in buffer.
+ off int
+
+ // The current repeated offsets.
+ repeatedOffset1 uint32
+ repeatedOffset2 uint32
+ repeatedOffset3 uint32
+
+ // The current Huffman tree used for compressing literals.
+ huffmanTable []uint16
+ huffmanTableBits int
+
+ // The window for back references.
+ windowSize int // maximum required window size
+ window []byte // window data
+
+ // A buffer available to hold a compressed block.
+ compressedBuf []byte
+
+ // A buffer for literals.
+ literals []byte
+
+ // Sequence decode FSE tables.
+ seqTables [3][]fseBaselineEntry
+ seqTableBits [3]uint8
+
+ // Buffers for sequence decode FSE tables.
+ seqTableBuffers [3][]fseBaselineEntry
+
+ // Scratch space used for small reads, to avoid allocation.
+ scratch [16]byte
+
+ // A scratch table for reading an FSE. Only temporarily valid.
+ fseScratch []fseEntry
+
+ // For checksum computation.
+ checksum xxhash64
+}
+
+// NewReader creates a new Reader that decompresses data from the given reader.
+func NewReader(input io.Reader) *Reader {
+ r := new(Reader)
+ r.Reset(input)
+ return r
+}
+
+// Reset discards the current state and starts reading a new stream from r.
+// This permits reusing a Reader rather than allocating a new one.
+func (r *Reader) Reset(input io.Reader) {
+ r.r = input
+
+ // Several fields are preserved to avoid allocation.
+ // Others are always set before they are used.
+ r.sawFrameHeader = false
+ r.hasChecksum = false
+ r.readOneFrame = false
+ r.frameSizeUnknown = false
+ r.remainingFrameSize = 0
+ r.blockOffset = 0
+ // buffer
+ r.off = 0
+ // repeatedOffset1
+ // repeatedOffset2
+ // repeatedOffset3
+ // huffmanTable
+ // huffmanTableBits
+ // windowSize
+ // window
+ // compressedBuf
+ // literals
+ // seqTables
+ // seqTableBits
+ // seqTableBuffers
+ // scratch
+ // fseScratch
+}
+
+// Read implements [io.Reader].
+func (r *Reader) Read(p []byte) (int, error) {
+ if err := r.refillIfNeeded(); err != nil {
+ return 0, err
+ }
+ n := copy(p, r.buffer[r.off:])
+ r.off += n
+ return n, nil
+}
+
+// ReadByte implements [io.ByteReader].
+func (r *Reader) ReadByte() (byte, error) {
+ if err := r.refillIfNeeded(); err != nil {
+ return 0, err
+ }
+ ret := r.buffer[r.off]
+ r.off++
+ return ret, nil
+}
+
+// refillIfNeeded reads the next block if necessary.
+func (r *Reader) refillIfNeeded() error {
+ for r.off >= len(r.buffer) {
+ if err := r.refill(); err != nil {
+ return err
+ }
+ r.off = 0
+ }
+ return nil
+}
+
+// refill reads and decompresses the next block.
+func (r *Reader) refill() error {
+ if !r.sawFrameHeader {
+ if err := r.readFrameHeader(); err != nil {
+ return err
+ }
+ }
+ return r.readBlock()
+}
+
+// readFrameHeader reads the frame header and prepares to read a block.
+func (r *Reader) readFrameHeader() error {
+retry:
+ relativeOffset := 0
+
+ // Read magic number. RFC 3.1.1.
+ if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
+ // We require that the stream contain at least one frame.
+ if err == io.EOF && !r.readOneFrame {
+ err = io.ErrUnexpectedEOF
+ }
+ return r.wrapError(relativeOffset, err)
+ }
+
+ if magic := binary.LittleEndian.Uint32(r.scratch[:4]); magic != 0xfd2fb528 {
+ if magic >= 0x184d2a50 && magic <= 0x184d2a5f {
+ // This is a skippable frame.
+ r.blockOffset += int64(relativeOffset) + 4
+ if err := r.skipFrame(); err != nil {
+ return err
+ }
+ goto retry
+ }
+
+ return r.makeError(relativeOffset, "invalid magic number")
+ }
+
+ relativeOffset += 4
+
+ // Read Frame_Header_Descriptor. RFC 3.1.1.1.1.
+ if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
+ return r.wrapNonEOFError(relativeOffset, err)
+ }
+ descriptor := r.scratch[0]
+
+ singleSegment := descriptor&(1<<5) != 0
+
+ fcsFieldSize := 1 << (descriptor >> 6)
+ if fcsFieldSize == 1 && !singleSegment {
+ fcsFieldSize = 0
+ }
+
+ var windowDescriptorSize int
+ if singleSegment {
+ windowDescriptorSize = 0
+ } else {
+ windowDescriptorSize = 1
+ }
+
+ if descriptor&(1<<3) != 0 {
+ return r.makeError(relativeOffset, "reserved bit set in frame header descriptor")
+ }
+
+ r.hasChecksum = descriptor&(1<<2) != 0
+ if r.hasChecksum {
+ r.checksum.reset()
+ }
+
+ if descriptor&3 != 0 {
+ return r.makeError(relativeOffset, "dictionaries are not supported")
+ }
+
+ relativeOffset++
+
+ headerSize := windowDescriptorSize + fcsFieldSize
+
+ if _, err := io.ReadFull(r.r, r.scratch[:headerSize]); err != nil {
+ return r.wrapNonEOFError(relativeOffset, err)
+ }
+
+ // Figure out the maximum amount of data we need to retain
+ // for backreferences.
+
+ if singleSegment {
+ // No window required, as all the data is in a single buffer.
+ r.windowSize = 0
+ } else {
+ // Window descriptor. RFC 3.1.1.1.2.
+ windowDescriptor := r.scratch[0]
+ exponent := uint64(windowDescriptor >> 3)
+ mantissa := uint64(windowDescriptor & 7)
+ windowLog := exponent + 10
+ windowBase := uint64(1) << windowLog
+ windowAdd := (windowBase / 8) * mantissa
+ windowSize := windowBase + windowAdd
+
+ // Default zstd sets limits on the window size.
+ if fuzzing && (windowLog > 31 || windowSize > 1<<27) {
+ return r.makeError(relativeOffset, "windowSize too large")
+ }
+
+ // RFC 8878 permits us to set an 8M max on window size.
+ if windowSize > 8<<20 {
+ windowSize = 8 << 20
+ }
+
+ r.windowSize = int(windowSize)
+ }
+
+ // Frame_Content_Size. RFC 3.1.1.4.
+ r.frameSizeUnknown = false
+ r.remainingFrameSize = 0
+ fb := r.scratch[windowDescriptorSize:]
+ switch fcsFieldSize {
+ case 0:
+ r.frameSizeUnknown = true
+ case 1:
+ r.remainingFrameSize = uint64(fb[0])
+ case 2:
+ r.remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16(fb))
+ case 4:
+ r.remainingFrameSize = uint64(binary.LittleEndian.Uint32(fb))
+ case 8:
+ r.remainingFrameSize = binary.LittleEndian.Uint64(fb)
+ default:
+ panic("unreachable")
+ }
+
+ relativeOffset += headerSize
+
+ r.sawFrameHeader = true
+ r.readOneFrame = true
+ r.blockOffset += int64(relativeOffset)
+
+ // Prepare to read blocks from the frame.
+ r.repeatedOffset1 = 1
+ r.repeatedOffset2 = 4
+ r.repeatedOffset3 = 8
+ r.huffmanTableBits = 0
+ r.window = r.window[:0]
+ r.seqTables[0] = nil
+ r.seqTables[1] = nil
+ r.seqTables[2] = nil
+
+ return nil
+}
+
+// skipFrame skips a skippable frame. RFC 3.1.2.
+func (r *Reader) skipFrame() error {
+ relativeOffset := 0
+
+ if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
+ return r.wrapNonEOFError(relativeOffset, err)
+ }
+
+ relativeOffset += 4
+
+ size := binary.LittleEndian.Uint32(r.scratch[:4])
+
+ if seeker, ok := r.r.(io.Seeker); ok {
+ if _, err := seeker.Seek(int64(size), io.SeekCurrent); err != nil {
+ return err
+ }
+ r.blockOffset += int64(relativeOffset) + int64(size)
+ return nil
+ }
+
+ var skip []byte
+ const chunk = 1 << 20 // 1M
+ for size >= chunk {
+ if len(skip) == 0 {
+ skip = make([]byte, chunk)
+ }
+ if _, err := io.ReadFull(r.r, skip); err != nil {
+ return r.wrapNonEOFError(relativeOffset, err)
+ }
+ relativeOffset += chunk
+ size -= chunk
+ }
+ if size > 0 {
+ if len(skip) == 0 {
+ skip = make([]byte, size)
+ }
+ if _, err := io.ReadFull(r.r, skip); err != nil {
+ return r.wrapNonEOFError(relativeOffset, err)
+ }
+ relativeOffset += int(size)
+ }
+
+ r.blockOffset += int64(relativeOffset)
+
+ return nil
+}
+
+// readBlock reads the next block from a frame.
+func (r *Reader) readBlock() error {
+ relativeOffset := 0
+
+ // Read Block_Header. RFC 3.1.1.2.
+ if _, err := io.ReadFull(r.r, r.scratch[:3]); err != nil {
+ return r.wrapNonEOFError(relativeOffset, err)
+ }
+
+ relativeOffset += 3
+
+ header := uint32(r.scratch[0]) | (uint32(r.scratch[1]) << 8) | (uint32(r.scratch[2]) << 16)
+
+ lastBlock := header&1 != 0
+ blockType := (header >> 1) & 3
+ blockSize := int(header >> 3)
+
+ // Maximum block size is smaller of window size and 128K.
+ // We don't record the window size for a single segment frame,
+ // so just use 128K. RFC 3.1.1.2.3, 3.1.1.2.4.
+ if blockSize > 128<<10 || (r.windowSize > 0 && blockSize > r.windowSize) {
+ return r.makeError(relativeOffset, "block size too large")
+ }
+
+ // Handle different block types. RFC 3.1.1.2.2.
+ switch blockType {
+ case 0:
+ r.setBufferSize(blockSize)
+ if _, err := io.ReadFull(r.r, r.buffer); err != nil {
+ return r.wrapNonEOFError(relativeOffset, err)
+ }
+ relativeOffset += blockSize
+ r.blockOffset += int64(relativeOffset)
+ case 1:
+ r.setBufferSize(blockSize)
+ if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
+ return r.wrapNonEOFError(relativeOffset, err)
+ }
+ relativeOffset++
+ v := r.scratch[0]
+ for i := range r.buffer {
+ r.buffer[i] = v
+ }
+ r.blockOffset += int64(relativeOffset)
+ case 2:
+ r.blockOffset += int64(relativeOffset)
+ if err := r.compressedBlock(blockSize); err != nil {
+ return err
+ }
+ r.blockOffset += int64(blockSize)
+ case 3:
+ return r.makeError(relativeOffset, "invalid block type")
+ }
+
+ if !r.frameSizeUnknown {
+ if uint64(len(r.buffer)) > r.remainingFrameSize {
+ return r.makeError(relativeOffset, "too many uncompressed bytes in frame")
+ }
+ r.remainingFrameSize -= uint64(len(r.buffer))
+ }
+
+ if r.hasChecksum {
+ r.checksum.update(r.buffer)
+ }
+
+ if !lastBlock {
+ r.saveWindow(r.buffer)
+ } else {
+ if !r.frameSizeUnknown && r.remainingFrameSize != 0 {
+ return r.makeError(relativeOffset, "not enough uncompressed bytes for frame")
+ }
+ // Check for checksum at end of frame. RFC 3.1.1.
+ if r.hasChecksum {
+ if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
+ return r.wrapNonEOFError(0, err)
+ }
+
+ inputChecksum := binary.LittleEndian.Uint32(r.scratch[:4])
+ dataChecksum := uint32(r.checksum.digest())
+ if inputChecksum != dataChecksum {
+ return r.wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", dataChecksum, inputChecksum))
+ }
+
+ r.blockOffset += 4
+ }
+ r.sawFrameHeader = false
+ }
+
+ return nil
+}
+
+// setBufferSize sets the decompressed buffer size.
+// When this is called the buffer is empty.
+func (r *Reader) setBufferSize(size int) {
+ if cap(r.buffer) < size {
+ need := size - cap(r.buffer)
+ r.buffer = append(r.buffer[:cap(r.buffer)], make([]byte, need)...)
+ }
+ r.buffer = r.buffer[:size]
+}
+
+// saveWindow saves bytes in the backreference window.
+// TODO: use a circular buffer for less data movement.
+func (r *Reader) saveWindow(buf []byte) {
+ if r.windowSize == 0 {
+ return
+ }
+
+ if len(buf) >= r.windowSize {
+ from := len(buf) - r.windowSize
+ r.window = append(r.window[:0], buf[from:]...)
+ return
+ }
+
+ keep := r.windowSize - len(buf) // must be positive
+ if keep < len(r.window) {
+ remove := len(r.window) - keep
+ copy(r.window[:], r.window[remove:])
+ }
+
+ r.window = append(r.window, buf...)
+}
+
+// zstdError is an error while decompressing.
+type zstdError struct {
+ offset int64
+ err error
+}
+
+func (ze *zstdError) Error() string {
+ return fmt.Sprintf("zstd decompression error at %d: %v", ze.offset, ze.err)
+}
+
+func (ze *zstdError) Unwrap() error {
+ return ze.err
+}
+
+func (r *Reader) makeEOFError(off int) error {
+ return r.wrapError(off, io.ErrUnexpectedEOF)
+}
+
+func (r *Reader) wrapNonEOFError(off int, err error) error {
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ return r.wrapError(off, err)
+}
+
+func (r *Reader) makeError(off int, msg string) error {
+ return r.wrapError(off, errors.New(msg))
+}
+
+func (r *Reader) wrapError(off int, err error) error {
+ if err == io.EOF {
+ return err
+ }
+ return &zstdError{r.blockOffset + int64(off), err}
+}
diff --git a/src/internal/zstd/zstd_test.go b/src/internal/zstd/zstd_test.go
new file mode 100644
index 0000000..bc75e0f
--- /dev/null
+++ b/src/internal/zstd/zstd_test.go
@@ -0,0 +1,249 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package zstd
+
+import (
+ "bytes"
+ "fmt"
+ "internal/race"
+ "internal/testenv"
+ "io"
+ "os"
+ "os/exec"
+ "strings"
+ "sync"
+ "testing"
+)
+
+// tests holds some simple test cases, including some found by fuzzing.
+var tests = []struct {
+ name, uncompressed, compressed string
+}{
+ {
+ "hello",
+ "hello, world\n",
+ "\x28\xb5\x2f\xfd\x24\x0d\x69\x00\x00\x68\x65\x6c\x6c\x6f\x2c\x20\x77\x6f\x72\x6c\x64\x0a\x4c\x1f\xf9\xf1",
+ },
+ {
+ // a small compressed .debug_ranges section.
+ "ranges",
+ "\xcc\x11\x00\x00\x00\x00\x00\x00\xd5\x13\x00\x00\x00\x00\x00\x00" +
+ "\x1c\x14\x00\x00\x00\x00\x00\x00\x72\x14\x00\x00\x00\x00\x00\x00" +
+ "\x9d\x14\x00\x00\x00\x00\x00\x00\xd5\x14\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\xfb\x12\x00\x00\x00\x00\x00\x00\x09\x13\x00\x00\x00\x00\x00\x00" +
+ "\x0c\x13\x00\x00\x00\x00\x00\x00\xcb\x13\x00\x00\x00\x00\x00\x00" +
+ "\x29\x14\x00\x00\x00\x00\x00\x00\x4e\x14\x00\x00\x00\x00\x00\x00" +
+ "\x9d\x14\x00\x00\x00\x00\x00\x00\xd5\x14\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\xfb\x12\x00\x00\x00\x00\x00\x00\x09\x13\x00\x00\x00\x00\x00\x00" +
+ "\x67\x13\x00\x00\x00\x00\x00\x00\xcb\x13\x00\x00\x00\x00\x00\x00" +
+ "\x9d\x14\x00\x00\x00\x00\x00\x00\xd5\x14\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x5f\x0b\x00\x00\x00\x00\x00\x00\x6c\x0b\x00\x00\x00\x00\x00\x00" +
+ "\x7d\x0b\x00\x00\x00\x00\x00\x00\x7e\x0c\x00\x00\x00\x00\x00\x00" +
+ "\x38\x0f\x00\x00\x00\x00\x00\x00\x5c\x0f\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x83\x0c\x00\x00\x00\x00\x00\x00\xfa\x0c\x00\x00\x00\x00\x00\x00" +
+ "\xfd\x0d\x00\x00\x00\x00\x00\x00\xef\x0e\x00\x00\x00\x00\x00\x00" +
+ "\x14\x0f\x00\x00\x00\x00\x00\x00\x38\x0f\x00\x00\x00\x00\x00\x00" +
+ "\x9f\x0f\x00\x00\x00\x00\x00\x00\xac\x0f\x00\x00\x00\x00\x00\x00" +
+ "\xdb\x0f\x00\x00\x00\x00\x00\x00\xff\x0f\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\xfd\x0d\x00\x00\x00\x00\x00\x00\xd8\x0e\x00\x00\x00\x00\x00\x00" +
+ "\x9f\x0f\x00\x00\x00\x00\x00\x00\xac\x0f\x00\x00\x00\x00\x00\x00" +
+ "\xdb\x0f\x00\x00\x00\x00\x00\x00\xff\x0f\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\xfa\x0c\x00\x00\x00\x00\x00\x00\xea\x0d\x00\x00\x00\x00\x00\x00" +
+ "\xef\x0e\x00\x00\x00\x00\x00\x00\x14\x0f\x00\x00\x00\x00\x00\x00" +
+ "\x5c\x0f\x00\x00\x00\x00\x00\x00\x9f\x0f\x00\x00\x00\x00\x00\x00" +
+ "\xac\x0f\x00\x00\x00\x00\x00\x00\xdb\x0f\x00\x00\x00\x00\x00\x00" +
+ "\xff\x0f\x00\x00\x00\x00\x00\x00\x2c\x10\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x60\x11\x00\x00\x00\x00\x00\x00\xd1\x16\x00\x00\x00\x00\x00\x00" +
+ "\x40\x0b\x00\x00\x00\x00\x00\x00\x2c\x10\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x7a\x00\x00\x00\x00\x00\x00\x00\xb6\x00\x00\x00\x00\x00\x00\x00" +
+ "\x9f\x01\x00\x00\x00\x00\x00\x00\xa7\x01\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
+ "\x7a\x00\x00\x00\x00\x00\x00\x00\xa9\x00\x00\x00\x00\x00\x00\x00" +
+ "\x9f\x01\x00\x00\x00\x00\x00\x00\xa7\x01\x00\x00\x00\x00\x00\x00" +
+ "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+
+ "\x28\xb5\x2f\xfd\x64\xa0\x01\x2d\x05\x00\xc4\x04\xcc\x11\x00\xd5" +
+ "\x13\x00\x1c\x14\x00\x72\x9d\xd5\xfb\x12\x00\x09\x0c\x13\xcb\x13" +
+ "\x29\x4e\x67\x5f\x0b\x6c\x0b\x7d\x0b\x7e\x0c\x38\x0f\x5c\x0f\x83" +
+ "\x0c\xfa\x0c\xfd\x0d\xef\x0e\x14\x38\x9f\x0f\xac\x0f\xdb\x0f\xff" +
+ "\x0f\xd8\x9f\xac\xdb\xff\xea\x5c\x2c\x10\x60\xd1\x16\x40\x0b\x7a" +
+ "\x00\xb6\x00\x9f\x01\xa7\x01\xa9\x36\x20\xa0\x83\x14\x34\x63\x4a" +
+ "\x21\x70\x8c\x07\x46\x03\x4e\x10\x62\x3c\x06\x4e\xc8\x8c\xb0\x32" +
+ "\x2a\x59\xad\xb2\xf1\x02\x82\x7c\x33\xcb\x92\x6f\x32\x4f\x9b\xb0" +
+ "\xa2\x30\xf0\xc0\x06\x1e\x98\x99\x2c\x06\x1e\xd8\xc0\x03\x56\xd8" +
+ "\xc0\x03\x0f\x6c\xe0\x01\xf1\xf0\xee\x9a\xc6\xc8\x97\x99\xd1\x6c" +
+ "\xb4\x21\x45\x3b\x10\xe4\x7b\x99\x4d\x8a\x36\x64\x5c\x77\x08\x02" +
+ "\xcb\xe0\xce",
+ },
+ {
+ "fuzz1",
+ "0\x00\x00\x00\x00\x000\x00\x00\x00\x00\x001\x00\x00\x00\x00\x000000",
+ "(\xb5/\xfd\x04X\x8d\x00\x00P0\x000\x001\x000000\x03T\x02\x00\x01\x01m\xf9\xb7G",
+ },
+}
+
+func TestSamples(t *testing.T) {
+ for _, test := range tests {
+ test := test
+ t.Run(test.name, func(t *testing.T) {
+ r := NewReader(strings.NewReader(test.compressed))
+ got, err := io.ReadAll(r)
+ if err != nil {
+ t.Fatal(err)
+ }
+ gotstr := string(got)
+ if gotstr != test.uncompressed {
+ t.Errorf("got %q want %q", gotstr, test.uncompressed)
+ }
+ })
+ }
+}
+
+var (
+ bigDataOnce sync.Once
+ bigDataBytes []byte
+ bigDataErr error
+)
+
+// bigData returns the contents of our large test file.
+func bigData(t testing.TB) []byte {
+ bigDataOnce.Do(func() {
+ bigDataBytes, bigDataErr = os.ReadFile("../../testdata/Isaac.Newton-Opticks.txt")
+ })
+ if bigDataErr != nil {
+ t.Fatal(bigDataErr)
+ }
+ return bigDataBytes
+}
+
+var (
+ zstdBigOnce sync.Once
+ zstdBigBytes []byte
+ zstdBigSkip bool
+ zstdBigErr error
+)
+
+// zstdBigData returns the compressed contents of our large test file.
+// This will only run on Unix systems with zstd installed.
+// That's OK as the package is GOOS-independent.
+func zstdBigData(t testing.TB) []byte {
+ input := bigData(t)
+
+ zstdBigOnce.Do(func() {
+ if _, err := os.Stat("/usr/bin/zstd"); err != nil {
+ zstdBigSkip = true
+ return
+ }
+
+ cmd := exec.Command("/usr/bin/zstd", "-z")
+ cmd.Stdin = bytes.NewReader(input)
+ var compressed bytes.Buffer
+ cmd.Stdout = &compressed
+ cmd.Stderr = os.Stderr
+ if err := cmd.Run(); err != nil {
+ zstdBigErr = fmt.Errorf("running zstd failed: %v", err)
+ return
+ }
+
+ zstdBigBytes = compressed.Bytes()
+ })
+ if zstdBigSkip {
+ t.Skip("skipping because /usr/bin/zstd does not exist")
+ }
+ if zstdBigErr != nil {
+ t.Fatal(zstdBigErr)
+ }
+ return zstdBigBytes
+}
+
+// Test decompressing a large file. We don't have a compressor,
+// so this test only runs on systems with zstd installed.
+func TestLarge(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping expensive test in short mode")
+ }
+
+ data := bigData(t)
+ compressed := zstdBigData(t)
+
+ t.Logf("/usr/bin/zstd compressed %d bytes to %d", len(data), len(compressed))
+
+ r := NewReader(bytes.NewReader(compressed))
+ got, err := io.ReadAll(r)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if !bytes.Equal(got, data) {
+ showDiffs(t, got, data)
+ }
+}
+
+// showDiffs reports the first few differences in two []byte.
+func showDiffs(t *testing.T, got, want []byte) {
+ t.Error("data mismatch")
+ if len(got) != len(want) {
+ t.Errorf("got data length %d, want %d", len(got), len(want))
+ }
+ diffs := 0
+ for i, b := range got {
+ if i >= len(want) {
+ break
+ }
+ if b != want[i] {
+ diffs++
+ if diffs > 20 {
+ break
+ }
+ t.Logf("%d: %#x != %#x", i, b, want[i])
+ }
+ }
+}
+
+func TestAlloc(t *testing.T) {
+ testenv.SkipIfOptimizationOff(t)
+ if race.Enabled {
+ t.Skip("skipping allocation test under race detector")
+ }
+
+ compressed := zstdBigData(t)
+ input := bytes.NewReader(compressed)
+ r := NewReader(input)
+ c := testing.AllocsPerRun(10, func() {
+ input.Reset(compressed)
+ r.Reset(input)
+ io.Copy(io.Discard, r)
+ })
+ if c != 0 {
+ t.Errorf("got %v allocs, want 0", c)
+ }
+}
+
+func BenchmarkLarge(b *testing.B) {
+ b.StopTimer()
+ b.ReportAllocs()
+
+ compressed := zstdBigData(b)
+
+ b.SetBytes(int64(len(compressed)))
+
+ input := bytes.NewReader(compressed)
+ r := NewReader(input)
+
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ input.Reset(compressed)
+ r.Reset(input)
+ io.Copy(io.Discard, r)
+ }
+}