diff options
Diffstat (limited to 'src/internal/zstd')
-rw-r--r-- | src/internal/zstd/bits.go | 130 | ||||
-rw-r--r-- | src/internal/zstd/block.go | 425 | ||||
-rw-r--r-- | src/internal/zstd/fse.go | 437 | ||||
-rw-r--r-- | src/internal/zstd/fse_test.go | 89 | ||||
-rw-r--r-- | src/internal/zstd/fuzz_test.go | 139 | ||||
-rw-r--r-- | src/internal/zstd/huff.go | 204 | ||||
-rw-r--r-- | src/internal/zstd/literals.go | 336 | ||||
-rw-r--r-- | src/internal/zstd/testdata/1890a371.gettysburg.txt-100x.zst | bin | 0 -> 826 bytes | |||
-rw-r--r-- | src/internal/zstd/testdata/README | 10 | ||||
-rw-r--r-- | src/internal/zstd/testdata/f2a8e35c.helloworld-11000x.zst | bin | 0 -> 47 bytes | |||
-rw-r--r-- | src/internal/zstd/testdata/fcf30b99.zero-dictionary-ids.zst | bin | 0 -> 64 bytes | |||
-rw-r--r-- | src/internal/zstd/window.go | 90 | ||||
-rw-r--r-- | src/internal/zstd/window_test.go | 72 | ||||
-rw-r--r-- | src/internal/zstd/xxhash.go | 148 | ||||
-rw-r--r-- | src/internal/zstd/xxhash_test.go | 115 | ||||
-rw-r--r-- | src/internal/zstd/zstd.go | 522 | ||||
-rw-r--r-- | src/internal/zstd/zstd_test.go | 335 |
17 files changed, 3052 insertions, 0 deletions
diff --git a/src/internal/zstd/bits.go b/src/internal/zstd/bits.go new file mode 100644 index 0000000..c9a2f70 --- /dev/null +++ b/src/internal/zstd/bits.go @@ -0,0 +1,130 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package zstd + +import ( + "math/bits" +) + +// block is the data for a single compressed block. +// The data starts immediately after the 3 byte block header, +// and is Block_Size bytes long. +type block []byte + +// bitReader reads a bit stream going forward. +type bitReader struct { + r *Reader // for error reporting + data block // the bits to read + off uint32 // current offset into data + bits uint32 // bits ready to be returned + cnt uint32 // number of valid bits in the bits field +} + +// makeBitReader makes a bit reader starting at off. +func (r *Reader) makeBitReader(data block, off int) bitReader { + return bitReader{ + r: r, + data: data, + off: uint32(off), + } +} + +// moreBits is called to read more bits. +// This ensures that at least 16 bits are available. +func (br *bitReader) moreBits() error { + for br.cnt < 16 { + if br.off >= uint32(len(br.data)) { + return br.r.makeEOFError(int(br.off)) + } + c := br.data[br.off] + br.off++ + br.bits |= uint32(c) << br.cnt + br.cnt += 8 + } + return nil +} + +// val is called to fetch a value of b bits. +func (br *bitReader) val(b uint8) uint32 { + r := br.bits & ((1 << b) - 1) + br.bits >>= b + br.cnt -= uint32(b) + return r +} + +// backup steps back to the last byte we used. +func (br *bitReader) backup() { + for br.cnt >= 8 { + br.off-- + br.cnt -= 8 + } +} + +// makeError returns an error at the current offset wrapping a string. +func (br *bitReader) makeError(msg string) error { + return br.r.makeError(int(br.off), msg) +} + +// reverseBitReader reads a bit stream in reverse. +type reverseBitReader struct { + r *Reader // for error reporting + data block // the bits to read + off uint32 // current offset into data + start uint32 // start in data; we read backward to start + bits uint32 // bits ready to be returned + cnt uint32 // number of valid bits in bits field +} + +// makeReverseBitReader makes a reverseBitReader reading backward +// from off to start. The bitstream starts with a 1 bit in the last +// byte, at off. +func (r *Reader) makeReverseBitReader(data block, off, start int) (reverseBitReader, error) { + streamStart := data[off] + if streamStart == 0 { + return reverseBitReader{}, r.makeError(off, "zero byte at reverse bit stream start") + } + rbr := reverseBitReader{ + r: r, + data: data, + off: uint32(off), + start: uint32(start), + bits: uint32(streamStart), + cnt: uint32(7 - bits.LeadingZeros8(streamStart)), + } + return rbr, nil +} + +// val is called to fetch a value of b bits. +func (rbr *reverseBitReader) val(b uint8) (uint32, error) { + if !rbr.fetch(b) { + return 0, rbr.r.makeEOFError(int(rbr.off)) + } + + rbr.cnt -= uint32(b) + v := (rbr.bits >> rbr.cnt) & ((1 << b) - 1) + return v, nil +} + +// fetch is called to ensure that at least b bits are available. +// It reports false if this can't be done, +// in which case only rbr.cnt bits are available. +func (rbr *reverseBitReader) fetch(b uint8) bool { + for rbr.cnt < uint32(b) { + if rbr.off <= rbr.start { + return false + } + rbr.off-- + c := rbr.data[rbr.off] + rbr.bits <<= 8 + rbr.bits |= uint32(c) + rbr.cnt += 8 + } + return true +} + +// makeError returns an error at the current offset wrapping a string. +func (rbr *reverseBitReader) makeError(msg string) error { + return rbr.r.makeError(int(rbr.off), msg) +} diff --git a/src/internal/zstd/block.go b/src/internal/zstd/block.go new file mode 100644 index 0000000..11a99cd --- /dev/null +++ b/src/internal/zstd/block.go @@ -0,0 +1,425 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package zstd + +import ( + "io" +) + +// debug can be set in the source to print debug info using println. +const debug = false + +// compressedBlock decompresses a compressed block, storing the decompressed +// data in r.buffer. The blockSize argument is the compressed size. +// RFC 3.1.1.3. +func (r *Reader) compressedBlock(blockSize int) error { + if len(r.compressedBuf) >= blockSize { + r.compressedBuf = r.compressedBuf[:blockSize] + } else { + // We know that blockSize <= 128K, + // so this won't allocate an enormous amount. + need := blockSize - len(r.compressedBuf) + r.compressedBuf = append(r.compressedBuf, make([]byte, need)...) + } + + if _, err := io.ReadFull(r.r, r.compressedBuf); err != nil { + return r.wrapNonEOFError(0, err) + } + + data := block(r.compressedBuf) + off := 0 + r.buffer = r.buffer[:0] + + litoff, litbuf, err := r.readLiterals(data, off, r.literals[:0]) + if err != nil { + return err + } + r.literals = litbuf + + off = litoff + + seqCount, off, err := r.initSeqs(data, off) + if err != nil { + return err + } + + if seqCount == 0 { + // No sequences, just literals. + if off < len(data) { + return r.makeError(off, "extraneous data after no sequences") + } + + r.buffer = append(r.buffer, litbuf...) + + return nil + } + + return r.execSeqs(data, off, litbuf, seqCount) +} + +// seqCode is the kind of sequence codes we have to handle. +type seqCode int + +const ( + seqLiteral seqCode = iota + seqOffset + seqMatch +) + +// seqCodeInfoData is the information needed to set up seqTables and +// seqTableBits for a particular kind of sequence code. +type seqCodeInfoData struct { + predefTable []fseBaselineEntry // predefined FSE + predefTableBits int // number of bits in predefTable + maxSym int // max symbol value in FSE + maxBits int // max bits for FSE + + // toBaseline converts from an FSE table to an FSE baseline table. + toBaseline func(*Reader, int, []fseEntry, []fseBaselineEntry) error +} + +// seqCodeInfo is the seqCodeInfoData for each kind of sequence code. +var seqCodeInfo = [3]seqCodeInfoData{ + seqLiteral: { + predefTable: predefinedLiteralTable[:], + predefTableBits: 6, + maxSym: 35, + maxBits: 9, + toBaseline: (*Reader).makeLiteralBaselineFSE, + }, + seqOffset: { + predefTable: predefinedOffsetTable[:], + predefTableBits: 5, + maxSym: 31, + maxBits: 8, + toBaseline: (*Reader).makeOffsetBaselineFSE, + }, + seqMatch: { + predefTable: predefinedMatchTable[:], + predefTableBits: 6, + maxSym: 52, + maxBits: 9, + toBaseline: (*Reader).makeMatchBaselineFSE, + }, +} + +// initSeqs reads the Sequences_Section_Header and sets up the FSE +// tables used to read the sequence codes. It returns the number of +// sequences and the new offset. RFC 3.1.1.3.2.1. +func (r *Reader) initSeqs(data block, off int) (int, int, error) { + if off >= len(data) { + return 0, 0, r.makeEOFError(off) + } + + seqHdr := data[off] + off++ + if seqHdr == 0 { + return 0, off, nil + } + + var seqCount int + if seqHdr < 128 { + seqCount = int(seqHdr) + } else if seqHdr < 255 { + if off >= len(data) { + return 0, 0, r.makeEOFError(off) + } + seqCount = ((int(seqHdr) - 128) << 8) + int(data[off]) + off++ + } else { + if off+1 >= len(data) { + return 0, 0, r.makeEOFError(off) + } + seqCount = int(data[off]) + (int(data[off+1]) << 8) + 0x7f00 + off += 2 + } + + // Read the Symbol_Compression_Modes byte. + + if off >= len(data) { + return 0, 0, r.makeEOFError(off) + } + symMode := data[off] + if symMode&3 != 0 { + return 0, 0, r.makeError(off, "invalid symbol compression mode") + } + off++ + + // Set up the FSE tables used to decode the sequence codes. + + var err error + off, err = r.setSeqTable(data, off, seqLiteral, (symMode>>6)&3) + if err != nil { + return 0, 0, err + } + + off, err = r.setSeqTable(data, off, seqOffset, (symMode>>4)&3) + if err != nil { + return 0, 0, err + } + + off, err = r.setSeqTable(data, off, seqMatch, (symMode>>2)&3) + if err != nil { + return 0, 0, err + } + + return seqCount, off, nil +} + +// setSeqTable uses the Compression_Mode in mode to set up r.seqTables and +// r.seqTableBits for kind. We store these in the Reader because one of +// the modes simply reuses the value from the last block in the frame. +func (r *Reader) setSeqTable(data block, off int, kind seqCode, mode byte) (int, error) { + info := &seqCodeInfo[kind] + switch mode { + case 0: + // Predefined_Mode + r.seqTables[kind] = info.predefTable + r.seqTableBits[kind] = uint8(info.predefTableBits) + return off, nil + + case 1: + // RLE_Mode + if off >= len(data) { + return 0, r.makeEOFError(off) + } + rle := data[off] + off++ + + // Build a simple baseline table that always returns rle. + + entry := []fseEntry{ + { + sym: rle, + bits: 0, + base: 0, + }, + } + if cap(r.seqTableBuffers[kind]) == 0 { + r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits) + } + r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1] + if err := info.toBaseline(r, off, entry, r.seqTableBuffers[kind]); err != nil { + return 0, err + } + + r.seqTables[kind] = r.seqTableBuffers[kind] + r.seqTableBits[kind] = 0 + return off, nil + + case 2: + // FSE_Compressed_Mode + if cap(r.fseScratch) < 1<<info.maxBits { + r.fseScratch = make([]fseEntry, 1<<info.maxBits) + } + r.fseScratch = r.fseScratch[:1<<info.maxBits] + + tableBits, roff, err := r.readFSE(data, off, info.maxSym, info.maxBits, r.fseScratch) + if err != nil { + return 0, err + } + r.fseScratch = r.fseScratch[:1<<tableBits] + + if cap(r.seqTableBuffers[kind]) == 0 { + r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits) + } + r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1<<tableBits] + + if err := info.toBaseline(r, roff, r.fseScratch, r.seqTableBuffers[kind]); err != nil { + return 0, err + } + + r.seqTables[kind] = r.seqTableBuffers[kind] + r.seqTableBits[kind] = uint8(tableBits) + return roff, nil + + case 3: + // Repeat_Mode + if len(r.seqTables[kind]) == 0 { + return 0, r.makeError(off, "missing repeat sequence FSE table") + } + return off, nil + } + panic("unreachable") +} + +// execSeqs reads and executes the sequences. RFC 3.1.1.3.2.1.2. +func (r *Reader) execSeqs(data block, off int, litbuf []byte, seqCount int) error { + // Set up the initial states for the sequence code readers. + + rbr, err := r.makeReverseBitReader(data, len(data)-1, off) + if err != nil { + return err + } + + literalState, err := rbr.val(r.seqTableBits[seqLiteral]) + if err != nil { + return err + } + + offsetState, err := rbr.val(r.seqTableBits[seqOffset]) + if err != nil { + return err + } + + matchState, err := rbr.val(r.seqTableBits[seqMatch]) + if err != nil { + return err + } + + // Read and perform all the sequences. RFC 3.1.1.4. + + seq := 0 + for seq < seqCount { + if len(r.buffer)+len(litbuf) > 128<<10 { + return rbr.makeError("uncompressed size too big") + } + + ptoffset := &r.seqTables[seqOffset][offsetState] + ptmatch := &r.seqTables[seqMatch][matchState] + ptliteral := &r.seqTables[seqLiteral][literalState] + + add, err := rbr.val(ptoffset.basebits) + if err != nil { + return err + } + offset := ptoffset.baseline + add + + add, err = rbr.val(ptmatch.basebits) + if err != nil { + return err + } + match := ptmatch.baseline + add + + add, err = rbr.val(ptliteral.basebits) + if err != nil { + return err + } + literal := ptliteral.baseline + add + + // Handle repeat offsets. RFC 3.1.1.5. + // See the comment in makeOffsetBaselineFSE. + if ptoffset.basebits > 1 { + r.repeatedOffset3 = r.repeatedOffset2 + r.repeatedOffset2 = r.repeatedOffset1 + r.repeatedOffset1 = offset + } else { + if literal == 0 { + offset++ + } + switch offset { + case 1: + offset = r.repeatedOffset1 + case 2: + offset = r.repeatedOffset2 + r.repeatedOffset2 = r.repeatedOffset1 + r.repeatedOffset1 = offset + case 3: + offset = r.repeatedOffset3 + r.repeatedOffset3 = r.repeatedOffset2 + r.repeatedOffset2 = r.repeatedOffset1 + r.repeatedOffset1 = offset + case 4: + offset = r.repeatedOffset1 - 1 + r.repeatedOffset3 = r.repeatedOffset2 + r.repeatedOffset2 = r.repeatedOffset1 + r.repeatedOffset1 = offset + } + } + + seq++ + if seq < seqCount { + // Update the states. + add, err = rbr.val(ptliteral.bits) + if err != nil { + return err + } + literalState = uint32(ptliteral.base) + add + + add, err = rbr.val(ptmatch.bits) + if err != nil { + return err + } + matchState = uint32(ptmatch.base) + add + + add, err = rbr.val(ptoffset.bits) + if err != nil { + return err + } + offsetState = uint32(ptoffset.base) + add + } + + // The next sequence is now in literal, offset, match. + + if debug { + println("literal", literal, "offset", offset, "match", match) + } + + // Copy literal bytes from litbuf. + if literal > uint32(len(litbuf)) { + return rbr.makeError("literal byte overflow") + } + if literal > 0 { + r.buffer = append(r.buffer, litbuf[:literal]...) + litbuf = litbuf[literal:] + } + + if match > 0 { + if err := r.copyFromWindow(&rbr, offset, match); err != nil { + return err + } + } + } + + r.buffer = append(r.buffer, litbuf...) + + if rbr.cnt != 0 { + return r.makeError(off, "extraneous data after sequences") + } + + return nil +} + +// Copy match bytes from the decoded output, or the window, at offset. +func (r *Reader) copyFromWindow(rbr *reverseBitReader, offset, match uint32) error { + if offset == 0 { + return rbr.makeError("invalid zero offset") + } + + // Offset may point into the buffer or the window and + // match may extend past the end of the initial buffer. + // |--r.window--|--r.buffer--| + // |<-----offset------| + // |------match----------->| + bufferOffset := uint32(0) + lenBlock := uint32(len(r.buffer)) + if lenBlock < offset { + lenWindow := r.window.len() + copy := offset - lenBlock + if copy > lenWindow { + return rbr.makeError("offset past window") + } + windowOffset := lenWindow - copy + if copy > match { + copy = match + } + r.buffer = r.window.appendTo(r.buffer, windowOffset, windowOffset+copy) + match -= copy + } else { + bufferOffset = lenBlock - offset + } + + // We are being asked to copy data that we are adding to the + // buffer in the same copy. + for match > 0 { + copy := uint32(len(r.buffer)) - bufferOffset + if copy > match { + copy = match + } + r.buffer = append(r.buffer, r.buffer[bufferOffset:bufferOffset+copy]...) + match -= copy + } + return nil +} diff --git a/src/internal/zstd/fse.go b/src/internal/zstd/fse.go new file mode 100644 index 0000000..f03a792 --- /dev/null +++ b/src/internal/zstd/fse.go @@ -0,0 +1,437 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package zstd + +import ( + "math/bits" +) + +// fseEntry is one entry in an FSE table. +type fseEntry struct { + sym uint8 // value that this entry records + bits uint8 // number of bits to read to determine next state + base uint16 // add those bits to this state to get the next state +} + +// readFSE reads an FSE table from data starting at off. +// maxSym is the maximum symbol value. +// maxBits is the maximum number of bits permitted for symbols in the table. +// The FSE is written into table, which must be at least 1<<maxBits in size. +// This returns the number of bits in the FSE table and the new offset. +// RFC 4.1.1. +func (r *Reader) readFSE(data block, off, maxSym, maxBits int, table []fseEntry) (tableBits, roff int, err error) { + br := r.makeBitReader(data, off) + if err := br.moreBits(); err != nil { + return 0, 0, err + } + + accuracyLog := int(br.val(4)) + 5 + if accuracyLog > maxBits { + return 0, 0, br.makeError("FSE accuracy log too large") + } + + // The number of remaining probabilities, plus 1. + // This determines the number of bits to be read for the next value. + remaining := (1 << accuracyLog) + 1 + + // The current difference between small and large values, + // which depends on the number of remaining values. + // Small values use 1 less bit. + threshold := 1 << accuracyLog + + // The number of bits needed to compute threshold. + bitsNeeded := accuracyLog + 1 + + // The next character value. + sym := 0 + + // Whether the last count was 0. + prev0 := false + + var norm [256]int16 + + for remaining > 1 && sym <= maxSym { + if err := br.moreBits(); err != nil { + return 0, 0, err + } + + if prev0 { + // Previous count was 0, so there is a 2-bit + // repeat flag. If the 2-bit flag is 0b11, + // it adds 3 and then there is another repeat flag. + zsym := sym + for (br.bits & 0xfff) == 0xfff { + zsym += 3 * 6 + br.bits >>= 12 + br.cnt -= 12 + if err := br.moreBits(); err != nil { + return 0, 0, err + } + } + for (br.bits & 3) == 3 { + zsym += 3 + br.bits >>= 2 + br.cnt -= 2 + if err := br.moreBits(); err != nil { + return 0, 0, err + } + } + + // We have at least 14 bits here, + // no need to call moreBits + + zsym += int(br.val(2)) + + if zsym > maxSym { + return 0, 0, br.makeError("FSE symbol index overflow") + } + + for ; sym < zsym; sym++ { + norm[uint8(sym)] = 0 + } + + prev0 = false + continue + } + + max := (2*threshold - 1) - remaining + var count int + if int(br.bits&uint32(threshold-1)) < max { + // A small value. + count = int(br.bits & uint32((threshold - 1))) + br.bits >>= bitsNeeded - 1 + br.cnt -= uint32(bitsNeeded - 1) + } else { + // A large value. + count = int(br.bits & uint32((2*threshold - 1))) + if count >= threshold { + count -= max + } + br.bits >>= bitsNeeded + br.cnt -= uint32(bitsNeeded) + } + + count-- + if count >= 0 { + remaining -= count + } else { + remaining-- + } + if sym >= 256 { + return 0, 0, br.makeError("FSE sym overflow") + } + norm[uint8(sym)] = int16(count) + sym++ + + prev0 = count == 0 + + for remaining < threshold { + bitsNeeded-- + threshold >>= 1 + } + } + + if remaining != 1 { + return 0, 0, br.makeError("too many symbols in FSE table") + } + + for ; sym <= maxSym; sym++ { + norm[uint8(sym)] = 0 + } + + br.backup() + + if err := r.buildFSE(off, norm[:maxSym+1], table, accuracyLog); err != nil { + return 0, 0, err + } + + return accuracyLog, int(br.off), nil +} + +// buildFSE builds an FSE decoding table from a list of probabilities. +// The probabilities are in norm. next is scratch space. The number of bits +// in the table is tableBits. +func (r *Reader) buildFSE(off int, norm []int16, table []fseEntry, tableBits int) error { + tableSize := 1 << tableBits + highThreshold := tableSize - 1 + + var next [256]uint16 + + for i, n := range norm { + if n >= 0 { + next[uint8(i)] = uint16(n) + } else { + table[highThreshold].sym = uint8(i) + highThreshold-- + next[uint8(i)] = 1 + } + } + + pos := 0 + step := (tableSize >> 1) + (tableSize >> 3) + 3 + mask := tableSize - 1 + for i, n := range norm { + for j := 0; j < int(n); j++ { + table[pos].sym = uint8(i) + pos = (pos + step) & mask + for pos > highThreshold { + pos = (pos + step) & mask + } + } + } + if pos != 0 { + return r.makeError(off, "FSE count error") + } + + for i := 0; i < tableSize; i++ { + sym := table[i].sym + nextState := next[sym] + next[sym]++ + + if nextState == 0 { + return r.makeError(off, "FSE state error") + } + + highBit := 15 - bits.LeadingZeros16(nextState) + + bits := tableBits - highBit + table[i].bits = uint8(bits) + table[i].base = (nextState << bits) - uint16(tableSize) + } + + return nil +} + +// fseBaselineEntry is an entry in an FSE baseline table. +// We use these for literal/match/length values. +// Those require mapping the symbol to a baseline value, +// and then reading zero or more bits and adding the value to the baseline. +// Rather than looking these up in separate tables, +// we convert the FSE table to an FSE baseline table. +type fseBaselineEntry struct { + baseline uint32 // baseline for value that this entry represents + basebits uint8 // number of bits to read to add to baseline + bits uint8 // number of bits to read to determine next state + base uint16 // add the bits to this base to get the next state +} + +// Given a literal length code, we need to read a number of bits and +// add that to a baseline. For states 0 to 15 the baseline is the +// state and the number of bits is zero. RFC 3.1.1.3.2.1.1. + +const literalLengthOffset = 16 + +var literalLengthBase = []uint32{ + 16 | (1 << 24), + 18 | (1 << 24), + 20 | (1 << 24), + 22 | (1 << 24), + 24 | (2 << 24), + 28 | (2 << 24), + 32 | (3 << 24), + 40 | (3 << 24), + 48 | (4 << 24), + 64 | (6 << 24), + 128 | (7 << 24), + 256 | (8 << 24), + 512 | (9 << 24), + 1024 | (10 << 24), + 2048 | (11 << 24), + 4096 | (12 << 24), + 8192 | (13 << 24), + 16384 | (14 << 24), + 32768 | (15 << 24), + 65536 | (16 << 24), +} + +// makeLiteralBaselineFSE converts the literal length fseTable to baselineTable. +func (r *Reader) makeLiteralBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error { + for i, e := range fseTable { + be := fseBaselineEntry{ + bits: e.bits, + base: e.base, + } + if e.sym < literalLengthOffset { + be.baseline = uint32(e.sym) + be.basebits = 0 + } else { + if e.sym > 35 { + return r.makeError(off, "FSE baseline symbol overflow") + } + idx := e.sym - literalLengthOffset + basebits := literalLengthBase[idx] + be.baseline = basebits & 0xffffff + be.basebits = uint8(basebits >> 24) + } + baselineTable[i] = be + } + return nil +} + +// makeOffsetBaselineFSE converts the offset length fseTable to baselineTable. +func (r *Reader) makeOffsetBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error { + for i, e := range fseTable { + be := fseBaselineEntry{ + bits: e.bits, + base: e.base, + } + if e.sym > 31 { + return r.makeError(off, "FSE offset symbol overflow") + } + + // The simple way to write this is + // be.baseline = 1 << e.sym + // be.basebits = e.sym + // That would give us an offset value that corresponds to + // the one described in the RFC. However, for offsets > 3 + // we have to subtract 3. And for offset values 1, 2, 3 + // we use a repeated offset. + // + // The baseline is always a power of 2, and is never 0, + // so for those low values we will see one entry that is + // baseline 1, basebits 0, and one entry that is baseline 2, + // basebits 1. All other entries will have baseline >= 4 + // basebits >= 2. + // + // So we can check for RFC offset <= 3 by checking for + // basebits <= 1. That means that we can subtract 3 here + // and not worry about doing it in the hot loop. + + be.baseline = 1 << e.sym + if e.sym >= 2 { + be.baseline -= 3 + } + be.basebits = e.sym + baselineTable[i] = be + } + return nil +} + +// Given a match length code, we need to read a number of bits and add +// that to a baseline. For states 0 to 31 the baseline is state+3 and +// the number of bits is zero. RFC 3.1.1.3.2.1.1. + +const matchLengthOffset = 32 + +var matchLengthBase = []uint32{ + 35 | (1 << 24), + 37 | (1 << 24), + 39 | (1 << 24), + 41 | (1 << 24), + 43 | (2 << 24), + 47 | (2 << 24), + 51 | (3 << 24), + 59 | (3 << 24), + 67 | (4 << 24), + 83 | (4 << 24), + 99 | (5 << 24), + 131 | (7 << 24), + 259 | (8 << 24), + 515 | (9 << 24), + 1027 | (10 << 24), + 2051 | (11 << 24), + 4099 | (12 << 24), + 8195 | (13 << 24), + 16387 | (14 << 24), + 32771 | (15 << 24), + 65539 | (16 << 24), +} + +// makeMatchBaselineFSE converts the match length fseTable to baselineTable. +func (r *Reader) makeMatchBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error { + for i, e := range fseTable { + be := fseBaselineEntry{ + bits: e.bits, + base: e.base, + } + if e.sym < matchLengthOffset { + be.baseline = uint32(e.sym) + 3 + be.basebits = 0 + } else { + if e.sym > 52 { + return r.makeError(off, "FSE baseline symbol overflow") + } + idx := e.sym - matchLengthOffset + basebits := matchLengthBase[idx] + be.baseline = basebits & 0xffffff + be.basebits = uint8(basebits >> 24) + } + baselineTable[i] = be + } + return nil +} + +// predefinedLiteralTable is the predefined table to use for literal lengths. +// Generated from table in RFC 3.1.1.3.2.2.1. +// Checked by TestPredefinedTables. +var predefinedLiteralTable = [...]fseBaselineEntry{ + {0, 0, 4, 0}, {0, 0, 4, 16}, {1, 0, 5, 32}, + {3, 0, 5, 0}, {4, 0, 5, 0}, {6, 0, 5, 0}, + {7, 0, 5, 0}, {9, 0, 5, 0}, {10, 0, 5, 0}, + {12, 0, 5, 0}, {14, 0, 6, 0}, {16, 1, 5, 0}, + {20, 1, 5, 0}, {22, 1, 5, 0}, {28, 2, 5, 0}, + {32, 3, 5, 0}, {48, 4, 5, 0}, {64, 6, 5, 32}, + {128, 7, 5, 0}, {256, 8, 6, 0}, {1024, 10, 6, 0}, + {4096, 12, 6, 0}, {0, 0, 4, 32}, {1, 0, 4, 0}, + {2, 0, 5, 0}, {4, 0, 5, 32}, {5, 0, 5, 0}, + {7, 0, 5, 32}, {8, 0, 5, 0}, {10, 0, 5, 32}, + {11, 0, 5, 0}, {13, 0, 6, 0}, {16, 1, 5, 32}, + {18, 1, 5, 0}, {22, 1, 5, 32}, {24, 2, 5, 0}, + {32, 3, 5, 32}, {40, 3, 5, 0}, {64, 6, 4, 0}, + {64, 6, 4, 16}, {128, 7, 5, 32}, {512, 9, 6, 0}, + {2048, 11, 6, 0}, {0, 0, 4, 48}, {1, 0, 4, 16}, + {2, 0, 5, 32}, {3, 0, 5, 32}, {5, 0, 5, 32}, + {6, 0, 5, 32}, {8, 0, 5, 32}, {9, 0, 5, 32}, + {11, 0, 5, 32}, {12, 0, 5, 32}, {15, 0, 6, 0}, + {18, 1, 5, 32}, {20, 1, 5, 32}, {24, 2, 5, 32}, + {28, 2, 5, 32}, {40, 3, 5, 32}, {48, 4, 5, 32}, + {65536, 16, 6, 0}, {32768, 15, 6, 0}, {16384, 14, 6, 0}, + {8192, 13, 6, 0}, +} + +// predefinedOffsetTable is the predefined table to use for offsets. +// Generated from table in RFC 3.1.1.3.2.2.3. +// Checked by TestPredefinedTables. +var predefinedOffsetTable = [...]fseBaselineEntry{ + {1, 0, 5, 0}, {61, 6, 4, 0}, {509, 9, 5, 0}, + {32765, 15, 5, 0}, {2097149, 21, 5, 0}, {5, 3, 5, 0}, + {125, 7, 4, 0}, {4093, 12, 5, 0}, {262141, 18, 5, 0}, + {8388605, 23, 5, 0}, {29, 5, 5, 0}, {253, 8, 4, 0}, + {16381, 14, 5, 0}, {1048573, 20, 5, 0}, {1, 2, 5, 0}, + {125, 7, 4, 16}, {2045, 11, 5, 0}, {131069, 17, 5, 0}, + {4194301, 22, 5, 0}, {13, 4, 5, 0}, {253, 8, 4, 16}, + {8189, 13, 5, 0}, {524285, 19, 5, 0}, {2, 1, 5, 0}, + {61, 6, 4, 16}, {1021, 10, 5, 0}, {65533, 16, 5, 0}, + {268435453, 28, 5, 0}, {134217725, 27, 5, 0}, {67108861, 26, 5, 0}, + {33554429, 25, 5, 0}, {16777213, 24, 5, 0}, +} + +// predefinedMatchTable is the predefined table to use for match lengths. +// Generated from table in RFC 3.1.1.3.2.2.2. +// Checked by TestPredefinedTables. +var predefinedMatchTable = [...]fseBaselineEntry{ + {3, 0, 6, 0}, {4, 0, 4, 0}, {5, 0, 5, 32}, + {6, 0, 5, 0}, {8, 0, 5, 0}, {9, 0, 5, 0}, + {11, 0, 5, 0}, {13, 0, 6, 0}, {16, 0, 6, 0}, + {19, 0, 6, 0}, {22, 0, 6, 0}, {25, 0, 6, 0}, + {28, 0, 6, 0}, {31, 0, 6, 0}, {34, 0, 6, 0}, + {37, 1, 6, 0}, {41, 1, 6, 0}, {47, 2, 6, 0}, + {59, 3, 6, 0}, {83, 4, 6, 0}, {131, 7, 6, 0}, + {515, 9, 6, 0}, {4, 0, 4, 16}, {5, 0, 4, 0}, + {6, 0, 5, 32}, {7, 0, 5, 0}, {9, 0, 5, 32}, + {10, 0, 5, 0}, {12, 0, 6, 0}, {15, 0, 6, 0}, + {18, 0, 6, 0}, {21, 0, 6, 0}, {24, 0, 6, 0}, + {27, 0, 6, 0}, {30, 0, 6, 0}, {33, 0, 6, 0}, + {35, 1, 6, 0}, {39, 1, 6, 0}, {43, 2, 6, 0}, + {51, 3, 6, 0}, {67, 4, 6, 0}, {99, 5, 6, 0}, + {259, 8, 6, 0}, {4, 0, 4, 32}, {4, 0, 4, 48}, + {5, 0, 4, 16}, {7, 0, 5, 32}, {8, 0, 5, 32}, + {10, 0, 5, 32}, {11, 0, 5, 32}, {14, 0, 6, 0}, + {17, 0, 6, 0}, {20, 0, 6, 0}, {23, 0, 6, 0}, + {26, 0, 6, 0}, {29, 0, 6, 0}, {32, 0, 6, 0}, + {65539, 16, 6, 0}, {32771, 15, 6, 0}, {16387, 14, 6, 0}, + {8195, 13, 6, 0}, {4099, 12, 6, 0}, {2051, 11, 6, 0}, + {1027, 10, 6, 0}, +} diff --git a/src/internal/zstd/fse_test.go b/src/internal/zstd/fse_test.go new file mode 100644 index 0000000..6f106b6 --- /dev/null +++ b/src/internal/zstd/fse_test.go @@ -0,0 +1,89 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package zstd + +import ( + "slices" + "testing" +) + +// literalPredefinedDistribution is the predefined distribution table +// for literal lengths. RFC 3.1.1.3.2.2.1. +var literalPredefinedDistribution = []int16{ + 4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, + -1, -1, -1, -1, +} + +// offsetPredefinedDistribution is the predefined distribution table +// for offsets. RFC 3.1.1.3.2.2.3. +var offsetPredefinedDistribution = []int16{ + 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, +} + +// matchPredefinedDistribution is the predefined distribution table +// for match lengths. RFC 3.1.1.3.2.2.2. +var matchPredefinedDistribution = []int16{ + 1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, + -1, -1, -1, -1, -1, +} + +// TestPredefinedTables verifies that we can generate the predefined +// literal/offset/match tables from the input data in RFC 8878. +// This serves as a test of the predefined tables, and also of buildFSE +// and the functions that make baseline FSE tables. +func TestPredefinedTables(t *testing.T) { + tests := []struct { + name string + distribution []int16 + tableBits int + toBaseline func(*Reader, int, []fseEntry, []fseBaselineEntry) error + predef []fseBaselineEntry + }{ + { + name: "literal", + distribution: literalPredefinedDistribution, + tableBits: 6, + toBaseline: (*Reader).makeLiteralBaselineFSE, + predef: predefinedLiteralTable[:], + }, + { + name: "offset", + distribution: offsetPredefinedDistribution, + tableBits: 5, + toBaseline: (*Reader).makeOffsetBaselineFSE, + predef: predefinedOffsetTable[:], + }, + { + name: "match", + distribution: matchPredefinedDistribution, + tableBits: 6, + toBaseline: (*Reader).makeMatchBaselineFSE, + predef: predefinedMatchTable[:], + }, + } + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + var r Reader + table := make([]fseEntry, 1<<test.tableBits) + if err := r.buildFSE(0, test.distribution, table, test.tableBits); err != nil { + t.Fatal(err) + } + + baselineTable := make([]fseBaselineEntry, len(table)) + if err := test.toBaseline(&r, 0, table, baselineTable); err != nil { + t.Fatal(err) + } + + if !slices.Equal(baselineTable, test.predef) { + t.Errorf("got %v, want %v", baselineTable, test.predef) + } + }) + } +} diff --git a/src/internal/zstd/fuzz_test.go b/src/internal/zstd/fuzz_test.go new file mode 100644 index 0000000..4b5c996 --- /dev/null +++ b/src/internal/zstd/fuzz_test.go @@ -0,0 +1,139 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package zstd + +import ( + "bytes" + "io" + "os" + "os/exec" + "testing" +) + +// badStrings is some inputs that FuzzReader failed on earlier. +var badStrings = []string{ + "(\xb5/\xfdd00,\x05\x00\xc4\x0400000000000000000000000000000000000000000000000000000000000000000000000000000 \xa07100000000000000000000000000000000000000000000000000000000000000000000000000aM\x8a2y0B\b", + "(\xb5/\xfd00$\x05\x0020 00X70000a70000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + "(\xb5/\xfd00$\x05\x0020 00B00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + "(\xb5/\xfd00}\x00\x0020\x00\x9000000000000", + "(\xb5/\xfd00}\x00\x00&0\x02\x830!000000000", + "(\xb5/\xfd\x1002000$\x05\x0010\xcc0\xa8100000000100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", + "(\xb5/\xfd\x1002000$\x05\x0000\xcc0\xa8100d\x0000001000000000000000000000000000000000000000000000000000000000000000000000000\x000000000000000000000000000000000000000000000000000000000000000000000000000000", + "(\xb5/\xfd001\x00\x0000000000000000000", + "(\xb5/\xfd00\xec\x00\x00&@\x05\x05A7002\x02\x00\x02\x00\x02\x0000000000000000", + "(\xb5/\xfd00\xec\x00\x00V@\x05\x0517002\x02\x00\x02\x00\x02\x0000000000000000", + "\x50\x2a\x4d\x18\x02\x00\x00\x00", +} + +// This is a simple fuzzer to see if the decompressor panics. +func FuzzReader(f *testing.F) { + for _, test := range tests { + f.Add([]byte(test.compressed)) + } + for _, s := range badStrings { + f.Add([]byte(s)) + } + f.Fuzz(func(t *testing.T, b []byte) { + r := NewReader(bytes.NewReader(b)) + io.Copy(io.Discard, r) + }) +} + +// Fuzz test to verify that what we decompress is what we compress. +// This isn't a great fuzz test because the fuzzer can't efficiently +// explore the space of decompressor behavior, since it can't see +// what the compressor is doing. But it's better than nothing. +func FuzzDecompressor(f *testing.F) { + zstd := findZstd(f) + + for _, test := range tests { + f.Add([]byte(test.uncompressed)) + } + + // Add some larger data, as that has more interesting compression. + f.Add(bytes.Repeat([]byte("abcdefghijklmnop"), 256)) + var buf bytes.Buffer + for i := 0; i < 256; i++ { + buf.WriteByte(byte(i)) + } + f.Add(bytes.Repeat(buf.Bytes(), 64)) + f.Add(bigData(f)) + + f.Fuzz(func(t *testing.T, b []byte) { + cmd := exec.Command(zstd, "-z") + cmd.Stdin = bytes.NewReader(b) + var compressed bytes.Buffer + cmd.Stdout = &compressed + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + t.Errorf("running zstd failed: %v", err) + } + + r := NewReader(bytes.NewReader(compressed.Bytes())) + got, err := io.ReadAll(r) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, b) { + showDiffs(t, got, b) + } + }) +} + +// Fuzz test to check that if we can decompress some data, +// so can zstd, and that we get the same result. +func FuzzReverse(f *testing.F) { + zstd := findZstd(f) + + for _, test := range tests { + f.Add([]byte(test.compressed)) + } + + // Set a hook to reject some cases where we don't match zstd. + fuzzing = true + defer func() { fuzzing = false }() + + f.Fuzz(func(t *testing.T, b []byte) { + r := NewReader(bytes.NewReader(b)) + goExp, goErr := io.ReadAll(r) + + cmd := exec.Command(zstd, "-d") + cmd.Stdin = bytes.NewReader(b) + var uncompressed bytes.Buffer + cmd.Stdout = &uncompressed + cmd.Stderr = os.Stderr + zstdErr := cmd.Run() + zstdExp := uncompressed.Bytes() + + if goErr == nil && zstdErr == nil { + if !bytes.Equal(zstdExp, goExp) { + showDiffs(t, zstdExp, goExp) + } + } else { + // Ideally we should check that this package and + // the zstd program both fail or both succeed, + // and that if they both fail one byte sequence + // is an exact prefix of the other. + // Actually trying this proved to be frustrating, + // as the zstd program appears to accept invalid + // byte sequences using rules that are difficult + // to determine. + // So we just check the prefix. + + c := len(goExp) + if c > len(zstdExp) { + c = len(zstdExp) + } + goExp = goExp[:c] + zstdExp = zstdExp[:c] + if !bytes.Equal(goExp, zstdExp) { + t.Error("byte mismatch after error") + t.Logf("Go error: %v\n", goErr) + t.Logf("zstd error: %v\n", zstdErr) + showDiffs(t, zstdExp, goExp) + } + } + }) +} diff --git a/src/internal/zstd/huff.go b/src/internal/zstd/huff.go new file mode 100644 index 0000000..452e24b --- /dev/null +++ b/src/internal/zstd/huff.go @@ -0,0 +1,204 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package zstd + +import ( + "io" + "math/bits" +) + +// maxHuffmanBits is the largest possible Huffman table bits. +const maxHuffmanBits = 11 + +// readHuff reads Huffman table from data starting at off into table. +// Each entry in a Huffman table is a pair of bytes. +// The high byte is the encoded value. The low byte is the number +// of bits used to encode that value. We index into the table +// with a value of size tableBits. A value that requires fewer bits +// appear in the table multiple times. +// This returns the number of bits in the Huffman table and the new offset. +// RFC 4.2.1. +func (r *Reader) readHuff(data block, off int, table []uint16) (tableBits, roff int, err error) { + if off >= len(data) { + return 0, 0, r.makeEOFError(off) + } + + hdr := data[off] + off++ + + var weights [256]uint8 + var count int + if hdr < 128 { + // The table is compressed using an FSE. RFC 4.2.1.2. + if len(r.fseScratch) < 1<<6 { + r.fseScratch = make([]fseEntry, 1<<6) + } + fseBits, noff, err := r.readFSE(data, off, 255, 6, r.fseScratch) + if err != nil { + return 0, 0, err + } + fseTable := r.fseScratch + + if off+int(hdr) > len(data) { + return 0, 0, r.makeEOFError(off) + } + + rbr, err := r.makeReverseBitReader(data, off+int(hdr)-1, noff) + if err != nil { + return 0, 0, err + } + + state1, err := rbr.val(uint8(fseBits)) + if err != nil { + return 0, 0, err + } + + state2, err := rbr.val(uint8(fseBits)) + if err != nil { + return 0, 0, err + } + + // There are two independent FSE streams, tracked by + // state1 and state2. We decode them alternately. + + for { + pt := &fseTable[state1] + if !rbr.fetch(pt.bits) { + if count >= 254 { + return 0, 0, rbr.makeError("Huffman count overflow") + } + weights[count] = pt.sym + weights[count+1] = fseTable[state2].sym + count += 2 + break + } + + v, err := rbr.val(pt.bits) + if err != nil { + return 0, 0, err + } + state1 = uint32(pt.base) + v + + if count >= 255 { + return 0, 0, rbr.makeError("Huffman count overflow") + } + + weights[count] = pt.sym + count++ + + pt = &fseTable[state2] + + if !rbr.fetch(pt.bits) { + if count >= 254 { + return 0, 0, rbr.makeError("Huffman count overflow") + } + weights[count] = pt.sym + weights[count+1] = fseTable[state1].sym + count += 2 + break + } + + v, err = rbr.val(pt.bits) + if err != nil { + return 0, 0, err + } + state2 = uint32(pt.base) + v + + if count >= 255 { + return 0, 0, rbr.makeError("Huffman count overflow") + } + + weights[count] = pt.sym + count++ + } + + off += int(hdr) + } else { + // The table is not compressed. Each weight is 4 bits. + + count = int(hdr) - 127 + if off+((count+1)/2) >= len(data) { + return 0, 0, io.ErrUnexpectedEOF + } + for i := 0; i < count; i += 2 { + b := data[off] + off++ + weights[i] = b >> 4 + weights[i+1] = b & 0xf + } + } + + // RFC 4.2.1.3. + + var weightMark [13]uint32 + weightMask := uint32(0) + for _, w := range weights[:count] { + if w > 12 { + return 0, 0, r.makeError(off, "Huffman weight overflow") + } + weightMark[w]++ + if w > 0 { + weightMask += 1 << (w - 1) + } + } + if weightMask == 0 { + return 0, 0, r.makeError(off, "bad Huffman weights") + } + + tableBits = 32 - bits.LeadingZeros32(weightMask) + if tableBits > maxHuffmanBits { + return 0, 0, r.makeError(off, "bad Huffman weights") + } + + if len(table) < 1<<tableBits { + return 0, 0, r.makeError(off, "Huffman table too small") + } + + // Work out the last weight value, which is omitted because + // the weights must sum to a power of two. + left := (uint32(1) << tableBits) - weightMask + if left == 0 { + return 0, 0, r.makeError(off, "bad Huffman weights") + } + highBit := 31 - bits.LeadingZeros32(left) + if uint32(1)<<highBit != left { + return 0, 0, r.makeError(off, "bad Huffman weights") + } + if count >= 256 { + return 0, 0, r.makeError(off, "Huffman weight overflow") + } + weights[count] = uint8(highBit + 1) + count++ + weightMark[highBit+1]++ + + if weightMark[1] < 2 || weightMark[1]&1 != 0 { + return 0, 0, r.makeError(off, "bad Huffman weights") + } + + // Change weightMark from a count of weights to the index of + // the first symbol for that weight. We shift the indexes to + // also store how many we have seen so far, + next := uint32(0) + for i := 0; i < tableBits; i++ { + cur := next + next += weightMark[i+1] << i + weightMark[i+1] = cur + } + + for i, w := range weights[:count] { + if w == 0 { + continue + } + length := uint32(1) << (w - 1) + tval := uint16(i)<<8 | (uint16(tableBits) + 1 - uint16(w)) + start := weightMark[w] + for j := uint32(0); j < length; j++ { + table[start+j] = tval + } + weightMark[w] += length + } + + return tableBits, off, nil +} diff --git a/src/internal/zstd/literals.go b/src/internal/zstd/literals.go new file mode 100644 index 0000000..11ef859 --- /dev/null +++ b/src/internal/zstd/literals.go @@ -0,0 +1,336 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package zstd + +import ( + "encoding/binary" +) + +// readLiterals reads and decompresses the literals from data at off. +// The literals are appended to outbuf, which is returned. +// Also returns the new input offset. RFC 3.1.1.3.1. +func (r *Reader) readLiterals(data block, off int, outbuf []byte) (int, []byte, error) { + if off >= len(data) { + return 0, nil, r.makeEOFError(off) + } + + // Literals section header. RFC 3.1.1.3.1.1. + hdr := data[off] + off++ + + if (hdr&3) == 0 || (hdr&3) == 1 { + return r.readRawRLELiterals(data, off, hdr, outbuf) + } else { + return r.readHuffLiterals(data, off, hdr, outbuf) + } +} + +// readRawRLELiterals reads and decompresses a Raw_Literals_Block or +// a RLE_Literals_Block. RFC 3.1.1.3.1.1. +func (r *Reader) readRawRLELiterals(data block, off int, hdr byte, outbuf []byte) (int, []byte, error) { + raw := (hdr & 3) == 0 + + var regeneratedSize int + switch (hdr >> 2) & 3 { + case 0, 2: + regeneratedSize = int(hdr >> 3) + case 1: + if off >= len(data) { + return 0, nil, r.makeEOFError(off) + } + regeneratedSize = int(hdr>>4) + (int(data[off]) << 4) + off++ + case 3: + if off+1 >= len(data) { + return 0, nil, r.makeEOFError(off) + } + regeneratedSize = int(hdr>>4) + (int(data[off]) << 4) + (int(data[off+1]) << 12) + off += 2 + } + + // We are going to use the entire literal block in the output. + // The maximum size of one decompressed block is 128K, + // so we can't have more literals than that. + if regeneratedSize > 128<<10 { + return 0, nil, r.makeError(off, "literal size too large") + } + + if raw { + // RFC 3.1.1.3.1.2. + if off+regeneratedSize > len(data) { + return 0, nil, r.makeError(off, "raw literal size too large") + } + outbuf = append(outbuf, data[off:off+regeneratedSize]...) + off += regeneratedSize + } else { + // RFC 3.1.1.3.1.3. + if off >= len(data) { + return 0, nil, r.makeError(off, "RLE literal missing") + } + rle := data[off] + off++ + for i := 0; i < regeneratedSize; i++ { + outbuf = append(outbuf, rle) + } + } + + return off, outbuf, nil +} + +// readHuffLiterals reads and decompresses a Compressed_Literals_Block or +// a Treeless_Literals_Block. RFC 3.1.1.3.1.4. +func (r *Reader) readHuffLiterals(data block, off int, hdr byte, outbuf []byte) (int, []byte, error) { + var ( + regeneratedSize int + compressedSize int + streams int + ) + switch (hdr >> 2) & 3 { + case 0, 1: + if off+1 >= len(data) { + return 0, nil, r.makeEOFError(off) + } + regeneratedSize = (int(hdr) >> 4) | ((int(data[off]) & 0x3f) << 4) + compressedSize = (int(data[off]) >> 6) | (int(data[off+1]) << 2) + off += 2 + if ((hdr >> 2) & 3) == 0 { + streams = 1 + } else { + streams = 4 + } + case 2: + if off+2 >= len(data) { + return 0, nil, r.makeEOFError(off) + } + regeneratedSize = (int(hdr) >> 4) | (int(data[off]) << 4) | ((int(data[off+1]) & 3) << 12) + compressedSize = (int(data[off+1]) >> 2) | (int(data[off+2]) << 6) + off += 3 + streams = 4 + case 3: + if off+3 >= len(data) { + return 0, nil, r.makeEOFError(off) + } + regeneratedSize = (int(hdr) >> 4) | (int(data[off]) << 4) | ((int(data[off+1]) & 0x3f) << 12) + compressedSize = (int(data[off+1]) >> 6) | (int(data[off+2]) << 2) | (int(data[off+3]) << 10) + off += 4 + streams = 4 + } + + // We are going to use the entire literal block in the output. + // The maximum size of one decompressed block is 128K, + // so we can't have more literals than that. + if regeneratedSize > 128<<10 { + return 0, nil, r.makeError(off, "literal size too large") + } + + roff := off + compressedSize + if roff > len(data) || roff < 0 { + return 0, nil, r.makeEOFError(off) + } + + totalStreamsSize := compressedSize + if (hdr & 3) == 2 { + // Compressed_Literals_Block. + // Read new huffman tree. + + if len(r.huffmanTable) < 1<<maxHuffmanBits { + r.huffmanTable = make([]uint16, 1<<maxHuffmanBits) + } + + huffmanTableBits, hoff, err := r.readHuff(data, off, r.huffmanTable) + if err != nil { + return 0, nil, err + } + r.huffmanTableBits = huffmanTableBits + + if totalStreamsSize < hoff-off { + return 0, nil, r.makeError(off, "Huffman table too big") + } + totalStreamsSize -= hoff - off + off = hoff + } else { + // Treeless_Literals_Block + // Reuse previous Huffman tree. + if r.huffmanTableBits == 0 { + return 0, nil, r.makeError(off, "missing literals Huffman tree") + } + } + + // Decompress compressedSize bytes of data at off using the + // Huffman tree. + + var err error + if streams == 1 { + outbuf, err = r.readLiteralsOneStream(data, off, totalStreamsSize, regeneratedSize, outbuf) + } else { + outbuf, err = r.readLiteralsFourStreams(data, off, totalStreamsSize, regeneratedSize, outbuf) + } + + if err != nil { + return 0, nil, err + } + + return roff, outbuf, nil +} + +// readLiteralsOneStream reads a single stream of compressed literals. +func (r *Reader) readLiteralsOneStream(data block, off, compressedSize, regeneratedSize int, outbuf []byte) ([]byte, error) { + // We let the reverse bit reader read earlier bytes, + // because the Huffman table ignores bits that it doesn't need. + rbr, err := r.makeReverseBitReader(data, off+compressedSize-1, off-2) + if err != nil { + return nil, err + } + + huffTable := r.huffmanTable + huffBits := uint32(r.huffmanTableBits) + huffMask := (uint32(1) << huffBits) - 1 + + for i := 0; i < regeneratedSize; i++ { + if !rbr.fetch(uint8(huffBits)) { + return nil, rbr.makeError("literals Huffman stream out of bits") + } + + var t uint16 + idx := (rbr.bits >> (rbr.cnt - huffBits)) & huffMask + t = huffTable[idx] + outbuf = append(outbuf, byte(t>>8)) + rbr.cnt -= uint32(t & 0xff) + } + + return outbuf, nil +} + +// readLiteralsFourStreams reads four interleaved streams of +// compressed literals. +func (r *Reader) readLiteralsFourStreams(data block, off, totalStreamsSize, regeneratedSize int, outbuf []byte) ([]byte, error) { + // Read the jump table to find out where the streams are. + // RFC 3.1.1.3.1.6. + if off+5 >= len(data) { + return nil, r.makeEOFError(off) + } + if totalStreamsSize < 6 { + return nil, r.makeError(off, "total streams size too small for jump table") + } + // RFC 3.1.1.3.1.6. + // "The decompressed size of each stream is equal to (Regenerated_Size+3)/4, + // except for the last stream, which may be up to 3 bytes smaller, + // to reach a total decompressed size as specified in Regenerated_Size." + regeneratedStreamSize := (regeneratedSize + 3) / 4 + if regeneratedSize < regeneratedStreamSize*3 { + return nil, r.makeError(off, "regenerated size too small to decode streams") + } + + streamSize1 := binary.LittleEndian.Uint16(data[off:]) + streamSize2 := binary.LittleEndian.Uint16(data[off+2:]) + streamSize3 := binary.LittleEndian.Uint16(data[off+4:]) + off += 6 + + tot := uint64(streamSize1) + uint64(streamSize2) + uint64(streamSize3) + if tot > uint64(totalStreamsSize)-6 { + return nil, r.makeEOFError(off) + } + streamSize4 := uint32(totalStreamsSize) - 6 - uint32(tot) + + off-- + off1 := off + int(streamSize1) + start1 := off + 1 + + off2 := off1 + int(streamSize2) + start2 := off1 + 1 + + off3 := off2 + int(streamSize3) + start3 := off2 + 1 + + off4 := off3 + int(streamSize4) + start4 := off3 + 1 + + // We let the reverse bit readers read earlier bytes, + // because the Huffman tables ignore bits that they don't need. + + rbr1, err := r.makeReverseBitReader(data, off1, start1-2) + if err != nil { + return nil, err + } + + rbr2, err := r.makeReverseBitReader(data, off2, start2-2) + if err != nil { + return nil, err + } + + rbr3, err := r.makeReverseBitReader(data, off3, start3-2) + if err != nil { + return nil, err + } + + rbr4, err := r.makeReverseBitReader(data, off4, start4-2) + if err != nil { + return nil, err + } + + out1 := len(outbuf) + out2 := out1 + regeneratedStreamSize + out3 := out2 + regeneratedStreamSize + out4 := out3 + regeneratedStreamSize + + regeneratedStreamSize4 := regeneratedSize - regeneratedStreamSize*3 + + outbuf = append(outbuf, make([]byte, regeneratedSize)...) + + huffTable := r.huffmanTable + huffBits := uint32(r.huffmanTableBits) + huffMask := (uint32(1) << huffBits) - 1 + + for i := 0; i < regeneratedStreamSize; i++ { + use4 := i < regeneratedStreamSize4 + + fetchHuff := func(rbr *reverseBitReader) (uint16, error) { + if !rbr.fetch(uint8(huffBits)) { + return 0, rbr.makeError("literals Huffman stream out of bits") + } + idx := (rbr.bits >> (rbr.cnt - huffBits)) & huffMask + return huffTable[idx], nil + } + + t1, err := fetchHuff(&rbr1) + if err != nil { + return nil, err + } + + t2, err := fetchHuff(&rbr2) + if err != nil { + return nil, err + } + + t3, err := fetchHuff(&rbr3) + if err != nil { + return nil, err + } + + if use4 { + t4, err := fetchHuff(&rbr4) + if err != nil { + return nil, err + } + outbuf[out4] = byte(t4 >> 8) + out4++ + rbr4.cnt -= uint32(t4 & 0xff) + } + + outbuf[out1] = byte(t1 >> 8) + out1++ + rbr1.cnt -= uint32(t1 & 0xff) + + outbuf[out2] = byte(t2 >> 8) + out2++ + rbr2.cnt -= uint32(t2 & 0xff) + + outbuf[out3] = byte(t3 >> 8) + out3++ + rbr3.cnt -= uint32(t3 & 0xff) + } + + return outbuf, nil +} diff --git a/src/internal/zstd/testdata/1890a371.gettysburg.txt-100x.zst b/src/internal/zstd/testdata/1890a371.gettysburg.txt-100x.zst Binary files differnew file mode 100644 index 0000000..afb4a27 --- /dev/null +++ b/src/internal/zstd/testdata/1890a371.gettysburg.txt-100x.zst diff --git a/src/internal/zstd/testdata/README b/src/internal/zstd/testdata/README new file mode 100644 index 0000000..1a6dbb3 --- /dev/null +++ b/src/internal/zstd/testdata/README @@ -0,0 +1,10 @@ +This directory holds files for testing zstd.NewReader. + +Each one is a Zstandard compressed file named as hash.arbitrary-name.zst, +where hash is the first eight hexadecimal digits of the SHA256 hash +of the expected uncompressed content: + + zstd -d < 1890a371.gettysburg.txt-100x.zst | sha256sum | head -c 8 + 1890a371 + +The test uses hash value to verify decompression result. diff --git a/src/internal/zstd/testdata/f2a8e35c.helloworld-11000x.zst b/src/internal/zstd/testdata/f2a8e35c.helloworld-11000x.zst Binary files differnew file mode 100644 index 0000000..87a8aca --- /dev/null +++ b/src/internal/zstd/testdata/f2a8e35c.helloworld-11000x.zst diff --git a/src/internal/zstd/testdata/fcf30b99.zero-dictionary-ids.zst b/src/internal/zstd/testdata/fcf30b99.zero-dictionary-ids.zst Binary files differnew file mode 100644 index 0000000..1be89e8 --- /dev/null +++ b/src/internal/zstd/testdata/fcf30b99.zero-dictionary-ids.zst diff --git a/src/internal/zstd/window.go b/src/internal/zstd/window.go new file mode 100644 index 0000000..f9c5f04 --- /dev/null +++ b/src/internal/zstd/window.go @@ -0,0 +1,90 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package zstd + +// window stores up to size bytes of data. +// It is implemented as a circular buffer: +// sequential save calls append to the data slice until +// its length reaches configured size and after that, +// save calls overwrite previously saved data at off +// and update off such that it always points at +// the byte stored before others. +type window struct { + size int + data []byte + off int +} + +// reset clears stored data and configures window size. +func (w *window) reset(size int) { + w.data = w.data[:0] + w.off = 0 + w.size = size +} + +// len returns the number of stored bytes. +func (w *window) len() uint32 { + return uint32(len(w.data)) +} + +// save stores up to size last bytes from the buf. +func (w *window) save(buf []byte) { + if w.size == 0 { + return + } + if len(buf) == 0 { + return + } + + if len(buf) >= w.size { + from := len(buf) - w.size + w.data = append(w.data[:0], buf[from:]...) + w.off = 0 + return + } + + // Update off to point to the oldest remaining byte. + free := w.size - len(w.data) + if free == 0 { + n := copy(w.data[w.off:], buf) + if n == len(buf) { + w.off += n + } else { + w.off = copy(w.data, buf[n:]) + } + } else { + if free >= len(buf) { + w.data = append(w.data, buf...) + } else { + w.data = append(w.data, buf[:free]...) + w.off = copy(w.data, buf[free:]) + } + } +} + +// appendTo appends stored bytes between from and to indices to the buf. +// Index from must be less or equal to index to and to must be less or equal to w.len(). +func (w *window) appendTo(buf []byte, from, to uint32) []byte { + dataLen := uint32(len(w.data)) + from += uint32(w.off) + to += uint32(w.off) + + wrap := false + if from > dataLen { + from -= dataLen + wrap = !wrap + } + if to > dataLen { + to -= dataLen + wrap = !wrap + } + + if wrap { + buf = append(buf, w.data[from:]...) + return append(buf, w.data[:to]...) + } else { + return append(buf, w.data[from:to]...) + } +} diff --git a/src/internal/zstd/window_test.go b/src/internal/zstd/window_test.go new file mode 100644 index 0000000..afa2eef --- /dev/null +++ b/src/internal/zstd/window_test.go @@ -0,0 +1,72 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package zstd + +import ( + "bytes" + "fmt" + "testing" +) + +func makeSequence(start, n int) (seq []byte) { + for i := 0; i < n; i++ { + seq = append(seq, byte(start+i)) + } + return +} + +func TestWindow(t *testing.T) { + for size := 0; size <= 3; size++ { + for i := 0; i <= 2*size; i++ { + a := makeSequence('a', i) + for j := 0; j <= 2*size; j++ { + b := makeSequence('a'+i, j) + for k := 0; k <= 2*size; k++ { + c := makeSequence('a'+i+j, k) + + t.Run(fmt.Sprintf("%d-%d-%d-%d", size, i, j, k), func(t *testing.T) { + testWindow(t, size, a, b, c) + }) + } + } + } + } +} + +// testWindow tests window by saving three sequences of bytes to it. +// Third sequence tests read offset that can become non-zero only after second save. +func testWindow(t *testing.T, size int, a, b, c []byte) { + var w window + w.reset(size) + + w.save(a) + w.save(b) + w.save(c) + + var tail []byte + tail = append(tail, a...) + tail = append(tail, b...) + tail = append(tail, c...) + + if len(tail) > size { + tail = tail[len(tail)-size:] + } + + if w.len() != uint32(len(tail)) { + t.Errorf("wrong data length: got: %d, want: %d", w.len(), len(tail)) + } + + var from, to uint32 + for from = 0; from <= uint32(len(tail)); from++ { + for to = from; to <= uint32(len(tail)); to++ { + got := w.appendTo(nil, from, to) + want := tail[from:to] + + if !bytes.Equal(got, want) { + t.Errorf("wrong data at [%d:%d]: got %q, want %q", from, to, got, want) + } + } + } +} diff --git a/src/internal/zstd/xxhash.go b/src/internal/zstd/xxhash.go new file mode 100644 index 0000000..4d579ee --- /dev/null +++ b/src/internal/zstd/xxhash.go @@ -0,0 +1,148 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package zstd + +import ( + "encoding/binary" + "math/bits" +) + +const ( + xxhPrime64c1 = 0x9e3779b185ebca87 + xxhPrime64c2 = 0xc2b2ae3d27d4eb4f + xxhPrime64c3 = 0x165667b19e3779f9 + xxhPrime64c4 = 0x85ebca77c2b2ae63 + xxhPrime64c5 = 0x27d4eb2f165667c5 +) + +// xxhash64 is the state of a xxHash-64 checksum. +type xxhash64 struct { + len uint64 // total length hashed + v [4]uint64 // accumulators + buf [32]byte // buffer + cnt int // number of bytes in buffer +} + +// reset discards the current state and prepares to compute a new hash. +// We assume a seed of 0 since that is what zstd uses. +func (xh *xxhash64) reset() { + xh.len = 0 + + // Separate addition for awkward constant overflow. + xh.v[0] = xxhPrime64c1 + xh.v[0] += xxhPrime64c2 + + xh.v[1] = xxhPrime64c2 + xh.v[2] = 0 + + // Separate negation for awkward constant overflow. + xh.v[3] = xxhPrime64c1 + xh.v[3] = -xh.v[3] + + for i := range xh.buf { + xh.buf[i] = 0 + } + xh.cnt = 0 +} + +// update adds a buffer to the has. +func (xh *xxhash64) update(b []byte) { + xh.len += uint64(len(b)) + + if xh.cnt+len(b) < len(xh.buf) { + copy(xh.buf[xh.cnt:], b) + xh.cnt += len(b) + return + } + + if xh.cnt > 0 { + n := copy(xh.buf[xh.cnt:], b) + b = b[n:] + xh.v[0] = xh.round(xh.v[0], binary.LittleEndian.Uint64(xh.buf[:])) + xh.v[1] = xh.round(xh.v[1], binary.LittleEndian.Uint64(xh.buf[8:])) + xh.v[2] = xh.round(xh.v[2], binary.LittleEndian.Uint64(xh.buf[16:])) + xh.v[3] = xh.round(xh.v[3], binary.LittleEndian.Uint64(xh.buf[24:])) + xh.cnt = 0 + } + + for len(b) >= 32 { + xh.v[0] = xh.round(xh.v[0], binary.LittleEndian.Uint64(b)) + xh.v[1] = xh.round(xh.v[1], binary.LittleEndian.Uint64(b[8:])) + xh.v[2] = xh.round(xh.v[2], binary.LittleEndian.Uint64(b[16:])) + xh.v[3] = xh.round(xh.v[3], binary.LittleEndian.Uint64(b[24:])) + b = b[32:] + } + + if len(b) > 0 { + copy(xh.buf[:], b) + xh.cnt = len(b) + } +} + +// digest returns the final hash value. +func (xh *xxhash64) digest() uint64 { + var h64 uint64 + if xh.len < 32 { + h64 = xh.v[2] + xxhPrime64c5 + } else { + h64 = bits.RotateLeft64(xh.v[0], 1) + + bits.RotateLeft64(xh.v[1], 7) + + bits.RotateLeft64(xh.v[2], 12) + + bits.RotateLeft64(xh.v[3], 18) + h64 = xh.mergeRound(h64, xh.v[0]) + h64 = xh.mergeRound(h64, xh.v[1]) + h64 = xh.mergeRound(h64, xh.v[2]) + h64 = xh.mergeRound(h64, xh.v[3]) + } + + h64 += xh.len + + len := xh.len + len &= 31 + buf := xh.buf[:] + for len >= 8 { + k1 := xh.round(0, binary.LittleEndian.Uint64(buf)) + buf = buf[8:] + h64 ^= k1 + h64 = bits.RotateLeft64(h64, 27)*xxhPrime64c1 + xxhPrime64c4 + len -= 8 + } + if len >= 4 { + h64 ^= uint64(binary.LittleEndian.Uint32(buf)) * xxhPrime64c1 + buf = buf[4:] + h64 = bits.RotateLeft64(h64, 23)*xxhPrime64c2 + xxhPrime64c3 + len -= 4 + } + for len > 0 { + h64 ^= uint64(buf[0]) * xxhPrime64c5 + buf = buf[1:] + h64 = bits.RotateLeft64(h64, 11) * xxhPrime64c1 + len-- + } + + h64 ^= h64 >> 33 + h64 *= xxhPrime64c2 + h64 ^= h64 >> 29 + h64 *= xxhPrime64c3 + h64 ^= h64 >> 32 + + return h64 +} + +// round updates a value. +func (xh *xxhash64) round(v, n uint64) uint64 { + v += n * xxhPrime64c2 + v = bits.RotateLeft64(v, 31) + v *= xxhPrime64c1 + return v +} + +// mergeRound updates a value in the final round. +func (xh *xxhash64) mergeRound(v, n uint64) uint64 { + n = xh.round(0, n) + v ^= n + v = v*xxhPrime64c1 + xxhPrime64c4 + return v +} diff --git a/src/internal/zstd/xxhash_test.go b/src/internal/zstd/xxhash_test.go new file mode 100644 index 0000000..68ca558 --- /dev/null +++ b/src/internal/zstd/xxhash_test.go @@ -0,0 +1,115 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package zstd + +import ( + "bytes" + "os" + "os/exec" + "strconv" + "testing" +) + +var xxHashTests = []struct { + data string + hash uint64 +}{ + { + "hello, world", + 0xb33a384e6d1b1242, + }, + { + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789$", + 0x1032d841e824f998, + }, +} + +func TestXXHash(t *testing.T) { + var xh xxhash64 + for i, test := range xxHashTests { + xh.reset() + xh.update([]byte(test.data)) + if got := xh.digest(); got != test.hash { + t.Errorf("#%d: got %#x want %#x", i, got, test.hash) + } + } +} + +func TestLargeXXHash(t *testing.T) { + if testing.Short() { + t.Skip("skipping expensive test in short mode") + } + + data, err := os.ReadFile("../../testdata/Isaac.Newton-Opticks.txt") + if err != nil { + t.Fatal(err) + } + + var xh xxhash64 + xh.reset() + i := 0 + for i < len(data) { + // Write varying amounts to test buffering. + c := i%4094 + 1 + if i+c > len(data) { + c = len(data) - i + } + xh.update(data[i : i+c]) + i += c + } + + got := xh.digest() + want := uint64(0xf0dd39fd7e063f82) + if got != want { + t.Errorf("got %#x want %#x", got, want) + } +} + +func findXxhsum(t testing.TB) string { + xxhsum, err := exec.LookPath("xxhsum") + if err != nil { + t.Skip("skipping because xxhsum not found") + } + return xxhsum +} + +func FuzzXXHash(f *testing.F) { + xxhsum := findXxhsum(f) + + for _, test := range xxHashTests { + f.Add([]byte(test.data)) + } + f.Add(bytes.Repeat([]byte("abcdefghijklmnop"), 256)) + var buf bytes.Buffer + for i := 0; i < 256; i++ { + buf.WriteByte(byte(i)) + } + f.Add(bytes.Repeat(buf.Bytes(), 64)) + f.Add(bigData(f)) + + f.Fuzz(func(t *testing.T, b []byte) { + cmd := exec.Command(xxhsum, "-H64") + cmd.Stdin = bytes.NewReader(b) + var hhsumHash bytes.Buffer + cmd.Stdout = &hhsumHash + if err := cmd.Run(); err != nil { + t.Fatalf("running hhsum failed: %v", err) + } + hhHashBytes := bytes.Fields(bytes.TrimSpace(hhsumHash.Bytes()))[0] + hhHash, err := strconv.ParseUint(string(hhHashBytes), 16, 64) + if err != nil { + t.Fatalf("could not parse hash %q: %v", hhHashBytes, err) + } + + var xh xxhash64 + xh.reset() + xh.update(b) + goHash := xh.digest() + + if goHash != hhHash { + t.Errorf("Go hash %#x != xxhsum hash %#x", goHash, hhHash) + } + }) +} diff --git a/src/internal/zstd/zstd.go b/src/internal/zstd/zstd.go new file mode 100644 index 0000000..0230076 --- /dev/null +++ b/src/internal/zstd/zstd.go @@ -0,0 +1,522 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package zstd provides a decompressor for zstd streams, +// described in RFC 8878. It does not support dictionaries. +package zstd + +import ( + "encoding/binary" + "errors" + "fmt" + "io" +) + +// fuzzing is a fuzzer hook set to true when fuzzing. +// This is used to reject cases where we don't match zstd. +var fuzzing = false + +// Reader implements [io.Reader] to read a zstd compressed stream. +type Reader struct { + // The underlying Reader. + r io.Reader + + // Whether we have read the frame header. + // This is of interest when buffer is empty. + // If true we expect to see a new block. + sawFrameHeader bool + + // Whether the current frame expects a checksum. + hasChecksum bool + + // Whether we have read at least one frame. + readOneFrame bool + + // True if the frame size is not known. + frameSizeUnknown bool + + // The number of uncompressed bytes remaining in the current frame. + // If frameSizeUnknown is true, this is not valid. + remainingFrameSize uint64 + + // The number of bytes read from r up to the start of the current + // block, for error reporting. + blockOffset int64 + + // Buffered decompressed data. + buffer []byte + // Current read offset in buffer. + off int + + // The current repeated offsets. + repeatedOffset1 uint32 + repeatedOffset2 uint32 + repeatedOffset3 uint32 + + // The current Huffman tree used for compressing literals. + huffmanTable []uint16 + huffmanTableBits int + + // The window for back references. + window window + + // A buffer available to hold a compressed block. + compressedBuf []byte + + // A buffer for literals. + literals []byte + + // Sequence decode FSE tables. + seqTables [3][]fseBaselineEntry + seqTableBits [3]uint8 + + // Buffers for sequence decode FSE tables. + seqTableBuffers [3][]fseBaselineEntry + + // Scratch space used for small reads, to avoid allocation. + scratch [16]byte + + // A scratch table for reading an FSE. Only temporarily valid. + fseScratch []fseEntry + + // For checksum computation. + checksum xxhash64 +} + +// NewReader creates a new Reader that decompresses data from the given reader. +func NewReader(input io.Reader) *Reader { + r := new(Reader) + r.Reset(input) + return r +} + +// Reset discards the current state and starts reading a new stream from r. +// This permits reusing a Reader rather than allocating a new one. +func (r *Reader) Reset(input io.Reader) { + r.r = input + + // Several fields are preserved to avoid allocation. + // Others are always set before they are used. + r.sawFrameHeader = false + r.hasChecksum = false + r.readOneFrame = false + r.frameSizeUnknown = false + r.remainingFrameSize = 0 + r.blockOffset = 0 + r.buffer = r.buffer[:0] + r.off = 0 + // repeatedOffset1 + // repeatedOffset2 + // repeatedOffset3 + // huffmanTable + // huffmanTableBits + // window + // compressedBuf + // literals + // seqTables + // seqTableBits + // seqTableBuffers + // scratch + // fseScratch +} + +// Read implements [io.Reader]. +func (r *Reader) Read(p []byte) (int, error) { + if err := r.refillIfNeeded(); err != nil { + return 0, err + } + n := copy(p, r.buffer[r.off:]) + r.off += n + return n, nil +} + +// ReadByte implements [io.ByteReader]. +func (r *Reader) ReadByte() (byte, error) { + if err := r.refillIfNeeded(); err != nil { + return 0, err + } + ret := r.buffer[r.off] + r.off++ + return ret, nil +} + +// refillIfNeeded reads the next block if necessary. +func (r *Reader) refillIfNeeded() error { + for r.off >= len(r.buffer) { + if err := r.refill(); err != nil { + return err + } + r.off = 0 + } + return nil +} + +// refill reads and decompresses the next block. +func (r *Reader) refill() error { + if !r.sawFrameHeader { + if err := r.readFrameHeader(); err != nil { + return err + } + } + return r.readBlock() +} + +// readFrameHeader reads the frame header and prepares to read a block. +func (r *Reader) readFrameHeader() error { +retry: + relativeOffset := 0 + + // Read magic number. RFC 3.1.1. + if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil { + // We require that the stream contains at least one frame. + if err == io.EOF && !r.readOneFrame { + err = io.ErrUnexpectedEOF + } + return r.wrapError(relativeOffset, err) + } + + if magic := binary.LittleEndian.Uint32(r.scratch[:4]); magic != 0xfd2fb528 { + if magic >= 0x184d2a50 && magic <= 0x184d2a5f { + // This is a skippable frame. + r.blockOffset += int64(relativeOffset) + 4 + if err := r.skipFrame(); err != nil { + return err + } + r.readOneFrame = true + goto retry + } + + return r.makeError(relativeOffset, "invalid magic number") + } + + relativeOffset += 4 + + // Read Frame_Header_Descriptor. RFC 3.1.1.1.1. + if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil { + return r.wrapNonEOFError(relativeOffset, err) + } + descriptor := r.scratch[0] + + singleSegment := descriptor&(1<<5) != 0 + + fcsFieldSize := 1 << (descriptor >> 6) + if fcsFieldSize == 1 && !singleSegment { + fcsFieldSize = 0 + } + + var windowDescriptorSize int + if singleSegment { + windowDescriptorSize = 0 + } else { + windowDescriptorSize = 1 + } + + if descriptor&(1<<3) != 0 { + return r.makeError(relativeOffset, "reserved bit set in frame header descriptor") + } + + r.hasChecksum = descriptor&(1<<2) != 0 + if r.hasChecksum { + r.checksum.reset() + } + + // Dictionary_ID_Flag. RFC 3.1.1.1.1.6. + dictionaryIdSize := 0 + if dictIdFlag := descriptor & 3; dictIdFlag != 0 { + dictionaryIdSize = 1 << (dictIdFlag - 1) + } + + relativeOffset++ + + headerSize := windowDescriptorSize + dictionaryIdSize + fcsFieldSize + + if _, err := io.ReadFull(r.r, r.scratch[:headerSize]); err != nil { + return r.wrapNonEOFError(relativeOffset, err) + } + + // Figure out the maximum amount of data we need to retain + // for backreferences. + var windowSize int + if !singleSegment { + // Window descriptor. RFC 3.1.1.1.2. + windowDescriptor := r.scratch[0] + exponent := uint64(windowDescriptor >> 3) + mantissa := uint64(windowDescriptor & 7) + windowLog := exponent + 10 + windowBase := uint64(1) << windowLog + windowAdd := (windowBase / 8) * mantissa + windowSize = int(windowBase + windowAdd) + + // Default zstd sets limits on the window size. + if fuzzing && (windowLog > 31 || windowSize > 1<<27) { + return r.makeError(relativeOffset, "windowSize too large") + } + } + + // Dictionary_ID. RFC 3.1.1.1.3. + if dictionaryIdSize != 0 { + dictionaryId := r.scratch[windowDescriptorSize : windowDescriptorSize+dictionaryIdSize] + // Allow only zero Dictionary ID. + for _, b := range dictionaryId { + if b != 0 { + return r.makeError(relativeOffset, "dictionaries are not supported") + } + } + } + + // Frame_Content_Size. RFC 3.1.1.1.4. + r.frameSizeUnknown = false + r.remainingFrameSize = 0 + fb := r.scratch[windowDescriptorSize+dictionaryIdSize:] + switch fcsFieldSize { + case 0: + r.frameSizeUnknown = true + case 1: + r.remainingFrameSize = uint64(fb[0]) + case 2: + r.remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16(fb)) + case 4: + r.remainingFrameSize = uint64(binary.LittleEndian.Uint32(fb)) + case 8: + r.remainingFrameSize = binary.LittleEndian.Uint64(fb) + default: + panic("unreachable") + } + + // RFC 3.1.1.1.2. + // When Single_Segment_Flag is set, Window_Descriptor is not present. + // In this case, Window_Size is Frame_Content_Size. + if singleSegment { + windowSize = int(r.remainingFrameSize) + } + + // RFC 8878 3.1.1.1.1.2. permits us to set an 8M max on window size. + if windowSize > 8<<20 { + windowSize = 8 << 20 + } + + relativeOffset += headerSize + + r.sawFrameHeader = true + r.readOneFrame = true + r.blockOffset += int64(relativeOffset) + + // Prepare to read blocks from the frame. + r.repeatedOffset1 = 1 + r.repeatedOffset2 = 4 + r.repeatedOffset3 = 8 + r.huffmanTableBits = 0 + r.window.reset(windowSize) + r.seqTables[0] = nil + r.seqTables[1] = nil + r.seqTables[2] = nil + + return nil +} + +// skipFrame skips a skippable frame. RFC 3.1.2. +func (r *Reader) skipFrame() error { + relativeOffset := 0 + + if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil { + return r.wrapNonEOFError(relativeOffset, err) + } + + relativeOffset += 4 + + size := binary.LittleEndian.Uint32(r.scratch[:4]) + if size == 0 { + r.blockOffset += int64(relativeOffset) + return nil + } + + if seeker, ok := r.r.(io.Seeker); ok { + r.blockOffset += int64(relativeOffset) + // Implementations of Seeker do not always detect invalid offsets, + // so check that the new offset is valid by comparing to the end. + prev, err := seeker.Seek(0, io.SeekCurrent) + if err != nil { + return r.wrapError(0, err) + } + end, err := seeker.Seek(0, io.SeekEnd) + if err != nil { + return r.wrapError(0, err) + } + if prev > end-int64(size) { + r.blockOffset += end - prev + return r.makeEOFError(0) + } + + // The new offset is valid, so seek to it. + _, err = seeker.Seek(prev+int64(size), io.SeekStart) + if err != nil { + return r.wrapError(0, err) + } + r.blockOffset += int64(size) + return nil + } + + var skip []byte + const chunk = 1 << 20 // 1M + for size >= chunk { + if len(skip) == 0 { + skip = make([]byte, chunk) + } + if _, err := io.ReadFull(r.r, skip); err != nil { + return r.wrapNonEOFError(relativeOffset, err) + } + relativeOffset += chunk + size -= chunk + } + if size > 0 { + if len(skip) == 0 { + skip = make([]byte, size) + } + if _, err := io.ReadFull(r.r, skip); err != nil { + return r.wrapNonEOFError(relativeOffset, err) + } + relativeOffset += int(size) + } + + r.blockOffset += int64(relativeOffset) + + return nil +} + +// readBlock reads the next block from a frame. +func (r *Reader) readBlock() error { + relativeOffset := 0 + + // Read Block_Header. RFC 3.1.1.2. + if _, err := io.ReadFull(r.r, r.scratch[:3]); err != nil { + return r.wrapNonEOFError(relativeOffset, err) + } + + relativeOffset += 3 + + header := uint32(r.scratch[0]) | (uint32(r.scratch[1]) << 8) | (uint32(r.scratch[2]) << 16) + + lastBlock := header&1 != 0 + blockType := (header >> 1) & 3 + blockSize := int(header >> 3) + + // Maximum block size is smaller of window size and 128K. + // We don't record the window size for a single segment frame, + // so just use 128K. RFC 3.1.1.2.3, 3.1.1.2.4. + if blockSize > 128<<10 || (r.window.size > 0 && blockSize > r.window.size) { + return r.makeError(relativeOffset, "block size too large") + } + + // Handle different block types. RFC 3.1.1.2.2. + switch blockType { + case 0: + r.setBufferSize(blockSize) + if _, err := io.ReadFull(r.r, r.buffer); err != nil { + return r.wrapNonEOFError(relativeOffset, err) + } + relativeOffset += blockSize + r.blockOffset += int64(relativeOffset) + case 1: + r.setBufferSize(blockSize) + if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil { + return r.wrapNonEOFError(relativeOffset, err) + } + relativeOffset++ + v := r.scratch[0] + for i := range r.buffer { + r.buffer[i] = v + } + r.blockOffset += int64(relativeOffset) + case 2: + r.blockOffset += int64(relativeOffset) + if err := r.compressedBlock(blockSize); err != nil { + return err + } + r.blockOffset += int64(blockSize) + case 3: + return r.makeError(relativeOffset, "invalid block type") + } + + if !r.frameSizeUnknown { + if uint64(len(r.buffer)) > r.remainingFrameSize { + return r.makeError(relativeOffset, "too many uncompressed bytes in frame") + } + r.remainingFrameSize -= uint64(len(r.buffer)) + } + + if r.hasChecksum { + r.checksum.update(r.buffer) + } + + if !lastBlock { + r.window.save(r.buffer) + } else { + if !r.frameSizeUnknown && r.remainingFrameSize != 0 { + return r.makeError(relativeOffset, "not enough uncompressed bytes for frame") + } + // Check for checksum at end of frame. RFC 3.1.1. + if r.hasChecksum { + if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil { + return r.wrapNonEOFError(0, err) + } + + inputChecksum := binary.LittleEndian.Uint32(r.scratch[:4]) + dataChecksum := uint32(r.checksum.digest()) + if inputChecksum != dataChecksum { + return r.wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", dataChecksum, inputChecksum)) + } + + r.blockOffset += 4 + } + r.sawFrameHeader = false + } + + return nil +} + +// setBufferSize sets the decompressed buffer size. +// When this is called the buffer is empty. +func (r *Reader) setBufferSize(size int) { + if cap(r.buffer) < size { + need := size - cap(r.buffer) + r.buffer = append(r.buffer[:cap(r.buffer)], make([]byte, need)...) + } + r.buffer = r.buffer[:size] +} + +// zstdError is an error while decompressing. +type zstdError struct { + offset int64 + err error +} + +func (ze *zstdError) Error() string { + return fmt.Sprintf("zstd decompression error at %d: %v", ze.offset, ze.err) +} + +func (ze *zstdError) Unwrap() error { + return ze.err +} + +func (r *Reader) makeEOFError(off int) error { + return r.wrapError(off, io.ErrUnexpectedEOF) +} + +func (r *Reader) wrapNonEOFError(off int, err error) error { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return r.wrapError(off, err) +} + +func (r *Reader) makeError(off int, msg string) error { + return r.wrapError(off, errors.New(msg)) +} + +func (r *Reader) wrapError(off int, err error) error { + if err == io.EOF { + return err + } + return &zstdError{r.blockOffset + int64(off), err} +} diff --git a/src/internal/zstd/zstd_test.go b/src/internal/zstd/zstd_test.go new file mode 100644 index 0000000..f2a2e1b --- /dev/null +++ b/src/internal/zstd/zstd_test.go @@ -0,0 +1,335 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package zstd + +import ( + "bytes" + "crypto/sha256" + "fmt" + "internal/race" + "internal/testenv" + "io" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "testing" +) + +// tests holds some simple test cases, including some found by fuzzing. +var tests = []struct { + name, uncompressed, compressed string +}{ + { + "hello", + "hello, world\n", + "\x28\xb5\x2f\xfd\x24\x0d\x69\x00\x00\x68\x65\x6c\x6c\x6f\x2c\x20\x77\x6f\x72\x6c\x64\x0a\x4c\x1f\xf9\xf1", + }, + { + // a small compressed .debug_ranges section. + "ranges", + "\xcc\x11\x00\x00\x00\x00\x00\x00\xd5\x13\x00\x00\x00\x00\x00\x00" + + "\x1c\x14\x00\x00\x00\x00\x00\x00\x72\x14\x00\x00\x00\x00\x00\x00" + + "\x9d\x14\x00\x00\x00\x00\x00\x00\xd5\x14\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\xfb\x12\x00\x00\x00\x00\x00\x00\x09\x13\x00\x00\x00\x00\x00\x00" + + "\x0c\x13\x00\x00\x00\x00\x00\x00\xcb\x13\x00\x00\x00\x00\x00\x00" + + "\x29\x14\x00\x00\x00\x00\x00\x00\x4e\x14\x00\x00\x00\x00\x00\x00" + + "\x9d\x14\x00\x00\x00\x00\x00\x00\xd5\x14\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\xfb\x12\x00\x00\x00\x00\x00\x00\x09\x13\x00\x00\x00\x00\x00\x00" + + "\x67\x13\x00\x00\x00\x00\x00\x00\xcb\x13\x00\x00\x00\x00\x00\x00" + + "\x9d\x14\x00\x00\x00\x00\x00\x00\xd5\x14\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\x5f\x0b\x00\x00\x00\x00\x00\x00\x6c\x0b\x00\x00\x00\x00\x00\x00" + + "\x7d\x0b\x00\x00\x00\x00\x00\x00\x7e\x0c\x00\x00\x00\x00\x00\x00" + + "\x38\x0f\x00\x00\x00\x00\x00\x00\x5c\x0f\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\x83\x0c\x00\x00\x00\x00\x00\x00\xfa\x0c\x00\x00\x00\x00\x00\x00" + + "\xfd\x0d\x00\x00\x00\x00\x00\x00\xef\x0e\x00\x00\x00\x00\x00\x00" + + "\x14\x0f\x00\x00\x00\x00\x00\x00\x38\x0f\x00\x00\x00\x00\x00\x00" + + "\x9f\x0f\x00\x00\x00\x00\x00\x00\xac\x0f\x00\x00\x00\x00\x00\x00" + + "\xdb\x0f\x00\x00\x00\x00\x00\x00\xff\x0f\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\xfd\x0d\x00\x00\x00\x00\x00\x00\xd8\x0e\x00\x00\x00\x00\x00\x00" + + "\x9f\x0f\x00\x00\x00\x00\x00\x00\xac\x0f\x00\x00\x00\x00\x00\x00" + + "\xdb\x0f\x00\x00\x00\x00\x00\x00\xff\x0f\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\xfa\x0c\x00\x00\x00\x00\x00\x00\xea\x0d\x00\x00\x00\x00\x00\x00" + + "\xef\x0e\x00\x00\x00\x00\x00\x00\x14\x0f\x00\x00\x00\x00\x00\x00" + + "\x5c\x0f\x00\x00\x00\x00\x00\x00\x9f\x0f\x00\x00\x00\x00\x00\x00" + + "\xac\x0f\x00\x00\x00\x00\x00\x00\xdb\x0f\x00\x00\x00\x00\x00\x00" + + "\xff\x0f\x00\x00\x00\x00\x00\x00\x2c\x10\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\x60\x11\x00\x00\x00\x00\x00\x00\xd1\x16\x00\x00\x00\x00\x00\x00" + + "\x40\x0b\x00\x00\x00\x00\x00\x00\x2c\x10\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\x7a\x00\x00\x00\x00\x00\x00\x00\xb6\x00\x00\x00\x00\x00\x00\x00" + + "\x9f\x01\x00\x00\x00\x00\x00\x00\xa7\x01\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\x7a\x00\x00\x00\x00\x00\x00\x00\xa9\x00\x00\x00\x00\x00\x00\x00" + + "\x9f\x01\x00\x00\x00\x00\x00\x00\xa7\x01\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + + "\x28\xb5\x2f\xfd\x64\xa0\x01\x2d\x05\x00\xc4\x04\xcc\x11\x00\xd5" + + "\x13\x00\x1c\x14\x00\x72\x9d\xd5\xfb\x12\x00\x09\x0c\x13\xcb\x13" + + "\x29\x4e\x67\x5f\x0b\x6c\x0b\x7d\x0b\x7e\x0c\x38\x0f\x5c\x0f\x83" + + "\x0c\xfa\x0c\xfd\x0d\xef\x0e\x14\x38\x9f\x0f\xac\x0f\xdb\x0f\xff" + + "\x0f\xd8\x9f\xac\xdb\xff\xea\x5c\x2c\x10\x60\xd1\x16\x40\x0b\x7a" + + "\x00\xb6\x00\x9f\x01\xa7\x01\xa9\x36\x20\xa0\x83\x14\x34\x63\x4a" + + "\x21\x70\x8c\x07\x46\x03\x4e\x10\x62\x3c\x06\x4e\xc8\x8c\xb0\x32" + + "\x2a\x59\xad\xb2\xf1\x02\x82\x7c\x33\xcb\x92\x6f\x32\x4f\x9b\xb0" + + "\xa2\x30\xf0\xc0\x06\x1e\x98\x99\x2c\x06\x1e\xd8\xc0\x03\x56\xd8" + + "\xc0\x03\x0f\x6c\xe0\x01\xf1\xf0\xee\x9a\xc6\xc8\x97\x99\xd1\x6c" + + "\xb4\x21\x45\x3b\x10\xe4\x7b\x99\x4d\x8a\x36\x64\x5c\x77\x08\x02" + + "\xcb\xe0\xce", + }, + { + "fuzz1", + "0\x00\x00\x00\x00\x000\x00\x00\x00\x00\x001\x00\x00\x00\x00\x000000", + "(\xb5/\xfd\x04X\x8d\x00\x00P0\x000\x001\x000000\x03T\x02\x00\x01\x01m\xf9\xb7G", + }, + { + "empty block", + "", + "\x28\xb5\x2f\xfd\x00\x00\x15\x00\x00\x00\x00", + }, + { + "single skippable frame", + "", + "\x50\x2a\x4d\x18\x00\x00\x00\x00", + }, + { + "two skippable frames", + "", + "\x50\x2a\x4d\x18\x00\x00\x00\x00" + + "\x50\x2a\x4d\x18\x00\x00\x00\x00", + }, +} + +func TestSamples(t *testing.T) { + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + r := NewReader(strings.NewReader(test.compressed)) + got, err := io.ReadAll(r) + if err != nil { + t.Fatal(err) + } + gotstr := string(got) + if gotstr != test.uncompressed { + t.Errorf("got %q want %q", gotstr, test.uncompressed) + } + }) + } +} + +func TestReset(t *testing.T) { + input := strings.NewReader("") + r := NewReader(input) + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + input.Reset(test.compressed) + r.Reset(input) + got, err := io.ReadAll(r) + if err != nil { + t.Fatal(err) + } + gotstr := string(got) + if gotstr != test.uncompressed { + t.Errorf("got %q want %q", gotstr, test.uncompressed) + } + }) + } +} + +var ( + bigDataOnce sync.Once + bigDataBytes []byte + bigDataErr error +) + +// bigData returns the contents of our large test file repeated multiple times. +func bigData(t testing.TB) []byte { + bigDataOnce.Do(func() { + bigDataBytes, bigDataErr = os.ReadFile("../../testdata/Isaac.Newton-Opticks.txt") + if bigDataErr == nil { + bigDataBytes = bytes.Repeat(bigDataBytes, 20) + } + }) + if bigDataErr != nil { + t.Fatal(bigDataErr) + } + return bigDataBytes +} + +func findZstd(t testing.TB) string { + zstd, err := exec.LookPath("zstd") + if err != nil { + t.Skip("skipping because zstd not found") + } + return zstd +} + +var ( + zstdBigOnce sync.Once + zstdBigBytes []byte + zstdBigErr error +) + +// zstdBigData returns the compressed contents of our large test file. +// This will only run on Unix systems with zstd installed. +// That's OK as the package is GOOS-independent. +func zstdBigData(t testing.TB) []byte { + input := bigData(t) + + zstd := findZstd(t) + + zstdBigOnce.Do(func() { + cmd := exec.Command(zstd, "-z") + cmd.Stdin = bytes.NewReader(input) + var compressed bytes.Buffer + cmd.Stdout = &compressed + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + zstdBigErr = fmt.Errorf("running zstd failed: %v", err) + return + } + + zstdBigBytes = compressed.Bytes() + }) + if zstdBigErr != nil { + t.Fatal(zstdBigErr) + } + return zstdBigBytes +} + +// Test decompressing a large file. We don't have a compressor, +// so this test only runs on systems with zstd installed. +func TestLarge(t *testing.T) { + if testing.Short() { + t.Skip("skipping expensive test in short mode") + } + + data := bigData(t) + compressed := zstdBigData(t) + + t.Logf("zstd compressed %d bytes to %d", len(data), len(compressed)) + + r := NewReader(bytes.NewReader(compressed)) + got, err := io.ReadAll(r) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(got, data) { + showDiffs(t, got, data) + } +} + +// showDiffs reports the first few differences in two []byte. +func showDiffs(t *testing.T, got, want []byte) { + t.Error("data mismatch") + if len(got) != len(want) { + t.Errorf("got data length %d, want %d", len(got), len(want)) + } + diffs := 0 + for i, b := range got { + if i >= len(want) { + break + } + if b != want[i] { + diffs++ + if diffs > 20 { + break + } + t.Logf("%d: %#x != %#x", i, b, want[i]) + } + } +} + +func TestAlloc(t *testing.T) { + testenv.SkipIfOptimizationOff(t) + if race.Enabled { + t.Skip("skipping allocation test under race detector") + } + + compressed := zstdBigData(t) + input := bytes.NewReader(compressed) + r := NewReader(input) + c := testing.AllocsPerRun(10, func() { + input.Reset(compressed) + r.Reset(input) + io.Copy(io.Discard, r) + }) + if c != 0 { + t.Errorf("got %v allocs, want 0", c) + } +} + +func TestFileSamples(t *testing.T) { + samples, err := os.ReadDir("testdata") + if err != nil { + t.Fatal(err) + } + + for _, sample := range samples { + name := sample.Name() + if !strings.HasSuffix(name, ".zst") { + continue + } + + t.Run(name, func(t *testing.T) { + f, err := os.Open(filepath.Join("testdata", name)) + if err != nil { + t.Fatal(err) + } + + r := NewReader(f) + h := sha256.New() + if _, err := io.Copy(h, r); err != nil { + t.Fatal(err) + } + got := fmt.Sprintf("%x", h.Sum(nil))[:8] + + want, _, _ := strings.Cut(name, ".") + if got != want { + t.Errorf("Wrong uncompressed content hash: got %s, want %s", got, want) + } + }) + } +} + +func TestReaderBad(t *testing.T) { + for i, s := range badStrings { + t.Run(fmt.Sprintf("badStrings#%d", i), func(t *testing.T) { + _, err := io.Copy(io.Discard, NewReader(strings.NewReader(s))) + if err == nil { + t.Error("expected error") + } + }) + } +} + +func BenchmarkLarge(b *testing.B) { + b.StopTimer() + b.ReportAllocs() + + compressed := zstdBigData(b) + + b.SetBytes(int64(len(compressed))) + + input := bytes.NewReader(compressed) + r := NewReader(input) + + b.StartTimer() + for i := 0; i < b.N; i++ { + input.Reset(compressed) + r.Reset(input) + io.Copy(io.Discard, r) + } +} |