diff options
Diffstat (limited to 'src/internal/zstd/zstd.go')
-rw-r--r-- | src/internal/zstd/zstd.go | 508 |
1 files changed, 508 insertions, 0 deletions
diff --git a/src/internal/zstd/zstd.go b/src/internal/zstd/zstd.go new file mode 100644 index 0000000..a860789 --- /dev/null +++ b/src/internal/zstd/zstd.go @@ -0,0 +1,508 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package zstd provides a decompressor for zstd streams, +// described in RFC 8878. It does not support dictionaries. +package zstd + +import ( + "encoding/binary" + "errors" + "fmt" + "io" +) + +// fuzzing is a fuzzer hook set to true when fuzzing. +// This is used to reject cases where we don't match zstd. +var fuzzing = false + +// Reader implements [io.Reader] to read a zstd compressed stream. +type Reader struct { + // The underlying Reader. + r io.Reader + + // Whether we have read the frame header. + // This is of interest when buffer is empty. + // If true we expect to see a new block. + sawFrameHeader bool + + // Whether the current frame expects a checksum. + hasChecksum bool + + // Whether we have read at least one frame. + readOneFrame bool + + // True if the frame size is not known. + frameSizeUnknown bool + + // The number of uncompressed bytes remaining in the current frame. + // If frameSizeUnknown is true, this is not valid. + remainingFrameSize uint64 + + // The number of bytes read from r up to the start of the current + // block, for error reporting. + blockOffset int64 + + // Buffered decompressed data. + buffer []byte + // Current read offset in buffer. + off int + + // The current repeated offsets. + repeatedOffset1 uint32 + repeatedOffset2 uint32 + repeatedOffset3 uint32 + + // The current Huffman tree used for compressing literals. + huffmanTable []uint16 + huffmanTableBits int + + // The window for back references. + windowSize int // maximum required window size + window []byte // window data + + // A buffer available to hold a compressed block. + compressedBuf []byte + + // A buffer for literals. + literals []byte + + // Sequence decode FSE tables. + seqTables [3][]fseBaselineEntry + seqTableBits [3]uint8 + + // Buffers for sequence decode FSE tables. + seqTableBuffers [3][]fseBaselineEntry + + // Scratch space used for small reads, to avoid allocation. + scratch [16]byte + + // A scratch table for reading an FSE. Only temporarily valid. + fseScratch []fseEntry + + // For checksum computation. + checksum xxhash64 +} + +// NewReader creates a new Reader that decompresses data from the given reader. +func NewReader(input io.Reader) *Reader { + r := new(Reader) + r.Reset(input) + return r +} + +// Reset discards the current state and starts reading a new stream from r. +// This permits reusing a Reader rather than allocating a new one. +func (r *Reader) Reset(input io.Reader) { + r.r = input + + // Several fields are preserved to avoid allocation. + // Others are always set before they are used. + r.sawFrameHeader = false + r.hasChecksum = false + r.readOneFrame = false + r.frameSizeUnknown = false + r.remainingFrameSize = 0 + r.blockOffset = 0 + // buffer + r.off = 0 + // repeatedOffset1 + // repeatedOffset2 + // repeatedOffset3 + // huffmanTable + // huffmanTableBits + // windowSize + // window + // compressedBuf + // literals + // seqTables + // seqTableBits + // seqTableBuffers + // scratch + // fseScratch +} + +// Read implements [io.Reader]. +func (r *Reader) Read(p []byte) (int, error) { + if err := r.refillIfNeeded(); err != nil { + return 0, err + } + n := copy(p, r.buffer[r.off:]) + r.off += n + return n, nil +} + +// ReadByte implements [io.ByteReader]. +func (r *Reader) ReadByte() (byte, error) { + if err := r.refillIfNeeded(); err != nil { + return 0, err + } + ret := r.buffer[r.off] + r.off++ + return ret, nil +} + +// refillIfNeeded reads the next block if necessary. +func (r *Reader) refillIfNeeded() error { + for r.off >= len(r.buffer) { + if err := r.refill(); err != nil { + return err + } + r.off = 0 + } + return nil +} + +// refill reads and decompresses the next block. +func (r *Reader) refill() error { + if !r.sawFrameHeader { + if err := r.readFrameHeader(); err != nil { + return err + } + } + return r.readBlock() +} + +// readFrameHeader reads the frame header and prepares to read a block. +func (r *Reader) readFrameHeader() error { +retry: + relativeOffset := 0 + + // Read magic number. RFC 3.1.1. + if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil { + // We require that the stream contain at least one frame. + if err == io.EOF && !r.readOneFrame { + err = io.ErrUnexpectedEOF + } + return r.wrapError(relativeOffset, err) + } + + if magic := binary.LittleEndian.Uint32(r.scratch[:4]); magic != 0xfd2fb528 { + if magic >= 0x184d2a50 && magic <= 0x184d2a5f { + // This is a skippable frame. + r.blockOffset += int64(relativeOffset) + 4 + if err := r.skipFrame(); err != nil { + return err + } + goto retry + } + + return r.makeError(relativeOffset, "invalid magic number") + } + + relativeOffset += 4 + + // Read Frame_Header_Descriptor. RFC 3.1.1.1.1. + if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil { + return r.wrapNonEOFError(relativeOffset, err) + } + descriptor := r.scratch[0] + + singleSegment := descriptor&(1<<5) != 0 + + fcsFieldSize := 1 << (descriptor >> 6) + if fcsFieldSize == 1 && !singleSegment { + fcsFieldSize = 0 + } + + var windowDescriptorSize int + if singleSegment { + windowDescriptorSize = 0 + } else { + windowDescriptorSize = 1 + } + + if descriptor&(1<<3) != 0 { + return r.makeError(relativeOffset, "reserved bit set in frame header descriptor") + } + + r.hasChecksum = descriptor&(1<<2) != 0 + if r.hasChecksum { + r.checksum.reset() + } + + if descriptor&3 != 0 { + return r.makeError(relativeOffset, "dictionaries are not supported") + } + + relativeOffset++ + + headerSize := windowDescriptorSize + fcsFieldSize + + if _, err := io.ReadFull(r.r, r.scratch[:headerSize]); err != nil { + return r.wrapNonEOFError(relativeOffset, err) + } + + // Figure out the maximum amount of data we need to retain + // for backreferences. + + if singleSegment { + // No window required, as all the data is in a single buffer. + r.windowSize = 0 + } else { + // Window descriptor. RFC 3.1.1.1.2. + windowDescriptor := r.scratch[0] + exponent := uint64(windowDescriptor >> 3) + mantissa := uint64(windowDescriptor & 7) + windowLog := exponent + 10 + windowBase := uint64(1) << windowLog + windowAdd := (windowBase / 8) * mantissa + windowSize := windowBase + windowAdd + + // Default zstd sets limits on the window size. + if fuzzing && (windowLog > 31 || windowSize > 1<<27) { + return r.makeError(relativeOffset, "windowSize too large") + } + + // RFC 8878 permits us to set an 8M max on window size. + if windowSize > 8<<20 { + windowSize = 8 << 20 + } + + r.windowSize = int(windowSize) + } + + // Frame_Content_Size. RFC 3.1.1.4. + r.frameSizeUnknown = false + r.remainingFrameSize = 0 + fb := r.scratch[windowDescriptorSize:] + switch fcsFieldSize { + case 0: + r.frameSizeUnknown = true + case 1: + r.remainingFrameSize = uint64(fb[0]) + case 2: + r.remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16(fb)) + case 4: + r.remainingFrameSize = uint64(binary.LittleEndian.Uint32(fb)) + case 8: + r.remainingFrameSize = binary.LittleEndian.Uint64(fb) + default: + panic("unreachable") + } + + relativeOffset += headerSize + + r.sawFrameHeader = true + r.readOneFrame = true + r.blockOffset += int64(relativeOffset) + + // Prepare to read blocks from the frame. + r.repeatedOffset1 = 1 + r.repeatedOffset2 = 4 + r.repeatedOffset3 = 8 + r.huffmanTableBits = 0 + r.window = r.window[:0] + r.seqTables[0] = nil + r.seqTables[1] = nil + r.seqTables[2] = nil + + return nil +} + +// skipFrame skips a skippable frame. RFC 3.1.2. +func (r *Reader) skipFrame() error { + relativeOffset := 0 + + if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil { + return r.wrapNonEOFError(relativeOffset, err) + } + + relativeOffset += 4 + + size := binary.LittleEndian.Uint32(r.scratch[:4]) + + if seeker, ok := r.r.(io.Seeker); ok { + if _, err := seeker.Seek(int64(size), io.SeekCurrent); err != nil { + return err + } + r.blockOffset += int64(relativeOffset) + int64(size) + return nil + } + + var skip []byte + const chunk = 1 << 20 // 1M + for size >= chunk { + if len(skip) == 0 { + skip = make([]byte, chunk) + } + if _, err := io.ReadFull(r.r, skip); err != nil { + return r.wrapNonEOFError(relativeOffset, err) + } + relativeOffset += chunk + size -= chunk + } + if size > 0 { + if len(skip) == 0 { + skip = make([]byte, size) + } + if _, err := io.ReadFull(r.r, skip); err != nil { + return r.wrapNonEOFError(relativeOffset, err) + } + relativeOffset += int(size) + } + + r.blockOffset += int64(relativeOffset) + + return nil +} + +// readBlock reads the next block from a frame. +func (r *Reader) readBlock() error { + relativeOffset := 0 + + // Read Block_Header. RFC 3.1.1.2. + if _, err := io.ReadFull(r.r, r.scratch[:3]); err != nil { + return r.wrapNonEOFError(relativeOffset, err) + } + + relativeOffset += 3 + + header := uint32(r.scratch[0]) | (uint32(r.scratch[1]) << 8) | (uint32(r.scratch[2]) << 16) + + lastBlock := header&1 != 0 + blockType := (header >> 1) & 3 + blockSize := int(header >> 3) + + // Maximum block size is smaller of window size and 128K. + // We don't record the window size for a single segment frame, + // so just use 128K. RFC 3.1.1.2.3, 3.1.1.2.4. + if blockSize > 128<<10 || (r.windowSize > 0 && blockSize > r.windowSize) { + return r.makeError(relativeOffset, "block size too large") + } + + // Handle different block types. RFC 3.1.1.2.2. + switch blockType { + case 0: + r.setBufferSize(blockSize) + if _, err := io.ReadFull(r.r, r.buffer); err != nil { + return r.wrapNonEOFError(relativeOffset, err) + } + relativeOffset += blockSize + r.blockOffset += int64(relativeOffset) + case 1: + r.setBufferSize(blockSize) + if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil { + return r.wrapNonEOFError(relativeOffset, err) + } + relativeOffset++ + v := r.scratch[0] + for i := range r.buffer { + r.buffer[i] = v + } + r.blockOffset += int64(relativeOffset) + case 2: + r.blockOffset += int64(relativeOffset) + if err := r.compressedBlock(blockSize); err != nil { + return err + } + r.blockOffset += int64(blockSize) + case 3: + return r.makeError(relativeOffset, "invalid block type") + } + + if !r.frameSizeUnknown { + if uint64(len(r.buffer)) > r.remainingFrameSize { + return r.makeError(relativeOffset, "too many uncompressed bytes in frame") + } + r.remainingFrameSize -= uint64(len(r.buffer)) + } + + if r.hasChecksum { + r.checksum.update(r.buffer) + } + + if !lastBlock { + r.saveWindow(r.buffer) + } else { + if !r.frameSizeUnknown && r.remainingFrameSize != 0 { + return r.makeError(relativeOffset, "not enough uncompressed bytes for frame") + } + // Check for checksum at end of frame. RFC 3.1.1. + if r.hasChecksum { + if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil { + return r.wrapNonEOFError(0, err) + } + + inputChecksum := binary.LittleEndian.Uint32(r.scratch[:4]) + dataChecksum := uint32(r.checksum.digest()) + if inputChecksum != dataChecksum { + return r.wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", dataChecksum, inputChecksum)) + } + + r.blockOffset += 4 + } + r.sawFrameHeader = false + } + + return nil +} + +// setBufferSize sets the decompressed buffer size. +// When this is called the buffer is empty. +func (r *Reader) setBufferSize(size int) { + if cap(r.buffer) < size { + need := size - cap(r.buffer) + r.buffer = append(r.buffer[:cap(r.buffer)], make([]byte, need)...) + } + r.buffer = r.buffer[:size] +} + +// saveWindow saves bytes in the backreference window. +// TODO: use a circular buffer for less data movement. +func (r *Reader) saveWindow(buf []byte) { + if r.windowSize == 0 { + return + } + + if len(buf) >= r.windowSize { + from := len(buf) - r.windowSize + r.window = append(r.window[:0], buf[from:]...) + return + } + + keep := r.windowSize - len(buf) // must be positive + if keep < len(r.window) { + remove := len(r.window) - keep + copy(r.window[:], r.window[remove:]) + } + + r.window = append(r.window, buf...) +} + +// zstdError is an error while decompressing. +type zstdError struct { + offset int64 + err error +} + +func (ze *zstdError) Error() string { + return fmt.Sprintf("zstd decompression error at %d: %v", ze.offset, ze.err) +} + +func (ze *zstdError) Unwrap() error { + return ze.err +} + +func (r *Reader) makeEOFError(off int) error { + return r.wrapError(off, io.ErrUnexpectedEOF) +} + +func (r *Reader) wrapNonEOFError(off int, err error) error { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return r.wrapError(off, err) +} + +func (r *Reader) makeError(off int, msg string) error { + return r.wrapError(off, errors.New(msg)) +} + +func (r *Reader) wrapError(off int, err error) error { + if err == io.EOF { + return err + } + return &zstdError{r.blockOffset + int64(off), err} +} |