diff options
Diffstat (limited to 'third_party/jpeg-xl/lib/jxl/enc_fast_lossless.cc')
-rw-r--r-- | third_party/jpeg-xl/lib/jxl/enc_fast_lossless.cc | 3860 |
1 files changed, 3860 insertions, 0 deletions
diff --git a/third_party/jpeg-xl/lib/jxl/enc_fast_lossless.cc b/third_party/jpeg-xl/lib/jxl/enc_fast_lossless.cc new file mode 100644 index 0000000000..286990ee8a --- /dev/null +++ b/third_party/jpeg-xl/lib/jxl/enc_fast_lossless.cc @@ -0,0 +1,3860 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#ifndef FJXL_SELF_INCLUDE + +#include "lib/jxl/enc_fast_lossless.h" + +#include <assert.h> +#include <stdint.h> +#include <stdio.h> +#include <string.h> + +#include <algorithm> +#include <array> +#include <limits> +#include <memory> +#include <vector> + +// Enable NEON and AVX2/AVX512 if not asked to do otherwise and the compilers +// support it. +#if defined(__aarch64__) || defined(_M_ARM64) +#include <arm_neon.h> + +#ifndef FJXL_ENABLE_NEON +#define FJXL_ENABLE_NEON 1 +#endif + +#elif (defined(__x86_64__) || defined(_M_X64)) && !defined(_MSC_VER) +#include <immintrin.h> + +// manually add _mm512_cvtsi512_si32 definition if missing +// (e.g. with Xcode on macOS Mojave) +// copied from gcc 11.1.0 include/avx512fintrin.h line 14367-14373 +#if defined(__clang__) && \ + ((!defined(__apple_build_version__) && __clang_major__ < 10) || \ + (defined(__apple_build_version__) && __apple_build_version__ < 12000032)) +inline int __attribute__((__gnu_inline__, __always_inline__, __artificial__)) +_mm512_cvtsi512_si32(__m512i __A) { + __v16si __B = (__v16si)__A; + return __B[0]; +} +#endif + +// TODO(veluca): MSVC support for dynamic dispatch. +#if defined(__clang__) || defined(__GNUC__) + +#ifndef FJXL_ENABLE_AVX2 +#define FJXL_ENABLE_AVX2 1 +#endif + +#ifndef FJXL_ENABLE_AVX512 +// On clang-7 or earlier, and gcc-10 or earlier, AVX512 seems broken. +#if (defined(__clang__) && \ + (!defined(__apple_build_version__) && __clang_major__ > 7) || \ + (defined(__apple_build_version__) && \ + __apple_build_version__ > 10010046)) || \ + (defined(__GNUC__) && __GNUC__ > 10) +#define FJXL_ENABLE_AVX512 1 +#endif +#endif + +#endif + +#endif + +#ifndef FJXL_ENABLE_NEON +#define FJXL_ENABLE_NEON 0 +#endif + +#ifndef FJXL_ENABLE_AVX2 +#define FJXL_ENABLE_AVX2 0 +#endif + +#ifndef FJXL_ENABLE_AVX512 +#define FJXL_ENABLE_AVX512 0 +#endif + +namespace { +#if defined(_MSC_VER) && !defined(__clang__) +#define FJXL_INLINE __forceinline +FJXL_INLINE uint32_t FloorLog2(uint32_t v) { + unsigned long index; + _BitScanReverse(&index, v); + return index; +} +FJXL_INLINE uint32_t CtzNonZero(uint64_t v) { + unsigned long index; + _BitScanForward(&index, v); + return index; +} +#else +#define FJXL_INLINE inline __attribute__((always_inline)) +FJXL_INLINE uint32_t FloorLog2(uint32_t v) { + return v ? 31 - __builtin_clz(v) : 0; +} +FJXL_INLINE uint32_t CtzNonZero(uint64_t v) { return __builtin_ctzll(v); } +#endif + +// Compiles to a memcpy on little-endian systems. +FJXL_INLINE void StoreLE64(uint8_t* tgt, uint64_t data) { +#if (!defined(__BYTE_ORDER__) || (__BYTE_ORDER__ != __ORDER_LITTLE_ENDIAN__)) + for (int i = 0; i < 8; i++) { + tgt[i] = (data >> (i * 8)) & 0xFF; + } +#else + memcpy(tgt, &data, 8); +#endif +} + +FJXL_INLINE size_t AddBits(uint32_t count, uint64_t bits, uint8_t* data_buf, + size_t& bits_in_buffer, uint64_t& bit_buffer) { + bit_buffer |= bits << bits_in_buffer; + bits_in_buffer += count; + StoreLE64(data_buf, bit_buffer); + size_t bytes_in_buffer = bits_in_buffer / 8; + bits_in_buffer -= bytes_in_buffer * 8; + bit_buffer >>= bytes_in_buffer * 8; + return bytes_in_buffer; +} + +struct BitWriter { + void Allocate(size_t maximum_bit_size) { + assert(data == nullptr); + // Leave some padding. + data.reset(static_cast<uint8_t*>(malloc(maximum_bit_size / 8 + 64))); + } + + void Write(uint32_t count, uint64_t bits) { + bytes_written += AddBits(count, bits, data.get() + bytes_written, + bits_in_buffer, buffer); + } + + void ZeroPadToByte() { + if (bits_in_buffer != 0) { + Write(8 - bits_in_buffer, 0); + } + } + + FJXL_INLINE void WriteMultiple(const uint64_t* nbits, const uint64_t* bits, + size_t n) { + // Necessary because Write() is only guaranteed to work with <=56 bits. + // Trying to SIMD-fy this code results in lower speed (and definitely less + // clarity). + { + for (size_t i = 0; i < n; i++) { + this->buffer |= bits[i] << this->bits_in_buffer; + memcpy(this->data.get() + this->bytes_written, &this->buffer, 8); + uint64_t shift = 64 - this->bits_in_buffer; + this->bits_in_buffer += nbits[i]; + // This `if` seems to be faster than using ternaries. + if (this->bits_in_buffer >= 64) { + uint64_t next_buffer = bits[i] >> shift; + this->buffer = next_buffer; + this->bits_in_buffer -= 64; + this->bytes_written += 8; + } + } + memcpy(this->data.get() + this->bytes_written, &this->buffer, 8); + size_t bytes_in_buffer = this->bits_in_buffer / 8; + this->bits_in_buffer -= bytes_in_buffer * 8; + this->buffer >>= bytes_in_buffer * 8; + this->bytes_written += bytes_in_buffer; + } + } + + std::unique_ptr<uint8_t[], void (*)(void*)> data = {nullptr, free}; + size_t bytes_written = 0; + size_t bits_in_buffer = 0; + uint64_t buffer = 0; +}; + +} // namespace + +extern "C" { + +struct JxlFastLosslessFrameState { + size_t width; + size_t height; + size_t nb_chans; + size_t bitdepth; + BitWriter header; + std::vector<std::array<BitWriter, 4>> group_data; + size_t current_bit_writer = 0; + size_t bit_writer_byte_pos = 0; + size_t bits_in_buffer = 0; + uint64_t bit_buffer = 0; +}; + +size_t JxlFastLosslessOutputSize(const JxlFastLosslessFrameState* frame) { + size_t total_size_groups = 0; + for (size_t i = 0; i < frame->group_data.size(); i++) { + size_t sz = 0; + for (size_t j = 0; j < frame->nb_chans; j++) { + const auto& writer = frame->group_data[i][j]; + sz += writer.bytes_written * 8 + writer.bits_in_buffer; + } + sz = (sz + 7) / 8; + total_size_groups += sz; + } + return frame->header.bytes_written + total_size_groups; +} + +size_t JxlFastLosslessMaxRequiredOutput( + const JxlFastLosslessFrameState* frame) { + return JxlFastLosslessOutputSize(frame) + 32; +} + +void JxlFastLosslessPrepareHeader(JxlFastLosslessFrameState* frame, + int add_image_header, int is_last) { + BitWriter* output = &frame->header; + output->Allocate(1000 + frame->group_data.size() * 32); + + std::vector<size_t> group_sizes(frame->group_data.size()); + for (size_t i = 0; i < frame->group_data.size(); i++) { + size_t sz = 0; + for (size_t j = 0; j < frame->nb_chans; j++) { + const auto& writer = frame->group_data[i][j]; + sz += writer.bytes_written * 8 + writer.bits_in_buffer; + } + sz = (sz + 7) / 8; + group_sizes[i] = sz; + } + + bool have_alpha = (frame->nb_chans == 2 || frame->nb_chans == 4); + + if (add_image_header) { + // Signature + output->Write(16, 0x0AFF); + + // Size header, hand-crafted. + // Not small + output->Write(1, 0); + + auto wsz = [output](size_t size) { + if (size - 1 < (1 << 9)) { + output->Write(2, 0b00); + output->Write(9, size - 1); + } else if (size - 1 < (1 << 13)) { + output->Write(2, 0b01); + output->Write(13, size - 1); + } else if (size - 1 < (1 << 18)) { + output->Write(2, 0b10); + output->Write(18, size - 1); + } else { + output->Write(2, 0b11); + output->Write(30, size - 1); + } + }; + + wsz(frame->height); + + // No special ratio. + output->Write(3, 0); + + wsz(frame->width); + + // Hand-crafted ImageMetadata. + output->Write(1, 0); // all_default + output->Write(1, 0); // extra_fields + output->Write(1, 0); // bit_depth.floating_point_sample + if (frame->bitdepth == 8) { + output->Write(2, 0b00); // bit_depth.bits_per_sample = 8 + } else if (frame->bitdepth == 10) { + output->Write(2, 0b01); // bit_depth.bits_per_sample = 10 + } else if (frame->bitdepth == 12) { + output->Write(2, 0b10); // bit_depth.bits_per_sample = 12 + } else { + output->Write(2, 0b11); // 1 + u(6) + output->Write(6, frame->bitdepth - 1); + } + if (frame->bitdepth <= 14) { + output->Write(1, 1); // 16-bit-buffer sufficient + } else { + output->Write(1, 0); // 16-bit-buffer NOT sufficient + } + if (have_alpha) { + output->Write(2, 0b01); // One extra channel + output->Write(1, 1); // ... all_default (ie. 8-bit alpha) + } else { + output->Write(2, 0b00); // No extra channel + } + output->Write(1, 0); // Not XYB + if (frame->nb_chans > 2) { + output->Write(1, 1); // color_encoding.all_default (sRGB) + } else { + output->Write(1, 0); // color_encoding.all_default false + output->Write(1, 0); // color_encoding.want_icc false + output->Write(2, 1); // grayscale + output->Write(2, 1); // D65 + output->Write(1, 0); // no gamma transfer function + output->Write(2, 0b10); // tf: 2 + u(4) + output->Write(4, 11); // tf of sRGB + output->Write(2, 1); // relative rendering intent + } + output->Write(2, 0b00); // No extensions. + + output->Write(1, 1); // all_default transform data + + // No ICC, no preview. Frame should start at byte boundery. + output->ZeroPadToByte(); + } + + // Handcrafted frame header. + output->Write(1, 0); // all_default + output->Write(2, 0b00); // regular frame + output->Write(1, 1); // modular + output->Write(2, 0b00); // default flags + output->Write(1, 0); // not YCbCr + output->Write(2, 0b00); // no upsampling + if (have_alpha) { + output->Write(2, 0b00); // no alpha upsampling + } + output->Write(2, 0b01); // default group size + output->Write(2, 0b00); // exactly one pass + output->Write(1, 0); // no custom size or origin + output->Write(2, 0b00); // kReplace blending mode + if (have_alpha) { + output->Write(2, 0b00); // kReplace blending mode for alpha channel + } + output->Write(1, is_last); // is_last + output->Write(2, 0b00); // a frame has no name + output->Write(1, 0); // loop filter is not all_default + output->Write(1, 0); // no gaborish + output->Write(2, 0); // 0 EPF iters + output->Write(2, 0b00); // No LF extensions + output->Write(2, 0b00); // No FH extensions + + output->Write(1, 0); // No TOC permutation + output->ZeroPadToByte(); // TOC is byte-aligned. + for (size_t i = 0; i < frame->group_data.size(); i++) { + size_t sz = group_sizes[i]; + if (sz < (1 << 10)) { + output->Write(2, 0b00); + output->Write(10, sz); + } else if (sz - 1024 < (1 << 14)) { + output->Write(2, 0b01); + output->Write(14, sz - 1024); + } else if (sz - 17408 < (1 << 22)) { + output->Write(2, 0b10); + output->Write(22, sz - 17408); + } else { + output->Write(2, 0b11); + output->Write(30, sz - 4211712); + } + } + output->ZeroPadToByte(); // Groups are byte-aligned. +} + +#if FJXL_ENABLE_AVX512 +__attribute__((target("avx512vbmi2"))) static size_t AppendBytesWithBitOffset( + const uint8_t* data, size_t n, size_t bit_buffer_nbits, + unsigned char* output, uint64_t& bit_buffer) { + if (n < 128) { + return 0; + } + + size_t i = 0; + __m512i shift = _mm512_set1_epi64(64 - bit_buffer_nbits); + __m512i carry = _mm512_set1_epi64(bit_buffer << (64 - bit_buffer_nbits)); + + for (; i + 64 <= n; i += 64) { + __m512i current = _mm512_loadu_si512(data + i); + __m512i previous_u64 = _mm512_alignr_epi64(current, carry, 7); + carry = current; + __m512i out = _mm512_shrdv_epi64(previous_u64, current, shift); + _mm512_storeu_si512(output + i, out); + } + + bit_buffer = data[i - 1] >> (8 - bit_buffer_nbits); + + return i; +} +#endif + +size_t JxlFastLosslessWriteOutput(JxlFastLosslessFrameState* frame, + unsigned char* output, size_t output_size) { + assert(output_size >= 32); + unsigned char* initial_output = output; + size_t (*append_bytes_with_bit_offset)(const uint8_t*, size_t, size_t, + unsigned char*, uint64_t&) = nullptr; + +#if FJXL_ENABLE_AVX512 + if (__builtin_cpu_supports("avx512vbmi2")) { + append_bytes_with_bit_offset = AppendBytesWithBitOffset; + } +#endif + + while (true) { + size_t& cur = frame->current_bit_writer; + size_t& bw_pos = frame->bit_writer_byte_pos; + if (cur >= 1 + frame->group_data.size() * frame->nb_chans) { + return output - initial_output; + } + if (output_size <= 8) { + return output - initial_output; + } + size_t nbc = frame->nb_chans; + const BitWriter& writer = + cur == 0 ? frame->header + : frame->group_data[(cur - 1) / nbc][(cur - 1) % nbc]; + size_t full_byte_count = + std::min(output_size - 8, writer.bytes_written - bw_pos); + if (frame->bits_in_buffer == 0) { + memcpy(output, writer.data.get() + bw_pos, full_byte_count); + } else { + size_t i = 0; + if (append_bytes_with_bit_offset) { + i += append_bytes_with_bit_offset( + writer.data.get() + bw_pos, full_byte_count, frame->bits_in_buffer, + output, frame->bit_buffer); + } +#if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__) + // Copy 8 bytes at a time until we reach the border. + for (; i + 8 < full_byte_count; i += 8) { + uint64_t chunk; + memcpy(&chunk, writer.data.get() + bw_pos + i, 8); + uint64_t out = frame->bit_buffer | (chunk << frame->bits_in_buffer); + memcpy(output + i, &out, 8); + frame->bit_buffer = chunk >> (64 - frame->bits_in_buffer); + } +#endif + for (; i < full_byte_count; i++) { + AddBits(8, writer.data.get()[bw_pos + i], output + i, + frame->bits_in_buffer, frame->bit_buffer); + } + } + output += full_byte_count; + output_size -= full_byte_count; + bw_pos += full_byte_count; + if (bw_pos == writer.bytes_written) { + auto write = [&](size_t num, uint64_t bits) { + size_t n = AddBits(num, bits, output, frame->bits_in_buffer, + frame->bit_buffer); + output += n; + output_size -= n; + }; + if (writer.bits_in_buffer) { + write(writer.bits_in_buffer, writer.buffer); + } + bw_pos = 0; + cur++; + if ((cur - 1) % nbc == 0 && frame->bits_in_buffer != 0) { + write(8 - frame->bits_in_buffer, 0); + } + } + } +} + +void JxlFastLosslessFreeFrameState(JxlFastLosslessFrameState* frame) { + delete frame; +} + +} // extern "C" + +#endif + +#ifdef FJXL_SELF_INCLUDE + +namespace { + +constexpr size_t kNumRawSymbols = 19; +constexpr size_t kNumLZ77 = 33; +constexpr size_t kLZ77CacheSize = 32; + +constexpr size_t kLZ77Offset = 224; +constexpr size_t kLZ77MinLength = 7; + +void EncodeHybridUintLZ77(uint32_t value, uint32_t* token, uint32_t* nbits, + uint32_t* bits) { + // 400 config + uint32_t n = FloorLog2(value); + *token = value < 16 ? value : 16 + n - 4; + *nbits = value < 16 ? 0 : n; + *bits = value < 16 ? 0 : value - (1 << *nbits); +} + +struct PrefixCode { + uint8_t raw_nbits[kNumRawSymbols] = {}; + uint8_t raw_bits[kNumRawSymbols] = {}; + + alignas(64) uint8_t raw_nbits_simd[16] = {}; + alignas(64) uint8_t raw_bits_simd[16] = {}; + + uint8_t lz77_nbits[kNumLZ77] = {}; + uint16_t lz77_bits[kNumLZ77] = {}; + + uint64_t lz77_cache_bits[kLZ77CacheSize] = {}; + uint8_t lz77_cache_nbits[kLZ77CacheSize] = {}; + + static uint16_t BitReverse(size_t nbits, uint16_t bits) { + constexpr uint16_t kNibbleLookup[16] = { + 0b0000, 0b1000, 0b0100, 0b1100, 0b0010, 0b1010, 0b0110, 0b1110, + 0b0001, 0b1001, 0b0101, 0b1101, 0b0011, 0b1011, 0b0111, 0b1111, + }; + uint16_t rev16 = (kNibbleLookup[bits & 0xF] << 12) | + (kNibbleLookup[(bits >> 4) & 0xF] << 8) | + (kNibbleLookup[(bits >> 8) & 0xF] << 4) | + (kNibbleLookup[bits >> 12]); + return rev16 >> (16 - nbits); + } + + // Create the prefix codes given the code lengths. + // Supports the code lengths being split into two halves. + static void ComputeCanonicalCode(const uint8_t* first_chunk_nbits, + uint8_t* first_chunk_bits, + size_t first_chunk_size, + const uint8_t* second_chunk_nbits, + uint16_t* second_chunk_bits, + size_t second_chunk_size) { + constexpr size_t kMaxCodeLength = 15; + uint8_t code_length_counts[kMaxCodeLength + 1] = {}; + for (size_t i = 0; i < first_chunk_size; i++) { + code_length_counts[first_chunk_nbits[i]]++; + assert(first_chunk_nbits[i] <= kMaxCodeLength); + assert(first_chunk_nbits[i] <= 8); + assert(first_chunk_nbits[i] > 0); + } + for (size_t i = 0; i < second_chunk_size; i++) { + code_length_counts[second_chunk_nbits[i]]++; + assert(second_chunk_nbits[i] <= kMaxCodeLength); + } + + uint16_t next_code[kMaxCodeLength + 1] = {}; + + uint16_t code = 0; + for (size_t i = 1; i < kMaxCodeLength + 1; i++) { + code = (code + code_length_counts[i - 1]) << 1; + next_code[i] = code; + } + + for (size_t i = 0; i < first_chunk_size; i++) { + first_chunk_bits[i] = + BitReverse(first_chunk_nbits[i], next_code[first_chunk_nbits[i]]++); + } + for (size_t i = 0; i < second_chunk_size; i++) { + second_chunk_bits[i] = + BitReverse(second_chunk_nbits[i], next_code[second_chunk_nbits[i]]++); + } + } + + template <typename T> + static void ComputeCodeLengthsNonZeroImpl(const uint64_t* freqs, size_t n, + size_t precision, T infty, + uint8_t* min_limit, + uint8_t* max_limit, + uint8_t* nbits) { + std::vector<T> dynp(((1U << precision) + 1) * (n + 1), infty); + auto d = [&](size_t sym, size_t off) -> T& { + return dynp[sym * ((1 << precision) + 1) + off]; + }; + d(0, 0) = 0; + for (size_t sym = 0; sym < n; sym++) { + for (T bits = min_limit[sym]; bits <= max_limit[sym]; bits++) { + size_t off_delta = 1U << (precision - bits); + for (size_t off = 0; off + off_delta <= (1U << precision); off++) { + d(sym + 1, off + off_delta) = + std::min(d(sym, off) + static_cast<T>(freqs[sym]) * bits, + d(sym + 1, off + off_delta)); + } + } + } + + size_t sym = n; + size_t off = 1U << precision; + + assert(d(sym, off) != infty); + + while (sym-- > 0) { + assert(off > 0); + for (size_t bits = min_limit[sym]; bits <= max_limit[sym]; bits++) { + size_t off_delta = 1U << (precision - bits); + if (off_delta <= off && + d(sym + 1, off) == d(sym, off - off_delta) + freqs[sym] * bits) { + off -= off_delta; + nbits[sym] = bits; + break; + } + } + } + } + + // Computes nbits[i] for i <= n, subject to min_limit[i] <= nbits[i] <= + // max_limit[i] and sum 2**-nbits[i] == 1, so to minimize sum(nbits[i] * + // freqs[i]). + static void ComputeCodeLengthsNonZero(const uint64_t* freqs, size_t n, + uint8_t* min_limit, uint8_t* max_limit, + uint8_t* nbits) { + size_t precision = 0; + size_t shortest_length = 255; + uint64_t freqsum = 0; + for (size_t i = 0; i < n; i++) { + assert(freqs[i] != 0); + freqsum += freqs[i]; + if (min_limit[i] < 1) min_limit[i] = 1; + assert(min_limit[i] <= max_limit[i]); + precision = std::max<size_t>(max_limit[i], precision); + shortest_length = std::min<size_t>(min_limit[i], shortest_length); + } + // If all the minimum limits are greater than 1, shift precision so that we + // behave as if the shortest was 1. + precision -= shortest_length - 1; + uint64_t infty = freqsum * precision; + if (infty < std::numeric_limits<uint32_t>::max() / 2) { + ComputeCodeLengthsNonZeroImpl(freqs, n, precision, + static_cast<uint32_t>(infty), min_limit, + max_limit, nbits); + } else { + ComputeCodeLengthsNonZeroImpl(freqs, n, precision, infty, min_limit, + max_limit, nbits); + } + } + + static constexpr size_t kMaxNumSymbols = + kNumRawSymbols + 1 < kNumLZ77 ? kNumLZ77 : kNumRawSymbols + 1; + static void ComputeCodeLengths(const uint64_t* freqs, size_t n, + const uint8_t* min_limit_in, + const uint8_t* max_limit_in, uint8_t* nbits) { + assert(n <= kMaxNumSymbols); + uint64_t compact_freqs[kMaxNumSymbols]; + uint8_t min_limit[kMaxNumSymbols]; + uint8_t max_limit[kMaxNumSymbols]; + size_t ni = 0; + for (size_t i = 0; i < n; i++) { + if (freqs[i]) { + compact_freqs[ni] = freqs[i]; + min_limit[ni] = min_limit_in[i]; + max_limit[ni] = max_limit_in[i]; + ni++; + } + } + uint8_t num_bits[kMaxNumSymbols] = {}; + ComputeCodeLengthsNonZero(compact_freqs, ni, min_limit, max_limit, + num_bits); + ni = 0; + for (size_t i = 0; i < n; i++) { + nbits[i] = 0; + if (freqs[i]) { + nbits[i] = num_bits[ni++]; + } + } + } + + // Invalid code, used to construct arrays. + PrefixCode() {} + + template <typename BitDepth> + PrefixCode(BitDepth, uint64_t* raw_counts, uint64_t* lz77_counts) { + // "merge" together all the lz77 counts in a single symbol for the level 1 + // table (containing just the raw symbols, up to length 7). + uint64_t level1_counts[kNumRawSymbols + 1]; + memcpy(level1_counts, raw_counts, kNumRawSymbols * sizeof(uint64_t)); + size_t numraw = kNumRawSymbols; + while (numraw > 0 && level1_counts[numraw - 1] == 0) numraw--; + + level1_counts[numraw] = 0; + for (size_t i = 0; i < kNumLZ77; i++) { + level1_counts[numraw] += lz77_counts[i]; + } + uint8_t level1_nbits[kNumRawSymbols + 1] = {}; + ComputeCodeLengths(level1_counts, numraw + 1, BitDepth::kMinRawLength, + BitDepth::kMaxRawLength, level1_nbits); + + uint8_t level2_nbits[kNumLZ77] = {}; + uint8_t min_lengths[kNumLZ77] = {}; + uint8_t l = 15 - level1_nbits[numraw]; + uint8_t max_lengths[kNumLZ77]; + for (size_t i = 0; i < kNumLZ77; i++) { + max_lengths[i] = l; + } + size_t num_lz77 = kNumLZ77; + while (num_lz77 > 0 && lz77_counts[num_lz77 - 1] == 0) num_lz77--; + ComputeCodeLengths(lz77_counts, num_lz77, min_lengths, max_lengths, + level2_nbits); + for (size_t i = 0; i < numraw; i++) { + raw_nbits[i] = level1_nbits[i]; + } + for (size_t i = 0; i < num_lz77; i++) { + lz77_nbits[i] = + level2_nbits[i] ? level1_nbits[numraw] + level2_nbits[i] : 0; + } + + ComputeCanonicalCode(raw_nbits, raw_bits, numraw, lz77_nbits, lz77_bits, + kNumLZ77); + BitDepth::PrepareForSimd(raw_nbits, raw_bits, numraw, raw_nbits_simd, + raw_bits_simd); + + // Prepare lz77 cache + for (size_t count = 0; count < kLZ77CacheSize; count++) { + unsigned token, nbits, bits; + EncodeHybridUintLZ77(count, &token, &nbits, &bits); + lz77_cache_nbits[count] = lz77_nbits[token] + nbits + raw_nbits[0]; + lz77_cache_bits[count] = + (((bits << lz77_nbits[token]) | lz77_bits[token]) << raw_nbits[0]) | + raw_bits[0]; + } + } + + void WriteTo(BitWriter* writer) const { + uint64_t code_length_counts[18] = {}; + code_length_counts[17] = 3 + 2 * (kNumLZ77 - 1); + for (size_t i = 0; i < kNumRawSymbols; i++) { + code_length_counts[raw_nbits[i]]++; + } + for (size_t i = 0; i < kNumLZ77; i++) { + code_length_counts[lz77_nbits[i]]++; + } + uint8_t code_length_nbits[18] = {}; + uint8_t code_length_nbits_min[18] = {}; + uint8_t code_length_nbits_max[18] = { + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + }; + ComputeCodeLengths(code_length_counts, 18, code_length_nbits_min, + code_length_nbits_max, code_length_nbits); + writer->Write(2, 0b00); // HSKIP = 0, i.e. don't skip code lengths. + + // As per Brotli RFC. + uint8_t code_length_order[18] = {1, 2, 3, 4, 0, 5, 17, 6, 16, + 7, 8, 9, 10, 11, 12, 13, 14, 15}; + uint8_t code_length_length_nbits[] = {2, 4, 3, 2, 2, 4}; + uint8_t code_length_length_bits[] = {0, 7, 3, 2, 1, 15}; + + // Encode lengths of code lengths. + size_t num_code_lengths = 18; + while (code_length_nbits[code_length_order[num_code_lengths - 1]] == 0) { + num_code_lengths--; + } + for (size_t i = 0; i < num_code_lengths; i++) { + int symbol = code_length_nbits[code_length_order[i]]; + writer->Write(code_length_length_nbits[symbol], + code_length_length_bits[symbol]); + } + + // Compute the canonical codes for the codes that represent the lengths of + // the actual codes for data. + uint16_t code_length_bits[18] = {}; + ComputeCanonicalCode(nullptr, nullptr, 0, code_length_nbits, + code_length_bits, 18); + // Encode raw bit code lengths. + for (size_t i = 0; i < kNumRawSymbols; i++) { + writer->Write(code_length_nbits[raw_nbits[i]], + code_length_bits[raw_nbits[i]]); + } + size_t num_lz77 = kNumLZ77; + while (lz77_nbits[num_lz77 - 1] == 0) { + num_lz77--; + } + // Encode 0s until 224 (start of LZ77 symbols). This is in total 224-19 = + // 205. + static_assert(kLZ77Offset == 224, ""); + static_assert(kNumRawSymbols == 19, ""); + writer->Write(code_length_nbits[17], code_length_bits[17]); + writer->Write(3, 0b010); // 5 + writer->Write(code_length_nbits[17], code_length_bits[17]); + writer->Write(3, 0b000); // (5-2)*8 + 3 = 27 + writer->Write(code_length_nbits[17], code_length_bits[17]); + writer->Write(3, 0b010); // (27-2)*8 + 5 = 205 + // Encode LZ77 symbols, with values 224+i. + for (size_t i = 0; i < num_lz77; i++) { + writer->Write(code_length_nbits[lz77_nbits[i]], + code_length_bits[lz77_nbits[i]]); + } + } +}; + +template <typename T> +struct VecPair { + T low; + T hi; +}; + +#ifdef FJXL_GENERIC_SIMD +#undef FJXL_GENERIC_SIMD +#endif + +#ifdef FJXL_AVX512 +#define FJXL_GENERIC_SIMD +struct SIMDVec32; +struct Mask32 { + __mmask16 mask; + SIMDVec32 IfThenElse(const SIMDVec32& if_true, const SIMDVec32& if_false); + size_t CountPrefix() const { + return CtzNonZero(~uint64_t{_cvtmask16_u32(mask)}); + } +}; + +struct SIMDVec32 { + __m512i vec; + + static constexpr size_t kLanes = 16; + + FJXL_INLINE static SIMDVec32 Load(const uint32_t* data) { + return SIMDVec32{_mm512_loadu_si512((__m512i*)data)}; + } + FJXL_INLINE void Store(uint32_t* data) { + _mm512_storeu_si512((__m512i*)data, vec); + } + FJXL_INLINE static SIMDVec32 Val(uint32_t v) { + return SIMDVec32{_mm512_set1_epi32(v)}; + } + FJXL_INLINE SIMDVec32 ValToToken() const { + return SIMDVec32{ + _mm512_sub_epi32(_mm512_set1_epi32(32), _mm512_lzcnt_epi32(vec))}; + } + FJXL_INLINE SIMDVec32 SatSubU(const SIMDVec32& to_subtract) const { + return SIMDVec32{_mm512_sub_epi32(_mm512_max_epu32(vec, to_subtract.vec), + to_subtract.vec)}; + } + FJXL_INLINE SIMDVec32 Sub(const SIMDVec32& to_subtract) const { + return SIMDVec32{_mm512_sub_epi32(vec, to_subtract.vec)}; + } + FJXL_INLINE SIMDVec32 Add(const SIMDVec32& oth) const { + return SIMDVec32{_mm512_add_epi32(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec32 Xor(const SIMDVec32& oth) const { + return SIMDVec32{_mm512_xor_epi32(vec, oth.vec)}; + } + FJXL_INLINE Mask32 Eq(const SIMDVec32& oth) const { + return Mask32{_mm512_cmpeq_epi32_mask(vec, oth.vec)}; + } + FJXL_INLINE Mask32 Gt(const SIMDVec32& oth) const { + return Mask32{_mm512_cmpgt_epi32_mask(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec32 Pow2() const { + return SIMDVec32{_mm512_sllv_epi32(_mm512_set1_epi32(1), vec)}; + } + template <size_t i> + FJXL_INLINE SIMDVec32 SignedShiftRight() const { + return SIMDVec32{_mm512_srai_epi32(vec, i)}; + } +}; + +struct SIMDVec16; + +struct Mask16 { + __mmask32 mask; + SIMDVec16 IfThenElse(const SIMDVec16& if_true, const SIMDVec16& if_false); + Mask16 And(const Mask16& oth) const { + return Mask16{_kand_mask32(mask, oth.mask)}; + } + size_t CountPrefix() const { + return CtzNonZero(~uint64_t{_cvtmask32_u32(mask)}); + } +}; + +struct SIMDVec16 { + __m512i vec; + + static constexpr size_t kLanes = 32; + + FJXL_INLINE static SIMDVec16 Load(const uint16_t* data) { + return SIMDVec16{_mm512_loadu_si512((__m512i*)data)}; + } + FJXL_INLINE void Store(uint16_t* data) { + _mm512_storeu_si512((__m512i*)data, vec); + } + FJXL_INLINE static SIMDVec16 Val(uint16_t v) { + return SIMDVec16{_mm512_set1_epi16(v)}; + } + FJXL_INLINE static SIMDVec16 FromTwo32(const SIMDVec32& lo, + const SIMDVec32& hi) { + auto tmp = _mm512_packus_epi32(lo.vec, hi.vec); + alignas(64) uint64_t perm[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + return SIMDVec16{ + _mm512_permutex2var_epi64(tmp, _mm512_load_si512((__m512i*)perm), tmp)}; + } + + FJXL_INLINE SIMDVec16 ValToToken() const { + auto c16 = _mm512_set1_epi32(16); + auto c32 = _mm512_set1_epi32(32); + auto low16bit = _mm512_set1_epi32(0x0000FFFF); + auto lzhi = + _mm512_sub_epi32(c16, _mm512_min_epu32(c16, _mm512_lzcnt_epi32(vec))); + auto lzlo = _mm512_sub_epi32( + c32, _mm512_lzcnt_epi32(_mm512_and_si512(low16bit, vec))); + return SIMDVec16{_mm512_or_si512(lzlo, _mm512_slli_epi32(lzhi, 16))}; + } + + FJXL_INLINE SIMDVec16 SatSubU(const SIMDVec16& to_subtract) const { + return SIMDVec16{_mm512_subs_epu16(vec, to_subtract.vec)}; + } + FJXL_INLINE SIMDVec16 Sub(const SIMDVec16& to_subtract) const { + return SIMDVec16{_mm512_sub_epi16(vec, to_subtract.vec)}; + } + FJXL_INLINE SIMDVec16 Add(const SIMDVec16& oth) const { + return SIMDVec16{_mm512_add_epi16(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 Min(const SIMDVec16& oth) const { + return SIMDVec16{_mm512_min_epu16(vec, oth.vec)}; + } + FJXL_INLINE Mask16 Eq(const SIMDVec16& oth) const { + return Mask16{_mm512_cmpeq_epi16_mask(vec, oth.vec)}; + } + FJXL_INLINE Mask16 Gt(const SIMDVec16& oth) const { + return Mask16{_mm512_cmpgt_epi16_mask(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 Pow2() const { + return SIMDVec16{_mm512_sllv_epi16(_mm512_set1_epi16(1), vec)}; + } + FJXL_INLINE SIMDVec16 Or(const SIMDVec16& oth) const { + return SIMDVec16{_mm512_or_si512(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 Xor(const SIMDVec16& oth) const { + return SIMDVec16{_mm512_xor_si512(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 And(const SIMDVec16& oth) const { + return SIMDVec16{_mm512_and_si512(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 HAdd(const SIMDVec16& oth) const { + return SIMDVec16{_mm512_srai_epi16(_mm512_add_epi16(vec, oth.vec), 1)}; + } + FJXL_INLINE SIMDVec16 PrepareForU8Lookup() const { + return SIMDVec16{_mm512_or_si512(vec, _mm512_set1_epi16(0xFF00))}; + } + FJXL_INLINE SIMDVec16 U8Lookup(const uint8_t* table) const { + return SIMDVec16{_mm512_shuffle_epi8( + _mm512_broadcast_i32x4(_mm_loadu_si128((__m128i*)table)), vec)}; + } + FJXL_INLINE VecPair<SIMDVec16> Interleave(const SIMDVec16& low) const { + auto lo = _mm512_unpacklo_epi16(low.vec, vec); + auto hi = _mm512_unpackhi_epi16(low.vec, vec); + alignas(64) uint64_t perm1[8] = {0, 1, 8, 9, 2, 3, 10, 11}; + alignas(64) uint64_t perm2[8] = {4, 5, 12, 13, 6, 7, 14, 15}; + return {SIMDVec16{_mm512_permutex2var_epi64( + lo, _mm512_load_si512((__m512i*)perm1), hi)}, + SIMDVec16{_mm512_permutex2var_epi64( + lo, _mm512_load_si512((__m512i*)perm2), hi)}}; + } + FJXL_INLINE VecPair<SIMDVec32> Upcast() const { + auto lo = _mm512_unpacklo_epi16(vec, _mm512_setzero_si512()); + auto hi = _mm512_unpackhi_epi16(vec, _mm512_setzero_si512()); + alignas(64) uint64_t perm1[8] = {0, 1, 8, 9, 2, 3, 10, 11}; + alignas(64) uint64_t perm2[8] = {4, 5, 12, 13, 6, 7, 14, 15}; + return {SIMDVec32{_mm512_permutex2var_epi64( + lo, _mm512_load_si512((__m512i*)perm1), hi)}, + SIMDVec32{_mm512_permutex2var_epi64( + lo, _mm512_load_si512((__m512i*)perm2), hi)}}; + } + template <size_t i> + FJXL_INLINE SIMDVec16 SignedShiftRight() const { + return SIMDVec16{_mm512_srai_epi16(vec, i)}; + } + + static std::array<SIMDVec16, 1> LoadG8(const unsigned char* data) { + __m256i bytes = _mm256_loadu_si256((__m256i*)data); + return {SIMDVec16{_mm512_cvtepu8_epi16(bytes)}}; + } + static std::array<SIMDVec16, 1> LoadG16(const unsigned char* data) { + return {Load((const uint16_t*)data)}; + } + + static std::array<SIMDVec16, 2> LoadGA8(const unsigned char* data) { + __m512i bytes = _mm512_loadu_si512((__m512i*)data); + __m512i gray = _mm512_and_si512(bytes, _mm512_set1_epi16(0xFF)); + __m512i alpha = _mm512_srli_epi16(bytes, 8); + return {SIMDVec16{gray}, SIMDVec16{alpha}}; + } + static std::array<SIMDVec16, 2> LoadGA16(const unsigned char* data) { + __m512i bytes1 = _mm512_loadu_si512((__m512i*)data); + __m512i bytes2 = _mm512_loadu_si512((__m512i*)(data + 64)); + __m512i g_mask = _mm512_set1_epi32(0xFFFF); + __m512i permuteidx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + __m512i g = _mm512_permutexvar_epi64( + permuteidx, _mm512_packus_epi32(_mm512_and_si512(bytes1, g_mask), + _mm512_and_si512(bytes2, g_mask))); + __m512i a = _mm512_permutexvar_epi64( + permuteidx, _mm512_packus_epi32(_mm512_srli_epi32(bytes1, 16), + _mm512_srli_epi32(bytes2, 16))); + return {SIMDVec16{g}, SIMDVec16{a}}; + } + + static std::array<SIMDVec16, 3> LoadRGB8(const unsigned char* data) { + __m512i bytes0 = _mm512_loadu_si512((__m512i*)data); + __m512i bytes1 = + _mm512_zextsi256_si512(_mm256_loadu_si256((__m256i*)(data + 64))); + + // 0x7A = element of upper half of second vector = 0 after lookup; still in + // the upper half once we add 1 or 2. + uint8_t z = 0x7A; + __m512i ridx = + _mm512_set_epi8(z, 93, z, 90, z, 87, z, 84, z, 81, z, 78, z, 75, z, 72, + z, 69, z, 66, z, 63, z, 60, z, 57, z, 54, z, 51, z, 48, + z, 45, z, 42, z, 39, z, 36, z, 33, z, 30, z, 27, z, 24, + z, 21, z, 18, z, 15, z, 12, z, 9, z, 6, z, 3, z, 0); + __m512i gidx = _mm512_add_epi8(ridx, _mm512_set1_epi8(1)); + __m512i bidx = _mm512_add_epi8(gidx, _mm512_set1_epi8(1)); + __m512i r = _mm512_permutex2var_epi8(bytes0, ridx, bytes1); + __m512i g = _mm512_permutex2var_epi8(bytes0, gidx, bytes1); + __m512i b = _mm512_permutex2var_epi8(bytes0, bidx, bytes1); + return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}}; + } + static std::array<SIMDVec16, 3> LoadRGB16(const unsigned char* data) { + __m512i bytes0 = _mm512_loadu_si512((__m512i*)data); + __m512i bytes1 = _mm512_loadu_si512((__m512i*)(data + 64)); + __m512i bytes2 = _mm512_loadu_si512((__m512i*)(data + 128)); + + __m512i ridx_lo = _mm512_set_epi16(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 63, 60, 57, + 54, 51, 48, 45, 42, 39, 36, 33, 30, 27, + 24, 21, 18, 15, 12, 9, 6, 3, 0); + // -1 is such that when adding 1 or 2, we get the correct index for + // green/blue. + __m512i ridx_hi = + _mm512_set_epi16(29, 26, 23, 20, 17, 14, 11, 8, 5, 2, -1, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); + __m512i gidx_lo = _mm512_add_epi16(ridx_lo, _mm512_set1_epi16(1)); + __m512i gidx_hi = _mm512_add_epi16(ridx_hi, _mm512_set1_epi16(1)); + __m512i bidx_lo = _mm512_add_epi16(gidx_lo, _mm512_set1_epi16(1)); + __m512i bidx_hi = _mm512_add_epi16(gidx_hi, _mm512_set1_epi16(1)); + + __mmask32 rmask = _cvtu32_mask32(0b11111111110000000000000000000000); + __mmask32 gbmask = _cvtu32_mask32(0b11111111111000000000000000000000); + + __m512i rlo = _mm512_permutex2var_epi16(bytes0, ridx_lo, bytes1); + __m512i glo = _mm512_permutex2var_epi16(bytes0, gidx_lo, bytes1); + __m512i blo = _mm512_permutex2var_epi16(bytes0, bidx_lo, bytes1); + __m512i r = _mm512_mask_permutexvar_epi16(rlo, rmask, ridx_hi, bytes2); + __m512i g = _mm512_mask_permutexvar_epi16(glo, gbmask, gidx_hi, bytes2); + __m512i b = _mm512_mask_permutexvar_epi16(blo, gbmask, bidx_hi, bytes2); + return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}}; + } + + static std::array<SIMDVec16, 4> LoadRGBA8(const unsigned char* data) { + __m512i bytes1 = _mm512_loadu_si512((__m512i*)data); + __m512i bytes2 = _mm512_loadu_si512((__m512i*)(data + 64)); + __m512i rg_mask = _mm512_set1_epi32(0xFFFF); + __m512i permuteidx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + __m512i rg = _mm512_permutexvar_epi64( + permuteidx, _mm512_packus_epi32(_mm512_and_si512(bytes1, rg_mask), + _mm512_and_si512(bytes2, rg_mask))); + __m512i ba = _mm512_permutexvar_epi64( + permuteidx, _mm512_packus_epi32(_mm512_srli_epi32(bytes1, 16), + _mm512_srli_epi32(bytes2, 16))); + __m512i r = _mm512_and_si512(rg, _mm512_set1_epi16(0xFF)); + __m512i g = _mm512_srli_epi16(rg, 8); + __m512i b = _mm512_and_si512(ba, _mm512_set1_epi16(0xFF)); + __m512i a = _mm512_srli_epi16(ba, 8); + return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}, SIMDVec16{a}}; + } + static std::array<SIMDVec16, 4> LoadRGBA16(const unsigned char* data) { + __m512i bytes0 = _mm512_loadu_si512((__m512i*)data); + __m512i bytes1 = _mm512_loadu_si512((__m512i*)(data + 64)); + __m512i bytes2 = _mm512_loadu_si512((__m512i*)(data + 128)); + __m512i bytes3 = _mm512_loadu_si512((__m512i*)(data + 192)); + + auto pack32 = [](__m512i a, __m512i b) { + __m512i permuteidx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + return _mm512_permutexvar_epi64(permuteidx, _mm512_packus_epi32(a, b)); + }; + auto packlow32 = [&pack32](__m512i a, __m512i b) { + __m512i mask = _mm512_set1_epi32(0xFFFF); + return pack32(_mm512_and_si512(a, mask), _mm512_and_si512(b, mask)); + }; + auto packhi32 = [&pack32](__m512i a, __m512i b) { + return pack32(_mm512_srli_epi32(a, 16), _mm512_srli_epi32(b, 16)); + }; + + __m512i rb0 = packlow32(bytes0, bytes1); + __m512i rb1 = packlow32(bytes2, bytes3); + __m512i ga0 = packhi32(bytes0, bytes1); + __m512i ga1 = packhi32(bytes2, bytes3); + + __m512i r = packlow32(rb0, rb1); + __m512i g = packlow32(ga0, ga1); + __m512i b = packhi32(rb0, rb1); + __m512i a = packhi32(ga0, ga1); + return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}, SIMDVec16{a}}; + } + + void SwapEndian() { + auto indices = _mm512_broadcast_i32x4( + _mm_setr_epi8(1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14)); + vec = _mm512_shuffle_epi8(vec, indices); + } +}; + +SIMDVec16 Mask16::IfThenElse(const SIMDVec16& if_true, + const SIMDVec16& if_false) { + return SIMDVec16{_mm512_mask_blend_epi16(mask, if_false.vec, if_true.vec)}; +} + +SIMDVec32 Mask32::IfThenElse(const SIMDVec32& if_true, + const SIMDVec32& if_false) { + return SIMDVec32{_mm512_mask_blend_epi32(mask, if_false.vec, if_true.vec)}; +} + +struct Bits64 { + static constexpr size_t kLanes = 8; + + __m512i nbits; + __m512i bits; + + FJXL_INLINE void Store(uint64_t* nbits_out, uint64_t* bits_out) { + _mm512_storeu_si512((__m512i*)nbits_out, nbits); + _mm512_storeu_si512((__m512i*)bits_out, bits); + } +}; + +struct Bits32 { + __m512i nbits; + __m512i bits; + + static Bits32 FromRaw(SIMDVec32 nbits, SIMDVec32 bits) { + return Bits32{nbits.vec, bits.vec}; + } + + Bits64 Merge() const { + auto nbits_hi32 = _mm512_srli_epi64(nbits, 32); + auto nbits_lo32 = _mm512_and_si512(nbits, _mm512_set1_epi64(0xFFFFFFFF)); + auto bits_hi32 = _mm512_srli_epi64(bits, 32); + auto bits_lo32 = _mm512_and_si512(bits, _mm512_set1_epi64(0xFFFFFFFF)); + + auto nbits64 = _mm512_add_epi64(nbits_hi32, nbits_lo32); + auto bits64 = + _mm512_or_si512(_mm512_sllv_epi64(bits_hi32, nbits_lo32), bits_lo32); + return Bits64{nbits64, bits64}; + } + + void Interleave(const Bits32& low) { + bits = _mm512_or_si512(_mm512_sllv_epi32(bits, low.nbits), low.bits); + nbits = _mm512_add_epi32(nbits, low.nbits); + } + + void ClipTo(size_t n) { + n = std::min<size_t>(n, 16); + constexpr uint32_t kMask[32] = { + ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, + ~0u, ~0u, ~0u, ~0u, ~0u, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }; + __m512i mask = _mm512_loadu_si512((__m512i*)(kMask + 16 - n)); + nbits = _mm512_and_si512(mask, nbits); + bits = _mm512_and_si512(mask, bits); + } + void Skip(size_t n) { + n = std::min<size_t>(n, 16); + constexpr uint32_t kMask[32] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, + ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, + }; + __m512i mask = _mm512_loadu_si512((__m512i*)(kMask + 16 - n)); + nbits = _mm512_and_si512(mask, nbits); + bits = _mm512_and_si512(mask, bits); + } +}; + +struct Bits16 { + __m512i nbits; + __m512i bits; + + static Bits16 FromRaw(SIMDVec16 nbits, SIMDVec16 bits) { + return Bits16{nbits.vec, bits.vec}; + } + + Bits32 Merge() const { + auto nbits_hi16 = _mm512_srli_epi32(nbits, 16); + auto nbits_lo16 = _mm512_and_si512(nbits, _mm512_set1_epi32(0xFFFF)); + auto bits_hi16 = _mm512_srli_epi32(bits, 16); + auto bits_lo16 = _mm512_and_si512(bits, _mm512_set1_epi32(0xFFFF)); + + auto nbits32 = _mm512_add_epi32(nbits_hi16, nbits_lo16); + auto bits32 = + _mm512_or_si512(_mm512_sllv_epi32(bits_hi16, nbits_lo16), bits_lo16); + return Bits32{nbits32, bits32}; + } + + void Interleave(const Bits16& low) { + bits = _mm512_or_si512(_mm512_sllv_epi16(bits, low.nbits), low.bits); + nbits = _mm512_add_epi16(nbits, low.nbits); + } + + void ClipTo(size_t n) { + n = std::min<size_t>(n, 32); + constexpr uint16_t kMask[64] = { + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }; + __m512i mask = _mm512_loadu_si512((__m512i*)(kMask + 32 - n)); + nbits = _mm512_and_si512(mask, nbits); + bits = _mm512_and_si512(mask, bits); + } + void Skip(size_t n) { + n = std::min<size_t>(n, 32); + constexpr uint16_t kMask[64] = { + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + }; + __m512i mask = _mm512_loadu_si512((__m512i*)(kMask + 32 - n)); + nbits = _mm512_and_si512(mask, nbits); + bits = _mm512_and_si512(mask, bits); + } +}; + +#endif + +#ifdef FJXL_AVX2 +#define FJXL_GENERIC_SIMD + +struct SIMDVec32; + +struct Mask32 { + __m256i mask; + SIMDVec32 IfThenElse(const SIMDVec32& if_true, const SIMDVec32& if_false); + size_t CountPrefix() const { + return CtzNonZero(~static_cast<uint64_t>( + (uint8_t)_mm256_movemask_ps(_mm256_castsi256_ps(mask)))); + } +}; + +struct SIMDVec32 { + __m256i vec; + + static constexpr size_t kLanes = 8; + + FJXL_INLINE static SIMDVec32 Load(const uint32_t* data) { + return SIMDVec32{_mm256_loadu_si256((__m256i*)data)}; + } + FJXL_INLINE void Store(uint32_t* data) { + _mm256_storeu_si256((__m256i*)data, vec); + } + FJXL_INLINE static SIMDVec32 Val(uint32_t v) { + return SIMDVec32{_mm256_set1_epi32(v)}; + } + FJXL_INLINE SIMDVec32 ValToToken() const { + // we know that each value has at most 20 bits, so we just need 5 nibbles + // and don't need to mask the fifth. However we do need to set the higher + // bytes to 0xFF, which will make table lookups return 0. + auto nibble0 = + _mm256_or_si256(_mm256_and_si256(vec, _mm256_set1_epi32(0xF)), + _mm256_set1_epi32(0xFFFFFF00)); + auto nibble1 = _mm256_or_si256( + _mm256_and_si256(_mm256_srli_epi32(vec, 4), _mm256_set1_epi32(0xF)), + _mm256_set1_epi32(0xFFFFFF00)); + auto nibble2 = _mm256_or_si256( + _mm256_and_si256(_mm256_srli_epi32(vec, 8), _mm256_set1_epi32(0xF)), + _mm256_set1_epi32(0xFFFFFF00)); + auto nibble3 = _mm256_or_si256( + _mm256_and_si256(_mm256_srli_epi32(vec, 12), _mm256_set1_epi32(0xF)), + _mm256_set1_epi32(0xFFFFFF00)); + auto nibble4 = _mm256_or_si256(_mm256_srli_epi32(vec, 16), + _mm256_set1_epi32(0xFFFFFF00)); + + auto lut0 = _mm256_broadcastsi128_si256( + _mm_setr_epi8(0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4)); + auto lut1 = _mm256_broadcastsi128_si256( + _mm_setr_epi8(0, 5, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8)); + auto lut2 = _mm256_broadcastsi128_si256(_mm_setr_epi8( + 0, 9, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12)); + auto lut3 = _mm256_broadcastsi128_si256(_mm_setr_epi8( + 0, 13, 14, 14, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16)); + auto lut4 = _mm256_broadcastsi128_si256(_mm_setr_epi8( + 0, 17, 18, 18, 19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20)); + + auto token0 = _mm256_shuffle_epi8(lut0, nibble0); + auto token1 = _mm256_shuffle_epi8(lut1, nibble1); + auto token2 = _mm256_shuffle_epi8(lut2, nibble2); + auto token3 = _mm256_shuffle_epi8(lut3, nibble3); + auto token4 = _mm256_shuffle_epi8(lut4, nibble4); + + auto token = + _mm256_max_epi32(_mm256_max_epi32(_mm256_max_epi32(token0, token1), + _mm256_max_epi32(token2, token3)), + token4); + return SIMDVec32{token}; + } + FJXL_INLINE SIMDVec32 SatSubU(const SIMDVec32& to_subtract) const { + return SIMDVec32{_mm256_sub_epi32(_mm256_max_epu32(vec, to_subtract.vec), + to_subtract.vec)}; + } + FJXL_INLINE SIMDVec32 Sub(const SIMDVec32& to_subtract) const { + return SIMDVec32{_mm256_sub_epi32(vec, to_subtract.vec)}; + } + FJXL_INLINE SIMDVec32 Add(const SIMDVec32& oth) const { + return SIMDVec32{_mm256_add_epi32(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec32 Xor(const SIMDVec32& oth) const { + return SIMDVec32{_mm256_xor_si256(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec32 Pow2() const { + return SIMDVec32{_mm256_sllv_epi32(_mm256_set1_epi32(1), vec)}; + } + FJXL_INLINE Mask32 Eq(const SIMDVec32& oth) const { + return Mask32{_mm256_cmpeq_epi32(vec, oth.vec)}; + } + FJXL_INLINE Mask32 Gt(const SIMDVec32& oth) const { + return Mask32{_mm256_cmpgt_epi32(vec, oth.vec)}; + } + template <size_t i> + FJXL_INLINE SIMDVec32 SignedShiftRight() const { + return SIMDVec32{_mm256_srai_epi32(vec, i)}; + } +}; + +struct SIMDVec16; + +struct Mask16 { + __m256i mask; + SIMDVec16 IfThenElse(const SIMDVec16& if_true, const SIMDVec16& if_false); + Mask16 And(const Mask16& oth) const { + return Mask16{_mm256_and_si256(mask, oth.mask)}; + } + size_t CountPrefix() const { + return CtzNonZero( + ~static_cast<uint64_t>((uint32_t)_mm256_movemask_epi8(mask))) / + 2; + } +}; + +struct SIMDVec16 { + __m256i vec; + + static constexpr size_t kLanes = 16; + + FJXL_INLINE static SIMDVec16 Load(const uint16_t* data) { + return SIMDVec16{_mm256_loadu_si256((__m256i*)data)}; + } + FJXL_INLINE void Store(uint16_t* data) { + _mm256_storeu_si256((__m256i*)data, vec); + } + FJXL_INLINE static SIMDVec16 Val(uint16_t v) { + return SIMDVec16{_mm256_set1_epi16(v)}; + } + FJXL_INLINE static SIMDVec16 FromTwo32(const SIMDVec32& lo, + const SIMDVec32& hi) { + auto tmp = _mm256_packus_epi32(lo.vec, hi.vec); + return SIMDVec16{_mm256_permute4x64_epi64(tmp, 0b11011000)}; + } + + FJXL_INLINE SIMDVec16 ValToToken() const { + auto nibble0 = + _mm256_or_si256(_mm256_and_si256(vec, _mm256_set1_epi16(0xF)), + _mm256_set1_epi16(0xFF00)); + auto nibble1 = _mm256_or_si256( + _mm256_and_si256(_mm256_srli_epi16(vec, 4), _mm256_set1_epi16(0xF)), + _mm256_set1_epi16(0xFF00)); + auto nibble2 = _mm256_or_si256( + _mm256_and_si256(_mm256_srli_epi16(vec, 8), _mm256_set1_epi16(0xF)), + _mm256_set1_epi16(0xFF00)); + auto nibble3 = + _mm256_or_si256(_mm256_srli_epi16(vec, 12), _mm256_set1_epi16(0xFF00)); + + auto lut0 = _mm256_broadcastsi128_si256( + _mm_setr_epi8(0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4)); + auto lut1 = _mm256_broadcastsi128_si256( + _mm_setr_epi8(0, 5, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8)); + auto lut2 = _mm256_broadcastsi128_si256(_mm_setr_epi8( + 0, 9, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12)); + auto lut3 = _mm256_broadcastsi128_si256(_mm_setr_epi8( + 0, 13, 14, 14, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16)); + + auto token0 = _mm256_shuffle_epi8(lut0, nibble0); + auto token1 = _mm256_shuffle_epi8(lut1, nibble1); + auto token2 = _mm256_shuffle_epi8(lut2, nibble2); + auto token3 = _mm256_shuffle_epi8(lut3, nibble3); + + auto token = _mm256_max_epi16(_mm256_max_epi16(token0, token1), + _mm256_max_epi16(token2, token3)); + return SIMDVec16{token}; + } + + FJXL_INLINE SIMDVec16 SatSubU(const SIMDVec16& to_subtract) const { + return SIMDVec16{_mm256_subs_epu16(vec, to_subtract.vec)}; + } + FJXL_INLINE SIMDVec16 Sub(const SIMDVec16& to_subtract) const { + return SIMDVec16{_mm256_sub_epi16(vec, to_subtract.vec)}; + } + FJXL_INLINE SIMDVec16 Add(const SIMDVec16& oth) const { + return SIMDVec16{_mm256_add_epi16(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 Min(const SIMDVec16& oth) const { + return SIMDVec16{_mm256_min_epu16(vec, oth.vec)}; + } + FJXL_INLINE Mask16 Eq(const SIMDVec16& oth) const { + return Mask16{_mm256_cmpeq_epi16(vec, oth.vec)}; + } + FJXL_INLINE Mask16 Gt(const SIMDVec16& oth) const { + return Mask16{_mm256_cmpgt_epi16(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 Pow2() const { + auto pow2_lo_lut = _mm256_broadcastsi128_si256( + _mm_setr_epi8(1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6, + 1u << 7, 0, 0, 0, 0, 0, 0, 0, 0)); + auto pow2_hi_lut = _mm256_broadcastsi128_si256( + _mm_setr_epi8(0, 0, 0, 0, 0, 0, 0, 0, 1 << 0, 1 << 1, 1 << 2, 1 << 3, + 1 << 4, 1 << 5, 1 << 6, 1u << 7)); + + auto masked = _mm256_or_si256(vec, _mm256_set1_epi16(0xFF00)); + + auto pow2_lo = _mm256_shuffle_epi8(pow2_lo_lut, masked); + auto pow2_hi = _mm256_shuffle_epi8(pow2_hi_lut, masked); + + auto pow2 = _mm256_or_si256(_mm256_slli_epi16(pow2_hi, 8), pow2_lo); + return SIMDVec16{pow2}; + } + FJXL_INLINE SIMDVec16 Or(const SIMDVec16& oth) const { + return SIMDVec16{_mm256_or_si256(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 Xor(const SIMDVec16& oth) const { + return SIMDVec16{_mm256_xor_si256(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 And(const SIMDVec16& oth) const { + return SIMDVec16{_mm256_and_si256(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 HAdd(const SIMDVec16& oth) const { + return SIMDVec16{_mm256_srai_epi16(_mm256_add_epi16(vec, oth.vec), 1)}; + } + FJXL_INLINE SIMDVec16 PrepareForU8Lookup() const { + return SIMDVec16{_mm256_or_si256(vec, _mm256_set1_epi16(0xFF00))}; + } + FJXL_INLINE SIMDVec16 U8Lookup(const uint8_t* table) const { + return SIMDVec16{_mm256_shuffle_epi8( + _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)table)), vec)}; + } + FJXL_INLINE VecPair<SIMDVec16> Interleave(const SIMDVec16& low) const { + auto v02 = _mm256_unpacklo_epi16(low.vec, vec); + auto v13 = _mm256_unpackhi_epi16(low.vec, vec); + return {SIMDVec16{_mm256_permute2x128_si256(v02, v13, 0x20)}, + SIMDVec16{_mm256_permute2x128_si256(v02, v13, 0x31)}}; + } + FJXL_INLINE VecPair<SIMDVec32> Upcast() const { + auto v02 = _mm256_unpacklo_epi16(vec, _mm256_setzero_si256()); + auto v13 = _mm256_unpackhi_epi16(vec, _mm256_setzero_si256()); + return {SIMDVec32{_mm256_permute2x128_si256(v02, v13, 0x20)}, + SIMDVec32{_mm256_permute2x128_si256(v02, v13, 0x31)}}; + } + template <size_t i> + FJXL_INLINE SIMDVec16 SignedShiftRight() const { + return SIMDVec16{_mm256_srai_epi16(vec, i)}; + } + + static std::array<SIMDVec16, 1> LoadG8(const unsigned char* data) { + __m128i bytes = _mm_loadu_si128((__m128i*)data); + return {SIMDVec16{_mm256_cvtepu8_epi16(bytes)}}; + } + static std::array<SIMDVec16, 1> LoadG16(const unsigned char* data) { + return {Load((const uint16_t*)data)}; + } + + static std::array<SIMDVec16, 2> LoadGA8(const unsigned char* data) { + __m256i bytes = _mm256_loadu_si256((__m256i*)data); + __m256i gray = _mm256_and_si256(bytes, _mm256_set1_epi16(0xFF)); + __m256i alpha = _mm256_srli_epi16(bytes, 8); + return {SIMDVec16{gray}, SIMDVec16{alpha}}; + } + static std::array<SIMDVec16, 2> LoadGA16(const unsigned char* data) { + __m256i bytes1 = _mm256_loadu_si256((__m256i*)data); + __m256i bytes2 = _mm256_loadu_si256((__m256i*)(data + 32)); + __m256i g_mask = _mm256_set1_epi32(0xFFFF); + __m256i g = _mm256_permute4x64_epi64( + _mm256_packus_epi32(_mm256_and_si256(bytes1, g_mask), + _mm256_and_si256(bytes2, g_mask)), + 0b11011000); + __m256i a = _mm256_permute4x64_epi64( + _mm256_packus_epi32(_mm256_srli_epi32(bytes1, 16), + _mm256_srli_epi32(bytes2, 16)), + 0b11011000); + return {SIMDVec16{g}, SIMDVec16{a}}; + } + + static std::array<SIMDVec16, 3> LoadRGB8(const unsigned char* data) { + __m128i bytes0 = _mm_loadu_si128((__m128i*)data); + __m128i bytes1 = _mm_loadu_si128((__m128i*)(data + 16)); + __m128i bytes2 = _mm_loadu_si128((__m128i*)(data + 32)); + + __m128i idx = + _mm_setr_epi8(0, 3, 6, 9, 12, 15, 2, 5, 8, 11, 14, 1, 4, 7, 10, 13); + + __m128i r6b5g5_0 = _mm_shuffle_epi8(bytes0, idx); + __m128i g6r5b5_1 = _mm_shuffle_epi8(bytes1, idx); + __m128i b6g5r5_2 = _mm_shuffle_epi8(bytes2, idx); + + __m128i mask010 = _mm_setr_epi8(0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0, 0, 0, 0, 0); + __m128i mask001 = _mm_setr_epi8(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF); + + __m128i b2g2b1 = _mm_blendv_epi8(b6g5r5_2, g6r5b5_1, mask001); + __m128i b2b0b1 = _mm_blendv_epi8(b2g2b1, r6b5g5_0, mask010); + + __m128i r0r1b1 = _mm_blendv_epi8(r6b5g5_0, g6r5b5_1, mask010); + __m128i r0r1r2 = _mm_blendv_epi8(r0r1b1, b6g5r5_2, mask001); + + __m128i g1r1g0 = _mm_blendv_epi8(g6r5b5_1, r6b5g5_0, mask001); + __m128i g1g2g0 = _mm_blendv_epi8(g1r1g0, b6g5r5_2, mask010); + + __m128i g0g1g2 = _mm_alignr_epi8(g1g2g0, g1g2g0, 11); + __m128i b0b1b2 = _mm_alignr_epi8(b2b0b1, b2b0b1, 6); + + return {SIMDVec16{_mm256_cvtepu8_epi16(r0r1r2)}, + SIMDVec16{_mm256_cvtepu8_epi16(g0g1g2)}, + SIMDVec16{_mm256_cvtepu8_epi16(b0b1b2)}}; + } + static std::array<SIMDVec16, 3> LoadRGB16(const unsigned char* data) { + auto load_and_split_lohi = [](const unsigned char* data) { + // LHLHLH... + __m256i bytes = _mm256_loadu_si256((__m256i*)data); + // L0L0L0... + __m256i lo = _mm256_and_si256(bytes, _mm256_set1_epi16(0xFF)); + // H0H0H0... + __m256i hi = _mm256_srli_epi16(bytes, 8); + // LLLLLLLLHHHHHHHHLLLLLLLLHHHHHHHH + __m256i packed = _mm256_packus_epi16(lo, hi); + return _mm256_permute4x64_epi64(packed, 0b11011000); + }; + __m256i bytes0 = load_and_split_lohi(data); + __m256i bytes1 = load_and_split_lohi(data + 32); + __m256i bytes2 = load_and_split_lohi(data + 64); + + __m256i idx = _mm256_broadcastsi128_si256( + _mm_setr_epi8(0, 3, 6, 9, 12, 15, 2, 5, 8, 11, 14, 1, 4, 7, 10, 13)); + + __m256i r6b5g5_0 = _mm256_shuffle_epi8(bytes0, idx); + __m256i g6r5b5_1 = _mm256_shuffle_epi8(bytes1, idx); + __m256i b6g5r5_2 = _mm256_shuffle_epi8(bytes2, idx); + + __m256i mask010 = _mm256_broadcastsi128_si256(_mm_setr_epi8( + 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0, 0, 0, 0, 0)); + __m256i mask001 = _mm256_broadcastsi128_si256(_mm_setr_epi8( + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF)); + + __m256i b2g2b1 = _mm256_blendv_epi8(b6g5r5_2, g6r5b5_1, mask001); + __m256i b2b0b1 = _mm256_blendv_epi8(b2g2b1, r6b5g5_0, mask010); + + __m256i r0r1b1 = _mm256_blendv_epi8(r6b5g5_0, g6r5b5_1, mask010); + __m256i r0r1r2 = _mm256_blendv_epi8(r0r1b1, b6g5r5_2, mask001); + + __m256i g1r1g0 = _mm256_blendv_epi8(g6r5b5_1, r6b5g5_0, mask001); + __m256i g1g2g0 = _mm256_blendv_epi8(g1r1g0, b6g5r5_2, mask010); + + __m256i g0g1g2 = _mm256_alignr_epi8(g1g2g0, g1g2g0, 11); + __m256i b0b1b2 = _mm256_alignr_epi8(b2b0b1, b2b0b1, 6); + + // Now r0r1r2, g0g1g2, b0b1b2 have the low bytes of the RGB pixels in their + // lower half, and the high bytes in their upper half. + + auto combine_low_hi = [](__m256i v) { + __m128i low = _mm256_extracti128_si256(v, 0); + __m128i hi = _mm256_extracti128_si256(v, 1); + __m256i low16 = _mm256_cvtepu8_epi16(low); + __m256i hi16 = _mm256_cvtepu8_epi16(hi); + return _mm256_or_si256(_mm256_slli_epi16(hi16, 8), low16); + }; + + return {SIMDVec16{combine_low_hi(r0r1r2)}, + SIMDVec16{combine_low_hi(g0g1g2)}, + SIMDVec16{combine_low_hi(b0b1b2)}}; + } + + static std::array<SIMDVec16, 4> LoadRGBA8(const unsigned char* data) { + __m256i bytes1 = _mm256_loadu_si256((__m256i*)data); + __m256i bytes2 = _mm256_loadu_si256((__m256i*)(data + 32)); + __m256i rg_mask = _mm256_set1_epi32(0xFFFF); + __m256i rg = _mm256_permute4x64_epi64( + _mm256_packus_epi32(_mm256_and_si256(bytes1, rg_mask), + _mm256_and_si256(bytes2, rg_mask)), + 0b11011000); + __m256i ba = _mm256_permute4x64_epi64( + _mm256_packus_epi32(_mm256_srli_epi32(bytes1, 16), + _mm256_srli_epi32(bytes2, 16)), + 0b11011000); + __m256i r = _mm256_and_si256(rg, _mm256_set1_epi16(0xFF)); + __m256i g = _mm256_srli_epi16(rg, 8); + __m256i b = _mm256_and_si256(ba, _mm256_set1_epi16(0xFF)); + __m256i a = _mm256_srli_epi16(ba, 8); + return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}, SIMDVec16{a}}; + } + static std::array<SIMDVec16, 4> LoadRGBA16(const unsigned char* data) { + __m256i bytes0 = _mm256_loadu_si256((__m256i*)data); + __m256i bytes1 = _mm256_loadu_si256((__m256i*)(data + 32)); + __m256i bytes2 = _mm256_loadu_si256((__m256i*)(data + 64)); + __m256i bytes3 = _mm256_loadu_si256((__m256i*)(data + 96)); + + auto pack32 = [](__m256i a, __m256i b) { + return _mm256_permute4x64_epi64(_mm256_packus_epi32(a, b), 0b11011000); + }; + auto packlow32 = [&pack32](__m256i a, __m256i b) { + __m256i mask = _mm256_set1_epi32(0xFFFF); + return pack32(_mm256_and_si256(a, mask), _mm256_and_si256(b, mask)); + }; + auto packhi32 = [&pack32](__m256i a, __m256i b) { + return pack32(_mm256_srli_epi32(a, 16), _mm256_srli_epi32(b, 16)); + }; + + __m256i rb0 = packlow32(bytes0, bytes1); + __m256i rb1 = packlow32(bytes2, bytes3); + __m256i ga0 = packhi32(bytes0, bytes1); + __m256i ga1 = packhi32(bytes2, bytes3); + + __m256i r = packlow32(rb0, rb1); + __m256i g = packlow32(ga0, ga1); + __m256i b = packhi32(rb0, rb1); + __m256i a = packhi32(ga0, ga1); + return {SIMDVec16{r}, SIMDVec16{g}, SIMDVec16{b}, SIMDVec16{a}}; + } + + void SwapEndian() { + auto indices = _mm256_broadcastsi128_si256( + _mm_setr_epi8(1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14)); + vec = _mm256_shuffle_epi8(vec, indices); + } +}; + +SIMDVec16 Mask16::IfThenElse(const SIMDVec16& if_true, + const SIMDVec16& if_false) { + return SIMDVec16{_mm256_blendv_epi8(if_false.vec, if_true.vec, mask)}; +} + +SIMDVec32 Mask32::IfThenElse(const SIMDVec32& if_true, + const SIMDVec32& if_false) { + return SIMDVec32{_mm256_blendv_epi8(if_false.vec, if_true.vec, mask)}; +} + +struct Bits64 { + static constexpr size_t kLanes = 4; + + __m256i nbits; + __m256i bits; + + FJXL_INLINE void Store(uint64_t* nbits_out, uint64_t* bits_out) { + _mm256_storeu_si256((__m256i*)nbits_out, nbits); + _mm256_storeu_si256((__m256i*)bits_out, bits); + } +}; + +struct Bits32 { + __m256i nbits; + __m256i bits; + + static Bits32 FromRaw(SIMDVec32 nbits, SIMDVec32 bits) { + return Bits32{nbits.vec, bits.vec}; + } + + Bits64 Merge() const { + auto nbits_hi32 = _mm256_srli_epi64(nbits, 32); + auto nbits_lo32 = _mm256_and_si256(nbits, _mm256_set1_epi64x(0xFFFFFFFF)); + auto bits_hi32 = _mm256_srli_epi64(bits, 32); + auto bits_lo32 = _mm256_and_si256(bits, _mm256_set1_epi64x(0xFFFFFFFF)); + + auto nbits64 = _mm256_add_epi64(nbits_hi32, nbits_lo32); + auto bits64 = + _mm256_or_si256(_mm256_sllv_epi64(bits_hi32, nbits_lo32), bits_lo32); + return Bits64{nbits64, bits64}; + } + + void Interleave(const Bits32& low) { + bits = _mm256_or_si256(_mm256_sllv_epi32(bits, low.nbits), low.bits); + nbits = _mm256_add_epi32(nbits, low.nbits); + } + + void ClipTo(size_t n) { + n = std::min<size_t>(n, 8); + constexpr uint32_t kMask[16] = { + ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, 0, 0, 0, 0, 0, 0, 0, 0, + }; + __m256i mask = _mm256_loadu_si256((__m256i*)(kMask + 8 - n)); + nbits = _mm256_and_si256(mask, nbits); + bits = _mm256_and_si256(mask, bits); + } + void Skip(size_t n) { + n = std::min<size_t>(n, 8); + constexpr uint32_t kMask[16] = { + 0, 0, 0, 0, 0, 0, 0, 0, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, ~0u, + }; + __m256i mask = _mm256_loadu_si256((__m256i*)(kMask + 8 - n)); + nbits = _mm256_and_si256(mask, nbits); + bits = _mm256_and_si256(mask, bits); + } +}; + +struct Bits16 { + __m256i nbits; + __m256i bits; + + static Bits16 FromRaw(SIMDVec16 nbits, SIMDVec16 bits) { + return Bits16{nbits.vec, bits.vec}; + } + + Bits32 Merge() const { + auto nbits_hi16 = _mm256_srli_epi32(nbits, 16); + auto nbits_lo16 = _mm256_and_si256(nbits, _mm256_set1_epi32(0xFFFF)); + auto bits_hi16 = _mm256_srli_epi32(bits, 16); + auto bits_lo16 = _mm256_and_si256(bits, _mm256_set1_epi32(0xFFFF)); + + auto nbits32 = _mm256_add_epi32(nbits_hi16, nbits_lo16); + auto bits32 = + _mm256_or_si256(_mm256_sllv_epi32(bits_hi16, nbits_lo16), bits_lo16); + return Bits32{nbits32, bits32}; + } + + void Interleave(const Bits16& low) { + auto pow2_lo_lut = _mm256_broadcastsi128_si256( + _mm_setr_epi8(1 << 0, 1 << 1, 1 << 2, 1 << 3, 1 << 4, 1 << 5, 1 << 6, + 1u << 7, 0, 0, 0, 0, 0, 0, 0, 0)); + auto low_nbits_masked = + _mm256_or_si256(low.nbits, _mm256_set1_epi16(0xFF00)); + + auto bits_shifted = _mm256_mullo_epi16( + bits, _mm256_shuffle_epi8(pow2_lo_lut, low_nbits_masked)); + + nbits = _mm256_add_epi16(nbits, low.nbits); + bits = _mm256_or_si256(bits_shifted, low.bits); + } + + void ClipTo(size_t n) { + n = std::min<size_t>(n, 16); + constexpr uint16_t kMask[32] = { + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }; + __m256i mask = _mm256_loadu_si256((__m256i*)(kMask + 16 - n)); + nbits = _mm256_and_si256(mask, nbits); + bits = _mm256_and_si256(mask, bits); + } + + void Skip(size_t n) { + n = std::min<size_t>(n, 16); + constexpr uint16_t kMask[32] = { + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + }; + __m256i mask = _mm256_loadu_si256((__m256i*)(kMask + 16 - n)); + nbits = _mm256_and_si256(mask, nbits); + bits = _mm256_and_si256(mask, bits); + } +}; + +#endif + +#ifdef FJXL_NEON +#define FJXL_GENERIC_SIMD + +struct SIMDVec32; + +struct Mask32 { + uint32x4_t mask; + SIMDVec32 IfThenElse(const SIMDVec32& if_true, const SIMDVec32& if_false); + Mask32 And(const Mask32& oth) const { + return Mask32{vandq_u32(mask, oth.mask)}; + } + size_t CountPrefix() const { + uint32_t val_unset[4] = {0, 1, 2, 3}; + uint32_t val_set[4] = {4, 4, 4, 4}; + uint32x4_t val = vbslq_u32(mask, vld1q_u32(val_set), vld1q_u32(val_unset)); + return vminvq_u32(val); + } +}; + +struct SIMDVec32 { + uint32x4_t vec; + + static constexpr size_t kLanes = 4; + + FJXL_INLINE static SIMDVec32 Load(const uint32_t* data) { + return SIMDVec32{vld1q_u32(data)}; + } + FJXL_INLINE void Store(uint32_t* data) { vst1q_u32(data, vec); } + FJXL_INLINE static SIMDVec32 Val(uint32_t v) { + return SIMDVec32{vdupq_n_u32(v)}; + } + FJXL_INLINE SIMDVec32 ValToToken() const { + return SIMDVec32{vsubq_u32(vdupq_n_u32(32), vclzq_u32(vec))}; + } + FJXL_INLINE SIMDVec32 SatSubU(const SIMDVec32& to_subtract) const { + return SIMDVec32{vqsubq_u32(vec, to_subtract.vec)}; + } + FJXL_INLINE SIMDVec32 Sub(const SIMDVec32& to_subtract) const { + return SIMDVec32{vsubq_u32(vec, to_subtract.vec)}; + } + FJXL_INLINE SIMDVec32 Add(const SIMDVec32& oth) const { + return SIMDVec32{vaddq_u32(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec32 Xor(const SIMDVec32& oth) const { + return SIMDVec32{veorq_u32(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec32 Pow2() const { + return SIMDVec32{vshlq_u32(vdupq_n_u32(1), vreinterpretq_s32_u32(vec))}; + } + FJXL_INLINE Mask32 Eq(const SIMDVec32& oth) const { + return Mask32{vceqq_u32(vec, oth.vec)}; + } + FJXL_INLINE Mask32 Gt(const SIMDVec32& oth) const { + return Mask32{ + vcgtq_s32(vreinterpretq_s32_u32(vec), vreinterpretq_s32_u32(oth.vec))}; + } + template <size_t i> + FJXL_INLINE SIMDVec32 SignedShiftRight() const { + return SIMDVec32{ + vreinterpretq_u32_s32(vshrq_n_s32(vreinterpretq_s32_u32(vec), i))}; + } +}; + +struct SIMDVec16; + +struct Mask16 { + uint16x8_t mask; + SIMDVec16 IfThenElse(const SIMDVec16& if_true, const SIMDVec16& if_false); + Mask16 And(const Mask16& oth) const { + return Mask16{vandq_u16(mask, oth.mask)}; + } + size_t CountPrefix() const { + uint16_t val_unset[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + uint16_t val_set[8] = {8, 8, 8, 8, 8, 8, 8, 8}; + uint16x8_t val = vbslq_u16(mask, vld1q_u16(val_set), vld1q_u16(val_unset)); + return vminvq_u16(val); + } +}; + +struct SIMDVec16 { + uint16x8_t vec; + + static constexpr size_t kLanes = 8; + + FJXL_INLINE static SIMDVec16 Load(const uint16_t* data) { + return SIMDVec16{vld1q_u16(data)}; + } + FJXL_INLINE void Store(uint16_t* data) { vst1q_u16(data, vec); } + FJXL_INLINE static SIMDVec16 Val(uint16_t v) { + return SIMDVec16{vdupq_n_u16(v)}; + } + FJXL_INLINE static SIMDVec16 FromTwo32(const SIMDVec32& lo, + const SIMDVec32& hi) { + return SIMDVec16{vmovn_high_u32(vmovn_u32(lo.vec), hi.vec)}; + } + + FJXL_INLINE SIMDVec16 ValToToken() const { + return SIMDVec16{vsubq_u16(vdupq_n_u16(16), vclzq_u16(vec))}; + } + FJXL_INLINE SIMDVec16 SatSubU(const SIMDVec16& to_subtract) const { + return SIMDVec16{vqsubq_u16(vec, to_subtract.vec)}; + } + FJXL_INLINE SIMDVec16 Sub(const SIMDVec16& to_subtract) const { + return SIMDVec16{vsubq_u16(vec, to_subtract.vec)}; + } + FJXL_INLINE SIMDVec16 Add(const SIMDVec16& oth) const { + return SIMDVec16{vaddq_u16(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 Min(const SIMDVec16& oth) const { + return SIMDVec16{vminq_u16(vec, oth.vec)}; + } + FJXL_INLINE Mask16 Eq(const SIMDVec16& oth) const { + return Mask16{vceqq_u16(vec, oth.vec)}; + } + FJXL_INLINE Mask16 Gt(const SIMDVec16& oth) const { + return Mask16{ + vcgtq_s16(vreinterpretq_s16_u16(vec), vreinterpretq_s16_u16(oth.vec))}; + } + FJXL_INLINE SIMDVec16 Pow2() const { + return SIMDVec16{vshlq_u16(vdupq_n_u16(1), vreinterpretq_s16_u16(vec))}; + } + FJXL_INLINE SIMDVec16 Or(const SIMDVec16& oth) const { + return SIMDVec16{vorrq_u16(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 Xor(const SIMDVec16& oth) const { + return SIMDVec16{veorq_u16(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 And(const SIMDVec16& oth) const { + return SIMDVec16{vandq_u16(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 HAdd(const SIMDVec16& oth) const { + return SIMDVec16{vhaddq_u16(vec, oth.vec)}; + } + FJXL_INLINE SIMDVec16 PrepareForU8Lookup() const { + return SIMDVec16{vorrq_u16(vec, vdupq_n_u16(0xFF00))}; + } + FJXL_INLINE SIMDVec16 U8Lookup(const uint8_t* table) const { + uint8x16_t tbl = vld1q_u8(table); + uint8x16_t indices = vreinterpretq_u8_u16(vec); + return SIMDVec16{vreinterpretq_u16_u8(vqtbl1q_u8(tbl, indices))}; + } + FJXL_INLINE VecPair<SIMDVec16> Interleave(const SIMDVec16& low) const { + return {SIMDVec16{vzip1q_u16(low.vec, vec)}, + SIMDVec16{vzip2q_u16(low.vec, vec)}}; + } + FJXL_INLINE VecPair<SIMDVec32> Upcast() const { + uint32x4_t lo = vmovl_u16(vget_low_u16(vec)); + uint32x4_t hi = vmovl_high_u16(vec); + return {SIMDVec32{lo}, SIMDVec32{hi}}; + } + template <size_t i> + FJXL_INLINE SIMDVec16 SignedShiftRight() const { + return SIMDVec16{ + vreinterpretq_u16_s16(vshrq_n_s16(vreinterpretq_s16_u16(vec), i))}; + } + + static std::array<SIMDVec16, 1> LoadG8(const unsigned char* data) { + uint8x8_t v = vld1_u8(data); + return {SIMDVec16{vmovl_u8(v)}}; + } + static std::array<SIMDVec16, 1> LoadG16(const unsigned char* data) { + return {Load((const uint16_t*)data)}; + } + + static std::array<SIMDVec16, 2> LoadGA8(const unsigned char* data) { + uint8x8x2_t v = vld2_u8(data); + return {SIMDVec16{vmovl_u8(v.val[0])}, SIMDVec16{vmovl_u8(v.val[1])}}; + } + static std::array<SIMDVec16, 2> LoadGA16(const unsigned char* data) { + uint16x8x2_t v = vld2q_u16((const uint16_t*)data); + return {SIMDVec16{v.val[0]}, SIMDVec16{v.val[1]}}; + } + + static std::array<SIMDVec16, 3> LoadRGB8(const unsigned char* data) { + uint8x8x3_t v = vld3_u8(data); + return {SIMDVec16{vmovl_u8(v.val[0])}, SIMDVec16{vmovl_u8(v.val[1])}, + SIMDVec16{vmovl_u8(v.val[2])}}; + } + static std::array<SIMDVec16, 3> LoadRGB16(const unsigned char* data) { + uint16x8x3_t v = vld3q_u16((const uint16_t*)data); + return {SIMDVec16{v.val[0]}, SIMDVec16{v.val[1]}, SIMDVec16{v.val[2]}}; + } + + static std::array<SIMDVec16, 4> LoadRGBA8(const unsigned char* data) { + uint8x8x4_t v = vld4_u8(data); + return {SIMDVec16{vmovl_u8(v.val[0])}, SIMDVec16{vmovl_u8(v.val[1])}, + SIMDVec16{vmovl_u8(v.val[2])}, SIMDVec16{vmovl_u8(v.val[3])}}; + } + static std::array<SIMDVec16, 4> LoadRGBA16(const unsigned char* data) { + uint16x8x4_t v = vld4q_u16((const uint16_t*)data); + return {SIMDVec16{v.val[0]}, SIMDVec16{v.val[1]}, SIMDVec16{v.val[2]}, + SIMDVec16{v.val[3]}}; + } + + void SwapEndian() { + vec = vreinterpretq_u16_u8(vrev16q_u8(vreinterpretq_u8_u16(vec))); + } +}; + +SIMDVec16 Mask16::IfThenElse(const SIMDVec16& if_true, + const SIMDVec16& if_false) { + return SIMDVec16{vbslq_u16(mask, if_true.vec, if_false.vec)}; +} + +SIMDVec32 Mask32::IfThenElse(const SIMDVec32& if_true, + const SIMDVec32& if_false) { + return SIMDVec32{vbslq_u32(mask, if_true.vec, if_false.vec)}; +} + +struct Bits64 { + static constexpr size_t kLanes = 2; + + uint64x2_t nbits; + uint64x2_t bits; + + FJXL_INLINE void Store(uint64_t* nbits_out, uint64_t* bits_out) { + vst1q_u64(nbits_out, nbits); + vst1q_u64(bits_out, bits); + } +}; + +struct Bits32 { + uint32x4_t nbits; + uint32x4_t bits; + + static Bits32 FromRaw(SIMDVec32 nbits, SIMDVec32 bits) { + return Bits32{nbits.vec, bits.vec}; + } + + Bits64 Merge() const { + // TODO(veluca): can probably be optimized. + uint64x2_t nbits_lo32 = + vandq_u64(vreinterpretq_u64_u32(nbits), vdupq_n_u64(0xFFFFFFFF)); + uint64x2_t bits_hi32 = + vshlq_u64(vshrq_n_u64(vreinterpretq_u64_u32(bits), 32), + vreinterpretq_s64_u64(nbits_lo32)); + uint64x2_t bits_lo32 = + vandq_u64(vreinterpretq_u64_u32(bits), vdupq_n_u64(0xFFFFFFFF)); + uint64x2_t nbits64 = + vsraq_n_u64(nbits_lo32, vreinterpretq_u64_u32(nbits), 32); + uint64x2_t bits64 = vorrq_u64(bits_hi32, bits_lo32); + return Bits64{nbits64, bits64}; + } + + void Interleave(const Bits32& low) { + bits = + vorrq_u32(vshlq_u32(bits, vreinterpretq_s32_u32(low.nbits)), low.bits); + nbits = vaddq_u32(nbits, low.nbits); + } + + void ClipTo(size_t n) { + n = std::min<size_t>(n, 4); + constexpr uint32_t kMask[8] = { + ~0u, ~0u, ~0u, ~0u, 0, 0, 0, 0, + }; + uint32x4_t mask = vld1q_u32(kMask + 4 - n); + nbits = vandq_u32(mask, nbits); + bits = vandq_u32(mask, bits); + } + void Skip(size_t n) { + n = std::min<size_t>(n, 4); + constexpr uint32_t kMask[8] = { + 0, 0, 0, 0, ~0u, ~0u, ~0u, ~0u, + }; + uint32x4_t mask = vld1q_u32(kMask + 4 - n); + nbits = vandq_u32(mask, nbits); + bits = vandq_u32(mask, bits); + } +}; + +struct Bits16 { + uint16x8_t nbits; + uint16x8_t bits; + + static Bits16 FromRaw(SIMDVec16 nbits, SIMDVec16 bits) { + return Bits16{nbits.vec, bits.vec}; + } + + Bits32 Merge() const { + // TODO(veluca): can probably be optimized. + uint32x4_t nbits_lo16 = + vandq_u32(vreinterpretq_u32_u16(nbits), vdupq_n_u32(0xFFFF)); + uint32x4_t bits_hi16 = + vshlq_u32(vshrq_n_u32(vreinterpretq_u32_u16(bits), 16), + vreinterpretq_s32_u32(nbits_lo16)); + uint32x4_t bits_lo16 = + vandq_u32(vreinterpretq_u32_u16(bits), vdupq_n_u32(0xFFFF)); + uint32x4_t nbits32 = + vsraq_n_u32(nbits_lo16, vreinterpretq_u32_u16(nbits), 16); + uint32x4_t bits32 = vorrq_u32(bits_hi16, bits_lo16); + return Bits32{nbits32, bits32}; + } + + void Interleave(const Bits16& low) { + bits = + vorrq_u16(vshlq_u16(bits, vreinterpretq_s16_u16(low.nbits)), low.bits); + nbits = vaddq_u16(nbits, low.nbits); + } + + void ClipTo(size_t n) { + n = std::min<size_t>(n, 8); + constexpr uint16_t kMask[16] = { + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + 0, 0, 0, 0, 0, 0, 0, 0, + }; + uint16x8_t mask = vld1q_u16(kMask + 8 - n); + nbits = vandq_u16(mask, nbits); + bits = vandq_u16(mask, bits); + } + void Skip(size_t n) { + n = std::min<size_t>(n, 8); + constexpr uint16_t kMask[16] = { + 0, 0, 0, 0, 0, 0, 0, 0, + 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, + }; + uint16x8_t mask = vld1q_u16(kMask + 8 - n); + nbits = vandq_u16(mask, nbits); + bits = vandq_u16(mask, bits); + } +}; + +#endif + +#ifdef FJXL_GENERIC_SIMD +constexpr size_t SIMDVec32::kLanes; +constexpr size_t SIMDVec16::kLanes; + +// Â Each of these functions will process SIMDVec16::kLanes worth of values. + +FJXL_INLINE void TokenizeSIMD(const uint16_t* residuals, uint16_t* token_out, + uint16_t* nbits_out, uint16_t* bits_out) { + SIMDVec16 res = SIMDVec16::Load(residuals); + SIMDVec16 token = res.ValToToken(); + SIMDVec16 nbits = token.SatSubU(SIMDVec16::Val(1)); + SIMDVec16 bits = res.SatSubU(nbits.Pow2()); + token.Store(token_out); + nbits.Store(nbits_out); + bits.Store(bits_out); +} + +FJXL_INLINE void TokenizeSIMD(const uint32_t* residuals, uint16_t* token_out, + uint32_t* nbits_out, uint32_t* bits_out) { + static_assert(SIMDVec16::kLanes == 2 * SIMDVec32::kLanes, ""); + SIMDVec32 res_lo = SIMDVec32::Load(residuals); + SIMDVec32 res_hi = SIMDVec32::Load(residuals + SIMDVec32::kLanes); + SIMDVec32 token_lo = res_lo.ValToToken(); + SIMDVec32 token_hi = res_hi.ValToToken(); + SIMDVec32 nbits_lo = token_lo.SatSubU(SIMDVec32::Val(1)); + SIMDVec32 nbits_hi = token_hi.SatSubU(SIMDVec32::Val(1)); + SIMDVec32 bits_lo = res_lo.SatSubU(nbits_lo.Pow2()); + SIMDVec32 bits_hi = res_hi.SatSubU(nbits_hi.Pow2()); + SIMDVec16 token = SIMDVec16::FromTwo32(token_lo, token_hi); + token.Store(token_out); + nbits_lo.Store(nbits_out); + nbits_hi.Store(nbits_out + SIMDVec32::kLanes); + bits_lo.Store(bits_out); + bits_hi.Store(bits_out + SIMDVec32::kLanes); +} + +FJXL_INLINE void HuffmanSIMDUpTo13(const uint16_t* tokens, + const PrefixCode& code, uint16_t* nbits_out, + uint16_t* bits_out) { + SIMDVec16 tok = SIMDVec16::Load(tokens).PrepareForU8Lookup(); + tok.U8Lookup(code.raw_nbits_simd).Store(nbits_out); + tok.U8Lookup(code.raw_bits_simd).Store(bits_out); +} + +FJXL_INLINE void HuffmanSIMD14(const uint16_t* tokens, const PrefixCode& code, + uint16_t* nbits_out, uint16_t* bits_out) { + SIMDVec16 token_cap = SIMDVec16::Val(15); + SIMDVec16 tok = SIMDVec16::Load(tokens); + SIMDVec16 tok_index = tok.Min(token_cap).PrepareForU8Lookup(); + SIMDVec16 huff_bits_pre = tok_index.U8Lookup(code.raw_bits_simd); + // Set the highest bit when token == 16; the Huffman code is constructed in + // such a way that the code for token 15 is the same as the code for 16, + // except for the highest bit. + Mask16 needs_high_bit = tok.Eq(SIMDVec16::Val(16)); + SIMDVec16 huff_bits = needs_high_bit.IfThenElse( + huff_bits_pre.Or(SIMDVec16::Val(128)), huff_bits_pre); + huff_bits.Store(bits_out); + tok_index.U8Lookup(code.raw_nbits_simd).Store(nbits_out); +} + +FJXL_INLINE void HuffmanSIMDAbove14(const uint16_t* tokens, + const PrefixCode& code, uint16_t* nbits_out, + uint16_t* bits_out) { + SIMDVec16 tok = SIMDVec16::Load(tokens); + // We assume `tok` fits in a *signed* 16-bit integer. + Mask16 above = tok.Gt(SIMDVec16::Val(12)); + // 13, 14 -> 13 + // 15, 16 -> 14 + // 17, 18 -> 15 + SIMDVec16 remap_tok = above.IfThenElse(tok.HAdd(SIMDVec16::Val(13)), tok); + SIMDVec16 tok_index = remap_tok.PrepareForU8Lookup(); + SIMDVec16 huff_bits_pre = tok_index.U8Lookup(code.raw_bits_simd); + // Set the highest bit when token == 14, 16, 18. + Mask16 needs_high_bit = above.And(tok.Eq(tok.And(SIMDVec16::Val(0xFFFE)))); + SIMDVec16 huff_bits = needs_high_bit.IfThenElse( + huff_bits_pre.Or(SIMDVec16::Val(128)), huff_bits_pre); + huff_bits.Store(bits_out); + tok_index.U8Lookup(code.raw_nbits_simd).Store(nbits_out); +} + +FJXL_INLINE void StoreSIMDUpTo8(const uint16_t* nbits_tok, + const uint16_t* bits_tok, + const uint16_t* nbits_huff, + const uint16_t* bits_huff, size_t n, + size_t skip, Bits32* bits_out) { + Bits16 bits = + Bits16::FromRaw(SIMDVec16::Load(nbits_tok), SIMDVec16::Load(bits_tok)); + Bits16 huff_bits = + Bits16::FromRaw(SIMDVec16::Load(nbits_huff), SIMDVec16::Load(bits_huff)); + bits.Interleave(huff_bits); + bits.ClipTo(n); + bits.Skip(skip); + bits_out[0] = bits.Merge(); +} + +// Huffman and raw bits don't necessarily fit in a single u16 here. +FJXL_INLINE void StoreSIMDUpTo14(const uint16_t* nbits_tok, + const uint16_t* bits_tok, + const uint16_t* nbits_huff, + const uint16_t* bits_huff, size_t n, + size_t skip, Bits32* bits_out) { + VecPair<SIMDVec16> bits = + SIMDVec16::Load(bits_tok).Interleave(SIMDVec16::Load(bits_huff)); + VecPair<SIMDVec16> nbits = + SIMDVec16::Load(nbits_tok).Interleave(SIMDVec16::Load(nbits_huff)); + Bits16 low = Bits16::FromRaw(nbits.low, bits.low); + Bits16 hi = Bits16::FromRaw(nbits.hi, bits.hi); + low.ClipTo(2 * n); + low.Skip(2 * skip); + hi.ClipTo(std::max(2 * n, SIMDVec16::kLanes) - SIMDVec16::kLanes); + hi.Skip(std::max(2 * skip, SIMDVec16::kLanes) - SIMDVec16::kLanes); + + bits_out[0] = low.Merge(); + bits_out[1] = hi.Merge(); +} + +FJXL_INLINE void StoreSIMDAbove14(const uint32_t* nbits_tok, + const uint32_t* bits_tok, + const uint16_t* nbits_huff, + const uint16_t* bits_huff, size_t n, + size_t skip, Bits32* bits_out) { + static_assert(SIMDVec16::kLanes == 2 * SIMDVec32::kLanes, ""); + Bits32 bits_low = + Bits32::FromRaw(SIMDVec32::Load(nbits_tok), SIMDVec32::Load(bits_tok)); + Bits32 bits_hi = + Bits32::FromRaw(SIMDVec32::Load(nbits_tok + SIMDVec32::kLanes), + SIMDVec32::Load(bits_tok + SIMDVec32::kLanes)); + + VecPair<SIMDVec32> huff_bits = SIMDVec16::Load(bits_huff).Upcast(); + VecPair<SIMDVec32> huff_nbits = SIMDVec16::Load(nbits_huff).Upcast(); + + Bits32 huff_low = Bits32::FromRaw(huff_nbits.low, huff_bits.low); + Bits32 huff_hi = Bits32::FromRaw(huff_nbits.hi, huff_bits.hi); + + bits_low.Interleave(huff_low); + bits_low.ClipTo(n); + bits_low.Skip(skip); + bits_out[0] = bits_low; + bits_hi.Interleave(huff_hi); + bits_hi.ClipTo(std::max(n, SIMDVec32::kLanes) - SIMDVec32::kLanes); + bits_hi.Skip(std::max(skip, SIMDVec32::kLanes) - SIMDVec32::kLanes); + bits_out[1] = bits_hi; +} + +#ifdef FJXL_AVX512 +FJXL_INLINE void StoreToWriterAVX512(const Bits32& bits32, BitWriter& output) { + __m512i bits = bits32.bits; + __m512i nbits = bits32.nbits; + + // Insert the leftover bits from the bit buffer at the bottom of the vector + // and extract the top of the vector. + uint64_t trail_bits = + _mm512_cvtsi512_si32(_mm512_alignr_epi32(bits, bits, 15)); + uint64_t trail_nbits = + _mm512_cvtsi512_si32(_mm512_alignr_epi32(nbits, nbits, 15)); + __m512i lead_bits = _mm512_set1_epi32(output.buffer); + __m512i lead_nbits = _mm512_set1_epi32(output.bits_in_buffer); + bits = _mm512_alignr_epi32(bits, lead_bits, 15); + nbits = _mm512_alignr_epi32(nbits, lead_nbits, 15); + + // Merge 32 -> 64 bits. + Bits32 b{nbits, bits}; + Bits64 b64 = b.Merge(); + bits = b64.bits; + nbits = b64.nbits; + + __m512i zero = _mm512_setzero_si512(); + + auto sh1 = [zero](__m512i vec) { return _mm512_alignr_epi64(vec, zero, 7); }; + auto sh2 = [zero](__m512i vec) { return _mm512_alignr_epi64(vec, zero, 6); }; + auto sh4 = [zero](__m512i vec) { return _mm512_alignr_epi64(vec, zero, 4); }; + + // Compute first-past-end-bit-position. + __m512i end_interm0 = _mm512_add_epi64(nbits, sh1(nbits)); + __m512i end_interm1 = _mm512_add_epi64(end_interm0, sh2(end_interm0)); + __m512i end = _mm512_add_epi64(end_interm1, sh4(end_interm1)); + + uint64_t simd_nbits = _mm512_cvtsi512_si32(_mm512_alignr_epi64(end, end, 7)); + + // Compute begin-bit-position. + __m512i begin = _mm512_sub_epi64(end, nbits); + + // Index of the last bit in the chunk, or the end bit if nbits==0. + __m512i last = _mm512_mask_sub_epi64( + end, _mm512_cmpneq_epi64_mask(nbits, zero), end, _mm512_set1_epi64(1)); + + __m512i lane_offset_mask = _mm512_set1_epi64(63); + + // Starting position of the chunk that each lane will ultimately belong to. + __m512i chunk_start = _mm512_andnot_si512(lane_offset_mask, last); + + // For all lanes that contain bits belonging to two different 64-bit chunks, + // compute the number of bits that belong to the first chunk. + // total # of bits fit in a u16, so we can satsub_u16 here. + __m512i first_chunk_nbits = _mm512_subs_epu16(chunk_start, begin); + + // Move all the previous-chunk-bits to the previous lane. + __m512i negnbits = _mm512_sub_epi64(_mm512_set1_epi64(64), first_chunk_nbits); + __m512i first_chunk_bits = + _mm512_srlv_epi64(_mm512_sllv_epi64(bits, negnbits), negnbits); + __m512i first_chunk_bits_down = + _mm512_alignr_epi32(zero, first_chunk_bits, 2); + bits = _mm512_srlv_epi64(bits, first_chunk_nbits); + nbits = _mm512_sub_epi64(nbits, first_chunk_nbits); + bits = _mm512_or_si512(bits, _mm512_sllv_epi64(first_chunk_bits_down, nbits)); + begin = _mm512_add_epi64(begin, first_chunk_nbits); + + // We now know that every lane should give bits to only one chunk. We can + // shift the bits and then horizontally-or-reduce them within the same chunk. + __m512i offset = _mm512_and_si512(begin, lane_offset_mask); + __m512i aligned_bits = _mm512_sllv_epi64(bits, offset); + // h-or-reduce within same chunk + __m512i red0 = _mm512_mask_or_epi64( + aligned_bits, _mm512_cmpeq_epi64_mask(sh1(chunk_start), chunk_start), + sh1(aligned_bits), aligned_bits); + __m512i red1 = _mm512_mask_or_epi64( + red0, _mm512_cmpeq_epi64_mask(sh2(chunk_start), chunk_start), sh2(red0), + red0); + __m512i reduced = _mm512_mask_or_epi64( + red1, _mm512_cmpeq_epi64_mask(sh4(chunk_start), chunk_start), sh4(red1), + red1); + // Extract the highest lane that belongs to each chunk (the lane that ends up + // with the OR-ed value of all the other lanes of that chunk). + __m512i next_chunk_start = + _mm512_alignr_epi32(_mm512_set1_epi64(~0), chunk_start, 2); + __m512i result = _mm512_maskz_compress_epi64( + _mm512_cmpneq_epi64_mask(chunk_start, next_chunk_start), reduced); + + _mm512_storeu_si512((__m512i*)(output.data.get() + output.bytes_written), + result); + + // Update the bit writer and add the last 32-bit lane. + // Note that since trail_nbits was at most 32 to begin with, operating on + // trail_bits does not risk overflowing. + output.bytes_written += simd_nbits / 8; + // Here we are implicitly relying on the fact that simd_nbits < 512 to know + // that the byte of bitreader data we access is initialized. This is + // guaranteed because the remaining bits in the bitreader buffer are at most + // 7, so simd_nbits <= 505 always. + trail_bits = (trail_bits << (simd_nbits % 8)) + + output.data.get()[output.bytes_written]; + trail_nbits += simd_nbits % 8; + StoreLE64(output.data.get() + output.bytes_written, trail_bits); + size_t trail_bytes = trail_nbits / 8; + output.bits_in_buffer = trail_nbits % 8; + output.buffer = trail_bits >> (trail_bytes * 8); + output.bytes_written += trail_bytes; +} + +#endif + +template <size_t n> +FJXL_INLINE void StoreToWriter(const Bits32* bits, BitWriter& output) { +#ifdef FJXL_AVX512 + static_assert(n <= 2, ""); + StoreToWriterAVX512(bits[0], output); + if (n == 2) { + StoreToWriterAVX512(bits[1], output); + } + return; +#endif + static_assert(n <= 4, ""); + alignas(64) uint64_t nbits64[Bits64::kLanes * n]; + alignas(64) uint64_t bits64[Bits64::kLanes * n]; + bits[0].Merge().Store(nbits64, bits64); + if (n > 1) { + bits[1].Merge().Store(nbits64 + Bits64::kLanes, bits64 + Bits64::kLanes); + } + if (n > 2) { + bits[2].Merge().Store(nbits64 + 2 * Bits64::kLanes, + bits64 + 2 * Bits64::kLanes); + } + if (n > 3) { + bits[3].Merge().Store(nbits64 + 3 * Bits64::kLanes, + bits64 + 3 * Bits64::kLanes); + } + output.WriteMultiple(nbits64, bits64, Bits64::kLanes * n); +} + +namespace detail { +template <typename T> +struct IntegerTypes; + +template <> +struct IntegerTypes<SIMDVec16> { + using signed_ = int16_t; + using unsigned_ = uint16_t; +}; + +template <> +struct IntegerTypes<SIMDVec32> { + using signed_ = int32_t; + using unsigned_ = uint32_t; +}; + +template <typename T> +struct SIMDType; + +template <> +struct SIMDType<int16_t> { + using type = SIMDVec16; +}; + +template <> +struct SIMDType<int32_t> { + using type = SIMDVec32; +}; + +} // namespace detail + +template <typename T> +using signed_t = typename detail::IntegerTypes<T>::signed_; + +template <typename T> +using unsigned_t = typename detail::IntegerTypes<T>::unsigned_; + +template <typename T> +using simd_t = typename detail::SIMDType<T>::type; + +// This function will process exactly one vector worth of pixels. + +template <typename T> +size_t PredictPixels(const signed_t<T>* pixels, const signed_t<T>* pixels_left, + const signed_t<T>* pixels_top, + const signed_t<T>* pixels_topleft, + unsigned_t<T>* residuals) { + T px = T::Load((unsigned_t<T>*)pixels); + T left = T::Load((unsigned_t<T>*)pixels_left); + T top = T::Load((unsigned_t<T>*)pixels_top); + T topleft = T::Load((unsigned_t<T>*)pixels_topleft); + T ac = left.Sub(topleft); + T ab = left.Sub(top); + T bc = top.Sub(topleft); + T grad = ac.Add(top); + T d = ab.Xor(bc); + T zero = T::Val(0); + T clamp = zero.Gt(d).IfThenElse(top, left); + T s = ac.Xor(bc); + T pred = zero.Gt(s).IfThenElse(grad, clamp); + T res = px.Sub(pred); + T res_times_2 = res.Add(res); + res = zero.Gt(res).IfThenElse(T::Val(-1).Sub(res_times_2), res_times_2); + res.Store(residuals); + return res.Eq(T::Val(0)).CountPrefix(); +} + +#endif + +void EncodeHybridUint000(uint32_t value, uint32_t* token, uint32_t* nbits, + uint32_t* bits) { + uint32_t n = FloorLog2(value); + *token = value ? n + 1 : 0; + *nbits = value ? n : 0; + *bits = value ? value - (1 << n) : 0; +} + +#ifdef FJXL_AVX512 +constexpr static size_t kLogChunkSize = 5; +#elif defined(FJXL_AVX2) || defined(FJXL_NEON) +// Even if NEON only has 128-bit lanes, it is still significantly (~1.3x) faster +// to process two vectors at a time. +constexpr static size_t kLogChunkSize = 4; +#else +constexpr static size_t kLogChunkSize = 3; +#endif + +constexpr static size_t kChunkSize = 1 << kLogChunkSize; + +template <typename Residual> +void GenericEncodeChunk(const Residual* residuals, size_t n, size_t skip, + const PrefixCode& code, BitWriter& output) { + for (size_t ix = skip; ix < n; ix++) { + unsigned token, nbits, bits; + EncodeHybridUint000(residuals[ix], &token, &nbits, &bits); + output.Write(code.raw_nbits[token] + nbits, + code.raw_bits[token] | bits << code.raw_nbits[token]); + } +} + +struct UpTo8Bits { + size_t bitdepth; + explicit UpTo8Bits(size_t bitdepth) : bitdepth(bitdepth) { + assert(bitdepth <= 8); + } + // Here we can fit up to 9 extra bits + 7 Huffman bits in a u16; for all other + // symbols, we could actually go up to 8 Huffman bits as we have at most 8 + // extra bits; however, the SIMD bit merging logic for AVX2 assumes that no + // Huffman length is 8 or more, so we cap at 8 anyway. Last symbol is used for + // LZ77 lengths and has no limitations except allowing to represent 32 symbols + // in total. + static constexpr uint8_t kMinRawLength[12] = {}; + static constexpr uint8_t kMaxRawLength[12] = { + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 10, + }; + static size_t MaxEncodedBitsPerSample() { return 16; } + static constexpr size_t kInputBytes = 1; + using pixel_t = int16_t; + using upixel_t = uint16_t; + + static void PrepareForSimd(const uint8_t* nbits, const uint8_t* bits, + size_t n, uint8_t* nbits_simd, + uint8_t* bits_simd) { + assert(n <= 16); + memcpy(nbits_simd, nbits, 16); + memcpy(bits_simd, bits, 16); + } + + static void EncodeChunk(upixel_t* residuals, size_t n, size_t skip, + const PrefixCode& code, BitWriter& output) { +#ifdef FJXL_GENERIC_SIMD + Bits32 bits32[kChunkSize / SIMDVec16::kLanes]; + alignas(64) uint16_t bits[SIMDVec16::kLanes]; + alignas(64) uint16_t nbits[SIMDVec16::kLanes]; + alignas(64) uint16_t bits_huff[SIMDVec16::kLanes]; + alignas(64) uint16_t nbits_huff[SIMDVec16::kLanes]; + alignas(64) uint16_t token[SIMDVec16::kLanes]; + for (size_t i = 0; i < kChunkSize; i += SIMDVec16::kLanes) { + TokenizeSIMD(residuals + i, token, nbits, bits); + HuffmanSIMDUpTo13(token, code, nbits_huff, bits_huff); + StoreSIMDUpTo8(nbits, bits, nbits_huff, bits_huff, std::max(n, i) - i, + std::max(skip, i) - i, bits32 + i / SIMDVec16::kLanes); + } + StoreToWriter<kChunkSize / SIMDVec16::kLanes>(bits32, output); + return; +#endif + GenericEncodeChunk(residuals, n, skip, code, output); + } + + size_t NumSymbols(bool doing_ycocg) const { + // values gain 1 bit for YCoCg, 1 bit for prediction. + // Maximum symbol is 1 + effective bit depth of residuals. + if (doing_ycocg) { + return bitdepth + 3; + } else { + return bitdepth + 2; + } + } +}; +constexpr uint8_t UpTo8Bits::kMinRawLength[]; +constexpr uint8_t UpTo8Bits::kMaxRawLength[]; + +struct From9To13Bits { + size_t bitdepth; + explicit From9To13Bits(size_t bitdepth) : bitdepth(bitdepth) { + assert(bitdepth <= 13 && bitdepth >= 9); + } + // Last symbol is used for LZ77 lengths and has no limitations except allowing + // to represent 32 symbols in total. + // We cannot fit all the bits in a u16, so do not even try and use up to 8 + // bits per raw symbol. + // There are at most 16 raw symbols, so Huffman coding can be SIMDfied without + // any special tricks. + static constexpr uint8_t kMinRawLength[17] = {}; + static constexpr uint8_t kMaxRawLength[17] = { + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 10, + }; + static size_t MaxEncodedBitsPerSample() { return 21; } + static constexpr size_t kInputBytes = 2; + using pixel_t = int16_t; + using upixel_t = uint16_t; + + static void PrepareForSimd(const uint8_t* nbits, const uint8_t* bits, + size_t n, uint8_t* nbits_simd, + uint8_t* bits_simd) { + assert(n <= 16); + memcpy(nbits_simd, nbits, 16); + memcpy(bits_simd, bits, 16); + } + + static void EncodeChunk(upixel_t* residuals, size_t n, size_t skip, + const PrefixCode& code, BitWriter& output) { +#ifdef FJXL_GENERIC_SIMD + Bits32 bits32[2 * kChunkSize / SIMDVec16::kLanes]; + alignas(64) uint16_t bits[SIMDVec16::kLanes]; + alignas(64) uint16_t nbits[SIMDVec16::kLanes]; + alignas(64) uint16_t bits_huff[SIMDVec16::kLanes]; + alignas(64) uint16_t nbits_huff[SIMDVec16::kLanes]; + alignas(64) uint16_t token[SIMDVec16::kLanes]; + for (size_t i = 0; i < kChunkSize; i += SIMDVec16::kLanes) { + TokenizeSIMD(residuals + i, token, nbits, bits); + HuffmanSIMDUpTo13(token, code, nbits_huff, bits_huff); + StoreSIMDUpTo14(nbits, bits, nbits_huff, bits_huff, std::max(n, i) - i, + std::max(skip, i) - i, + bits32 + 2 * i / SIMDVec16::kLanes); + } + StoreToWriter<2 * kChunkSize / SIMDVec16::kLanes>(bits32, output); + return; +#endif + GenericEncodeChunk(residuals, n, skip, code, output); + } + + size_t NumSymbols(bool doing_ycocg) const { + // values gain 1 bit for YCoCg, 1 bit for prediction. + // Maximum symbol is 1 + effective bit depth of residuals. + if (doing_ycocg) { + return bitdepth + 3; + } else { + return bitdepth + 2; + } + } +}; +constexpr uint8_t From9To13Bits::kMinRawLength[]; +constexpr uint8_t From9To13Bits::kMaxRawLength[]; + +void CheckHuffmanBitsSIMD(int bits1, int nbits1, int bits2, int nbits2) { + assert(nbits1 == 8); + assert(nbits2 == 8); + assert(bits2 == (bits1 | 128)); +} + +struct Exactly14Bits { + explicit Exactly14Bits(size_t bitdepth) { assert(bitdepth == 14); } + // Force LZ77 symbols to have at least 8 bits, and raw symbols 15 and 16 to + // have exactly 8, and no other symbol to have 8 or more. This ensures that + // the representation for 15 and 16 is identical up to one bit. + static constexpr uint8_t kMinRawLength[18] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 7, + }; + static constexpr uint8_t kMaxRawLength[18] = { + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 10, + }; + static constexpr size_t bitdepth = 14; + static size_t MaxEncodedBitsPerSample() { return 22; } + static constexpr size_t kInputBytes = 2; + using pixel_t = int16_t; + using upixel_t = uint16_t; + + static void PrepareForSimd(const uint8_t* nbits, const uint8_t* bits, + size_t n, uint8_t* nbits_simd, + uint8_t* bits_simd) { + assert(n == 17); + CheckHuffmanBitsSIMD(bits[15], nbits[15], bits[16], nbits[16]); + memcpy(nbits_simd, nbits, 16); + memcpy(bits_simd, bits, 16); + } + + static void EncodeChunk(upixel_t* residuals, size_t n, size_t skip, + const PrefixCode& code, BitWriter& output) { +#ifdef FJXL_GENERIC_SIMD + Bits32 bits32[2 * kChunkSize / SIMDVec16::kLanes]; + alignas(64) uint16_t bits[SIMDVec16::kLanes]; + alignas(64) uint16_t nbits[SIMDVec16::kLanes]; + alignas(64) uint16_t bits_huff[SIMDVec16::kLanes]; + alignas(64) uint16_t nbits_huff[SIMDVec16::kLanes]; + alignas(64) uint16_t token[SIMDVec16::kLanes]; + for (size_t i = 0; i < kChunkSize; i += SIMDVec16::kLanes) { + TokenizeSIMD(residuals + i, token, nbits, bits); + HuffmanSIMD14(token, code, nbits_huff, bits_huff); + StoreSIMDUpTo14(nbits, bits, nbits_huff, bits_huff, std::max(n, i) - i, + std::max(skip, i) - i, + bits32 + 2 * i / SIMDVec16::kLanes); + } + StoreToWriter<2 * kChunkSize / SIMDVec16::kLanes>(bits32, output); + return; +#endif + GenericEncodeChunk(residuals, n, skip, code, output); + } + + size_t NumSymbols(bool) const { return 17; } +}; +constexpr uint8_t Exactly14Bits::kMinRawLength[]; +constexpr uint8_t Exactly14Bits::kMaxRawLength[]; + +struct MoreThan14Bits { + size_t bitdepth; + explicit MoreThan14Bits(size_t bitdepth) : bitdepth(bitdepth) { + assert(bitdepth > 14); + assert(bitdepth <= 16); + } + // Force LZ77 symbols to have at least 8 bits, and raw symbols 13 to 18 to + // have exactly 8, and no other symbol to have 8 or more. This ensures that + // the representation for (13, 14), (15, 16), (17, 18) is identical up to one + // bit. + static constexpr uint8_t kMinRawLength[20] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 7, + }; + static constexpr uint8_t kMaxRawLength[20] = { + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 10, + }; + static size_t MaxEncodedBitsPerSample() { return 24; } + static constexpr size_t kInputBytes = 2; + using pixel_t = int32_t; + using upixel_t = uint32_t; + + static void PrepareForSimd(const uint8_t* nbits, const uint8_t* bits, + size_t n, uint8_t* nbits_simd, + uint8_t* bits_simd) { + assert(n == 19); + CheckHuffmanBitsSIMD(bits[13], nbits[13], bits[14], nbits[14]); + CheckHuffmanBitsSIMD(bits[15], nbits[15], bits[16], nbits[16]); + CheckHuffmanBitsSIMD(bits[17], nbits[17], bits[18], nbits[18]); + for (size_t i = 0; i < 14; i++) { + nbits_simd[i] = nbits[i]; + bits_simd[i] = bits[i]; + } + nbits_simd[14] = nbits[15]; + bits_simd[14] = bits[15]; + nbits_simd[15] = nbits[17]; + bits_simd[15] = bits[17]; + } + + static void EncodeChunk(upixel_t* residuals, size_t n, size_t skip, + const PrefixCode& code, BitWriter& output) { +#ifdef FJXL_GENERIC_SIMD + Bits32 bits32[2 * kChunkSize / SIMDVec16::kLanes]; + alignas(64) uint32_t bits[SIMDVec16::kLanes]; + alignas(64) uint32_t nbits[SIMDVec16::kLanes]; + alignas(64) uint16_t bits_huff[SIMDVec16::kLanes]; + alignas(64) uint16_t nbits_huff[SIMDVec16::kLanes]; + alignas(64) uint16_t token[SIMDVec16::kLanes]; + for (size_t i = 0; i < kChunkSize; i += SIMDVec16::kLanes) { + TokenizeSIMD(residuals + i, token, nbits, bits); + HuffmanSIMDAbove14(token, code, nbits_huff, bits_huff); + StoreSIMDAbove14(nbits, bits, nbits_huff, bits_huff, std::max(n, i) - i, + std::max(skip, i) - i, + bits32 + 2 * i / SIMDVec16::kLanes); + } + StoreToWriter<2 * kChunkSize / SIMDVec16::kLanes>(bits32, output); + return; +#endif + GenericEncodeChunk(residuals, n, skip, code, output); + } + size_t NumSymbols(bool) const { return 19; } +}; +constexpr uint8_t MoreThan14Bits::kMinRawLength[]; +constexpr uint8_t MoreThan14Bits::kMaxRawLength[]; + +void PrepareDCGlobalCommon(bool is_single_group, size_t width, size_t height, + const PrefixCode code[4], BitWriter* output) { + output->Allocate(100000 + (is_single_group ? width * height * 16 : 0)); + // No patches, spline or noise. + output->Write(1, 1); // default DC dequantization factors (?) + output->Write(1, 1); // use global tree / histograms + output->Write(1, 0); // no lz77 for the tree + + output->Write(1, 1); // simple code for the tree's context map + output->Write(2, 0); // all contexts clustered together + output->Write(1, 1); // use prefix code for tree + output->Write(4, 0); // 000 hybrid uint + output->Write(6, 0b100011); // Alphabet size is 4 (var16) + output->Write(2, 1); // simple prefix code + output->Write(2, 3); // with 4 symbols + output->Write(2, 0); + output->Write(2, 1); + output->Write(2, 2); + output->Write(2, 3); + output->Write(1, 0); // First tree encoding option + // Huffman table + extra bits for the tree. + uint8_t symbol_bits[6] = {0b00, 0b10, 0b001, 0b101, 0b0011, 0b0111}; + uint8_t symbol_nbits[6] = {2, 2, 3, 3, 4, 4}; + // Write a tree with a leaf per channel, and gradient predictor for every + // leaf. + for (auto v : {1, 2, 1, 4, 1, 0, 0, 5, 0, 0, 0, 0, 5, + 0, 0, 0, 0, 5, 0, 0, 0, 0, 5, 0, 0, 0}) { + output->Write(symbol_nbits[v], symbol_bits[v]); + } + + output->Write(1, 1); // Enable lz77 for the main bitstream + output->Write(2, 0b00); // lz77 offset 224 + static_assert(kLZ77Offset == 224, ""); + output->Write(4, 0b1010); // lz77 min length 7 + // 400 hybrid uint config for lz77 + output->Write(4, 4); + output->Write(3, 0); + output->Write(3, 0); + + output->Write(1, 1); // simple code for the context map + output->Write(2, 3); // 3 bits per entry + output->Write(3, 4); // channel 3 + output->Write(3, 3); // channel 2 + output->Write(3, 2); // channel 1 + output->Write(3, 1); // channel 0 + output->Write(3, 0); // distance histogram first + + output->Write(1, 1); // use prefix codes + output->Write(4, 0); // 000 hybrid uint config for distances (only need 0) + for (size_t i = 0; i < 4; i++) { + output->Write(4, 0); // 000 hybrid uint config for symbols (only <= 10) + } + + // Distance alphabet size: + output->Write(5, 0b00001); // 2: just need 1 for RLE (i.e. distance 1) + // Symbol + LZ77 alphabet size: + for (size_t i = 0; i < 4; i++) { + output->Write(1, 1); // > 1 + output->Write(4, 8); // <= 512 + output->Write(8, 256); // == 512 + } + + // Distance histogram: + output->Write(2, 1); // simple prefix code + output->Write(2, 0); // with one symbol + output->Write(1, 1); // 1 + + // Symbol + lz77 histogram: + for (size_t i = 0; i < 4; i++) { + code[i].WriteTo(output); + } + + // Group header for global modular image. + output->Write(1, 1); // Global tree + output->Write(1, 1); // All default wp +} + +void PrepareDCGlobal(bool is_single_group, size_t width, size_t height, + size_t nb_chans, const PrefixCode code[4], + BitWriter* output) { + PrepareDCGlobalCommon(is_single_group, width, height, code, output); + if (nb_chans > 2) { + output->Write(2, 0b01); // 1 transform + output->Write(2, 0b00); // RCT + output->Write(5, 0b00000); // Starting from ch 0 + output->Write(2, 0b00); // YCoCg + } else { + output->Write(2, 0b00); // no transforms + } + if (!is_single_group) { + output->ZeroPadToByte(); + } +} + +template <typename BitDepth> +struct ChunkEncoder { + FJXL_INLINE static void EncodeRle(size_t count, const PrefixCode& code, + BitWriter& output) { + if (count == 0) return; + count -= kLZ77MinLength + 1; + if (count < kLZ77CacheSize) { + output.Write(code.lz77_cache_nbits[count], code.lz77_cache_bits[count]); + } else { + unsigned token, nbits, bits; + EncodeHybridUintLZ77(count, &token, &nbits, &bits); + uint64_t wbits = bits; + wbits = (wbits << code.lz77_nbits[token]) | code.lz77_bits[token]; + wbits = (wbits << code.raw_nbits[0]) | code.raw_bits[0]; + output.Write(code.lz77_nbits[token] + nbits + code.raw_nbits[0], wbits); + } + } + + FJXL_INLINE void Chunk(size_t run, typename BitDepth::upixel_t* residuals, + size_t skip, size_t n) { + EncodeRle(run, *code, *output); + BitDepth::EncodeChunk(residuals, n, skip, *code, *output); + } + + inline void Finalize(size_t run) { EncodeRle(run, *code, *output); } + + const PrefixCode* code; + BitWriter* output; +}; + +template <typename BitDepth> +struct ChunkSampleCollector { + FJXL_INLINE void Rle(size_t count, uint64_t* lz77_counts) { + if (count == 0) return; + raw_counts[0] += 1; + count -= kLZ77MinLength + 1; + unsigned token, nbits, bits; + EncodeHybridUintLZ77(count, &token, &nbits, &bits); + lz77_counts[token]++; + } + + FJXL_INLINE void Chunk(size_t run, typename BitDepth::upixel_t* residuals, + size_t skip, size_t n) { + // Run is broken. Encode the run and encode the individual vector. + Rle(run, lz77_counts); + for (size_t ix = skip; ix < n; ix++) { + unsigned token, nbits, bits; + EncodeHybridUint000(residuals[ix], &token, &nbits, &bits); + raw_counts[token]++; + } + } + + // don't count final run since we don't know how long it really is + void Finalize(size_t run) {} + + uint64_t* raw_counts; + uint64_t* lz77_counts; +}; + +constexpr uint32_t PackSigned(int32_t value) { + return (static_cast<uint32_t>(value) << 1) ^ + ((static_cast<uint32_t>(~value) >> 31) - 1); +} + +template <typename T, typename BitDepth> +struct ChannelRowProcessor { + using upixel_t = typename BitDepth::upixel_t; + using pixel_t = typename BitDepth::pixel_t; + T* t; + void ProcessChunk(const pixel_t* row, const pixel_t* row_left, + const pixel_t* row_top, const pixel_t* row_topleft, + size_t n) { + alignas(64) upixel_t residuals[kChunkSize] = {}; + size_t prefix_size = 0; + size_t required_prefix_size = 0; +#ifdef FJXL_GENERIC_SIMD + constexpr size_t kNum = + sizeof(pixel_t) == 2 ? SIMDVec16::kLanes : SIMDVec32::kLanes; + for (size_t ix = 0; ix < kChunkSize; ix += kNum) { + size_t c = + PredictPixels<simd_t<pixel_t>>(row + ix, row_left + ix, row_top + ix, + row_topleft + ix, residuals + ix); + prefix_size = + prefix_size == required_prefix_size ? prefix_size + c : prefix_size; + required_prefix_size += kNum; + } +#else + for (size_t ix = 0; ix < kChunkSize; ix++) { + pixel_t px = row[ix]; + pixel_t left = row_left[ix]; + pixel_t top = row_top[ix]; + pixel_t topleft = row_topleft[ix]; + pixel_t ac = left - topleft; + pixel_t ab = left - top; + pixel_t bc = top - topleft; + pixel_t grad = static_cast<pixel_t>(static_cast<upixel_t>(ac) + + static_cast<upixel_t>(top)); + pixel_t d = ab ^ bc; + pixel_t clamp = d < 0 ? top : left; + pixel_t s = ac ^ bc; + pixel_t pred = s < 0 ? grad : clamp; + residuals[ix] = PackSigned(px - pred); + prefix_size = prefix_size == required_prefix_size + ? prefix_size + (residuals[ix] == 0) + : prefix_size; + required_prefix_size += 1; + } +#endif + prefix_size = std::min(n, prefix_size); + if (prefix_size == n && (run > 0 || prefix_size > kLZ77MinLength)) { + // Run continues, nothing to do. + run += prefix_size; + } else if (prefix_size + run > kLZ77MinLength) { + // Run is broken. Encode the run and encode the individual vector. + t->Chunk(run + prefix_size, residuals, prefix_size, n); + run = 0; + } else { + // There was no run to begin with. + t->Chunk(0, residuals, 0, n); + } + } + + void ProcessRow(const pixel_t* row, const pixel_t* row_left, + const pixel_t* row_top, const pixel_t* row_topleft, + size_t xs) { + for (size_t x = 0; x < xs; x += kChunkSize) { + ProcessChunk(row + x, row_left + x, row_top + x, row_topleft + x, + std::min(kChunkSize, xs - x)); + } + } + + void Finalize() { t->Finalize(run); } + // Invariant: run == 0 or run > kLZ77MinLength. + size_t run = 0; +}; + +uint16_t LoadLE16(const unsigned char* ptr) { + return uint16_t{ptr[0]} | (uint16_t{ptr[1]} << 8); +} + +uint16_t SwapEndian(uint16_t in) { return (in >> 8) | (in << 8); } + +#ifdef FJXL_GENERIC_SIMD +void StorePixels(SIMDVec16 p, int16_t* dest) { p.Store((uint16_t*)dest); } + +void StorePixels(SIMDVec16 p, int32_t* dest) { + VecPair<SIMDVec32> p_up = p.Upcast(); + p_up.low.Store((uint32_t*)dest); + p_up.hi.Store((uint32_t*)dest + SIMDVec32::kLanes); +} +#endif + +template <typename pixel_t> +void FillRowG8(const unsigned char* rgba, size_t oxs, pixel_t* luma) { + size_t x = 0; +#ifdef FJXL_GENERIC_SIMD + for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) { + auto rgb = SIMDVec16::LoadG8(rgba + x); + StorePixels(rgb[0], luma + x); + } +#endif + for (; x < oxs; x++) { + luma[x] = rgba[x]; + } +} + +template <bool big_endian, typename pixel_t> +void FillRowG16(const unsigned char* rgba, size_t oxs, pixel_t* luma) { + size_t x = 0; +#ifdef FJXL_GENERIC_SIMD + for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) { + auto rgb = SIMDVec16::LoadG16(rgba + 2 * x); + if (big_endian) { + rgb[0].SwapEndian(); + } + StorePixels(rgb[0], luma + x); + } +#endif + for (; x < oxs; x++) { + uint16_t val = LoadLE16(rgba + 2 * x); + if (big_endian) { + val = SwapEndian(val); + } + luma[x] = val; + } +} + +template <typename pixel_t> +void FillRowGA8(const unsigned char* rgba, size_t oxs, pixel_t* luma, + pixel_t* alpha) { + size_t x = 0; +#ifdef FJXL_GENERIC_SIMD + for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) { + auto rgb = SIMDVec16::LoadGA8(rgba + 2 * x); + StorePixels(rgb[0], luma + x); + StorePixels(rgb[1], alpha + x); + } +#endif + for (; x < oxs; x++) { + luma[x] = rgba[2 * x]; + alpha[x] = rgba[2 * x + 1]; + } +} + +template <bool big_endian, typename pixel_t> +void FillRowGA16(const unsigned char* rgba, size_t oxs, pixel_t* luma, + pixel_t* alpha) { + size_t x = 0; +#ifdef FJXL_GENERIC_SIMD + for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) { + auto rgb = SIMDVec16::LoadGA16(rgba + 4 * x); + if (big_endian) { + rgb[0].SwapEndian(); + rgb[1].SwapEndian(); + } + StorePixels(rgb[0], luma + x); + StorePixels(rgb[1], alpha + x); + } +#endif + for (; x < oxs; x++) { + uint16_t l = LoadLE16(rgba + 4 * x); + uint16_t a = LoadLE16(rgba + 4 * x + 2); + if (big_endian) { + l = SwapEndian(l); + a = SwapEndian(a); + } + luma[x] = l; + alpha[x] = a; + } +} + +template <typename pixel_t> +void StoreYCoCg(pixel_t r, pixel_t g, pixel_t b, pixel_t* y, pixel_t* co, + pixel_t* cg) { + *co = r - b; + pixel_t tmp = b + (*co >> 1); + *cg = g - tmp; + *y = tmp + (*cg >> 1); +} + +#ifdef FJXL_GENERIC_SIMD +void StoreYCoCg(SIMDVec16 r, SIMDVec16 g, SIMDVec16 b, int16_t* y, int16_t* co, + int16_t* cg) { + SIMDVec16 co_v = r.Sub(b); + SIMDVec16 tmp = b.Add(co_v.SignedShiftRight<1>()); + SIMDVec16 cg_v = g.Sub(tmp); + SIMDVec16 y_v = tmp.Add(cg_v.SignedShiftRight<1>()); + y_v.Store((uint16_t*)y); + co_v.Store((uint16_t*)co); + cg_v.Store((uint16_t*)cg); +} + +void StoreYCoCg(SIMDVec16 r, SIMDVec16 g, SIMDVec16 b, int32_t* y, int32_t* co, + int32_t* cg) { + VecPair<SIMDVec32> r_up = r.Upcast(); + VecPair<SIMDVec32> g_up = g.Upcast(); + VecPair<SIMDVec32> b_up = b.Upcast(); + SIMDVec32 co_lo_v = r_up.low.Sub(b_up.low); + SIMDVec32 tmp_lo = b_up.low.Add(co_lo_v.SignedShiftRight<1>()); + SIMDVec32 cg_lo_v = g_up.low.Sub(tmp_lo); + SIMDVec32 y_lo_v = tmp_lo.Add(cg_lo_v.SignedShiftRight<1>()); + SIMDVec32 co_hi_v = r_up.hi.Sub(b_up.hi); + SIMDVec32 tmp_hi = b_up.hi.Add(co_hi_v.SignedShiftRight<1>()); + SIMDVec32 cg_hi_v = g_up.hi.Sub(tmp_hi); + SIMDVec32 y_hi_v = tmp_hi.Add(cg_hi_v.SignedShiftRight<1>()); + y_lo_v.Store((uint32_t*)y); + co_lo_v.Store((uint32_t*)co); + cg_lo_v.Store((uint32_t*)cg); + y_hi_v.Store((uint32_t*)y + SIMDVec32::kLanes); + co_hi_v.Store((uint32_t*)co + SIMDVec32::kLanes); + cg_hi_v.Store((uint32_t*)cg + SIMDVec32::kLanes); +} +#endif + +template <typename pixel_t> +void FillRowRGB8(const unsigned char* rgba, size_t oxs, pixel_t* y, pixel_t* co, + pixel_t* cg) { + size_t x = 0; +#ifdef FJXL_GENERIC_SIMD + for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) { + auto rgb = SIMDVec16::LoadRGB8(rgba + 3 * x); + StoreYCoCg(rgb[0], rgb[1], rgb[2], y + x, co + x, cg + x); + } +#endif + for (; x < oxs; x++) { + uint16_t r = rgba[3 * x]; + uint16_t g = rgba[3 * x + 1]; + uint16_t b = rgba[3 * x + 2]; + StoreYCoCg<pixel_t>(r, g, b, y + x, co + x, cg + x); + } +} + +template <bool big_endian, typename pixel_t> +void FillRowRGB16(const unsigned char* rgba, size_t oxs, pixel_t* y, + pixel_t* co, pixel_t* cg) { + size_t x = 0; +#ifdef FJXL_GENERIC_SIMD + for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) { + auto rgb = SIMDVec16::LoadRGB16(rgba + 6 * x); + if (big_endian) { + rgb[0].SwapEndian(); + rgb[1].SwapEndian(); + rgb[2].SwapEndian(); + } + StoreYCoCg(rgb[0], rgb[1], rgb[2], y + x, co + x, cg + x); + } +#endif + for (; x < oxs; x++) { + uint16_t r = LoadLE16(rgba + 6 * x); + uint16_t g = LoadLE16(rgba + 6 * x + 2); + uint16_t b = LoadLE16(rgba + 6 * x + 4); + if (big_endian) { + r = SwapEndian(r); + g = SwapEndian(g); + b = SwapEndian(b); + } + StoreYCoCg<pixel_t>(r, g, b, y + x, co + x, cg + x); + } +} + +template <typename pixel_t> +void FillRowRGBA8(const unsigned char* rgba, size_t oxs, pixel_t* y, + pixel_t* co, pixel_t* cg, pixel_t* alpha) { + size_t x = 0; +#ifdef FJXL_GENERIC_SIMD + for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) { + auto rgb = SIMDVec16::LoadRGBA8(rgba + 4 * x); + StoreYCoCg(rgb[0], rgb[1], rgb[2], y + x, co + x, cg + x); + StorePixels(rgb[3], alpha + x); + } +#endif + for (; x < oxs; x++) { + uint16_t r = rgba[4 * x]; + uint16_t g = rgba[4 * x + 1]; + uint16_t b = rgba[4 * x + 2]; + uint16_t a = rgba[4 * x + 3]; + StoreYCoCg<pixel_t>(r, g, b, y + x, co + x, cg + x); + alpha[x] = a; + } +} + +template <bool big_endian, typename pixel_t> +void FillRowRGBA16(const unsigned char* rgba, size_t oxs, pixel_t* y, + pixel_t* co, pixel_t* cg, pixel_t* alpha) { + size_t x = 0; +#ifdef FJXL_GENERIC_SIMD + for (; x + SIMDVec16::kLanes <= oxs; x += SIMDVec16::kLanes) { + auto rgb = SIMDVec16::LoadRGBA16(rgba + 8 * x); + if (big_endian) { + rgb[0].SwapEndian(); + rgb[1].SwapEndian(); + rgb[2].SwapEndian(); + rgb[3].SwapEndian(); + } + StoreYCoCg(rgb[0], rgb[1], rgb[2], y + x, co + x, cg + x); + StorePixels(rgb[3], alpha + x); + } +#endif + for (; x < oxs; x++) { + uint16_t r = LoadLE16(rgba + 8 * x); + uint16_t g = LoadLE16(rgba + 8 * x + 2); + uint16_t b = LoadLE16(rgba + 8 * x + 4); + uint16_t a = LoadLE16(rgba + 8 * x + 6); + if (big_endian) { + r = SwapEndian(r); + g = SwapEndian(g); + b = SwapEndian(b); + a = SwapEndian(a); + } + StoreYCoCg<pixel_t>(r, g, b, y + x, co + x, cg + x); + alpha[x] = a; + } +} + +template <typename Processor, typename BitDepth> +void ProcessImageArea(const unsigned char* rgba, size_t x0, size_t y0, + size_t xs, size_t yskip, size_t ys, size_t row_stride, + BitDepth bitdepth, size_t nb_chans, bool big_endian, + Processor* processors) { + constexpr size_t kPadding = 32; + + using pixel_t = typename BitDepth::pixel_t; + + constexpr size_t kAlign = 64; + constexpr size_t kAlignPixels = kAlign / sizeof(pixel_t); + + auto align = [=](pixel_t* ptr) { + size_t offset = reinterpret_cast<uintptr_t>(ptr) % kAlign; + if (offset) { + ptr += offset / sizeof(pixel_t); + } + return ptr; + }; + + constexpr size_t kNumPx = + (256 + kPadding * 2 + kAlignPixels + kAlignPixels - 1) / kAlignPixels * + kAlignPixels; + + std::vector<std::array<std::array<pixel_t, kNumPx>, 2>> group_data(nb_chans); + + for (size_t y = 0; y < ys; y++) { + const auto rgba_row = + rgba + row_stride * (y0 + y) + x0 * nb_chans * BitDepth::kInputBytes; + pixel_t* crow[4] = {}; + pixel_t* prow[4] = {}; + for (size_t i = 0; i < nb_chans; i++) { + crow[i] = align(&group_data[i][y & 1][kPadding]); + prow[i] = align(&group_data[i][(y - 1) & 1][kPadding]); + } + + // Pre-fill rows with YCoCg converted pixels. + if (nb_chans == 1) { + if (BitDepth::kInputBytes == 1) { + FillRowG8(rgba_row, xs, crow[0]); + } else if (big_endian) { + FillRowG16</*big_endian=*/true>(rgba_row, xs, crow[0]); + } else { + FillRowG16</*big_endian=*/false>(rgba_row, xs, crow[0]); + } + } else if (nb_chans == 2) { + if (BitDepth::kInputBytes == 1) { + FillRowGA8(rgba_row, xs, crow[0], crow[1]); + } else if (big_endian) { + FillRowGA16</*big_endian=*/true>(rgba_row, xs, crow[0], crow[1]); + } else { + FillRowGA16</*big_endian=*/false>(rgba_row, xs, crow[0], crow[1]); + } + } else if (nb_chans == 3) { + if (BitDepth::kInputBytes == 1) { + FillRowRGB8(rgba_row, xs, crow[0], crow[1], crow[2]); + } else if (big_endian) { + FillRowRGB16</*big_endian=*/true>(rgba_row, xs, crow[0], crow[1], + crow[2]); + } else { + FillRowRGB16</*big_endian=*/false>(rgba_row, xs, crow[0], crow[1], + crow[2]); + } + } else { + if (BitDepth::kInputBytes == 1) { + FillRowRGBA8(rgba_row, xs, crow[0], crow[1], crow[2], crow[3]); + } else if (big_endian) { + FillRowRGBA16</*big_endian=*/true>(rgba_row, xs, crow[0], crow[1], + crow[2], crow[3]); + } else { + FillRowRGBA16</*big_endian=*/false>(rgba_row, xs, crow[0], crow[1], + crow[2], crow[3]); + } + } + // Deal with x == 0. + for (size_t c = 0; c < nb_chans; c++) { + *(crow[c] - 1) = y > 0 ? *(prow[c]) : 0; + // Fix topleft. + *(prow[c] - 1) = y > 0 ? *(prow[c]) : 0; + } + if (y < yskip) continue; + for (size_t c = 0; c < nb_chans; c++) { + // Get pointers to px/left/top/topleft data to speedup loop. + const pixel_t* row = crow[c]; + const pixel_t* row_left = crow[c] - 1; + const pixel_t* row_top = y == 0 ? row_left : prow[c]; + const pixel_t* row_topleft = y == 0 ? row_left : prow[c] - 1; + + processors[c].ProcessRow(row, row_left, row_top, row_topleft, xs); + } + } + for (size_t c = 0; c < nb_chans; c++) { + processors[c].Finalize(); + } +} + +template <typename BitDepth> +void WriteACSection(const unsigned char* rgba, size_t x0, size_t y0, size_t xs, + size_t ys, size_t row_stride, bool is_single_group, + BitDepth bitdepth, size_t nb_chans, bool big_endian, + const PrefixCode code[4], + std::array<BitWriter, 4>& output) { + for (size_t i = 0; i < nb_chans; i++) { + if (is_single_group && i == 0) continue; + output[i].Allocate(xs * ys * bitdepth.MaxEncodedBitsPerSample() + 4); + } + if (!is_single_group) { + // Group header for modular image. + // When the image is single-group, the global modular image is the one + // that contains the pixel data, and there is no group header. + output[0].Write(1, 1); // Global tree + output[0].Write(1, 1); // All default wp + output[0].Write(2, 0b00); // 0 transforms + } + + ChunkEncoder<BitDepth> encoders[4]; + ChannelRowProcessor<ChunkEncoder<BitDepth>, BitDepth> row_encoders[4]; + for (size_t c = 0; c < nb_chans; c++) { + row_encoders[c].t = &encoders[c]; + encoders[c].output = &output[c]; + encoders[c].code = &code[c]; + } + ProcessImageArea<ChannelRowProcessor<ChunkEncoder<BitDepth>, BitDepth>>( + rgba, x0, y0, xs, 0, ys, row_stride, bitdepth, nb_chans, big_endian, + row_encoders); +} + +constexpr int kHashExp = 16; +constexpr uint32_t kHashSize = 1 << kHashExp; +constexpr uint32_t kHashMultiplier = 2654435761; +constexpr int kMaxColors = 512; + +// can be any function that returns a value in 0 .. kHashSize-1 +// has to map 0 to 0 +inline uint32_t pixel_hash(uint32_t p) { + return (p * kHashMultiplier) >> (32 - kHashExp); +} + +template <size_t nb_chans> +void FillRowPalette(const unsigned char* inrow, size_t xs, + const int16_t* lookup, int16_t* out) { + for (size_t x = 0; x < xs; x++) { + uint32_t p = 0; + memcpy(&p, inrow + x * nb_chans, nb_chans); + out[x] = lookup[pixel_hash(p)]; + } +} + +template <typename Processor> +void ProcessImageAreaPalette(const unsigned char* rgba, size_t x0, size_t y0, + size_t xs, size_t yskip, size_t ys, + size_t row_stride, const int16_t* lookup, + size_t nb_chans, Processor* processors) { + constexpr size_t kPadding = 32; + + std::vector<std::array<int16_t, 256 + kPadding * 2>> group_data(2); + Processor& row_encoder = processors[0]; + + for (size_t y = 0; y < ys; y++) { + // Pre-fill rows with palette converted pixels. + const unsigned char* inrow = rgba + row_stride * (y0 + y) + x0 * nb_chans; + int16_t* outrow = &group_data[y & 1][kPadding]; + if (nb_chans == 1) { + FillRowPalette<1>(inrow, xs, lookup, outrow); + } else if (nb_chans == 2) { + FillRowPalette<2>(inrow, xs, lookup, outrow); + } else if (nb_chans == 3) { + FillRowPalette<3>(inrow, xs, lookup, outrow); + } else if (nb_chans == 4) { + FillRowPalette<4>(inrow, xs, lookup, outrow); + } + // Deal with x == 0. + group_data[y & 1][kPadding - 1] = + y > 0 ? group_data[(y - 1) & 1][kPadding] : 0; + // Fix topleft. + group_data[(y - 1) & 1][kPadding - 1] = + y > 0 ? group_data[(y - 1) & 1][kPadding] : 0; + // Get pointers to px/left/top/topleft data to speedup loop. + const int16_t* row = &group_data[y & 1][kPadding]; + const int16_t* row_left = &group_data[y & 1][kPadding - 1]; + const int16_t* row_top = + y == 0 ? row_left : &group_data[(y - 1) & 1][kPadding]; + const int16_t* row_topleft = + y == 0 ? row_left : &group_data[(y - 1) & 1][kPadding - 1]; + + row_encoder.ProcessRow(row, row_left, row_top, row_topleft, xs); + } + row_encoder.Finalize(); +} + +void WriteACSectionPalette(const unsigned char* rgba, size_t x0, size_t y0, + size_t xs, size_t ys, size_t row_stride, + bool is_single_group, const PrefixCode code[4], + const int16_t* lookup, size_t nb_chans, + BitWriter& output) { + if (!is_single_group) { + output.Allocate(16 * xs * ys + 4); + // Group header for modular image. + // When the image is single-group, the global modular image is the one + // that contains the pixel data, and there is no group header. + output.Write(1, 1); // Global tree + output.Write(1, 1); // All default wp + output.Write(2, 0b00); // 0 transforms + } + + ChunkEncoder<UpTo8Bits> encoder; + ChannelRowProcessor<ChunkEncoder<UpTo8Bits>, UpTo8Bits> row_encoder; + + row_encoder.t = &encoder; + encoder.output = &output; + encoder.code = &code[is_single_group ? 1 : 0]; + ProcessImageAreaPalette< + ChannelRowProcessor<ChunkEncoder<UpTo8Bits>, UpTo8Bits>>( + rgba, x0, y0, xs, 0, ys, row_stride, lookup, nb_chans, &row_encoder); +} + +template <typename BitDepth> +void CollectSamples(const unsigned char* rgba, size_t x0, size_t y0, size_t xs, + size_t row_stride, size_t row_count, + uint64_t raw_counts[4][kNumRawSymbols], + uint64_t lz77_counts[4][kNumLZ77], bool is_single_group, + bool palette, BitDepth bitdepth, size_t nb_chans, + bool big_endian, const int16_t* lookup) { + if (palette) { + ChunkSampleCollector<UpTo8Bits> sample_collectors[4]; + ChannelRowProcessor<ChunkSampleCollector<UpTo8Bits>, UpTo8Bits> + row_sample_collectors[4]; + for (size_t c = 0; c < nb_chans; c++) { + row_sample_collectors[c].t = &sample_collectors[c]; + sample_collectors[c].raw_counts = raw_counts[is_single_group ? 1 : 0]; + sample_collectors[c].lz77_counts = lz77_counts[is_single_group ? 1 : 0]; + } + ProcessImageAreaPalette< + ChannelRowProcessor<ChunkSampleCollector<UpTo8Bits>, UpTo8Bits>>( + rgba, x0, y0, xs, 1, 1 + row_count, row_stride, lookup, nb_chans, + row_sample_collectors); + } else { + ChunkSampleCollector<BitDepth> sample_collectors[4]; + ChannelRowProcessor<ChunkSampleCollector<BitDepth>, BitDepth> + row_sample_collectors[4]; + for (size_t c = 0; c < nb_chans; c++) { + row_sample_collectors[c].t = &sample_collectors[c]; + sample_collectors[c].raw_counts = raw_counts[c]; + sample_collectors[c].lz77_counts = lz77_counts[c]; + } + ProcessImageArea< + ChannelRowProcessor<ChunkSampleCollector<BitDepth>, BitDepth>>( + rgba, x0, y0, xs, 1, 1 + row_count, row_stride, bitdepth, nb_chans, + big_endian, row_sample_collectors); + } +} + +void PrepareDCGlobalPalette(bool is_single_group, size_t width, size_t height, + const PrefixCode code[4], + const std::vector<uint32_t>& palette, + size_t pcolors, BitWriter* output) { + PrepareDCGlobalCommon(is_single_group, width, height, code, output); + output->Write(2, 0b01); // 1 transform + output->Write(2, 0b01); // Palette + output->Write(5, 0b00000); // Starting from ch 0 + output->Write(2, 0b10); // 4-channel palette (RGBA) + // pcolors <= kMaxColors + kChunkSize - 1 + static_assert(kMaxColors + kChunkSize < 1281, + "add code to signal larger palette sizes"); + if (pcolors < 256) { + output->Write(2, 0b00); + output->Write(8, pcolors); + } else { + output->Write(2, 0b01); + output->Write(10, pcolors - 256); + } + + output->Write(2, 0b00); // nb_deltas == 0 + output->Write(4, 0); // Zero predictor for delta palette + // Encode palette + ChunkEncoder<UpTo8Bits> encoder; + ChannelRowProcessor<ChunkEncoder<UpTo8Bits>, UpTo8Bits> row_encoder; + row_encoder.t = &encoder; + encoder.output = output; + encoder.code = &code[0]; + int16_t p[4][32 + 1024] = {}; + uint8_t prgba[4]; + size_t i = 0; + size_t have_zero = 0; + if (palette[pcolors - 1] == 0) have_zero = 1; + for (; i < pcolors; i++) { + memcpy(prgba, &palette[i], 4); + p[0][16 + i + have_zero] = prgba[0]; + p[1][16 + i + have_zero] = prgba[1]; + p[2][16 + i + have_zero] = prgba[2]; + p[3][16 + i + have_zero] = prgba[3]; + } + p[0][15] = 0; + row_encoder.ProcessRow(p[0] + 16, p[0] + 15, p[0] + 15, p[0] + 15, pcolors); + p[1][15] = p[0][16]; + p[0][15] = p[0][16]; + row_encoder.ProcessRow(p[1] + 16, p[1] + 15, p[0] + 16, p[0] + 15, pcolors); + p[2][15] = p[1][16]; + p[1][15] = p[1][16]; + row_encoder.ProcessRow(p[2] + 16, p[2] + 15, p[1] + 16, p[1] + 15, pcolors); + p[3][15] = p[2][16]; + p[2][15] = p[2][16]; + row_encoder.ProcessRow(p[3] + 16, p[3] + 15, p[2] + 16, p[2] + 15, pcolors); + row_encoder.Finalize(); + + if (!is_single_group) { + output->ZeroPadToByte(); + } +} + +template <typename BitDepth> +JxlFastLosslessFrameState* LLEnc(const unsigned char* rgba, size_t width, + size_t stride, size_t height, + BitDepth bitdepth, size_t nb_chans, + bool big_endian, int effort, + void* runner_opaque, + FJxlParallelRunner runner) { + assert(width != 0); + assert(height != 0); + assert(stride >= nb_chans * BitDepth::kInputBytes * width); + + // Count colors to try palette + std::vector<uint32_t> palette(kHashSize); + palette[0] = 1; + std::vector<int16_t> lookup(kHashSize); + lookup[0] = 0; + int pcolors = 0; + bool collided = effort < 2 || bitdepth.bitdepth != 8 || + nb_chans < 4; // todo: also do rgb palette + for (size_t y = 0; y < height && !collided; y++) { + const unsigned char* r = rgba + stride * y; + size_t x = 0; + if (nb_chans == 4) { + // this is just an unrolling of the next loop + for (; x + 7 < width; x += 8) { + uint32_t p[8], index[8]; + memcpy(p, r + x * 4, 32); + for (int i = 0; i < 8; i++) index[i] = pixel_hash(p[i]); + for (int i = 0; i < 8; i++) { + uint32_t init_entry = index[i] ? 0 : 1; + if (init_entry != palette[index[i]] && p[i] != palette[index[i]]) { + collided = true; + } + } + for (int i = 0; i < 8; i++) palette[index[i]] = p[i]; + } + for (; x < width; x++) { + uint32_t p; + memcpy(&p, r + x * 4, 4); + uint32_t index = pixel_hash(p); + uint32_t init_entry = index ? 0 : 1; + if (init_entry != palette[index] && p != palette[index]) { + collided = true; + } + palette[index] = p; + } + } else { + for (; x < width; x++) { + uint32_t p = 0; + memcpy(&p, r + x * nb_chans, nb_chans); + uint32_t index = pixel_hash(p); + uint32_t init_entry = index ? 0 : 1; + if (init_entry != palette[index] && p != palette[index]) { + collided = true; + } + palette[index] = p; + } + } + } + + int nb_entries = 0; + if (!collided) { + if (palette[0] == 0) pcolors = 1; + if (palette[0] == 1) palette[0] = 0; + bool have_color = false; + uint8_t minG = 255, maxG = 0; + for (uint32_t k = 0; k < kHashSize; k++) { + if (palette[k] == 0) continue; + uint8_t p[4]; + memcpy(p, &palette[k], 4); + // move entries to front so sort has less work + palette[nb_entries] = palette[k]; + if (p[0] != p[1] || p[0] != p[2]) have_color = true; + if (p[1] < minG) minG = p[1]; + if (p[1] > maxG) maxG = p[1]; + nb_entries++; + // don't do palette if too many colors are needed + if (nb_entries + pcolors > kMaxColors) { + collided = true; + break; + } + } + if (!have_color) { + // don't do palette if it's just grayscale without many holes + if (maxG - minG < nb_entries * 1.4f) collided = true; + } + } + if (!collided) { + std::sort( + palette.begin(), palette.begin() + nb_entries, + [](uint32_t ap, uint32_t bp) { + if (ap == 0) return false; + if (bp == 0) return true; + uint8_t a[4], b[4]; + memcpy(a, &ap, 4); + memcpy(b, &bp, 4); + float ay, by; + ay = (0.299f * a[0] + 0.587f * a[1] + 0.114f * a[2] + 0.01f) * a[3]; + by = (0.299f * b[0] + 0.587f * b[1] + 0.114f * b[2] + 0.01f) * b[3]; + return ay < by; // sort on alpha*luma + }); + for (int k = 0; k < nb_entries; k++) { + if (palette[k] == 0) break; + lookup[pixel_hash(palette[k])] = pcolors++; + } + } + + size_t num_groups_x = (width + 255) / 256; + size_t num_groups_y = (height + 255) / 256; + size_t num_dc_groups_x = (width + 2047) / 2048; + size_t num_dc_groups_y = (height + 2047) / 2048; + + uint64_t raw_counts[4][kNumRawSymbols] = {}; + uint64_t lz77_counts[4][kNumLZ77] = {}; + + bool onegroup = num_groups_x == 1 && num_groups_y == 1; + + // sample the middle (effort * 2) rows of every group + for (size_t g = 0; g < num_groups_y * num_groups_x; g++) { + size_t xg = g % num_groups_x; + size_t yg = g / num_groups_x; + int y_offset = yg * 256; + int y_max = std::min<size_t>(height - yg * 256, 256); + int y_begin = y_offset + std::max<int>(0, y_max - 2 * effort) / 2; + int y_count = + std::min<int>(2 * effort * y_max / 256, y_offset + y_max - y_begin - 1); + int x_max = + std::min<size_t>(width - xg * 256, 256) / kChunkSize * kChunkSize; + CollectSamples(rgba, xg * 256, y_begin, x_max, stride, y_count, raw_counts, + lz77_counts, onegroup, !collided, bitdepth, nb_chans, + big_endian, lookup.data()); + } + + // TODO(veluca): can probably improve this and make it bitdepth-dependent. + uint64_t base_raw_counts[kNumRawSymbols] = { + 3843, 852, 1270, 1214, 1014, 727, 481, 300, 159, 51, + 5, 1, 1, 1, 1, 1, 1, 1, 1}; + + bool doing_ycocg = nb_chans > 2 && collided; + for (size_t i = bitdepth.NumSymbols(doing_ycocg); i < kNumRawSymbols; i++) { + base_raw_counts[i] = 0; + } + + for (size_t c = 0; c < 4; c++) { + for (size_t i = 0; i < kNumRawSymbols; i++) { + raw_counts[c][i] = (raw_counts[c][i] << 8) + base_raw_counts[i]; + } + } + + if (!collided) { + unsigned token, nbits, bits; + EncodeHybridUint000(PackSigned(pcolors - 1), &token, &nbits, &bits); + // ensure all palette indices can actually be encoded + for (size_t i = 0; i < token + 1; i++) + raw_counts[0][i] = std::max<uint64_t>(raw_counts[0][i], 1); + // these tokens are only used for the palette itself so they can get a bad + // code + for (size_t i = token + 1; i < 10; i++) raw_counts[0][i] = 1; + } + + uint64_t base_lz77_counts[kNumLZ77] = { + 29, 27, 25, 23, 21, 21, 19, 18, 21, 17, 16, 15, 15, 14, + 13, 13, 137, 98, 61, 34, 1, 1, 1, 1, 1, 1, 1, 1, + }; + + for (size_t c = 0; c < 4; c++) { + for (size_t i = 0; i < kNumLZ77; i++) { + lz77_counts[c][i] = (lz77_counts[c][i] << 8) + base_lz77_counts[i]; + } + } + + alignas(64) PrefixCode hcode[4]; + for (size_t i = 0; i < 4; i++) { + hcode[i] = PrefixCode(bitdepth, raw_counts[i], lz77_counts[i]); + } + + size_t num_groups = onegroup ? 1 + : (2 + num_dc_groups_x * num_dc_groups_y + + num_groups_x * num_groups_y); + + JxlFastLosslessFrameState* frame_state = new JxlFastLosslessFrameState(); + + frame_state->width = width; + frame_state->height = height; + frame_state->nb_chans = nb_chans; + frame_state->bitdepth = bitdepth.bitdepth; + + frame_state->group_data = std::vector<std::array<BitWriter, 4>>(num_groups); + if (collided) { + PrepareDCGlobal(onegroup, width, height, nb_chans, hcode, + &frame_state->group_data[0][0]); + } else { + PrepareDCGlobalPalette(onegroup, width, height, hcode, palette, pcolors, + &frame_state->group_data[0][0]); + } + + auto run_one = [&](size_t g) { + size_t xg = g % num_groups_x; + size_t yg = g / num_groups_x; + size_t group_id = + onegroup ? 0 : (2 + num_dc_groups_x * num_dc_groups_y + g); + size_t xs = std::min<size_t>(width - xg * 256, 256); + size_t ys = std::min<size_t>(height - yg * 256, 256); + size_t x0 = xg * 256; + size_t y0 = yg * 256; + auto& gd = frame_state->group_data[group_id]; + if (collided) { + WriteACSection(rgba, x0, y0, xs, ys, stride, onegroup, bitdepth, nb_chans, + big_endian, hcode, gd); + + } else { + WriteACSectionPalette(rgba, x0, y0, xs, ys, stride, onegroup, hcode, + lookup.data(), nb_chans, gd[0]); + } + }; + + runner( + runner_opaque, &run_one, + +[](void* r, size_t i) { (*reinterpret_cast<decltype(&run_one)>(r))(i); }, + num_groups_x * num_groups_y); + + return frame_state; +} + +JxlFastLosslessFrameState* JxlFastLosslessEncodeImpl( + const unsigned char* rgba, size_t width, size_t stride, size_t height, + size_t nb_chans, size_t bitdepth, bool big_endian, int effort, + void* runner_opaque, FJxlParallelRunner runner) { + assert(bitdepth > 0); + assert(nb_chans <= 4); + assert(nb_chans != 0); + if (bitdepth <= 8) { + return LLEnc(rgba, width, stride, height, UpTo8Bits(bitdepth), nb_chans, + big_endian, effort, runner_opaque, runner); + } + if (bitdepth <= 13) { + return LLEnc(rgba, width, stride, height, From9To13Bits(bitdepth), nb_chans, + big_endian, effort, runner_opaque, runner); + } + if (bitdepth == 14) { + return LLEnc(rgba, width, stride, height, Exactly14Bits(bitdepth), nb_chans, + big_endian, effort, runner_opaque, runner); + } + return LLEnc(rgba, width, stride, height, MoreThan14Bits(bitdepth), nb_chans, + big_endian, effort, runner_opaque, runner); +} + +} // namespace + +#endif // FJXL_SELF_INCLUDE + +#ifndef FJXL_SELF_INCLUDE + +#define FJXL_SELF_INCLUDE + +// If we have NEON enabled, it is the default target. +#if FJXL_ENABLE_NEON + +namespace default_implementation { +#define FJXL_NEON +#include "lib/jxl/enc_fast_lossless.cc" +#undef FJXL_NEON +} // namespace default_implementation + +#else // FJXL_ENABLE_NEON + +namespace default_implementation { +#include "lib/jxl/enc_fast_lossless.cc" +} + +#if FJXL_ENABLE_AVX2 +#ifdef __clang__ +#pragma clang attribute push(__attribute__((target("avx,avx2"))), \ + apply_to = function) +// Causes spurious warnings on clang5. +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wmissing-braces" +#elif defined(__GNUC__) +#pragma GCC push_options +// Seems to cause spurious errors on GCC8. +#pragma GCC diagnostic ignored "-Wpsabi" +#pragma GCC target "avx,avx2" +#endif + +namespace AVX2 { +#define FJXL_AVX2 +#include "lib/jxl/enc_fast_lossless.cc" +#undef FJXL_AVX2 +} // namespace AVX2 + +#ifdef __clang__ +#pragma clang attribute pop +#pragma clang diagnostic pop +#elif defined(__GNUC__) +#pragma GCC pop_options +#endif +#endif // FJXL_ENABLE_AVX2 + +#if FJXL_ENABLE_AVX512 +#ifdef __clang__ +#pragma clang attribute push( \ + __attribute__((target("avx512cd,avx512bw,avx512vl,avx512f,avx512vbmi"))), \ + apply_to = function) +#elif defined(__GNUC__) +#pragma GCC push_options +#pragma GCC target "avx512cd,avx512bw,avx512vl,avx512f,avx512vbmi" +#endif + +namespace AVX512 { +#define FJXL_AVX512 +#include "lib/jxl/enc_fast_lossless.cc" +#undef FJXL_AVX512 +} // namespace AVX512 + +#ifdef __clang__ +#pragma clang attribute pop +#elif defined(__GNUC__) +#pragma GCC pop_options +#endif +#endif // FJXL_ENABLE_AVX512 + +#endif + +extern "C" { + +size_t JxlFastLosslessEncode(const unsigned char* rgba, size_t width, + size_t row_stride, size_t height, size_t nb_chans, + size_t bitdepth, int big_endian, int effort, + unsigned char** output, void* runner_opaque, + FJxlParallelRunner runner) { + auto frame_state = JxlFastLosslessPrepareFrame( + rgba, width, row_stride, height, nb_chans, bitdepth, big_endian, effort, + runner_opaque, runner); + JxlFastLosslessPrepareHeader(frame_state, /*add_image_header=*/1, + /*is_last=*/1); + size_t output_size = JxlFastLosslessMaxRequiredOutput(frame_state); + *output = (unsigned char*)malloc(output_size); + size_t written = 0; + size_t total = 0; + while ((written = JxlFastLosslessWriteOutput(frame_state, *output + total, + output_size - total)) != 0) { + total += written; + } + return total; +} + +JxlFastLosslessFrameState* JxlFastLosslessPrepareFrame( + const unsigned char* rgba, size_t width, size_t row_stride, size_t height, + size_t nb_chans, size_t bitdepth, int big_endian, int effort, + void* runner_opaque, FJxlParallelRunner runner) { + auto trivial_runner = + +[](void*, void* opaque, void fun(void*, size_t), size_t count) { + for (size_t i = 0; i < count; i++) { + fun(opaque, i); + } + }; + + if (runner == nullptr) { + runner = trivial_runner; + } + +#if FJXL_ENABLE_AVX512 + if (__builtin_cpu_supports("avx512cd") && + __builtin_cpu_supports("avx512vbmi") && + __builtin_cpu_supports("avx512bw") && __builtin_cpu_supports("avx512f") && + __builtin_cpu_supports("avx512vl")) { + return AVX512::JxlFastLosslessEncodeImpl(rgba, width, row_stride, height, + nb_chans, bitdepth, big_endian, + effort, runner_opaque, runner); + } +#endif +#if FJXL_ENABLE_AVX2 + if (__builtin_cpu_supports("avx2")) { + return AVX2::JxlFastLosslessEncodeImpl(rgba, width, row_stride, height, + nb_chans, bitdepth, big_endian, + effort, runner_opaque, runner); + } +#endif + + return default_implementation::JxlFastLosslessEncodeImpl( + rgba, width, row_stride, height, nb_chans, bitdepth, big_endian, effort, + runner_opaque, runner); +} + +} // extern "C" + +#endif // FJXL_SELF_INCLUDE |