diff options
Diffstat (limited to 'src/internal/zstd/block.go')
-rw-r--r-- | src/internal/zstd/block.go | 425 |
1 files changed, 425 insertions, 0 deletions
diff --git a/src/internal/zstd/block.go b/src/internal/zstd/block.go new file mode 100644 index 0000000..11a99cd --- /dev/null +++ b/src/internal/zstd/block.go @@ -0,0 +1,425 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package zstd + +import ( + "io" +) + +// debug can be set in the source to print debug info using println. +const debug = false + +// compressedBlock decompresses a compressed block, storing the decompressed +// data in r.buffer. The blockSize argument is the compressed size. +// RFC 3.1.1.3. +func (r *Reader) compressedBlock(blockSize int) error { + if len(r.compressedBuf) >= blockSize { + r.compressedBuf = r.compressedBuf[:blockSize] + } else { + // We know that blockSize <= 128K, + // so this won't allocate an enormous amount. + need := blockSize - len(r.compressedBuf) + r.compressedBuf = append(r.compressedBuf, make([]byte, need)...) + } + + if _, err := io.ReadFull(r.r, r.compressedBuf); err != nil { + return r.wrapNonEOFError(0, err) + } + + data := block(r.compressedBuf) + off := 0 + r.buffer = r.buffer[:0] + + litoff, litbuf, err := r.readLiterals(data, off, r.literals[:0]) + if err != nil { + return err + } + r.literals = litbuf + + off = litoff + + seqCount, off, err := r.initSeqs(data, off) + if err != nil { + return err + } + + if seqCount == 0 { + // No sequences, just literals. + if off < len(data) { + return r.makeError(off, "extraneous data after no sequences") + } + + r.buffer = append(r.buffer, litbuf...) + + return nil + } + + return r.execSeqs(data, off, litbuf, seqCount) +} + +// seqCode is the kind of sequence codes we have to handle. +type seqCode int + +const ( + seqLiteral seqCode = iota + seqOffset + seqMatch +) + +// seqCodeInfoData is the information needed to set up seqTables and +// seqTableBits for a particular kind of sequence code. +type seqCodeInfoData struct { + predefTable []fseBaselineEntry // predefined FSE + predefTableBits int // number of bits in predefTable + maxSym int // max symbol value in FSE + maxBits int // max bits for FSE + + // toBaseline converts from an FSE table to an FSE baseline table. + toBaseline func(*Reader, int, []fseEntry, []fseBaselineEntry) error +} + +// seqCodeInfo is the seqCodeInfoData for each kind of sequence code. +var seqCodeInfo = [3]seqCodeInfoData{ + seqLiteral: { + predefTable: predefinedLiteralTable[:], + predefTableBits: 6, + maxSym: 35, + maxBits: 9, + toBaseline: (*Reader).makeLiteralBaselineFSE, + }, + seqOffset: { + predefTable: predefinedOffsetTable[:], + predefTableBits: 5, + maxSym: 31, + maxBits: 8, + toBaseline: (*Reader).makeOffsetBaselineFSE, + }, + seqMatch: { + predefTable: predefinedMatchTable[:], + predefTableBits: 6, + maxSym: 52, + maxBits: 9, + toBaseline: (*Reader).makeMatchBaselineFSE, + }, +} + +// initSeqs reads the Sequences_Section_Header and sets up the FSE +// tables used to read the sequence codes. It returns the number of +// sequences and the new offset. RFC 3.1.1.3.2.1. +func (r *Reader) initSeqs(data block, off int) (int, int, error) { + if off >= len(data) { + return 0, 0, r.makeEOFError(off) + } + + seqHdr := data[off] + off++ + if seqHdr == 0 { + return 0, off, nil + } + + var seqCount int + if seqHdr < 128 { + seqCount = int(seqHdr) + } else if seqHdr < 255 { + if off >= len(data) { + return 0, 0, r.makeEOFError(off) + } + seqCount = ((int(seqHdr) - 128) << 8) + int(data[off]) + off++ + } else { + if off+1 >= len(data) { + return 0, 0, r.makeEOFError(off) + } + seqCount = int(data[off]) + (int(data[off+1]) << 8) + 0x7f00 + off += 2 + } + + // Read the Symbol_Compression_Modes byte. + + if off >= len(data) { + return 0, 0, r.makeEOFError(off) + } + symMode := data[off] + if symMode&3 != 0 { + return 0, 0, r.makeError(off, "invalid symbol compression mode") + } + off++ + + // Set up the FSE tables used to decode the sequence codes. + + var err error + off, err = r.setSeqTable(data, off, seqLiteral, (symMode>>6)&3) + if err != nil { + return 0, 0, err + } + + off, err = r.setSeqTable(data, off, seqOffset, (symMode>>4)&3) + if err != nil { + return 0, 0, err + } + + off, err = r.setSeqTable(data, off, seqMatch, (symMode>>2)&3) + if err != nil { + return 0, 0, err + } + + return seqCount, off, nil +} + +// setSeqTable uses the Compression_Mode in mode to set up r.seqTables and +// r.seqTableBits for kind. We store these in the Reader because one of +// the modes simply reuses the value from the last block in the frame. +func (r *Reader) setSeqTable(data block, off int, kind seqCode, mode byte) (int, error) { + info := &seqCodeInfo[kind] + switch mode { + case 0: + // Predefined_Mode + r.seqTables[kind] = info.predefTable + r.seqTableBits[kind] = uint8(info.predefTableBits) + return off, nil + + case 1: + // RLE_Mode + if off >= len(data) { + return 0, r.makeEOFError(off) + } + rle := data[off] + off++ + + // Build a simple baseline table that always returns rle. + + entry := []fseEntry{ + { + sym: rle, + bits: 0, + base: 0, + }, + } + if cap(r.seqTableBuffers[kind]) == 0 { + r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits) + } + r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1] + if err := info.toBaseline(r, off, entry, r.seqTableBuffers[kind]); err != nil { + return 0, err + } + + r.seqTables[kind] = r.seqTableBuffers[kind] + r.seqTableBits[kind] = 0 + return off, nil + + case 2: + // FSE_Compressed_Mode + if cap(r.fseScratch) < 1<<info.maxBits { + r.fseScratch = make([]fseEntry, 1<<info.maxBits) + } + r.fseScratch = r.fseScratch[:1<<info.maxBits] + + tableBits, roff, err := r.readFSE(data, off, info.maxSym, info.maxBits, r.fseScratch) + if err != nil { + return 0, err + } + r.fseScratch = r.fseScratch[:1<<tableBits] + + if cap(r.seqTableBuffers[kind]) == 0 { + r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits) + } + r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1<<tableBits] + + if err := info.toBaseline(r, roff, r.fseScratch, r.seqTableBuffers[kind]); err != nil { + return 0, err + } + + r.seqTables[kind] = r.seqTableBuffers[kind] + r.seqTableBits[kind] = uint8(tableBits) + return roff, nil + + case 3: + // Repeat_Mode + if len(r.seqTables[kind]) == 0 { + return 0, r.makeError(off, "missing repeat sequence FSE table") + } + return off, nil + } + panic("unreachable") +} + +// execSeqs reads and executes the sequences. RFC 3.1.1.3.2.1.2. +func (r *Reader) execSeqs(data block, off int, litbuf []byte, seqCount int) error { + // Set up the initial states for the sequence code readers. + + rbr, err := r.makeReverseBitReader(data, len(data)-1, off) + if err != nil { + return err + } + + literalState, err := rbr.val(r.seqTableBits[seqLiteral]) + if err != nil { + return err + } + + offsetState, err := rbr.val(r.seqTableBits[seqOffset]) + if err != nil { + return err + } + + matchState, err := rbr.val(r.seqTableBits[seqMatch]) + if err != nil { + return err + } + + // Read and perform all the sequences. RFC 3.1.1.4. + + seq := 0 + for seq < seqCount { + if len(r.buffer)+len(litbuf) > 128<<10 { + return rbr.makeError("uncompressed size too big") + } + + ptoffset := &r.seqTables[seqOffset][offsetState] + ptmatch := &r.seqTables[seqMatch][matchState] + ptliteral := &r.seqTables[seqLiteral][literalState] + + add, err := rbr.val(ptoffset.basebits) + if err != nil { + return err + } + offset := ptoffset.baseline + add + + add, err = rbr.val(ptmatch.basebits) + if err != nil { + return err + } + match := ptmatch.baseline + add + + add, err = rbr.val(ptliteral.basebits) + if err != nil { + return err + } + literal := ptliteral.baseline + add + + // Handle repeat offsets. RFC 3.1.1.5. + // See the comment in makeOffsetBaselineFSE. + if ptoffset.basebits > 1 { + r.repeatedOffset3 = r.repeatedOffset2 + r.repeatedOffset2 = r.repeatedOffset1 + r.repeatedOffset1 = offset + } else { + if literal == 0 { + offset++ + } + switch offset { + case 1: + offset = r.repeatedOffset1 + case 2: + offset = r.repeatedOffset2 + r.repeatedOffset2 = r.repeatedOffset1 + r.repeatedOffset1 = offset + case 3: + offset = r.repeatedOffset3 + r.repeatedOffset3 = r.repeatedOffset2 + r.repeatedOffset2 = r.repeatedOffset1 + r.repeatedOffset1 = offset + case 4: + offset = r.repeatedOffset1 - 1 + r.repeatedOffset3 = r.repeatedOffset2 + r.repeatedOffset2 = r.repeatedOffset1 + r.repeatedOffset1 = offset + } + } + + seq++ + if seq < seqCount { + // Update the states. + add, err = rbr.val(ptliteral.bits) + if err != nil { + return err + } + literalState = uint32(ptliteral.base) + add + + add, err = rbr.val(ptmatch.bits) + if err != nil { + return err + } + matchState = uint32(ptmatch.base) + add + + add, err = rbr.val(ptoffset.bits) + if err != nil { + return err + } + offsetState = uint32(ptoffset.base) + add + } + + // The next sequence is now in literal, offset, match. + + if debug { + println("literal", literal, "offset", offset, "match", match) + } + + // Copy literal bytes from litbuf. + if literal > uint32(len(litbuf)) { + return rbr.makeError("literal byte overflow") + } + if literal > 0 { + r.buffer = append(r.buffer, litbuf[:literal]...) + litbuf = litbuf[literal:] + } + + if match > 0 { + if err := r.copyFromWindow(&rbr, offset, match); err != nil { + return err + } + } + } + + r.buffer = append(r.buffer, litbuf...) + + if rbr.cnt != 0 { + return r.makeError(off, "extraneous data after sequences") + } + + return nil +} + +// Copy match bytes from the decoded output, or the window, at offset. +func (r *Reader) copyFromWindow(rbr *reverseBitReader, offset, match uint32) error { + if offset == 0 { + return rbr.makeError("invalid zero offset") + } + + // Offset may point into the buffer or the window and + // match may extend past the end of the initial buffer. + // |--r.window--|--r.buffer--| + // |<-----offset------| + // |------match----------->| + bufferOffset := uint32(0) + lenBlock := uint32(len(r.buffer)) + if lenBlock < offset { + lenWindow := r.window.len() + copy := offset - lenBlock + if copy > lenWindow { + return rbr.makeError("offset past window") + } + windowOffset := lenWindow - copy + if copy > match { + copy = match + } + r.buffer = r.window.appendTo(r.buffer, windowOffset, windowOffset+copy) + match -= copy + } else { + bufferOffset = lenBlock - offset + } + + // We are being asked to copy data that we are adding to the + // buffer in the same copy. + for match > 0 { + copy := uint32(len(r.buffer)) - bufferOffset + if copy > match { + copy = match + } + r.buffer = append(r.buffer, r.buffer[bufferOffset:bufferOffset+copy]...) + match -= copy + } + return nil +} |