summaryrefslogtreecommitdiffstats
path: root/third_party/jpeg-xl/lib/jxl/enc_fast_lossless.cc
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/jpeg-xl/lib/jxl/enc_fast_lossless.cc')
-rw-r--r--third_party/jpeg-xl/lib/jxl/enc_fast_lossless.cc3860
1 files changed, 3860 insertions, 0 deletions
diff --git a/third_party/jpeg-xl/lib/jxl/enc_fast_lossless.cc b/third_party/jpeg-xl/lib/jxl/enc_fast_lossless.cc
new file mode 100644
index 0000000000..286990ee8a
--- /dev/null
+++ b/third_party/jpeg-xl/lib/jxl/enc_fast_lossless.cc
@@ -0,0 +1,3860 @@
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#ifndef FJXL_SELF_INCLUDE
+
+#include "lib/jxl/enc_fast_lossless.h"
+
+#include <assert.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <string.h>
+
+#include <algorithm>
+#include <array>
+#include <limits>
+#include <memory>
+#include <vector>
+
+// Enable NEON and AVX2/AVX512 if not asked to do otherwise and the compilers
+// support it.
+#if defined(__aarch64__) || defined(_M_ARM64)
+#include <arm_neon.h>
+
+#ifndef FJXL_ENABLE_NEON
+#define FJXL_ENABLE_NEON 1
+#endif
+
+#elif (defined(__x86_64__) || defined(_M_X64)) && !defined(_MSC_VER)
+#include <immintrin.h>
+
+// manually add _mm512_cvtsi512_si32 definition if missing
+// (e.g. with Xcode on macOS Mojave)
+// copied from gcc 11.1.0 include/avx512fintrin.h line 14367-14373
+#if defined(__clang__) && \
+ ((!defined(__apple_build_version__) && __clang_major__ < 10) || \
+ (defined(__apple_build_version__) && __apple_build_version__ < 12000032))
+inline int __attribute__((__gnu_inline__, __always_inline__, __artificial__))
+_mm512_cvtsi512_si32(__m512i __A) {
+ __v16si __B = (__v16si)__A;
+ return __B[0];
+}
+#endif
+
+// TODO(veluca): MSVC support for dynamic dispatch.
+#if defined(__clang__) || defined(__GNUC__)
+
+#ifndef FJXL_ENABLE_AVX2
+#define FJXL_ENABLE_AVX2 1
+#endif
+
+#ifndef FJXL_ENABLE_AVX512
+// On clang-7 or earlier, and gcc-10 or earlier, AVX512 seems broken.
+#if (defined(__clang__) && \
+ (!defined(__apple_build_version__) && __clang_major__ > 7) || \
+ (defined(__apple_build_version__) && \
+ __apple_build_version__ > 10010046)) || \
+ (defined(__GNUC__) && __GNUC__ > 10)
+#define FJXL_ENABLE_AVX512 1
+#endif
+#endif
+
+#endif
+
+#endif
+
+#ifndef FJXL_ENABLE_NEON
+#define FJXL_ENABLE_NEON 0
+#endif
+
+#ifndef FJXL_ENABLE_AVX2
+#define FJXL_ENABLE_AVX2 0
+#endif
+
+#ifndef FJXL_ENABLE_AVX512
+#define FJXL_ENABLE_AVX512 0
+#endif
+
+namespace {
+#if defined(_MSC_VER) && !defined(__clang__)
+#define FJXL_INLINE __forceinline
+FJXL_INLINE uint32_t FloorLog2(uint32_t v) {
+ unsigned long index;
+ _BitScanReverse(&index, v);
+ return index;
+}
+FJXL_INLINE uint32_t CtzNonZero(uint64_t v) {
+ unsigned long index;
+ _BitScanForward(&index, v);
+ return index;
+}
+#else
+#define FJXL_INLINE inline __attribute__((always_inline))
+FJXL_INLINE uint32_t FloorLog2(uint32_t v) {
+ return v ? 31 - __builtin_clz(v) : 0;
+}
+FJXL_INLINE uint32_t CtzNonZero(uint64_t v) { return __builtin_ctzll(v); }
+#endif
+
+// Compiles to a memcpy on little-endian systems.
+FJXL_INLINE void StoreLE64(uint8_t* tgt, uint64_t data) {
+#if (!defined(__BYTE_ORDER__) || (__BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__))
+ for (int i = 0; i < 8; i++) {
+ tgt[i] = (data >> (i * 8)) & 0xFF;
+ }
+#else
+ memcpy(tgt, &data, 8);
+#endif
+}
+
+FJXL_INLINE size_t AddBits(uint32_t count, uint64_t bits, uint8_t* data_buf,
+ size_t& bits_in_buffer, uint64_t& bit_buffer) {
+ bit_buffer |= bits << bits_in_buffer;
+ bits_in_buffer += count;
+ StoreLE64(data_buf, bit_buffer);
+ size_t bytes_in_buffer = bits_in_buffer / 8;
+ bits_in_buffer -= bytes_in_buffer * 8;
+ bit_buffer >>= bytes_in_buffer * 8;
+ return bytes_in_buffer;
+}
+
+struct BitWriter {
+ void Allocate(size_t maximum_bit_size) {
+ assert(data == nullptr);
+ // Leave some padding.
+ data.reset(static_cast<uint8_t*>(malloc(maximum_bit_size / 8 + 64)));
+ }
+
+ void Write(uint32_t count, uint64_t bits) {
+ bytes_written += AddBits(count, bits, data.get() + bytes_written,
+ bits_in_buffer, buffer);
+ }
+
+ void ZeroPadToByte() {
+ if (bits_in_buffer != 0) {
+ Write(8 - bits_in_buffer, 0);
+ }
+ }
+
+ FJXL_INLINE void WriteMultiple(const uint64_t* nbits, const uint64_t* bits,
+ size_t n) {
+ // Necessary because Write() is only guaranteed to work with <=56 bits.
+ // Trying to SIMD-fy this code results in lower speed (and definitely less
+ // clarity).
+ {
+ for (size_t i = 0; i < n; i++) {
+ this->buffer |= bits[i] << this->bits_in_buffer;
+ memcpy(this->data.get() + this->bytes_written, &this->buffer, 8);
+ uint64_t shift = 64 - this->bits_in_buffer;
+ this->bits_in_buffer += nbits[i];
+ // This `if` seems to be faster than using ternaries.
+ if (this->bits_in_buffer >= 64) {
+ uint64_t next_buffer = bits[i] >> shift;
+ this->buffer = next_buffer;
+ this->bits_in_buffer -= 64;
+ this->bytes_written += 8;
+ }
+ }
+ memcpy(this->data.get() + this->bytes_written, &this->buffer, 8);
+ size_t bytes_in_buffer = this->bits_in_buffer / 8;
+ this->bits_in_buffer -= bytes_in_buffer * 8;
+ this->buffer >>= bytes_in_buffer * 8;
+ this->bytes_written += bytes_in_buffer;
+ }
+ }
+
+ std::unique_ptr<uint8_t[], void (*)(void*)> data = {nullptr, free};
+ size_t bytes_written = 0;
+ size_t bits_in_buffer = 0;
+ uint64_t buffer = 0;
+};
+
+} // namespace
+
+extern "C" {
+
+struct JxlFastLosslessFrameState {
+ size_t width;
+ size_t height;
+ size_t nb_chans;
+ size_t bitdepth;
+ BitWriter header;
+ std::vector<std::array<BitWriter, 4>> group_data;
+ size_t current_bit_writer = 0;
+ size_t bit_writer_byte_pos = 0;
+ size_t bits_in_buffer = 0;
+ uint64_t bit_buffer = 0;
+};
+
+size_t JxlFastLosslessOutputSize(const JxlFastLosslessFrameState* frame) {
+ size_t total_size_groups = 0;
+ for (size_t i = 0; i < frame->group_data.size(); i++) {
+ size_t sz = 0;
+ for (size_t j = 0; j < frame->nb_chans; j++) {
+ const auto& writer = frame->group_data[i][j];
+ sz += writer.bytes_written * 8 + writer.bits_in_buffer;
+ }
+ sz = (sz + 7) / 8;
+ total_size_groups += sz;
+ }
+ return frame->header.bytes_written + total_size_groups;
+}
+
+size_t JxlFastLosslessMaxRequiredOutput(
+ const JxlFastLosslessFrameState* frame) {
+ return JxlFastLosslessOutputSize(frame) + 32;
+}
+
+void JxlFastLosslessPrepareHeader(JxlFastLosslessFrameState* frame,
+ int add_image_header, int is_last) {
+ BitWriter* output = &frame->header;
+ output->Allocate(1000 + frame->group_data.size() * 32);
+
+ std::vector<size_t> group_sizes(frame->group_data.size());
+ for (size_t i = 0; i < frame->group_data.size(); i++) {
+ size_t sz = 0;
+ for (size_t j = 0; j < frame->nb_chans; j++) {
+ const auto& writer = frame->group_data[i][j];
+ sz += writer.bytes_written * 8 + writer.bits_in_buffer;
+ }
+ sz = (sz + 7) / 8;
+ group_sizes[i] = sz;
+ }
+
+ bool have_alpha = (frame->nb_chans == 2 || frame->nb_chans == 4);
+
+ if (add_image_header) {
+ // Signature
+ output->Write(16, 0x0AFF);
+
+ // Size header, hand-crafted.
+ // Not small
+ output->Write(1, 0);
+
+ auto wsz = [output](size_t size) {
+ if (size - 1 < (1 << 9)) {
+ output->Write(2, 0b00);
+ output->Write(9, size - 1);
+ } else if (size - 1 < (1 << 13)) {
+ output->Write(2, 0b01);
+ output->Write(13, size - 1);
+ } else if (size - 1 < (1 << 18)) {
+ output->Write(2, 0b10);
+ output->Write(18, size - 1);
+ } else {
+ output->Write(2, 0b11);
+ output->Write(30, size - 1);
+ }
+ };
+
+ wsz(frame->height);
+
+ // No special ratio.
+ output->Write(3, 0);
+
+ wsz(frame->width);
+
+ // Hand-crafted ImageMetadata.
+ output->Write(1, 0); // all_default
+ output->Write(1, 0); // extra_fields
+ output->Write(1, 0); // bit_depth.floating_point_sample
+ if (frame->bitdepth == 8) {
+ output->Write(2, 0b00); // bit_depth.bits_per_sample = 8
+ } else if (frame->bitdepth == 10) {
+ output->Write(2, 0b01); // bit_depth.bits_per_sample = 10
+ } else if (frame->bitdepth == 12) {
+ output->Write(2, 0b10); // bit_depth.bits_per_sample = 12
+ } else {
+ output->Write(2, 0b11); // 1 + u(6)
+ output->Write(6, frame->bitdepth - 1);
+ }
+ if (frame->bitdepth <= 14) {
+ output->Write(1, 1); // 16-bit-buffer sufficient
+ } else {
+ output->Write(1, 0); // 16-bit-buffer NOT sufficient
+ }
+ if (have_alpha) {
+ output->Write(2, 0b01); // One extra channel
+ output->Write(1, 1); // ... all_default (ie. 8-bit alpha)
+ } else {
+ output->Write(2, 0b00); // No extra channel
+ }
+ output->Write(1, 0); // Not XYB
+ if (frame->nb_chans > 2) {
+ output->Write(1, 1); // color_encoding.all_default (sRGB)
+ } else {
+ output->Write(1, 0); // color_encoding.all_default false
+ output->Write(1, 0); // color_encoding.want_icc false
+ output->Write(2, 1); // grayscale
+ output->Write(2, 1); // D65
+ output->Write(1, 0); // no gamma transfer function
+ output->Write(2, 0b10); // tf: 2 + u(4)
+ output->Write(4, 11); // tf of sRGB
+ output->Write(2, 1); // relative rendering intent
+ }
+ output->Write(2, 0b00); // No extensions.
+
+ output->Write(1, 1); // all_default transform data
+
+ // No ICC, no preview. Frame should start at byte boundery.
+ output->ZeroPadToByte();
+ }
+
+ // Handcrafted frame header.
+ output->Write(1, 0); // all_default
+ output->Write(2, 0b00); // regular frame
+ output->Write(1, 1); // modular
+ output->Write(2, 0b00); // default flags
+ output->Write(1, 0); // not YCbCr
+ output->Write(2, 0b00); // no upsampling
+ if (have_alpha) {
+ output->Write(2, 0b00); // no alpha upsampling
+ }
+ output->Write(2, 0b01); // default group size
+ output->Write(2, 0b00); // exactly one pass
+ output->Write(1, 0); // no custom size or origin
+ output->Write(2, 0b00); // kReplace blending mode
+ if (have_alpha) {
+ output->Write(2, 0b00); // kReplace blending mode for alpha channel
+ }
+ output->Write(1, is_last); // is_last
+ output->Write(2, 0b00); // a frame has no name
+ output->Write(1, 0); // loop filter is not all_default
+ output->Write(1, 0); // no gaborish
+ output->Write(2, 0); // 0 EPF iters
+ output->Write(2, 0b00); // No LF extensions
+ output->Write(2, 0b00); // No FH extensions
+
+ output->Write(1, 0); // No TOC permutation
+ output->ZeroPadToByte(); // TOC is byte-aligned.
+ for (size_t i = 0; i < frame->group_data.size(); i++) {
+ size_t sz = group_sizes[i];
+ if (sz < (1 << 10)) {
+ output->Write(2, 0b00);
+ output->Write(10, sz);
+ } else if (sz - 1024 < (1 << 14)) {
+ output->Write(2, 0b01);
+ output->Write(14, sz - 1024);
+ } else if (sz - 17408 < (1 << 22)) {
+ output->Write(2, 0b10);
+ output->Write(22, sz - 17408);
+ } else {
+ output->Write(2, 0b11);
+ output->Write(30, sz - 4211712);
+ }
+ }
+ output->ZeroPadToByte(); // Groups are byte-aligned.
+}
+
+#if FJXL_ENABLE_AVX512
+__attribute__((target("avx512vbmi2"))) static size_t AppendBytesWithBitOffset(
+ const uint8_t* data, size_t n, size_t bit_buffer_nbits,
+ unsigned char* output, uint64_t& bit_buffer) {
+ if (n < 128) {
+ return 0;
+ }
+
+ size_t i = 0;
+ __m512i shift = _mm512_set1_epi64(64 - bit_buffer_nbits);
+ __m512i carry = _mm512_set1_epi64(bit_buffer << (64 - bit_buffer_nbits));
+
+ for (; i + 64 <= n; i += 64) {
+ __m512i current = _mm512_loadu_si512(data + i);
+ __m512i previous_u64 = _mm512_alignr_epi64(current, carry, 7);
+ carry = current;
+ __m512i out = _mm512_shrdv_epi64(previous_u64, current, shift);
+ _mm512_storeu_si512(output + i, out);
+ }
+
+ bit_buffer = data[i - 1] >> (8 - bit_buffer_nbits);
+
+ return i;
+}
+#endif
+
+size_t JxlFastLosslessWriteOutput(JxlFastLosslessFrameState* frame,
+ unsigned char* output, size_t output_size) {
+ assert(output_size >= 32);
+ unsigned char* initial_output = output;
+ size_t (*append_bytes_with_bit_offset)(const uint8_t*, size_t, size_t,
+ unsigned char*, uint64_t&) = nullptr;
+
+#if FJXL_ENABLE_AVX512
+ if (__builtin_cpu_supports("avx512vbmi2")) {
+ append_bytes_with_bit_offset = AppendBytesWithBitOffset;
+ }
+#endif
+
+ while (true) {
+ size_t& cur = frame->current_bit_writer;
+ size_t& bw_pos = frame->bit_writer_byte_pos;
+ if (cur >= 1 + frame->group_data.size() * frame->nb_chans) {
+ return output - initial_output;
+ }
+ if (output_size <= 8) {
+ return output - initial_output;
+ }
+ size_t nbc = frame->nb_chans;
+ const BitWriter& writer =
+ cur == 0 ? frame->header
+ : frame->group_data[(cur - 1) / nbc][(cur - 1) % nbc];
+ size_t full_byte_count =
+ std::min(output_size - 8, writer.bytes_written - bw_pos);
+ if (frame->bits_in_buffer == 0) {
+ memcpy(output, writer.data.get() + bw_pos, full_byte_count);
+ } else {
+ size_t i = 0;
+ if (append_bytes_with_bit_offset) {
+ i += append_bytes_with_bit_offset(
+ writer.data.get() + bw_pos, full_byte_count, frame->bits_in_buffer,
+ output, frame->bit_buffer);
+ }
+#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
+ // Copy 8 bytes at a time until we reach the border.
+ for (; i + 8 < full_byte_count; i += 8) {
+ uint64_t chunk;
+ memcpy(&chunk, writer.data.get() + bw_pos + i, 8);
+ uint64_t out = frame->bit_buffer | (chunk << frame->bits_in_buffer);
+ memcpy(output + i, &out, 8);
+ frame->bit_buffer = chunk >> (64 - frame->bits_in_buffer);
+ }
+#endif
+ for (; i < full_byte_count; i++) {
+ AddBits(8, writer.data.get()[bw_pos + i], output + i,
+ frame->bits_in_buffer, frame->bit_buffer);
+ }
+ }
+ output += full_byte_count;
+ output_size -= full_byte_count;
+ bw_pos += full_byte_count;
+ if (bw_pos == writer.bytes_written) {
+ auto write = [&](size_t num, uint64_t bits) {
+ size_t n = AddBits(num, bits, output, frame->bits_in_buffer,
+ frame->bit_buffer);
+ output += n;
+ output_size -= n;
+ };
+ if (writer.bits_in_buffer) {
+ write(writer.bits_in_buffer, writer.buffer);
+ }
+ bw_pos = 0;
+ cur++;
+ if ((cur - 1) % nbc == 0 && frame->bits_in_buffer != 0) {
+ write(8 - frame->bits_in_buffer, 0);
+ }
+ }
+ }
+}
+
+void JxlFastLosslessFreeFrameState(JxlFastLosslessFrameState* frame) {
+ delete frame;
+}
+
+} // extern "C"
+
+#endif
+
+#ifdef FJXL_SELF_INCLUDE
+
+namespace {
+
+constexpr size_t kNumRawSymbols = 19;
+constexpr size_t kNumLZ77 = 33;
+constexpr size_t kLZ77CacheSize = 32;
+
+constexpr size_t kLZ77Offset = 224;
+constexpr size_t kLZ77MinLength = 7;
+
+void EncodeHybridUintLZ77(uint32_t value, uint32_t* token, uint32_t* nbits,
+ uint32_t* bits) {
+ // 400 config
+ uint32_t n = FloorLog2(value);
+ *token = value < 16 ? value : 16 + n - 4;
+ *nbits = value < 16 ? 0 : n;
+ *bits = value < 16 ? 0 : value - (1 << *nbits);
+}
+
+struct PrefixCode {
+ uint8_t raw_nbits[kNumRawSymbols] = {};
+ uint8_t raw_bits[kNumRawSymbols] = {};
+
+ alignas(64) uint8_t raw_nbits_simd[16] = {};
+ alignas(64) uint8_t raw_bits_simd[16] = {};
+
+ uint8_t lz77_nbits[kNumLZ77] = {};
+ uint16_t lz77_bits[kNumLZ77] = {};
+
+ uint64_t lz77_cache_bits[kLZ77CacheSize] = {};
+ uint8_t lz77_cache_nbits[kLZ77CacheSize] = {};
+
+ static uint16_t BitReverse(size_t nbits, uint16_t bits) {
+ constexpr uint16_t kNibbleLookup[16] = {
+ 0b0000, 0b1000, 0b0100, 0b1100, 0b0010, 0b1010, 0b0110, 0b1110,
+ 0b0001, 0b1001, 0b0101, 0b1101, 0b0011, 0b1011, 0b0111, 0b1111,
+ };
+ uint16_t rev16 = (kNibbleLookup[bits & 0xF] << 12) |
+ (kNibbleLookup[(bits >> 4) & 0xF] << 8) |
+ (kNibbleLookup[(bits >> 8) & 0xF] << 4) |
+ (kNibbleLookup[bits >> 12]);
+ return rev16 >> (16 - nbits);
+ }
+
+ // Create the prefix codes given the code lengths.
+ // Supports the code lengths being split into two halves.
+ static void ComputeCanonicalCode(const uint8_t* first_chunk_nbits,
+ uint8_t* first_chunk_bits,
+ size_t first_chunk_size,
+ const uint8_t* second_chunk_nbits,
+ uint16_t* second_chunk_bits,
+ size_t second_chunk_size) {
+ constexpr size_t kMaxCodeLength = 15;
+ uint8_t code_length_counts[kMaxCodeLength + 1] = {};
+ for (size_t i = 0; i < first_chunk_size; i++) {
+ code_length_counts[first_chunk_nbits[i]]++;
+ assert(first_chunk_nbits[i] <= kMaxCodeLength);
+ assert(first_chunk_nbits[i] <= 8);
+ assert(first_chunk_nbits[i] > 0);
+ }
+ for (size_t i = 0; i < second_chunk_size; i++) {
+ code_length_counts[second_chunk_nbits[i]]++;
+ assert(second_chunk_nbits[i] <= kMaxCodeLength);
+ }
+
+ uint16_t next_code[kMaxCodeLength + 1] = {};
+
+ uint16_t code = 0;
+ for (size_t i = 1; i < kMaxCodeLength + 1; i++) {
+ code = (code + code_length_counts[i - 1]) << 1;
+ next_code[i] = code;
+ }
+
+ for (size_t i = 0; i < first_chunk_size; i++) {
+ first_chunk_bits[i] =
+ BitReverse(first_chunk_nbits[i], next_code[first_chunk_nbits[i]]++);
+ }
+ for (size_t i = 0; i < second_chunk_size; i++) {
+ second_chunk_bits[i] =
+ BitReverse(second_chunk_nbits[i], next_code[second_chunk_nbits[i]]++);
+ }
+ }
+
+ template <typename T>
+ static void ComputeCodeLengthsNonZeroImpl(const uint64_t* freqs, size_t n,
+ size_t precision, T infty,
+ uint8_t* min_limit,
+ uint8_t* max_limit,
+ uint8_t* nbits) {
+ std::vector<T> dynp(((1U << precision) + 1) * (n + 1), infty);
+ auto d = [&](size_t sym, size_t off) -> T& {
+ return dynp[sym * ((1 << precision) + 1) + off];
+ };
+ d(0, 0) = 0;
+ for (size_t sym = 0; sym < n; sym++) {
+ for (T bits = min_limit[sym]; bits <= max_limit[sym]; bits++) {
+ size_t off_delta = 1U << (precision - bits);
+ for (size_t off = 0; off + off_delta <= (1U << precision); off++) {
+ d(sym + 1, off + off_delta) =
+ std::min(d(sym, off) + static_cast<T>(freqs[sym]) * bits,
+ d(sym + 1, off + off_delta));
+ }
+ }
+ }
+
+ size_t sym = n;
+ size_t off = 1U << precision;
+
+ assert(d(sym, off) != infty);
+
+ while (sym-- > 0) {
+ assert(off > 0);
+ for (size_t bits = min_limit[sym]; bits <= max_limit[sym]; bits++) {
+ size_t off_delta = 1U << (precision - bits);
+ if (off_delta <= off &&
+ d(sym + 1, off) == d(sym, off - off_delta) + freqs[sym] * bits) {
+ off -= off_delta;
+ nbits[sym] = bits;
+ break;
+ }
+ }
+ }
+ }
+
+ // Computes nbits[i] for i <= n, subject to min_limit[i] <= nbits[i] <=
+ // max_limit[i] and sum 2**-nbits[i] == 1, so to minimize sum(nbits[i] *
+ // freqs[i]).
+ static void ComputeCodeLengthsNonZero(const uint64_t* freqs, size_t n,
+ uint8_t* min_limit, uint8_t* max_limit,
+ uint8_t* nbits) {
+ size_t precision = 0;
+ size_t shortest_length = 255;
+ uint64_t freqsum = 0;
+ for (size_t i = 0; i < n; i++) {
+ assert(freqs[i] != 0);
+ freqsum += freqs[i];
+ if (min_limit[i] < 1) min_limit[i] = 1;
+ assert(min_limit[i] <= max_limit[i]);
+ precision = std::max<size_t>(max_limit[i], precision);
+ shortest_length = std::min<size_t>(min_limit[i], shortest_length);
+ }
+ // If all the minimum limits are greater than 1, shift precision so that we
+ // behave as if the shortest was 1.
+ precision -= shortest_length - 1;
+ uint64_t infty = freqsum * precision;
+ if (infty < std::numeric_limits<uint32_t>::max() / 2) {
+ ComputeCodeLengthsNonZeroImpl(freqs, n, precision,
+ static_cast<uint32_t>(infty), min_limit,
+ max_limit, nbits);
+ } else {
+ ComputeCodeLengthsNonZeroImpl(freqs, n, precision, infty, min_limit,
+ max_limit, nbits);
+ }
+ }
+
+ static constexpr size_t kMaxNumSymbols =
+ kNumRawSymbols + 1 < kNumLZ77 ? kNumLZ77 : kNumRawSymbols + 1;
+ static void ComputeCodeLengths(const uint64_t* freqs, size_t n,
+ const uint8_t* min_limit_in,
+ const uint8_t* max_limit_in, uint8_t* nbits) {
+ assert(n <= kMaxNumSymbols);
+ uint64_t compact_freqs[kMaxNumSymbols];
+ uint8_t min_limit[kMaxNumSymbols];
+ uint8_t max_limit[kMaxNumSymbols];
+ size_t ni = 0;
+ for (size_t i = 0; i < n; i++) {
+ if (freqs[i]) {
+ compact_freqs[ni] = freqs[i];
+ min_limit[ni] = min_limit_in[i];
+ max_limit[ni] = max_limit_in[i];
+ ni++;
+ }
+ }
+ uint8_t num_bits[kMaxNumSymbols] = {};
+ ComputeCodeLengthsNonZero(compact_freqs, ni, min_limit, max_limit,
+ num_bits);
+ ni = 0;
+ for (size_t i = 0; i < n; i++) {
+ nbits[i] = 0;
+ if (freqs[i]) {
+ nbits[i] = num_bits[ni++];
+ }
+ }
+ }
+
+ // Invalid code, used to construct arrays.
+ PrefixCode() {}
+
+ template <typename BitDepth>
+ PrefixCode(BitDepth, uint64_t* raw_counts, uint64_t* lz77_counts) {
+ // "merge" together all the lz77 counts in a single symbol for the level 1
+ // table (containing just the raw symbols, up to length 7).
+ uint64_t level1_counts[kNumRawSymbols + 1];
+ memcpy(level1_counts, raw_counts, kNumRawSymbols * sizeof(uint64_t));
+ size_t numraw = kNumRawSymbols;
+ while (numraw > 0 && level1_counts[numraw - 1] == 0) numraw--;
+
+ level1_counts[numraw] = 0;
+ for (size_t i = 0; i < kNumLZ77; i++) {
+ level1_counts[numraw] += lz77_counts[i];
+ }
+ uint8_t level1_nbits[kNumRawSymbols + 1] = {};
+ ComputeCodeLengths(level1_counts, numraw + 1, BitDepth::kMinRawLength,
+ BitDepth::kMaxRawLength, level1_nbits);
+
+ uint8_t level2_nbits[kNumLZ77] = {};
+ uint8_t min_lengths[kNumLZ77] = {};
+ uint8_t l = 15 - level1_nbits[numraw];
+ uint8_t max_lengths[kNumLZ77];
+ for (size_t i = 0; i < kNumLZ77; i++) {
+ max_lengths[i] = l;
+ }
+ size_t num_lz77 = kNumLZ77;
+ while (num_lz77 > 0 && lz77_counts[num_lz77 - 1] == 0) num_lz77--;
+ ComputeCodeLengths(lz77_counts, num_lz77, min_lengths, max_lengths,
+ level2_nbits);
+ for (size_t i = 0; i < numraw; i++) {
+ raw_nbits[i] = level1_nbits[i];
+ }
+ for (size_t i = 0; i < num_lz77; i++) {
+ lz77_nbits[i] =
+ level2_nbits[i] ? level1_nbits[numraw] + level2_nbits[i] : 0;
+ }
+
+ ComputeCanonicalCode(raw_nbits, raw_bits, numraw, lz77_nbits, lz77_bits,
+ kNumLZ77);
+ BitDepth::PrepareForSimd(raw_nbits, raw_bits, numraw, raw_nbits_simd,
+ raw_bits_simd);
+
+ // Prepare lz77 cache
+ for (size_t count = 0; count < kLZ77CacheSize; count++) {
+ unsigned token, nbits, bits;
+ EncodeHybridUintLZ77(count, &token, &nbits, &bits);
+ lz77_cache_nbits[count] = lz77_nbits[token] + nbits + raw_nbits[0];
+ lz77_cache_bits[count] =
+ (((bits << lz77_nbits[token]) | lz77_bits[token]) << raw_nbits[0]) |
+ raw_bits[0];
+ }
+ }
+
+ void WriteTo(BitWriter* writer) const {
+ uint64_t code_length_counts[18] = {};
+ code_length_counts[17] = 3 + 2 * (kNumLZ77 - 1);
+ for (size_t i = 0; i < kNumRawSymbols; i++) {
+ code_length_counts[raw_nbits[i]]++;
+ }
+ for (size_t i = 0; i < kNumLZ77; i++) {
+ code_length_counts[lz77_nbits[i]]++;
+ }
+ uint8_t code_length_nbits[18] = {};
+ uint8_t code_length_nbits_min[18] = {};
+ uint8_t code_length_nbits_max[18] = {
+ 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
+ };
+ ComputeCodeLengths(code_length_counts, 18, code_length_nbits_min,
+ code_length_nbits_max, code_length_nbits);
+ writer->Write(2, 0b00); // HSKIP = 0, i.e. don't skip code lengths.
+
+ // As per Brotli RFC.
+ uint8_t code_length_order[18] = {1, 2, 3, 4, 0, 5, 17, 6, 16,
+ 7, 8, 9, 10, 11, 12, 13, 14, 15};
+ uint8_t code_length_length_nbits[] = {2, 4, 3, 2, 2, 4};
+ uint8_t code_length_length_bits[] = {0, 7, 3, 2, 1, 15};
+
+ // Encode lengths of code lengths.
+ size_t num_code_lengths = 18;
+ while (code_length_nbits[code_length_order[num_code_lengths - 1]] == 0) {
+ num_code_lengths--;
+ }
+ for (size_t i = 0; i < num_code_lengths; i++) {
+ int symbol = code_length_nbits[code_length_order[i]];
+ writer->Write(code_length_length_nbits[symbol],
+ code_length_length_bits[symbol]);
+ }
+
+ // Compute the canonical codes for the codes that represent the lengths of
+ // the actual codes for data.
+ uint16_t code_length_bits[18] = {};
+ ComputeCanonicalCode(nullptr, nullptr, 0, code_length_nbits,
+ code_length_bits, 18);
+ // Encode raw bit code lengths.
+ for (size_t i = 0; i < kNumRawSymbols; i++) {
+ writer->Write(code_length_nbits[raw_nbits[i]],
+ code_length_bits[raw_nbits[i]]);
+ }
+ size_t num_lz77 = kNumLZ77;
+ while (lz77_nbits[num_lz77 - 1] == 0) {
+ num_lz77--;
+ }
+ // Encode 0s until 224 (start of LZ77 symbols). This is in total 224-19 =
+ // 205.
+ static_assert(kLZ77Offset == 224, "");
+ static_assert(kNumRawSymbols == 19, "");
+ writer->Write(code_length_nbits[17], code_length_bits[17]);
+ writer->Write(3, 0b010); // 5
+ writer->Write(code_length_nbits[17], code_length_bits[17]);
+ writer->Write(3, 0b000); // (5-2)*8 + 3 = 27
+ writer->Write(code_length_nbits[17], code_length_bits[17]);
+ writer->Write(3, 0b010); // (27-2)*8 + 5 = 205
+ // Encode LZ77 symbols, with values 224+i.
+ for (size_t i = 0; i < num_lz77; i++) {
+ writer->Write(code_length_nbits[lz77_nbits[i]],
+ code_length_bits[lz77_nbits[i]]);
+ }
+ }
+};
+
+template <typename T>
+struct VecPair {
+ T low;
+ T hi;
+};
+
+#ifdef FJXL_GENERIC_SIMD
+#undef FJXL_GENERIC_SIMD
+#endif
+
+#ifdef FJXL_AVX512
+#define FJXL_GENERIC_SIMD
+struct SIMDVec32;
+struct Mask32 {
+ __mmask16 mask;
+ SIMDVec32 IfThenElse(const SIMDVec32& if_true, const SIMDVec32& if_false);
+ size_t CountPrefix() const {
+ return CtzNonZero(~uint64_t{_cvtmask16_u32(mask)});
+ }
+};
+
+struct SIMDVec32 {
+ __m512i vec;
+
+ static constexpr size_t kLanes = 16;
+
+ FJXL_INLINE static SIMDVec32 Load(const uint32_t* data) {
+ return SIMDVec32{_mm512_loadu_si512((__m512i*)data)};
+ }
+ FJXL_INLINE void Store(uint32_t* data) {
+ _mm512_storeu_si512((__m512i*)data, vec);
+ }
+ FJXL_INLINE static SIMDVec32 Val(uint32_t v) {
+ return SIMDVec32{_mm512_set1_epi32(v)};
+ }
+ FJXL_INLINE SIMDVec32 ValToToken() const {
+ return SIMDVec32{
+ _mm512_sub_epi32(_mm512_set1_epi32(32), _mm512_lzcnt_epi32(vec))};
+ }
+ FJXL_INLINE SIMDVec32 SatSubU(const SIMDVec32& to_subtract) const {
+ return SIMDVec32{_mm512_sub_epi32(_mm512_max_epu32(vec, to_subtract.vec),
+ to_subtract.vec)};
+ }
+ FJXL_INLINE SIMDVec32 Sub(const SIMDVec32& to_subtract) const {
+ return SIMDVec32{_mm512_sub_epi32(vec, to_subtract.vec)};
+ }
+ FJXL_INLINE SIMDVec32 Add(const SIMDVec32& oth) const {
+ return SIMDVec32{_mm512_add_epi32(vec, oth.vec)};
+ }
+ FJXL_INLINE SIMDVec32 Xor(const SIMDVec32& oth) const {
+ return SIMDVec32{_mm512_xor_epi32(vec, oth.vec)};
+ }
+ FJXL_INLINE Mask32 Eq(const SIMDVec32& oth) const {
+ return Mask32{_mm512_cmpeq_epi32_mask(vec, oth.vec)};
+ }
+ FJXL_INLINE Mask32 Gt(const SIMDVec32& oth) const {
+ return Mask32{_mm512_cmpgt_epi32_mask(vec, oth.vec)};
+ }
+ FJXL_INLINE SIMDVec32 Pow2() const {
+ return SIMDVec32{_mm512_sllv_epi32(_mm512_set1_epi32(1), vec)};
+ }
+ template <size_t i>
+ FJXL_INLINE SIMDVec32 SignedShiftRight() const {
+ return SIMDVec32{_mm512_srai_epi32(vec, i)};
+ }
+};
+
+struct SIMDVec16;
+
+struct Mask16 {
+ __mmask32 mask;
+ SIMDVec16 IfThenElse(const SIMDVec16& if_true, const SIMDVec16& if_false);
+ Mask16 And(const Mask16& oth) const {
+ return Mask16{_kand_mask32(mask, oth.mask)};
+ }
+ size_t CountPrefix() const {
+ return CtzNonZero(~uint64_t{_cvtmask32_u32(mask)});
+ }
+};
+
+struct SIMDVec16 {
+ __m512i vec;
+
+ static constexpr size_t kLanes = 32;
+
+ FJXL_INLINE static SIMDVec16 Load(const uint16_t* data) {
+ return SIMDVec16{_mm512_loadu_si512((__m512i*)data)};
+ }
+ FJXL_INLINE void Store(uint16_t* data) {
+ _mm512_storeu_si512((__m512i*)data, vec);
+ }
+ FJXL_INLINE static SIMDVec16 Val(uint16_t v) {
+ return SIMDVec16{_mm512_set1_epi16(v)};
+ }
+ FJXL_INLINE static SIMDVec16 FromTwo32(const SIMDVec32& lo,
+ const SIMDVec32& hi) {
+ auto tmp = _mm512_packus_epi32(lo.vec, hi.vec);
+ alignas(64) uint64_t perm[8] = {0, 2, 4, 6, 1, 3, 5, 7};
+ return SIMDVec16{
+ _mm512_permutex2var_epi64(tmp, _mm512_load_si512((__m512i*)perm), tmp)};
+ }
+
+ FJXL_INLINE SIMDVec16 ValToToken() const {
+ auto c16 = _mm512_set1_epi32(16);
+ auto c32 = _mm512_set1_epi32(32);
+ auto low16bit = _mm512_set1_epi32(0x0000FFFF);
+ auto lzhi =
+ _mm512_sub_epi32(c16, _mm512_min_epu32(c16, _mm512_lzcnt_epi32(vec)));
+ auto lzlo = _mm512_sub_epi32(
+ c32, _mm512_lzcnt_epi32(_mm512_and_si512(low16bit, vec)));
+ return SIMDVec16{_mm512_or_si512(lzlo, _mm512_slli_epi32(lzhi, 16))};
+ }
+
+ FJXL_INLINE SIMDVec16 SatSubU(const SIMDVec16& to_subtract) const {
+ return SIMDVec16{_mm512_subs_epu16(vec, to_subtract.vec)};
+ }
+ FJXL_INLINE SIMDVec16 Sub(const SIMDVec16& to_subtract) const {
+ return SIMDVec16{_mm512_sub_epi16(vec, to_subtract.vec)};
+ }
+ FJXL_INLINE SIMDVec16 Add(const SIMDVec16& oth) const {
+ return SIMDVec16{_mm512_add_epi16(vec, oth.vec)};
+ }
+ FJXL_INLINE SIMDVec16 Min(const SIMDVec16& oth) const {
+ return SIMDVec16{_mm512_min_epu16(vec, oth.vec)};
+ }
+ FJXL_INLINE Mask16 Eq(const SIMDVec16& oth) const {
+ return Mask16{_mm512_cmpeq_epi16_mask(vec, oth.vec)};
+ }
+ FJXL_INLINE Mask16 Gt(const SIMDVec16& oth) const {
+ return Mask16{_mm512_cmpgt_epi16_mask(vec, oth.vec)};
+ }
+ FJXL_INLINE SIMDVec16 Pow2() const {
+ return SIMDVec16{_mm512_sllv_epi16(_mm512_set1_epi16(1), vec)};
+ }
+ FJXL_INLINE SIMDVec16 Or(const SIMDVec16& oth) const {
+ return SIMDVec16{_mm512_or_si512(vec, oth.vec)};
+ }
+ FJXL_INLINE SIMDVec16 Xor(const SIMDVec16& oth) const {
+ return SIMDVec16{_mm512_xor_si512(vec, oth.vec)};
+ }
+ FJXL_INLINE SIMDVec16 And(const SIMDVec16& oth) const {
+ return SIMDVec16{_mm512_and_si512(vec, oth.vec)};
+ }
+ FJXL_INLINE SIMDVec16 HAdd(const SIMDVec16& oth) const {
+ return SIMDVec16{_mm512_srai_epi16(_mm512_add_epi16(vec, oth.vec), 1)};
+ }
+ FJXL_INLINE SIMDVec16 PrepareForU8Lookup() const {
+ return SIMDVec16{_mm512_or_si512(vec, _mm512_set1_epi16(0xFF00))};
+ }
+ FJXL_INLINE SIMDVec16 U8Lookup(const uint8_t* table) const {
+ return SIMDVec16{_mm512_shuffle_epi8(
+ _mm512_broadcast_i32x4(_mm_loadu_si128((__m128i*)table)), vec)};
+ }
+ FJXL_INLINE VecPair<SIMDVec16> Interleave(const SIMDVec16& low) const {
+ auto lo = _mm512_unpacklo_epi16(low.vec, vec);
+ auto hi = _mm512_unpackhi_epi16(low.vec, vec);
+ alignas(64) uint64_t perm1[8] = {0, 1, 8, 9, 2, 3, 10, 11};
+ alignas(64) uint64_t perm2[8] = {4, 5, 12, 13, 6, 7, 14, 15};
+ return {SIMDVec16{_mm512_permutex2var_epi64(
+ lo, _mm512_load_si512((__m512i*)perm1), hi)},
+ SIMDVec16{_mm512_permutex2var_epi64(
+ lo, _mm512_load_si512((__m512i*)perm2), hi)}};
+ }
+ FJXL_INLINE VecPair<SIMDVec32> Upcast() const {
+ auto lo = _mm512_unpacklo_epi16(vec, _mm512_setzero_si512());
+ auto hi = _mm512_unpackhi_epi16(vec, _mm512_setzero_si512());
+ alignas(64) uint64_t perm1[8] = {0, 1, 8, 9, 2, 3, 10, 11};
+ alignas(64) uint64_t perm2[8] = {4, 5, 12, 13, 6, 7, 14, 15};
+ return {SIMDVec32{_mm512_permutex2var_epi64(
+ lo, _mm512_load_si512((__m512i*)perm1), hi)},
+ SIMDVec32{_mm512_permutex2var_epi64(
+ lo, _mm512_load_si512((__m512i*)perm2), hi)}};
+ }
+ template <size_t i>
+ FJXL_INLINE SIMDVec16 SignedShiftRight() const {
+ return SIMDVec16{_mm512_srai_epi16(vec, i)};
+ }
+
+ static std::array<SIMDVec16, 1> LoadG8(const unsigned char* data) {
+ __m256i bytes = _mm256_loadu_si256((__m256i*)data);
+ return {SIMDVec16{_mm512_cvtepu8_epi16(bytes)}};
+ }
+ static std::array<SIMDVec16, 1> LoadG16(const unsigned char* data) {
+ return {Load((const uint16_t*)data)};
+ }
+
+ static std::array<SIMDVec16, 2> LoadGA8(const unsigned char* data) {
+ __m512i bytes = _mm512_loadu_si512((__m512i*)data);
+ __m512i gray = _mm512_and_si512(bytes, _mm512_set1_epi16(0xFF));
+ __m512i alpha = _mm512_srli_epi16(bytes, 8);
+ return {SIMDVec16{gray}, SIMDVec16{alpha}};
+ }
+ static std::array<SIMDVec16, 2> LoadGA16(const unsigned char* data) {
+ __m512i bytes1 = _mm512_loadu_si512((__m512i*)data);
+ __m512i bytes2 = _mm512_loadu_si512((__m512i*)(data + 64));
+ __m512i g_mask = _mm512_set1_epi32(0xFFFF);
+ __m512i permuteidx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
+ __m512i g = _mm512_permutexvar_epi64(
+ permuteidx, _mm512_packus_epi32(_mm512_and_si512(bytes1, g_mask),
+ _mm512_and_si512(bytes2, g_mask)));
+ __m512i a = _mm512_permutexvar_epi64(
+ permuteidx, _mm512_packus_epi32(_mm512_srli_epi32(bytes1, 16),
+ _mm512_srli_epi32(bytes2, 16)));
+ return {SIMDVec16{g}, SIMDVec16{a}};
+ }
+
+ static std::array<SIMDVec16, 3> LoadRGB8(const unsigned char* data) {
+ __m512i bytes0 = _mm512_loadu_si512((__m512i*)data);
+ __m512i bytes1 =
+ _mm512_zextsi256_si512(_mm256_loadu_si256((__m256i*)(data + 64)));
+
+ // 0x7A = element of upper half of second vector = 0 after lookup; still in
+ // the upper half once we add 1 or 2.
+ uint8_t z = 0x7A;
+ __m512i ridx =
+ _mm512_set_epi8(z, 93, z, 90, z, 87, z, 84, z, 81, z, 78, z, 75, z, 72,
+ z, 69, z, 66, z, 63, z, 60, z, 57, z, 54, z, 51, z, 48,
+ z, 45, z, 42, z, 39, z, 36, z, 33, z, 30, z, 27, z, 24,
+ z, 21, z, 18, z, 15, z, 12, z, 9, z, 6, z, 3, z, 0);
+ __m512i gidx = _mm512_add_epi8(ridx, _mm512_set1_epi8(1));
+ __m512i bidx = _mm512_add_epi8(gidx, _mm512_set1_epi8(1));
+ __m512i r = _mm512_permutex2var_epi8(bytes0, ridx, bytes1);
+ __m512i g = _mm512_permutex2var_epi8(bytes0, gidx, bytes1);
+ __m512i b = _mm512_permutex2var_epi8(bytes0, bidx, bytes1);
+ return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}};
+ }
+ static std::array<SIMDVec16, 3> LoadRGB16(const unsigned char* data) {
+ __m512i bytes0 = _mm512_loadu_si512((__m512i*)data);
+ __m512i bytes1 = _mm512_loadu_si512((__m512i*)(data + 64));
+ __m512i bytes2 = _mm512_loadu_si512((__m512i*)(data + 128));
+
+ __m512i ridx_lo = _mm512_set_epi16(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 63, 60, 57,
+ 54, 51, 48, 45, 42, 39, 36, 33, 30, 27,
+ 24, 21, 18, 15, 12, 9, 6, 3, 0);
+ // -1 is such that when adding 1 or 2, we get the correct index for
+ // green/blue.
+ __m512i ridx_hi =
+ _mm512_set_epi16(29, 26, 23, 20, 17, 14, 11, 8, 5, 2, -1, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
+ __m512i gidx_lo = _mm512_add_epi16(ridx_lo, _mm512_set1_epi16(1));
+ __m512i gidx_hi = _mm512_add_epi16(ridx_hi, _mm512_set1_epi16(1));
+ __m512i bidx_lo = _mm512_add_epi16(gidx_lo, _mm512_set1_epi16(1));
+ __m512i bidx_hi = _mm512_add_epi16(gidx_hi, _mm512_set1_epi16(1));
+
+ __mmask32 rmask = _cvtu32_mask32(0b11111111110000000000000000000000);
+ __mmask32 gbmask = _cvtu32_mask32(0b11111111111000000000000000000000);
+
+ __m512i rlo = _mm512_permutex2var_epi16(bytes0, ridx_lo, bytes1);
+ __m512i glo = _mm512_permutex2var_epi16(bytes0, gidx_lo, bytes1);
+ __m512i blo = _mm512_permutex2var_epi16(bytes0, bidx_lo, bytes1);
+ __m512i r = _mm512_mask_permutexvar_epi16(rlo, rmask, ridx_hi, bytes2);
+ __m512i g = _mm512_mask_permutexvar_epi16(glo, gbmask, gidx_hi, bytes2);
+ __m512i b = _mm512_mask_permutexvar_epi16(blo, gbmask, bidx_hi, bytes2);
+ return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}};
+ }
+
+ static std::array<SIMDVec16, 4> LoadRGBA8(const unsigned char* data) {
+ __m512i bytes1 = _mm512_loadu_si512((__m512i*)data);
+ __m512i bytes2 = _mm512_loadu_si512((__m512i*)(data + 64));
+ __m512i rg_mask = _mm512_set1_epi32(0xFFFF);
+ __m512i permuteidx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
+ __m512i rg = _mm512_permutexvar_epi64(
+ permuteidx, _mm512_packus_epi32(_mm512_and_si512(bytes1, rg_mask),
+ _mm512_and_si512(bytes2, rg_mask)));
+ __m512i ba = _mm512_permutexvar_epi64(
+ permuteidx, _mm512_packus_epi32(_mm512_srli_epi32(bytes1, 16),
+ _mm512_srli_epi32(bytes2, 16)));
+ __m512i r = _mm512_and_si512(rg, _mm512_set1_epi16(0xFF));
+ __m512i g = _mm512_srli_epi16(rg, 8);
+ __m512i b = _mm512_and_si512(ba, _mm512_set1_epi16(0xFF));
+ __m512i a = _mm512_srli_epi16(ba, 8);
+ return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}, SIMDVec16{a}};
+ }
+ static std::array<SIMDVec16, 4> LoadRGBA16(const unsigned char* data) {
+ __m512i bytes0 = _mm512_loadu_si512((__m512i*)data);
+ __m512i bytes1 = _mm512_loadu_si512((__m512i*)(data + 64));
+ __m512i bytes2 = _mm512_loadu_si512((__m512i*)(data + 128));
+ __m512i bytes3 = _mm512_loadu_si512((__m512i*)(data + 192));
+
+ auto pack32 = [](__m512i a, __m512i b) {
+ __m512i permuteidx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
+ return _mm512_permutexvar_epi64(permuteidx, _mm512_packus_epi32(a, b));
+ };
+ auto packlow32 = [&pack32](__m512i a, __m512i b) {
+ __m512i mask = _mm512_set1_epi32(0xFFFF);
+ return pack32(_mm512_and_si512(a, mask), _mm512_and_si512(b, mask));
+ };
+ auto packhi32 = [&pack32](__m512i a, __m512i b) {
+ return pack32(_mm512_srli_epi32(a, 16), _mm512_srli_epi32(b, 16));
+ };
+
+ __m512i rb0 = packlow32(bytes0, bytes1);
+ __m512i rb1 = packlow32(bytes2, bytes3);
+ __m512i ga0 = packhi32(bytes0, bytes1);
+ __m512i ga1 = packhi32(bytes2, bytes3);
+
+ __m512i r = packlow32(rb0, rb1);
+ __m512i g = packlow32(ga0, ga1);
+ __m512i b = packhi32(rb0, rb1);
+ __m512i a = packhi32(ga0, ga1);
+ return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}, SIMDVec16{a}};
+ }
+
+ void SwapEndian() {
+ auto indices = _mm512_broadcast_i32x4(
+ _mm_setr_epi8(1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14));
+ vec = _mm512_shuffle_epi8(vec, indices);
+ }
+};
+
+SIMDVec16 Mask16::IfThenElse(const SIMDVec16& if_true,
+ const SIMDVec16& if_false) {
+ return SIMDVec16{_mm512_mask_blend_epi16(mask, if_false.vec, if_true.vec)};
+}
+
+SIMDVec32 Mask32::IfThenElse(const SIMDVec32& if_true,
+ const SIMDVec32& if_false) {
+ return SIMDVec32{_mm512_mask_blend_epi32(mask, if_false.vec, if_true.vec)};
+}
+
+struct Bits64 {
+ static constexpr size_t kLanes = 8;
+
+ __m512i nbits;
+ __m512i bits;
+
+ FJXL_INLINE void Store(uint64_t* nbits_out, uint64_t* bits_out) {
+ _mm512_storeu_si512((__m512i*)nbits_out, nbits);
+ _mm512_storeu_si512((__m512i*)bits_out, bits);
+ }
+};
+
+struct Bits32 {
+ __m512i nbits;
+ __m512i bits;
+
+ static Bits32 FromRaw(SIMDVec32 nbits, SIMDVec32 bits) {
+ return Bits32{nbits.vec, bits.vec};
+ }
+
+ Bits64 Merge() const {
+ auto nbits_hi32 = _mm512_srli_epi64(nbits, 32);
+ auto nbits_lo32 = _mm512_and_si512(nbits, _mm512_set1_epi64(0xFFFFFFFF));
+ auto bits_hi32 = _mm512_srli_epi64(bits, 32);
+ auto bits_lo32 = _mm512_and_si512(bits, _mm512_set1_epi64(0xFFFFFFFF));
+
+ auto nbits64 = _mm512_add_epi64(nbits_hi32, nbits_lo32);
+ auto bits64 =
+ _mm512_or_si512(_mm512_sllv_epi64(bits_hi32, nbits_lo32), bits_lo32);
+ return Bits64{nbits64, bits64};
+ }
+
+ void Interleave(const Bits32& low) {
+ bits = _mm512_or_si512(_mm512_sllv_epi32(bits, low.nbits), low.bits);
+ nbits = _mm512_add_epi32(nbits, low.nbits);
+ }
+
+ void ClipTo(size_t n) {
+ n = std::min<size_t>(n, 16);
+ constexpr uint32_t kMask[32] = {
+ ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u,
+ ~0u, ~0u, ~0u, ~0u, ~0u, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ };
+ __m512i mask = _mm512_loadu_si512((__m512i*)(kMask + 16 - n));
+ nbits = _mm512_and_si512(mask, nbits);
+ bits = _mm512_and_si512(mask, bits);
+ }
+ void Skip(size_t n) {
+ n = std::min<size_t>(n, 16);
+ constexpr uint32_t kMask[32] = {
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u,
+ ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u,
+ };
+ __m512i mask = _mm512_loadu_si512((__m512i*)(kMask + 16 - n));
+ nbits = _mm512_and_si512(mask, nbits);
+ bits = _mm512_and_si512(mask, bits);
+ }
+};
+
+struct Bits16 {
+ __m512i nbits;
+ __m512i bits;
+
+ static Bits16 FromRaw(SIMDVec16 nbits, SIMDVec16 bits) {
+ return Bits16{nbits.vec, bits.vec};
+ }
+
+ Bits32 Merge() const {
+ auto nbits_hi16 = _mm512_srli_epi32(nbits, 16);
+ auto nbits_lo16 = _mm512_and_si512(nbits, _mm512_set1_epi32(0xFFFF));
+ auto bits_hi16 = _mm512_srli_epi32(bits, 16);
+ auto bits_lo16 = _mm512_and_si512(bits, _mm512_set1_epi32(0xFFFF));
+
+ auto nbits32 = _mm512_add_epi32(nbits_hi16, nbits_lo16);
+ auto bits32 =
+ _mm512_or_si512(_mm512_sllv_epi32(bits_hi16, nbits_lo16), bits_lo16);
+ return Bits32{nbits32, bits32};
+ }
+
+ void Interleave(const Bits16& low) {
+ bits = _mm512_or_si512(_mm512_sllv_epi16(bits, low.nbits), low.bits);
+ nbits = _mm512_add_epi16(nbits, low.nbits);
+ }
+
+ void ClipTo(size_t n) {
+ n = std::min<size_t>(n, 32);
+ constexpr uint16_t kMask[64] = {
+ 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
+ 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
+ 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
+ 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ };
+ __m512i mask = _mm512_loadu_si512((__m512i*)(kMask + 32 - n));
+ nbits = _mm512_and_si512(mask, nbits);
+ bits = _mm512_and_si512(mask, bits);
+ }
+ void Skip(size_t n) {
+ n = std::min<size_t>(n, 32);
+ constexpr uint16_t kMask[64] = {
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
+ 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
+ 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
+ 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
+ };
+ __m512i mask = _mm512_loadu_si512((__m512i*)(kMask + 32 - n));
+ nbits = _mm512_and_si512(mask, nbits);
+ bits = _mm512_and_si512(mask, bits);
+ }
+};
+
+#endif
+
+#ifdef FJXL_AVX2
+#define FJXL_GENERIC_SIMD
+
+struct SIMDVec32;
+
+struct Mask32 {
+ __m256i mask;
+ SIMDVec32 IfThenElse(const SIMDVec32& if_true, const SIMDVec32& if_false);
+ size_t CountPrefix() const {
+ return CtzNonZero(~static_cast<uint64_t>(
+ (uint8_t)_mm256_movemask_ps(_mm256_castsi256_ps(mask))));
+ }
+};
+
+struct SIMDVec32 {
+ __m256i vec;
+
+ static constexpr size_t kLanes = 8;
+
+ FJXL_INLINE static SIMDVec32 Load(const uint32_t* data) {
+ return SIMDVec32{_mm256_loadu_si256((__m256i*)data)};
+ }
+ FJXL_INLINE void Store(uint32_t* data) {
+ _mm256_storeu_si256((__m256i*)data, vec);
+ }
+ FJXL_INLINE static SIMDVec32 Val(uint32_t v) {
+ return SIMDVec32{_mm256_set1_epi32(v)};
+ }
+ FJXL_INLINE SIMDVec32 ValToToken() const {
+ // we know that each value has at most 20 bits, so we just need 5 nibbles
+ // and don't need to mask the fifth. However we do need to set the higher
+ // bytes to 0xFF, which will make table lookups return 0.
+ auto nibble0 =
+ _mm256_or_si256(_mm256_and_si256(vec, _mm256_set1_epi32(0xF)),
+ _mm256_set1_epi32(0xFFFFFF00));
+ auto nibble1 = _mm256_or_si256(
+ _mm256_and_si256(_mm256_srli_epi32(vec, 4), _mm256_set1_epi32(0xF)),
+ _mm256_set1_epi32(0xFFFFFF00));
+ auto nibble2 = _mm256_or_si256(
+ _mm256_and_si256(_mm256_srli_epi32(vec, 8), _mm256_set1_epi32(0xF)),
+ _mm256_set1_epi32(0xFFFFFF00));
+ auto nibble3 = _mm256_or_si256(
+ _mm256_and_si256(_mm256_srli_epi32(vec, 12), _mm256_set1_epi32(0xF)),
+ _mm256_set1_epi32(0xFFFFFF00));
+ auto nibble4 = _mm256_or_si256(_mm256_srli_epi32(vec, 16),
+ _mm256_set1_epi32(0xFFFFFF00));
+
+ auto lut0 = _mm256_broadcastsi128_si256(
+ _mm_setr_epi8(0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4));
+ auto lut1 = _mm256_broadcastsi128_si256(
+ _mm_setr_epi8(0, 5, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8));
+ auto lut2 = _mm256_broadcastsi128_si256(_mm_setr_epi8(
+ 0, 9, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12));
+ auto lut3 = _mm256_broadcastsi128_si256(_mm_setr_epi8(
+ 0, 13, 14, 14, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16));
+ auto lut4 = _mm256_broadcastsi128_si256(_mm_setr_epi8(
+ 0, 17, 18, 18, 19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20));
+
+ auto token0 = _mm256_shuffle_epi8(lut0, nibble0);
+ auto token1 = _mm256_shuffle_epi8(lut1, nibble1);
+ auto token2 = _mm256_shuffle_epi8(lut2, nibble2);
+ auto token3 = _mm256_shuffle_epi8(lut3, nibble3);
+ auto token4 = _mm256_shuffle_epi8(lut4, nibble4);
+
+ auto token =
+ _mm256_max_epi32(_mm256_max_epi32(_mm256_max_epi32(token0, token1),
+ _mm256_max_epi32(token2, token3)),
+ token4);
+ return SIMDVec32{token};
+ }
+ FJXL_INLINE SIMDVec32 SatSubU(const SIMDVec32& to_subtract) const {
+ return SIMDVec32{_mm256_sub_epi32(_mm256_max_epu32(vec, to_subtract.vec),
+ to_subtract.vec)};
+ }
+ FJXL_INLINE SIMDVec32 Sub(const SIMDVec32& to_subtract) const {
+ return SIMDVec32{_mm256_sub_epi32(vec, to_subtract.vec)};
+ }
+ FJXL_INLINE SIMDVec32 Add(const SIMDVec32& oth) const {
+ return SIMDVec32{_mm256_add_epi32(vec, oth.vec)};
+ }
+ FJXL_INLINE SIMDVec32 Xor(const SIMDVec32& oth) const {
+ return SIMDVec32{_mm256_xor_si256(vec, oth.vec)};
+ }
+ FJXL_INLINE SIMDVec32 Pow2() const {
+ return SIMDVec32{_mm256_sllv_epi32(_mm256_set1_epi32(1), vec)};
+ }
+ FJXL_INLINE Mask32 Eq(const SIMDVec32& oth) const {
+ return Mask32{_mm256_cmpeq_epi32(vec, oth.vec)};
+ }
+ FJXL_INLINE Mask32 Gt(const SIMDVec32& oth) const {
+ return Mask32{_mm256_cmpgt_epi32(vec, oth.vec)};
+ }
+ template <size_t i>
+ FJXL_INLINE SIMDVec32 SignedShiftRight() const {
+ return SIMDVec32{_mm256_srai_epi32(vec, i)};
+ }
+};
+
+struct SIMDVec16;
+
+struct Mask16 {
+ __m256i mask;
+ SIMDVec16 IfThenElse(const SIMDVec16& if_true, const SIMDVec16& if_false);
+ Mask16 And(const Mask16& oth) const {
+ return Mask16{_mm256_and_si256(mask, oth.mask)};
+ }
+ size_t CountPrefix() const {
+ return CtzNonZero(
+ ~static_cast<uint64_t>((uint32_t)_mm256_movemask_epi8(mask))) /
+ 2;
+ }
+};
+
+struct SIMDVec16 {
+ __m256i vec;
+
+ static constexpr size_t kLanes = 16;
+
+ FJXL_INLINE static SIMDVec16 Load(const uint16_t* data) {
+ return SIMDVec16{_mm256_loadu_si256((__m256i*)data)};
+ }
+ FJXL_INLINE void Store(uint16_t* data) {
+ _mm256_storeu_si256((__m256i*)data, vec);
+ }
+ FJXL_INLINE static SIMDVec16 Val(uint16_t v) {
+ return SIMDVec16{_mm256_set1_epi16(v)};
+ }
+ FJXL_INLINE static SIMDVec16 FromTwo32(const SIMDVec32& lo,
+ const SIMDVec32& hi) {
+ auto tmp = _mm256_packus_epi32(lo.vec, hi.vec);
+ return SIMDVec16{_mm256_permute4x64_epi64(tmp, 0b11011000)};
+ }
+
+ FJXL_INLINE SIMDVec16 ValToToken() const {
+ auto nibble0 =
+ _mm256_or_si256(_mm256_and_si256(vec, _mm256_set1_epi16(0xF)),
+ _mm256_set1_epi16(0xFF00));
+ auto nibble1 = _mm256_or_si256(
+ _mm256_and_si256(_mm256_srli_epi16(vec, 4), _mm256_set1_epi16(0xF)),
+ _mm256_set1_epi16(0xFF00));
+ auto nibble2 = _mm256_or_si256(
+ _mm256_and_si256(_mm256_srli_epi16(vec, 8), _mm256_set1_epi16(0xF)),
+ _mm256_set1_epi16(0xFF00));
+ auto nibble3 =
+ _mm256_or_si256(_mm256_srli_epi16(vec, 12), _mm256_set1_epi16(0xFF00));
+
+ auto lut0 = _mm256_broadcastsi128_si256(
+ _mm_setr_epi8(0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4));
+ auto lut1 = _mm256_broadcastsi128_si256(
+ _mm_setr_epi8(0, 5, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8));
+ auto lut2 = _mm256_broadcastsi128_si256(_mm_setr_epi8(
+ 0, 9, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12));
+ auto lut3 = _mm256_broadcastsi128_si256(_mm_setr_epi8(
+ 0, 13, 14, 14, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16));
+
+ auto token0 = _mm256_shuffle_epi8(lut0, nibble0);
+ auto token1 = _mm256_shuffle_epi8(lut1, nibble1);
+ auto token2 = _mm256_shuffle_epi8(lut2, nibble2);
+ auto token3 = _mm256_shuffle_epi8(lut3, nibble3);
+
+ auto token = _mm256_max_epi16(_mm256_max_epi16(token0, token1),
+ _mm256_max_epi16(token2, token3));
+ return SIMDVec16{token};
+ }
+
+ FJXL_INLINE SIMDVec16 SatSubU(const SIMDVec16& to_subtract) const {
+ return SIMDVec16{_mm256_subs_epu16(vec, to_subtract.vec)};
+ }
+ FJXL_INLINE SIMDVec16 Sub(const SIMDVec16& to_subtract) const {
+ return SIMDVec16{_mm256_sub_epi16(vec, to_subtract.vec)};
+ }
+ FJXL_INLINE SIMDVec16 Add(const SIMDVec16& oth) const {
+ return SIMDVec16{_mm256_add_epi16(vec, oth.vec)};
+ }
+ FJXL_INLINE SIMDVec16 Min(const SIMDVec16& oth) const {
+ return SIMDVec16{_mm256_min_epu16(vec, oth.vec)};
+ }
+ FJXL_INLINE Mask16 Eq(const SIMDVec16& oth) const {
+ return Mask16{_mm256_cmpeq_epi16(vec, oth.vec)};
+ }
+ FJXL_INLINE Mask16 Gt(const SIMDVec16& oth) const {
+ return Mask16{_mm256_cmpgt_epi16(vec, oth.vec)};
+ }
+ FJXL_INLINE SIMDVec16 Pow2() const {
+ auto pow2_lo_lut = _mm256_broadcastsi128_si256(
+ _mm_setr_epi8(1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6,
+ 1u << 7, 0, 0, 0, 0, 0, 0, 0, 0));
+ auto pow2_hi_lut = _mm256_broadcastsi128_si256(
+ _mm_setr_epi8(0, 0, 0, 0, 0, 0, 0, 0, 1 << 0, 1 << 1, 1 << 2, 1 << 3,
+ 1 << 4, 1 << 5, 1 << 6, 1u << 7));
+
+ auto masked = _mm256_or_si256(vec, _mm256_set1_epi16(0xFF00));
+
+ auto pow2_lo = _mm256_shuffle_epi8(pow2_lo_lut, masked);
+ auto pow2_hi = _mm256_shuffle_epi8(pow2_hi_lut, masked);
+
+ auto pow2 = _mm256_or_si256(_mm256_slli_epi16(pow2_hi, 8), pow2_lo);
+ return SIMDVec16{pow2};
+ }
+ FJXL_INLINE SIMDVec16 Or(const SIMDVec16& oth) const {
+ return SIMDVec16{_mm256_or_si256(vec, oth.vec)};
+ }
+ FJXL_INLINE SIMDVec16 Xor(const SIMDVec16& oth) const {
+ return SIMDVec16{_mm256_xor_si256(vec, oth.vec)};
+ }
+ FJXL_INLINE SIMDVec16 And(const SIMDVec16& oth) const {
+ return SIMDVec16{_mm256_and_si256(vec, oth.vec)};
+ }
+ FJXL_INLINE SIMDVec16 HAdd(const SIMDVec16& oth) const {
+ return SIMDVec16{_mm256_srai_epi16(_mm256_add_epi16(vec, oth.vec), 1)};
+ }
+ FJXL_INLINE SIMDVec16 PrepareForU8Lookup() const {
+ return SIMDVec16{_mm256_or_si256(vec, _mm256_set1_epi16(0xFF00))};
+ }
+ FJXL_INLINE SIMDVec16 U8Lookup(const uint8_t* table) const {
+ return SIMDVec16{_mm256_shuffle_epi8(
+ _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)table)), vec)};
+ }
+ FJXL_INLINE VecPair<SIMDVec16> Interleave(const SIMDVec16& low) const {
+ auto v02 = _mm256_unpacklo_epi16(low.vec, vec);
+ auto v13 = _mm256_unpackhi_epi16(low.vec, vec);
+ return {SIMDVec16{_mm256_permute2x128_si256(v02, v13, 0x20)},
+ SIMDVec16{_mm256_permute2x128_si256(v02, v13, 0x31)}};
+ }
+ FJXL_INLINE VecPair<SIMDVec32> Upcast() const {
+ auto v02 = _mm256_unpacklo_epi16(vec, _mm256_setzero_si256());
+ auto v13 = _mm256_unpackhi_epi16(vec, _mm256_setzero_si256());
+ return {SIMDVec32{_mm256_permute2x128_si256(v02, v13, 0x20)},
+ SIMDVec32{_mm256_permute2x128_si256(v02, v13, 0x31)}};
+ }
+ template <size_t i>
+ FJXL_INLINE SIMDVec16 SignedShiftRight() const {
+ return SIMDVec16{_mm256_srai_epi16(vec, i)};
+ }
+
+ static std::array<SIMDVec16, 1> LoadG8(const unsigned char* data) {
+ __m128i bytes = _mm_loadu_si128((__m128i*)data);
+ return {SIMDVec16{_mm256_cvtepu8_epi16(bytes)}};
+ }
+ static std::array<SIMDVec16, 1> LoadG16(const unsigned char* data) {
+ return {Load((const uint16_t*)data)};
+ }
+
+ static std::array<SIMDVec16, 2> LoadGA8(const unsigned char* data) {
+ __m256i bytes = _mm256_loadu_si256((__m256i*)data);
+ __m256i gray = _mm256_and_si256(bytes, _mm256_set1_epi16(0xFF));
+ __m256i alpha = _mm256_srli_epi16(bytes, 8);
+ return {SIMDVec16{gray}, SIMDVec16{alpha}};
+ }
+ static std::array<SIMDVec16, 2> LoadGA16(const unsigned char* data) {
+ __m256i bytes1 = _mm256_loadu_si256((__m256i*)data);
+ __m256i bytes2 = _mm256_loadu_si256((__m256i*)(data + 32));
+ __m256i g_mask = _mm256_set1_epi32(0xFFFF);
+ __m256i g = _mm256_permute4x64_epi64(
+ _mm256_packus_epi32(_mm256_and_si256(bytes1, g_mask),
+ _mm256_and_si256(bytes2, g_mask)),
+ 0b11011000);
+ __m256i a = _mm256_permute4x64_epi64(
+ _mm256_packus_epi32(_mm256_srli_epi32(bytes1, 16),
+ _mm256_srli_epi32(bytes2, 16)),
+ 0b11011000);
+ return {SIMDVec16{g}, SIMDVec16{a}};
+ }
+
+ static std::array<SIMDVec16, 3> LoadRGB8(const unsigned char* data) {
+ __m128i bytes0 = _mm_loadu_si128((__m128i*)data);
+ __m128i bytes1 = _mm_loadu_si128((__m128i*)(data + 16));
+ __m128i bytes2 = _mm_loadu_si128((__m128i*)(data + 32));
+
+ __m128i idx =
+ _mm_setr_epi8(0, 3, 6, 9, 12, 15, 2, 5, 8, 11, 14, 1, 4, 7, 10, 13);
+
+ __m128i r6b5g5_0 = _mm_shuffle_epi8(bytes0, idx);
+ __m128i g6r5b5_1 = _mm_shuffle_epi8(bytes1, idx);
+ __m128i b6g5r5_2 = _mm_shuffle_epi8(bytes2, idx);
+
+ __m128i mask010 = _mm_setr_epi8(0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xFF, 0xFF,
+ 0xFF, 0, 0, 0, 0, 0);
+ __m128i mask001 = _mm_setr_epi8(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF,
+ 0xFF, 0xFF, 0xFF);
+
+ __m128i b2g2b1 = _mm_blendv_epi8(b6g5r5_2, g6r5b5_1, mask001);
+ __m128i b2b0b1 = _mm_blendv_epi8(b2g2b1, r6b5g5_0, mask010);
+
+ __m128i r0r1b1 = _mm_blendv_epi8(r6b5g5_0, g6r5b5_1, mask010);
+ __m128i r0r1r2 = _mm_blendv_epi8(r0r1b1, b6g5r5_2, mask001);
+
+ __m128i g1r1g0 = _mm_blendv_epi8(g6r5b5_1, r6b5g5_0, mask001);
+ __m128i g1g2g0 = _mm_blendv_epi8(g1r1g0, b6g5r5_2, mask010);
+
+ __m128i g0g1g2 = _mm_alignr_epi8(g1g2g0, g1g2g0, 11);
+ __m128i b0b1b2 = _mm_alignr_epi8(b2b0b1, b2b0b1, 6);
+
+ return {SIMDVec16{_mm256_cvtepu8_epi16(r0r1r2)},
+ SIMDVec16{_mm256_cvtepu8_epi16(g0g1g2)},
+ SIMDVec16{_mm256_cvtepu8_epi16(b0b1b2)}};
+ }
+ static std::array<SIMDVec16, 3> LoadRGB16(const unsigned char* data) {
+ auto load_and_split_lohi = [](const unsigned char* data) {
+ // LHLHLH...
+ __m256i bytes = _mm256_loadu_si256((__m256i*)data);
+ // L0L0L0...
+ __m256i lo = _mm256_and_si256(bytes, _mm256_set1_epi16(0xFF));
+ // H0H0H0...
+ __m256i hi = _mm256_srli_epi16(bytes, 8);
+ // LLLLLLLLHHHHHHHHLLLLLLLLHHHHHHHH
+ __m256i packed = _mm256_packus_epi16(lo, hi);
+ return _mm256_permute4x64_epi64(packed, 0b11011000);
+ };
+ __m256i bytes0 = load_and_split_lohi(data);
+ __m256i bytes1 = load_and_split_lohi(data + 32);
+ __m256i bytes2 = load_and_split_lohi(data + 64);
+
+ __m256i idx = _mm256_broadcastsi128_si256(
+ _mm_setr_epi8(0, 3, 6, 9, 12, 15, 2, 5, 8, 11, 14, 1, 4, 7, 10, 13));
+
+ __m256i r6b5g5_0 = _mm256_shuffle_epi8(bytes0, idx);
+ __m256i g6r5b5_1 = _mm256_shuffle_epi8(bytes1, idx);
+ __m256i b6g5r5_2 = _mm256_shuffle_epi8(bytes2, idx);
+
+ __m256i mask010 = _mm256_broadcastsi128_si256(_mm_setr_epi8(
+ 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0, 0, 0, 0, 0));
+ __m256i mask001 = _mm256_broadcastsi128_si256(_mm_setr_epi8(
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF));
+
+ __m256i b2g2b1 = _mm256_blendv_epi8(b6g5r5_2, g6r5b5_1, mask001);
+ __m256i b2b0b1 = _mm256_blendv_epi8(b2g2b1, r6b5g5_0, mask010);
+
+ __m256i r0r1b1 = _mm256_blendv_epi8(r6b5g5_0, g6r5b5_1, mask010);
+ __m256i r0r1r2 = _mm256_blendv_epi8(r0r1b1, b6g5r5_2, mask001);
+
+ __m256i g1r1g0 = _mm256_blendv_epi8(g6r5b5_1, r6b5g5_0, mask001);
+ __m256i g1g2g0 = _mm256_blendv_epi8(g1r1g0, b6g5r5_2, mask010);
+
+ __m256i g0g1g2 = _mm256_alignr_epi8(g1g2g0, g1g2g0, 11);
+ __m256i b0b1b2 = _mm256_alignr_epi8(b2b0b1, b2b0b1, 6);
+
+ // Now r0r1r2, g0g1g2, b0b1b2 have the low bytes of the RGB pixels in their
+ // lower half, and the high bytes in their upper half.
+
+ auto combine_low_hi = [](__m256i v) {
+ __m128i low = _mm256_extracti128_si256(v, 0);
+ __m128i hi = _mm256_extracti128_si256(v, 1);
+ __m256i low16 = _mm256_cvtepu8_epi16(low);
+ __m256i hi16 = _mm256_cvtepu8_epi16(hi);
+ return _mm256_or_si256(_mm256_slli_epi16(hi16, 8), low16);
+ };
+
+ return {SIMDVec16{combine_low_hi(r0r1r2)},
+ SIMDVec16{combine_low_hi(g0g1g2)},
+ SIMDVec16{combine_low_hi(b0b1b2)}};
+ }
+
+ static std::array<SIMDVec16, 4> LoadRGBA8(const unsigned char* data) {
+ __m256i bytes1 = _mm256_loadu_si256((__m256i*)data);
+ __m256i bytes2 = _mm256_loadu_si256((__m256i*)(data + 32));
+ __m256i rg_mask = _mm256_set1_epi32(0xFFFF);
+ __m256i rg = _mm256_permute4x64_epi64(
+ _mm256_packus_epi32(_mm256_and_si256(bytes1, rg_mask),
+ _mm256_and_si256(bytes2, rg_mask)),
+ 0b11011000);
+ __m256i ba = _mm256_permute4x64_epi64(
+ _mm256_packus_epi32(_mm256_srli_epi32(bytes1, 16),
+ _mm256_srli_epi32(bytes2, 16)),
+ 0b11011000);
+ __m256i r = _mm256_and_si256(rg, _mm256_set1_epi16(0xFF));
+ __m256i g = _mm256_srli_epi16(rg, 8);
+ __m256i b = _mm256_and_si256(ba, _mm256_set1_epi16(0xFF));
+ __m256i a = _mm256_srli_epi16(ba, 8);
+ return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}, SIMDVec16{a}};
+ }
+ static std::array<SIMDVec16, 4> LoadRGBA16(const unsigned char* data) {
+ __m256i bytes0 = _mm256_loadu_si256((__m256i*)data);
+ __m256i bytes1 = _mm256_loadu_si256((__m256i*)(data + 32));
+ __m256i bytes2 = _mm256_loadu_si256((__m256i*)(data + 64));
+ __m256i bytes3 = _mm256_loadu_si256((__m256i*)(data + 96));
+
+ auto pack32 = [](__m256i a, __m256i b) {
+ return _mm256_permute4x64_epi64(_mm256_packus_epi32(a, b), 0b11011000);
+ };
+ auto packlow32 = [&pack32](__m256i a, __m256i b) {
+ __m256i mask = _mm256_set1_epi32(0xFFFF);
+ return pack32(_mm256_and_si256(a, mask), _mm256_and_si256(b, mask));
+ };
+ auto packhi32 = [&pack32](__m256i a, __m256i b) {
+ return pack32(_mm256_srli_epi32(a, 16), _mm256_srli_epi32(b, 16));
+ };
+
+ __m256i rb0 = packlow32(bytes0, bytes1);
+ __m256i rb1 = packlow32(bytes2, bytes3);
+ __m256i ga0 = packhi32(bytes0, bytes1);
+ __m256i ga1 = packhi32(bytes2, bytes3);
+
+ __m256i r = packlow32(rb0, rb1);
+ __m256i g = packlow32(ga0, ga1);
+ __m256i b = packhi32(rb0, rb1);
+ __m256i a = packhi32(ga0, ga1);
+ return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}, SIMDVec16{a}};
+ }
+
+ void SwapEndian() {
+ auto indices = _mm256_broadcastsi128_si256(
+ _mm_setr_epi8(1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14));
+ vec = _mm256_shuffle_epi8(vec, indices);
+ }
+};
+
+SIMDVec16 Mask16::IfThenElse(const SIMDVec16& if_true,
+ const SIMDVec16& if_false) {
+ return SIMDVec16{_mm256_blendv_epi8(if_false.vec, if_true.vec, mask)};
+}
+
+SIMDVec32 Mask32::IfThenElse(const SIMDVec32& if_true,
+ const SIMDVec32& if_false) {
+ return SIMDVec32{_mm256_blendv_epi8(if_false.vec, if_true.vec, mask)};
+}
+
+struct Bits64 {
+ static constexpr size_t kLanes = 4;
+
+ __m256i nbits;
+ __m256i bits;
+
+ FJXL_INLINE void Store(uint64_t* nbits_out, uint64_t* bits_out) {
+ _mm256_storeu_si256((__m256i*)nbits_out, nbits);
+ _mm256_storeu_si256((__m256i*)bits_out, bits);
+ }
+};
+
+struct Bits32 {
+ __m256i nbits;
+ __m256i bits;
+
+ static Bits32 FromRaw(SIMDVec32 nbits, SIMDVec32 bits) {
+ return Bits32{nbits.vec, bits.vec};
+ }
+
+ Bits64 Merge() const {
+ auto nbits_hi32 = _mm256_srli_epi64(nbits, 32);
+ auto nbits_lo32 = _mm256_and_si256(nbits, _mm256_set1_epi64x(0xFFFFFFFF));
+ auto bits_hi32 = _mm256_srli_epi64(bits, 32);
+ auto bits_lo32 = _mm256_and_si256(bits, _mm256_set1_epi64x(0xFFFFFFFF));
+
+ auto nbits64 = _mm256_add_epi64(nbits_hi32, nbits_lo32);
+ auto bits64 =
+ _mm256_or_si256(_mm256_sllv_epi64(bits_hi32, nbits_lo32), bits_lo32);
+ return Bits64{nbits64, bits64};
+ }
+
+ void Interleave(const Bits32& low) {
+ bits = _mm256_or_si256(_mm256_sllv_epi32(bits, low.nbits), low.bits);
+ nbits = _mm256_add_epi32(nbits, low.nbits);
+ }
+
+ void ClipTo(size_t n) {
+ n = std::min<size_t>(n, 8);
+ constexpr uint32_t kMask[16] = {
+ ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, 0, 0, 0, 0, 0, 0, 0, 0,
+ };
+ __m256i mask = _mm256_loadu_si256((__m256i*)(kMask + 8 - n));
+ nbits = _mm256_and_si256(mask, nbits);
+ bits = _mm256_and_si256(mask, bits);
+ }
+ void Skip(size_t n) {
+ n = std::min<size_t>(n, 8);
+ constexpr uint32_t kMask[16] = {
+ 0, 0, 0, 0, 0, 0, 0, 0, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u,
+ };
+ __m256i mask = _mm256_loadu_si256((__m256i*)(kMask + 8 - n));
+ nbits = _mm256_and_si256(mask, nbits);
+ bits = _mm256_and_si256(mask, bits);
+ }
+};
+
+struct Bits16 {
+ __m256i nbits;
+ __m256i bits;
+
+ static Bits16 FromRaw(SIMDVec16 nbits, SIMDVec16 bits) {
+ return Bits16{nbits.vec, bits.vec};
+ }
+
+ Bits32 Merge() const {
+ auto nbits_hi16 = _mm256_srli_epi32(nbits, 16);
+ auto nbits_lo16 = _mm256_and_si256(nbits, _mm256_set1_epi32(0xFFFF));
+ auto bits_hi16 = _mm256_srli_epi32(bits, 16);
+ auto bits_lo16 = _mm256_and_si256(bits, _mm256_set1_epi32(0xFFFF));
+
+ auto nbits32 = _mm256_add_epi32(nbits_hi16, nbits_lo16);
+ auto bits32 =
+ _mm256_or_si256(_mm256_sllv_epi32(bits_hi16, nbits_lo16), bits_lo16);
+ return Bits32{nbits32, bits32};
+ }
+
+ void Interleave(const Bits16& low) {
+ auto pow2_lo_lut = _mm256_broadcastsi128_si256(
+ _mm_setr_epi8(1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6,
+ 1u << 7, 0, 0, 0, 0, 0, 0, 0, 0));
+ auto low_nbits_masked =
+ _mm256_or_si256(low.nbits, _mm256_set1_epi16(0xFF00));
+
+ auto bits_shifted = _mm256_mullo_epi16(
+ bits, _mm256_shuffle_epi8(pow2_lo_lut, low_nbits_masked));
+
+ nbits = _mm256_add_epi16(nbits, low.nbits);
+ bits = _mm256_or_si256(bits_shifted, low.bits);
+ }
+
+ void ClipTo(size_t n) {
+ n = std::min<size_t>(n, 16);
+ constexpr uint16_t kMask[32] = {
+ 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
+ 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ };
+ __m256i mask = _mm256_loadu_si256((__m256i*)(kMask + 16 - n));
+ nbits = _mm256_and_si256(mask, nbits);
+ bits = _mm256_and_si256(mask, bits);
+ }
+
+ void Skip(size_t n) {
+ n = std::min<size_t>(n, 16);
+ constexpr uint16_t kMask[32] = {
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
+ 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
+ };
+ __m256i mask = _mm256_loadu_si256((__m256i*)(kMask + 16 - n));
+ nbits = _mm256_and_si256(mask, nbits);
+ bits = _mm256_and_si256(mask, bits);
+ }
+};
+
+#endif
+
+#ifdef FJXL_NEON
+#define FJXL_GENERIC_SIMD
+
+struct SIMDVec32;
+
+struct Mask32 {
+ uint32x4_t mask;
+ SIMDVec32 IfThenElse(const SIMDVec32& if_true, const SIMDVec32& if_false);
+ Mask32 And(const Mask32& oth) const {
+ return Mask32{vandq_u32(mask, oth.mask)};
+ }
+ size_t CountPrefix() const {
+ uint32_t val_unset[4] = {0, 1, 2, 3};
+ uint32_t val_set[4] = {4, 4, 4, 4};
+ uint32x4_t val = vbslq_u32(mask, vld1q_u32(val_set), vld1q_u32(val_unset));
+ return vminvq_u32(val);
+ }
+};
+
+struct SIMDVec32 {
+ uint32x4_t vec;
+
+ static constexpr size_t kLanes = 4;
+
+ FJXL_INLINE static SIMDVec32 Load(const uint32_t* data) {
+ return SIMDVec32{vld1q_u32(data)};
+ }
+ FJXL_INLINE void Store(uint32_t* data) { vst1q_u32(data, vec); }
+ FJXL_INLINE static SIMDVec32 Val(uint32_t v) {
+ return SIMDVec32{vdupq_n_u32(v)};
+ }
+ FJXL_INLINE SIMDVec32 ValToToken() const {
+ return SIMDVec32{vsubq_u32(vdupq_n_u32(32), vclzq_u32(vec))};
+ }
+ FJXL_INLINE SIMDVec32 SatSubU(const SIMDVec32& to_subtract) const {
+ return SIMDVec32{vqsubq_u32(vec, to_subtract.vec)};
+ }
+ FJXL_INLINE SIMDVec32 Sub(const SIMDVec32& to_subtract) const {
+ return SIMDVec32{vsubq_u32(vec, to_subtract.vec)};
+ }
+ FJXL_INLINE SIMDVec32 Add(const SIMDVec32& oth) const {
+ return SIMDVec32{vaddq_u32(vec, oth.vec)};
+ }
+ FJXL_INLINE SIMDVec32 Xor(const SIMDVec32& oth) const {
+ return SIMDVec32{veorq_u32(vec, oth.vec)};
+ }
+ FJXL_INLINE SIMDVec32 Pow2() const {
+ return SIMDVec32{vshlq_u32(vdupq_n_u32(1), vreinterpretq_s32_u32(vec))};
+ }
+ FJXL_INLINE Mask32 Eq(const SIMDVec32& oth) const {
+ return Mask32{vceqq_u32(vec, oth.vec)};
+ }
+ FJXL_INLINE Mask32 Gt(const SIMDVec32& oth) const {
+ return Mask32{
+ vcgtq_s32(vreinterpretq_s32_u32(vec), vreinterpretq_s32_u32(oth.vec))};
+ }
+ template <size_t i>
+ FJXL_INLINE SIMDVec32 SignedShiftRight() const {
+ return SIMDVec32{
+ vreinterpretq_u32_s32(vshrq_n_s32(vreinterpretq_s32_u32(vec), i))};
+ }
+};
+
+struct SIMDVec16;
+
+struct Mask16 {
+ uint16x8_t mask;
+ SIMDVec16 IfThenElse(const SIMDVec16& if_true, const SIMDVec16& if_false);
+ Mask16 And(const Mask16& oth) const {
+ return Mask16{vandq_u16(mask, oth.mask)};
+ }
+ size_t CountPrefix() const {
+ uint16_t val_unset[8] = {0, 1, 2, 3, 4, 5, 6, 7};
+ uint16_t val_set[8] = {8, 8, 8, 8, 8, 8, 8, 8};
+ uint16x8_t val = vbslq_u16(mask, vld1q_u16(val_set), vld1q_u16(val_unset));
+ return vminvq_u16(val);
+ }
+};
+
+struct SIMDVec16 {
+ uint16x8_t vec;
+
+ static constexpr size_t kLanes = 8;
+
+ FJXL_INLINE static SIMDVec16 Load(const uint16_t* data) {
+ return SIMDVec16{vld1q_u16(data)};
+ }
+ FJXL_INLINE void Store(uint16_t* data) { vst1q_u16(data, vec); }
+ FJXL_INLINE static SIMDVec16 Val(uint16_t v) {
+ return SIMDVec16{vdupq_n_u16(v)};
+ }
+ FJXL_INLINE static SIMDVec16 FromTwo32(const SIMDVec32& lo,
+ const SIMDVec32& hi) {
+ return SIMDVec16{vmovn_high_u32(vmovn_u32(lo.vec), hi.vec)};
+ }
+
+ FJXL_INLINE SIMDVec16 ValToToken() const {
+ return SIMDVec16{vsubq_u16(vdupq_n_u16(16), vclzq_u16(vec))};
+ }
+ FJXL_INLINE SIMDVec16 SatSubU(const SIMDVec16& to_subtract) const {
+ return SIMDVec16{vqsubq_u16(vec, to_subtract.vec)};
+ }
+ FJXL_INLINE SIMDVec16 Sub(const SIMDVec16& to_subtract) const {
+ return SIMDVec16{vsubq_u16(vec, to_subtract.vec)};
+ }
+ FJXL_INLINE SIMDVec16 Add(const SIMDVec16& oth) const {
+ return SIMDVec16{vaddq_u16(vec, oth.vec)};
+ }
+ FJXL_INLINE SIMDVec16 Min(const SIMDVec16& oth) const {
+ return SIMDVec16{vminq_u16(vec, oth.vec)};
+ }
+ FJXL_INLINE Mask16 Eq(const SIMDVec16& oth) const {
+ return Mask16{vceqq_u16(vec, oth.vec)};
+ }
+ FJXL_INLINE Mask16 Gt(const SIMDVec16& oth) const {
+ return Mask16{
+ vcgtq_s16(vreinterpretq_s16_u16(vec), vreinterpretq_s16_u16(oth.vec))};
+ }
+ FJXL_INLINE SIMDVec16 Pow2() const {
+ return SIMDVec16{vshlq_u16(vdupq_n_u16(1), vreinterpretq_s16_u16(vec))};
+ }
+ FJXL_INLINE SIMDVec16 Or(const SIMDVec16& oth) const {
+ return SIMDVec16{vorrq_u16(vec, oth.vec)};
+ }
+ FJXL_INLINE SIMDVec16 Xor(const SIMDVec16& oth) const {
+ return SIMDVec16{veorq_u16(vec, oth.vec)};
+ }
+ FJXL_INLINE SIMDVec16 And(const SIMDVec16& oth) const {
+ return SIMDVec16{vandq_u16(vec, oth.vec)};
+ }
+ FJXL_INLINE SIMDVec16 HAdd(const SIMDVec16& oth) const {
+ return SIMDVec16{vhaddq_u16(vec, oth.vec)};
+ }
+ FJXL_INLINE SIMDVec16 PrepareForU8Lookup() const {
+ return SIMDVec16{vorrq_u16(vec, vdupq_n_u16(0xFF00))};
+ }
+ FJXL_INLINE SIMDVec16 U8Lookup(const uint8_t* table) const {
+ uint8x16_t tbl = vld1q_u8(table);
+ uint8x16_t indices = vreinterpretq_u8_u16(vec);
+ return SIMDVec16{vreinterpretq_u16_u8(vqtbl1q_u8(tbl, indices))};
+ }
+ FJXL_INLINE VecPair<SIMDVec16> Interleave(const SIMDVec16& low) const {
+ return {SIMDVec16{vzip1q_u16(low.vec, vec)},
+ SIMDVec16{vzip2q_u16(low.vec, vec)}};
+ }
+ FJXL_INLINE VecPair<SIMDVec32> Upcast() const {
+ uint32x4_t lo = vmovl_u16(vget_low_u16(vec));
+ uint32x4_t hi = vmovl_high_u16(vec);
+ return {SIMDVec32{lo}, SIMDVec32{hi}};
+ }
+ template <size_t i>
+ FJXL_INLINE SIMDVec16 SignedShiftRight() const {
+ return SIMDVec16{
+ vreinterpretq_u16_s16(vshrq_n_s16(vreinterpretq_s16_u16(vec), i))};
+ }
+
+ static std::array<SIMDVec16, 1> LoadG8(const unsigned char* data) {
+ uint8x8_t v = vld1_u8(data);
+ return {SIMDVec16{vmovl_u8(v)}};
+ }
+ static std::array<SIMDVec16, 1> LoadG16(const unsigned char* data) {
+ return {Load((const uint16_t*)data)};
+ }
+
+ static std::array<SIMDVec16, 2> LoadGA8(const unsigned char* data) {
+ uint8x8x2_t v = vld2_u8(data);
+ return {SIMDVec16{vmovl_u8(v.val[0])}, SIMDVec16{vmovl_u8(v.val[1])}};
+ }
+ static std::array<SIMDVec16, 2> LoadGA16(const unsigned char* data) {
+ uint16x8x2_t v = vld2q_u16((const uint16_t*)data);
+ return {SIMDVec16{v.val[0]}, SIMDVec16{v.val[1]}};
+ }
+
+ static std::array<SIMDVec16, 3> LoadRGB8(const unsigned char* data) {
+ uint8x8x3_t v = vld3_u8(data);
+ return {SIMDVec16{vmovl_u8(v.val[0])}, SIMDVec16{vmovl_u8(v.val[1])},
+ SIMDVec16{vmovl_u8(v.val[2])}};
+ }
+ static std::array<SIMDVec16, 3> LoadRGB16(const unsigned char* data) {
+ uint16x8x3_t v = vld3q_u16((const uint16_t*)data);
+ return {SIMDVec16{v.val[0]}, SIMDVec16{v.val[1]}, SIMDVec16{v.val[2]}};
+ }
+
+ static std::array<SIMDVec16, 4> LoadRGBA8(const unsigned char* data) {
+ uint8x8x4_t v = vld4_u8(data);
+ return {SIMDVec16{vmovl_u8(v.val[0])}, SIMDVec16{vmovl_u8(v.val[1])},
+ SIMDVec16{vmovl_u8(v.val[2])}, SIMDVec16{vmovl_u8(v.val[3])}};
+ }
+ static std::array<SIMDVec16, 4> LoadRGBA16(const unsigned char* data) {
+ uint16x8x4_t v = vld4q_u16((const uint16_t*)data);
+ return {SIMDVec16{v.val[0]}, SIMDVec16{v.val[1]}, SIMDVec16{v.val[2]},
+ SIMDVec16{v.val[3]}};
+ }
+
+ void SwapEndian() {
+ vec = vreinterpretq_u16_u8(vrev16q_u8(vreinterpretq_u8_u16(vec)));
+ }
+};
+
+SIMDVec16 Mask16::IfThenElse(const SIMDVec16& if_true,
+ const SIMDVec16& if_false) {
+ return SIMDVec16{vbslq_u16(mask, if_true.vec, if_false.vec)};
+}
+
+SIMDVec32 Mask32::IfThenElse(const SIMDVec32& if_true,
+ const SIMDVec32& if_false) {
+ return SIMDVec32{vbslq_u32(mask, if_true.vec, if_false.vec)};
+}
+
+struct Bits64 {
+ static constexpr size_t kLanes = 2;
+
+ uint64x2_t nbits;
+ uint64x2_t bits;
+
+ FJXL_INLINE void Store(uint64_t* nbits_out, uint64_t* bits_out) {
+ vst1q_u64(nbits_out, nbits);
+ vst1q_u64(bits_out, bits);
+ }
+};
+
+struct Bits32 {
+ uint32x4_t nbits;
+ uint32x4_t bits;
+
+ static Bits32 FromRaw(SIMDVec32 nbits, SIMDVec32 bits) {
+ return Bits32{nbits.vec, bits.vec};
+ }
+
+ Bits64 Merge() const {
+ // TODO(veluca): can probably be optimized.
+ uint64x2_t nbits_lo32 =
+ vandq_u64(vreinterpretq_u64_u32(nbits), vdupq_n_u64(0xFFFFFFFF));
+ uint64x2_t bits_hi32 =
+ vshlq_u64(vshrq_n_u64(vreinterpretq_u64_u32(bits), 32),
+ vreinterpretq_s64_u64(nbits_lo32));
+ uint64x2_t bits_lo32 =
+ vandq_u64(vreinterpretq_u64_u32(bits), vdupq_n_u64(0xFFFFFFFF));
+ uint64x2_t nbits64 =
+ vsraq_n_u64(nbits_lo32, vreinterpretq_u64_u32(nbits), 32);
+ uint64x2_t bits64 = vorrq_u64(bits_hi32, bits_lo32);
+ return Bits64{nbits64, bits64};
+ }
+
+ void Interleave(const Bits32& low) {
+ bits =
+ vorrq_u32(vshlq_u32(bits, vreinterpretq_s32_u32(low.nbits)), low.bits);
+ nbits = vaddq_u32(nbits, low.nbits);
+ }
+
+ void ClipTo(size_t n) {
+ n = std::min<size_t>(n, 4);
+ constexpr uint32_t kMask[8] = {
+ ~0u, ~0u, ~0u, ~0u, 0, 0, 0, 0,
+ };
+ uint32x4_t mask = vld1q_u32(kMask + 4 - n);
+ nbits = vandq_u32(mask, nbits);
+ bits = vandq_u32(mask, bits);
+ }
+ void Skip(size_t n) {
+ n = std::min<size_t>(n, 4);
+ constexpr uint32_t kMask[8] = {
+ 0, 0, 0, 0, ~0u, ~0u, ~0u, ~0u,
+ };
+ uint32x4_t mask = vld1q_u32(kMask + 4 - n);
+ nbits = vandq_u32(mask, nbits);
+ bits = vandq_u32(mask, bits);
+ }
+};
+
+struct Bits16 {
+ uint16x8_t nbits;
+ uint16x8_t bits;
+
+ static Bits16 FromRaw(SIMDVec16 nbits, SIMDVec16 bits) {
+ return Bits16{nbits.vec, bits.vec};
+ }
+
+ Bits32 Merge() const {
+ // TODO(veluca): can probably be optimized.
+ uint32x4_t nbits_lo16 =
+ vandq_u32(vreinterpretq_u32_u16(nbits), vdupq_n_u32(0xFFFF));
+ uint32x4_t bits_hi16 =
+ vshlq_u32(vshrq_n_u32(vreinterpretq_u32_u16(bits), 16),
+ vreinterpretq_s32_u32(nbits_lo16));
+ uint32x4_t bits_lo16 =
+ vandq_u32(vreinterpretq_u32_u16(bits), vdupq_n_u32(0xFFFF));
+ uint32x4_t nbits32 =
+ vsraq_n_u32(nbits_lo16, vreinterpretq_u32_u16(nbits), 16);
+ uint32x4_t bits32 = vorrq_u32(bits_hi16, bits_lo16);
+ return Bits32{nbits32, bits32};
+ }
+
+ void Interleave(const Bits16& low) {
+ bits =
+ vorrq_u16(vshlq_u16(bits, vreinterpretq_s16_u16(low.nbits)), low.bits);
+ nbits = vaddq_u16(nbits, low.nbits);
+ }
+
+ void ClipTo(size_t n) {
+ n = std::min<size_t>(n, 8);
+ constexpr uint16_t kMask[16] = {
+ 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ };
+ uint16x8_t mask = vld1q_u16(kMask + 8 - n);
+ nbits = vandq_u16(mask, nbits);
+ bits = vandq_u16(mask, bits);
+ }
+ void Skip(size_t n) {
+ n = std::min<size_t>(n, 8);
+ constexpr uint16_t kMask[16] = {
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
+ };
+ uint16x8_t mask = vld1q_u16(kMask + 8 - n);
+ nbits = vandq_u16(mask, nbits);
+ bits = vandq_u16(mask, bits);
+ }
+};
+
+#endif
+
+#ifdef FJXL_GENERIC_SIMD
+constexpr size_t SIMDVec32::kLanes;
+constexpr size_t SIMDVec16::kLanes;
+
+//  Each of these functions will process SIMDVec16::kLanes worth of values.
+
+FJXL_INLINE void TokenizeSIMD(const uint16_t* residuals, uint16_t* token_out,
+ uint16_t* nbits_out, uint16_t* bits_out) {
+ SIMDVec16 res = SIMDVec16::Load(residuals);
+ SIMDVec16 token = res.ValToToken();
+ SIMDVec16 nbits = token.SatSubU(SIMDVec16::Val(1));
+ SIMDVec16 bits = res.SatSubU(nbits.Pow2());
+ token.Store(token_out);
+ nbits.Store(nbits_out);
+ bits.Store(bits_out);
+}
+
+FJXL_INLINE void TokenizeSIMD(const uint32_t* residuals, uint16_t* token_out,
+ uint32_t* nbits_out, uint32_t* bits_out) {
+ static_assert(SIMDVec16::kLanes == 2 * SIMDVec32::kLanes, "");
+ SIMDVec32 res_lo = SIMDVec32::Load(residuals);
+ SIMDVec32 res_hi = SIMDVec32::Load(residuals + SIMDVec32::kLanes);
+ SIMDVec32 token_lo = res_lo.ValToToken();
+ SIMDVec32 token_hi = res_hi.ValToToken();
+ SIMDVec32 nbits_lo = token_lo.SatSubU(SIMDVec32::Val(1));
+ SIMDVec32 nbits_hi = token_hi.SatSubU(SIMDVec32::Val(1));
+ SIMDVec32 bits_lo = res_lo.SatSubU(nbits_lo.Pow2());
+ SIMDVec32 bits_hi = res_hi.SatSubU(nbits_hi.Pow2());
+ SIMDVec16 token = SIMDVec16::FromTwo32(token_lo, token_hi);
+ token.Store(token_out);
+ nbits_lo.Store(nbits_out);
+ nbits_hi.Store(nbits_out + SIMDVec32::kLanes);
+ bits_lo.Store(bits_out);
+ bits_hi.Store(bits_out + SIMDVec32::kLanes);
+}
+
+FJXL_INLINE void HuffmanSIMDUpTo13(const uint16_t* tokens,
+ const PrefixCode& code, uint16_t* nbits_out,
+ uint16_t* bits_out) {
+ SIMDVec16 tok = SIMDVec16::Load(tokens).PrepareForU8Lookup();
+ tok.U8Lookup(code.raw_nbits_simd).Store(nbits_out);
+ tok.U8Lookup(code.raw_bits_simd).Store(bits_out);
+}
+
+FJXL_INLINE void HuffmanSIMD14(const uint16_t* tokens, const PrefixCode& code,
+ uint16_t* nbits_out, uint16_t* bits_out) {
+ SIMDVec16 token_cap = SIMDVec16::Val(15);
+ SIMDVec16 tok = SIMDVec16::Load(tokens);
+ SIMDVec16 tok_index = tok.Min(token_cap).PrepareForU8Lookup();
+ SIMDVec16 huff_bits_pre = tok_index.U8Lookup(code.raw_bits_simd);
+ // Set the highest bit when token == 16; the Huffman code is constructed in
+ // such a way that the code for token 15 is the same as the code for 16,
+ // except for the highest bit.
+ Mask16 needs_high_bit = tok.Eq(SIMDVec16::Val(16));
+ SIMDVec16 huff_bits = needs_high_bit.IfThenElse(
+ huff_bits_pre.Or(SIMDVec16::Val(128)), huff_bits_pre);
+ huff_bits.Store(bits_out);
+ tok_index.U8Lookup(code.raw_nbits_simd).Store(nbits_out);
+}
+
+FJXL_INLINE void HuffmanSIMDAbove14(const uint16_t* tokens,
+ const PrefixCode& code, uint16_t* nbits_out,
+ uint16_t* bits_out) {
+ SIMDVec16 tok = SIMDVec16::Load(tokens);
+ // We assume `tok` fits in a *signed* 16-bit integer.
+ Mask16 above = tok.Gt(SIMDVec16::Val(12));
+ // 13, 14 -> 13
+ // 15, 16 -> 14
+ // 17, 18 -> 15
+ SIMDVec16 remap_tok = above.IfThenElse(tok.HAdd(SIMDVec16::Val(13)), tok);
+ SIMDVec16 tok_index = remap_tok.PrepareForU8Lookup();
+ SIMDVec16 huff_bits_pre = tok_index.U8Lookup(code.raw_bits_simd);
+ // Set the highest bit when token == 14, 16, 18.
+ Mask16 needs_high_bit = above.And(tok.Eq(tok.And(SIMDVec16::Val(0xFFFE))));
+ SIMDVec16 huff_bits = needs_high_bit.IfThenElse(
+ huff_bits_pre.Or(SIMDVec16::Val(128)), huff_bits_pre);
+ huff_bits.Store(bits_out);
+ tok_index.U8Lookup(code.raw_nbits_simd).Store(nbits_out);
+}
+
+FJXL_INLINE void StoreSIMDUpTo8(const uint16_t* nbits_tok,
+ const uint16_t* bits_tok,
+ const uint16_t* nbits_huff,
+ const uint16_t* bits_huff, size_t n,
+ size_t skip, Bits32* bits_out) {
+ Bits16 bits =
+ Bits16::FromRaw(SIMDVec16::Load(nbits_tok), SIMDVec16::Load(bits_tok));
+ Bits16 huff_bits =
+ Bits16::FromRaw(SIMDVec16::Load(nbits_huff), SIMDVec16::Load(bits_huff));
+ bits.Interleave(huff_bits);
+ bits.ClipTo(n);
+ bits.Skip(skip);
+ bits_out[0] = bits.Merge();
+}
+
+// Huffman and raw bits don't necessarily fit in a single u16 here.
+FJXL_INLINE void StoreSIMDUpTo14(const uint16_t* nbits_tok,
+ const uint16_t* bits_tok,
+ const uint16_t* nbits_huff,
+ const uint16_t* bits_huff, size_t n,
+ size_t skip, Bits32* bits_out) {
+ VecPair<SIMDVec16> bits =
+ SIMDVec16::Load(bits_tok).Interleave(SIMDVec16::Load(bits_huff));
+ VecPair<SIMDVec16> nbits =
+ SIMDVec16::Load(nbits_tok).Interleave(SIMDVec16::Load(nbits_huff));
+ Bits16 low = Bits16::FromRaw(nbits.low, bits.low);
+ Bits16 hi = Bits16::FromRaw(nbits.hi, bits.hi);
+ low.ClipTo(2 * n);
+ low.Skip(2 * skip);
+ hi.ClipTo(std::max(2 * n, SIMDVec16::kLanes) - SIMDVec16::kLanes);
+ hi.Skip(std::max(2 * skip, SIMDVec16::kLanes) - SIMDVec16::kLanes);
+
+ bits_out[0] = low.Merge();
+ bits_out[1] = hi.Merge();
+}
+
+FJXL_INLINE void StoreSIMDAbove14(const uint32_t* nbits_tok,
+ const uint32_t* bits_tok,
+ const uint16_t* nbits_huff,
+ const uint16_t* bits_huff, size_t n,
+ size_t skip, Bits32* bits_out) {
+ static_assert(SIMDVec16::kLanes == 2 * SIMDVec32::kLanes, "");
+ Bits32 bits_low =
+ Bits32::FromRaw(SIMDVec32::Load(nbits_tok), SIMDVec32::Load(bits_tok));
+ Bits32 bits_hi =
+ Bits32::FromRaw(SIMDVec32::Load(nbits_tok + SIMDVec32::kLanes),
+ SIMDVec32::Load(bits_tok + SIMDVec32::kLanes));
+
+ VecPair<SIMDVec32> huff_bits = SIMDVec16::Load(bits_huff).Upcast();
+ VecPair<SIMDVec32> huff_nbits = SIMDVec16::Load(nbits_huff).Upcast();
+
+ Bits32 huff_low = Bits32::FromRaw(huff_nbits.low, huff_bits.low);
+ Bits32 huff_hi = Bits32::FromRaw(huff_nbits.hi, huff_bits.hi);
+
+ bits_low.Interleave(huff_low);
+ bits_low.ClipTo(n);
+ bits_low.Skip(skip);
+ bits_out[0] = bits_low;
+ bits_hi.Interleave(huff_hi);
+ bits_hi.ClipTo(std::max(n, SIMDVec32::kLanes) - SIMDVec32::kLanes);
+ bits_hi.Skip(std::max(skip, SIMDVec32::kLanes) - SIMDVec32::kLanes);
+ bits_out[1] = bits_hi;
+}
+
+#ifdef FJXL_AVX512
+FJXL_INLINE void StoreToWriterAVX512(const Bits32& bits32, BitWriter& output) {
+ __m512i bits = bits32.bits;
+ __m512i nbits = bits32.nbits;
+
+ // Insert the leftover bits from the bit buffer at the bottom of the vector
+ // and extract the top of the vector.
+ uint64_t trail_bits =
+ _mm512_cvtsi512_si32(_mm512_alignr_epi32(bits, bits, 15));
+ uint64_t trail_nbits =
+ _mm512_cvtsi512_si32(_mm512_alignr_epi32(nbits, nbits, 15));
+ __m512i lead_bits = _mm512_set1_epi32(output.buffer);
+ __m512i lead_nbits = _mm512_set1_epi32(output.bits_in_buffer);
+ bits = _mm512_alignr_epi32(bits, lead_bits, 15);
+ nbits = _mm512_alignr_epi32(nbits, lead_nbits, 15);
+
+ // Merge 32 -> 64 bits.
+ Bits32 b{nbits, bits};
+ Bits64 b64 = b.Merge();
+ bits = b64.bits;
+ nbits = b64.nbits;
+
+ __m512i zero = _mm512_setzero_si512();
+
+ auto sh1 = [zero](__m512i vec) { return _mm512_alignr_epi64(vec, zero, 7); };
+ auto sh2 = [zero](__m512i vec) { return _mm512_alignr_epi64(vec, zero, 6); };
+ auto sh4 = [zero](__m512i vec) { return _mm512_alignr_epi64(vec, zero, 4); };
+
+ // Compute first-past-end-bit-position.
+ __m512i end_interm0 = _mm512_add_epi64(nbits, sh1(nbits));
+ __m512i end_interm1 = _mm512_add_epi64(end_interm0, sh2(end_interm0));
+ __m512i end = _mm512_add_epi64(end_interm1, sh4(end_interm1));
+
+ uint64_t simd_nbits = _mm512_cvtsi512_si32(_mm512_alignr_epi64(end, end, 7));
+
+ // Compute begin-bit-position.
+ __m512i begin = _mm512_sub_epi64(end, nbits);
+
+ // Index of the last bit in the chunk, or the end bit if nbits==0.
+ __m512i last = _mm512_mask_sub_epi64(
+ end, _mm512_cmpneq_epi64_mask(nbits, zero), end, _mm512_set1_epi64(1));
+
+ __m512i lane_offset_mask = _mm512_set1_epi64(63);
+
+ // Starting position of the chunk that each lane will ultimately belong to.
+ __m512i chunk_start = _mm512_andnot_si512(lane_offset_mask, last);
+
+ // For all lanes that contain bits belonging to two different 64-bit chunks,
+ // compute the number of bits that belong to the first chunk.
+ // total # of bits fit in a u16, so we can satsub_u16 here.
+ __m512i first_chunk_nbits = _mm512_subs_epu16(chunk_start, begin);
+
+ // Move all the previous-chunk-bits to the previous lane.
+ __m512i negnbits = _mm512_sub_epi64(_mm512_set1_epi64(64), first_chunk_nbits);
+ __m512i first_chunk_bits =
+ _mm512_srlv_epi64(_mm512_sllv_epi64(bits, negnbits), negnbits);
+ __m512i first_chunk_bits_down =
+ _mm512_alignr_epi32(zero, first_chunk_bits, 2);
+ bits = _mm512_srlv_epi64(bits, first_chunk_nbits);
+ nbits = _mm512_sub_epi64(nbits, first_chunk_nbits);
+ bits = _mm512_or_si512(bits, _mm512_sllv_epi64(first_chunk_bits_down, nbits));
+ begin = _mm512_add_epi64(begin, first_chunk_nbits);
+
+ // We now know that every lane should give bits to only one chunk. We can
+ // shift the bits and then horizontally-or-reduce them within the same chunk.
+ __m512i offset = _mm512_and_si512(begin, lane_offset_mask);
+ __m512i aligned_bits = _mm512_sllv_epi64(bits, offset);
+ // h-or-reduce within same chunk
+ __m512i red0 = _mm512_mask_or_epi64(
+ aligned_bits, _mm512_cmpeq_epi64_mask(sh1(chunk_start), chunk_start),
+ sh1(aligned_bits), aligned_bits);
+ __m512i red1 = _mm512_mask_or_epi64(
+ red0, _mm512_cmpeq_epi64_mask(sh2(chunk_start), chunk_start), sh2(red0),
+ red0);
+ __m512i reduced = _mm512_mask_or_epi64(
+ red1, _mm512_cmpeq_epi64_mask(sh4(chunk_start), chunk_start), sh4(red1),
+ red1);
+ // Extract the highest lane that belongs to each chunk (the lane that ends up
+ // with the OR-ed value of all the other lanes of that chunk).
+ __m512i next_chunk_start =
+ _mm512_alignr_epi32(_mm512_set1_epi64(~0), chunk_start, 2);
+ __m512i result = _mm512_maskz_compress_epi64(
+ _mm512_cmpneq_epi64_mask(chunk_start, next_chunk_start), reduced);
+
+ _mm512_storeu_si512((__m512i*)(output.data.get() + output.bytes_written),
+ result);
+
+ // Update the bit writer and add the last 32-bit lane.
+ // Note that since trail_nbits was at most 32 to begin with, operating on
+ // trail_bits does not risk overflowing.
+ output.bytes_written += simd_nbits / 8;
+ // Here we are implicitly relying on the fact that simd_nbits < 512 to know
+ // that the byte of bitreader data we access is initialized. This is
+ // guaranteed because the remaining bits in the bitreader buffer are at most
+ // 7, so simd_nbits <= 505 always.
+ trail_bits = (trail_bits << (simd_nbits % 8)) +
+ output.data.get()[output.bytes_written];
+ trail_nbits += simd_nbits % 8;
+ StoreLE64(output.data.get() + output.bytes_written, trail_bits);
+ size_t trail_bytes = trail_nbits / 8;
+ output.bits_in_buffer = trail_nbits % 8;
+ output.buffer = trail_bits >> (trail_bytes * 8);
+ output.bytes_written += trail_bytes;
+}
+
+#endif
+
+template <size_t n>
+FJXL_INLINE void StoreToWriter(const Bits32* bits, BitWriter& output) {
+#ifdef FJXL_AVX512
+ static_assert(n <= 2, "");
+ StoreToWriterAVX512(bits[0], output);
+ if (n == 2) {
+ StoreToWriterAVX512(bits[1], output);
+ }
+ return;
+#endif
+ static_assert(n <= 4, "");
+ alignas(64) uint64_t nbits64[Bits64::kLanes * n];
+ alignas(64) uint64_t bits64[Bits64::kLanes * n];
+ bits[0].Merge().Store(nbits64, bits64);
+ if (n > 1) {
+ bits[1].Merge().Store(nbits64 + Bits64::kLanes, bits64 + Bits64::kLanes);
+ }
+ if (n > 2) {
+ bits[2].Merge().Store(nbits64 + 2 * Bits64::kLanes,
+ bits64 + 2 * Bits64::kLanes);
+ }
+ if (n > 3) {
+ bits[3].Merge().Store(nbits64 + 3 * Bits64::kLanes,
+ bits64 + 3 * Bits64::kLanes);
+ }
+ output.WriteMultiple(nbits64, bits64, Bits64::kLanes * n);
+}
+
+namespace detail {
+template <typename T>
+struct IntegerTypes;
+
+template <>
+struct IntegerTypes<SIMDVec16> {
+ using signed_ = int16_t;
+ using unsigned_ = uint16_t;
+};
+
+template <>
+struct IntegerTypes<SIMDVec32> {
+ using signed_ = int32_t;
+ using unsigned_ = uint32_t;
+};
+
+template <typename T>
+struct SIMDType;
+
+template <>
+struct SIMDType<int16_t> {
+ using type = SIMDVec16;
+};
+
+template <>
+struct SIMDType<int32_t> {
+ using type = SIMDVec32;
+};
+
+} // namespace detail
+
+template <typename T>
+using signed_t = typename detail::IntegerTypes<T>::signed_;
+
+template <typename T>
+using unsigned_t = typename detail::IntegerTypes<T>::unsigned_;
+
+template <typename T>
+using simd_t = typename detail::SIMDType<T>::type;
+
+// This function will process exactly one vector worth of pixels.
+
+template <typename T>
+size_t PredictPixels(const signed_t<T>* pixels, const signed_t<T>* pixels_left,
+ const signed_t<T>* pixels_top,
+ const signed_t<T>* pixels_topleft,
+ unsigned_t<T>* residuals) {
+ T px = T::Load((unsigned_t<T>*)pixels);
+ T left = T::Load((unsigned_t<T>*)pixels_left);
+ T top = T::Load((unsigned_t<T>*)pixels_top);
+ T topleft = T::Load((unsigned_t<T>*)pixels_topleft);
+ T ac = left.Sub(topleft);
+ T ab = left.Sub(top);
+ T bc = top.Sub(topleft);
+ T grad = ac.Add(top);
+ T d = ab.Xor(bc);
+ T zero = T::Val(0);
+ T clamp = zero.Gt(d).IfThenElse(top, left);
+ T s = ac.Xor(bc);
+ T pred = zero.Gt(s).IfThenElse(grad, clamp);
+ T res = px.Sub(pred);
+ T res_times_2 = res.Add(res);
+ res = zero.Gt(res).IfThenElse(T::Val(-1).Sub(res_times_2), res_times_2);
+ res.Store(residuals);
+ return res.Eq(T::Val(0)).CountPrefix();
+}
+
+#endif
+
+void EncodeHybridUint000(uint32_t value, uint32_t* token, uint32_t* nbits,
+ uint32_t* bits) {
+ uint32_t n = FloorLog2(value);
+ *token = value ? n + 1 : 0;
+ *nbits = value ? n : 0;
+ *bits = value ? value - (1 << n) : 0;
+}
+
+#ifdef FJXL_AVX512
+constexpr static size_t kLogChunkSize = 5;
+#elif defined(FJXL_AVX2) || defined(FJXL_NEON)
+// Even if NEON only has 128-bit lanes, it is still significantly (~1.3x) faster
+// to process two vectors at a time.
+constexpr static size_t kLogChunkSize = 4;
+#else
+constexpr static size_t kLogChunkSize = 3;
+#endif
+
+constexpr static size_t kChunkSize = 1 << kLogChunkSize;
+
+template <typename Residual>
+void GenericEncodeChunk(const Residual* residuals, size_t n, size_t skip,
+ const PrefixCode& code, BitWriter& output) {
+ for (size_t ix = skip; ix < n; ix++) {
+ unsigned token, nbits, bits;
+ EncodeHybridUint000(residuals[ix], &token, &nbits, &bits);
+ output.Write(code.raw_nbits[token] + nbits,
+ code.raw_bits[token] | bits << code.raw_nbits[token]);
+ }
+}
+
+struct UpTo8Bits {
+ size_t bitdepth;
+ explicit UpTo8Bits(size_t bitdepth) : bitdepth(bitdepth) {
+ assert(bitdepth <= 8);
+ }
+ // Here we can fit up to 9 extra bits + 7 Huffman bits in a u16; for all other
+ // symbols, we could actually go up to 8 Huffman bits as we have at most 8
+ // extra bits; however, the SIMD bit merging logic for AVX2 assumes that no
+ // Huffman length is 8 or more, so we cap at 8 anyway. Last symbol is used for
+ // LZ77 lengths and has no limitations except allowing to represent 32 symbols
+ // in total.
+ static constexpr uint8_t kMinRawLength[12] = {};
+ static constexpr uint8_t kMaxRawLength[12] = {
+ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 10,
+ };
+ static size_t MaxEncodedBitsPerSample() { return 16; }
+ static constexpr size_t kInputBytes = 1;
+ using pixel_t = int16_t;
+ using upixel_t = uint16_t;
+
+ static void PrepareForSimd(const uint8_t* nbits, const uint8_t* bits,
+ size_t n, uint8_t* nbits_simd,
+ uint8_t* bits_simd) {
+ assert(n <= 16);
+ memcpy(nbits_simd, nbits, 16);
+ memcpy(bits_simd, bits, 16);
+ }
+
+ static void EncodeChunk(upixel_t* residuals, size_t n, size_t skip,
+ const PrefixCode& code, BitWriter& output) {
+#ifdef FJXL_GENERIC_SIMD
+ Bits32 bits32[kChunkSize / SIMDVec16::kLanes];
+ alignas(64) uint16_t bits[SIMDVec16::kLanes];
+ alignas(64) uint16_t nbits[SIMDVec16::kLanes];
+ alignas(64) uint16_t bits_huff[SIMDVec16::kLanes];
+ alignas(64) uint16_t nbits_huff[SIMDVec16::kLanes];
+ alignas(64) uint16_t token[SIMDVec16::kLanes];
+ for (size_t i = 0; i < kChunkSize; i += SIMDVec16::kLanes) {
+ TokenizeSIMD(residuals + i, token, nbits, bits);
+ HuffmanSIMDUpTo13(token, code, nbits_huff, bits_huff);
+ StoreSIMDUpTo8(nbits, bits, nbits_huff, bits_huff, std::max(n, i) - i,
+ std::max(skip, i) - i, bits32 + i / SIMDVec16::kLanes);
+ }
+ StoreToWriter<kChunkSize / SIMDVec16::kLanes>(bits32, output);
+ return;
+#endif
+ GenericEncodeChunk(residuals, n, skip, code, output);
+ }
+
+ size_t NumSymbols(bool doing_ycocg) const {
+ // values gain 1 bit for YCoCg, 1 bit for prediction.
+ // Maximum symbol is 1 + effective bit depth of residuals.
+ if (doing_ycocg) {
+ return bitdepth + 3;
+ } else {
+ return bitdepth + 2;
+ }
+ }
+};
+constexpr uint8_t UpTo8Bits::kMinRawLength[];
+constexpr uint8_t UpTo8Bits::kMaxRawLength[];
+
+struct From9To13Bits {
+ size_t bitdepth;
+ explicit From9To13Bits(size_t bitdepth) : bitdepth(bitdepth) {
+ assert(bitdepth <= 13 && bitdepth >= 9);
+ }
+ // Last symbol is used for LZ77 lengths and has no limitations except allowing
+ // to represent 32 symbols in total.
+ // We cannot fit all the bits in a u16, so do not even try and use up to 8
+ // bits per raw symbol.
+ // There are at most 16 raw symbols, so Huffman coding can be SIMDfied without
+ // any special tricks.
+ static constexpr uint8_t kMinRawLength[17] = {};
+ static constexpr uint8_t kMaxRawLength[17] = {
+ 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 10,
+ };
+ static size_t MaxEncodedBitsPerSample() { return 21; }
+ static constexpr size_t kInputBytes = 2;
+ using pixel_t = int16_t;
+ using upixel_t = uint16_t;
+
+ static void PrepareForSimd(const uint8_t* nbits, const uint8_t* bits,
+ size_t n, uint8_t* nbits_simd,
+ uint8_t* bits_simd) {
+ assert(n <= 16);
+ memcpy(nbits_simd, nbits, 16);
+ memcpy(bits_simd, bits, 16);
+ }
+
+ static void EncodeChunk(upixel_t* residuals, size_t n, size_t skip,
+ const PrefixCode& code, BitWriter& output) {
+#ifdef FJXL_GENERIC_SIMD
+ Bits32 bits32[2 * kChunkSize / SIMDVec16::kLanes];
+ alignas(64) uint16_t bits[SIMDVec16::kLanes];
+ alignas(64) uint16_t nbits[SIMDVec16::kLanes];
+ alignas(64) uint16_t bits_huff[SIMDVec16::kLanes];
+ alignas(64) uint16_t nbits_huff[SIMDVec16::kLanes];
+ alignas(64) uint16_t token[SIMDVec16::kLanes];
+ for (size_t i = 0; i < kChunkSize; i += SIMDVec16::kLanes) {
+ TokenizeSIMD(residuals + i, token, nbits, bits);
+ HuffmanSIMDUpTo13(token, code, nbits_huff, bits_huff);
+ StoreSIMDUpTo14(nbits, bits, nbits_huff, bits_huff, std::max(n, i) - i,
+ std::max(skip, i) - i,
+ bits32 + 2 * i / SIMDVec16::kLanes);
+ }
+ StoreToWriter<2 * kChunkSize / SIMDVec16::kLanes>(bits32, output);
+ return;
+#endif
+ GenericEncodeChunk(residuals, n, skip, code, output);
+ }
+
+ size_t NumSymbols(bool doing_ycocg) const {
+ // values gain 1 bit for YCoCg, 1 bit for prediction.
+ // Maximum symbol is 1 + effective bit depth of residuals.
+ if (doing_ycocg) {
+ return bitdepth + 3;
+ } else {
+ return bitdepth + 2;
+ }
+ }
+};
+constexpr uint8_t From9To13Bits::kMinRawLength[];
+constexpr uint8_t From9To13Bits::kMaxRawLength[];
+
+void CheckHuffmanBitsSIMD(int bits1, int nbits1, int bits2, int nbits2) {
+ assert(nbits1 == 8);
+ assert(nbits2 == 8);
+ assert(bits2 == (bits1 | 128));
+}
+
+struct Exactly14Bits {
+ explicit Exactly14Bits(size_t bitdepth) { assert(bitdepth == 14); }
+ // Force LZ77 symbols to have at least 8 bits, and raw symbols 15 and 16 to
+ // have exactly 8, and no other symbol to have 8 or more. This ensures that
+ // the representation for 15 and 16 is identical up to one bit.
+ static constexpr uint8_t kMinRawLength[18] = {
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 7,
+ };
+ static constexpr uint8_t kMaxRawLength[18] = {
+ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 10,
+ };
+ static constexpr size_t bitdepth = 14;
+ static size_t MaxEncodedBitsPerSample() { return 22; }
+ static constexpr size_t kInputBytes = 2;
+ using pixel_t = int16_t;
+ using upixel_t = uint16_t;
+
+ static void PrepareForSimd(const uint8_t* nbits, const uint8_t* bits,
+ size_t n, uint8_t* nbits_simd,
+ uint8_t* bits_simd) {
+ assert(n == 17);
+ CheckHuffmanBitsSIMD(bits[15], nbits[15], bits[16], nbits[16]);
+ memcpy(nbits_simd, nbits, 16);
+ memcpy(bits_simd, bits, 16);
+ }
+
+ static void EncodeChunk(upixel_t* residuals, size_t n, size_t skip,
+ const PrefixCode& code, BitWriter& output) {
+#ifdef FJXL_GENERIC_SIMD
+ Bits32 bits32[2 * kChunkSize / SIMDVec16::kLanes];
+ alignas(64) uint16_t bits[SIMDVec16::kLanes];
+ alignas(64) uint16_t nbits[SIMDVec16::kLanes];
+ alignas(64) uint16_t bits_huff[SIMDVec16::kLanes];
+ alignas(64) uint16_t nbits_huff[SIMDVec16::kLanes];
+ alignas(64) uint16_t token[SIMDVec16::kLanes];
+ for (size_t i = 0; i < kChunkSize; i += SIMDVec16::kLanes) {
+ TokenizeSIMD(residuals + i, token, nbits, bits);
+ HuffmanSIMD14(token, code, nbits_huff, bits_huff);
+ StoreSIMDUpTo14(nbits, bits, nbits_huff, bits_huff, std::max(n, i) - i,
+ std::max(skip, i) - i,
+ bits32 + 2 * i / SIMDVec16::kLanes);
+ }
+ StoreToWriter<2 * kChunkSize / SIMDVec16::kLanes>(bits32, output);
+ return;
+#endif
+ GenericEncodeChunk(residuals, n, skip, code, output);
+ }
+
+ size_t NumSymbols(bool) const { return 17; }
+};
+constexpr uint8_t Exactly14Bits::kMinRawLength[];
+constexpr uint8_t Exactly14Bits::kMaxRawLength[];
+
+struct MoreThan14Bits {
+ size_t bitdepth;
+ explicit MoreThan14Bits(size_t bitdepth) : bitdepth(bitdepth) {
+ assert(bitdepth > 14);
+ assert(bitdepth <= 16);
+ }
+ // Force LZ77 symbols to have at least 8 bits, and raw symbols 13 to 18 to
+ // have exactly 8, and no other symbol to have 8 or more. This ensures that
+ // the representation for (13, 14), (15, 16), (17, 18) is identical up to one
+ // bit.
+ static constexpr uint8_t kMinRawLength[20] = {
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 7,
+ };
+ static constexpr uint8_t kMaxRawLength[20] = {
+ 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 10,
+ };
+ static size_t MaxEncodedBitsPerSample() { return 24; }
+ static constexpr size_t kInputBytes = 2;
+ using pixel_t = int32_t;
+ using upixel_t = uint32_t;
+
+ static void PrepareForSimd(const uint8_t* nbits, const uint8_t* bits,
+ size_t n, uint8_t* nbits_simd,
+ uint8_t* bits_simd) {
+ assert(n == 19);
+ CheckHuffmanBitsSIMD(bits[13], nbits[13], bits[14], nbits[14]);
+ CheckHuffmanBitsSIMD(bits[15], nbits[15], bits[16], nbits[16]);
+ CheckHuffmanBitsSIMD(bits[17], nbits[17], bits[18], nbits[18]);
+ for (size_t i = 0; i < 14; i++) {
+ nbits_simd[i] = nbits[i];
+ bits_simd[i] = bits[i];
+ }
+ nbits_simd[14] = nbits[15];
+ bits_simd[14] = bits[15];
+ nbits_simd[15] = nbits[17];
+ bits_simd[15] = bits[17];
+ }
+
+ static void EncodeChunk(upixel_t* residuals, size_t n, size_t skip,
+ const PrefixCode& code, BitWriter& output) {
+#ifdef FJXL_GENERIC_SIMD
+ Bits32 bits32[2 * kChunkSize / SIMDVec16::kLanes];
+ alignas(64) uint32_t bits[SIMDVec16::kLanes];
+ alignas(64) uint32_t nbits[SIMDVec16::kLanes];
+ alignas(64) uint16_t bits_huff[SIMDVec16::kLanes];
+ alignas(64) uint16_t nbits_huff[SIMDVec16::kLanes];
+ alignas(64) uint16_t token[SIMDVec16::kLanes];
+ for (size_t i = 0; i < kChunkSize; i += SIMDVec16::kLanes) {
+ TokenizeSIMD(residuals + i, token, nbits, bits);
+ HuffmanSIMDAbove14(token, code, nbits_huff, bits_huff);
+ StoreSIMDAbove14(nbits, bits, nbits_huff, bits_huff, std::max(n, i) - i,
+ std::max(skip, i) - i,
+ bits32 + 2 * i / SIMDVec16::kLanes);
+ }
+ StoreToWriter<2 * kChunkSize / SIMDVec16::kLanes>(bits32, output);
+ return;
+#endif
+ GenericEncodeChunk(residuals, n, skip, code, output);
+ }
+ size_t NumSymbols(bool) const { return 19; }
+};
+constexpr uint8_t MoreThan14Bits::kMinRawLength[];
+constexpr uint8_t MoreThan14Bits::kMaxRawLength[];
+
+void PrepareDCGlobalCommon(bool is_single_group, size_t width, size_t height,
+ const PrefixCode code[4], BitWriter* output) {
+ output->Allocate(100000 + (is_single_group ? width * height * 16 : 0));
+ // No patches, spline or noise.
+ output->Write(1, 1); // default DC dequantization factors (?)
+ output->Write(1, 1); // use global tree / histograms
+ output->Write(1, 0); // no lz77 for the tree
+
+ output->Write(1, 1); // simple code for the tree's context map
+ output->Write(2, 0); // all contexts clustered together
+ output->Write(1, 1); // use prefix code for tree
+ output->Write(4, 0); // 000 hybrid uint
+ output->Write(6, 0b100011); // Alphabet size is 4 (var16)
+ output->Write(2, 1); // simple prefix code
+ output->Write(2, 3); // with 4 symbols
+ output->Write(2, 0);
+ output->Write(2, 1);
+ output->Write(2, 2);
+ output->Write(2, 3);
+ output->Write(1, 0); // First tree encoding option
+ // Huffman table + extra bits for the tree.
+ uint8_t symbol_bits[6] = {0b00, 0b10, 0b001, 0b101, 0b0011, 0b0111};
+ uint8_t symbol_nbits[6] = {2, 2, 3, 3, 4, 4};
+ // Write a tree with a leaf per channel, and gradient predictor for every
+ // leaf.
+ for (auto v : {1, 2, 1, 4, 1, 0, 0, 5, 0, 0, 0, 0, 5,
+ 0, 0, 0, 0, 5, 0, 0, 0, 0, 5, 0, 0, 0}) {
+ output->Write(symbol_nbits[v], symbol_bits[v]);
+ }
+
+ output->Write(1, 1); // Enable lz77 for the main bitstream
+ output->Write(2, 0b00); // lz77 offset 224
+ static_assert(kLZ77Offset == 224, "");
+ output->Write(4, 0b1010); // lz77 min length 7
+ // 400 hybrid uint config for lz77
+ output->Write(4, 4);
+ output->Write(3, 0);
+ output->Write(3, 0);
+
+ output->Write(1, 1); // simple code for the context map
+ output->Write(2, 3); // 3 bits per entry
+ output->Write(3, 4); // channel 3
+ output->Write(3, 3); // channel 2
+ output->Write(3, 2); // channel 1
+ output->Write(3, 1); // channel 0
+ output->Write(3, 0); // distance histogram first
+
+ output->Write(1, 1); // use prefix codes
+ output->Write(4, 0); // 000 hybrid uint config for distances (only need 0)
+ for (size_t i = 0; i < 4; i++) {
+ output->Write(4, 0); // 000 hybrid uint config for symbols (only <= 10)
+ }
+
+ // Distance alphabet size:
+ output->Write(5, 0b00001); // 2: just need 1 for RLE (i.e. distance 1)
+ // Symbol + LZ77 alphabet size:
+ for (size_t i = 0; i < 4; i++) {
+ output->Write(1, 1); // > 1
+ output->Write(4, 8); // <= 512
+ output->Write(8, 256); // == 512
+ }
+
+ // Distance histogram:
+ output->Write(2, 1); // simple prefix code
+ output->Write(2, 0); // with one symbol
+ output->Write(1, 1); // 1
+
+ // Symbol + lz77 histogram:
+ for (size_t i = 0; i < 4; i++) {
+ code[i].WriteTo(output);
+ }
+
+ // Group header for global modular image.
+ output->Write(1, 1); // Global tree
+ output->Write(1, 1); // All default wp
+}
+
+void PrepareDCGlobal(bool is_single_group, size_t width, size_t height,
+ size_t nb_chans, const PrefixCode code[4],
+ BitWriter* output) {
+ PrepareDCGlobalCommon(is_single_group, width, height, code, output);
+ if (nb_chans > 2) {
+ output->Write(2, 0b01); // 1 transform
+ output->Write(2, 0b00); // RCT
+ output->Write(5, 0b00000); // Starting from ch 0
+ output->Write(2, 0b00); // YCoCg
+ } else {
+ output->Write(2, 0b00); // no transforms
+ }
+ if (!is_single_group) {
+ output->ZeroPadToByte();
+ }
+}
+
+template <typename BitDepth>
+struct ChunkEncoder {
+ FJXL_INLINE static void EncodeRle(size_t count, const PrefixCode& code,
+ BitWriter& output) {
+ if (count == 0) return;
+ count -= kLZ77MinLength + 1;
+ if (count < kLZ77CacheSize) {
+ output.Write(code.lz77_cache_nbits[count], code.lz77_cache_bits[count]);
+ } else {
+ unsigned token, nbits, bits;
+ EncodeHybridUintLZ77(count, &token, &nbits, &bits);
+ uint64_t wbits = bits;
+ wbits = (wbits << code.lz77_nbits[token]) | code.lz77_bits[token];
+ wbits = (wbits << code.raw_nbits[0]) | code.raw_bits[0];
+ output.Write(code.lz77_nbits[token] + nbits + code.raw_nbits[0], wbits);
+ }
+ }
+
+ FJXL_INLINE void Chunk(size_t run, typename BitDepth::upixel_t* residuals,
+ size_t skip, size_t n) {
+ EncodeRle(run, *code, *output);
+ BitDepth::EncodeChunk(residuals, n, skip, *code, *output);
+ }
+
+ inline void Finalize(size_t run) { EncodeRle(run, *code, *output); }
+
+ const PrefixCode* code;
+ BitWriter* output;
+};
+
+template <typename BitDepth>
+struct ChunkSampleCollector {
+ FJXL_INLINE void Rle(size_t count, uint64_t* lz77_counts) {
+ if (count == 0) return;
+ raw_counts[0] += 1;
+ count -= kLZ77MinLength + 1;
+ unsigned token, nbits, bits;
+ EncodeHybridUintLZ77(count, &token, &nbits, &bits);
+ lz77_counts[token]++;
+ }
+
+ FJXL_INLINE void Chunk(size_t run, typename BitDepth::upixel_t* residuals,
+ size_t skip, size_t n) {
+ // Run is broken. Encode the run and encode the individual vector.
+ Rle(run, lz77_counts);
+ for (size_t ix = skip; ix < n; ix++) {
+ unsigned token, nbits, bits;
+ EncodeHybridUint000(residuals[ix], &token, &nbits, &bits);
+ raw_counts[token]++;
+ }
+ }
+
+ // don't count final run since we don't know how long it really is
+ void Finalize(size_t run) {}
+
+ uint64_t* raw_counts;
+ uint64_t* lz77_counts;
+};
+
+constexpr uint32_t PackSigned(int32_t value) {
+ return (static_cast<uint32_t>(value) << 1) ^
+ ((static_cast<uint32_t>(~value) >> 31) - 1);
+}
+
+template <typename T, typename BitDepth>
+struct ChannelRowProcessor {
+ using upixel_t = typename BitDepth::upixel_t;
+ using pixel_t = typename BitDepth::pixel_t;
+ T* t;
+ void ProcessChunk(const pixel_t* row, const pixel_t* row_left,
+ const pixel_t* row_top, const pixel_t* row_topleft,
+ size_t n) {
+ alignas(64) upixel_t residuals[kChunkSize] = {};
+ size_t prefix_size = 0;
+ size_t required_prefix_size = 0;
+#ifdef FJXL_GENERIC_SIMD
+ constexpr size_t kNum =
+ sizeof(pixel_t) == 2 ? SIMDVec16::kLanes : SIMDVec32::kLanes;
+ for (size_t ix = 0; ix < kChunkSize; ix += kNum) {
+ size_t c =
+ PredictPixels<simd_t<pixel_t>>(row + ix, row_left + ix, row_top + ix,
+ row_topleft + ix, residuals + ix);
+ prefix_size =
+ prefix_size == required_prefix_size ? prefix_size + c : prefix_size;
+ required_prefix_size += kNum;
+ }
+#else
+ for (size_t ix = 0; ix < kChunkSize; ix++) {
+ pixel_t px = row[ix];
+ pixel_t left = row_left[ix];
+ pixel_t top = row_top[ix];
+ pixel_t topleft = row_topleft[ix];
+ pixel_t ac = left - topleft;
+ pixel_t ab = left - top;
+ pixel_t bc = top - topleft;
+ pixel_t grad = static_cast<pixel_t>(static_cast<upixel_t>(ac) +
+ static_cast<upixel_t>(top));
+ pixel_t d = ab ^ bc;
+ pixel_t clamp = d < 0 ? top : left;
+ pixel_t s = ac ^ bc;
+ pixel_t pred = s < 0 ? grad : clamp;
+ residuals[ix] = PackSigned(px - pred);
+ prefix_size = prefix_size == required_prefix_size
+ ? prefix_size + (residuals[ix] == 0)
+ : prefix_size;
+ required_prefix_size += 1;
+ }
+#endif
+ prefix_size = std::min(n, prefix_size);
+ if (prefix_size == n && (run > 0 || prefix_size > kLZ77MinLength)) {
+ // Run continues, nothing to do.
+ run += prefix_size;
+ } else if (prefix_size + run > kLZ77MinLength) {
+ // Run is broken. Encode the run and encode the individual vector.
+ t->Chunk(run + prefix_size, residuals, prefix_size, n);
+ run = 0;
+ } else {
+ // There was no run to begin with.
+ t->Chunk(0, residuals, 0, n);
+ }
+ }
+
+ void ProcessRow(const pixel_t* row, const pixel_t* row_left,
+ const pixel_t* row_top, const pixel_t* row_topleft,
+ size_t xs) {
+ for (size_t x = 0; x < xs; x += kChunkSize) {
+ ProcessChunk(row + x, row_left + x, row_top + x, row_topleft + x,
+ std::min(kChunkSize, xs - x));
+ }
+ }
+
+ void Finalize() { t->Finalize(run); }
+ // Invariant: run == 0 or run > kLZ77MinLength.
+ size_t run = 0;
+};
+
+uint16_t LoadLE16(const unsigned char* ptr) {
+ return uint16_t{ptr[0]} | (uint16_t{ptr[1]} << 8);
+}
+
+uint16_t SwapEndian(uint16_t in) { return (in >> 8) | (in << 8); }
+
+#ifdef FJXL_GENERIC_SIMD
+void StorePixels(SIMDVec16 p, int16_t* dest) { p.Store((uint16_t*)dest); }
+
+void StorePixels(SIMDVec16 p, int32_t* dest) {
+ VecPair<SIMDVec32> p_up = p.Upcast();
+ p_up.low.Store((uint32_t*)dest);
+ p_up.hi.Store((uint32_t*)dest + SIMDVec32::kLanes);
+}
+#endif
+
+template <typename pixel_t>
+void FillRowG8(const unsigned char* rgba, size_t oxs, pixel_t* luma) {
+ size_t x = 0;
+#ifdef FJXL_GENERIC_SIMD
+ for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
+ auto rgb = SIMDVec16::LoadG8(rgba + x);
+ StorePixels(rgb[0], luma + x);
+ }
+#endif
+ for (; x < oxs; x++) {
+ luma[x] = rgba[x];
+ }
+}
+
+template <bool big_endian, typename pixel_t>
+void FillRowG16(const unsigned char* rgba, size_t oxs, pixel_t* luma) {
+ size_t x = 0;
+#ifdef FJXL_GENERIC_SIMD
+ for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
+ auto rgb = SIMDVec16::LoadG16(rgba + 2 * x);
+ if (big_endian) {
+ rgb[0].SwapEndian();
+ }
+ StorePixels(rgb[0], luma + x);
+ }
+#endif
+ for (; x < oxs; x++) {
+ uint16_t val = LoadLE16(rgba + 2 * x);
+ if (big_endian) {
+ val = SwapEndian(val);
+ }
+ luma[x] = val;
+ }
+}
+
+template <typename pixel_t>
+void FillRowGA8(const unsigned char* rgba, size_t oxs, pixel_t* luma,
+ pixel_t* alpha) {
+ size_t x = 0;
+#ifdef FJXL_GENERIC_SIMD
+ for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
+ auto rgb = SIMDVec16::LoadGA8(rgba + 2 * x);
+ StorePixels(rgb[0], luma + x);
+ StorePixels(rgb[1], alpha + x);
+ }
+#endif
+ for (; x < oxs; x++) {
+ luma[x] = rgba[2 * x];
+ alpha[x] = rgba[2 * x + 1];
+ }
+}
+
+template <bool big_endian, typename pixel_t>
+void FillRowGA16(const unsigned char* rgba, size_t oxs, pixel_t* luma,
+ pixel_t* alpha) {
+ size_t x = 0;
+#ifdef FJXL_GENERIC_SIMD
+ for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
+ auto rgb = SIMDVec16::LoadGA16(rgba + 4 * x);
+ if (big_endian) {
+ rgb[0].SwapEndian();
+ rgb[1].SwapEndian();
+ }
+ StorePixels(rgb[0], luma + x);
+ StorePixels(rgb[1], alpha + x);
+ }
+#endif
+ for (; x < oxs; x++) {
+ uint16_t l = LoadLE16(rgba + 4 * x);
+ uint16_t a = LoadLE16(rgba + 4 * x + 2);
+ if (big_endian) {
+ l = SwapEndian(l);
+ a = SwapEndian(a);
+ }
+ luma[x] = l;
+ alpha[x] = a;
+ }
+}
+
+template <typename pixel_t>
+void StoreYCoCg(pixel_t r, pixel_t g, pixel_t b, pixel_t* y, pixel_t* co,
+ pixel_t* cg) {
+ *co = r - b;
+ pixel_t tmp = b + (*co >> 1);
+ *cg = g - tmp;
+ *y = tmp + (*cg >> 1);
+}
+
+#ifdef FJXL_GENERIC_SIMD
+void StoreYCoCg(SIMDVec16 r, SIMDVec16 g, SIMDVec16 b, int16_t* y, int16_t* co,
+ int16_t* cg) {
+ SIMDVec16 co_v = r.Sub(b);
+ SIMDVec16 tmp = b.Add(co_v.SignedShiftRight<1>());
+ SIMDVec16 cg_v = g.Sub(tmp);
+ SIMDVec16 y_v = tmp.Add(cg_v.SignedShiftRight<1>());
+ y_v.Store((uint16_t*)y);
+ co_v.Store((uint16_t*)co);
+ cg_v.Store((uint16_t*)cg);
+}
+
+void StoreYCoCg(SIMDVec16 r, SIMDVec16 g, SIMDVec16 b, int32_t* y, int32_t* co,
+ int32_t* cg) {
+ VecPair<SIMDVec32> r_up = r.Upcast();
+ VecPair<SIMDVec32> g_up = g.Upcast();
+ VecPair<SIMDVec32> b_up = b.Upcast();
+ SIMDVec32 co_lo_v = r_up.low.Sub(b_up.low);
+ SIMDVec32 tmp_lo = b_up.low.Add(co_lo_v.SignedShiftRight<1>());
+ SIMDVec32 cg_lo_v = g_up.low.Sub(tmp_lo);
+ SIMDVec32 y_lo_v = tmp_lo.Add(cg_lo_v.SignedShiftRight<1>());
+ SIMDVec32 co_hi_v = r_up.hi.Sub(b_up.hi);
+ SIMDVec32 tmp_hi = b_up.hi.Add(co_hi_v.SignedShiftRight<1>());
+ SIMDVec32 cg_hi_v = g_up.hi.Sub(tmp_hi);
+ SIMDVec32 y_hi_v = tmp_hi.Add(cg_hi_v.SignedShiftRight<1>());
+ y_lo_v.Store((uint32_t*)y);
+ co_lo_v.Store((uint32_t*)co);
+ cg_lo_v.Store((uint32_t*)cg);
+ y_hi_v.Store((uint32_t*)y + SIMDVec32::kLanes);
+ co_hi_v.Store((uint32_t*)co + SIMDVec32::kLanes);
+ cg_hi_v.Store((uint32_t*)cg + SIMDVec32::kLanes);
+}
+#endif
+
+template <typename pixel_t>
+void FillRowRGB8(const unsigned char* rgba, size_t oxs, pixel_t* y, pixel_t* co,
+ pixel_t* cg) {
+ size_t x = 0;
+#ifdef FJXL_GENERIC_SIMD
+ for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
+ auto rgb = SIMDVec16::LoadRGB8(rgba + 3 * x);
+ StoreYCoCg(rgb[0], rgb[1], rgb[2], y + x, co + x, cg + x);
+ }
+#endif
+ for (; x < oxs; x++) {
+ uint16_t r = rgba[3 * x];
+ uint16_t g = rgba[3 * x + 1];
+ uint16_t b = rgba[3 * x + 2];
+ StoreYCoCg<pixel_t>(r, g, b, y + x, co + x, cg + x);
+ }
+}
+
+template <bool big_endian, typename pixel_t>
+void FillRowRGB16(const unsigned char* rgba, size_t oxs, pixel_t* y,
+ pixel_t* co, pixel_t* cg) {
+ size_t x = 0;
+#ifdef FJXL_GENERIC_SIMD
+ for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
+ auto rgb = SIMDVec16::LoadRGB16(rgba + 6 * x);
+ if (big_endian) {
+ rgb[0].SwapEndian();
+ rgb[1].SwapEndian();
+ rgb[2].SwapEndian();
+ }
+ StoreYCoCg(rgb[0], rgb[1], rgb[2], y + x, co + x, cg + x);
+ }
+#endif
+ for (; x < oxs; x++) {
+ uint16_t r = LoadLE16(rgba + 6 * x);
+ uint16_t g = LoadLE16(rgba + 6 * x + 2);
+ uint16_t b = LoadLE16(rgba + 6 * x + 4);
+ if (big_endian) {
+ r = SwapEndian(r);
+ g = SwapEndian(g);
+ b = SwapEndian(b);
+ }
+ StoreYCoCg<pixel_t>(r, g, b, y + x, co + x, cg + x);
+ }
+}
+
+template <typename pixel_t>
+void FillRowRGBA8(const unsigned char* rgba, size_t oxs, pixel_t* y,
+ pixel_t* co, pixel_t* cg, pixel_t* alpha) {
+ size_t x = 0;
+#ifdef FJXL_GENERIC_SIMD
+ for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
+ auto rgb = SIMDVec16::LoadRGBA8(rgba + 4 * x);
+ StoreYCoCg(rgb[0], rgb[1], rgb[2], y + x, co + x, cg + x);
+ StorePixels(rgb[3], alpha + x);
+ }
+#endif
+ for (; x < oxs; x++) {
+ uint16_t r = rgba[4 * x];
+ uint16_t g = rgba[4 * x + 1];
+ uint16_t b = rgba[4 * x + 2];
+ uint16_t a = rgba[4 * x + 3];
+ StoreYCoCg<pixel_t>(r, g, b, y + x, co + x, cg + x);
+ alpha[x] = a;
+ }
+}
+
+template <bool big_endian, typename pixel_t>
+void FillRowRGBA16(const unsigned char* rgba, size_t oxs, pixel_t* y,
+ pixel_t* co, pixel_t* cg, pixel_t* alpha) {
+ size_t x = 0;
+#ifdef FJXL_GENERIC_SIMD
+ for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) {
+ auto rgb = SIMDVec16::LoadRGBA16(rgba + 8 * x);
+ if (big_endian) {
+ rgb[0].SwapEndian();
+ rgb[1].SwapEndian();
+ rgb[2].SwapEndian();
+ rgb[3].SwapEndian();
+ }
+ StoreYCoCg(rgb[0], rgb[1], rgb[2], y + x, co + x, cg + x);
+ StorePixels(rgb[3], alpha + x);
+ }
+#endif
+ for (; x < oxs; x++) {
+ uint16_t r = LoadLE16(rgba + 8 * x);
+ uint16_t g = LoadLE16(rgba + 8 * x + 2);
+ uint16_t b = LoadLE16(rgba + 8 * x + 4);
+ uint16_t a = LoadLE16(rgba + 8 * x + 6);
+ if (big_endian) {
+ r = SwapEndian(r);
+ g = SwapEndian(g);
+ b = SwapEndian(b);
+ a = SwapEndian(a);
+ }
+ StoreYCoCg<pixel_t>(r, g, b, y + x, co + x, cg + x);
+ alpha[x] = a;
+ }
+}
+
+template <typename Processor, typename BitDepth>
+void ProcessImageArea(const unsigned char* rgba, size_t x0, size_t y0,
+ size_t xs, size_t yskip, size_t ys, size_t row_stride,
+ BitDepth bitdepth, size_t nb_chans, bool big_endian,
+ Processor* processors) {
+ constexpr size_t kPadding = 32;
+
+ using pixel_t = typename BitDepth::pixel_t;
+
+ constexpr size_t kAlign = 64;
+ constexpr size_t kAlignPixels = kAlign / sizeof(pixel_t);
+
+ auto align = [=](pixel_t* ptr) {
+ size_t offset = reinterpret_cast<uintptr_t>(ptr) % kAlign;
+ if (offset) {
+ ptr += offset / sizeof(pixel_t);
+ }
+ return ptr;
+ };
+
+ constexpr size_t kNumPx =
+ (256 + kPadding * 2 + kAlignPixels + kAlignPixels - 1) / kAlignPixels *
+ kAlignPixels;
+
+ std::vector<std::array<std::array<pixel_t, kNumPx>, 2>> group_data(nb_chans);
+
+ for (size_t y = 0; y < ys; y++) {
+ const auto rgba_row =
+ rgba + row_stride * (y0 + y) + x0 * nb_chans * BitDepth::kInputBytes;
+ pixel_t* crow[4] = {};
+ pixel_t* prow[4] = {};
+ for (size_t i = 0; i < nb_chans; i++) {
+ crow[i] = align(&group_data[i][y & 1][kPadding]);
+ prow[i] = align(&group_data[i][(y - 1) & 1][kPadding]);
+ }
+
+ // Pre-fill rows with YCoCg converted pixels.
+ if (nb_chans == 1) {
+ if (BitDepth::kInputBytes == 1) {
+ FillRowG8(rgba_row, xs, crow[0]);
+ } else if (big_endian) {
+ FillRowG16</*big_endian=*/true>(rgba_row, xs, crow[0]);
+ } else {
+ FillRowG16</*big_endian=*/false>(rgba_row, xs, crow[0]);
+ }
+ } else if (nb_chans == 2) {
+ if (BitDepth::kInputBytes == 1) {
+ FillRowGA8(rgba_row, xs, crow[0], crow[1]);
+ } else if (big_endian) {
+ FillRowGA16</*big_endian=*/true>(rgba_row, xs, crow[0], crow[1]);
+ } else {
+ FillRowGA16</*big_endian=*/false>(rgba_row, xs, crow[0], crow[1]);
+ }
+ } else if (nb_chans == 3) {
+ if (BitDepth::kInputBytes == 1) {
+ FillRowRGB8(rgba_row, xs, crow[0], crow[1], crow[2]);
+ } else if (big_endian) {
+ FillRowRGB16</*big_endian=*/true>(rgba_row, xs, crow[0], crow[1],
+ crow[2]);
+ } else {
+ FillRowRGB16</*big_endian=*/false>(rgba_row, xs, crow[0], crow[1],
+ crow[2]);
+ }
+ } else {
+ if (BitDepth::kInputBytes == 1) {
+ FillRowRGBA8(rgba_row, xs, crow[0], crow[1], crow[2], crow[3]);
+ } else if (big_endian) {
+ FillRowRGBA16</*big_endian=*/true>(rgba_row, xs, crow[0], crow[1],
+ crow[2], crow[3]);
+ } else {
+ FillRowRGBA16</*big_endian=*/false>(rgba_row, xs, crow[0], crow[1],
+ crow[2], crow[3]);
+ }
+ }
+ // Deal with x == 0.
+ for (size_t c = 0; c < nb_chans; c++) {
+ *(crow[c] - 1) = y > 0 ? *(prow[c]) : 0;
+ // Fix topleft.
+ *(prow[c] - 1) = y > 0 ? *(prow[c]) : 0;
+ }
+ if (y < yskip) continue;
+ for (size_t c = 0; c < nb_chans; c++) {
+ // Get pointers to px/left/top/topleft data to speedup loop.
+ const pixel_t* row = crow[c];
+ const pixel_t* row_left = crow[c] - 1;
+ const pixel_t* row_top = y == 0 ? row_left : prow[c];
+ const pixel_t* row_topleft = y == 0 ? row_left : prow[c] - 1;
+
+ processors[c].ProcessRow(row, row_left, row_top, row_topleft, xs);
+ }
+ }
+ for (size_t c = 0; c < nb_chans; c++) {
+ processors[c].Finalize();
+ }
+}
+
+template <typename BitDepth>
+void WriteACSection(const unsigned char* rgba, size_t x0, size_t y0, size_t xs,
+ size_t ys, size_t row_stride, bool is_single_group,
+ BitDepth bitdepth, size_t nb_chans, bool big_endian,
+ const PrefixCode code[4],
+ std::array<BitWriter, 4>& output) {
+ for (size_t i = 0; i < nb_chans; i++) {
+ if (is_single_group && i == 0) continue;
+ output[i].Allocate(xs * ys * bitdepth.MaxEncodedBitsPerSample() + 4);
+ }
+ if (!is_single_group) {
+ // Group header for modular image.
+ // When the image is single-group, the global modular image is the one
+ // that contains the pixel data, and there is no group header.
+ output[0].Write(1, 1); // Global tree
+ output[0].Write(1, 1); // All default wp
+ output[0].Write(2, 0b00); // 0 transforms
+ }
+
+ ChunkEncoder<BitDepth> encoders[4];
+ ChannelRowProcessor<ChunkEncoder<BitDepth>, BitDepth> row_encoders[4];
+ for (size_t c = 0; c < nb_chans; c++) {
+ row_encoders[c].t = &encoders[c];
+ encoders[c].output = &output[c];
+ encoders[c].code = &code[c];
+ }
+ ProcessImageArea<ChannelRowProcessor<ChunkEncoder<BitDepth>, BitDepth>>(
+ rgba, x0, y0, xs, 0, ys, row_stride, bitdepth, nb_chans, big_endian,
+ row_encoders);
+}
+
+constexpr int kHashExp = 16;
+constexpr uint32_t kHashSize = 1 << kHashExp;
+constexpr uint32_t kHashMultiplier = 2654435761;
+constexpr int kMaxColors = 512;
+
+// can be any function that returns a value in 0 .. kHashSize-1
+// has to map 0 to 0
+inline uint32_t pixel_hash(uint32_t p) {
+ return (p * kHashMultiplier) >> (32 - kHashExp);
+}
+
+template <size_t nb_chans>
+void FillRowPalette(const unsigned char* inrow, size_t xs,
+ const int16_t* lookup, int16_t* out) {
+ for (size_t x = 0; x < xs; x++) {
+ uint32_t p = 0;
+ memcpy(&p, inrow + x * nb_chans, nb_chans);
+ out[x] = lookup[pixel_hash(p)];
+ }
+}
+
+template <typename Processor>
+void ProcessImageAreaPalette(const unsigned char* rgba, size_t x0, size_t y0,
+ size_t xs, size_t yskip, size_t ys,
+ size_t row_stride, const int16_t* lookup,
+ size_t nb_chans, Processor* processors) {
+ constexpr size_t kPadding = 32;
+
+ std::vector<std::array<int16_t, 256 + kPadding * 2>> group_data(2);
+ Processor& row_encoder = processors[0];
+
+ for (size_t y = 0; y < ys; y++) {
+ // Pre-fill rows with palette converted pixels.
+ const unsigned char* inrow = rgba + row_stride * (y0 + y) + x0 * nb_chans;
+ int16_t* outrow = &group_data[y & 1][kPadding];
+ if (nb_chans == 1) {
+ FillRowPalette<1>(inrow, xs, lookup, outrow);
+ } else if (nb_chans == 2) {
+ FillRowPalette<2>(inrow, xs, lookup, outrow);
+ } else if (nb_chans == 3) {
+ FillRowPalette<3>(inrow, xs, lookup, outrow);
+ } else if (nb_chans == 4) {
+ FillRowPalette<4>(inrow, xs, lookup, outrow);
+ }
+ // Deal with x == 0.
+ group_data[y & 1][kPadding - 1] =
+ y > 0 ? group_data[(y - 1) & 1][kPadding] : 0;
+ // Fix topleft.
+ group_data[(y - 1) & 1][kPadding - 1] =
+ y > 0 ? group_data[(y - 1) & 1][kPadding] : 0;
+ // Get pointers to px/left/top/topleft data to speedup loop.
+ const int16_t* row = &group_data[y & 1][kPadding];
+ const int16_t* row_left = &group_data[y & 1][kPadding - 1];
+ const int16_t* row_top =
+ y == 0 ? row_left : &group_data[(y - 1) & 1][kPadding];
+ const int16_t* row_topleft =
+ y == 0 ? row_left : &group_data[(y - 1) & 1][kPadding - 1];
+
+ row_encoder.ProcessRow(row, row_left, row_top, row_topleft, xs);
+ }
+ row_encoder.Finalize();
+}
+
+void WriteACSectionPalette(const unsigned char* rgba, size_t x0, size_t y0,
+ size_t xs, size_t ys, size_t row_stride,
+ bool is_single_group, const PrefixCode code[4],
+ const int16_t* lookup, size_t nb_chans,
+ BitWriter& output) {
+ if (!is_single_group) {
+ output.Allocate(16 * xs * ys + 4);
+ // Group header for modular image.
+ // When the image is single-group, the global modular image is the one
+ // that contains the pixel data, and there is no group header.
+ output.Write(1, 1); // Global tree
+ output.Write(1, 1); // All default wp
+ output.Write(2, 0b00); // 0 transforms
+ }
+
+ ChunkEncoder<UpTo8Bits> encoder;
+ ChannelRowProcessor<ChunkEncoder<UpTo8Bits>, UpTo8Bits> row_encoder;
+
+ row_encoder.t = &encoder;
+ encoder.output = &output;
+ encoder.code = &code[is_single_group ? 1 : 0];
+ ProcessImageAreaPalette<
+ ChannelRowProcessor<ChunkEncoder<UpTo8Bits>, UpTo8Bits>>(
+ rgba, x0, y0, xs, 0, ys, row_stride, lookup, nb_chans, &row_encoder);
+}
+
+template <typename BitDepth>
+void CollectSamples(const unsigned char* rgba, size_t x0, size_t y0, size_t xs,
+ size_t row_stride, size_t row_count,
+ uint64_t raw_counts[4][kNumRawSymbols],
+ uint64_t lz77_counts[4][kNumLZ77], bool is_single_group,
+ bool palette, BitDepth bitdepth, size_t nb_chans,
+ bool big_endian, const int16_t* lookup) {
+ if (palette) {
+ ChunkSampleCollector<UpTo8Bits> sample_collectors[4];
+ ChannelRowProcessor<ChunkSampleCollector<UpTo8Bits>, UpTo8Bits>
+ row_sample_collectors[4];
+ for (size_t c = 0; c < nb_chans; c++) {
+ row_sample_collectors[c].t = &sample_collectors[c];
+ sample_collectors[c].raw_counts = raw_counts[is_single_group ? 1 : 0];
+ sample_collectors[c].lz77_counts = lz77_counts[is_single_group ? 1 : 0];
+ }
+ ProcessImageAreaPalette<
+ ChannelRowProcessor<ChunkSampleCollector<UpTo8Bits>, UpTo8Bits>>(
+ rgba, x0, y0, xs, 1, 1 + row_count, row_stride, lookup, nb_chans,
+ row_sample_collectors);
+ } else {
+ ChunkSampleCollector<BitDepth> sample_collectors[4];
+ ChannelRowProcessor<ChunkSampleCollector<BitDepth>, BitDepth>
+ row_sample_collectors[4];
+ for (size_t c = 0; c < nb_chans; c++) {
+ row_sample_collectors[c].t = &sample_collectors[c];
+ sample_collectors[c].raw_counts = raw_counts[c];
+ sample_collectors[c].lz77_counts = lz77_counts[c];
+ }
+ ProcessImageArea<
+ ChannelRowProcessor<ChunkSampleCollector<BitDepth>, BitDepth>>(
+ rgba, x0, y0, xs, 1, 1 + row_count, row_stride, bitdepth, nb_chans,
+ big_endian, row_sample_collectors);
+ }
+}
+
+void PrepareDCGlobalPalette(bool is_single_group, size_t width, size_t height,
+ const PrefixCode code[4],
+ const std::vector<uint32_t>& palette,
+ size_t pcolors, BitWriter* output) {
+ PrepareDCGlobalCommon(is_single_group, width, height, code, output);
+ output->Write(2, 0b01); // 1 transform
+ output->Write(2, 0b01); // Palette
+ output->Write(5, 0b00000); // Starting from ch 0
+ output->Write(2, 0b10); // 4-channel palette (RGBA)
+ // pcolors <= kMaxColors + kChunkSize - 1
+ static_assert(kMaxColors + kChunkSize < 1281,
+ "add code to signal larger palette sizes");
+ if (pcolors < 256) {
+ output->Write(2, 0b00);
+ output->Write(8, pcolors);
+ } else {
+ output->Write(2, 0b01);
+ output->Write(10, pcolors - 256);
+ }
+
+ output->Write(2, 0b00); // nb_deltas == 0
+ output->Write(4, 0); // Zero predictor for delta palette
+ // Encode palette
+ ChunkEncoder<UpTo8Bits> encoder;
+ ChannelRowProcessor<ChunkEncoder<UpTo8Bits>, UpTo8Bits> row_encoder;
+ row_encoder.t = &encoder;
+ encoder.output = output;
+ encoder.code = &code[0];
+ int16_t p[4][32 + 1024] = {};
+ uint8_t prgba[4];
+ size_t i = 0;
+ size_t have_zero = 0;
+ if (palette[pcolors - 1] == 0) have_zero = 1;
+ for (; i < pcolors; i++) {
+ memcpy(prgba, &palette[i], 4);
+ p[0][16 + i + have_zero] = prgba[0];
+ p[1][16 + i + have_zero] = prgba[1];
+ p[2][16 + i + have_zero] = prgba[2];
+ p[3][16 + i + have_zero] = prgba[3];
+ }
+ p[0][15] = 0;
+ row_encoder.ProcessRow(p[0] + 16, p[0] + 15, p[0] + 15, p[0] + 15, pcolors);
+ p[1][15] = p[0][16];
+ p[0][15] = p[0][16];
+ row_encoder.ProcessRow(p[1] + 16, p[1] + 15, p[0] + 16, p[0] + 15, pcolors);
+ p[2][15] = p[1][16];
+ p[1][15] = p[1][16];
+ row_encoder.ProcessRow(p[2] + 16, p[2] + 15, p[1] + 16, p[1] + 15, pcolors);
+ p[3][15] = p[2][16];
+ p[2][15] = p[2][16];
+ row_encoder.ProcessRow(p[3] + 16, p[3] + 15, p[2] + 16, p[2] + 15, pcolors);
+ row_encoder.Finalize();
+
+ if (!is_single_group) {
+ output->ZeroPadToByte();
+ }
+}
+
+template <typename BitDepth>
+JxlFastLosslessFrameState* LLEnc(const unsigned char* rgba, size_t width,
+ size_t stride, size_t height,
+ BitDepth bitdepth, size_t nb_chans,
+ bool big_endian, int effort,
+ void* runner_opaque,
+ FJxlParallelRunner runner) {
+ assert(width != 0);
+ assert(height != 0);
+ assert(stride >= nb_chans * BitDepth::kInputBytes * width);
+
+ // Count colors to try palette
+ std::vector<uint32_t> palette(kHashSize);
+ palette[0] = 1;
+ std::vector<int16_t> lookup(kHashSize);
+ lookup[0] = 0;
+ int pcolors = 0;
+ bool collided = effort < 2 || bitdepth.bitdepth != 8 ||
+ nb_chans < 4; // todo: also do rgb palette
+ for (size_t y = 0; y < height && !collided; y++) {
+ const unsigned char* r = rgba + stride * y;
+ size_t x = 0;
+ if (nb_chans == 4) {
+ // this is just an unrolling of the next loop
+ for (; x + 7 < width; x += 8) {
+ uint32_t p[8], index[8];
+ memcpy(p, r + x * 4, 32);
+ for (int i = 0; i < 8; i++) index[i] = pixel_hash(p[i]);
+ for (int i = 0; i < 8; i++) {
+ uint32_t init_entry = index[i] ? 0 : 1;
+ if (init_entry != palette[index[i]] && p[i] != palette[index[i]]) {
+ collided = true;
+ }
+ }
+ for (int i = 0; i < 8; i++) palette[index[i]] = p[i];
+ }
+ for (; x < width; x++) {
+ uint32_t p;
+ memcpy(&p, r + x * 4, 4);
+ uint32_t index = pixel_hash(p);
+ uint32_t init_entry = index ? 0 : 1;
+ if (init_entry != palette[index] && p != palette[index]) {
+ collided = true;
+ }
+ palette[index] = p;
+ }
+ } else {
+ for (; x < width; x++) {
+ uint32_t p = 0;
+ memcpy(&p, r + x * nb_chans, nb_chans);
+ uint32_t index = pixel_hash(p);
+ uint32_t init_entry = index ? 0 : 1;
+ if (init_entry != palette[index] && p != palette[index]) {
+ collided = true;
+ }
+ palette[index] = p;
+ }
+ }
+ }
+
+ int nb_entries = 0;
+ if (!collided) {
+ if (palette[0] == 0) pcolors = 1;
+ if (palette[0] == 1) palette[0] = 0;
+ bool have_color = false;
+ uint8_t minG = 255, maxG = 0;
+ for (uint32_t k = 0; k < kHashSize; k++) {
+ if (palette[k] == 0) continue;
+ uint8_t p[4];
+ memcpy(p, &palette[k], 4);
+ // move entries to front so sort has less work
+ palette[nb_entries] = palette[k];
+ if (p[0] != p[1] || p[0] != p[2]) have_color = true;
+ if (p[1] < minG) minG = p[1];
+ if (p[1] > maxG) maxG = p[1];
+ nb_entries++;
+ // don't do palette if too many colors are needed
+ if (nb_entries + pcolors > kMaxColors) {
+ collided = true;
+ break;
+ }
+ }
+ if (!have_color) {
+ // don't do palette if it's just grayscale without many holes
+ if (maxG - minG < nb_entries * 1.4f) collided = true;
+ }
+ }
+ if (!collided) {
+ std::sort(
+ palette.begin(), palette.begin() + nb_entries,
+ [](uint32_t ap, uint32_t bp) {
+ if (ap == 0) return false;
+ if (bp == 0) return true;
+ uint8_t a[4], b[4];
+ memcpy(a, &ap, 4);
+ memcpy(b, &bp, 4);
+ float ay, by;
+ ay = (0.299f * a[0] + 0.587f * a[1] + 0.114f * a[2] + 0.01f) * a[3];
+ by = (0.299f * b[0] + 0.587f * b[1] + 0.114f * b[2] + 0.01f) * b[3];
+ return ay < by; // sort on alpha*luma
+ });
+ for (int k = 0; k < nb_entries; k++) {
+ if (palette[k] == 0) break;
+ lookup[pixel_hash(palette[k])] = pcolors++;
+ }
+ }
+
+ size_t num_groups_x = (width + 255) / 256;
+ size_t num_groups_y = (height + 255) / 256;
+ size_t num_dc_groups_x = (width + 2047) / 2048;
+ size_t num_dc_groups_y = (height + 2047) / 2048;
+
+ uint64_t raw_counts[4][kNumRawSymbols] = {};
+ uint64_t lz77_counts[4][kNumLZ77] = {};
+
+ bool onegroup = num_groups_x == 1 && num_groups_y == 1;
+
+ // sample the middle (effort * 2) rows of every group
+ for (size_t g = 0; g < num_groups_y * num_groups_x; g++) {
+ size_t xg = g % num_groups_x;
+ size_t yg = g / num_groups_x;
+ int y_offset = yg * 256;
+ int y_max = std::min<size_t>(height - yg * 256, 256);
+ int y_begin = y_offset + std::max<int>(0, y_max - 2 * effort) / 2;
+ int y_count =
+ std::min<int>(2 * effort * y_max / 256, y_offset + y_max - y_begin - 1);
+ int x_max =
+ std::min<size_t>(width - xg * 256, 256) / kChunkSize * kChunkSize;
+ CollectSamples(rgba, xg * 256, y_begin, x_max, stride, y_count, raw_counts,
+ lz77_counts, onegroup, !collided, bitdepth, nb_chans,
+ big_endian, lookup.data());
+ }
+
+ // TODO(veluca): can probably improve this and make it bitdepth-dependent.
+ uint64_t base_raw_counts[kNumRawSymbols] = {
+ 3843, 852, 1270, 1214, 1014, 727, 481, 300, 159, 51,
+ 5, 1, 1, 1, 1, 1, 1, 1, 1};
+
+ bool doing_ycocg = nb_chans > 2 && collided;
+ for (size_t i = bitdepth.NumSymbols(doing_ycocg); i < kNumRawSymbols; i++) {
+ base_raw_counts[i] = 0;
+ }
+
+ for (size_t c = 0; c < 4; c++) {
+ for (size_t i = 0; i < kNumRawSymbols; i++) {
+ raw_counts[c][i] = (raw_counts[c][i] << 8) + base_raw_counts[i];
+ }
+ }
+
+ if (!collided) {
+ unsigned token, nbits, bits;
+ EncodeHybridUint000(PackSigned(pcolors - 1), &token, &nbits, &bits);
+ // ensure all palette indices can actually be encoded
+ for (size_t i = 0; i < token + 1; i++)
+ raw_counts[0][i] = std::max<uint64_t>(raw_counts[0][i], 1);
+ // these tokens are only used for the palette itself so they can get a bad
+ // code
+ for (size_t i = token + 1; i < 10; i++) raw_counts[0][i] = 1;
+ }
+
+ uint64_t base_lz77_counts[kNumLZ77] = {
+ 29, 27, 25, 23, 21, 21, 19, 18, 21, 17, 16, 15, 15, 14,
+ 13, 13, 137, 98, 61, 34, 1, 1, 1, 1, 1, 1, 1, 1,
+ };
+
+ for (size_t c = 0; c < 4; c++) {
+ for (size_t i = 0; i < kNumLZ77; i++) {
+ lz77_counts[c][i] = (lz77_counts[c][i] << 8) + base_lz77_counts[i];
+ }
+ }
+
+ alignas(64) PrefixCode hcode[4];
+ for (size_t i = 0; i < 4; i++) {
+ hcode[i] = PrefixCode(bitdepth, raw_counts[i], lz77_counts[i]);
+ }
+
+ size_t num_groups = onegroup ? 1
+ : (2 + num_dc_groups_x * num_dc_groups_y +
+ num_groups_x * num_groups_y);
+
+ JxlFastLosslessFrameState* frame_state = new JxlFastLosslessFrameState();
+
+ frame_state->width = width;
+ frame_state->height = height;
+ frame_state->nb_chans = nb_chans;
+ frame_state->bitdepth = bitdepth.bitdepth;
+
+ frame_state->group_data = std::vector<std::array<BitWriter, 4>>(num_groups);
+ if (collided) {
+ PrepareDCGlobal(onegroup, width, height, nb_chans, hcode,
+ &frame_state->group_data[0][0]);
+ } else {
+ PrepareDCGlobalPalette(onegroup, width, height, hcode, palette, pcolors,
+ &frame_state->group_data[0][0]);
+ }
+
+ auto run_one = [&](size_t g) {
+ size_t xg = g % num_groups_x;
+ size_t yg = g / num_groups_x;
+ size_t group_id =
+ onegroup ? 0 : (2 + num_dc_groups_x * num_dc_groups_y + g);
+ size_t xs = std::min<size_t>(width - xg * 256, 256);
+ size_t ys = std::min<size_t>(height - yg * 256, 256);
+ size_t x0 = xg * 256;
+ size_t y0 = yg * 256;
+ auto& gd = frame_state->group_data[group_id];
+ if (collided) {
+ WriteACSection(rgba, x0, y0, xs, ys, stride, onegroup, bitdepth, nb_chans,
+ big_endian, hcode, gd);
+
+ } else {
+ WriteACSectionPalette(rgba, x0, y0, xs, ys, stride, onegroup, hcode,
+ lookup.data(), nb_chans, gd[0]);
+ }
+ };
+
+ runner(
+ runner_opaque, &run_one,
+ +[](void* r, size_t i) { (*reinterpret_cast<decltype(&run_one)>(r))(i); },
+ num_groups_x * num_groups_y);
+
+ return frame_state;
+}
+
+JxlFastLosslessFrameState* JxlFastLosslessEncodeImpl(
+ const unsigned char* rgba, size_t width, size_t stride, size_t height,
+ size_t nb_chans, size_t bitdepth, bool big_endian, int effort,
+ void* runner_opaque, FJxlParallelRunner runner) {
+ assert(bitdepth > 0);
+ assert(nb_chans <= 4);
+ assert(nb_chans != 0);
+ if (bitdepth <= 8) {
+ return LLEnc(rgba, width, stride, height, UpTo8Bits(bitdepth), nb_chans,
+ big_endian, effort, runner_opaque, runner);
+ }
+ if (bitdepth <= 13) {
+ return LLEnc(rgba, width, stride, height, From9To13Bits(bitdepth), nb_chans,
+ big_endian, effort, runner_opaque, runner);
+ }
+ if (bitdepth == 14) {
+ return LLEnc(rgba, width, stride, height, Exactly14Bits(bitdepth), nb_chans,
+ big_endian, effort, runner_opaque, runner);
+ }
+ return LLEnc(rgba, width, stride, height, MoreThan14Bits(bitdepth), nb_chans,
+ big_endian, effort, runner_opaque, runner);
+}
+
+} // namespace
+
+#endif // FJXL_SELF_INCLUDE
+
+#ifndef FJXL_SELF_INCLUDE
+
+#define FJXL_SELF_INCLUDE
+
+// If we have NEON enabled, it is the default target.
+#if FJXL_ENABLE_NEON
+
+namespace default_implementation {
+#define FJXL_NEON
+#include "lib/jxl/enc_fast_lossless.cc"
+#undef FJXL_NEON
+} // namespace default_implementation
+
+#else // FJXL_ENABLE_NEON
+
+namespace default_implementation {
+#include "lib/jxl/enc_fast_lossless.cc"
+}
+
+#if FJXL_ENABLE_AVX2
+#ifdef __clang__
+#pragma clang attribute push(__attribute__((target("avx,avx2"))), \
+ apply_to = function)
+// Causes spurious warnings on clang5.
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wmissing-braces"
+#elif defined(__GNUC__)
+#pragma GCC push_options
+// Seems to cause spurious errors on GCC8.
+#pragma GCC diagnostic ignored "-Wpsabi"
+#pragma GCC target "avx,avx2"
+#endif
+
+namespace AVX2 {
+#define FJXL_AVX2
+#include "lib/jxl/enc_fast_lossless.cc"
+#undef FJXL_AVX2
+} // namespace AVX2
+
+#ifdef __clang__
+#pragma clang attribute pop
+#pragma clang diagnostic pop
+#elif defined(__GNUC__)
+#pragma GCC pop_options
+#endif
+#endif // FJXL_ENABLE_AVX2
+
+#if FJXL_ENABLE_AVX512
+#ifdef __clang__
+#pragma clang attribute push( \
+ __attribute__((target("avx512cd,avx512bw,avx512vl,avx512f,avx512vbmi"))), \
+ apply_to = function)
+#elif defined(__GNUC__)
+#pragma GCC push_options
+#pragma GCC target "avx512cd,avx512bw,avx512vl,avx512f,avx512vbmi"
+#endif
+
+namespace AVX512 {
+#define FJXL_AVX512
+#include "lib/jxl/enc_fast_lossless.cc"
+#undef FJXL_AVX512
+} // namespace AVX512
+
+#ifdef __clang__
+#pragma clang attribute pop
+#elif defined(__GNUC__)
+#pragma GCC pop_options
+#endif
+#endif // FJXL_ENABLE_AVX512
+
+#endif
+
+extern "C" {
+
+size_t JxlFastLosslessEncode(const unsigned char* rgba, size_t width,
+ size_t row_stride, size_t height, size_t nb_chans,
+ size_t bitdepth, int big_endian, int effort,
+ unsigned char** output, void* runner_opaque,
+ FJxlParallelRunner runner) {
+ auto frame_state = JxlFastLosslessPrepareFrame(
+ rgba, width, row_stride, height, nb_chans, bitdepth, big_endian, effort,
+ runner_opaque, runner);
+ JxlFastLosslessPrepareHeader(frame_state, /*add_image_header=*/1,
+ /*is_last=*/1);
+ size_t output_size = JxlFastLosslessMaxRequiredOutput(frame_state);
+ *output = (unsigned char*)malloc(output_size);
+ size_t written = 0;
+ size_t total = 0;
+ while ((written = JxlFastLosslessWriteOutput(frame_state, *output + total,
+ output_size - total)) != 0) {
+ total += written;
+ }
+ return total;
+}
+
+JxlFastLosslessFrameState* JxlFastLosslessPrepareFrame(
+ const unsigned char* rgba, size_t width, size_t row_stride, size_t height,
+ size_t nb_chans, size_t bitdepth, int big_endian, int effort,
+ void* runner_opaque, FJxlParallelRunner runner) {
+ auto trivial_runner =
+ +[](void*, void* opaque, void fun(void*, size_t), size_t count) {
+ for (size_t i = 0; i < count; i++) {
+ fun(opaque, i);
+ }
+ };
+
+ if (runner == nullptr) {
+ runner = trivial_runner;
+ }
+
+#if FJXL_ENABLE_AVX512
+ if (__builtin_cpu_supports("avx512cd") &&
+ __builtin_cpu_supports("avx512vbmi") &&
+ __builtin_cpu_supports("avx512bw") && __builtin_cpu_supports("avx512f") &&
+ __builtin_cpu_supports("avx512vl")) {
+ return AVX512::JxlFastLosslessEncodeImpl(rgba, width, row_stride, height,
+ nb_chans, bitdepth, big_endian,
+ effort, runner_opaque, runner);
+ }
+#endif
+#if FJXL_ENABLE_AVX2
+ if (__builtin_cpu_supports("avx2")) {
+ return AVX2::JxlFastLosslessEncodeImpl(rgba, width, row_stride, height,
+ nb_chans, bitdepth, big_endian,
+ effort, runner_opaque, runner);
+ }
+#endif
+
+ return default_implementation::JxlFastLosslessEncodeImpl(
+ rgba, width, row_stride, height, nb_chans, bitdepth, big_endian, effort,
+ runner_opaque, runner);
+}
+
+} // extern "C"
+
+#endif // FJXL_SELF_INCLUDE