summaryrefslogtreecommitdiffstats
path: root/src/zstd/doc/educational_decoder
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-27 18:24:20 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-27 18:24:20 +0000
commit483eb2f56657e8e7f419ab1a4fab8dce9ade8609 (patch)
treee5d88d25d870d5dedacb6bbdbe2a966086a0a5cf /src/zstd/doc/educational_decoder
parentInitial commit. (diff)
downloadceph-upstream.tar.xz
ceph-upstream.zip
Adding upstream version 14.2.21.upstream/14.2.21upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
-rw-r--r--src/zstd/doc/educational_decoder/Makefile34
-rw-r--r--src/zstd/doc/educational_decoder/README.md29
-rw-r--r--src/zstd/doc/educational_decoder/harness.c125
-rw-r--r--src/zstd/doc/educational_decoder/zstd_decompress.c2303
-rw-r--r--src/zstd/doc/educational_decoder/zstd_decompress.h58
5 files changed, 2549 insertions, 0 deletions
diff --git a/src/zstd/doc/educational_decoder/Makefile b/src/zstd/doc/educational_decoder/Makefile
new file mode 100644
index 00000000..ace1294f
--- /dev/null
+++ b/src/zstd/doc/educational_decoder/Makefile
@@ -0,0 +1,34 @@
+HARNESS_FILES=*.c
+
+MULTITHREAD_LDFLAGS = -pthread
+DEBUGFLAGS= -g -DZSTD_DEBUG=1
+CPPFLAGS += -I$(ZSTDDIR) -I$(ZSTDDIR)/common -I$(ZSTDDIR)/compress \
+ -I$(ZSTDDIR)/dictBuilder -I$(ZSTDDIR)/deprecated -I$(PRGDIR)
+CFLAGS ?= -O3
+CFLAGS += -Wall -Wextra -Wcast-qual -Wcast-align -Wshadow \
+ -Wstrict-aliasing=1 -Wswitch-enum -Wdeclaration-after-statement \
+ -Wstrict-prototypes -Wundef -Wformat-security \
+ -Wvla -Wformat=2 -Winit-self -Wfloat-equal -Wwrite-strings \
+ -Wredundant-decls
+CFLAGS += $(DEBUGFLAGS)
+CFLAGS += $(MOREFLAGS)
+FLAGS = $(CPPFLAGS) $(CFLAGS) $(LDFLAGS) $(MULTITHREAD_LDFLAGS)
+
+harness: $(HARNESS_FILES)
+ $(CC) $(FLAGS) $^ -o $@
+
+clean:
+ @$(RM) -f harness
+ @$(RM) -rf harness.dSYM
+
+test: harness
+ @zstd README.md -o tmp.zst
+ @./harness tmp.zst tmp
+ @diff -s tmp README.md
+ @$(RM) -f tmp*
+ @zstd --train harness.c zstd_decompress.c zstd_decompress.h README.md
+ @zstd -D dictionary README.md -o tmp.zst
+ @./harness tmp.zst tmp dictionary
+ @diff -s tmp README.md
+ @$(RM) -f tmp* dictionary
+ @make clean
diff --git a/src/zstd/doc/educational_decoder/README.md b/src/zstd/doc/educational_decoder/README.md
new file mode 100644
index 00000000..e3b9bf58
--- /dev/null
+++ b/src/zstd/doc/educational_decoder/README.md
@@ -0,0 +1,29 @@
+Educational Decoder
+===================
+
+`zstd_decompress.c` is a self-contained implementation in C99 of a decoder,
+according to the [Zstandard format specification].
+While it does not implement as many features as the reference decoder,
+such as the streaming API or content checksums, it is written to be easy to
+follow and understand, to help understand how the Zstandard format works.
+It's laid out to match the [format specification],
+so it can be used to understand how complex segments could be implemented.
+It also contains implementations of Huffman and FSE table decoding.
+
+[Zstandard format specification]: https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
+[format specification]: https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
+
+`harness.c` provides a simple test harness around the decoder:
+
+ harness <input-file> <output-file> [dictionary]
+
+As an additional resource to be used with this decoder,
+see the `decodecorpus` tool in the [tests] directory.
+It generates valid Zstandard frames that can be used to verify
+a Zstandard decoder implementation.
+Note that to use the tool to verify this decoder implementation,
+the --content-size flag should be set,
+as this decoder does not handle streaming decoding,
+and so it must know the decompressed size in advance.
+
+[tests]: https://github.com/facebook/zstd/blob/dev/tests/
diff --git a/src/zstd/doc/educational_decoder/harness.c b/src/zstd/doc/educational_decoder/harness.c
new file mode 100644
index 00000000..47882b16
--- /dev/null
+++ b/src/zstd/doc/educational_decoder/harness.c
@@ -0,0 +1,125 @@
+/*
+ * Copyright (c) 2017-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under both the BSD-style license (found in the
+ * LICENSE file in the root directory of this source tree) and the GPLv2 (found
+ * in the COPYING file in the root directory of this source tree).
+ */
+
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "zstd_decompress.h"
+
+typedef unsigned char u8;
+
+// If the data doesn't have decompressed size with it, fallback on assuming the
+// compression ratio is at most 16
+#define MAX_COMPRESSION_RATIO (16)
+
+// Protect against allocating too much memory for output
+#define MAX_OUTPUT_SIZE ((size_t)1024 * 1024 * 1024)
+
+u8 *input;
+u8 *output;
+u8 *dict;
+
+size_t read_file(const char *path, u8 **ptr) {
+ FILE *f = fopen(path, "rb");
+ if (!f) {
+ fprintf(stderr, "failed to open file %s\n", path);
+ exit(1);
+ }
+
+ fseek(f, 0L, SEEK_END);
+ size_t size = ftell(f);
+ rewind(f);
+
+ *ptr = malloc(size);
+ if (!ptr) {
+ fprintf(stderr, "failed to allocate memory to hold %s\n", path);
+ exit(1);
+ }
+
+ size_t pos = 0;
+ while (!feof(f)) {
+ size_t read = fread(&(*ptr)[pos], 1, size, f);
+ if (ferror(f)) {
+ fprintf(stderr, "error while reading file %s\n", path);
+ exit(1);
+ }
+ pos += read;
+ }
+
+ fclose(f);
+
+ return pos;
+}
+
+void write_file(const char *path, const u8 *ptr, size_t size) {
+ FILE *f = fopen(path, "wb");
+
+ size_t written = 0;
+ while (written < size) {
+ written += fwrite(&ptr[written], 1, size, f);
+ if (ferror(f)) {
+ fprintf(stderr, "error while writing file %s\n", path);
+ exit(1);
+ }
+ }
+
+ fclose(f);
+}
+
+int main(int argc, char **argv) {
+ if (argc < 3) {
+ fprintf(stderr, "usage: %s <file.zst> <out_path> [dictionary]\n",
+ argv[0]);
+
+ return 1;
+ }
+
+ size_t input_size = read_file(argv[1], &input);
+ size_t dict_size = 0;
+ if (argc >= 4) {
+ dict_size = read_file(argv[3], &dict);
+ }
+
+ size_t decompressed_size = ZSTD_get_decompressed_size(input, input_size);
+ if (decompressed_size == (size_t)-1) {
+ decompressed_size = MAX_COMPRESSION_RATIO * input_size;
+ fprintf(stderr, "WARNING: Compressed data does not contain "
+ "decompressed size, going to assume the compression "
+ "ratio is at most %d (decompressed size of at most "
+ "%zu)\n",
+ MAX_COMPRESSION_RATIO, decompressed_size);
+ }
+ if (decompressed_size > MAX_OUTPUT_SIZE) {
+ fprintf(stderr,
+ "Required output size too large for this implementation\n");
+ return 1;
+ }
+ output = malloc(decompressed_size);
+ if (!output) {
+ fprintf(stderr, "failed to allocate memory\n");
+ return 1;
+ }
+
+ dictionary_t* const parsed_dict = create_dictionary();
+ if (dict) {
+ parse_dictionary(parsed_dict, dict, dict_size);
+ }
+ size_t decompressed =
+ ZSTD_decompress_with_dict(output, decompressed_size,
+ input, input_size, parsed_dict);
+
+ free_dictionary(parsed_dict);
+
+ write_file(argv[2], output, decompressed);
+
+ free(input);
+ free(output);
+ free(dict);
+ input = output = dict = NULL;
+}
diff --git a/src/zstd/doc/educational_decoder/zstd_decompress.c b/src/zstd/doc/educational_decoder/zstd_decompress.c
new file mode 100644
index 00000000..bea0e0ce
--- /dev/null
+++ b/src/zstd/doc/educational_decoder/zstd_decompress.c
@@ -0,0 +1,2303 @@
+/*
+ * Copyright (c) 2017-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under both the BSD-style license (found in the
+ * LICENSE file in the root directory of this source tree) and the GPLv2 (found
+ * in the COPYING file in the root directory of this source tree).
+ */
+
+/// Zstandard educational decoder implementation
+/// See https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
+
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include "zstd_decompress.h"
+
+/******* UTILITY MACROS AND TYPES *********************************************/
+// Max block size decompressed size is 128 KB and literal blocks can't be
+// larger than their block
+#define MAX_LITERALS_SIZE ((size_t)128 * 1024)
+
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+
+/// This decoder calls exit(1) when it encounters an error, however a production
+/// library should propagate error codes
+#define ERROR(s) \
+ do { \
+ fprintf(stderr, "Error: %s\n", s); \
+ exit(1); \
+ } while (0)
+#define INP_SIZE() \
+ ERROR("Input buffer smaller than it should be or input is " \
+ "corrupted")
+#define OUT_SIZE() ERROR("Output buffer too small for output")
+#define CORRUPTION() ERROR("Corruption detected while decompressing")
+#define BAD_ALLOC() ERROR("Memory allocation error")
+#define IMPOSSIBLE() ERROR("An impossibility has occurred")
+
+typedef uint8_t u8;
+typedef uint16_t u16;
+typedef uint32_t u32;
+typedef uint64_t u64;
+
+typedef int8_t i8;
+typedef int16_t i16;
+typedef int32_t i32;
+typedef int64_t i64;
+/******* END UTILITY MACROS AND TYPES *****************************************/
+
+/******* IMPLEMENTATION PRIMITIVE PROTOTYPES **********************************/
+/// The implementations for these functions can be found at the bottom of this
+/// file. They implement low-level functionality needed for the higher level
+/// decompression functions.
+
+/*** IO STREAM OPERATIONS *************/
+
+/// ostream_t/istream_t are used to wrap the pointers/length data passed into
+/// ZSTD_decompress, so that all IO operations are safely bounds checked
+/// They are written/read forward, and reads are treated as little-endian
+/// They should be used opaquely to ensure safety
+typedef struct {
+ u8 *ptr;
+ size_t len;
+} ostream_t;
+
+typedef struct {
+ const u8 *ptr;
+ size_t len;
+
+ // Input often reads a few bits at a time, so maintain an internal offset
+ int bit_offset;
+} istream_t;
+
+/// The following two functions are the only ones that allow the istream to be
+/// non-byte aligned
+
+/// Reads `num` bits from a bitstream, and updates the internal offset
+static inline u64 IO_read_bits(istream_t *const in, const int num_bits);
+/// Backs-up the stream by `num` bits so they can be read again
+static inline void IO_rewind_bits(istream_t *const in, const int num_bits);
+/// If the remaining bits in a byte will be unused, advance to the end of the
+/// byte
+static inline void IO_align_stream(istream_t *const in);
+
+/// Write the given byte into the output stream
+static inline void IO_write_byte(ostream_t *const out, u8 symb);
+
+/// Returns the number of bytes left to be read in this stream. The stream must
+/// be byte aligned.
+static inline size_t IO_istream_len(const istream_t *const in);
+
+/// Advances the stream by `len` bytes, and returns a pointer to the chunk that
+/// was skipped. The stream must be byte aligned.
+static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len);
+/// Advances the stream by `len` bytes, and returns a pointer to the chunk that
+/// was skipped so it can be written to.
+static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len);
+
+/// Advance the inner state by `len` bytes. The stream must be byte aligned.
+static inline void IO_advance_input(istream_t *const in, size_t len);
+
+/// Returns an `ostream_t` constructed from the given pointer and length.
+static inline ostream_t IO_make_ostream(u8 *out, size_t len);
+/// Returns an `istream_t` constructed from the given pointer and length.
+static inline istream_t IO_make_istream(const u8 *in, size_t len);
+
+/// Returns an `istream_t` with the same base as `in`, and length `len`.
+/// Then, advance `in` to account for the consumed bytes.
+/// `in` must be byte aligned.
+static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len);
+/*** END IO STREAM OPERATIONS *********/
+
+/*** BITSTREAM OPERATIONS *************/
+/// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits,
+/// and return them interpreted as a little-endian unsigned integer.
+static inline u64 read_bits_LE(const u8 *src, const int num_bits,
+ const size_t offset);
+
+/// Read bits from the end of a HUF or FSE bitstream. `offset` is in bits, so
+/// it updates `offset` to `offset - bits`, and then reads `bits` bits from
+/// `src + offset`. If the offset becomes negative, the extra bits at the
+/// bottom are filled in with `0` bits instead of reading from before `src`.
+static inline u64 STREAM_read_bits(const u8 *src, const int bits,
+ i64 *const offset);
+/*** END BITSTREAM OPERATIONS *********/
+
+/*** BIT COUNTING OPERATIONS **********/
+/// Returns the index of the highest set bit in `num`, or `-1` if `num == 0`
+static inline int highest_set_bit(const u64 num);
+/*** END BIT COUNTING OPERATIONS ******/
+
+/*** HUFFMAN PRIMITIVES ***************/
+// Table decode method uses exponential memory, so we need to limit depth
+#define HUF_MAX_BITS (16)
+
+// Limit the maximum number of symbols to 256 so we can store a symbol in a byte
+#define HUF_MAX_SYMBS (256)
+
+/// Structure containing all tables necessary for efficient Huffman decoding
+typedef struct {
+ u8 *symbols;
+ u8 *num_bits;
+ int max_bits;
+} HUF_dtable;
+
+/// Decode a single symbol and read in enough bits to refresh the state
+static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
+ u16 *const state, const u8 *const src,
+ i64 *const offset);
+/// Read in a full state's worth of bits to initialize it
+static inline void HUF_init_state(const HUF_dtable *const dtable,
+ u16 *const state, const u8 *const src,
+ i64 *const offset);
+
+/// Decompresses a single Huffman stream, returns the number of bytes decoded.
+/// `src_len` must be the exact length of the Huffman-coded block.
+static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
+ ostream_t *const out, istream_t *const in);
+/// Same as previous but decodes 4 streams, formatted as in the Zstandard
+/// specification.
+/// `src_len` must be the exact length of the Huffman-coded block.
+static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
+ ostream_t *const out, istream_t *const in);
+
+/// Initialize a Huffman decoding table using the table of bit counts provided
+static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
+ const int num_symbs);
+/// Initialize a Huffman decoding table using the table of weights provided
+/// Weights follow the definition provided in the Zstandard specification
+static void HUF_init_dtable_usingweights(HUF_dtable *const table,
+ const u8 *const weights,
+ const int num_symbs);
+
+/// Free the malloc'ed parts of a decoding table
+static void HUF_free_dtable(HUF_dtable *const dtable);
+
+/// Deep copy a decoding table, so that it can be used and free'd without
+/// impacting the source table.
+static void HUF_copy_dtable(HUF_dtable *const dst, const HUF_dtable *const src);
+/*** END HUFFMAN PRIMITIVES ***********/
+
+/*** FSE PRIMITIVES *******************/
+/// For more description of FSE see
+/// https://github.com/Cyan4973/FiniteStateEntropy/
+
+// FSE table decoding uses exponential memory, so limit the maximum accuracy
+#define FSE_MAX_ACCURACY_LOG (15)
+// Limit the maximum number of symbols so they can be stored in a single byte
+#define FSE_MAX_SYMBS (256)
+
+/// The tables needed to decode FSE encoded streams
+typedef struct {
+ u8 *symbols;
+ u8 *num_bits;
+ u16 *new_state_base;
+ int accuracy_log;
+} FSE_dtable;
+
+/// Return the symbol for the current state
+static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
+ const u16 state);
+/// Read the number of bits necessary to update state, update, and shift offset
+/// back to reflect the bits read
+static inline void FSE_update_state(const FSE_dtable *const dtable,
+ u16 *const state, const u8 *const src,
+ i64 *const offset);
+
+/// Combine peek and update: decode a symbol and update the state
+static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
+ u16 *const state, const u8 *const src,
+ i64 *const offset);
+
+/// Read bits from the stream to initialize the state and shift offset back
+static inline void FSE_init_state(const FSE_dtable *const dtable,
+ u16 *const state, const u8 *const src,
+ i64 *const offset);
+
+/// Decompress two interleaved bitstreams (e.g. compressed Huffman weights)
+/// using an FSE decoding table. `src_len` must be the exact length of the
+/// block.
+static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
+ ostream_t *const out,
+ istream_t *const in);
+
+/// Initialize a decoding table using normalized frequencies.
+static void FSE_init_dtable(FSE_dtable *const dtable,
+ const i16 *const norm_freqs, const int num_symbs,
+ const int accuracy_log);
+
+/// Decode an FSE header as defined in the Zstandard format specification and
+/// use the decoded frequencies to initialize a decoding table.
+static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
+ const int max_accuracy_log);
+
+/// Initialize an FSE table that will always return the same symbol and consume
+/// 0 bits per symbol, to be used for RLE mode in sequence commands
+static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb);
+
+/// Free the malloc'ed parts of a decoding table
+static void FSE_free_dtable(FSE_dtable *const dtable);
+
+/// Deep copy a decoding table, so that it can be used and free'd without
+/// impacting the source table.
+static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src);
+/*** END FSE PRIMITIVES ***************/
+
+/******* END IMPLEMENTATION PRIMITIVE PROTOTYPES ******************************/
+
+/******* ZSTD HELPER STRUCTS AND PROTOTYPES ***********************************/
+
+/// A small structure that can be reused in various places that need to access
+/// frame header information
+typedef struct {
+ // The size of window that we need to be able to contiguously store for
+ // references
+ size_t window_size;
+ // The total output size of this compressed frame
+ size_t frame_content_size;
+
+ // The dictionary id if this frame uses one
+ u32 dictionary_id;
+
+ // Whether or not the content of this frame has a checksum
+ int content_checksum_flag;
+ // Whether or not the output for this frame is in a single segment
+ int single_segment_flag;
+} frame_header_t;
+
+/// The context needed to decode blocks in a frame
+typedef struct {
+ frame_header_t header;
+
+ // The total amount of data available for backreferences, to determine if an
+ // offset too large to be correct
+ size_t current_total_output;
+
+ const u8 *dict_content;
+ size_t dict_content_len;
+
+ // Entropy encoding tables so they can be repeated by future blocks instead
+ // of retransmitting
+ HUF_dtable literals_dtable;
+ FSE_dtable ll_dtable;
+ FSE_dtable ml_dtable;
+ FSE_dtable of_dtable;
+
+ // The last 3 offsets for the special "repeat offsets".
+ u64 previous_offsets[3];
+} frame_context_t;
+
+/// The decoded contents of a dictionary so that it doesn't have to be repeated
+/// for each frame that uses it
+struct dictionary_s {
+ // Entropy tables
+ HUF_dtable literals_dtable;
+ FSE_dtable ll_dtable;
+ FSE_dtable ml_dtable;
+ FSE_dtable of_dtable;
+
+ // Raw content for backreferences
+ u8 *content;
+ size_t content_size;
+
+ // Offset history to prepopulate the frame's history
+ u64 previous_offsets[3];
+
+ u32 dictionary_id;
+};
+
+/// A tuple containing the parts necessary to decode and execute a ZSTD sequence
+/// command
+typedef struct {
+ u32 literal_length;
+ u32 match_length;
+ u32 offset;
+} sequence_command_t;
+
+/// The decoder works top-down, starting at the high level like Zstd frames, and
+/// working down to lower more technical levels such as blocks, literals, and
+/// sequences. The high-level functions roughly follow the outline of the
+/// format specification:
+/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
+
+/// Before the implementation of each high-level function declared here, the
+/// prototypes for their helper functions are defined and explained
+
+/// Decode a single Zstd frame, or error if the input is not a valid frame.
+/// Accepts a dict argument, which may be NULL indicating no dictionary.
+/// See
+/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame-concatenation
+static void decode_frame(ostream_t *const out, istream_t *const in,
+ const dictionary_t *const dict);
+
+// Decode data in a compressed block
+static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
+ istream_t *const in);
+
+// Decode the literals section of a block
+static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
+ u8 **const literals);
+
+// Decode the sequences part of a block
+static size_t decode_sequences(frame_context_t *const ctx, istream_t *const in,
+ sequence_command_t **const sequences);
+
+// Execute the decoded sequences on the literals block
+static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
+ const u8 *const literals,
+ const size_t literals_len,
+ const sequence_command_t *const sequences,
+ const size_t num_sequences);
+
+// Copies literals and returns the total literal length that was copied
+static u32 copy_literals(const size_t seq, istream_t *litstream,
+ ostream_t *const out);
+
+// Given an offset code from a sequence command (either an actual offset value
+// or an index for previous offset), computes the correct offset and udpates
+// the offset history
+static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist);
+
+// Given an offset, match length, and total output, as well as the frame
+// context for the dictionary, determines if the dictionary is used and
+// executes the copy operation
+static void execute_match_copy(frame_context_t *const ctx, size_t offset,
+ size_t match_length, size_t total_output,
+ ostream_t *const out);
+
+/******* END ZSTD HELPER STRUCTS AND PROTOTYPES *******************************/
+
+size_t ZSTD_decompress(void *const dst, const size_t dst_len,
+ const void *const src, const size_t src_len) {
+ dictionary_t* uninit_dict = create_dictionary();
+ size_t const decomp_size = ZSTD_decompress_with_dict(dst, dst_len, src,
+ src_len, uninit_dict);
+ free_dictionary(uninit_dict);
+ return decomp_size;
+}
+
+size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len,
+ const void *const src, const size_t src_len,
+ dictionary_t* parsed_dict) {
+
+ istream_t in = IO_make_istream(src, src_len);
+ ostream_t out = IO_make_ostream(dst, dst_len);
+
+ // "A content compressed by Zstandard is transformed into a Zstandard frame.
+ // Multiple frames can be appended into a single file or stream. A frame is
+ // totally independent, has a defined beginning and end, and a set of
+ // parameters which tells the decoder how to decompress it."
+
+ /* this decoder assumes decompression of a single frame */
+ decode_frame(&out, &in, parsed_dict);
+
+ return out.ptr - (u8 *)dst;
+}
+
+/******* FRAME DECODING ******************************************************/
+
+static void decode_data_frame(ostream_t *const out, istream_t *const in,
+ const dictionary_t *const dict);
+static void init_frame_context(frame_context_t *const context,
+ istream_t *const in,
+ const dictionary_t *const dict);
+static void free_frame_context(frame_context_t *const context);
+static void parse_frame_header(frame_header_t *const header,
+ istream_t *const in);
+static void frame_context_apply_dict(frame_context_t *const ctx,
+ const dictionary_t *const dict);
+
+static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
+ istream_t *const in);
+
+static void decode_frame(ostream_t *const out, istream_t *const in,
+ const dictionary_t *const dict) {
+ const u32 magic_number = IO_read_bits(in, 32);
+ // Zstandard frame
+ //
+ // "Magic_Number
+ //
+ // 4 Bytes, little-endian format. Value : 0xFD2FB528"
+ if (magic_number == 0xFD2FB528U) {
+ // ZSTD frame
+ decode_data_frame(out, in, dict);
+
+ return;
+ }
+
+ // not a real frame or a skippable frame
+ ERROR("Tried to decode non-ZSTD frame");
+}
+
+/// Decode a frame that contains compressed data. Not all frames do as there
+/// are skippable frames.
+/// See
+/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#general-structure-of-zstandard-frame-format
+static void decode_data_frame(ostream_t *const out, istream_t *const in,
+ const dictionary_t *const dict) {
+ frame_context_t ctx;
+
+ // Initialize the context that needs to be carried from block to block
+ init_frame_context(&ctx, in, dict);
+
+ if (ctx.header.frame_content_size != 0 &&
+ ctx.header.frame_content_size > out->len) {
+ OUT_SIZE();
+ }
+
+ decompress_data(&ctx, out, in);
+
+ free_frame_context(&ctx);
+}
+
+/// Takes the information provided in the header and dictionary, and initializes
+/// the context for this frame
+static void init_frame_context(frame_context_t *const context,
+ istream_t *const in,
+ const dictionary_t *const dict) {
+ // Most fields in context are correct when initialized to 0
+ memset(context, 0, sizeof(frame_context_t));
+
+ // Parse data from the frame header
+ parse_frame_header(&context->header, in);
+
+ // Set up the offset history for the repeat offset commands
+ context->previous_offsets[0] = 1;
+ context->previous_offsets[1] = 4;
+ context->previous_offsets[2] = 8;
+
+ // Apply details from the dict if it exists
+ frame_context_apply_dict(context, dict);
+}
+
+static void free_frame_context(frame_context_t *const context) {
+ HUF_free_dtable(&context->literals_dtable);
+
+ FSE_free_dtable(&context->ll_dtable);
+ FSE_free_dtable(&context->ml_dtable);
+ FSE_free_dtable(&context->of_dtable);
+
+ memset(context, 0, sizeof(frame_context_t));
+}
+
+static void parse_frame_header(frame_header_t *const header,
+ istream_t *const in) {
+ // "The first header's byte is called the Frame_Header_Descriptor. It tells
+ // which other fields are present. Decoding this byte is enough to tell the
+ // size of Frame_Header.
+ //
+ // Bit number Field name
+ // 7-6 Frame_Content_Size_flag
+ // 5 Single_Segment_flag
+ // 4 Unused_bit
+ // 3 Reserved_bit
+ // 2 Content_Checksum_flag
+ // 1-0 Dictionary_ID_flag"
+ const u8 descriptor = IO_read_bits(in, 8);
+
+ // decode frame header descriptor into flags
+ const u8 frame_content_size_flag = descriptor >> 6;
+ const u8 single_segment_flag = (descriptor >> 5) & 1;
+ const u8 reserved_bit = (descriptor >> 3) & 1;
+ const u8 content_checksum_flag = (descriptor >> 2) & 1;
+ const u8 dictionary_id_flag = descriptor & 3;
+
+ if (reserved_bit != 0) {
+ CORRUPTION();
+ }
+
+ header->single_segment_flag = single_segment_flag;
+ header->content_checksum_flag = content_checksum_flag;
+
+ // decode window size
+ if (!single_segment_flag) {
+ // "Provides guarantees on maximum back-reference distance that will be
+ // used within compressed data. This information is important for
+ // decoders to allocate enough memory.
+ //
+ // Bit numbers 7-3 2-0
+ // Field name Exponent Mantissa"
+ u8 window_descriptor = IO_read_bits(in, 8);
+ u8 exponent = window_descriptor >> 3;
+ u8 mantissa = window_descriptor & 7;
+
+ // Use the algorithm from the specification to compute window size
+ // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
+ size_t window_base = (size_t)1 << (10 + exponent);
+ size_t window_add = (window_base / 8) * mantissa;
+ header->window_size = window_base + window_add;
+ }
+
+ // decode dictionary id if it exists
+ if (dictionary_id_flag) {
+ // "This is a variable size field, which contains the ID of the
+ // dictionary required to properly decode the frame. Note that this
+ // field is optional. When it's not present, it's up to the caller to
+ // make sure it uses the correct dictionary. Format is little-endian."
+ const int bytes_array[] = {0, 1, 2, 4};
+ const int bytes = bytes_array[dictionary_id_flag];
+
+ header->dictionary_id = IO_read_bits(in, bytes * 8);
+ } else {
+ header->dictionary_id = 0;
+ }
+
+ // decode frame content size if it exists
+ if (single_segment_flag || frame_content_size_flag) {
+ // "This is the original (uncompressed) size. This information is
+ // optional. The Field_Size is provided according to value of
+ // Frame_Content_Size_flag. The Field_Size can be equal to 0 (not
+ // present), 1, 2, 4 or 8 bytes. Format is little-endian."
+ //
+ // if frame_content_size_flag == 0 but single_segment_flag is set, we
+ // still have a 1 byte field
+ const int bytes_array[] = {1, 2, 4, 8};
+ const int bytes = bytes_array[frame_content_size_flag];
+
+ header->frame_content_size = IO_read_bits(in, bytes * 8);
+ if (bytes == 2) {
+ // "When Field_Size is 2, the offset of 256 is added."
+ header->frame_content_size += 256;
+ }
+ } else {
+ header->frame_content_size = 0;
+ }
+
+ if (single_segment_flag) {
+ // "The Window_Descriptor byte is optional. It is absent when
+ // Single_Segment_flag is set. In this case, the maximum back-reference
+ // distance is the content size itself, which can be any value from 1 to
+ // 2^64-1 bytes (16 EB)."
+ header->window_size = header->frame_content_size;
+ }
+}
+
+/// A dictionary acts as initializing values for the frame context before
+/// decompression, so we implement it by applying it's predetermined
+/// tables and content to the context before beginning decompression
+static void frame_context_apply_dict(frame_context_t *const ctx,
+ const dictionary_t *const dict) {
+ // If the content pointer is NULL then it must be an empty dict
+ if (!dict || !dict->content)
+ return;
+
+ // If the requested dictionary_id is non-zero, the correct dictionary must
+ // be present
+ if (ctx->header.dictionary_id != 0 &&
+ ctx->header.dictionary_id != dict->dictionary_id) {
+ ERROR("Wrong dictionary provided");
+ }
+
+ // Copy the dict content to the context for references during sequence
+ // execution
+ ctx->dict_content = dict->content;
+ ctx->dict_content_len = dict->content_size;
+
+ // If it's a formatted dict copy the precomputed tables in so they can
+ // be used in the table repeat modes
+ if (dict->dictionary_id != 0) {
+ // Deep copy the entropy tables so they can be freed independently of
+ // the dictionary struct
+ HUF_copy_dtable(&ctx->literals_dtable, &dict->literals_dtable);
+ FSE_copy_dtable(&ctx->ll_dtable, &dict->ll_dtable);
+ FSE_copy_dtable(&ctx->of_dtable, &dict->of_dtable);
+ FSE_copy_dtable(&ctx->ml_dtable, &dict->ml_dtable);
+
+ // Copy the repeated offsets
+ memcpy(ctx->previous_offsets, dict->previous_offsets,
+ sizeof(ctx->previous_offsets));
+ }
+}
+
+/// Decompress the data from a frame block by block
+static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
+ istream_t *const in) {
+ // "A frame encapsulates one or multiple blocks. Each block can be
+ // compressed or not, and has a guaranteed maximum content size, which
+ // depends on frame parameters. Unlike frames, each block depends on
+ // previous blocks for proper decoding. However, each block can be
+ // decompressed without waiting for its successor, allowing streaming
+ // operations."
+ int last_block = 0;
+ do {
+ // "Last_Block
+ //
+ // The lowest bit signals if this block is the last one. Frame ends
+ // right after this block.
+ //
+ // Block_Type and Block_Size
+ //
+ // The next 2 bits represent the Block_Type, while the remaining 21 bits
+ // represent the Block_Size. Format is little-endian."
+ last_block = IO_read_bits(in, 1);
+ const int block_type = IO_read_bits(in, 2);
+ const size_t block_len = IO_read_bits(in, 21);
+
+ switch (block_type) {
+ case 0: {
+ // "Raw_Block - this is an uncompressed block. Block_Size is the
+ // number of bytes to read and copy."
+ const u8 *const read_ptr = IO_get_read_ptr(in, block_len);
+ u8 *const write_ptr = IO_get_write_ptr(out, block_len);
+
+ // Copy the raw data into the output
+ memcpy(write_ptr, read_ptr, block_len);
+
+ ctx->current_total_output += block_len;
+ break;
+ }
+ case 1: {
+ // "RLE_Block - this is a single byte, repeated N times. In which
+ // case, Block_Size is the size to regenerate, while the
+ // "compressed" block is just 1 byte (the byte to repeat)."
+ const u8 *const read_ptr = IO_get_read_ptr(in, 1);
+ u8 *const write_ptr = IO_get_write_ptr(out, block_len);
+
+ // Copy `block_len` copies of `read_ptr[0]` to the output
+ memset(write_ptr, read_ptr[0], block_len);
+
+ ctx->current_total_output += block_len;
+ break;
+ }
+ case 2: {
+ // "Compressed_Block - this is a Zstandard compressed block,
+ // detailed in another section of this specification. Block_Size is
+ // the compressed size.
+
+ // Create a sub-stream for the block
+ istream_t block_stream = IO_make_sub_istream(in, block_len);
+ decompress_block(ctx, out, &block_stream);
+ break;
+ }
+ case 3:
+ // "Reserved - this is not a block. This value cannot be used with
+ // current version of this specification."
+ CORRUPTION();
+ break;
+ default:
+ IMPOSSIBLE();
+ }
+ } while (!last_block);
+
+ if (ctx->header.content_checksum_flag) {
+ // This program does not support checking the checksum, so skip over it
+ // if it's present
+ IO_advance_input(in, 4);
+ }
+}
+/******* END FRAME DECODING ***************************************************/
+
+/******* BLOCK DECOMPRESSION **************************************************/
+static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
+ istream_t *const in) {
+ // "A compressed block consists of 2 sections :
+ //
+ // Literals_Section
+ // Sequences_Section"
+
+
+ // Part 1: decode the literals block
+ u8 *literals = NULL;
+ const size_t literals_size = decode_literals(ctx, in, &literals);
+
+ // Part 2: decode the sequences block
+ sequence_command_t *sequences = NULL;
+ const size_t num_sequences =
+ decode_sequences(ctx, in, &sequences);
+
+ // Part 3: combine literals and sequence commands to generate output
+ execute_sequences(ctx, out, literals, literals_size, sequences,
+ num_sequences);
+ free(literals);
+ free(sequences);
+}
+/******* END BLOCK DECOMPRESSION **********************************************/
+
+/******* LITERALS DECODING ****************************************************/
+static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
+ const int block_type,
+ const int size_format);
+static size_t decode_literals_compressed(frame_context_t *const ctx,
+ istream_t *const in,
+ u8 **const literals,
+ const int block_type,
+ const int size_format);
+static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in);
+static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
+ int *const num_symbs);
+
+static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
+ u8 **const literals) {
+ // "Literals can be stored uncompressed or compressed using Huffman prefix
+ // codes. When compressed, an optional tree description can be present,
+ // followed by 1 or 4 streams."
+ //
+ // "Literals_Section_Header
+ //
+ // Header is in charge of describing how literals are packed. It's a
+ // byte-aligned variable-size bitfield, ranging from 1 to 5 bytes, using
+ // little-endian convention."
+ //
+ // "Literals_Block_Type
+ //
+ // This field uses 2 lowest bits of first byte, describing 4 different block
+ // types"
+ //
+ // size_format takes between 1 and 2 bits
+ int block_type = IO_read_bits(in, 2);
+ int size_format = IO_read_bits(in, 2);
+
+ if (block_type <= 1) {
+ // Raw or RLE literals block
+ return decode_literals_simple(in, literals, block_type,
+ size_format);
+ } else {
+ // Huffman compressed literals
+ return decode_literals_compressed(ctx, in, literals, block_type,
+ size_format);
+ }
+}
+
+/// Decodes literals blocks in raw or RLE form
+static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
+ const int block_type,
+ const int size_format) {
+ size_t size;
+ switch (size_format) {
+ // These cases are in the form ?0
+ // In this case, the ? bit is actually part of the size field
+ case 0:
+ case 2:
+ // "Size_Format uses 1 bit. Regenerated_Size uses 5 bits (0-31)."
+ IO_rewind_bits(in, 1);
+ size = IO_read_bits(in, 5);
+ break;
+ case 1:
+ // "Size_Format uses 2 bits. Regenerated_Size uses 12 bits (0-4095)."
+ size = IO_read_bits(in, 12);
+ break;
+ case 3:
+ // "Size_Format uses 2 bits. Regenerated_Size uses 20 bits (0-1048575)."
+ size = IO_read_bits(in, 20);
+ break;
+ default:
+ // Size format is in range 0-3
+ IMPOSSIBLE();
+ }
+
+ if (size > MAX_LITERALS_SIZE) {
+ CORRUPTION();
+ }
+
+ *literals = malloc(size);
+ if (!*literals) {
+ BAD_ALLOC();
+ }
+
+ switch (block_type) {
+ case 0: {
+ // "Raw_Literals_Block - Literals are stored uncompressed."
+ const u8 *const read_ptr = IO_get_read_ptr(in, size);
+ memcpy(*literals, read_ptr, size);
+ break;
+ }
+ case 1: {
+ // "RLE_Literals_Block - Literals consist of a single byte value repeated N times."
+ const u8 *const read_ptr = IO_get_read_ptr(in, 1);
+ memset(*literals, read_ptr[0], size);
+ break;
+ }
+ default:
+ IMPOSSIBLE();
+ }
+
+ return size;
+}
+
+/// Decodes Huffman compressed literals
+static size_t decode_literals_compressed(frame_context_t *const ctx,
+ istream_t *const in,
+ u8 **const literals,
+ const int block_type,
+ const int size_format) {
+ size_t regenerated_size, compressed_size;
+ // Only size_format=0 has 1 stream, so default to 4
+ int num_streams = 4;
+ switch (size_format) {
+ case 0:
+ // "A single stream. Both Compressed_Size and Regenerated_Size use 10
+ // bits (0-1023)."
+ num_streams = 1;
+ // Fall through as it has the same size format
+ case 1:
+ // "4 streams. Both Compressed_Size and Regenerated_Size use 10 bits
+ // (0-1023)."
+ regenerated_size = IO_read_bits(in, 10);
+ compressed_size = IO_read_bits(in, 10);
+ break;
+ case 2:
+ // "4 streams. Both Compressed_Size and Regenerated_Size use 14 bits
+ // (0-16383)."
+ regenerated_size = IO_read_bits(in, 14);
+ compressed_size = IO_read_bits(in, 14);
+ break;
+ case 3:
+ // "4 streams. Both Compressed_Size and Regenerated_Size use 18 bits
+ // (0-262143)."
+ regenerated_size = IO_read_bits(in, 18);
+ compressed_size = IO_read_bits(in, 18);
+ break;
+ default:
+ // Impossible
+ IMPOSSIBLE();
+ }
+ if (regenerated_size > MAX_LITERALS_SIZE ||
+ compressed_size >= regenerated_size) {
+ CORRUPTION();
+ }
+
+ *literals = malloc(regenerated_size);
+ if (!*literals) {
+ BAD_ALLOC();
+ }
+
+ ostream_t lit_stream = IO_make_ostream(*literals, regenerated_size);
+ istream_t huf_stream = IO_make_sub_istream(in, compressed_size);
+
+ if (block_type == 2) {
+ // Decode the provided Huffman table
+ // "This section is only present when Literals_Block_Type type is
+ // Compressed_Literals_Block (2)."
+
+ HUF_free_dtable(&ctx->literals_dtable);
+ decode_huf_table(&ctx->literals_dtable, &huf_stream);
+ } else {
+ // If the previous Huffman table is being repeated, ensure it exists
+ if (!ctx->literals_dtable.symbols) {
+ CORRUPTION();
+ }
+ }
+
+ size_t symbols_decoded;
+ if (num_streams == 1) {
+ symbols_decoded = HUF_decompress_1stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
+ } else {
+ symbols_decoded = HUF_decompress_4stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
+ }
+
+ if (symbols_decoded != regenerated_size) {
+ CORRUPTION();
+ }
+
+ return regenerated_size;
+}
+
+// Decode the Huffman table description
+static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in) {
+ // "All literal values from zero (included) to last present one (excluded)
+ // are represented by Weight with values from 0 to Max_Number_of_Bits."
+
+ // "This is a single byte value (0-255), which describes how to decode the list of weights."
+ const u8 header = IO_read_bits(in, 8);
+
+ u8 weights[HUF_MAX_SYMBS];
+ memset(weights, 0, sizeof(weights));
+
+ int num_symbs;
+
+ if (header >= 128) {
+ // "This is a direct representation, where each Weight is written
+ // directly as a 4 bits field (0-15). The full representation occupies
+ // ((Number_of_Symbols+1)/2) bytes, meaning it uses a last full byte
+ // even if Number_of_Symbols is odd. Number_of_Symbols = headerByte -
+ // 127"
+ num_symbs = header - 127;
+ const size_t bytes = (num_symbs + 1) / 2;
+
+ const u8 *const weight_src = IO_get_read_ptr(in, bytes);
+
+ for (int i = 0; i < num_symbs; i++) {
+ // "They are encoded forward, 2
+ // weights to a byte with the first weight taking the top four bits
+ // and the second taking the bottom four (e.g. the following
+ // operations could be used to read the weights: Weight[0] =
+ // (Byte[0] >> 4), Weight[1] = (Byte[0] & 0xf), etc.)."
+ if (i % 2 == 0) {
+ weights[i] = weight_src[i / 2] >> 4;
+ } else {
+ weights[i] = weight_src[i / 2] & 0xf;
+ }
+ }
+ } else {
+ // The weights are FSE encoded, decode them before we can construct the
+ // table
+ istream_t fse_stream = IO_make_sub_istream(in, header);
+ ostream_t weight_stream = IO_make_ostream(weights, HUF_MAX_SYMBS);
+ fse_decode_hufweights(&weight_stream, &fse_stream, &num_symbs);
+ }
+
+ // Construct the table using the decoded weights
+ HUF_init_dtable_usingweights(dtable, weights, num_symbs);
+}
+
+static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
+ int *const num_symbs) {
+ const int MAX_ACCURACY_LOG = 7;
+
+ FSE_dtable dtable;
+
+ // "An FSE bitstream starts by a header, describing probabilities
+ // distribution. It will create a Decoding Table. For a list of Huffman
+ // weights, maximum accuracy is 7 bits."
+ FSE_decode_header(&dtable, in, MAX_ACCURACY_LOG);
+
+ // Decode the weights
+ *num_symbs = FSE_decompress_interleaved2(&dtable, weights, in);
+
+ FSE_free_dtable(&dtable);
+}
+/******* END LITERALS DECODING ************************************************/
+
+/******* SEQUENCE DECODING ****************************************************/
+/// The combination of FSE states needed to decode sequences
+typedef struct {
+ FSE_dtable ll_table;
+ FSE_dtable of_table;
+ FSE_dtable ml_table;
+
+ u16 ll_state;
+ u16 of_state;
+ u16 ml_state;
+} sequence_states_t;
+
+/// Different modes to signal to decode_seq_tables what to do
+typedef enum {
+ seq_literal_length = 0,
+ seq_offset = 1,
+ seq_match_length = 2,
+} seq_part_t;
+
+typedef enum {
+ seq_predefined = 0,
+ seq_rle = 1,
+ seq_fse = 2,
+ seq_repeat = 3,
+} seq_mode_t;
+
+/// The predefined FSE distribution tables for `seq_predefined` mode
+static const i16 SEQ_LITERAL_LENGTH_DEFAULT_DIST[36] = {
+ 4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, -1, -1, -1, -1};
+static const i16 SEQ_OFFSET_DEFAULT_DIST[29] = {
+ 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1};
+static const i16 SEQ_MATCH_LENGTH_DEFAULT_DIST[53] = {
+ 1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1};
+
+/// The sequence decoding baseline and number of additional bits to read/add
+/// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#the-codes-for-literals-lengths-match-lengths-and-offsets
+static const u32 SEQ_LITERAL_LENGTH_BASELINES[36] = {
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 18, 20, 22, 24, 28, 32, 40,
+ 48, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65538};
+static const u8 SEQ_LITERAL_LENGTH_EXTRA_BITS[36] = {
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
+ 1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+
+static const u32 SEQ_MATCH_LENGTH_BASELINES[53] = {
+ 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
+ 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
+ 31, 32, 33, 34, 35, 37, 39, 41, 43, 47, 51, 59, 67, 83,
+ 99, 131, 259, 515, 1027, 2051, 4099, 8195, 16387, 32771, 65539};
+static const u8 SEQ_MATCH_LENGTH_EXTRA_BITS[53] = {
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
+ 2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+
+/// Offset decoding is simpler so we just need a maximum code value
+static const u8 SEQ_MAX_CODES[3] = {35, -1, 52};
+
+static void decompress_sequences(frame_context_t *const ctx,
+ istream_t *const in,
+ sequence_command_t *const sequences,
+ const size_t num_sequences);
+static sequence_command_t decode_sequence(sequence_states_t *const state,
+ const u8 *const src,
+ i64 *const offset);
+static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
+ const seq_part_t type, const seq_mode_t mode);
+
+static size_t decode_sequences(frame_context_t *const ctx, istream_t *in,
+ sequence_command_t **const sequences) {
+ // "A compressed block is a succession of sequences . A sequence is a
+ // literal copy command, followed by a match copy command. A literal copy
+ // command specifies a length. It is the number of bytes to be copied (or
+ // extracted) from the literal section. A match copy command specifies an
+ // offset and a length. The offset gives the position to copy from, which
+ // can be within a previous block."
+
+ size_t num_sequences;
+
+ // "Number_of_Sequences
+ //
+ // This is a variable size field using between 1 and 3 bytes. Let's call its
+ // first byte byte0."
+ u8 header = IO_read_bits(in, 8);
+ if (header == 0) {
+ // "There are no sequences. The sequence section stops there.
+ // Regenerated content is defined entirely by literals section."
+ *sequences = NULL;
+ return 0;
+ } else if (header < 128) {
+ // "Number_of_Sequences = byte0 . Uses 1 byte."
+ num_sequences = header;
+ } else if (header < 255) {
+ // "Number_of_Sequences = ((byte0-128) << 8) + byte1 . Uses 2 bytes."
+ num_sequences = ((header - 128) << 8) + IO_read_bits(in, 8);
+ } else {
+ // "Number_of_Sequences = byte1 + (byte2<<8) + 0x7F00 . Uses 3 bytes."
+ num_sequences = IO_read_bits(in, 16) + 0x7F00;
+ }
+
+ *sequences = malloc(num_sequences * sizeof(sequence_command_t));
+ if (!*sequences) {
+ BAD_ALLOC();
+ }
+
+ decompress_sequences(ctx, in, *sequences, num_sequences);
+ return num_sequences;
+}
+
+/// Decompress the FSE encoded sequence commands
+static void decompress_sequences(frame_context_t *const ctx, istream_t *in,
+ sequence_command_t *const sequences,
+ const size_t num_sequences) {
+ // "The Sequences_Section regroup all symbols required to decode commands.
+ // There are 3 symbol types : literals lengths, offsets and match lengths.
+ // They are encoded together, interleaved, in a single bitstream."
+
+ // "Symbol compression modes
+ //
+ // This is a single byte, defining the compression mode of each symbol
+ // type."
+ //
+ // Bit number : Field name
+ // 7-6 : Literals_Lengths_Mode
+ // 5-4 : Offsets_Mode
+ // 3-2 : Match_Lengths_Mode
+ // 1-0 : Reserved
+ u8 compression_modes = IO_read_bits(in, 8);
+
+ if ((compression_modes & 3) != 0) {
+ // Reserved bits set
+ CORRUPTION();
+ }
+
+ // "Following the header, up to 3 distribution tables can be described. When
+ // present, they are in this order :
+ //
+ // Literals lengths
+ // Offsets
+ // Match Lengths"
+ // Update the tables we have stored in the context
+ decode_seq_table(&ctx->ll_dtable, in, seq_literal_length,
+ (compression_modes >> 6) & 3);
+
+ decode_seq_table(&ctx->of_dtable, in, seq_offset,
+ (compression_modes >> 4) & 3);
+
+ decode_seq_table(&ctx->ml_dtable, in, seq_match_length,
+ (compression_modes >> 2) & 3);
+
+
+ sequence_states_t states;
+
+ // Initialize the decoding tables
+ {
+ states.ll_table = ctx->ll_dtable;
+ states.of_table = ctx->of_dtable;
+ states.ml_table = ctx->ml_dtable;
+ }
+
+ const size_t len = IO_istream_len(in);
+ const u8 *const src = IO_get_read_ptr(in, len);
+
+ // "After writing the last bit containing information, the compressor writes
+ // a single 1-bit and then fills the byte with 0-7 0 bits of padding."
+ const int padding = 8 - highest_set_bit(src[len - 1]);
+ // The offset starts at the end because FSE streams are read backwards
+ i64 bit_offset = len * 8 - padding;
+
+ // "The bitstream starts with initial state values, each using the required
+ // number of bits in their respective accuracy, decoded previously from
+ // their normalized distribution.
+ //
+ // It starts by Literals_Length_State, followed by Offset_State, and finally
+ // Match_Length_State."
+ FSE_init_state(&states.ll_table, &states.ll_state, src, &bit_offset);
+ FSE_init_state(&states.of_table, &states.of_state, src, &bit_offset);
+ FSE_init_state(&states.ml_table, &states.ml_state, src, &bit_offset);
+
+ for (size_t i = 0; i < num_sequences; i++) {
+ // Decode sequences one by one
+ sequences[i] = decode_sequence(&states, src, &bit_offset);
+ }
+
+ if (bit_offset != 0) {
+ CORRUPTION();
+ }
+}
+
+// Decode a single sequence and update the state
+static sequence_command_t decode_sequence(sequence_states_t *const states,
+ const u8 *const src,
+ i64 *const offset) {
+ // "Each symbol is a code in its own context, which specifies Baseline and
+ // Number_of_Bits to add. Codes are FSE compressed, and interleaved with raw
+ // additional bits in the same bitstream."
+
+ // Decode symbols, but don't update states
+ const u8 of_code = FSE_peek_symbol(&states->of_table, states->of_state);
+ const u8 ll_code = FSE_peek_symbol(&states->ll_table, states->ll_state);
+ const u8 ml_code = FSE_peek_symbol(&states->ml_table, states->ml_state);
+
+ // Offset doesn't need a max value as it's not decoded using a table
+ if (ll_code > SEQ_MAX_CODES[seq_literal_length] ||
+ ml_code > SEQ_MAX_CODES[seq_match_length]) {
+ CORRUPTION();
+ }
+
+ // Read the interleaved bits
+ sequence_command_t seq;
+ // "Decoding starts by reading the Number_of_Bits required to decode Offset.
+ // It then does the same for Match_Length, and then for Literals_Length."
+ seq.offset = ((u32)1 << of_code) + STREAM_read_bits(src, of_code, offset);
+
+ seq.match_length =
+ SEQ_MATCH_LENGTH_BASELINES[ml_code] +
+ STREAM_read_bits(src, SEQ_MATCH_LENGTH_EXTRA_BITS[ml_code], offset);
+
+ seq.literal_length =
+ SEQ_LITERAL_LENGTH_BASELINES[ll_code] +
+ STREAM_read_bits(src, SEQ_LITERAL_LENGTH_EXTRA_BITS[ll_code], offset);
+
+ // "If it is not the last sequence in the block, the next operation is to
+ // update states. Using the rules pre-calculated in the decoding tables,
+ // Literals_Length_State is updated, followed by Match_Length_State, and
+ // then Offset_State."
+ // If the stream is complete don't read bits to update state
+ if (*offset != 0) {
+ FSE_update_state(&states->ll_table, &states->ll_state, src, offset);
+ FSE_update_state(&states->ml_table, &states->ml_state, src, offset);
+ FSE_update_state(&states->of_table, &states->of_state, src, offset);
+ }
+
+ return seq;
+}
+
+/// Given a sequence part and table mode, decode the FSE distribution
+/// Errors if the mode is `seq_repeat` without a pre-existing table in `table`
+static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
+ const seq_part_t type, const seq_mode_t mode) {
+ // Constant arrays indexed by seq_part_t
+ const i16 *const default_distributions[] = {SEQ_LITERAL_LENGTH_DEFAULT_DIST,
+ SEQ_OFFSET_DEFAULT_DIST,
+ SEQ_MATCH_LENGTH_DEFAULT_DIST};
+ const size_t default_distribution_lengths[] = {36, 29, 53};
+ const size_t default_distribution_accuracies[] = {6, 5, 6};
+
+ const size_t max_accuracies[] = {9, 8, 9};
+
+ if (mode != seq_repeat) {
+ // Free old one before overwriting
+ FSE_free_dtable(table);
+ }
+
+ switch (mode) {
+ case seq_predefined: {
+ // "Predefined_Mode : uses a predefined distribution table."
+ const i16 *distribution = default_distributions[type];
+ const size_t symbs = default_distribution_lengths[type];
+ const size_t accuracy_log = default_distribution_accuracies[type];
+
+ FSE_init_dtable(table, distribution, symbs, accuracy_log);
+ break;
+ }
+ case seq_rle: {
+ // "RLE_Mode : it's a single code, repeated Number_of_Sequences times."
+ const u8 symb = IO_get_read_ptr(in, 1)[0];
+ FSE_init_dtable_rle(table, symb);
+ break;
+ }
+ case seq_fse: {
+ // "FSE_Compressed_Mode : standard FSE compression. A distribution table
+ // will be present "
+ FSE_decode_header(table, in, max_accuracies[type]);
+ break;
+ }
+ case seq_repeat:
+ // "Repeat_Mode : re-use distribution table from previous compressed
+ // block."
+ // Nothing to do here, table will be unchanged
+ if (!table->symbols) {
+ // This mode is invalid if we don't already have a table
+ CORRUPTION();
+ }
+ break;
+ default:
+ // Impossible, as mode is from 0-3
+ IMPOSSIBLE();
+ break;
+ }
+
+}
+/******* END SEQUENCE DECODING ************************************************/
+
+/******* SEQUENCE EXECUTION ***************************************************/
+static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
+ const u8 *const literals,
+ const size_t literals_len,
+ const sequence_command_t *const sequences,
+ const size_t num_sequences) {
+ istream_t litstream = IO_make_istream(literals, literals_len);
+
+ u64 *const offset_hist = ctx->previous_offsets;
+ size_t total_output = ctx->current_total_output;
+
+ for (size_t i = 0; i < num_sequences; i++) {
+ const sequence_command_t seq = sequences[i];
+ {
+ const u32 literals_size = copy_literals(seq.literal_length, &litstream, out);
+ total_output += literals_size;
+ }
+
+ size_t const offset = compute_offset(seq, offset_hist);
+
+ size_t const match_length = seq.match_length;
+
+ execute_match_copy(ctx, offset, match_length, total_output, out);
+
+ total_output += match_length;
+ }
+
+ // Copy any leftover literals
+ {
+ size_t len = IO_istream_len(&litstream);
+ copy_literals(len, &litstream, out);
+ total_output += len;
+ }
+
+ ctx->current_total_output = total_output;
+}
+
+static u32 copy_literals(const size_t literal_length, istream_t *litstream,
+ ostream_t *const out) {
+ // If the sequence asks for more literals than are left, the
+ // sequence must be corrupted
+ if (literal_length > IO_istream_len(litstream)) {
+ CORRUPTION();
+ }
+
+ u8 *const write_ptr = IO_get_write_ptr(out, literal_length);
+ const u8 *const read_ptr =
+ IO_get_read_ptr(litstream, literal_length);
+ // Copy literals to output
+ memcpy(write_ptr, read_ptr, literal_length);
+
+ return literal_length;
+}
+
+static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist) {
+ size_t offset;
+ // Offsets are special, we need to handle the repeat offsets
+ if (seq.offset <= 3) {
+ // "The first 3 values define a repeated offset and we will call
+ // them Repeated_Offset1, Repeated_Offset2, and Repeated_Offset3.
+ // They are sorted in recency order, with Repeated_Offset1 meaning
+ // 'most recent one'".
+
+ // Use 0 indexing for the array
+ u32 idx = seq.offset - 1;
+ if (seq.literal_length == 0) {
+ // "There is an exception though, when current sequence's
+ // literals length is 0. In this case, repeated offsets are
+ // shifted by one, so Repeated_Offset1 becomes Repeated_Offset2,
+ // Repeated_Offset2 becomes Repeated_Offset3, and
+ // Repeated_Offset3 becomes Repeated_Offset1 - 1_byte."
+ idx++;
+ }
+
+ if (idx == 0) {
+ offset = offset_hist[0];
+ } else {
+ // If idx == 3 then literal length was 0 and the offset was 3,
+ // as per the exception listed above
+ offset = idx < 3 ? offset_hist[idx] : offset_hist[0] - 1;
+
+ // If idx == 1 we don't need to modify offset_hist[2], since
+ // we're using the second-most recent code
+ if (idx > 1) {
+ offset_hist[2] = offset_hist[1];
+ }
+ offset_hist[1] = offset_hist[0];
+ offset_hist[0] = offset;
+ }
+ } else {
+ // When it's not a repeat offset:
+ // "if (Offset_Value > 3) offset = Offset_Value - 3;"
+ offset = seq.offset - 3;
+
+ // Shift back history
+ offset_hist[2] = offset_hist[1];
+ offset_hist[1] = offset_hist[0];
+ offset_hist[0] = offset;
+ }
+ return offset;
+}
+
+static void execute_match_copy(frame_context_t *const ctx, size_t offset,
+ size_t match_length, size_t total_output,
+ ostream_t *const out) {
+ u8 *write_ptr = IO_get_write_ptr(out, match_length);
+ if (total_output <= ctx->header.window_size) {
+ // In this case offset might go back into the dictionary
+ if (offset > total_output + ctx->dict_content_len) {
+ // The offset goes beyond even the dictionary
+ CORRUPTION();
+ }
+
+ if (offset > total_output) {
+ // "The rest of the dictionary is its content. The content act
+ // as a "past" in front of data to compress or decompress, so it
+ // can be referenced in sequence commands."
+ const size_t dict_copy =
+ MIN(offset - total_output, match_length);
+ const size_t dict_offset =
+ ctx->dict_content_len - (offset - total_output);
+
+ memcpy(write_ptr, ctx->dict_content + dict_offset, dict_copy);
+ write_ptr += dict_copy;
+ match_length -= dict_copy;
+ }
+ } else if (offset > ctx->header.window_size) {
+ CORRUPTION();
+ }
+
+ // We must copy byte by byte because the match length might be larger
+ // than the offset
+ // ex: if the output so far was "abc", a command with offset=3 and
+ // match_length=6 would produce "abcabcabc" as the new output
+ for (size_t j = 0; j < match_length; j++) {
+ *write_ptr = *(write_ptr - offset);
+ write_ptr++;
+ }
+}
+/******* END SEQUENCE EXECUTION ***********************************************/
+
+/******* OUTPUT SIZE COUNTING *************************************************/
+/// Get the decompressed size of an input stream so memory can be allocated in
+/// advance.
+/// This implementation assumes `src` points to a single ZSTD-compressed frame
+size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) {
+ istream_t in = IO_make_istream(src, src_len);
+
+ // get decompressed size from ZSTD frame header
+ {
+ const u32 magic_number = IO_read_bits(&in, 32);
+
+ if (magic_number == 0xFD2FB528U) {
+ // ZSTD frame
+ frame_header_t header;
+ parse_frame_header(&header, &in);
+
+ if (header.frame_content_size == 0 && !header.single_segment_flag) {
+ // Content size not provided, we can't tell
+ return -1;
+ }
+
+ return header.frame_content_size;
+ } else {
+ // not a real frame or skippable frame
+ ERROR("ZSTD frame magic number did not match");
+ }
+ }
+}
+/******* END OUTPUT SIZE COUNTING *********************************************/
+
+/******* DICTIONARY PARSING ***************************************************/
+#define DICT_SIZE_ERROR() ERROR("Dictionary size cannot be less than 8 bytes")
+#define NULL_SRC() ERROR("Tried to create dictionary with pointer to null src");
+
+dictionary_t* create_dictionary() {
+ dictionary_t* dict = calloc(1, sizeof(dictionary_t));
+ if (!dict) {
+ BAD_ALLOC();
+ }
+ return dict;
+}
+
+static void init_dictionary_content(dictionary_t *const dict,
+ istream_t *const in);
+
+void parse_dictionary(dictionary_t *const dict, const void *src,
+ size_t src_len) {
+ const u8 *byte_src = (const u8 *)src;
+ memset(dict, 0, sizeof(dictionary_t));
+ if (src == NULL) { /* cannot initialize dictionary with null src */
+ NULL_SRC();
+ }
+ if (src_len < 8) {
+ DICT_SIZE_ERROR();
+ }
+
+ istream_t in = IO_make_istream(byte_src, src_len);
+
+ const u32 magic_number = IO_read_bits(&in, 32);
+ if (magic_number != 0xEC30A437) {
+ // raw content dict
+ IO_rewind_bits(&in, 32);
+ init_dictionary_content(dict, &in);
+ return;
+ }
+
+ dict->dictionary_id = IO_read_bits(&in, 32);
+
+ // "Entropy_Tables : following the same format as the tables in compressed
+ // blocks. They are stored in following order : Huffman tables for literals,
+ // FSE table for offsets, FSE table for match lengths, and FSE table for
+ // literals lengths. It's finally followed by 3 offset values, populating
+ // recent offsets (instead of using {1,4,8}), stored in order, 4-bytes
+ // little-endian each, for a total of 12 bytes. Each recent offset must have
+ // a value < dictionary size."
+ decode_huf_table(&dict->literals_dtable, &in);
+ decode_seq_table(&dict->of_dtable, &in, seq_offset, seq_fse);
+ decode_seq_table(&dict->ml_dtable, &in, seq_match_length, seq_fse);
+ decode_seq_table(&dict->ll_dtable, &in, seq_literal_length, seq_fse);
+
+ // Read in the previous offset history
+ dict->previous_offsets[0] = IO_read_bits(&in, 32);
+ dict->previous_offsets[1] = IO_read_bits(&in, 32);
+ dict->previous_offsets[2] = IO_read_bits(&in, 32);
+
+ // Ensure the provided offsets aren't too large
+ // "Each recent offset must have a value < dictionary size."
+ for (int i = 0; i < 3; i++) {
+ if (dict->previous_offsets[i] > src_len) {
+ ERROR("Dictionary corrupted");
+ }
+ }
+
+ // "Content : The rest of the dictionary is its content. The content act as
+ // a "past" in front of data to compress or decompress, so it can be
+ // referenced in sequence commands."
+ init_dictionary_content(dict, &in);
+}
+
+static void init_dictionary_content(dictionary_t *const dict,
+ istream_t *const in) {
+ // Copy in the content
+ dict->content_size = IO_istream_len(in);
+ dict->content = malloc(dict->content_size);
+ if (!dict->content) {
+ BAD_ALLOC();
+ }
+
+ const u8 *const content = IO_get_read_ptr(in, dict->content_size);
+
+ memcpy(dict->content, content, dict->content_size);
+}
+
+/// Free an allocated dictionary
+void free_dictionary(dictionary_t *const dict) {
+ HUF_free_dtable(&dict->literals_dtable);
+ FSE_free_dtable(&dict->ll_dtable);
+ FSE_free_dtable(&dict->of_dtable);
+ FSE_free_dtable(&dict->ml_dtable);
+
+ free(dict->content);
+
+ memset(dict, 0, sizeof(dictionary_t));
+
+ free(dict);
+}
+/******* END DICTIONARY PARSING ***********************************************/
+
+/******* IO STREAM OPERATIONS *************************************************/
+#define UNALIGNED() ERROR("Attempting to operate on a non-byte aligned stream")
+/// Reads `num` bits from a bitstream, and updates the internal offset
+static inline u64 IO_read_bits(istream_t *const in, const int num_bits) {
+ if (num_bits > 64 || num_bits <= 0) {
+ ERROR("Attempt to read an invalid number of bits");
+ }
+
+ const size_t bytes = (num_bits + in->bit_offset + 7) / 8;
+ const size_t full_bytes = (num_bits + in->bit_offset) / 8;
+ if (bytes > in->len) {
+ INP_SIZE();
+ }
+
+ const u64 result = read_bits_LE(in->ptr, num_bits, in->bit_offset);
+
+ in->bit_offset = (num_bits + in->bit_offset) % 8;
+ in->ptr += full_bytes;
+ in->len -= full_bytes;
+
+ return result;
+}
+
+/// If a non-zero number of bits have been read from the current byte, advance
+/// the offset to the next byte
+static inline void IO_rewind_bits(istream_t *const in, int num_bits) {
+ if (num_bits < 0) {
+ ERROR("Attempting to rewind stream by a negative number of bits");
+ }
+
+ // move the offset back by `num_bits` bits
+ const int new_offset = in->bit_offset - num_bits;
+ // determine the number of whole bytes we have to rewind, rounding up to an
+ // integer number (e.g. if `new_offset == -5`, `bytes == 1`)
+ const i64 bytes = -(new_offset - 7) / 8;
+
+ in->ptr -= bytes;
+ in->len += bytes;
+ // make sure the resulting `bit_offset` is positive, as mod in C does not
+ // convert numbers from negative to positive (e.g. -22 % 8 == -6)
+ in->bit_offset = ((new_offset % 8) + 8) % 8;
+}
+
+/// If the remaining bits in a byte will be unused, advance to the end of the
+/// byte
+static inline void IO_align_stream(istream_t *const in) {
+ if (in->bit_offset != 0) {
+ if (in->len == 0) {
+ INP_SIZE();
+ }
+ in->ptr++;
+ in->len--;
+ in->bit_offset = 0;
+ }
+}
+
+/// Write the given byte into the output stream
+static inline void IO_write_byte(ostream_t *const out, u8 symb) {
+ if (out->len == 0) {
+ OUT_SIZE();
+ }
+
+ out->ptr[0] = symb;
+ out->ptr++;
+ out->len--;
+}
+
+/// Returns the number of bytes left to be read in this stream. The stream must
+/// be byte aligned.
+static inline size_t IO_istream_len(const istream_t *const in) {
+ return in->len;
+}
+
+/// Returns a pointer where `len` bytes can be read, and advances the internal
+/// state. The stream must be byte aligned.
+static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len) {
+ if (len > in->len) {
+ INP_SIZE();
+ }
+ if (in->bit_offset != 0) {
+ UNALIGNED();
+ }
+ const u8 *const ptr = in->ptr;
+ in->ptr += len;
+ in->len -= len;
+
+ return ptr;
+}
+/// Returns a pointer to write `len` bytes to, and advances the internal state
+static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len) {
+ if (len > out->len) {
+ OUT_SIZE();
+ }
+ u8 *const ptr = out->ptr;
+ out->ptr += len;
+ out->len -= len;
+
+ return ptr;
+}
+
+/// Advance the inner state by `len` bytes
+static inline void IO_advance_input(istream_t *const in, size_t len) {
+ if (len > in->len) {
+ INP_SIZE();
+ }
+ if (in->bit_offset != 0) {
+ UNALIGNED();
+ }
+
+ in->ptr += len;
+ in->len -= len;
+}
+
+/// Returns an `ostream_t` constructed from the given pointer and length
+static inline ostream_t IO_make_ostream(u8 *out, size_t len) {
+ return (ostream_t) { out, len };
+}
+
+/// Returns an `istream_t` constructed from the given pointer and length
+static inline istream_t IO_make_istream(const u8 *in, size_t len) {
+ return (istream_t) { in, len, 0 };
+}
+
+/// Returns an `istream_t` with the same base as `in`, and length `len`
+/// Then, advance `in` to account for the consumed bytes
+/// `in` must be byte aligned
+static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len) {
+ // Consume `len` bytes of the parent stream
+ const u8 *const ptr = IO_get_read_ptr(in, len);
+
+ // Make a substream using the pointer to those `len` bytes
+ return IO_make_istream(ptr, len);
+}
+/******* END IO STREAM OPERATIONS *********************************************/
+
+/******* BITSTREAM OPERATIONS *************************************************/
+/// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits
+static inline u64 read_bits_LE(const u8 *src, const int num_bits,
+ const size_t offset) {
+ if (num_bits > 64) {
+ ERROR("Attempt to read an invalid number of bits");
+ }
+
+ // Skip over bytes that aren't in range
+ src += offset / 8;
+ size_t bit_offset = offset % 8;
+ u64 res = 0;
+
+ int shift = 0;
+ int left = num_bits;
+ while (left > 0) {
+ u64 mask = left >= 8 ? 0xff : (((u64)1 << left) - 1);
+ // Read the next byte, shift it to account for the offset, and then mask
+ // out the top part if we don't need all the bits
+ res += (((u64)*src++ >> bit_offset) & mask) << shift;
+ shift += 8 - bit_offset;
+ left -= 8 - bit_offset;
+ bit_offset = 0;
+ }
+
+ return res;
+}
+
+/// Read bits from the end of a HUF or FSE bitstream. `offset` is in bits, so
+/// it updates `offset` to `offset - bits`, and then reads `bits` bits from
+/// `src + offset`. If the offset becomes negative, the extra bits at the
+/// bottom are filled in with `0` bits instead of reading from before `src`.
+static inline u64 STREAM_read_bits(const u8 *const src, const int bits,
+ i64 *const offset) {
+ *offset = *offset - bits;
+ size_t actual_off = *offset;
+ size_t actual_bits = bits;
+ // Don't actually read bits from before the start of src, so if `*offset <
+ // 0` fix actual_off and actual_bits to reflect the quantity to read
+ if (*offset < 0) {
+ actual_bits += *offset;
+ actual_off = 0;
+ }
+ u64 res = read_bits_LE(src, actual_bits, actual_off);
+
+ if (*offset < 0) {
+ // Fill in the bottom "overflowed" bits with 0's
+ res = -*offset >= 64 ? 0 : (res << -*offset);
+ }
+ return res;
+}
+/******* END BITSTREAM OPERATIONS *********************************************/
+
+/******* BIT COUNTING OPERATIONS **********************************************/
+/// Returns `x`, where `2^x` is the largest power of 2 less than or equal to
+/// `num`, or `-1` if `num == 0`.
+static inline int highest_set_bit(const u64 num) {
+ for (int i = 63; i >= 0; i--) {
+ if (((u64)1 << i) <= num) {
+ return i;
+ }
+ }
+ return -1;
+}
+/******* END BIT COUNTING OPERATIONS ******************************************/
+
+/******* HUFFMAN PRIMITIVES ***************************************************/
+static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
+ u16 *const state, const u8 *const src,
+ i64 *const offset) {
+ // Look up the symbol and number of bits to read
+ const u8 symb = dtable->symbols[*state];
+ const u8 bits = dtable->num_bits[*state];
+ const u16 rest = STREAM_read_bits(src, bits, offset);
+ // Shift `bits` bits out of the state, keeping the low order bits that
+ // weren't necessary to determine this symbol. Then add in the new bits
+ // read from the stream.
+ *state = ((*state << bits) + rest) & (((u16)1 << dtable->max_bits) - 1);
+
+ return symb;
+}
+
+static inline void HUF_init_state(const HUF_dtable *const dtable,
+ u16 *const state, const u8 *const src,
+ i64 *const offset) {
+ // Read in a full `dtable->max_bits` bits to initialize the state
+ const u8 bits = dtable->max_bits;
+ *state = STREAM_read_bits(src, bits, offset);
+}
+
+static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
+ ostream_t *const out,
+ istream_t *const in) {
+ const size_t len = IO_istream_len(in);
+ if (len == 0) {
+ INP_SIZE();
+ }
+ const u8 *const src = IO_get_read_ptr(in, len);
+
+ // "Each bitstream must be read backward, that is starting from the end down
+ // to the beginning. Therefore it's necessary to know the size of each
+ // bitstream.
+ //
+ // It's also necessary to know exactly which bit is the latest. This is
+ // detected by a final bit flag : the highest bit of latest byte is a
+ // final-bit-flag. Consequently, a last byte of 0 is not possible. And the
+ // final-bit-flag itself is not part of the useful bitstream. Hence, the
+ // last byte contains between 0 and 7 useful bits."
+ const int padding = 8 - highest_set_bit(src[len - 1]);
+
+ // Offset starts at the end because HUF streams are read backwards
+ i64 bit_offset = len * 8 - padding;
+ u16 state;
+
+ HUF_init_state(dtable, &state, src, &bit_offset);
+
+ size_t symbols_written = 0;
+ while (bit_offset > -dtable->max_bits) {
+ // Iterate over the stream, decoding one symbol at a time
+ IO_write_byte(out, HUF_decode_symbol(dtable, &state, src, &bit_offset));
+ symbols_written++;
+ }
+ // "The process continues up to reading the required number of symbols per
+ // stream. If a bitstream is not entirely and exactly consumed, hence
+ // reaching exactly its beginning position with all bits consumed, the
+ // decoding process is considered faulty."
+
+ // When all symbols have been decoded, the final state value shouldn't have
+ // any data from the stream, so it should have "read" dtable->max_bits from
+ // before the start of `src`
+ // Therefore `offset`, the edge to start reading new bits at, should be
+ // dtable->max_bits before the start of the stream
+ if (bit_offset != -dtable->max_bits) {
+ CORRUPTION();
+ }
+
+ return symbols_written;
+}
+
+static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
+ ostream_t *const out, istream_t *const in) {
+ // "Compressed size is provided explicitly : in the 4-streams variant,
+ // bitstreams are preceded by 3 unsigned little-endian 16-bits values. Each
+ // value represents the compressed size of one stream, in order. The last
+ // stream size is deducted from total compressed size and from previously
+ // decoded stream sizes"
+ const size_t csize1 = IO_read_bits(in, 16);
+ const size_t csize2 = IO_read_bits(in, 16);
+ const size_t csize3 = IO_read_bits(in, 16);
+
+ istream_t in1 = IO_make_sub_istream(in, csize1);
+ istream_t in2 = IO_make_sub_istream(in, csize2);
+ istream_t in3 = IO_make_sub_istream(in, csize3);
+ istream_t in4 = IO_make_sub_istream(in, IO_istream_len(in));
+
+ size_t total_output = 0;
+ // Decode each stream independently for simplicity
+ // If we wanted to we could decode all 4 at the same time for speed,
+ // utilizing more execution units
+ total_output += HUF_decompress_1stream(dtable, out, &in1);
+ total_output += HUF_decompress_1stream(dtable, out, &in2);
+ total_output += HUF_decompress_1stream(dtable, out, &in3);
+ total_output += HUF_decompress_1stream(dtable, out, &in4);
+
+ return total_output;
+}
+
+/// Initializes a Huffman table using canonical Huffman codes
+/// For more explanation on canonical Huffman codes see
+/// http://www.cs.uofs.edu/~mccloske/courses/cmps340/huff_canonical_dec2015.html
+/// Codes within a level are allocated in symbol order (i.e. smaller symbols get
+/// earlier codes)
+static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
+ const int num_symbs) {
+ memset(table, 0, sizeof(HUF_dtable));
+ if (num_symbs > HUF_MAX_SYMBS) {
+ ERROR("Too many symbols for Huffman");
+ }
+
+ u8 max_bits = 0;
+ u16 rank_count[HUF_MAX_BITS + 1];
+ memset(rank_count, 0, sizeof(rank_count));
+
+ // Count the number of symbols for each number of bits, and determine the
+ // depth of the tree
+ for (int i = 0; i < num_symbs; i++) {
+ if (bits[i] > HUF_MAX_BITS) {
+ ERROR("Huffman table depth too large");
+ }
+ max_bits = MAX(max_bits, bits[i]);
+ rank_count[bits[i]]++;
+ }
+
+ const size_t table_size = 1 << max_bits;
+ table->max_bits = max_bits;
+ table->symbols = malloc(table_size);
+ table->num_bits = malloc(table_size);
+
+ if (!table->symbols || !table->num_bits) {
+ free(table->symbols);
+ free(table->num_bits);
+ BAD_ALLOC();
+ }
+
+ // "Symbols are sorted by Weight. Within same Weight, symbols keep natural
+ // order. Symbols with a Weight of zero are removed. Then, starting from
+ // lowest weight, prefix codes are distributed in order."
+
+ u32 rank_idx[HUF_MAX_BITS + 1];
+ // Initialize the starting codes for each rank (number of bits)
+ rank_idx[max_bits] = 0;
+ for (int i = max_bits; i >= 1; i--) {
+ rank_idx[i - 1] = rank_idx[i] + rank_count[i] * (1 << (max_bits - i));
+ // The entire range takes the same number of bits so we can memset it
+ memset(&table->num_bits[rank_idx[i]], i, rank_idx[i - 1] - rank_idx[i]);
+ }
+
+ if (rank_idx[0] != table_size) {
+ CORRUPTION();
+ }
+
+ // Allocate codes and fill in the table
+ for (int i = 0; i < num_symbs; i++) {
+ if (bits[i] != 0) {
+ // Allocate a code for this symbol and set its range in the table
+ const u16 code = rank_idx[bits[i]];
+ // Since the code doesn't care about the bottom `max_bits - bits[i]`
+ // bits of state, it gets a range that spans all possible values of
+ // the lower bits
+ const u16 len = 1 << (max_bits - bits[i]);
+ memset(&table->symbols[code], i, len);
+ rank_idx[bits[i]] += len;
+ }
+ }
+}
+
+static void HUF_init_dtable_usingweights(HUF_dtable *const table,
+ const u8 *const weights,
+ const int num_symbs) {
+ // +1 because the last weight is not transmitted in the header
+ if (num_symbs + 1 > HUF_MAX_SYMBS) {
+ ERROR("Too many symbols for Huffman");
+ }
+
+ u8 bits[HUF_MAX_SYMBS];
+
+ u64 weight_sum = 0;
+ for (int i = 0; i < num_symbs; i++) {
+ // Weights are in the same range as bit count
+ if (weights[i] > HUF_MAX_BITS) {
+ CORRUPTION();
+ }
+ weight_sum += weights[i] > 0 ? (u64)1 << (weights[i] - 1) : 0;
+ }
+
+ // Find the first power of 2 larger than the sum
+ const int max_bits = highest_set_bit(weight_sum) + 1;
+ const u64 left_over = ((u64)1 << max_bits) - weight_sum;
+ // If the left over isn't a power of 2, the weights are invalid
+ if (left_over & (left_over - 1)) {
+ CORRUPTION();
+ }
+
+ // left_over is used to find the last weight as it's not transmitted
+ // by inverting 2^(weight - 1) we can determine the value of last_weight
+ const int last_weight = highest_set_bit(left_over) + 1;
+
+ for (int i = 0; i < num_symbs; i++) {
+ // "Number_of_Bits = Number_of_Bits ? Max_Number_of_Bits + 1 - Weight : 0"
+ bits[i] = weights[i] > 0 ? (max_bits + 1 - weights[i]) : 0;
+ }
+ bits[num_symbs] =
+ max_bits + 1 - last_weight; // Last weight is always non-zero
+
+ HUF_init_dtable(table, bits, num_symbs + 1);
+}
+
+static void HUF_free_dtable(HUF_dtable *const dtable) {
+ free(dtable->symbols);
+ free(dtable->num_bits);
+ memset(dtable, 0, sizeof(HUF_dtable));
+}
+
+static void HUF_copy_dtable(HUF_dtable *const dst,
+ const HUF_dtable *const src) {
+ if (src->max_bits == 0) {
+ memset(dst, 0, sizeof(HUF_dtable));
+ return;
+ }
+
+ const size_t size = (size_t)1 << src->max_bits;
+ dst->max_bits = src->max_bits;
+
+ dst->symbols = malloc(size);
+ dst->num_bits = malloc(size);
+ if (!dst->symbols || !dst->num_bits) {
+ BAD_ALLOC();
+ }
+
+ memcpy(dst->symbols, src->symbols, size);
+ memcpy(dst->num_bits, src->num_bits, size);
+}
+/******* END HUFFMAN PRIMITIVES ***********************************************/
+
+/******* FSE PRIMITIVES *******************************************************/
+/// For more description of FSE see
+/// https://github.com/Cyan4973/FiniteStateEntropy/
+
+/// Allow a symbol to be decoded without updating state
+static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
+ const u16 state) {
+ return dtable->symbols[state];
+}
+
+/// Consumes bits from the input and uses the current state to determine the
+/// next state
+static inline void FSE_update_state(const FSE_dtable *const dtable,
+ u16 *const state, const u8 *const src,
+ i64 *const offset) {
+ const u8 bits = dtable->num_bits[*state];
+ const u16 rest = STREAM_read_bits(src, bits, offset);
+ *state = dtable->new_state_base[*state] + rest;
+}
+
+/// Decodes a single FSE symbol and updates the offset
+static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
+ u16 *const state, const u8 *const src,
+ i64 *const offset) {
+ const u8 symb = FSE_peek_symbol(dtable, *state);
+ FSE_update_state(dtable, state, src, offset);
+ return symb;
+}
+
+static inline void FSE_init_state(const FSE_dtable *const dtable,
+ u16 *const state, const u8 *const src,
+ i64 *const offset) {
+ // Read in a full `accuracy_log` bits to initialize the state
+ const u8 bits = dtable->accuracy_log;
+ *state = STREAM_read_bits(src, bits, offset);
+}
+
+static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
+ ostream_t *const out,
+ istream_t *const in) {
+ const size_t len = IO_istream_len(in);
+ if (len == 0) {
+ INP_SIZE();
+ }
+ const u8 *const src = IO_get_read_ptr(in, len);
+
+ // "Each bitstream must be read backward, that is starting from the end down
+ // to the beginning. Therefore it's necessary to know the size of each
+ // bitstream.
+ //
+ // It's also necessary to know exactly which bit is the latest. This is
+ // detected by a final bit flag : the highest bit of latest byte is a
+ // final-bit-flag. Consequently, a last byte of 0 is not possible. And the
+ // final-bit-flag itself is not part of the useful bitstream. Hence, the
+ // last byte contains between 0 and 7 useful bits."
+ const int padding = 8 - highest_set_bit(src[len - 1]);
+ i64 offset = len * 8 - padding;
+
+ u16 state1, state2;
+ // "The first state (State1) encodes the even indexed symbols, and the
+ // second (State2) encodes the odd indexes. State1 is initialized first, and
+ // then State2, and they take turns decoding a single symbol and updating
+ // their state."
+ FSE_init_state(dtable, &state1, src, &offset);
+ FSE_init_state(dtable, &state2, src, &offset);
+
+ // Decode until we overflow the stream
+ // Since we decode in reverse order, overflowing the stream is offset going
+ // negative
+ size_t symbols_written = 0;
+ while (1) {
+ // "The number of symbols to decode is determined by tracking bitStream
+ // overflow condition: If updating state after decoding a symbol would
+ // require more bits than remain in the stream, it is assumed the extra
+ // bits are 0. Then, the symbols for each of the final states are
+ // decoded and the process is complete."
+ IO_write_byte(out, FSE_decode_symbol(dtable, &state1, src, &offset));
+ symbols_written++;
+ if (offset < 0) {
+ // There's still a symbol to decode in state2
+ IO_write_byte(out, FSE_peek_symbol(dtable, state2));
+ symbols_written++;
+ break;
+ }
+
+ IO_write_byte(out, FSE_decode_symbol(dtable, &state2, src, &offset));
+ symbols_written++;
+ if (offset < 0) {
+ // There's still a symbol to decode in state1
+ IO_write_byte(out, FSE_peek_symbol(dtable, state1));
+ symbols_written++;
+ break;
+ }
+ }
+
+ return symbols_written;
+}
+
+static void FSE_init_dtable(FSE_dtable *const dtable,
+ const i16 *const norm_freqs, const int num_symbs,
+ const int accuracy_log) {
+ if (accuracy_log > FSE_MAX_ACCURACY_LOG) {
+ ERROR("FSE accuracy too large");
+ }
+ if (num_symbs > FSE_MAX_SYMBS) {
+ ERROR("Too many symbols for FSE");
+ }
+
+ dtable->accuracy_log = accuracy_log;
+
+ const size_t size = (size_t)1 << accuracy_log;
+ dtable->symbols = malloc(size * sizeof(u8));
+ dtable->num_bits = malloc(size * sizeof(u8));
+ dtable->new_state_base = malloc(size * sizeof(u16));
+
+ if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
+ BAD_ALLOC();
+ }
+
+ // Used to determine how many bits need to be read for each state,
+ // and where the destination range should start
+ // Needs to be u16 because max value is 2 * max number of symbols,
+ // which can be larger than a byte can store
+ u16 state_desc[FSE_MAX_SYMBS];
+
+ // "Symbols are scanned in their natural order for "less than 1"
+ // probabilities. Symbols with this probability are being attributed a
+ // single cell, starting from the end of the table. These symbols define a
+ // full state reset, reading Accuracy_Log bits."
+ int high_threshold = size;
+ for (int s = 0; s < num_symbs; s++) {
+ // Scan for low probability symbols to put at the top
+ if (norm_freqs[s] == -1) {
+ dtable->symbols[--high_threshold] = s;
+ state_desc[s] = 1;
+ }
+ }
+
+ // "All remaining symbols are sorted in their natural order. Starting from
+ // symbol 0 and table position 0, each symbol gets attributed as many cells
+ // as its probability. Cell allocation is spreaded, not linear."
+ // Place the rest in the table
+ const u16 step = (size >> 1) + (size >> 3) + 3;
+ const u16 mask = size - 1;
+ u16 pos = 0;
+ for (int s = 0; s < num_symbs; s++) {
+ if (norm_freqs[s] <= 0) {
+ continue;
+ }
+
+ state_desc[s] = norm_freqs[s];
+
+ for (int i = 0; i < norm_freqs[s]; i++) {
+ // Give `norm_freqs[s]` states to symbol s
+ dtable->symbols[pos] = s;
+ // "A position is skipped if already occupied, typically by a "less
+ // than 1" probability symbol."
+ do {
+ pos = (pos + step) & mask;
+ } while (pos >=
+ high_threshold);
+ // Note: no other collision checking is necessary as `step` is
+ // coprime to `size`, so the cycle will visit each position exactly
+ // once
+ }
+ }
+ if (pos != 0) {
+ CORRUPTION();
+ }
+
+ // Now we can fill baseline and num bits
+ for (size_t i = 0; i < size; i++) {
+ u8 symbol = dtable->symbols[i];
+ u16 next_state_desc = state_desc[symbol]++;
+ // Fills in the table appropriately, next_state_desc increases by symbol
+ // over time, decreasing number of bits
+ dtable->num_bits[i] = (u8)(accuracy_log - highest_set_bit(next_state_desc));
+ // Baseline increases until the bit threshold is passed, at which point
+ // it resets to 0
+ dtable->new_state_base[i] =
+ ((u16)next_state_desc << dtable->num_bits[i]) - size;
+ }
+}
+
+/// Decode an FSE header as defined in the Zstandard format specification and
+/// use the decoded frequencies to initialize a decoding table.
+static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
+ const int max_accuracy_log) {
+ // "An FSE distribution table describes the probabilities of all symbols
+ // from 0 to the last present one (included) on a normalized scale of 1 <<
+ // Accuracy_Log .
+ //
+ // It's a bitstream which is read forward, in little-endian fashion. It's
+ // not necessary to know its exact size, since it will be discovered and
+ // reported by the decoding process.
+ if (max_accuracy_log > FSE_MAX_ACCURACY_LOG) {
+ ERROR("FSE accuracy too large");
+ }
+
+ // The bitstream starts by reporting on which scale it operates.
+ // Accuracy_Log = low4bits + 5. Note that maximum Accuracy_Log for literal
+ // and match lengths is 9, and for offsets is 8. Higher values are
+ // considered errors."
+ const int accuracy_log = 5 + IO_read_bits(in, 4);
+ if (accuracy_log > max_accuracy_log) {
+ ERROR("FSE accuracy too large");
+ }
+
+ // "Then follows each symbol value, from 0 to last present one. The number
+ // of bits used by each field is variable. It depends on :
+ //
+ // Remaining probabilities + 1 : example : Presuming an Accuracy_Log of 8,
+ // and presuming 100 probabilities points have already been distributed, the
+ // decoder may read any value from 0 to 255 - 100 + 1 == 156 (inclusive).
+ // Therefore, it must read log2sup(156) == 8 bits.
+ //
+ // Value decoded : small values use 1 less bit : example : Presuming values
+ // from 0 to 156 (inclusive) are possible, 255-156 = 99 values are remaining
+ // in an 8-bits field. They are used this way : first 99 values (hence from
+ // 0 to 98) use only 7 bits, values from 99 to 156 use 8 bits. "
+
+ i32 remaining = 1 << accuracy_log;
+ i16 frequencies[FSE_MAX_SYMBS];
+
+ int symb = 0;
+ while (remaining > 0 && symb < FSE_MAX_SYMBS) {
+ // Log of the number of possible values we could read
+ int bits = highest_set_bit(remaining + 1) + 1;
+
+ u16 val = IO_read_bits(in, bits);
+
+ // Try to mask out the lower bits to see if it qualifies for the "small
+ // value" threshold
+ const u16 lower_mask = ((u16)1 << (bits - 1)) - 1;
+ const u16 threshold = ((u16)1 << bits) - 1 - (remaining + 1);
+
+ if ((val & lower_mask) < threshold) {
+ IO_rewind_bits(in, 1);
+ val = val & lower_mask;
+ } else if (val > lower_mask) {
+ val = val - threshold;
+ }
+
+ // "Probability is obtained from Value decoded by following formula :
+ // Proba = value - 1"
+ const i16 proba = (i16)val - 1;
+
+ // "It means value 0 becomes negative probability -1. -1 is a special
+ // probability, which means "less than 1". Its effect on distribution
+ // table is described in next paragraph. For the purpose of calculating
+ // cumulated distribution, it counts as one."
+ remaining -= proba < 0 ? -proba : proba;
+
+ frequencies[symb] = proba;
+ symb++;
+
+ // "When a symbol has a probability of zero, it is followed by a 2-bits
+ // repeat flag. This repeat flag tells how many probabilities of zeroes
+ // follow the current one. It provides a number ranging from 0 to 3. If
+ // it is a 3, another 2-bits repeat flag follows, and so on."
+ if (proba == 0) {
+ // Read the next two bits to see how many more 0s
+ int repeat = IO_read_bits(in, 2);
+
+ while (1) {
+ for (int i = 0; i < repeat && symb < FSE_MAX_SYMBS; i++) {
+ frequencies[symb++] = 0;
+ }
+ if (repeat == 3) {
+ repeat = IO_read_bits(in, 2);
+ } else {
+ break;
+ }
+ }
+ }
+ }
+ IO_align_stream(in);
+
+ // "When last symbol reaches cumulated total of 1 << Accuracy_Log, decoding
+ // is complete. If the last symbol makes cumulated total go above 1 <<
+ // Accuracy_Log, distribution is considered corrupted."
+ if (remaining != 0 || symb >= FSE_MAX_SYMBS) {
+ CORRUPTION();
+ }
+
+ // Initialize the decoding table using the determined weights
+ FSE_init_dtable(dtable, frequencies, symb, accuracy_log);
+}
+
+static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb) {
+ dtable->symbols = malloc(sizeof(u8));
+ dtable->num_bits = malloc(sizeof(u8));
+ dtable->new_state_base = malloc(sizeof(u16));
+
+ if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
+ BAD_ALLOC();
+ }
+
+ // This setup will always have a state of 0, always return symbol `symb`,
+ // and never consume any bits
+ dtable->symbols[0] = symb;
+ dtable->num_bits[0] = 0;
+ dtable->new_state_base[0] = 0;
+ dtable->accuracy_log = 0;
+}
+
+static void FSE_free_dtable(FSE_dtable *const dtable) {
+ free(dtable->symbols);
+ free(dtable->num_bits);
+ free(dtable->new_state_base);
+ memset(dtable, 0, sizeof(FSE_dtable));
+}
+
+static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src) {
+ if (src->accuracy_log == 0) {
+ memset(dst, 0, sizeof(FSE_dtable));
+ return;
+ }
+
+ size_t size = (size_t)1 << src->accuracy_log;
+ dst->accuracy_log = src->accuracy_log;
+
+ dst->symbols = malloc(size);
+ dst->num_bits = malloc(size);
+ dst->new_state_base = malloc(size * sizeof(u16));
+ if (!dst->symbols || !dst->num_bits || !dst->new_state_base) {
+ BAD_ALLOC();
+ }
+
+ memcpy(dst->symbols, src->symbols, size);
+ memcpy(dst->num_bits, src->num_bits, size);
+ memcpy(dst->new_state_base, src->new_state_base, size * sizeof(u16));
+}
+/******* END FSE PRIMITIVES ***************************************************/
diff --git a/src/zstd/doc/educational_decoder/zstd_decompress.h b/src/zstd/doc/educational_decoder/zstd_decompress.h
new file mode 100644
index 00000000..a01fde33
--- /dev/null
+++ b/src/zstd/doc/educational_decoder/zstd_decompress.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under both the BSD-style license (found in the
+ * LICENSE file in the root directory of this source tree) and the GPLv2 (found
+ * in the COPYING file in the root directory of this source tree).
+ */
+
+/******* EXPOSED TYPES ********************************************************/
+/*
+* Contains the parsed contents of a dictionary
+* This includes Huffman and FSE tables used for decoding and data on offsets
+*/
+typedef struct dictionary_s dictionary_t;
+/******* END EXPOSED TYPES ****************************************************/
+
+/******* DECOMPRESSION FUNCTIONS **********************************************/
+/// Zstandard decompression functions.
+/// `dst` must point to a space at least as large as the reconstructed output.
+size_t ZSTD_decompress(void *const dst, const size_t dst_len,
+ const void *const src, const size_t src_len);
+
+/// If `dict != NULL` and `dict_len >= 8`, does the same thing as
+/// `ZSTD_decompress` but uses the provided dict
+size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len,
+ const void *const src, const size_t src_len,
+ dictionary_t* parsed_dict);
+
+/// Get the decompressed size of an input stream so memory can be allocated in
+/// advance
+/// Returns -1 if the size can't be determined
+/// Assumes decompression of a single frame
+size_t ZSTD_get_decompressed_size(const void *const src, const size_t src_len);
+/******* END DECOMPRESSION FUNCTIONS ******************************************/
+
+/******* DICTIONARY MANAGEMENT ***********************************************/
+/*
+ * Return a valid dictionary_t pointer for use with dictionary initialization
+ * or decompression
+ */
+dictionary_t* create_dictionary();
+
+/*
+ * Parse a provided dictionary blob for use in decompression
+ * `src` -- must point to memory space representing the dictionary
+ * `src_len` -- must provide the dictionary size
+ * `dict` -- will contain the parsed contents of the dictionary and
+ * can be used for decompression
+ */
+void parse_dictionary(dictionary_t *const dict, const void *src,
+ size_t src_len);
+
+/*
+ * Free internal Huffman tables, FSE tables, and dictionary content
+ */
+void free_dictionary(dictionary_t *const dict);
+/******* END DICTIONARY MANAGEMENT *******************************************/