diff options
Diffstat (limited to 'src/zstd/doc/educational_decoder/zstd_decompress.c')
-rw-r--r-- | src/zstd/doc/educational_decoder/zstd_decompress.c | 2303 |
1 files changed, 2303 insertions, 0 deletions
diff --git a/src/zstd/doc/educational_decoder/zstd_decompress.c b/src/zstd/doc/educational_decoder/zstd_decompress.c new file mode 100644 index 00000000..bea0e0ce --- /dev/null +++ b/src/zstd/doc/educational_decoder/zstd_decompress.c @@ -0,0 +1,2303 @@ +/* + * Copyright (c) 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the + * LICENSE file in the root directory of this source tree) and the GPLv2 (found + * in the COPYING file in the root directory of this source tree). + */ + +/// Zstandard educational decoder implementation +/// See https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md + +#include <stdint.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include "zstd_decompress.h" + +/******* UTILITY MACROS AND TYPES *********************************************/ +// Max block size decompressed size is 128 KB and literal blocks can't be +// larger than their block +#define MAX_LITERALS_SIZE ((size_t)128 * 1024) + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +/// This decoder calls exit(1) when it encounters an error, however a production +/// library should propagate error codes +#define ERROR(s) \ + do { \ + fprintf(stderr, "Error: %s\n", s); \ + exit(1); \ + } while (0) +#define INP_SIZE() \ + ERROR("Input buffer smaller than it should be or input is " \ + "corrupted") +#define OUT_SIZE() ERROR("Output buffer too small for output") +#define CORRUPTION() ERROR("Corruption detected while decompressing") +#define BAD_ALLOC() ERROR("Memory allocation error") +#define IMPOSSIBLE() ERROR("An impossibility has occurred") + +typedef uint8_t u8; +typedef uint16_t u16; +typedef uint32_t u32; +typedef uint64_t u64; + +typedef int8_t i8; +typedef int16_t i16; +typedef int32_t i32; +typedef int64_t i64; +/******* END UTILITY MACROS AND TYPES *****************************************/ + +/******* IMPLEMENTATION PRIMITIVE PROTOTYPES **********************************/ +/// The implementations for these functions can be found at the bottom of this +/// file. They implement low-level functionality needed for the higher level +/// decompression functions. + +/*** IO STREAM OPERATIONS *************/ + +/// ostream_t/istream_t are used to wrap the pointers/length data passed into +/// ZSTD_decompress, so that all IO operations are safely bounds checked +/// They are written/read forward, and reads are treated as little-endian +/// They should be used opaquely to ensure safety +typedef struct { + u8 *ptr; + size_t len; +} ostream_t; + +typedef struct { + const u8 *ptr; + size_t len; + + // Input often reads a few bits at a time, so maintain an internal offset + int bit_offset; +} istream_t; + +/// The following two functions are the only ones that allow the istream to be +/// non-byte aligned + +/// Reads `num` bits from a bitstream, and updates the internal offset +static inline u64 IO_read_bits(istream_t *const in, const int num_bits); +/// Backs-up the stream by `num` bits so they can be read again +static inline void IO_rewind_bits(istream_t *const in, const int num_bits); +/// If the remaining bits in a byte will be unused, advance to the end of the +/// byte +static inline void IO_align_stream(istream_t *const in); + +/// Write the given byte into the output stream +static inline void IO_write_byte(ostream_t *const out, u8 symb); + +/// Returns the number of bytes left to be read in this stream. The stream must +/// be byte aligned. +static inline size_t IO_istream_len(const istream_t *const in); + +/// Advances the stream by `len` bytes, and returns a pointer to the chunk that +/// was skipped. The stream must be byte aligned. +static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len); +/// Advances the stream by `len` bytes, and returns a pointer to the chunk that +/// was skipped so it can be written to. +static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len); + +/// Advance the inner state by `len` bytes. The stream must be byte aligned. +static inline void IO_advance_input(istream_t *const in, size_t len); + +/// Returns an `ostream_t` constructed from the given pointer and length. +static inline ostream_t IO_make_ostream(u8 *out, size_t len); +/// Returns an `istream_t` constructed from the given pointer and length. +static inline istream_t IO_make_istream(const u8 *in, size_t len); + +/// Returns an `istream_t` with the same base as `in`, and length `len`. +/// Then, advance `in` to account for the consumed bytes. +/// `in` must be byte aligned. +static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len); +/*** END IO STREAM OPERATIONS *********/ + +/*** BITSTREAM OPERATIONS *************/ +/// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits, +/// and return them interpreted as a little-endian unsigned integer. +static inline u64 read_bits_LE(const u8 *src, const int num_bits, + const size_t offset); + +/// Read bits from the end of a HUF or FSE bitstream. `offset` is in bits, so +/// it updates `offset` to `offset - bits`, and then reads `bits` bits from +/// `src + offset`. If the offset becomes negative, the extra bits at the +/// bottom are filled in with `0` bits instead of reading from before `src`. +static inline u64 STREAM_read_bits(const u8 *src, const int bits, + i64 *const offset); +/*** END BITSTREAM OPERATIONS *********/ + +/*** BIT COUNTING OPERATIONS **********/ +/// Returns the index of the highest set bit in `num`, or `-1` if `num == 0` +static inline int highest_set_bit(const u64 num); +/*** END BIT COUNTING OPERATIONS ******/ + +/*** HUFFMAN PRIMITIVES ***************/ +// Table decode method uses exponential memory, so we need to limit depth +#define HUF_MAX_BITS (16) + +// Limit the maximum number of symbols to 256 so we can store a symbol in a byte +#define HUF_MAX_SYMBS (256) + +/// Structure containing all tables necessary for efficient Huffman decoding +typedef struct { + u8 *symbols; + u8 *num_bits; + int max_bits; +} HUF_dtable; + +/// Decode a single symbol and read in enough bits to refresh the state +static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable, + u16 *const state, const u8 *const src, + i64 *const offset); +/// Read in a full state's worth of bits to initialize it +static inline void HUF_init_state(const HUF_dtable *const dtable, + u16 *const state, const u8 *const src, + i64 *const offset); + +/// Decompresses a single Huffman stream, returns the number of bytes decoded. +/// `src_len` must be the exact length of the Huffman-coded block. +static size_t HUF_decompress_1stream(const HUF_dtable *const dtable, + ostream_t *const out, istream_t *const in); +/// Same as previous but decodes 4 streams, formatted as in the Zstandard +/// specification. +/// `src_len` must be the exact length of the Huffman-coded block. +static size_t HUF_decompress_4stream(const HUF_dtable *const dtable, + ostream_t *const out, istream_t *const in); + +/// Initialize a Huffman decoding table using the table of bit counts provided +static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits, + const int num_symbs); +/// Initialize a Huffman decoding table using the table of weights provided +/// Weights follow the definition provided in the Zstandard specification +static void HUF_init_dtable_usingweights(HUF_dtable *const table, + const u8 *const weights, + const int num_symbs); + +/// Free the malloc'ed parts of a decoding table +static void HUF_free_dtable(HUF_dtable *const dtable); + +/// Deep copy a decoding table, so that it can be used and free'd without +/// impacting the source table. +static void HUF_copy_dtable(HUF_dtable *const dst, const HUF_dtable *const src); +/*** END HUFFMAN PRIMITIVES ***********/ + +/*** FSE PRIMITIVES *******************/ +/// For more description of FSE see +/// https://github.com/Cyan4973/FiniteStateEntropy/ + +// FSE table decoding uses exponential memory, so limit the maximum accuracy +#define FSE_MAX_ACCURACY_LOG (15) +// Limit the maximum number of symbols so they can be stored in a single byte +#define FSE_MAX_SYMBS (256) + +/// The tables needed to decode FSE encoded streams +typedef struct { + u8 *symbols; + u8 *num_bits; + u16 *new_state_base; + int accuracy_log; +} FSE_dtable; + +/// Return the symbol for the current state +static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable, + const u16 state); +/// Read the number of bits necessary to update state, update, and shift offset +/// back to reflect the bits read +static inline void FSE_update_state(const FSE_dtable *const dtable, + u16 *const state, const u8 *const src, + i64 *const offset); + +/// Combine peek and update: decode a symbol and update the state +static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable, + u16 *const state, const u8 *const src, + i64 *const offset); + +/// Read bits from the stream to initialize the state and shift offset back +static inline void FSE_init_state(const FSE_dtable *const dtable, + u16 *const state, const u8 *const src, + i64 *const offset); + +/// Decompress two interleaved bitstreams (e.g. compressed Huffman weights) +/// using an FSE decoding table. `src_len` must be the exact length of the +/// block. +static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable, + ostream_t *const out, + istream_t *const in); + +/// Initialize a decoding table using normalized frequencies. +static void FSE_init_dtable(FSE_dtable *const dtable, + const i16 *const norm_freqs, const int num_symbs, + const int accuracy_log); + +/// Decode an FSE header as defined in the Zstandard format specification and +/// use the decoded frequencies to initialize a decoding table. +static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in, + const int max_accuracy_log); + +/// Initialize an FSE table that will always return the same symbol and consume +/// 0 bits per symbol, to be used for RLE mode in sequence commands +static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb); + +/// Free the malloc'ed parts of a decoding table +static void FSE_free_dtable(FSE_dtable *const dtable); + +/// Deep copy a decoding table, so that it can be used and free'd without +/// impacting the source table. +static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src); +/*** END FSE PRIMITIVES ***************/ + +/******* END IMPLEMENTATION PRIMITIVE PROTOTYPES ******************************/ + +/******* ZSTD HELPER STRUCTS AND PROTOTYPES ***********************************/ + +/// A small structure that can be reused in various places that need to access +/// frame header information +typedef struct { + // The size of window that we need to be able to contiguously store for + // references + size_t window_size; + // The total output size of this compressed frame + size_t frame_content_size; + + // The dictionary id if this frame uses one + u32 dictionary_id; + + // Whether or not the content of this frame has a checksum + int content_checksum_flag; + // Whether or not the output for this frame is in a single segment + int single_segment_flag; +} frame_header_t; + +/// The context needed to decode blocks in a frame +typedef struct { + frame_header_t header; + + // The total amount of data available for backreferences, to determine if an + // offset too large to be correct + size_t current_total_output; + + const u8 *dict_content; + size_t dict_content_len; + + // Entropy encoding tables so they can be repeated by future blocks instead + // of retransmitting + HUF_dtable literals_dtable; + FSE_dtable ll_dtable; + FSE_dtable ml_dtable; + FSE_dtable of_dtable; + + // The last 3 offsets for the special "repeat offsets". + u64 previous_offsets[3]; +} frame_context_t; + +/// The decoded contents of a dictionary so that it doesn't have to be repeated +/// for each frame that uses it +struct dictionary_s { + // Entropy tables + HUF_dtable literals_dtable; + FSE_dtable ll_dtable; + FSE_dtable ml_dtable; + FSE_dtable of_dtable; + + // Raw content for backreferences + u8 *content; + size_t content_size; + + // Offset history to prepopulate the frame's history + u64 previous_offsets[3]; + + u32 dictionary_id; +}; + +/// A tuple containing the parts necessary to decode and execute a ZSTD sequence +/// command +typedef struct { + u32 literal_length; + u32 match_length; + u32 offset; +} sequence_command_t; + +/// The decoder works top-down, starting at the high level like Zstd frames, and +/// working down to lower more technical levels such as blocks, literals, and +/// sequences. The high-level functions roughly follow the outline of the +/// format specification: +/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md + +/// Before the implementation of each high-level function declared here, the +/// prototypes for their helper functions are defined and explained + +/// Decode a single Zstd frame, or error if the input is not a valid frame. +/// Accepts a dict argument, which may be NULL indicating no dictionary. +/// See +/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame-concatenation +static void decode_frame(ostream_t *const out, istream_t *const in, + const dictionary_t *const dict); + +// Decode data in a compressed block +static void decompress_block(frame_context_t *const ctx, ostream_t *const out, + istream_t *const in); + +// Decode the literals section of a block +static size_t decode_literals(frame_context_t *const ctx, istream_t *const in, + u8 **const literals); + +// Decode the sequences part of a block +static size_t decode_sequences(frame_context_t *const ctx, istream_t *const in, + sequence_command_t **const sequences); + +// Execute the decoded sequences on the literals block +static void execute_sequences(frame_context_t *const ctx, ostream_t *const out, + const u8 *const literals, + const size_t literals_len, + const sequence_command_t *const sequences, + const size_t num_sequences); + +// Copies literals and returns the total literal length that was copied +static u32 copy_literals(const size_t seq, istream_t *litstream, + ostream_t *const out); + +// Given an offset code from a sequence command (either an actual offset value +// or an index for previous offset), computes the correct offset and udpates +// the offset history +static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist); + +// Given an offset, match length, and total output, as well as the frame +// context for the dictionary, determines if the dictionary is used and +// executes the copy operation +static void execute_match_copy(frame_context_t *const ctx, size_t offset, + size_t match_length, size_t total_output, + ostream_t *const out); + +/******* END ZSTD HELPER STRUCTS AND PROTOTYPES *******************************/ + +size_t ZSTD_decompress(void *const dst, const size_t dst_len, + const void *const src, const size_t src_len) { + dictionary_t* uninit_dict = create_dictionary(); + size_t const decomp_size = ZSTD_decompress_with_dict(dst, dst_len, src, + src_len, uninit_dict); + free_dictionary(uninit_dict); + return decomp_size; +} + +size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len, + const void *const src, const size_t src_len, + dictionary_t* parsed_dict) { + + istream_t in = IO_make_istream(src, src_len); + ostream_t out = IO_make_ostream(dst, dst_len); + + // "A content compressed by Zstandard is transformed into a Zstandard frame. + // Multiple frames can be appended into a single file or stream. A frame is + // totally independent, has a defined beginning and end, and a set of + // parameters which tells the decoder how to decompress it." + + /* this decoder assumes decompression of a single frame */ + decode_frame(&out, &in, parsed_dict); + + return out.ptr - (u8 *)dst; +} + +/******* FRAME DECODING ******************************************************/ + +static void decode_data_frame(ostream_t *const out, istream_t *const in, + const dictionary_t *const dict); +static void init_frame_context(frame_context_t *const context, + istream_t *const in, + const dictionary_t *const dict); +static void free_frame_context(frame_context_t *const context); +static void parse_frame_header(frame_header_t *const header, + istream_t *const in); +static void frame_context_apply_dict(frame_context_t *const ctx, + const dictionary_t *const dict); + +static void decompress_data(frame_context_t *const ctx, ostream_t *const out, + istream_t *const in); + +static void decode_frame(ostream_t *const out, istream_t *const in, + const dictionary_t *const dict) { + const u32 magic_number = IO_read_bits(in, 32); + // Zstandard frame + // + // "Magic_Number + // + // 4 Bytes, little-endian format. Value : 0xFD2FB528" + if (magic_number == 0xFD2FB528U) { + // ZSTD frame + decode_data_frame(out, in, dict); + + return; + } + + // not a real frame or a skippable frame + ERROR("Tried to decode non-ZSTD frame"); +} + +/// Decode a frame that contains compressed data. Not all frames do as there +/// are skippable frames. +/// See +/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#general-structure-of-zstandard-frame-format +static void decode_data_frame(ostream_t *const out, istream_t *const in, + const dictionary_t *const dict) { + frame_context_t ctx; + + // Initialize the context that needs to be carried from block to block + init_frame_context(&ctx, in, dict); + + if (ctx.header.frame_content_size != 0 && + ctx.header.frame_content_size > out->len) { + OUT_SIZE(); + } + + decompress_data(&ctx, out, in); + + free_frame_context(&ctx); +} + +/// Takes the information provided in the header and dictionary, and initializes +/// the context for this frame +static void init_frame_context(frame_context_t *const context, + istream_t *const in, + const dictionary_t *const dict) { + // Most fields in context are correct when initialized to 0 + memset(context, 0, sizeof(frame_context_t)); + + // Parse data from the frame header + parse_frame_header(&context->header, in); + + // Set up the offset history for the repeat offset commands + context->previous_offsets[0] = 1; + context->previous_offsets[1] = 4; + context->previous_offsets[2] = 8; + + // Apply details from the dict if it exists + frame_context_apply_dict(context, dict); +} + +static void free_frame_context(frame_context_t *const context) { + HUF_free_dtable(&context->literals_dtable); + + FSE_free_dtable(&context->ll_dtable); + FSE_free_dtable(&context->ml_dtable); + FSE_free_dtable(&context->of_dtable); + + memset(context, 0, sizeof(frame_context_t)); +} + +static void parse_frame_header(frame_header_t *const header, + istream_t *const in) { + // "The first header's byte is called the Frame_Header_Descriptor. It tells + // which other fields are present. Decoding this byte is enough to tell the + // size of Frame_Header. + // + // Bit number Field name + // 7-6 Frame_Content_Size_flag + // 5 Single_Segment_flag + // 4 Unused_bit + // 3 Reserved_bit + // 2 Content_Checksum_flag + // 1-0 Dictionary_ID_flag" + const u8 descriptor = IO_read_bits(in, 8); + + // decode frame header descriptor into flags + const u8 frame_content_size_flag = descriptor >> 6; + const u8 single_segment_flag = (descriptor >> 5) & 1; + const u8 reserved_bit = (descriptor >> 3) & 1; + const u8 content_checksum_flag = (descriptor >> 2) & 1; + const u8 dictionary_id_flag = descriptor & 3; + + if (reserved_bit != 0) { + CORRUPTION(); + } + + header->single_segment_flag = single_segment_flag; + header->content_checksum_flag = content_checksum_flag; + + // decode window size + if (!single_segment_flag) { + // "Provides guarantees on maximum back-reference distance that will be + // used within compressed data. This information is important for + // decoders to allocate enough memory. + // + // Bit numbers 7-3 2-0 + // Field name Exponent Mantissa" + u8 window_descriptor = IO_read_bits(in, 8); + u8 exponent = window_descriptor >> 3; + u8 mantissa = window_descriptor & 7; + + // Use the algorithm from the specification to compute window size + // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor + size_t window_base = (size_t)1 << (10 + exponent); + size_t window_add = (window_base / 8) * mantissa; + header->window_size = window_base + window_add; + } + + // decode dictionary id if it exists + if (dictionary_id_flag) { + // "This is a variable size field, which contains the ID of the + // dictionary required to properly decode the frame. Note that this + // field is optional. When it's not present, it's up to the caller to + // make sure it uses the correct dictionary. Format is little-endian." + const int bytes_array[] = {0, 1, 2, 4}; + const int bytes = bytes_array[dictionary_id_flag]; + + header->dictionary_id = IO_read_bits(in, bytes * 8); + } else { + header->dictionary_id = 0; + } + + // decode frame content size if it exists + if (single_segment_flag || frame_content_size_flag) { + // "This is the original (uncompressed) size. This information is + // optional. The Field_Size is provided according to value of + // Frame_Content_Size_flag. The Field_Size can be equal to 0 (not + // present), 1, 2, 4 or 8 bytes. Format is little-endian." + // + // if frame_content_size_flag == 0 but single_segment_flag is set, we + // still have a 1 byte field + const int bytes_array[] = {1, 2, 4, 8}; + const int bytes = bytes_array[frame_content_size_flag]; + + header->frame_content_size = IO_read_bits(in, bytes * 8); + if (bytes == 2) { + // "When Field_Size is 2, the offset of 256 is added." + header->frame_content_size += 256; + } + } else { + header->frame_content_size = 0; + } + + if (single_segment_flag) { + // "The Window_Descriptor byte is optional. It is absent when + // Single_Segment_flag is set. In this case, the maximum back-reference + // distance is the content size itself, which can be any value from 1 to + // 2^64-1 bytes (16 EB)." + header->window_size = header->frame_content_size; + } +} + +/// A dictionary acts as initializing values for the frame context before +/// decompression, so we implement it by applying it's predetermined +/// tables and content to the context before beginning decompression +static void frame_context_apply_dict(frame_context_t *const ctx, + const dictionary_t *const dict) { + // If the content pointer is NULL then it must be an empty dict + if (!dict || !dict->content) + return; + + // If the requested dictionary_id is non-zero, the correct dictionary must + // be present + if (ctx->header.dictionary_id != 0 && + ctx->header.dictionary_id != dict->dictionary_id) { + ERROR("Wrong dictionary provided"); + } + + // Copy the dict content to the context for references during sequence + // execution + ctx->dict_content = dict->content; + ctx->dict_content_len = dict->content_size; + + // If it's a formatted dict copy the precomputed tables in so they can + // be used in the table repeat modes + if (dict->dictionary_id != 0) { + // Deep copy the entropy tables so they can be freed independently of + // the dictionary struct + HUF_copy_dtable(&ctx->literals_dtable, &dict->literals_dtable); + FSE_copy_dtable(&ctx->ll_dtable, &dict->ll_dtable); + FSE_copy_dtable(&ctx->of_dtable, &dict->of_dtable); + FSE_copy_dtable(&ctx->ml_dtable, &dict->ml_dtable); + + // Copy the repeated offsets + memcpy(ctx->previous_offsets, dict->previous_offsets, + sizeof(ctx->previous_offsets)); + } +} + +/// Decompress the data from a frame block by block +static void decompress_data(frame_context_t *const ctx, ostream_t *const out, + istream_t *const in) { + // "A frame encapsulates one or multiple blocks. Each block can be + // compressed or not, and has a guaranteed maximum content size, which + // depends on frame parameters. Unlike frames, each block depends on + // previous blocks for proper decoding. However, each block can be + // decompressed without waiting for its successor, allowing streaming + // operations." + int last_block = 0; + do { + // "Last_Block + // + // The lowest bit signals if this block is the last one. Frame ends + // right after this block. + // + // Block_Type and Block_Size + // + // The next 2 bits represent the Block_Type, while the remaining 21 bits + // represent the Block_Size. Format is little-endian." + last_block = IO_read_bits(in, 1); + const int block_type = IO_read_bits(in, 2); + const size_t block_len = IO_read_bits(in, 21); + + switch (block_type) { + case 0: { + // "Raw_Block - this is an uncompressed block. Block_Size is the + // number of bytes to read and copy." + const u8 *const read_ptr = IO_get_read_ptr(in, block_len); + u8 *const write_ptr = IO_get_write_ptr(out, block_len); + + // Copy the raw data into the output + memcpy(write_ptr, read_ptr, block_len); + + ctx->current_total_output += block_len; + break; + } + case 1: { + // "RLE_Block - this is a single byte, repeated N times. In which + // case, Block_Size is the size to regenerate, while the + // "compressed" block is just 1 byte (the byte to repeat)." + const u8 *const read_ptr = IO_get_read_ptr(in, 1); + u8 *const write_ptr = IO_get_write_ptr(out, block_len); + + // Copy `block_len` copies of `read_ptr[0]` to the output + memset(write_ptr, read_ptr[0], block_len); + + ctx->current_total_output += block_len; + break; + } + case 2: { + // "Compressed_Block - this is a Zstandard compressed block, + // detailed in another section of this specification. Block_Size is + // the compressed size. + + // Create a sub-stream for the block + istream_t block_stream = IO_make_sub_istream(in, block_len); + decompress_block(ctx, out, &block_stream); + break; + } + case 3: + // "Reserved - this is not a block. This value cannot be used with + // current version of this specification." + CORRUPTION(); + break; + default: + IMPOSSIBLE(); + } + } while (!last_block); + + if (ctx->header.content_checksum_flag) { + // This program does not support checking the checksum, so skip over it + // if it's present + IO_advance_input(in, 4); + } +} +/******* END FRAME DECODING ***************************************************/ + +/******* BLOCK DECOMPRESSION **************************************************/ +static void decompress_block(frame_context_t *const ctx, ostream_t *const out, + istream_t *const in) { + // "A compressed block consists of 2 sections : + // + // Literals_Section + // Sequences_Section" + + + // Part 1: decode the literals block + u8 *literals = NULL; + const size_t literals_size = decode_literals(ctx, in, &literals); + + // Part 2: decode the sequences block + sequence_command_t *sequences = NULL; + const size_t num_sequences = + decode_sequences(ctx, in, &sequences); + + // Part 3: combine literals and sequence commands to generate output + execute_sequences(ctx, out, literals, literals_size, sequences, + num_sequences); + free(literals); + free(sequences); +} +/******* END BLOCK DECOMPRESSION **********************************************/ + +/******* LITERALS DECODING ****************************************************/ +static size_t decode_literals_simple(istream_t *const in, u8 **const literals, + const int block_type, + const int size_format); +static size_t decode_literals_compressed(frame_context_t *const ctx, + istream_t *const in, + u8 **const literals, + const int block_type, + const int size_format); +static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in); +static void fse_decode_hufweights(ostream_t *weights, istream_t *const in, + int *const num_symbs); + +static size_t decode_literals(frame_context_t *const ctx, istream_t *const in, + u8 **const literals) { + // "Literals can be stored uncompressed or compressed using Huffman prefix + // codes. When compressed, an optional tree description can be present, + // followed by 1 or 4 streams." + // + // "Literals_Section_Header + // + // Header is in charge of describing how literals are packed. It's a + // byte-aligned variable-size bitfield, ranging from 1 to 5 bytes, using + // little-endian convention." + // + // "Literals_Block_Type + // + // This field uses 2 lowest bits of first byte, describing 4 different block + // types" + // + // size_format takes between 1 and 2 bits + int block_type = IO_read_bits(in, 2); + int size_format = IO_read_bits(in, 2); + + if (block_type <= 1) { + // Raw or RLE literals block + return decode_literals_simple(in, literals, block_type, + size_format); + } else { + // Huffman compressed literals + return decode_literals_compressed(ctx, in, literals, block_type, + size_format); + } +} + +/// Decodes literals blocks in raw or RLE form +static size_t decode_literals_simple(istream_t *const in, u8 **const literals, + const int block_type, + const int size_format) { + size_t size; + switch (size_format) { + // These cases are in the form ?0 + // In this case, the ? bit is actually part of the size field + case 0: + case 2: + // "Size_Format uses 1 bit. Regenerated_Size uses 5 bits (0-31)." + IO_rewind_bits(in, 1); + size = IO_read_bits(in, 5); + break; + case 1: + // "Size_Format uses 2 bits. Regenerated_Size uses 12 bits (0-4095)." + size = IO_read_bits(in, 12); + break; + case 3: + // "Size_Format uses 2 bits. Regenerated_Size uses 20 bits (0-1048575)." + size = IO_read_bits(in, 20); + break; + default: + // Size format is in range 0-3 + IMPOSSIBLE(); + } + + if (size > MAX_LITERALS_SIZE) { + CORRUPTION(); + } + + *literals = malloc(size); + if (!*literals) { + BAD_ALLOC(); + } + + switch (block_type) { + case 0: { + // "Raw_Literals_Block - Literals are stored uncompressed." + const u8 *const read_ptr = IO_get_read_ptr(in, size); + memcpy(*literals, read_ptr, size); + break; + } + case 1: { + // "RLE_Literals_Block - Literals consist of a single byte value repeated N times." + const u8 *const read_ptr = IO_get_read_ptr(in, 1); + memset(*literals, read_ptr[0], size); + break; + } + default: + IMPOSSIBLE(); + } + + return size; +} + +/// Decodes Huffman compressed literals +static size_t decode_literals_compressed(frame_context_t *const ctx, + istream_t *const in, + u8 **const literals, + const int block_type, + const int size_format) { + size_t regenerated_size, compressed_size; + // Only size_format=0 has 1 stream, so default to 4 + int num_streams = 4; + switch (size_format) { + case 0: + // "A single stream. Both Compressed_Size and Regenerated_Size use 10 + // bits (0-1023)." + num_streams = 1; + // Fall through as it has the same size format + case 1: + // "4 streams. Both Compressed_Size and Regenerated_Size use 10 bits + // (0-1023)." + regenerated_size = IO_read_bits(in, 10); + compressed_size = IO_read_bits(in, 10); + break; + case 2: + // "4 streams. Both Compressed_Size and Regenerated_Size use 14 bits + // (0-16383)." + regenerated_size = IO_read_bits(in, 14); + compressed_size = IO_read_bits(in, 14); + break; + case 3: + // "4 streams. Both Compressed_Size and Regenerated_Size use 18 bits + // (0-262143)." + regenerated_size = IO_read_bits(in, 18); + compressed_size = IO_read_bits(in, 18); + break; + default: + // Impossible + IMPOSSIBLE(); + } + if (regenerated_size > MAX_LITERALS_SIZE || + compressed_size >= regenerated_size) { + CORRUPTION(); + } + + *literals = malloc(regenerated_size); + if (!*literals) { + BAD_ALLOC(); + } + + ostream_t lit_stream = IO_make_ostream(*literals, regenerated_size); + istream_t huf_stream = IO_make_sub_istream(in, compressed_size); + + if (block_type == 2) { + // Decode the provided Huffman table + // "This section is only present when Literals_Block_Type type is + // Compressed_Literals_Block (2)." + + HUF_free_dtable(&ctx->literals_dtable); + decode_huf_table(&ctx->literals_dtable, &huf_stream); + } else { + // If the previous Huffman table is being repeated, ensure it exists + if (!ctx->literals_dtable.symbols) { + CORRUPTION(); + } + } + + size_t symbols_decoded; + if (num_streams == 1) { + symbols_decoded = HUF_decompress_1stream(&ctx->literals_dtable, &lit_stream, &huf_stream); + } else { + symbols_decoded = HUF_decompress_4stream(&ctx->literals_dtable, &lit_stream, &huf_stream); + } + + if (symbols_decoded != regenerated_size) { + CORRUPTION(); + } + + return regenerated_size; +} + +// Decode the Huffman table description +static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in) { + // "All literal values from zero (included) to last present one (excluded) + // are represented by Weight with values from 0 to Max_Number_of_Bits." + + // "This is a single byte value (0-255), which describes how to decode the list of weights." + const u8 header = IO_read_bits(in, 8); + + u8 weights[HUF_MAX_SYMBS]; + memset(weights, 0, sizeof(weights)); + + int num_symbs; + + if (header >= 128) { + // "This is a direct representation, where each Weight is written + // directly as a 4 bits field (0-15). The full representation occupies + // ((Number_of_Symbols+1)/2) bytes, meaning it uses a last full byte + // even if Number_of_Symbols is odd. Number_of_Symbols = headerByte - + // 127" + num_symbs = header - 127; + const size_t bytes = (num_symbs + 1) / 2; + + const u8 *const weight_src = IO_get_read_ptr(in, bytes); + + for (int i = 0; i < num_symbs; i++) { + // "They are encoded forward, 2 + // weights to a byte with the first weight taking the top four bits + // and the second taking the bottom four (e.g. the following + // operations could be used to read the weights: Weight[0] = + // (Byte[0] >> 4), Weight[1] = (Byte[0] & 0xf), etc.)." + if (i % 2 == 0) { + weights[i] = weight_src[i / 2] >> 4; + } else { + weights[i] = weight_src[i / 2] & 0xf; + } + } + } else { + // The weights are FSE encoded, decode them before we can construct the + // table + istream_t fse_stream = IO_make_sub_istream(in, header); + ostream_t weight_stream = IO_make_ostream(weights, HUF_MAX_SYMBS); + fse_decode_hufweights(&weight_stream, &fse_stream, &num_symbs); + } + + // Construct the table using the decoded weights + HUF_init_dtable_usingweights(dtable, weights, num_symbs); +} + +static void fse_decode_hufweights(ostream_t *weights, istream_t *const in, + int *const num_symbs) { + const int MAX_ACCURACY_LOG = 7; + + FSE_dtable dtable; + + // "An FSE bitstream starts by a header, describing probabilities + // distribution. It will create a Decoding Table. For a list of Huffman + // weights, maximum accuracy is 7 bits." + FSE_decode_header(&dtable, in, MAX_ACCURACY_LOG); + + // Decode the weights + *num_symbs = FSE_decompress_interleaved2(&dtable, weights, in); + + FSE_free_dtable(&dtable); +} +/******* END LITERALS DECODING ************************************************/ + +/******* SEQUENCE DECODING ****************************************************/ +/// The combination of FSE states needed to decode sequences +typedef struct { + FSE_dtable ll_table; + FSE_dtable of_table; + FSE_dtable ml_table; + + u16 ll_state; + u16 of_state; + u16 ml_state; +} sequence_states_t; + +/// Different modes to signal to decode_seq_tables what to do +typedef enum { + seq_literal_length = 0, + seq_offset = 1, + seq_match_length = 2, +} seq_part_t; + +typedef enum { + seq_predefined = 0, + seq_rle = 1, + seq_fse = 2, + seq_repeat = 3, +} seq_mode_t; + +/// The predefined FSE distribution tables for `seq_predefined` mode +static const i16 SEQ_LITERAL_LENGTH_DEFAULT_DIST[36] = { + 4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, -1, -1, -1, -1}; +static const i16 SEQ_OFFSET_DEFAULT_DIST[29] = { + 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1}; +static const i16 SEQ_MATCH_LENGTH_DEFAULT_DIST[53] = { + 1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1}; + +/// The sequence decoding baseline and number of additional bits to read/add +/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#the-codes-for-literals-lengths-match-lengths-and-offsets +static const u32 SEQ_LITERAL_LENGTH_BASELINES[36] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 18, 20, 22, 24, 28, 32, 40, + 48, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65538}; +static const u8 SEQ_LITERAL_LENGTH_EXTRA_BITS[36] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + +static const u32 SEQ_MATCH_LENGTH_BASELINES[53] = { + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 37, 39, 41, 43, 47, 51, 59, 67, 83, + 99, 131, 259, 515, 1027, 2051, 4099, 8195, 16387, 32771, 65539}; +static const u8 SEQ_MATCH_LENGTH_EXTRA_BITS[53] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, + 2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + +/// Offset decoding is simpler so we just need a maximum code value +static const u8 SEQ_MAX_CODES[3] = {35, -1, 52}; + +static void decompress_sequences(frame_context_t *const ctx, + istream_t *const in, + sequence_command_t *const sequences, + const size_t num_sequences); +static sequence_command_t decode_sequence(sequence_states_t *const state, + const u8 *const src, + i64 *const offset); +static void decode_seq_table(FSE_dtable *const table, istream_t *const in, + const seq_part_t type, const seq_mode_t mode); + +static size_t decode_sequences(frame_context_t *const ctx, istream_t *in, + sequence_command_t **const sequences) { + // "A compressed block is a succession of sequences . A sequence is a + // literal copy command, followed by a match copy command. A literal copy + // command specifies a length. It is the number of bytes to be copied (or + // extracted) from the literal section. A match copy command specifies an + // offset and a length. The offset gives the position to copy from, which + // can be within a previous block." + + size_t num_sequences; + + // "Number_of_Sequences + // + // This is a variable size field using between 1 and 3 bytes. Let's call its + // first byte byte0." + u8 header = IO_read_bits(in, 8); + if (header == 0) { + // "There are no sequences. The sequence section stops there. + // Regenerated content is defined entirely by literals section." + *sequences = NULL; + return 0; + } else if (header < 128) { + // "Number_of_Sequences = byte0 . Uses 1 byte." + num_sequences = header; + } else if (header < 255) { + // "Number_of_Sequences = ((byte0-128) << 8) + byte1 . Uses 2 bytes." + num_sequences = ((header - 128) << 8) + IO_read_bits(in, 8); + } else { + // "Number_of_Sequences = byte1 + (byte2<<8) + 0x7F00 . Uses 3 bytes." + num_sequences = IO_read_bits(in, 16) + 0x7F00; + } + + *sequences = malloc(num_sequences * sizeof(sequence_command_t)); + if (!*sequences) { + BAD_ALLOC(); + } + + decompress_sequences(ctx, in, *sequences, num_sequences); + return num_sequences; +} + +/// Decompress the FSE encoded sequence commands +static void decompress_sequences(frame_context_t *const ctx, istream_t *in, + sequence_command_t *const sequences, + const size_t num_sequences) { + // "The Sequences_Section regroup all symbols required to decode commands. + // There are 3 symbol types : literals lengths, offsets and match lengths. + // They are encoded together, interleaved, in a single bitstream." + + // "Symbol compression modes + // + // This is a single byte, defining the compression mode of each symbol + // type." + // + // Bit number : Field name + // 7-6 : Literals_Lengths_Mode + // 5-4 : Offsets_Mode + // 3-2 : Match_Lengths_Mode + // 1-0 : Reserved + u8 compression_modes = IO_read_bits(in, 8); + + if ((compression_modes & 3) != 0) { + // Reserved bits set + CORRUPTION(); + } + + // "Following the header, up to 3 distribution tables can be described. When + // present, they are in this order : + // + // Literals lengths + // Offsets + // Match Lengths" + // Update the tables we have stored in the context + decode_seq_table(&ctx->ll_dtable, in, seq_literal_length, + (compression_modes >> 6) & 3); + + decode_seq_table(&ctx->of_dtable, in, seq_offset, + (compression_modes >> 4) & 3); + + decode_seq_table(&ctx->ml_dtable, in, seq_match_length, + (compression_modes >> 2) & 3); + + + sequence_states_t states; + + // Initialize the decoding tables + { + states.ll_table = ctx->ll_dtable; + states.of_table = ctx->of_dtable; + states.ml_table = ctx->ml_dtable; + } + + const size_t len = IO_istream_len(in); + const u8 *const src = IO_get_read_ptr(in, len); + + // "After writing the last bit containing information, the compressor writes + // a single 1-bit and then fills the byte with 0-7 0 bits of padding." + const int padding = 8 - highest_set_bit(src[len - 1]); + // The offset starts at the end because FSE streams are read backwards + i64 bit_offset = len * 8 - padding; + + // "The bitstream starts with initial state values, each using the required + // number of bits in their respective accuracy, decoded previously from + // their normalized distribution. + // + // It starts by Literals_Length_State, followed by Offset_State, and finally + // Match_Length_State." + FSE_init_state(&states.ll_table, &states.ll_state, src, &bit_offset); + FSE_init_state(&states.of_table, &states.of_state, src, &bit_offset); + FSE_init_state(&states.ml_table, &states.ml_state, src, &bit_offset); + + for (size_t i = 0; i < num_sequences; i++) { + // Decode sequences one by one + sequences[i] = decode_sequence(&states, src, &bit_offset); + } + + if (bit_offset != 0) { + CORRUPTION(); + } +} + +// Decode a single sequence and update the state +static sequence_command_t decode_sequence(sequence_states_t *const states, + const u8 *const src, + i64 *const offset) { + // "Each symbol is a code in its own context, which specifies Baseline and + // Number_of_Bits to add. Codes are FSE compressed, and interleaved with raw + // additional bits in the same bitstream." + + // Decode symbols, but don't update states + const u8 of_code = FSE_peek_symbol(&states->of_table, states->of_state); + const u8 ll_code = FSE_peek_symbol(&states->ll_table, states->ll_state); + const u8 ml_code = FSE_peek_symbol(&states->ml_table, states->ml_state); + + // Offset doesn't need a max value as it's not decoded using a table + if (ll_code > SEQ_MAX_CODES[seq_literal_length] || + ml_code > SEQ_MAX_CODES[seq_match_length]) { + CORRUPTION(); + } + + // Read the interleaved bits + sequence_command_t seq; + // "Decoding starts by reading the Number_of_Bits required to decode Offset. + // It then does the same for Match_Length, and then for Literals_Length." + seq.offset = ((u32)1 << of_code) + STREAM_read_bits(src, of_code, offset); + + seq.match_length = + SEQ_MATCH_LENGTH_BASELINES[ml_code] + + STREAM_read_bits(src, SEQ_MATCH_LENGTH_EXTRA_BITS[ml_code], offset); + + seq.literal_length = + SEQ_LITERAL_LENGTH_BASELINES[ll_code] + + STREAM_read_bits(src, SEQ_LITERAL_LENGTH_EXTRA_BITS[ll_code], offset); + + // "If it is not the last sequence in the block, the next operation is to + // update states. Using the rules pre-calculated in the decoding tables, + // Literals_Length_State is updated, followed by Match_Length_State, and + // then Offset_State." + // If the stream is complete don't read bits to update state + if (*offset != 0) { + FSE_update_state(&states->ll_table, &states->ll_state, src, offset); + FSE_update_state(&states->ml_table, &states->ml_state, src, offset); + FSE_update_state(&states->of_table, &states->of_state, src, offset); + } + + return seq; +} + +/// Given a sequence part and table mode, decode the FSE distribution +/// Errors if the mode is `seq_repeat` without a pre-existing table in `table` +static void decode_seq_table(FSE_dtable *const table, istream_t *const in, + const seq_part_t type, const seq_mode_t mode) { + // Constant arrays indexed by seq_part_t + const i16 *const default_distributions[] = {SEQ_LITERAL_LENGTH_DEFAULT_DIST, + SEQ_OFFSET_DEFAULT_DIST, + SEQ_MATCH_LENGTH_DEFAULT_DIST}; + const size_t default_distribution_lengths[] = {36, 29, 53}; + const size_t default_distribution_accuracies[] = {6, 5, 6}; + + const size_t max_accuracies[] = {9, 8, 9}; + + if (mode != seq_repeat) { + // Free old one before overwriting + FSE_free_dtable(table); + } + + switch (mode) { + case seq_predefined: { + // "Predefined_Mode : uses a predefined distribution table." + const i16 *distribution = default_distributions[type]; + const size_t symbs = default_distribution_lengths[type]; + const size_t accuracy_log = default_distribution_accuracies[type]; + + FSE_init_dtable(table, distribution, symbs, accuracy_log); + break; + } + case seq_rle: { + // "RLE_Mode : it's a single code, repeated Number_of_Sequences times." + const u8 symb = IO_get_read_ptr(in, 1)[0]; + FSE_init_dtable_rle(table, symb); + break; + } + case seq_fse: { + // "FSE_Compressed_Mode : standard FSE compression. A distribution table + // will be present " + FSE_decode_header(table, in, max_accuracies[type]); + break; + } + case seq_repeat: + // "Repeat_Mode : re-use distribution table from previous compressed + // block." + // Nothing to do here, table will be unchanged + if (!table->symbols) { + // This mode is invalid if we don't already have a table + CORRUPTION(); + } + break; + default: + // Impossible, as mode is from 0-3 + IMPOSSIBLE(); + break; + } + +} +/******* END SEQUENCE DECODING ************************************************/ + +/******* SEQUENCE EXECUTION ***************************************************/ +static void execute_sequences(frame_context_t *const ctx, ostream_t *const out, + const u8 *const literals, + const size_t literals_len, + const sequence_command_t *const sequences, + const size_t num_sequences) { + istream_t litstream = IO_make_istream(literals, literals_len); + + u64 *const offset_hist = ctx->previous_offsets; + size_t total_output = ctx->current_total_output; + + for (size_t i = 0; i < num_sequences; i++) { + const sequence_command_t seq = sequences[i]; + { + const u32 literals_size = copy_literals(seq.literal_length, &litstream, out); + total_output += literals_size; + } + + size_t const offset = compute_offset(seq, offset_hist); + + size_t const match_length = seq.match_length; + + execute_match_copy(ctx, offset, match_length, total_output, out); + + total_output += match_length; + } + + // Copy any leftover literals + { + size_t len = IO_istream_len(&litstream); + copy_literals(len, &litstream, out); + total_output += len; + } + + ctx->current_total_output = total_output; +} + +static u32 copy_literals(const size_t literal_length, istream_t *litstream, + ostream_t *const out) { + // If the sequence asks for more literals than are left, the + // sequence must be corrupted + if (literal_length > IO_istream_len(litstream)) { + CORRUPTION(); + } + + u8 *const write_ptr = IO_get_write_ptr(out, literal_length); + const u8 *const read_ptr = + IO_get_read_ptr(litstream, literal_length); + // Copy literals to output + memcpy(write_ptr, read_ptr, literal_length); + + return literal_length; +} + +static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist) { + size_t offset; + // Offsets are special, we need to handle the repeat offsets + if (seq.offset <= 3) { + // "The first 3 values define a repeated offset and we will call + // them Repeated_Offset1, Repeated_Offset2, and Repeated_Offset3. + // They are sorted in recency order, with Repeated_Offset1 meaning + // 'most recent one'". + + // Use 0 indexing for the array + u32 idx = seq.offset - 1; + if (seq.literal_length == 0) { + // "There is an exception though, when current sequence's + // literals length is 0. In this case, repeated offsets are + // shifted by one, so Repeated_Offset1 becomes Repeated_Offset2, + // Repeated_Offset2 becomes Repeated_Offset3, and + // Repeated_Offset3 becomes Repeated_Offset1 - 1_byte." + idx++; + } + + if (idx == 0) { + offset = offset_hist[0]; + } else { + // If idx == 3 then literal length was 0 and the offset was 3, + // as per the exception listed above + offset = idx < 3 ? offset_hist[idx] : offset_hist[0] - 1; + + // If idx == 1 we don't need to modify offset_hist[2], since + // we're using the second-most recent code + if (idx > 1) { + offset_hist[2] = offset_hist[1]; + } + offset_hist[1] = offset_hist[0]; + offset_hist[0] = offset; + } + } else { + // When it's not a repeat offset: + // "if (Offset_Value > 3) offset = Offset_Value - 3;" + offset = seq.offset - 3; + + // Shift back history + offset_hist[2] = offset_hist[1]; + offset_hist[1] = offset_hist[0]; + offset_hist[0] = offset; + } + return offset; +} + +static void execute_match_copy(frame_context_t *const ctx, size_t offset, + size_t match_length, size_t total_output, + ostream_t *const out) { + u8 *write_ptr = IO_get_write_ptr(out, match_length); + if (total_output <= ctx->header.window_size) { + // In this case offset might go back into the dictionary + if (offset > total_output + ctx->dict_content_len) { + // The offset goes beyond even the dictionary + CORRUPTION(); + } + + if (offset > total_output) { + // "The rest of the dictionary is its content. The content act + // as a "past" in front of data to compress or decompress, so it + // can be referenced in sequence commands." + const size_t dict_copy = + MIN(offset - total_output, match_length); + const size_t dict_offset = + ctx->dict_content_len - (offset - total_output); + + memcpy(write_ptr, ctx->dict_content + dict_offset, dict_copy); + write_ptr += dict_copy; + match_length -= dict_copy; + } + } else if (offset > ctx->header.window_size) { + CORRUPTION(); + } + + // We must copy byte by byte because the match length might be larger + // than the offset + // ex: if the output so far was "abc", a command with offset=3 and + // match_length=6 would produce "abcabcabc" as the new output + for (size_t j = 0; j < match_length; j++) { + *write_ptr = *(write_ptr - offset); + write_ptr++; + } +} +/******* END SEQUENCE EXECUTION ***********************************************/ + +/******* OUTPUT SIZE COUNTING *************************************************/ +/// Get the decompressed size of an input stream so memory can be allocated in +/// advance. +/// This implementation assumes `src` points to a single ZSTD-compressed frame +size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) { + istream_t in = IO_make_istream(src, src_len); + + // get decompressed size from ZSTD frame header + { + const u32 magic_number = IO_read_bits(&in, 32); + + if (magic_number == 0xFD2FB528U) { + // ZSTD frame + frame_header_t header; + parse_frame_header(&header, &in); + + if (header.frame_content_size == 0 && !header.single_segment_flag) { + // Content size not provided, we can't tell + return -1; + } + + return header.frame_content_size; + } else { + // not a real frame or skippable frame + ERROR("ZSTD frame magic number did not match"); + } + } +} +/******* END OUTPUT SIZE COUNTING *********************************************/ + +/******* DICTIONARY PARSING ***************************************************/ +#define DICT_SIZE_ERROR() ERROR("Dictionary size cannot be less than 8 bytes") +#define NULL_SRC() ERROR("Tried to create dictionary with pointer to null src"); + +dictionary_t* create_dictionary() { + dictionary_t* dict = calloc(1, sizeof(dictionary_t)); + if (!dict) { + BAD_ALLOC(); + } + return dict; +} + +static void init_dictionary_content(dictionary_t *const dict, + istream_t *const in); + +void parse_dictionary(dictionary_t *const dict, const void *src, + size_t src_len) { + const u8 *byte_src = (const u8 *)src; + memset(dict, 0, sizeof(dictionary_t)); + if (src == NULL) { /* cannot initialize dictionary with null src */ + NULL_SRC(); + } + if (src_len < 8) { + DICT_SIZE_ERROR(); + } + + istream_t in = IO_make_istream(byte_src, src_len); + + const u32 magic_number = IO_read_bits(&in, 32); + if (magic_number != 0xEC30A437) { + // raw content dict + IO_rewind_bits(&in, 32); + init_dictionary_content(dict, &in); + return; + } + + dict->dictionary_id = IO_read_bits(&in, 32); + + // "Entropy_Tables : following the same format as the tables in compressed + // blocks. They are stored in following order : Huffman tables for literals, + // FSE table for offsets, FSE table for match lengths, and FSE table for + // literals lengths. It's finally followed by 3 offset values, populating + // recent offsets (instead of using {1,4,8}), stored in order, 4-bytes + // little-endian each, for a total of 12 bytes. Each recent offset must have + // a value < dictionary size." + decode_huf_table(&dict->literals_dtable, &in); + decode_seq_table(&dict->of_dtable, &in, seq_offset, seq_fse); + decode_seq_table(&dict->ml_dtable, &in, seq_match_length, seq_fse); + decode_seq_table(&dict->ll_dtable, &in, seq_literal_length, seq_fse); + + // Read in the previous offset history + dict->previous_offsets[0] = IO_read_bits(&in, 32); + dict->previous_offsets[1] = IO_read_bits(&in, 32); + dict->previous_offsets[2] = IO_read_bits(&in, 32); + + // Ensure the provided offsets aren't too large + // "Each recent offset must have a value < dictionary size." + for (int i = 0; i < 3; i++) { + if (dict->previous_offsets[i] > src_len) { + ERROR("Dictionary corrupted"); + } + } + + // "Content : The rest of the dictionary is its content. The content act as + // a "past" in front of data to compress or decompress, so it can be + // referenced in sequence commands." + init_dictionary_content(dict, &in); +} + +static void init_dictionary_content(dictionary_t *const dict, + istream_t *const in) { + // Copy in the content + dict->content_size = IO_istream_len(in); + dict->content = malloc(dict->content_size); + if (!dict->content) { + BAD_ALLOC(); + } + + const u8 *const content = IO_get_read_ptr(in, dict->content_size); + + memcpy(dict->content, content, dict->content_size); +} + +/// Free an allocated dictionary +void free_dictionary(dictionary_t *const dict) { + HUF_free_dtable(&dict->literals_dtable); + FSE_free_dtable(&dict->ll_dtable); + FSE_free_dtable(&dict->of_dtable); + FSE_free_dtable(&dict->ml_dtable); + + free(dict->content); + + memset(dict, 0, sizeof(dictionary_t)); + + free(dict); +} +/******* END DICTIONARY PARSING ***********************************************/ + +/******* IO STREAM OPERATIONS *************************************************/ +#define UNALIGNED() ERROR("Attempting to operate on a non-byte aligned stream") +/// Reads `num` bits from a bitstream, and updates the internal offset +static inline u64 IO_read_bits(istream_t *const in, const int num_bits) { + if (num_bits > 64 || num_bits <= 0) { + ERROR("Attempt to read an invalid number of bits"); + } + + const size_t bytes = (num_bits + in->bit_offset + 7) / 8; + const size_t full_bytes = (num_bits + in->bit_offset) / 8; + if (bytes > in->len) { + INP_SIZE(); + } + + const u64 result = read_bits_LE(in->ptr, num_bits, in->bit_offset); + + in->bit_offset = (num_bits + in->bit_offset) % 8; + in->ptr += full_bytes; + in->len -= full_bytes; + + return result; +} + +/// If a non-zero number of bits have been read from the current byte, advance +/// the offset to the next byte +static inline void IO_rewind_bits(istream_t *const in, int num_bits) { + if (num_bits < 0) { + ERROR("Attempting to rewind stream by a negative number of bits"); + } + + // move the offset back by `num_bits` bits + const int new_offset = in->bit_offset - num_bits; + // determine the number of whole bytes we have to rewind, rounding up to an + // integer number (e.g. if `new_offset == -5`, `bytes == 1`) + const i64 bytes = -(new_offset - 7) / 8; + + in->ptr -= bytes; + in->len += bytes; + // make sure the resulting `bit_offset` is positive, as mod in C does not + // convert numbers from negative to positive (e.g. -22 % 8 == -6) + in->bit_offset = ((new_offset % 8) + 8) % 8; +} + +/// If the remaining bits in a byte will be unused, advance to the end of the +/// byte +static inline void IO_align_stream(istream_t *const in) { + if (in->bit_offset != 0) { + if (in->len == 0) { + INP_SIZE(); + } + in->ptr++; + in->len--; + in->bit_offset = 0; + } +} + +/// Write the given byte into the output stream +static inline void IO_write_byte(ostream_t *const out, u8 symb) { + if (out->len == 0) { + OUT_SIZE(); + } + + out->ptr[0] = symb; + out->ptr++; + out->len--; +} + +/// Returns the number of bytes left to be read in this stream. The stream must +/// be byte aligned. +static inline size_t IO_istream_len(const istream_t *const in) { + return in->len; +} + +/// Returns a pointer where `len` bytes can be read, and advances the internal +/// state. The stream must be byte aligned. +static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len) { + if (len > in->len) { + INP_SIZE(); + } + if (in->bit_offset != 0) { + UNALIGNED(); + } + const u8 *const ptr = in->ptr; + in->ptr += len; + in->len -= len; + + return ptr; +} +/// Returns a pointer to write `len` bytes to, and advances the internal state +static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len) { + if (len > out->len) { + OUT_SIZE(); + } + u8 *const ptr = out->ptr; + out->ptr += len; + out->len -= len; + + return ptr; +} + +/// Advance the inner state by `len` bytes +static inline void IO_advance_input(istream_t *const in, size_t len) { + if (len > in->len) { + INP_SIZE(); + } + if (in->bit_offset != 0) { + UNALIGNED(); + } + + in->ptr += len; + in->len -= len; +} + +/// Returns an `ostream_t` constructed from the given pointer and length +static inline ostream_t IO_make_ostream(u8 *out, size_t len) { + return (ostream_t) { out, len }; +} + +/// Returns an `istream_t` constructed from the given pointer and length +static inline istream_t IO_make_istream(const u8 *in, size_t len) { + return (istream_t) { in, len, 0 }; +} + +/// Returns an `istream_t` with the same base as `in`, and length `len` +/// Then, advance `in` to account for the consumed bytes +/// `in` must be byte aligned +static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len) { + // Consume `len` bytes of the parent stream + const u8 *const ptr = IO_get_read_ptr(in, len); + + // Make a substream using the pointer to those `len` bytes + return IO_make_istream(ptr, len); +} +/******* END IO STREAM OPERATIONS *********************************************/ + +/******* BITSTREAM OPERATIONS *************************************************/ +/// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits +static inline u64 read_bits_LE(const u8 *src, const int num_bits, + const size_t offset) { + if (num_bits > 64) { + ERROR("Attempt to read an invalid number of bits"); + } + + // Skip over bytes that aren't in range + src += offset / 8; + size_t bit_offset = offset % 8; + u64 res = 0; + + int shift = 0; + int left = num_bits; + while (left > 0) { + u64 mask = left >= 8 ? 0xff : (((u64)1 << left) - 1); + // Read the next byte, shift it to account for the offset, and then mask + // out the top part if we don't need all the bits + res += (((u64)*src++ >> bit_offset) & mask) << shift; + shift += 8 - bit_offset; + left -= 8 - bit_offset; + bit_offset = 0; + } + + return res; +} + +/// Read bits from the end of a HUF or FSE bitstream. `offset` is in bits, so +/// it updates `offset` to `offset - bits`, and then reads `bits` bits from +/// `src + offset`. If the offset becomes negative, the extra bits at the +/// bottom are filled in with `0` bits instead of reading from before `src`. +static inline u64 STREAM_read_bits(const u8 *const src, const int bits, + i64 *const offset) { + *offset = *offset - bits; + size_t actual_off = *offset; + size_t actual_bits = bits; + // Don't actually read bits from before the start of src, so if `*offset < + // 0` fix actual_off and actual_bits to reflect the quantity to read + if (*offset < 0) { + actual_bits += *offset; + actual_off = 0; + } + u64 res = read_bits_LE(src, actual_bits, actual_off); + + if (*offset < 0) { + // Fill in the bottom "overflowed" bits with 0's + res = -*offset >= 64 ? 0 : (res << -*offset); + } + return res; +} +/******* END BITSTREAM OPERATIONS *********************************************/ + +/******* BIT COUNTING OPERATIONS **********************************************/ +/// Returns `x`, where `2^x` is the largest power of 2 less than or equal to +/// `num`, or `-1` if `num == 0`. +static inline int highest_set_bit(const u64 num) { + for (int i = 63; i >= 0; i--) { + if (((u64)1 << i) <= num) { + return i; + } + } + return -1; +} +/******* END BIT COUNTING OPERATIONS ******************************************/ + +/******* HUFFMAN PRIMITIVES ***************************************************/ +static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable, + u16 *const state, const u8 *const src, + i64 *const offset) { + // Look up the symbol and number of bits to read + const u8 symb = dtable->symbols[*state]; + const u8 bits = dtable->num_bits[*state]; + const u16 rest = STREAM_read_bits(src, bits, offset); + // Shift `bits` bits out of the state, keeping the low order bits that + // weren't necessary to determine this symbol. Then add in the new bits + // read from the stream. + *state = ((*state << bits) + rest) & (((u16)1 << dtable->max_bits) - 1); + + return symb; +} + +static inline void HUF_init_state(const HUF_dtable *const dtable, + u16 *const state, const u8 *const src, + i64 *const offset) { + // Read in a full `dtable->max_bits` bits to initialize the state + const u8 bits = dtable->max_bits; + *state = STREAM_read_bits(src, bits, offset); +} + +static size_t HUF_decompress_1stream(const HUF_dtable *const dtable, + ostream_t *const out, + istream_t *const in) { + const size_t len = IO_istream_len(in); + if (len == 0) { + INP_SIZE(); + } + const u8 *const src = IO_get_read_ptr(in, len); + + // "Each bitstream must be read backward, that is starting from the end down + // to the beginning. Therefore it's necessary to know the size of each + // bitstream. + // + // It's also necessary to know exactly which bit is the latest. This is + // detected by a final bit flag : the highest bit of latest byte is a + // final-bit-flag. Consequently, a last byte of 0 is not possible. And the + // final-bit-flag itself is not part of the useful bitstream. Hence, the + // last byte contains between 0 and 7 useful bits." + const int padding = 8 - highest_set_bit(src[len - 1]); + + // Offset starts at the end because HUF streams are read backwards + i64 bit_offset = len * 8 - padding; + u16 state; + + HUF_init_state(dtable, &state, src, &bit_offset); + + size_t symbols_written = 0; + while (bit_offset > -dtable->max_bits) { + // Iterate over the stream, decoding one symbol at a time + IO_write_byte(out, HUF_decode_symbol(dtable, &state, src, &bit_offset)); + symbols_written++; + } + // "The process continues up to reading the required number of symbols per + // stream. If a bitstream is not entirely and exactly consumed, hence + // reaching exactly its beginning position with all bits consumed, the + // decoding process is considered faulty." + + // When all symbols have been decoded, the final state value shouldn't have + // any data from the stream, so it should have "read" dtable->max_bits from + // before the start of `src` + // Therefore `offset`, the edge to start reading new bits at, should be + // dtable->max_bits before the start of the stream + if (bit_offset != -dtable->max_bits) { + CORRUPTION(); + } + + return symbols_written; +} + +static size_t HUF_decompress_4stream(const HUF_dtable *const dtable, + ostream_t *const out, istream_t *const in) { + // "Compressed size is provided explicitly : in the 4-streams variant, + // bitstreams are preceded by 3 unsigned little-endian 16-bits values. Each + // value represents the compressed size of one stream, in order. The last + // stream size is deducted from total compressed size and from previously + // decoded stream sizes" + const size_t csize1 = IO_read_bits(in, 16); + const size_t csize2 = IO_read_bits(in, 16); + const size_t csize3 = IO_read_bits(in, 16); + + istream_t in1 = IO_make_sub_istream(in, csize1); + istream_t in2 = IO_make_sub_istream(in, csize2); + istream_t in3 = IO_make_sub_istream(in, csize3); + istream_t in4 = IO_make_sub_istream(in, IO_istream_len(in)); + + size_t total_output = 0; + // Decode each stream independently for simplicity + // If we wanted to we could decode all 4 at the same time for speed, + // utilizing more execution units + total_output += HUF_decompress_1stream(dtable, out, &in1); + total_output += HUF_decompress_1stream(dtable, out, &in2); + total_output += HUF_decompress_1stream(dtable, out, &in3); + total_output += HUF_decompress_1stream(dtable, out, &in4); + + return total_output; +} + +/// Initializes a Huffman table using canonical Huffman codes +/// For more explanation on canonical Huffman codes see +/// http://www.cs.uofs.edu/~mccloske/courses/cmps340/huff_canonical_dec2015.html +/// Codes within a level are allocated in symbol order (i.e. smaller symbols get +/// earlier codes) +static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits, + const int num_symbs) { + memset(table, 0, sizeof(HUF_dtable)); + if (num_symbs > HUF_MAX_SYMBS) { + ERROR("Too many symbols for Huffman"); + } + + u8 max_bits = 0; + u16 rank_count[HUF_MAX_BITS + 1]; + memset(rank_count, 0, sizeof(rank_count)); + + // Count the number of symbols for each number of bits, and determine the + // depth of the tree + for (int i = 0; i < num_symbs; i++) { + if (bits[i] > HUF_MAX_BITS) { + ERROR("Huffman table depth too large"); + } + max_bits = MAX(max_bits, bits[i]); + rank_count[bits[i]]++; + } + + const size_t table_size = 1 << max_bits; + table->max_bits = max_bits; + table->symbols = malloc(table_size); + table->num_bits = malloc(table_size); + + if (!table->symbols || !table->num_bits) { + free(table->symbols); + free(table->num_bits); + BAD_ALLOC(); + } + + // "Symbols are sorted by Weight. Within same Weight, symbols keep natural + // order. Symbols with a Weight of zero are removed. Then, starting from + // lowest weight, prefix codes are distributed in order." + + u32 rank_idx[HUF_MAX_BITS + 1]; + // Initialize the starting codes for each rank (number of bits) + rank_idx[max_bits] = 0; + for (int i = max_bits; i >= 1; i--) { + rank_idx[i - 1] = rank_idx[i] + rank_count[i] * (1 << (max_bits - i)); + // The entire range takes the same number of bits so we can memset it + memset(&table->num_bits[rank_idx[i]], i, rank_idx[i - 1] - rank_idx[i]); + } + + if (rank_idx[0] != table_size) { + CORRUPTION(); + } + + // Allocate codes and fill in the table + for (int i = 0; i < num_symbs; i++) { + if (bits[i] != 0) { + // Allocate a code for this symbol and set its range in the table + const u16 code = rank_idx[bits[i]]; + // Since the code doesn't care about the bottom `max_bits - bits[i]` + // bits of state, it gets a range that spans all possible values of + // the lower bits + const u16 len = 1 << (max_bits - bits[i]); + memset(&table->symbols[code], i, len); + rank_idx[bits[i]] += len; + } + } +} + +static void HUF_init_dtable_usingweights(HUF_dtable *const table, + const u8 *const weights, + const int num_symbs) { + // +1 because the last weight is not transmitted in the header + if (num_symbs + 1 > HUF_MAX_SYMBS) { + ERROR("Too many symbols for Huffman"); + } + + u8 bits[HUF_MAX_SYMBS]; + + u64 weight_sum = 0; + for (int i = 0; i < num_symbs; i++) { + // Weights are in the same range as bit count + if (weights[i] > HUF_MAX_BITS) { + CORRUPTION(); + } + weight_sum += weights[i] > 0 ? (u64)1 << (weights[i] - 1) : 0; + } + + // Find the first power of 2 larger than the sum + const int max_bits = highest_set_bit(weight_sum) + 1; + const u64 left_over = ((u64)1 << max_bits) - weight_sum; + // If the left over isn't a power of 2, the weights are invalid + if (left_over & (left_over - 1)) { + CORRUPTION(); + } + + // left_over is used to find the last weight as it's not transmitted + // by inverting 2^(weight - 1) we can determine the value of last_weight + const int last_weight = highest_set_bit(left_over) + 1; + + for (int i = 0; i < num_symbs; i++) { + // "Number_of_Bits = Number_of_Bits ? Max_Number_of_Bits + 1 - Weight : 0" + bits[i] = weights[i] > 0 ? (max_bits + 1 - weights[i]) : 0; + } + bits[num_symbs] = + max_bits + 1 - last_weight; // Last weight is always non-zero + + HUF_init_dtable(table, bits, num_symbs + 1); +} + +static void HUF_free_dtable(HUF_dtable *const dtable) { + free(dtable->symbols); + free(dtable->num_bits); + memset(dtable, 0, sizeof(HUF_dtable)); +} + +static void HUF_copy_dtable(HUF_dtable *const dst, + const HUF_dtable *const src) { + if (src->max_bits == 0) { + memset(dst, 0, sizeof(HUF_dtable)); + return; + } + + const size_t size = (size_t)1 << src->max_bits; + dst->max_bits = src->max_bits; + + dst->symbols = malloc(size); + dst->num_bits = malloc(size); + if (!dst->symbols || !dst->num_bits) { + BAD_ALLOC(); + } + + memcpy(dst->symbols, src->symbols, size); + memcpy(dst->num_bits, src->num_bits, size); +} +/******* END HUFFMAN PRIMITIVES ***********************************************/ + +/******* FSE PRIMITIVES *******************************************************/ +/// For more description of FSE see +/// https://github.com/Cyan4973/FiniteStateEntropy/ + +/// Allow a symbol to be decoded without updating state +static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable, + const u16 state) { + return dtable->symbols[state]; +} + +/// Consumes bits from the input and uses the current state to determine the +/// next state +static inline void FSE_update_state(const FSE_dtable *const dtable, + u16 *const state, const u8 *const src, + i64 *const offset) { + const u8 bits = dtable->num_bits[*state]; + const u16 rest = STREAM_read_bits(src, bits, offset); + *state = dtable->new_state_base[*state] + rest; +} + +/// Decodes a single FSE symbol and updates the offset +static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable, + u16 *const state, const u8 *const src, + i64 *const offset) { + const u8 symb = FSE_peek_symbol(dtable, *state); + FSE_update_state(dtable, state, src, offset); + return symb; +} + +static inline void FSE_init_state(const FSE_dtable *const dtable, + u16 *const state, const u8 *const src, + i64 *const offset) { + // Read in a full `accuracy_log` bits to initialize the state + const u8 bits = dtable->accuracy_log; + *state = STREAM_read_bits(src, bits, offset); +} + +static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable, + ostream_t *const out, + istream_t *const in) { + const size_t len = IO_istream_len(in); + if (len == 0) { + INP_SIZE(); + } + const u8 *const src = IO_get_read_ptr(in, len); + + // "Each bitstream must be read backward, that is starting from the end down + // to the beginning. Therefore it's necessary to know the size of each + // bitstream. + // + // It's also necessary to know exactly which bit is the latest. This is + // detected by a final bit flag : the highest bit of latest byte is a + // final-bit-flag. Consequently, a last byte of 0 is not possible. And the + // final-bit-flag itself is not part of the useful bitstream. Hence, the + // last byte contains between 0 and 7 useful bits." + const int padding = 8 - highest_set_bit(src[len - 1]); + i64 offset = len * 8 - padding; + + u16 state1, state2; + // "The first state (State1) encodes the even indexed symbols, and the + // second (State2) encodes the odd indexes. State1 is initialized first, and + // then State2, and they take turns decoding a single symbol and updating + // their state." + FSE_init_state(dtable, &state1, src, &offset); + FSE_init_state(dtable, &state2, src, &offset); + + // Decode until we overflow the stream + // Since we decode in reverse order, overflowing the stream is offset going + // negative + size_t symbols_written = 0; + while (1) { + // "The number of symbols to decode is determined by tracking bitStream + // overflow condition: If updating state after decoding a symbol would + // require more bits than remain in the stream, it is assumed the extra + // bits are 0. Then, the symbols for each of the final states are + // decoded and the process is complete." + IO_write_byte(out, FSE_decode_symbol(dtable, &state1, src, &offset)); + symbols_written++; + if (offset < 0) { + // There's still a symbol to decode in state2 + IO_write_byte(out, FSE_peek_symbol(dtable, state2)); + symbols_written++; + break; + } + + IO_write_byte(out, FSE_decode_symbol(dtable, &state2, src, &offset)); + symbols_written++; + if (offset < 0) { + // There's still a symbol to decode in state1 + IO_write_byte(out, FSE_peek_symbol(dtable, state1)); + symbols_written++; + break; + } + } + + return symbols_written; +} + +static void FSE_init_dtable(FSE_dtable *const dtable, + const i16 *const norm_freqs, const int num_symbs, + const int accuracy_log) { + if (accuracy_log > FSE_MAX_ACCURACY_LOG) { + ERROR("FSE accuracy too large"); + } + if (num_symbs > FSE_MAX_SYMBS) { + ERROR("Too many symbols for FSE"); + } + + dtable->accuracy_log = accuracy_log; + + const size_t size = (size_t)1 << accuracy_log; + dtable->symbols = malloc(size * sizeof(u8)); + dtable->num_bits = malloc(size * sizeof(u8)); + dtable->new_state_base = malloc(size * sizeof(u16)); + + if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) { + BAD_ALLOC(); + } + + // Used to determine how many bits need to be read for each state, + // and where the destination range should start + // Needs to be u16 because max value is 2 * max number of symbols, + // which can be larger than a byte can store + u16 state_desc[FSE_MAX_SYMBS]; + + // "Symbols are scanned in their natural order for "less than 1" + // probabilities. Symbols with this probability are being attributed a + // single cell, starting from the end of the table. These symbols define a + // full state reset, reading Accuracy_Log bits." + int high_threshold = size; + for (int s = 0; s < num_symbs; s++) { + // Scan for low probability symbols to put at the top + if (norm_freqs[s] == -1) { + dtable->symbols[--high_threshold] = s; + state_desc[s] = 1; + } + } + + // "All remaining symbols are sorted in their natural order. Starting from + // symbol 0 and table position 0, each symbol gets attributed as many cells + // as its probability. Cell allocation is spreaded, not linear." + // Place the rest in the table + const u16 step = (size >> 1) + (size >> 3) + 3; + const u16 mask = size - 1; + u16 pos = 0; + for (int s = 0; s < num_symbs; s++) { + if (norm_freqs[s] <= 0) { + continue; + } + + state_desc[s] = norm_freqs[s]; + + for (int i = 0; i < norm_freqs[s]; i++) { + // Give `norm_freqs[s]` states to symbol s + dtable->symbols[pos] = s; + // "A position is skipped if already occupied, typically by a "less + // than 1" probability symbol." + do { + pos = (pos + step) & mask; + } while (pos >= + high_threshold); + // Note: no other collision checking is necessary as `step` is + // coprime to `size`, so the cycle will visit each position exactly + // once + } + } + if (pos != 0) { + CORRUPTION(); + } + + // Now we can fill baseline and num bits + for (size_t i = 0; i < size; i++) { + u8 symbol = dtable->symbols[i]; + u16 next_state_desc = state_desc[symbol]++; + // Fills in the table appropriately, next_state_desc increases by symbol + // over time, decreasing number of bits + dtable->num_bits[i] = (u8)(accuracy_log - highest_set_bit(next_state_desc)); + // Baseline increases until the bit threshold is passed, at which point + // it resets to 0 + dtable->new_state_base[i] = + ((u16)next_state_desc << dtable->num_bits[i]) - size; + } +} + +/// Decode an FSE header as defined in the Zstandard format specification and +/// use the decoded frequencies to initialize a decoding table. +static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in, + const int max_accuracy_log) { + // "An FSE distribution table describes the probabilities of all symbols + // from 0 to the last present one (included) on a normalized scale of 1 << + // Accuracy_Log . + // + // It's a bitstream which is read forward, in little-endian fashion. It's + // not necessary to know its exact size, since it will be discovered and + // reported by the decoding process. + if (max_accuracy_log > FSE_MAX_ACCURACY_LOG) { + ERROR("FSE accuracy too large"); + } + + // The bitstream starts by reporting on which scale it operates. + // Accuracy_Log = low4bits + 5. Note that maximum Accuracy_Log for literal + // and match lengths is 9, and for offsets is 8. Higher values are + // considered errors." + const int accuracy_log = 5 + IO_read_bits(in, 4); + if (accuracy_log > max_accuracy_log) { + ERROR("FSE accuracy too large"); + } + + // "Then follows each symbol value, from 0 to last present one. The number + // of bits used by each field is variable. It depends on : + // + // Remaining probabilities + 1 : example : Presuming an Accuracy_Log of 8, + // and presuming 100 probabilities points have already been distributed, the + // decoder may read any value from 0 to 255 - 100 + 1 == 156 (inclusive). + // Therefore, it must read log2sup(156) == 8 bits. + // + // Value decoded : small values use 1 less bit : example : Presuming values + // from 0 to 156 (inclusive) are possible, 255-156 = 99 values are remaining + // in an 8-bits field. They are used this way : first 99 values (hence from + // 0 to 98) use only 7 bits, values from 99 to 156 use 8 bits. " + + i32 remaining = 1 << accuracy_log; + i16 frequencies[FSE_MAX_SYMBS]; + + int symb = 0; + while (remaining > 0 && symb < FSE_MAX_SYMBS) { + // Log of the number of possible values we could read + int bits = highest_set_bit(remaining + 1) + 1; + + u16 val = IO_read_bits(in, bits); + + // Try to mask out the lower bits to see if it qualifies for the "small + // value" threshold + const u16 lower_mask = ((u16)1 << (bits - 1)) - 1; + const u16 threshold = ((u16)1 << bits) - 1 - (remaining + 1); + + if ((val & lower_mask) < threshold) { + IO_rewind_bits(in, 1); + val = val & lower_mask; + } else if (val > lower_mask) { + val = val - threshold; + } + + // "Probability is obtained from Value decoded by following formula : + // Proba = value - 1" + const i16 proba = (i16)val - 1; + + // "It means value 0 becomes negative probability -1. -1 is a special + // probability, which means "less than 1". Its effect on distribution + // table is described in next paragraph. For the purpose of calculating + // cumulated distribution, it counts as one." + remaining -= proba < 0 ? -proba : proba; + + frequencies[symb] = proba; + symb++; + + // "When a symbol has a probability of zero, it is followed by a 2-bits + // repeat flag. This repeat flag tells how many probabilities of zeroes + // follow the current one. It provides a number ranging from 0 to 3. If + // it is a 3, another 2-bits repeat flag follows, and so on." + if (proba == 0) { + // Read the next two bits to see how many more 0s + int repeat = IO_read_bits(in, 2); + + while (1) { + for (int i = 0; i < repeat && symb < FSE_MAX_SYMBS; i++) { + frequencies[symb++] = 0; + } + if (repeat == 3) { + repeat = IO_read_bits(in, 2); + } else { + break; + } + } + } + } + IO_align_stream(in); + + // "When last symbol reaches cumulated total of 1 << Accuracy_Log, decoding + // is complete. If the last symbol makes cumulated total go above 1 << + // Accuracy_Log, distribution is considered corrupted." + if (remaining != 0 || symb >= FSE_MAX_SYMBS) { + CORRUPTION(); + } + + // Initialize the decoding table using the determined weights + FSE_init_dtable(dtable, frequencies, symb, accuracy_log); +} + +static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb) { + dtable->symbols = malloc(sizeof(u8)); + dtable->num_bits = malloc(sizeof(u8)); + dtable->new_state_base = malloc(sizeof(u16)); + + if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) { + BAD_ALLOC(); + } + + // This setup will always have a state of 0, always return symbol `symb`, + // and never consume any bits + dtable->symbols[0] = symb; + dtable->num_bits[0] = 0; + dtable->new_state_base[0] = 0; + dtable->accuracy_log = 0; +} + +static void FSE_free_dtable(FSE_dtable *const dtable) { + free(dtable->symbols); + free(dtable->num_bits); + free(dtable->new_state_base); + memset(dtable, 0, sizeof(FSE_dtable)); +} + +static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src) { + if (src->accuracy_log == 0) { + memset(dst, 0, sizeof(FSE_dtable)); + return; + } + + size_t size = (size_t)1 << src->accuracy_log; + dst->accuracy_log = src->accuracy_log; + + dst->symbols = malloc(size); + dst->num_bits = malloc(size); + dst->new_state_base = malloc(size * sizeof(u16)); + if (!dst->symbols || !dst->num_bits || !dst->new_state_base) { + BAD_ALLOC(); + } + + memcpy(dst->symbols, src->symbols, size); + memcpy(dst->num_bits, src->num_bits, size); + memcpy(dst->new_state_base, src->new_state_base, size * sizeof(u16)); +} +/******* END FSE PRIMITIVES ***************************************************/ |