// Copyright 2023 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // Package zstd provides a decompressor for zstd streams, // described in RFC 8878. It does not support dictionaries. package zstd import ( "encoding/binary" "errors" "fmt" "io" ) // fuzzing is a fuzzer hook set to true when fuzzing. // This is used to reject cases where we don't match zstd. var fuzzing = false // Reader implements [io.Reader] to read a zstd compressed stream. type Reader struct { // The underlying Reader. r io.Reader // Whether we have read the frame header. // This is of interest when buffer is empty. // If true we expect to see a new block. sawFrameHeader bool // Whether the current frame expects a checksum. hasChecksum bool // Whether we have read at least one frame. readOneFrame bool // True if the frame size is not known. frameSizeUnknown bool // The number of uncompressed bytes remaining in the current frame. // If frameSizeUnknown is true, this is not valid. remainingFrameSize uint64 // The number of bytes read from r up to the start of the current // block, for error reporting. blockOffset int64 // Buffered decompressed data. buffer []byte // Current read offset in buffer. off int // The current repeated offsets. repeatedOffset1 uint32 repeatedOffset2 uint32 repeatedOffset3 uint32 // The current Huffman tree used for compressing literals. huffmanTable []uint16 huffmanTableBits int // The window for back references. windowSize int // maximum required window size window []byte // window data // A buffer available to hold a compressed block. compressedBuf []byte // A buffer for literals. literals []byte // Sequence decode FSE tables. seqTables [3][]fseBaselineEntry seqTableBits [3]uint8 // Buffers for sequence decode FSE tables. seqTableBuffers [3][]fseBaselineEntry // Scratch space used for small reads, to avoid allocation. scratch [16]byte // A scratch table for reading an FSE. Only temporarily valid. fseScratch []fseEntry // For checksum computation. checksum xxhash64 } // NewReader creates a new Reader that decompresses data from the given reader. func NewReader(input io.Reader) *Reader { r := new(Reader) r.Reset(input) return r } // Reset discards the current state and starts reading a new stream from r. // This permits reusing a Reader rather than allocating a new one. func (r *Reader) Reset(input io.Reader) { r.r = input // Several fields are preserved to avoid allocation. // Others are always set before they are used. r.sawFrameHeader = false r.hasChecksum = false r.readOneFrame = false r.frameSizeUnknown = false r.remainingFrameSize = 0 r.blockOffset = 0 // buffer r.off = 0 // repeatedOffset1 // repeatedOffset2 // repeatedOffset3 // huffmanTable // huffmanTableBits // windowSize // window // compressedBuf // literals // seqTables // seqTableBits // seqTableBuffers // scratch // fseScratch } // Read implements [io.Reader]. func (r *Reader) Read(p []byte) (int, error) { if err := r.refillIfNeeded(); err != nil { return 0, err } n := copy(p, r.buffer[r.off:]) r.off += n return n, nil } // ReadByte implements [io.ByteReader]. func (r *Reader) ReadByte() (byte, error) { if err := r.refillIfNeeded(); err != nil { return 0, err } ret := r.buffer[r.off] r.off++ return ret, nil } // refillIfNeeded reads the next block if necessary. func (r *Reader) refillIfNeeded() error { for r.off >= len(r.buffer) { if err := r.refill(); err != nil { return err } r.off = 0 } return nil } // refill reads and decompresses the next block. func (r *Reader) refill() error { if !r.sawFrameHeader { if err := r.readFrameHeader(); err != nil { return err } } return r.readBlock() } // readFrameHeader reads the frame header and prepares to read a block. func (r *Reader) readFrameHeader() error { retry: relativeOffset := 0 // Read magic number. RFC 3.1.1. if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil { // We require that the stream contain at least one frame. if err == io.EOF && !r.readOneFrame { err = io.ErrUnexpectedEOF } return r.wrapError(relativeOffset, err) } if magic := binary.LittleEndian.Uint32(r.scratch[:4]); magic != 0xfd2fb528 { if magic >= 0x184d2a50 && magic <= 0x184d2a5f { // This is a skippable frame. r.blockOffset += int64(relativeOffset) + 4 if err := r.skipFrame(); err != nil { return err } goto retry } return r.makeError(relativeOffset, "invalid magic number") } relativeOffset += 4 // Read Frame_Header_Descriptor. RFC 3.1.1.1.1. if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil { return r.wrapNonEOFError(relativeOffset, err) } descriptor := r.scratch[0] singleSegment := descriptor&(1<<5) != 0 fcsFieldSize := 1 << (descriptor >> 6) if fcsFieldSize == 1 && !singleSegment { fcsFieldSize = 0 } var windowDescriptorSize int if singleSegment { windowDescriptorSize = 0 } else { windowDescriptorSize = 1 } if descriptor&(1<<3) != 0 { return r.makeError(relativeOffset, "reserved bit set in frame header descriptor") } r.hasChecksum = descriptor&(1<<2) != 0 if r.hasChecksum { r.checksum.reset() } if descriptor&3 != 0 { return r.makeError(relativeOffset, "dictionaries are not supported") } relativeOffset++ headerSize := windowDescriptorSize + fcsFieldSize if _, err := io.ReadFull(r.r, r.scratch[:headerSize]); err != nil { return r.wrapNonEOFError(relativeOffset, err) } // Figure out the maximum amount of data we need to retain // for backreferences. if singleSegment { // No window required, as all the data is in a single buffer. r.windowSize = 0 } else { // Window descriptor. RFC 3.1.1.1.2. windowDescriptor := r.scratch[0] exponent := uint64(windowDescriptor >> 3) mantissa := uint64(windowDescriptor & 7) windowLog := exponent + 10 windowBase := uint64(1) << windowLog windowAdd := (windowBase / 8) * mantissa windowSize := windowBase + windowAdd // Default zstd sets limits on the window size. if fuzzing && (windowLog > 31 || windowSize > 1<<27) { return r.makeError(relativeOffset, "windowSize too large") } // RFC 8878 permits us to set an 8M max on window size. if windowSize > 8<<20 { windowSize = 8 << 20 } r.windowSize = int(windowSize) } // Frame_Content_Size. RFC 3.1.1.4. r.frameSizeUnknown = false r.remainingFrameSize = 0 fb := r.scratch[windowDescriptorSize:] switch fcsFieldSize { case 0: r.frameSizeUnknown = true case 1: r.remainingFrameSize = uint64(fb[0]) case 2: r.remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16(fb)) case 4: r.remainingFrameSize = uint64(binary.LittleEndian.Uint32(fb)) case 8: r.remainingFrameSize = binary.LittleEndian.Uint64(fb) default: panic("unreachable") } relativeOffset += headerSize r.sawFrameHeader = true r.readOneFrame = true r.blockOffset += int64(relativeOffset) // Prepare to read blocks from the frame. r.repeatedOffset1 = 1 r.repeatedOffset2 = 4 r.repeatedOffset3 = 8 r.huffmanTableBits = 0 r.window = r.window[:0] r.seqTables[0] = nil r.seqTables[1] = nil r.seqTables[2] = nil return nil } // skipFrame skips a skippable frame. RFC 3.1.2. func (r *Reader) skipFrame() error { relativeOffset := 0 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil { return r.wrapNonEOFError(relativeOffset, err) } relativeOffset += 4 size := binary.LittleEndian.Uint32(r.scratch[:4]) if seeker, ok := r.r.(io.Seeker); ok { if _, err := seeker.Seek(int64(size), io.SeekCurrent); err != nil { return err } r.blockOffset += int64(relativeOffset) + int64(size) return nil } var skip []byte const chunk = 1 << 20 // 1M for size >= chunk { if len(skip) == 0 { skip = make([]byte, chunk) } if _, err := io.ReadFull(r.r, skip); err != nil { return r.wrapNonEOFError(relativeOffset, err) } relativeOffset += chunk size -= chunk } if size > 0 { if len(skip) == 0 { skip = make([]byte, size) } if _, err := io.ReadFull(r.r, skip); err != nil { return r.wrapNonEOFError(relativeOffset, err) } relativeOffset += int(size) } r.blockOffset += int64(relativeOffset) return nil } // readBlock reads the next block from a frame. func (r *Reader) readBlock() error { relativeOffset := 0 // Read Block_Header. RFC 3.1.1.2. if _, err := io.ReadFull(r.r, r.scratch[:3]); err != nil { return r.wrapNonEOFError(relativeOffset, err) } relativeOffset += 3 header := uint32(r.scratch[0]) | (uint32(r.scratch[1]) << 8) | (uint32(r.scratch[2]) << 16) lastBlock := header&1 != 0 blockType := (header >> 1) & 3 blockSize := int(header >> 3) // Maximum block size is smaller of window size and 128K. // We don't record the window size for a single segment frame, // so just use 128K. RFC 3.1.1.2.3, 3.1.1.2.4. if blockSize > 128<<10 || (r.windowSize > 0 && blockSize > r.windowSize) { return r.makeError(relativeOffset, "block size too large") } // Handle different block types. RFC 3.1.1.2.2. switch blockType { case 0: r.setBufferSize(blockSize) if _, err := io.ReadFull(r.r, r.buffer); err != nil { return r.wrapNonEOFError(relativeOffset, err) } relativeOffset += blockSize r.blockOffset += int64(relativeOffset) case 1: r.setBufferSize(blockSize) if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil { return r.wrapNonEOFError(relativeOffset, err) } relativeOffset++ v := r.scratch[0] for i := range r.buffer { r.buffer[i] = v } r.blockOffset += int64(relativeOffset) case 2: r.blockOffset += int64(relativeOffset) if err := r.compressedBlock(blockSize); err != nil { return err } r.blockOffset += int64(blockSize) case 3: return r.makeError(relativeOffset, "invalid block type") } if !r.frameSizeUnknown { if uint64(len(r.buffer)) > r.remainingFrameSize { return r.makeError(relativeOffset, "too many uncompressed bytes in frame") } r.remainingFrameSize -= uint64(len(r.buffer)) } if r.hasChecksum { r.checksum.update(r.buffer) } if !lastBlock { r.saveWindow(r.buffer) } else { if !r.frameSizeUnknown && r.remainingFrameSize != 0 { return r.makeError(relativeOffset, "not enough uncompressed bytes for frame") } // Check for checksum at end of frame. RFC 3.1.1. if r.hasChecksum { if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil { return r.wrapNonEOFError(0, err) } inputChecksum := binary.LittleEndian.Uint32(r.scratch[:4]) dataChecksum := uint32(r.checksum.digest()) if inputChecksum != dataChecksum { return r.wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", dataChecksum, inputChecksum)) } r.blockOffset += 4 } r.sawFrameHeader = false } return nil } // setBufferSize sets the decompressed buffer size. // When this is called the buffer is empty. func (r *Reader) setBufferSize(size int) { if cap(r.buffer) < size { need := size - cap(r.buffer) r.buffer = append(r.buffer[:cap(r.buffer)], make([]byte, need)...) } r.buffer = r.buffer[:size] } // saveWindow saves bytes in the backreference window. // TODO: use a circular buffer for less data movement. func (r *Reader) saveWindow(buf []byte) { if r.windowSize == 0 { return } if len(buf) >= r.windowSize { from := len(buf) - r.windowSize r.window = append(r.window[:0], buf[from:]...) return } keep := r.windowSize - len(buf) // must be positive if keep < len(r.window) { remove := len(r.window) - keep copy(r.window[:], r.window[remove:]) } r.window = append(r.window, buf...) } // zstdError is an error while decompressing. type zstdError struct { offset int64 err error } func (ze *zstdError) Error() string { return fmt.Sprintf("zstd decompression error at %d: %v", ze.offset, ze.err) } func (ze *zstdError) Unwrap() error { return ze.err } func (r *Reader) makeEOFError(off int) error { return r.wrapError(off, io.ErrUnexpectedEOF) } func (r *Reader) wrapNonEOFError(off int, err error) error { if err == io.EOF { err = io.ErrUnexpectedEOF } return r.wrapError(off, err) } func (r *Reader) makeError(off int, msg string) error { return r.wrapError(off, errors.New(msg)) } func (r *Reader) wrapError(off int, err error) error { if err == io.EOF { return err } return &zstdError{r.blockOffset + int64(off), err} }